]> git.cryptolib.org Git - avr-crypto-lib.git/blob - bigint/bigint.c
some minor improvments and bug fixes
[avr-crypto-lib.git] / bigint / bigint.c
1 /* bigint.c */
2 /*
3     This file is part of the ARM-Crypto-Lib.
4     Copyright (C) 2008  Daniel Otte (daniel.otte@rub.de)
5
6     This program is free software: you can redistribute it and/or modify
7     it under the terms of the GNU General Public License as published by
8     the Free Software Foundation, either version 3 of the License, or
9     (at your option) any later version.
10
11     This program is distributed in the hope that it will be useful,
12     but WITHOUT ANY WARRANTY; without even the implied warranty of
13     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14     GNU General Public License for more details.
15
16     You should have received a copy of the GNU General Public License
17     along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19 /**
20  * \file                bigint.c
21  * \author              Daniel Otte
22  * \date                2010-02-22
23  * 
24  * \license         GPLv3 or later
25  * 
26  */
27  
28
29 #define STRING2(x) #x
30 #define STRING(x) STRING2(x)
31 #define STR_LINE STRING(__LINE__)
32
33 #include "bigint.h"
34 #include <string.h>
35
36 #define DEBUG 0
37
38 #if DEBUG
39 #include "cli.h"
40 #include "bigint_io.h"
41 #endif
42
43 #ifndef MAX
44  #define MAX(a,b) (((a)>(b))?(a):(b))
45 #endif
46
47 #ifndef MIN
48  #define MIN(a,b) (((a)<(b))?(a):(b))
49 #endif
50
51 #define SET_FBS(a, v) do{(a)->info &=~BIGINT_FBS_MASK; (a)->info |= (v);}while(0)
52 #define GET_FBS(a)   ((a)->info&BIGINT_FBS_MASK)
53 #define SET_NEG(a)   (a)->info |= BIGINT_NEG_MASK
54 #define SET_POS(a)   (a)->info &= ~BIGINT_NEG_MASK
55 #define XCHG(a,b)    do{(a)^=(b); (b)^=(a); (a)^=(b);}while(0)
56 #define XCHG_PTR(a,b)    do{ a = (void*)(((intptr_t)(a)) ^ ((intptr_t)(b))); \
57                                  b = (void*)(((intptr_t)(a)) ^ ((intptr_t)(b))); \
58                                  a = (void*)(((intptr_t)(a)) ^ ((intptr_t)(b)));}while(0)
59
60 #define GET_SIGN(a) ((a)->info&BIGINT_NEG_MASK)
61
62 /******************************************************************************/
63 void bigint_adjust(bigint_t* a){
64         while(a->length_W!=0 && a->wordv[a->length_W-1]==0){
65                 a->length_W--;
66         }
67         if(a->length_W==0){
68                 a->info=0;
69                 return;
70         }
71         bigint_word_t t;
72         uint8_t i = BIGINT_WORD_SIZE-1;
73         t = a->wordv[a->length_W-1];
74         while((t&(1L<<(BIGINT_WORD_SIZE-1)))==0 && i){
75                 t<<=1;
76                 i--;
77         }
78         SET_FBS(a, i);
79 }
80
81 /******************************************************************************/
82
83 uint16_t bigint_length_b(const bigint_t* a){
84         if(!a->length_W || a->length_W==0){
85                 return 0;
86         }
87         return (a->length_W-1) * BIGINT_WORD_SIZE + GET_FBS(a);
88 }
89
90 /******************************************************************************/
91
92 uint16_t bigint_length_B(const bigint_t* a){
93         return a->length_W * sizeof(bigint_word_t);
94 }
95
96 /******************************************************************************/
97
98 uint32_t bigint_get_first_set_bit(const bigint_t* a){
99         if(a->length_W==0){
100                 return (uint32_t)(-1);
101         }
102         return (a->length_W-1)*sizeof(bigint_word_t)*8+GET_FBS(a);
103 }
104
105
106 /******************************************************************************/
107
108 uint32_t bigint_get_last_set_bit(const bigint_t* a){
109         uint32_t r=0;
110         uint8_t b=0;
111         bigint_word_t x=1;
112         if(a->length_W==0){
113                 return (uint32_t)(-1);
114         }
115         while(a->wordv[r]==0 && r<a->length_W){
116                 ++r;
117         }
118         if(a->wordv[r] == 0){
119                 return (uint32_t)(-1);
120         }
121         while((x&a->wordv[r])==0){
122                 ++b;
123                 x <<= 1;
124         }
125         return r*BIGINT_WORD_SIZE+b;
126 }
127
128 /******************************************************************************/
129
130 void bigint_copy(bigint_t* dest, const bigint_t* src){
131         memcpy(dest->wordv, src->wordv, src->length_W*sizeof(bigint_word_t));
132         dest->length_W = src->length_W;
133         dest->info = src->info;
134 }
135
136 /******************************************************************************/
137
138 /* this should be implemented in assembly */
139 void bigint_add_u(bigint_t* dest, const bigint_t* a, const bigint_t* b){
140         uint16_t i;
141         bigint_wordplus_t t = 0LL;
142         if(a->length_W < b->length_W){
143                 XCHG_PTR(a,b);
144         }
145         for(i = 0; i < b->length_W; ++i){
146                 t += a->wordv[i];
147                 t += b->wordv[i];
148                 dest->wordv[i] = (bigint_word_t)t;
149                 t >>= BIGINT_WORD_SIZE;
150         }
151         for(; i < a->length_W; ++i){
152                 t += a->wordv[i];
153                 dest->wordv[i] = (bigint_word_t)t;
154                 t >>= BIGINT_WORD_SIZE;
155         }
156         if(t){
157                 dest->wordv[i++] = (bigint_word_t)t;
158         }
159         dest->length_W = i;
160         bigint_adjust(dest);
161 }
162
163 /******************************************************************************/
164
165 /* this should be implemented in assembly */
166 void bigint_add_scale_u(bigint_t* dest, const bigint_t* a, uint16_t scale){
167         if(a->length_W == 0){
168                 return;
169         }
170         if(scale == 0){
171                 bigint_add_u(dest, dest, a);
172                 return;
173         }
174 #if DEBUG_ADD_SCALE
175         cli_putstr_P(PSTR("\r\nDBG: bigint_add_scale("));
176         bigint_print_hex(dest);
177         cli_putc(',');
178         bigint_print_hex(dest);
179         cli_putc(',');
180         cli_hexdump_rev(&scale, 2);
181         cli_putc(')');
182 #endif
183 #if BIGINT_WORD_SIZE == 8
184         if(scale >= dest->length_W){
185 #if DEBUG_ADD_SCALE
186                 cli_putstr_P(PSTR("\r\n\tpath one"));
187 #endif
188                 memset(dest->wordv + dest->length_W, 0, scale - dest->length_W);
189                 memcpy(dest->wordv + scale, a->wordv, a->length_W);
190                 dest->info = a->info;
191                 dest->length_W = a->length_W + scale;
192                 return;
193         }
194         bigint_t x;
195 #if DEBUG_ADD_SCALE
196         cli_putstr_P(PSTR("\r\n\tpath two"));
197 #endif
198         x.length_W = dest->length_W - scale;
199         x.info = dest->info;
200         x.wordv = dest->wordv + scale;
201         bigint_add_u(&x, &x, a);
202         dest->length_W = x.length_W + scale;
203         dest->info = 0;
204         bigint_adjust(dest);
205 #else
206         bigint_t s;
207         uint16_t word_shift = scale / sizeof(bigint_word_t), byte_shift = scale % sizeof(bigint_word_t);
208         bigint_word_t bv[a->length_W + 1];
209         s.wordv = bv;
210         bv[0] = bv[a->length_W] = 0;
211         memcpy((uint8_t*)bv + byte_shift, a->wordv, a->length_W * sizeof(bigint_word_t));
212         s.length_W = a->length_W + 1;
213         bigint_adjust(&s);
214         memset(dest->wordv + dest->length_W, 0, (MAX(dest->length_W, s.length_W + word_shift) - dest->length_W) * sizeof(bigint_word_t));
215         x.wordv = dest->wordv + word_shift;
216         x.length_W = dest->length_W - word_shift;
217         if((int16_t)x.length_W < 0){
218                 x.length_W = 0;
219                 x.info = 0;
220         }else{
221                 x.info = dest->info;
222         }
223         bigint_add_u(&x, &x, &s);
224         dest->length_W = x.length_W + word_shift;
225         dest->info = 0;
226         bigint_adjust(dest);
227 #endif
228
229
230 /*      uint16_t i,j=0;
231         uint16_t scale_w;
232         bigint_word_t *dst;
233         bigint_wordplus_t t=0;
234         scale_w = (scale+sizeof(bigint_word_t)-1)/sizeof(bigint_word_t);
235         if(scale>dest->length_W*sizeof(bigint_word_t)){
236                 memset(((uint8_t*)dest->wordv)+dest->length_W*sizeof(bigint_word_t), 0, scale-dest->length_W*sizeof(bigint_word_t));
237         }
238         // a->wordv = (const uint32_t*)(((uint8_t*)a->wordv)+(scale&3));
239         dst  = dest->wordv + (scale&(sizeof(bigint_word_t)-1));
240         for(i=scale/sizeof(bigint_word_t); i<a->length_W+scale_w; ++i,++j){
241                 t += a->wordv[j];
242                 if(dest->length_W>i){
243                         t += dst[i];
244                 }
245                 dst[i] = (bigint_word_t)t;
246                 t>>=BIGINT_WORD_SIZE;
247         }
248         while(t){
249                 if(dest->length_W>i){
250                         t += dst[i];
251                 }
252                 dst[i] = (bigint_word_t)t;
253                 t>>=BIGINT_WORD_SIZE;
254                 ++i;
255         }
256         if(dest->length_W < i){
257                 dest->length_W = i;
258         }
259         bigint_adjust(dest);
260         */
261 }
262
263 /******************************************************************************/
264
265 /* this should be implemented in assembly */
266 void bigint_sub_u(bigint_t* dest, const bigint_t* a, const bigint_t* b){
267         int8_t borrow=0;
268         int8_t  r;
269         bigint_wordplus_signed_t t=0LL;
270         uint16_t i;
271         r = bigint_cmp_u(a,b);
272         if(r==0){
273                 bigint_set_zero(dest);
274                 return;
275         }
276         if(b->length_W==0){
277                 bigint_copy(dest, a);
278                 SET_POS(dest);
279                 return;
280         }
281         if(a->length_W==0){
282                 bigint_copy(dest, b);
283                 SET_NEG(dest);
284                 return;
285         }
286         if(r<0){
287                 bigint_sub_u(dest, b, a);
288                 SET_NEG(dest);
289                 return;
290         }
291         for(i=0; i < a->length_W; ++i){
292                 t = a->wordv[i];
293                 if(i < b->length_W){
294                         t -= b->wordv[i];
295                 }
296                 t -= borrow;
297                 dest->wordv[i]=(bigint_word_t)t;
298                 if(t<0){
299                         borrow = 1;
300                 }else{
301                         borrow = 0;
302                 }
303         }
304         SET_POS(dest);
305         dest->length_W = i;
306         bigint_adjust(dest);
307 }
308
309 /******************************************************************************/
310
311 int8_t bigint_cmp_u(const bigint_t* a, const bigint_t* b){
312         if(a->length_W > b->length_W){
313                 return 1;
314         }
315         if(a->length_W < b->length_W){
316                 return -1;
317         }
318         if(a->length_W == 0){
319                 return 0;
320         }
321         uint16_t i;
322         i = a->length_W-1;
323         do{
324                 if(a->wordv[i] != b->wordv[i]){
325                         if(a->wordv[i] > b->wordv[i]){
326                                 return 1;
327                         }else{
328                                 return -1;
329                         }
330                 }
331         }while(i--);
332         return 0;
333 }
334
335 /******************************************************************************/
336
337 void bigint_add_s(bigint_t* dest, const bigint_t* a, const bigint_t* b){
338         uint8_t s;
339         s  = GET_SIGN(a)?2:0;
340         s |= GET_SIGN(b)?1:0;
341         switch(s){
342                 case 0: /* both positive */
343                         bigint_add_u(dest, a,b);
344                         SET_POS(dest);
345                         break;
346                 case 1: /* a positive, b negative */
347                         bigint_sub_u(dest, a, b);
348                         break;
349                 case 2: /* a negative, b positive */
350                         bigint_sub_u(dest, b, a);
351                         break;
352                 case 3: /* both negative */
353                         bigint_add_u(dest, a, b);
354                         SET_NEG(dest);
355                         break;
356                 default: /* how can this happen?*/
357                         break;
358         }
359 }
360
361 /******************************************************************************/
362
363 void bigint_sub_s(bigint_t* dest, const bigint_t* a, const bigint_t* b){
364         uint8_t s;
365         s  = GET_SIGN(a)?2:0;
366         s |= GET_SIGN(b)?1:0;
367         switch(s){
368                 case 0: /* both positive */
369                         bigint_sub_u(dest, a,b);
370                         break;
371                 case 1: /* a positive, b negative */
372                         bigint_add_u(dest, a, b);
373                         SET_POS(dest);
374                         break;
375                 case 2: /* a negative, b positive */
376                         bigint_add_u(dest, a, b);
377                         SET_NEG(dest);
378                         break;
379                 case 3: /* both negative */
380                         bigint_sub_u(dest, b, a);
381                         break;
382                 default: /* how can this happen?*/
383                                         break;
384         }
385
386 }
387
388 /******************************************************************************/
389
390 int8_t bigint_cmp_s(const bigint_t* a, const bigint_t* b){
391         uint8_t s;
392         if(a->length_W==0 && b->length_W==0){
393                 return 0;
394         }
395         s  = GET_SIGN(a)?2:0;
396         s |= GET_SIGN(b)?1:0;
397         switch(s){
398                 case 0: /* both positive */
399                         return bigint_cmp_u(a, b);
400                         break;
401                 case 1: /* a positive, b negative */
402                         return 1;
403                         break;
404                 case 2: /* a negative, b positive */
405                         return -1;
406                         break;
407                 case 3: /* both negative */
408                         return bigint_cmp_u(b, a);
409                         break;
410                 default: /* how can this happen?*/
411                                         break;
412         }
413         return 0; /* just to satisfy the compiler */
414 }
415
416 /******************************************************************************/
417
418 void bigint_shiftleft(bigint_t* a, uint16_t shift){
419         uint16_t byteshift, words_to_shift;
420         int16_t i;
421         uint8_t bitshift;
422         bigint_word_t *p;
423         bigint_wordplus_t t=0;
424         if(shift == 0){
425                 return;
426         }
427         byteshift = shift / 8;
428         bitshift = shift & 7;
429         memset(&a->wordv[a->length_W], 0x00, byteshift);
430         if(byteshift){
431                 memmove(((uint8_t*)a->wordv)+byteshift, a->wordv, a->length_W * sizeof(bigint_word_t));
432                 memset(a->wordv, 0, byteshift);
433         }
434         if(bitshift != 0){
435                 p = (bigint_word_t*)((uint8_t*)a->wordv + byteshift / sizeof(bigint_word_t));
436                 words_to_shift = a->length_W + ((byteshift % sizeof(bigint_word_t))?1:0);
437                 /* XXX */
438                 for(i = 0; i < words_to_shift; ++i){
439                         t |= ((bigint_wordplus_t)p[i]) << bitshift;
440                         p[i] = (bigint_word_t)t;
441                         t >>= BIGINT_WORD_SIZE;
442                 }
443                 p[i] = (bigint_word_t)t;
444         }
445         a->length_W += (shift + BIGINT_WORD_SIZE - 1) / BIGINT_WORD_SIZE;
446         bigint_adjust(a);
447 }
448
449 /******************************************************************************/
450
451 void bigint_shiftright(bigint_t* a, uint16_t shift){
452         uint16_t byteshift;
453         uint16_t i;
454         uint8_t bitshift;
455         bigint_wordplus_t t=0;
456         byteshift = shift/8;
457         bitshift = shift&7;
458         if(byteshift >= a->length_W * sizeof(bigint_word_t)){ /* we would shift out more than we have */
459                 bigint_set_zero(a);
460                 return;
461         }
462         if(byteshift == a->length_W * sizeof(bigint_word_t) - 1 && bitshift > GET_FBS(a)){
463                 bigint_set_zero(a);
464                 return;
465         }
466         if(byteshift){
467                 memmove(a->wordv, (uint8_t*)a->wordv + byteshift, a->length_W * sizeof(bigint_word_t) - byteshift);
468                 memset((uint8_t*)a->wordv + a->length_W * sizeof(bigint_word_t) - byteshift, 0,  byteshift);
469         }
470         byteshift /= sizeof(bigint_word_t);
471     a->length_W -= (byteshift  + sizeof(bigint_word_t) - 1) / sizeof(bigint_word_t);
472         if(bitshift != 0 && a->length_W){
473          /* shift to the right */
474                 i = a->length_W - 1;
475                 do{
476                         t |= ((bigint_wordplus_t)(a->wordv[i])) << (BIGINT_WORD_SIZE - bitshift);
477                         a->wordv[i] = (bigint_word_t)(t >> BIGINT_WORD_SIZE);
478                         t <<= BIGINT_WORD_SIZE;
479                 }while(i--);
480         }
481         bigint_adjust(a);
482 }
483
484 /******************************************************************************/
485
486 void bigint_xor(bigint_t* dest, const bigint_t* a){
487         uint16_t i;
488         for(i=0; i<a->length_W; ++i){
489                 dest->wordv[i] ^= a->wordv[i];
490         }
491         bigint_adjust(dest);
492 }
493
494 /******************************************************************************/
495
496 void bigint_set_zero(bigint_t* a){
497         a->length_W=0;
498 }
499
500 /******************************************************************************/
501
502 /* using the Karatsuba-Algorithm */
503 /* x*y = (xh*yh)*b**2n + ((xh+xl)*(yh+yl) - xh*yh - xl*yl)*b**n + yh*yl */
504 void bigint_mul_u(bigint_t* dest, const bigint_t* a, const bigint_t* b){
505         if(a->length_W==0 || b->length_W==0){
506                 bigint_set_zero(dest);
507                 return;
508         }
509         if(dest==a || dest==b){
510                 bigint_t d;
511                 bigint_word_t d_b[a->length_W + b->length_W];
512                 d.wordv = d_b;
513                 bigint_mul_u(&d, a, b);
514                 bigint_copy(dest, &d);
515                 return;
516         }
517         if(a->length_W==1 || b->length_W==1){
518                 if(a->length_W != 1){
519                         XCHG_PTR(a,b);
520                 }
521                 bigint_wordplus_t t=0;
522                 uint16_t i;
523                 bigint_word_t x = a->wordv[0];
524                 for(i=0; i < b->length_W; ++i){
525                         t += ((bigint_wordplus_t)b->wordv[i]) * ((bigint_wordplus_t)x);
526                         dest->wordv[i] = (bigint_word_t)t;
527                         t>>=BIGINT_WORD_SIZE;
528                 }
529                 dest->wordv[i] = (bigint_word_t)t;
530                 dest->length_W = i+1;
531                 dest->info = 0;
532                 bigint_adjust(dest);
533                 return;
534         }
535         if(a->length_W * sizeof(bigint_word_t) <= 4 && b->length_W * sizeof(bigint_word_t) <= 4){
536                 uint32_t p=0, q=0;
537                 uint64_t r;
538                 memcpy(&p, a->wordv, a->length_W*sizeof(bigint_word_t));
539                 memcpy(&q, b->wordv, b->length_W*sizeof(bigint_word_t));
540                 r = (uint64_t)p * (uint64_t)q;
541                 memcpy(dest->wordv, &r, (dest->length_W = a->length_W + b->length_W)*sizeof(bigint_word_t));
542                 bigint_adjust(dest);
543                 return;
544         }
545 #if BIGINT_WORD_SIZE == 8
546         if(a->length_W <= 4 || b->length_W <= 4){
547                 if(a->length_W > 4){
548                         XCHG_PTR(a,b);
549                 }
550                 uint32_t x = 0, y = 0;
551                 uint16_t j = b->length_W / 4, idx = 0;
552                 uint64_t r = 0;
553                 memcpy(&x, a->wordv, a->length_W);
554                 while(j){
555                         r += (uint64_t)((uint32_t*)b->wordv)[idx] * (uint64_t)x;
556                         ((uint32_t*)dest->wordv)[idx] = (uint32_t)r;
557                         r >>= 32;
558                         ++idx;
559                         --j;
560                 }
561                 idx *= 4;
562                 memcpy(&y, b->wordv + idx, b->length_W - idx);
563                 r += (uint64_t)y * (uint64_t)x;
564                 while(r){
565                         dest->wordv[idx++] = (uint8_t)r;
566                         r >>= 8;
567                 }
568                 dest->length_W = idx;
569                 bigint_adjust(dest);
570                 return;
571         }
572 #endif
573         /* split a in xh & xl; split b in yh & yl */
574         const uint16_t n = (MAX(a->length_W, b->length_W)+1)/2;
575         bigint_t xl, xh, yl, yh;
576         xl.wordv = a->wordv;
577         yl.wordv = b->wordv;
578         if(a->length_W <= n){
579                 bigint_set_zero(&xh);
580                 xl.length_W = a->length_W;
581                 xl.info = a->info;
582         }else{
583                 xl.length_W = n;
584                 xl.info = 0;
585                 bigint_adjust(&xl);
586                 xh.wordv = &(a->wordv[n]);
587                 xh.length_W = a->length_W-n;
588                 xh.info = a->info;
589         }
590         if(b->length_W <= n){
591                 bigint_set_zero(&yh);
592                 yl.length_W = b->length_W;
593                 yl.info = b->info;
594         }else{
595                 yl.length_W = n;
596                 yl.info = 0;
597                 bigint_adjust(&yl);
598                 yh.wordv = &(b->wordv[n]);
599                 yh.length_W = b->length_W-n;
600                 yh.info = b->info;
601         }
602         /* now we have split up a and b */
603         /* remember we want to do:
604          * x*y = (xh * b ** n + xl) * (yh * b ** n + yl)
605          * x*y = (xh*yh)*b**2n + ((xh+xl)*(yh+yl) - xh*yh - xl*yl)*b**n + xl*yl
606          *          5          9     2   4   3    7   5   6   1         8   1
607          */
608         bigint_word_t  tmp1_b[n*2+2];
609         bigint_t tmp1, tmp2;
610         tmp1.wordv = tmp1_b;
611     tmp2.wordv = &tmp1_b[n+1];
612
613         bigint_add_u(&tmp2,  &xh,   &xl);                            /* 2: tmp2 <= xh + xl     */
614         bigint_add_u(&tmp1, &yh,   &yl);                            /* 3: tmp1 <= yh + yl     */
615         bigint_mul_u(dest,  &tmp2, &tmp1);                          /* 4: dest <= tmp2 * tmp1  */
616         bigint_mul_u(&tmp1, &xh,   &yh);                            /* 5: tmp1 <= xh * yh     */
617     bigint_sub_u(dest,  dest,  &tmp1);                          /* 7: dest <= dest - tmp1       */
618     bigint_word_t tmp3_b[2*n];
619     tmp2.wordv = tmp3_b;
620     bigint_mul_u(&tmp2, &xl,   &yl);                            /* 1: tmp3 <= xl * yl     */
621         bigint_sub_u(dest,  dest,  &tmp2);                          /* 6: dest <= dest - tmp3    */
622     bigint_shiftleft(dest, n * sizeof(bigint_word_t) * 8);
623         bigint_add_u(dest, dest, &tmp2);                            /* 8: dest <= tmp3 + dest ** n       */
624         bigint_add_scale_u(dest, &tmp1, 2*n*sizeof(bigint_word_t)); /* 9: dest <= dest + tmp1 ** (2 * n) */
625
626 #if 0
627         bigint_mul_u(dest,  &xl,   &yl);                            /* 1: dest <= xl * yl     */
628         bigint_add_u(&tmp2, &xh,   &xl);                            /* 2: tmp2 <= xh + xl     */
629         bigint_add_u(&tmp1, &yh,   &yl);                            /* 3: tmp1 <= yh + yl     */
630         bigint_mul_u(&tmp3, &tmp2, &tmp1);                          /* 4: tmp3 <= tmp2 * tmp1  */
631         bigint_mul_u(&tmp1, &xh,   &yh);                            /* 5: h    <= xh * yh     */
632         bigint_sub_u(&tmp3, &tmp3, dest);                           /* 6: tmp3 <= tmp3 - dest    */
633     bigint_sub_u(&tmp3, &tmp3, &tmp1);                          /* 7: tmp3 <= tmp3 - h       */
634         bigint_add_scale_u(dest, &tmp3, n*sizeof(bigint_word_t));   /* 8: dest <= dest + tmp3 ** n       */
635         bigint_add_scale_u(dest, &tmp1, 2*n*sizeof(bigint_word_t)); /* 9: dest <= dest + tmp1 ** (2 * n) */
636 #endif
637 }
638
639 /******************************************************************************/
640
641 void bigint_mul_s(bigint_t* dest, const bigint_t* a, const bigint_t* b){
642         uint8_t s;
643         s  = GET_SIGN(a)?2:0;
644         s |= GET_SIGN(b)?1:0;
645         switch(s){
646                 case 0: /* both positive */
647                         bigint_mul_u(dest, a,b);
648                         SET_POS(dest);
649                         break;
650                 case 1: /* a positive, b negative */
651                         bigint_mul_u(dest, a,b);
652                         SET_NEG(dest);
653                         break;
654                 case 2: /* a negative, b positive */
655                         bigint_mul_u(dest, a,b);
656                         SET_NEG(dest);
657                         break;
658                 case 3: /* both negative */
659                         bigint_mul_u(dest, a,b);
660                         SET_POS(dest);
661                         break;
662                 default: /* how can this happen?*/
663                         break;
664         }
665 }
666
667 /******************************************************************************/
668
669 /* square */
670 /* (xh*b^n+xl)^2 = xh^2*b^2n + 2*xh*xl*b^n + xl^2 */
671 void bigint_square(bigint_t* dest, const bigint_t* a){
672         if(a->length_W * sizeof(bigint_word_t) <= 4){
673                 uint64_t r = 0;
674                 memcpy(&r, a->wordv, a->length_W * sizeof(bigint_word_t));
675                 r = r * r;
676                 memcpy(dest->wordv, &r, 2 * a->length_W*sizeof(bigint_word_t));
677                 SET_POS(dest);
678                 dest->length_W = 2 * a->length_W;
679                 bigint_adjust(dest);
680                 return;
681         }
682         if(dest==a){
683                 bigint_t d;
684                 bigint_word_t d_b[a->length_W * 2];
685                 d.wordv = d_b;
686                 bigint_square(&d, a);
687                 bigint_copy(dest, &d);
688                 return;
689         }
690         uint16_t n;
691         n = (a->length_W + 1) / 2;
692         bigint_t xh, xl, tmp; /* x-high, x-low, temp */
693         bigint_word_t buffer[2 * n + 1];
694         xl.wordv = a->wordv;
695         xl.length_W = n;
696         xl.info = 0;
697         xh.wordv = &(a->wordv[n]);
698         xh.length_W = a->length_W - n;
699         xh.info = a->info;
700         bigint_adjust(&xl);
701         tmp.wordv = buffer;
702 /* (xh * b**n + xl)**2 = xh**2 * b**2n + 2 * xh * xl * b**n + xl**2 */
703
704 //      cli_putstr("\r\nDBG (a): xl: "); bigint_print_hex(&xl);
705 //      cli_putstr("\r\nDBG (b): xh: "); bigint_print_hex(&xh);
706         bigint_square(dest, &xl);
707 //      cli_putstr("\r\nDBG (1): xl**2: "); bigint_print_hex(dest);
708         bigint_square(&tmp, &xh);
709 //      cli_putstr("\r\nDBG (2): xh**2: "); bigint_print_hex(&tmp);
710         bigint_add_scale_u(dest, &tmp, 2 * n * sizeof(bigint_word_t));
711 //      cli_putstr("\r\nDBG (3): xl**2 + xh**2*n**2: "); bigint_print_hex(dest);
712         bigint_mul_u(&tmp, &xl, &xh);
713 //      cli_putstr("\r\nDBG (4): xl*xh: "); bigint_print_hex(&tmp);
714         bigint_shiftleft(&tmp, 1);
715 //      cli_putstr("\r\nDBG (5): xl*xh*2: "); bigint_print_hex(&tmp);
716         bigint_add_scale_u(dest, &tmp, n*sizeof(bigint_word_t));
717 //      cli_putstr("\r\nDBG (6): x**2: "); bigint_print_hex(dest);
718 //      cli_putstr("\r\n");
719 }
720
721 /******************************************************************************/
722 void bigint_sub_u_bitscale(bigint_t* a, const bigint_t* b, uint16_t bitscale){
723         bigint_t tmp, x;
724         bigint_word_t tmp_b[b->length_W + 1];
725         const uint16_t word_shift = bitscale / BIGINT_WORD_SIZE;
726
727         if(a->length_W < b->length_W + word_shift){
728 #if DEBUG
729                 cli_putstr("\r\nDBG: *bang*\r\n");
730 #endif
731                 bigint_set_zero(a);
732                 return;
733         }
734         tmp.wordv = tmp_b;
735         bigint_copy(&tmp, b);
736         bigint_shiftleft(&tmp, bitscale % BIGINT_WORD_SIZE);
737
738 //      cli_putstr_P(PSTR("\r\nDBG: shifted value: ")); bigint_print_hex(&tmp);
739
740         x.info = a->info;
741         x.wordv = &(a->wordv[word_shift]);
742         x.length_W = a->length_W - word_shift;
743
744         bigint_sub_u(&x, &x, &tmp);
745         a->length_W = x.length_W + word_shift;
746         bigint_adjust(a);
747         return;
748 }
749
750 /******************************************************************************/
751
752 void bigint_reduce(bigint_t* a, const bigint_t* r){
753 //      bigint_adjust((bigint_t*)r);
754         uint8_t rfbs = GET_FBS(r);
755 #if DEBUG
756         cli_putstr("\r\nDBG: (a) = "); bigint_print_hex(a);
757 #endif
758         if(r->length_W==0 || a->length_W==0){
759                 return;
760         }
761         if((r->length_W*sizeof(bigint_word_t)<=4) && (a->length_W*sizeof(bigint_word_t)<=4)){
762                 uint32_t p=0, q=0;
763                 memcpy(&p, a->wordv, a->length_W*sizeof(bigint_word_t));
764                 memcpy(&q, r->wordv, r->length_W*sizeof(bigint_word_t));
765                 p %= q;
766                 memcpy(a->wordv, &p, a->length_W*sizeof(bigint_word_t));
767                 bigint_adjust(a);
768 //              cli_putstr("\r\nDBG: (0) = "); bigint_print_hex(a);
769                 return;
770         }
771         uint16_t shift;
772         while(a->length_W > r->length_W){
773                 shift = (a->length_W - r->length_W) * 8 * sizeof(bigint_word_t) + GET_FBS(a) - rfbs - 1;
774                 /*
775                 if((a->wordv[a->length_W-1] & ((1LL<<GET_FBS(a)) - 1)) > r->wordv[r->length_W-1]){
776                         // cli_putc('~');
777                         cli_putstr("\r\n ~ [a] = ");
778                         cli_hexdump_rev(&a->wordv[a->length_W-1], 4);
779                         cli_putstr("  [r] = ");
780                         cli_hexdump_rev(&r->wordv[r->length_W-1], 4);
781                         shift += 1;
782                 }
783                 */
784 //              cli_putstr("\r\nDBG: (p) shift = "); cli_hexdump_rev(&shift, 2);
785 //              cli_putstr(" a_len = "); cli_hexdump_rev(&a->length_W, 2);
786 //              cli_putstr(" r_len = "); cli_hexdump_rev(&r->length_W, 2);
787 //              uart_flush(0);
788                 bigint_sub_u_bitscale(a, r, shift);
789 //              cli_putstr("\r\nDBG: (1) = "); bigint_print_hex(a);
790         }
791         while((GET_FBS(a) > rfbs) && (a->length_W == r->length_W)){
792                 shift = GET_FBS(a)-rfbs-1;
793 //              cli_putstr("\r\nDBG: (2a) = "); bigint_print_hex(a);
794 //              cli_putstr("\r\nDBG: (q) shift = "); cli_hexdump_rev(&shift, 2);
795                 bigint_sub_u_bitscale(a, r, shift);
796 //              cli_putstr("\r\nDBG: (2b) = "); bigint_print_hex(a);
797         }
798         while(bigint_cmp_u(a,r)>=0){
799                 bigint_sub_u(a,a,r);
800 //              cli_putstr("\r\nDBG: (3) = "); bigint_print_hex(a);
801         }
802 //      cli_putc(' ');
803         bigint_adjust(a);
804 //      cli_putstr("\r\nDBG: (a) = "); bigint_print_hex(a);
805 //      cli_putstr("\r\n");
806 }
807
808 /******************************************************************************/
809
810 /* calculate dest = a**exp % r */
811 /* using square&multiply */
812 void bigint_expmod_u(bigint_t* dest, const bigint_t* a, const bigint_t* exp, const bigint_t* r){
813         if(a->length_W==0 || r->length_W==0){
814                 return;
815         }
816
817         bigint_t res, base;
818         bigint_word_t t, base_b[MAX(a->length_W,r->length_W)], res_b[r->length_W*2];
819         uint16_t i;
820         uint8_t j;
821 //      uint16_t *xaddr = &i;
822 //      cli_putstr("\r\npre-alloc (");
823 //      cli_hexdump_rev(&xaddr, 4);
824 //      cli_putstr(") ...");
825         res.wordv = res_b;
826         base.wordv = base_b;
827         bigint_copy(&base, a);
828 //      cli_putstr("\r\npost-copy");
829         bigint_reduce(&base, r);
830         res.wordv[0]=1;
831         res.length_W=1;
832         res.info = 0;
833         bigint_adjust(&res);
834         if(exp->length_W == 0){
835                 bigint_copy(dest, &res);
836                 return;
837         }
838         uint8_t flag = 0;
839         t=exp->wordv[exp->length_W - 1];
840         for(i=exp->length_W; i > 0; --i){
841                 t = exp->wordv[i - 1];
842                 for(j=BIGINT_WORD_SIZE; j > 0; --j){
843                         if(!flag){
844                                 if(t & (1<<(BIGINT_WORD_SIZE-1))){
845                                         flag = 1;
846                                 }
847                         }
848                         if(flag){
849                                 bigint_square(&res, &res);
850                                 bigint_reduce(&res, r);
851                                 if(t & (1<<(BIGINT_WORD_SIZE-1))){
852                                         bigint_mul_u(&res, &res, &base);
853                                         bigint_reduce(&res, r);
854                                 }
855                         }
856                         t<<=1;
857                 }
858         }
859
860 //      cli_putc('+');
861         SET_POS(&res);
862         bigint_copy(dest, &res);
863 }
864
865 /******************************************************************************/
866
867 #define cli_putstr(a)
868 #define bigint_print_hex(a)
869 #define cli_hexdump_rev(a,b)
870 #define uart_flush(a)
871
872 /* gcd <-- gcd(x,y) a*x+b*y=gcd */
873 void bigint_gcdext(bigint_t* gcd, bigint_t* a, bigint_t* b, const bigint_t* x, const bigint_t* y){
874          bigint_t g, x_, y_, u, v, a_, b_, c_, d_;
875          uint16_t i=0;
876          if(x->length_W==0 || y->length_W==0){
877                  return;
878          }
879          if(x->length_W==1 && x->wordv[0]==1){
880                  gcd->length_W = 1;
881                  gcd->wordv[0] = 1;
882                  if(a){
883                          a->length_W = 1;
884                          a->wordv[0] = 1;
885                          SET_POS(a);
886                          bigint_adjust(a);
887                  }
888                  if(b){
889                          bigint_set_zero(b);
890                  }
891                  return;
892          }
893          if(y->length_W==1 && y->wordv[0]==1){
894                  gcd->length_W = 1;
895                  gcd->wordv[0] = 1;
896                  if(b){
897                          b->length_W = 1;
898                          b->wordv[0] = 1;
899                          SET_POS(b);
900                          bigint_adjust(b);
901                  }
902                  if(a){
903                          bigint_set_zero(a);
904                  }
905                  return;
906          }
907
908          while(x->wordv[i]==0 && y->wordv[i]==0){
909                  ++i;
910          }
911          bigint_word_t g_b[i+2], x_b[x->length_W-i], y_b[y->length_W-i];
912          bigint_word_t u_b[x->length_W-i], v_b[y->length_W-i];
913          bigint_word_t a_b[y->length_W+2], c_b[y->length_W+2];
914          bigint_word_t b_b[x->length_W+2], d_b[x->length_W+2];
915
916          g.wordv = g_b;
917          x_.wordv = x_b;
918          y_.wordv = y_b;
919          memset(g_b, 0, i*sizeof(bigint_word_t));
920          g_b[i]=1;
921          g.length_W = i+1;
922          g.info=0;
923          x_.info = y_.info = 0;
924          x_.length_W = x->length_W-i;
925          y_.length_W = y->length_W-i;
926          memcpy(x_.wordv, x->wordv+i, x_.length_W*sizeof(bigint_word_t));
927          memcpy(y_.wordv, y->wordv+i, y_.length_W*sizeof(bigint_word_t));
928          for(i=0; (x_.wordv[0]&(1<<i))==0 && (y_.wordv[0]&(1<<i))==0; ++i){
929          }
930
931          bigint_adjust(&x_);
932          bigint_adjust(&y_);
933
934          if(i){
935                  bigint_shiftleft(&g, i);
936                  bigint_shiftright(&x_, i);
937                  bigint_shiftright(&y_, i);
938          }
939
940          u.wordv = u_b;
941          v.wordv = v_b;
942          a_.wordv = a_b;
943          b_.wordv = b_b;
944          c_.wordv = c_b;
945          d_.wordv = d_b;
946
947          bigint_copy(&u, &x_);
948          bigint_copy(&v, &y_);
949          a_.wordv[0] = 1;
950          a_.length_W = 1;
951          a_.info = 0;
952          d_.wordv[0] = 1;
953          d_.length_W = 1;
954          d_.info = 0;
955          bigint_set_zero(&b_);
956          bigint_set_zero(&c_);
957          do{
958                  cli_putstr("\r\nDBG (gcdext) 0");
959                  while((u.wordv[0]&1)==0){
960                          cli_putstr("\r\nDBG (gcdext) 0.1");
961                          bigint_shiftright(&u, 1);
962                          if((a_.wordv[0]&1) || (b_.wordv[0]&1)){
963                                  bigint_add_s(&a_, &a_, &y_);
964                                  bigint_sub_s(&b_, &b_, &x_);
965                          }
966                          bigint_shiftright(&a_, 1);
967                          bigint_shiftright(&b_, 1);
968                  }
969                  while((v.wordv[0]&1)==0){
970                          cli_putstr("\r\nDBG (gcdext) 0.2");
971                          bigint_shiftright(&v, 1);
972                          if((c_.wordv[0]&1) || (d_.wordv[0]&1)){
973                                  bigint_add_s(&c_, &c_, &y_);
974                                  bigint_sub_s(&d_, &d_, &x_);
975                          }
976                          bigint_shiftright(&c_, 1);
977                          bigint_shiftright(&d_, 1);
978
979                  }
980                  if(bigint_cmp_u(&u, &v)>=0){
981                         bigint_sub_u(&u, &u, &v);
982                         bigint_sub_s(&a_, &a_, &c_);
983                         bigint_sub_s(&b_, &b_, &d_);
984                  }else{
985                         bigint_sub_u(&v, &v, &u);
986                         bigint_sub_s(&c_, &c_, &a_);
987                         bigint_sub_s(&d_, &d_, &b_);
988                  }
989          }while(u.length_W);
990          if(gcd){
991                  bigint_mul_s(gcd, &v, &g);
992          }
993          if(a){
994                 bigint_copy(a, &c_);
995          }
996          if(b){
997                  bigint_copy(b, &d_);
998          }
999 }
1000
1001 /******************************************************************************/
1002
1003 void bigint_inverse(bigint_t* dest, const bigint_t* a, const bigint_t* m){
1004         bigint_gcdext(NULL, dest, NULL, a, m);
1005         while(dest->info&BIGINT_NEG_MASK){
1006                 bigint_add_s(dest, dest, m);
1007         }
1008 }
1009
1010 /******************************************************************************/
1011
1012 void bigint_changeendianess(bigint_t* a){
1013         uint8_t t, *p, *q;
1014         p = (uint8_t*)(a->wordv);
1015         q = p + a->length_W * sizeof(bigint_word_t) - 1;
1016         while(p<q){
1017                 t = *p;
1018                 *p = *q;
1019                 *q = t;
1020                 ++p; --q;
1021         }
1022 }
1023
1024 /******************************************************************************/
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047