]> git.cryptolib.org Git - avr-crypto-lib.git/blob - bigint/bigint.c
c5f799e52936b46f6d2d87db52b48df4edbea2d7
[avr-crypto-lib.git] / bigint / bigint.c
1 /* bigint.c */
2 /*
3     This file is part of the ARM-Crypto-Lib.
4     Copyright (C) 2008  Daniel Otte (daniel.otte@rub.de)
5
6     This program is free software: you can redistribute it and/or modify
7     it under the terms of the GNU General Public License as published by
8     the Free Software Foundation, either version 3 of the License, or
9     (at your option) any later version.
10
11     This program is distributed in the hope that it will be useful,
12     but WITHOUT ANY WARRANTY; without even the implied warranty of
13     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14     GNU General Public License for more details.
15
16     You should have received a copy of the GNU General Public License
17     along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19 /**
20  * \file                bigint.c
21  * \author              Daniel Otte
22  * \date                2010-02-22
23  * 
24  * \license         GPLv3 or later
25  * 
26  */
27  
28
29 #define STRING2(x) #x
30 #define STRING(x) STRING2(x)
31 #define STR_LINE STRING(__LINE__)
32
33 #include "bigint.h"
34 #include <string.h>
35
36 #define DEBUG 1
37
38 #if DEBUG || 1
39 #include "cli.h"
40 #include "uart_i.h"
41 #include "bigint_io.h"
42 #include <stdio.h>
43 #endif
44
45 #ifndef MAX
46  #define MAX(a,b) (((a)>(b))?(a):(b))
47 #endif
48
49 #ifndef MIN
50  #define MIN(a,b) (((a) < (b)) ? (a) : (b))
51 #endif
52
53 #define SET_FBS(a, v) do {(a)->info &= ~BIGINT_FBS_MASK; (a)->info |= (v);} while(0)
54 #define GET_FBS(a)   ((a)->info & BIGINT_FBS_MASK)
55 #define SET_NEG(a)   (a)->info |= BIGINT_NEG_MASK
56 #define SET_POS(a)   (a)->info &= ~BIGINT_NEG_MASK
57 #define XCHG(a,b)    do{(a) ^= (b); (b) ^= (a); (a) ^= (b);} while(0)
58 #define XCHG_PTR(a,b)    do{ a = (void*)(((intptr_t)(a)) ^ ((intptr_t)(b))); \
59                                  b = (void*)(((intptr_t)(a)) ^ ((intptr_t)(b))); \
60                                  a = (void*)(((intptr_t)(a)) ^ ((intptr_t)(b)));} while(0)
61
62 #define GET_SIGN(a) ((a)->info & BIGINT_NEG_MASK)
63
64 /******************************************************************************/
65 void bigint_adjust(bigint_t *a){
66         while (a->length_W != 0 && a->wordv[a->length_W - 1] == 0) {
67                 a->length_W--;
68         }
69         if (a->length_W == 0) {
70                 a->info=0;
71                 return;
72         }
73         bigint_word_t t;
74         uint8_t i = BIGINT_WORD_SIZE - 1;
75         t = a->wordv[a->length_W - 1];
76         while ((t & (((bigint_word_t)1) << (BIGINT_WORD_SIZE - 1))) == 0 && i) {
77                 t <<= 1;
78                 i--;
79         }
80         SET_FBS(a, i);
81 }
82
83 /******************************************************************************/
84
85 bigint_length_t bigint_length_b(const bigint_t *a){
86         if(!a->length_W || a->length_W==0){
87                 return 0;
88         }
89         return (a->length_W-1) * BIGINT_WORD_SIZE + GET_FBS(a);
90 }
91
92 /******************************************************************************/
93
94 bigint_length_t bigint_length_B(const bigint_t *a){
95         return a->length_W * sizeof(bigint_word_t);
96 }
97
98 /******************************************************************************/
99
100 uint32_t bigint_get_first_set_bit(const bigint_t *a){
101         if(a->length_W == 0) {
102                 return (uint32_t)(-1);
103         }
104         return (a->length_W-1) * sizeof(bigint_word_t) * CHAR_BIT + GET_FBS(a);
105 }
106
107
108 /******************************************************************************/
109
110 uint32_t bigint_get_last_set_bit(const bigint_t *a){
111         uint32_t r=0;
112         uint8_t b=0;
113         bigint_word_t x=1;
114         if(a->length_W==0){
115                 return (uint32_t)(-1);
116         }
117         while(a->wordv[r]==0 && r<a->length_W){
118                 ++r;
119         }
120         if(a->wordv[r] == 0){
121                 return (uint32_t)(-1);
122         }
123         while((x&a->wordv[r])==0){
124                 ++b;
125                 x <<= 1;
126         }
127         return r*BIGINT_WORD_SIZE+b;
128 }
129
130 /******************************************************************************/
131
132 void bigint_copy(bigint_t *dest, const bigint_t *src){
133     if(dest->wordv != src->wordv){
134             memcpy(dest->wordv, src->wordv, src->length_W * sizeof(bigint_word_t));
135     }
136     dest->length_W = src->length_W;
137         dest->info = src->info;
138 }
139
140 /******************************************************************************/
141
142 /* this should be implemented in assembly */
143 void bigint_add_u(bigint_t *dest, const bigint_t *a, const bigint_t *b){
144         bigint_length_t i;
145         bigint_wordplus_t t = 0LL;
146         if (a->length_W < b->length_W) {
147                 XCHG_PTR(a, b);
148         }
149         for(i = 0; i < b->length_W; ++i) {
150                 t += a->wordv[i];
151                 t += b->wordv[i];
152                 dest->wordv[i] = (bigint_word_t)t;
153                 t >>= BIGINT_WORD_SIZE;
154         }
155         for(; i < a->length_W; ++i){
156                 t += a->wordv[i];
157                 dest->wordv[i] = (bigint_word_t)t;
158                 t >>= BIGINT_WORD_SIZE;
159         }
160         if(t){
161                 dest->wordv[i++] = (bigint_word_t)t;
162         }
163         dest->length_W = i;
164         SET_POS(dest);
165         bigint_adjust(dest);
166 }
167
168 /******************************************************************************/
169
170 /* this should be implemented in assembly */
171 void bigint_add_scale_u(bigint_t *dest, const bigint_t *a, bigint_length_t scale){
172         if(a->length_W == 0){
173                 return;
174         }
175         if(scale == 0){
176                 bigint_add_u(dest, dest, a);
177                 return;
178         }
179         bigint_t x;
180 #if BIGINT_WORD_SIZE == 8
181         memset(dest->wordv + dest->length_W, 0, MAX(dest->length_W, a->length_W + scale) - dest->length_W);
182         x.wordv = dest->wordv + scale;
183         x.length_W = dest->length_W - scale;
184         if((int16_t)x.length_W < 0){
185                 x.length_W = 0;
186                 x.info = 0;
187         } else {
188                 x.info = dest->info;
189         }
190         bigint_add_u(&x, &x, a);
191         dest->length_W = x.length_W + scale;
192         dest->info = 0;
193         bigint_adjust(dest);
194 #else
195         bigint_t s;
196         bigint_length_t word_shift = scale / sizeof(bigint_word_t), byte_shift = scale % sizeof(bigint_word_t);
197         bigint_word_t bv[a->length_W + 1];
198         s.wordv = bv;
199         bv[0] = bv[a->length_W] = 0;
200         memcpy((uint8_t*)bv + byte_shift, a->wordv, a->length_W * sizeof(bigint_word_t));
201         s.length_W = a->length_W + 1;
202         bigint_adjust(&s);
203         memset(dest->wordv + dest->length_W, 0, (MAX(dest->length_W, s.length_W + word_shift) - dest->length_W) * sizeof(bigint_word_t));
204         x.wordv = dest->wordv + word_shift;
205         x.length_W = dest->length_W - word_shift;
206         if((int16_t)x.length_W < 0){
207                 x.length_W = 0;
208                 x.info = 0;
209         }else{
210                 x.info = dest->info;
211         }
212         bigint_add_u(&x, &x, &s);
213         dest->length_W = x.length_W + word_shift;
214         dest->info = 0;
215         bigint_adjust(dest);
216 #endif
217 }
218
219 /******************************************************************************/
220
221 /* this should be implemented in assembly */
222 void bigint_sub_u(bigint_t *dest, const bigint_t *a, const bigint_t *b){
223         int8_t borrow = 0;
224         int8_t  r;
225         bigint_wordplus_signed_t t = 0;
226         bigint_length_t i;
227         if(b->length_W == 0){
228                 bigint_copy(dest, a);
229                 SET_POS(dest);
230                 return;
231         }
232         if(a->length_W == 0){
233                 bigint_copy(dest, b);
234                 SET_NEG(dest);
235                 return;
236         }
237     r = bigint_cmp_u(a,b);
238     if(r == 0){
239         bigint_set_zero(dest);
240         return;
241     }
242         if(r < 0){
243                 bigint_sub_u(dest, b, a);
244                 SET_NEG(dest);
245                 return;
246         }
247         for(i = 0; i < a->length_W; ++i){
248                 t = a->wordv[i];
249                 if(i < b->length_W){
250                         t -= b->wordv[i];
251                 }
252                 t -= borrow;
253                 dest->wordv[i] = (bigint_word_t)t;
254                 borrow = t < 0 ? 1 : 0;
255         }
256         SET_POS(dest);
257         dest->length_W = i;
258         bigint_adjust(dest);
259 }
260
261 /******************************************************************************/
262
263 int8_t bigint_cmp_u(const bigint_t *a, const bigint_t *b){
264         if(a->length_W > b->length_W){
265                 return 1;
266         }
267         if(a->length_W < b->length_W){
268                 return -1;
269         }
270         if(a->length_W == 0){
271                 return 0;
272         }
273         bigint_length_t i;
274         i = a->length_W - 1;
275         do{
276                 if(a->wordv[i] != b->wordv[i]){
277                         if(a->wordv[i] > b->wordv[i]){
278                                 return 1;
279                         }else{
280                                 return -1;
281                         }
282                 }
283         }while(i--);
284         return 0;
285 }
286
287 /******************************************************************************/
288
289 void bigint_add_s(bigint_t *dest, const bigint_t *a, const bigint_t *b){
290         uint8_t s;
291         s  = GET_SIGN(a)?2:0;
292         s |= GET_SIGN(b)?1:0;
293         switch(s){
294                 case 0: /* both positive */
295                         bigint_add_u(dest, a,b);
296                         SET_POS(dest);
297                         break;
298                 case 1: /* a positive, b negative */
299                         bigint_sub_u(dest, a, b);
300                         break;
301                 case 2: /* a negative, b positive */
302                         bigint_sub_u(dest, b, a);
303                         break;
304                 case 3: /* both negative */
305                         bigint_add_u(dest, a, b);
306                         SET_NEG(dest);
307                         break;
308                 default: /* how can this happen?*/
309                         break;
310         }
311 }
312
313 /******************************************************************************/
314
315 void bigint_sub_s(bigint_t *dest, const bigint_t *a, const bigint_t *b){
316         uint8_t s;
317         s  = GET_SIGN(a)?2:0;
318         s |= GET_SIGN(b)?1:0;
319         switch(s){
320                 case 0: /* both positive */
321                         bigint_sub_u(dest, a,b);
322                         break;
323                 case 1: /* a positive, b negative */
324                         bigint_add_u(dest, a, b);
325                         SET_POS(dest);
326                         break;
327                 case 2: /* a negative, b positive */
328                         bigint_add_u(dest, a, b);
329                         SET_NEG(dest);
330                         break;
331                 case 3: /* both negative */
332                         bigint_sub_u(dest, b, a);
333                         break;
334                 default: /* how can this happen?*/
335                                         break;
336         }
337
338 }
339
340 /******************************************************************************/
341
342 int8_t bigint_cmp_s(const bigint_t *a, const bigint_t *b){
343         uint8_t s;
344         if(a->length_W==0 && b->length_W==0){
345                 return 0;
346         }
347         s  = GET_SIGN(a)?2:0;
348         s |= GET_SIGN(b)?1:0;
349         switch(s){
350                 case 0: /* both positive */
351                         return bigint_cmp_u(a, b);
352                         break;
353                 case 1: /* a positive, b negative */
354                         return 1;
355                         break;
356                 case 2: /* a negative, b positive */
357                         return -1;
358                         break;
359                 case 3: /* both negative */
360                         return bigint_cmp_u(b, a);
361                         break;
362                 default: /* how can this happen?*/
363                                         break;
364         }
365         return 0; /* just to satisfy the compiler */
366 }
367
368 /******************************************************************************/
369
370 void bigint_shiftleft(bigint_t *a, bigint_length_t shift){
371         bigint_length_t byteshift, words_to_shift;
372         int16_t i;
373         uint8_t bitshift;
374         bigint_word_t *p;
375         bigint_wordplus_t t = 0;
376
377         if (shift == 0) {
378                 return;
379         }
380         byteshift = shift / 8;
381         bitshift = shift & 7;
382     if (byteshift % sizeof(bigint_word_t)) {
383         a->wordv[a->length_W + byteshift / sizeof(bigint_t)] = 0;
384     }
385         if (byteshift) {
386                 memmove(((uint8_t*)a->wordv) + byteshift, a->wordv, a->length_W * sizeof(bigint_word_t));
387                 memset(a->wordv, 0, byteshift);
388         }
389         if (bitshift == 0) {
390             a->length_W += (byteshift + sizeof(bigint_word_t) - 1) / sizeof(bigint_word_t);
391             bigint_adjust(a);
392             return;
393         }
394         p = &a->wordv[byteshift / sizeof(bigint_word_t)];
395         words_to_shift = a->length_W + (byteshift % sizeof(bigint_word_t) ? 1 : 0);
396     for (i = 0; i < words_to_shift; ++i) {
397         t |= ((bigint_wordplus_t)p[i]) << bitshift;
398         p[i] = (bigint_word_t)t;
399         t >>= BIGINT_WORD_SIZE;
400     }
401     if (t) {
402         p[i] = (bigint_word_t)t;
403         a->length_W += 1;
404     }
405     a->length_W += (byteshift + sizeof(bigint_word_t) - 1) / sizeof(bigint_word_t);
406     bigint_adjust(a);
407 }
408
409 /******************************************************************************/
410
411 void bigint_shiftright(bigint_t *a, bigint_length_t shift){
412         bigint_length_t byteshift;
413         bigint_length_t i;
414         uint8_t bitshift;
415         bigint_wordplus_t t = 0;
416         byteshift = shift / 8;
417         bitshift = shift & 7;
418
419         if(bigint_get_first_set_bit(a) < shift){ /* we would shift out more than we have */
420                 bigint_set_zero(a);
421                 return;
422         }
423
424         if(byteshift){
425                 memmove(a->wordv, (uint8_t*)a->wordv + byteshift, a->length_W * sizeof(bigint_word_t) - byteshift);
426                 memset((uint8_t*)&a->wordv[a->length_W] - byteshift, 0, byteshift);
427         }
428
429     a->length_W -= byteshift / sizeof(bigint_word_t);
430
431     if(bitshift != 0 && a->length_W){
432          /* shift to the right */
433                 i = a->length_W - 1;
434                 do{
435                         t |= ((bigint_wordplus_t)(a->wordv[i])) << (BIGINT_WORD_SIZE - bitshift);
436                         a->wordv[i] = (bigint_word_t)(t >> BIGINT_WORD_SIZE);
437                         t <<= BIGINT_WORD_SIZE;
438                 }while(i--);
439         }
440         bigint_adjust(a);
441 }
442
443 /******************************************************************************/
444
445 void bigint_xor(bigint_t *dest, const bigint_t *a){
446         bigint_length_t i;
447         for(i=0; i<a->length_W; ++i){
448                 dest->wordv[i] ^= a->wordv[i];
449         }
450         bigint_adjust(dest);
451 }
452
453 /******************************************************************************/
454
455 void bigint_set_zero(bigint_t *a){
456         a->length_W=0;
457 }
458
459 /******************************************************************************/
460
461 /* using the Karatsuba-Algorithm */
462 /* x*y = (xh*yh)*b**2n + ((xh+xl)*(yh+yl) - xh*yh - xl*yl)*b**n + yh*yl */
463 void bigint_mul_u(bigint_t *dest, const bigint_t *a, const bigint_t *b){
464         if(a->length_W == 0 || b->length_W == 0){
465                 bigint_set_zero(dest);
466                 return;
467         }
468         if(dest == a || dest == b){
469                 bigint_t d;
470                 bigint_word_t d_b[a->length_W + b->length_W];
471                 d.wordv = d_b;
472                 bigint_mul_u(&d, a, b);
473                 bigint_copy(dest, &d);
474                 return;
475         }
476         if(a->length_W == 1 || b->length_W == 1){
477                 if(a->length_W != 1){
478                         XCHG_PTR(a,b);
479                 }
480                 bigint_wordplus_t t = 0;
481                 bigint_length_t i;
482                 bigint_word_t x = a->wordv[0];
483                 for(i=0; i < b->length_W; ++i){
484                         t += ((bigint_wordplus_t)b->wordv[i]) * ((bigint_wordplus_t)x);
485                         dest->wordv[i] = (bigint_word_t)t;
486                         t >>= BIGINT_WORD_SIZE;
487                 }
488                 dest->length_W = i;
489                 if(t){
490                     dest->wordv[i] = (bigint_word_t)t;
491                     dest->length_W += 1;
492                 }
493                 dest->info = 0;
494                 bigint_adjust(dest);
495                 return;
496         }
497         if(a->length_W * sizeof(bigint_word_t) <= 4 && b->length_W * sizeof(bigint_word_t) <= 4){
498                 uint32_t p=0, q=0;
499                 uint64_t r;
500                 memcpy(&p, a->wordv, a->length_W*sizeof(bigint_word_t));
501                 memcpy(&q, b->wordv, b->length_W*sizeof(bigint_word_t));
502                 r = (uint64_t)p * (uint64_t)q;
503                 memcpy(dest->wordv, &r, (dest->length_W = a->length_W + b->length_W)*sizeof(bigint_word_t));
504                 bigint_adjust(dest);
505                 return;
506         }
507         /* split a in xh & xl; split b in yh & yl */
508         const bigint_length_t n = (MAX(a->length_W, b->length_W)+1)/2;
509         bigint_t xl, xh, yl, yh;
510         xl.wordv = a->wordv;
511         yl.wordv = b->wordv;
512         if(a->length_W<=n){
513                 bigint_set_zero(&xh);
514                 xl.length_W = a->length_W;
515                 xl.info = a->info;
516         }else{
517                 xl.length_W=n;
518                 xl.info = 0;
519                 bigint_adjust(&xl);
520                 xh.wordv = &(a->wordv[n]);
521                 xh.length_W = a->length_W-n;
522                 xh.info = a->info;
523         }
524         if(b->length_W<=n){
525                 bigint_set_zero(&yh);
526                 yl.length_W = b->length_W;
527                 yl.info = b->info;
528         }else{
529                 yl.length_W=n;
530                 yl.info = 0;
531                 bigint_adjust(&yl);
532                 yh.wordv = &(b->wordv[n]);
533                 yh.length_W = b->length_W-n;
534                 yh.info = b->info;
535         }
536         /* now we have split up a and b */
537         /* remember we want to do:
538          * x*y = (xh * b**n + xl) * (yh * b**n + yl)
539          *     = (xh * yh) * b**2n + xh * b**n * yl + yh * b**n * xl + xl * yl
540          *     = (xh * yh) * b**2n + (xh * yl + yh * xl) * b**n + xl *yl
541          *     // xh * yl + yh * xl = (xh + yh) * (xl + yl) - xh * yh - xl * yl
542          * x*y = (xh * yh) * b**2n + ((xh+xl)*(yh+yl) - xh*yh - xl*yl)*b**n + xl*yl
543          *          5              9     2   4   3    7   5   6   1         8   1
544          */
545         bigint_word_t  tmp_b[2 * n + 2], m_b[2 * (n + 1)];
546         bigint_t tmp, tmp2, m;
547         tmp.wordv = tmp_b;
548         tmp2.wordv = &(tmp_b[n + 1]);
549         m.wordv = m_b;
550
551         bigint_mul_u(dest, &xl, &yl);  /* 1: dest <= xl*yl     */
552         bigint_add_u(&tmp2, &xh, &xl); /* 2: tmp2 <= xh+xl     */
553         bigint_add_u(&tmp, &yh, &yl);  /* 3: tmp  <= yh+yl     */
554         bigint_mul_u(&m, &tmp2, &tmp); /* 4: m    <= tmp2*tmp  */
555         bigint_mul_u(&tmp, &xh, &yh);  /* 5: tmp  <= xh*yh     */
556         bigint_sub_u(&m, &m, dest);    /* 6: m    <= m-dest    */
557     bigint_sub_u(&m, &m, &tmp);    /* 7: m    <= m-tmp     */
558         bigint_add_scale_u(dest, &m, n * sizeof(bigint_word_t));       /* 8: dest <= dest+m**n*/
559         bigint_add_scale_u(dest, &tmp, 2 * n * sizeof(bigint_word_t)); /* 9: dest <= dest+tmp**(2*n) */
560 }
561
562 /******************************************************************************/
563
564 void bigint_mul_s(bigint_t *dest, const bigint_t *a, const bigint_t *b){
565         uint8_t s;
566         s  = GET_SIGN(a)?2:0;
567         s |= GET_SIGN(b)?1:0;
568         switch(s){
569                 case 0: /* both positive */
570                         bigint_mul_u(dest, a,b);
571                         SET_POS(dest);
572                         break;
573                 case 1: /* a positive, b negative */
574                         bigint_mul_u(dest, a,b);
575                         SET_NEG(dest);
576                         break;
577                 case 2: /* a negative, b positive */
578                         bigint_mul_u(dest, a,b);
579                         SET_NEG(dest);
580                         break;
581                 case 3: /* both negative */
582                         bigint_mul_u(dest, a,b);
583                         SET_POS(dest);
584                         break;
585                 default: /* how can this happen?*/
586                         break;
587         }
588 }
589
590 /******************************************************************************/
591 #if OLD_SQUARE
592
593 #if DEBUG_SQUARE
594 unsigned square_depth = 0;
595 #endif
596
597 /* square */
598 /* (xh*b^n+xl)^2 = xh^2*b^2n + 2*xh*xl*b^n + xl^2 */
599 void bigint_square(bigint_t *dest, const bigint_t *a){
600     if(a->length_W * sizeof(bigint_word_t) <= 4){
601                 uint64_t r = 0;
602                 memcpy(&r, a->wordv, a->length_W * sizeof(bigint_word_t));
603                 r = r * r;
604                 memcpy(dest->wordv, &r, 2 * a->length_W * sizeof(bigint_word_t));
605                 SET_POS(dest);
606                 dest->length_W = 2 * a->length_W;
607                 bigint_adjust(dest);
608                 return;
609         }
610
611         if(dest->wordv == a->wordv){
612                 bigint_t d;
613                 bigint_word_t d_b[a->length_W*2];
614                 d.wordv = d_b;
615                 bigint_square(&d, a);
616                 bigint_copy(dest, &d);
617                 return;
618         }
619         bigint_fast_square(dest, a);
620         return;
621 #if DEBUG_SQUARE
622         square_depth += 1;
623 #endif
624
625         bigint_length_t n;
626         n=(a->length_W+1)/2;
627         bigint_t xh, xl, tmp; /* x-high, x-low, temp */
628         bigint_word_t buffer[2*n+1];
629         xl.wordv = a->wordv;
630         xl.length_W = n;
631         xl.info = 0;
632         xh.wordv = &(a->wordv[n]);
633         xh.length_W = a->length_W-n;
634         xh.info = a->info;
635         bigint_adjust(&xl);
636         tmp.wordv = buffer;
637 /* (xh * b**n + xl)**2 = xh**2 * b**2n + 2 * xh * xl * b**n + xl**2 */
638 #if DEBUG_SQUARE
639         if(square_depth == 1){
640         cli_putstr("\r\nDBG (a): xl: "); bigint_print_hex(&xl);
641         cli_putstr("\r\nDBG (b): xh: "); bigint_print_hex(&xh);
642         }
643 #endif
644         bigint_square(dest, &xl);
645 #if DEBUG_SQUARE
646         if(square_depth == 1){
647             cli_putstr("\r\nDBG (1): xl**2: "); bigint_print_hex(dest);
648         }
649 #endif
650     bigint_square(&tmp, &xh);
651 #if DEBUG_SQUARE
652     if(square_depth == 1){
653         cli_putstr("\r\nDBG (2): xh**2: "); bigint_print_hex(&tmp);
654     }
655 #endif
656         bigint_add_scale_u(dest, &tmp, 2 * n * sizeof(bigint_word_t));
657 #if DEBUG_SQUARE
658         if(square_depth == 1){
659             cli_putstr("\r\nDBG (3): xl**2 + xh**2*n**2: "); bigint_print_hex(dest);
660         }
661 #endif
662         bigint_mul_u(&tmp, &xl, &xh);
663 #if DEBUG_SQUARE
664         if(square_depth == 1){
665             cli_putstr("\r\nDBG (4): xl*xh: "); bigint_print_hex(&tmp);
666         }
667 #endif
668         bigint_shiftleft(&tmp, 1);
669 #if DEBUG_SQUARE
670         if(square_depth == 1){
671             cli_putstr("\r\nDBG (5): xl*xh*2: "); bigint_print_hex(&tmp);
672         }
673 #endif
674         bigint_add_scale_u(dest, &tmp, n * sizeof(bigint_word_t));
675 #if DEBUG_SQUARE
676         if(square_depth == 1){
677             cli_putstr("\r\nDBG (6): x**2: "); bigint_print_hex(dest);
678             cli_putstr("\r\n");
679         }
680         square_depth -= 1;
681 #endif
682 }
683
684 #endif
685
686 /******************************************************************************/
687
688 void bigint_square(bigint_t *dest, const bigint_t *a) {
689     union __attribute__((packed)) {
690         bigint_word_t u[2];
691         bigint_wordplus_t uv;
692     } acc;
693     bigint_word_t q, c1, c2;
694     bigint_length_t i, j;
695
696     if(a->length_W * sizeof(bigint_word_t) <= 4){
697         uint64_t r = 0;
698         memcpy(&r, a->wordv, a->length_W * sizeof(bigint_word_t));
699         r = r * r;
700         memcpy(dest->wordv, &r, 2 * a->length_W * sizeof(bigint_word_t));
701         SET_POS(dest);
702         dest->length_W = 2 * a->length_W;
703         bigint_adjust(dest);
704         return;
705     }
706
707     if(dest->wordv == a->wordv){
708         bigint_t d;
709         bigint_word_t d_b[a->length_W * 2];
710         d.wordv = d_b;
711         bigint_square(&d, a);
712         bigint_copy(dest, &d);
713         return;
714     }
715
716     memset(dest->wordv, 0, a->length_W * 2 * sizeof(bigint_word_t));
717
718     for (i = 0; i < a->length_W; ++i) {
719         acc.uv = (bigint_wordplus_t)a->wordv[i] * (bigint_wordplus_t)a->wordv[i] + (bigint_wordplus_t)dest->wordv[2 * i];
720         dest->wordv[2 * i] = acc.u[0];
721         c1 = acc.u[1];
722         c2 = 0;
723         for (j = i + 1; j < a->length_W; ++j) {
724             acc.uv = (bigint_wordplus_t)a->wordv[i] * (bigint_wordplus_t)a->wordv[j];
725             q = acc.u[1] >> (BIGINT_WORD_SIZE - 1);
726             acc.uv <<= 1;
727             acc.uv += dest->wordv[i + j];
728             q += (acc.uv < dest->wordv[i + j]);
729             acc.uv += c1;
730             q += (acc.uv < c1);
731             dest->wordv[i + j] = acc.u[0];
732             c1 = (bigint_wordplus_t)acc.u[1] + c2;
733             c2 = q + (c1 < c2);
734         }
735         dest->wordv[i + a->length_W] += c1;
736         if (i < a->length_W - 1) {
737             dest->wordv[i + a->length_W + 1] = c2 + (dest->wordv[i + a->length_W] < c1);
738         }
739     }
740     dest->info = 0;
741     dest->length_W = 2 * a->length_W;
742     bigint_adjust(dest);
743 }
744
745 /******************************************************************************/
746
747 void bigint_sub_u_bitscale(bigint_t *a, const bigint_t *b, bigint_length_t bitscale){
748         bigint_t tmp, x;
749         bigint_word_t tmp_b[b->length_W + 1];
750         const bigint_length_t word_shift = bitscale / BIGINT_WORD_SIZE;
751
752         if (a->length_W < b->length_W + word_shift) {
753 #if DEBUG
754                 cli_putstr("\r\nDBG: *bang*\r\n");
755 #endif
756                 bigint_set_zero(a);
757                 return;
758         }
759         tmp.wordv = tmp_b;
760         bigint_copy(&tmp, b);
761         bigint_shiftleft(&tmp, bitscale % BIGINT_WORD_SIZE);
762
763         x.info = a->info;
764         x.wordv = &(a->wordv[word_shift]);
765         x.length_W = a->length_W - word_shift;
766
767         bigint_sub_u(&x, &x, &tmp);
768         bigint_adjust(a);
769         return;
770 }
771
772 /******************************************************************************/
773
774 void bigint_reduce(bigint_t *a, const bigint_t *r){
775         uint8_t rfbs = GET_FBS(r);
776         if (r->length_W == 0 || a->length_W == 0) {
777                 return;
778         }
779
780         if (bigint_length_b(a) + 3 > bigint_length_b(r)) {
781         if ((r->length_W * sizeof(bigint_word_t) <= 4) && (a->length_W * sizeof(bigint_word_t) <= 4)) {
782             uint32_t p = 0, q = 0;
783             memcpy(&p, a->wordv, a->length_W * sizeof(bigint_word_t));
784             memcpy(&q, r->wordv, r->length_W * sizeof(bigint_word_t));
785             p %= q;
786             memcpy(a->wordv, &p, a->length_W * sizeof(bigint_word_t));
787             a->length_W = r->length_W;
788             bigint_adjust(a);
789             return;
790         }
791         unsigned shift;
792         while (a->length_W > r->length_W) {
793             shift = (a->length_W - r->length_W) * CHAR_BIT * sizeof(bigint_word_t) + GET_FBS(a) - rfbs - 1;
794             bigint_sub_u_bitscale(a, r, shift);
795         }
796         while ((GET_FBS(a) > rfbs) && (a->length_W == r->length_W)) {
797             shift = GET_FBS(a) - rfbs - 1;
798             bigint_sub_u_bitscale(a, r, shift);
799         }
800         }
801         while (bigint_cmp_u(a, r) >= 0) {
802                 bigint_sub_u(a, a, r);
803         }
804         bigint_adjust(a);
805 }
806
807 /******************************************************************************/
808
809 /* calculate dest = a**exp % r */
810 /* using square&multiply */
811 void bigint_expmod_u_sam(bigint_t *dest, const bigint_t *a, const bigint_t *exp, const bigint_t *r){
812         if(a->length_W == 0){
813             bigint_set_zero(dest);
814                 return;
815         }
816
817         bigint_t res, base;
818         bigint_word_t t, base_w[MAX(a->length_W, r->length_W)], res_w[r->length_W * 2];
819         bigint_length_t i;
820         uint8_t j;
821         res.wordv = res_w;
822         base.wordv = base_w;
823         bigint_copy(&base, a);
824         bigint_reduce(&base, r);
825         res.wordv[0] = 1;
826         res.length_W = 1;
827         res.info = 0;
828         bigint_adjust(&res);
829         if(exp->length_W == 0){
830                 bigint_copy(dest, &res);
831                 return;
832         }
833         bigint_copy(&res, &base);
834         uint8_t flag = 0;
835         t = exp->wordv[exp->length_W - 1];
836         for (i = exp->length_W; i > 0; --i) {
837                 t = exp->wordv[i - 1];
838                 for (j = BIGINT_WORD_SIZE; j > 0; --j) {
839                         if (!flag) {
840                                 if (t & (((bigint_word_t)1) << (BIGINT_WORD_SIZE - 1))) {
841                                         flag = 1;
842                                 }
843                         } else {
844                                 bigint_square(&res, &res);
845                                 bigint_reduce(&res, r);
846                                 if(t & (((bigint_word_t)1) << (BIGINT_WORD_SIZE - 1))){
847                                         bigint_mul_u(&res, &res, &base);
848                                         bigint_reduce(&res, r);
849                                 }
850                         }
851                         t <<= 1;
852                 }
853         }
854
855         SET_POS(&res);
856         bigint_copy(dest, &res);
857 }
858
859 /******************************************************************************/
860 /* gcd <-- gcd(x,y) a*x+b*y=gcd */
861 void bigint_gcdext(bigint_t *gcd, bigint_t *a, bigint_t *b, const bigint_t *x, const bigint_t *y){
862          bigint_length_t i = 0;
863          if(x->length_W == 0 || y->length_W == 0){
864              if(gcd){
865                  bigint_set_zero(gcd);
866              }
867              if(a){
868                  bigint_set_zero(a);
869              }
870          if(b){
871              bigint_set_zero(b);
872          }
873                  return;
874          }
875          if(x->length_W == 1 && x->wordv[0] == 1){
876              if(gcd){
877              gcd->length_W = 1;
878              gcd->wordv[0] = 1;
879              gcd->info = 0;
880              }
881                  if(a){
882                          a->length_W = 1;
883                          a->wordv[0] = 1;
884                          SET_POS(a);
885                          bigint_adjust(a);
886                  }
887                  if(b){
888                          bigint_set_zero(b);
889                  }
890                  return;
891          }
892          if(y->length_W == 1 && y->wordv[0] == 1){
893                  if(gcd){
894              gcd->length_W = 1;
895              gcd->wordv[0] = 1;
896              gcd->info = 0;
897                  }
898                  if(b){
899                          b->length_W = 1;
900                          b->wordv[0] = 1;
901                          SET_POS(b);
902                          bigint_adjust(b);
903                  }
904                  if(a){
905                          bigint_set_zero(a);
906                  }
907                  return;
908          }
909
910          while(x->wordv[i] == 0 && y->wordv[i] == 0){
911                  ++i;
912          }
913          bigint_word_t g_b[i + 2], x_b[x->length_W - i], y_b[y->length_W - i];
914          bigint_word_t u_b[x->length_W - i], v_b[y->length_W - i];
915          bigint_word_t a_b[y->length_W + 2], c_b[y->length_W + 2];
916          bigint_word_t b_b[x->length_W + 2], d_b[x->length_W + 2];
917      bigint_t g, x_, y_, u, v, a_, b_, c_, d_;
918
919          g.wordv = g_b;
920          x_.wordv = x_b;
921          y_.wordv = y_b;
922          memset(g_b, 0, i * sizeof(bigint_word_t));
923          g_b[i] = 1;
924          g.length_W = i + 1;
925          g.info = 0;
926          x_.info = y_.info = 0;
927          x_.length_W = x->length_W - i;
928          y_.length_W = y->length_W - i;
929          memcpy(x_.wordv, x->wordv + i, x_.length_W * sizeof(bigint_word_t));
930          memcpy(y_.wordv, y->wordv + i, y_.length_W * sizeof(bigint_word_t));
931          for(i = 0; (x_.wordv[0] & (1 << i)) == 0 && (y_.wordv[0] & (1 << i)) == 0; ++i){
932          }
933
934          bigint_adjust(&x_);
935          bigint_adjust(&y_);
936
937          if(i){
938                  bigint_shiftleft(&g, i);
939                  bigint_shiftright(&x_, i);
940                  bigint_shiftright(&y_, i);
941          }
942
943          u.wordv = u_b;
944          v.wordv = v_b;
945          a_.wordv = a_b;
946          b_.wordv = b_b;
947          c_.wordv = c_b;
948          d_.wordv = d_b;
949
950          bigint_copy(&u, &x_);
951          bigint_copy(&v, &y_);
952          a_.wordv[0] = 1;
953          a_.length_W = 1;
954          a_.info = 0;
955          d_.wordv[0] = 1;
956          d_.length_W = 1;
957          d_.info = 0;
958          bigint_set_zero(&b_);
959          bigint_set_zero(&c_);
960          do {
961                  while ((u.wordv[0] & 1) == 0) {
962                          bigint_shiftright(&u, 1);
963                          if((a_.wordv[0] & 1) || (b_.wordv[0] & 1)){
964                                  bigint_add_s(&a_, &a_, &y_);
965                                  bigint_sub_s(&b_, &b_, &x_);
966                          }
967                          bigint_shiftright(&a_, 1);
968                          bigint_shiftright(&b_, 1);
969                  }
970                  while ((v.wordv[0] & 1) == 0) {
971                      bigint_shiftright(&v, 1);
972                          if((c_.wordv[0] & 1) || (d_.wordv[0] & 1)){
973                                  bigint_add_s(&c_, &c_, &y_);
974                                  bigint_sub_s(&d_, &d_, &x_);
975                          }
976                          bigint_shiftright(&c_, 1);
977                          bigint_shiftright(&d_, 1);
978                  }
979                  if(bigint_cmp_u(&u, &v) >= 0){
980                         bigint_sub_u(&u, &u, &v);
981                         bigint_sub_s(&a_, &a_, &c_);
982                         bigint_sub_s(&b_, &b_, &d_);
983                  }else{
984                         bigint_sub_u(&v, &v, &u);
985                         bigint_sub_s(&c_, &c_, &a_);
986                         bigint_sub_s(&d_, &d_, &b_);
987                  }
988          } while(u.length_W);
989          if(gcd){
990                  bigint_mul_s(gcd, &v, &g);
991          }
992          if(a){
993                 bigint_copy(a, &c_);
994          }
995          if(b){
996                  bigint_copy(b, &d_);
997          }
998 }
999
1000 /******************************************************************************/
1001
1002 void bigint_inverse(bigint_t *dest, const bigint_t *a, const bigint_t *m){
1003         bigint_gcdext(NULL, dest, NULL, a, m);
1004         while(dest->info&BIGINT_NEG_MASK){
1005                 bigint_add_s(dest, dest, m);
1006         }
1007 }
1008
1009 /******************************************************************************/
1010
1011 void bigint_changeendianess(bigint_t *a){
1012         uint8_t t, *p, *q;
1013         p = (uint8_t*)(a->wordv);
1014         q = p + a->length_W * sizeof(bigint_word_t) - 1;
1015         while(p<q){
1016                 t = *p;
1017                 *p = *q;
1018                 *q = t;
1019                 ++p; --q;
1020         }
1021 }
1022
1023 /******************************************************************************/
1024
1025 void bigint_mul_word_u(bigint_t *a, bigint_word_t b){
1026     bigint_wordplus_t c0 = 0, c1 = 0;
1027     bigint_length_t i;
1028
1029     if(b == 0){
1030         bigint_set_zero(a);
1031         return;
1032     }
1033
1034     for(i = 0; i < a->length_W; ++i){
1035         c1 = ((bigint_wordplus_t)(a->wordv[i])) * ((bigint_wordplus_t)b);
1036         c1 += c0;
1037         a->wordv[i] = (bigint_word_t)c1;
1038         c0 = c1 >> BIGINT_WORD_SIZE;
1039     }
1040     if(c0){
1041         a->wordv[a->length_W] = (bigint_word_t)c0;
1042         a->length_W += 1;
1043     }
1044     bigint_adjust(a);
1045 }
1046
1047 /******************************************************************************/
1048 #if 1
1049
1050 void bigint_clip(bigint_t *dest, bigint_length_t length_W){
1051     if(dest->length_W > length_W){
1052         dest->length_W = length_W;
1053     }
1054     bigint_adjust(dest);
1055 }
1056 /******************************************************************************/
1057
1058 /*
1059  * m_ = m * m'[0]
1060  * dest = (a * b) % m (?)
1061  */
1062
1063 void bigint_mont_mul(bigint_t *dest, const bigint_t *a, const bigint_t *b, const bigint_t *m, const bigint_t *m_){
1064     const bigint_length_t s = MAX(MAX(a->length_W, b->length_W), m->length_W);
1065     bigint_t u, t;
1066     bigint_word_t u_w[s + 2], t_w[s + 2];
1067     bigint_length_t i;
1068
1069     if (a->length_W == 0 || b->length_W == 0) {
1070         bigint_set_zero(dest);
1071         return;
1072     }
1073     u.wordv = u_w;
1074     u.info = 0;
1075     u.length_W = 0;
1076     t.wordv = t_w;
1077     for (i = 0; i < a->length_W; ++i) {
1078         bigint_copy(&t, b);
1079         bigint_mul_word_u(&t, a->wordv[i]);
1080         bigint_add_u(&u, &u, &t);
1081         bigint_copy(&t, m_);
1082         if (u.length_W != 0) {
1083             bigint_mul_word_u(&t, u.wordv[0]);
1084             bigint_add_u(&u, &u, &t);
1085         }
1086         bigint_shiftright(&u, BIGINT_WORD_SIZE);
1087     }
1088     for (; i < s; ++i) {
1089         bigint_copy(&t, m_);
1090         if (u.length_W != 0) {
1091             bigint_mul_word_u(&t, u.wordv[0]);
1092             bigint_add_u(&u, &u, &t);
1093         }
1094         bigint_shiftright(&u, BIGINT_WORD_SIZE);
1095     }
1096     bigint_reduce(&u, m);
1097     bigint_copy(dest, &u);
1098 }
1099
1100 /******************************************************************************/
1101
1102 void bigint_mont_red(bigint_t *dest, const bigint_t *a, const bigint_t *m, const bigint_t *m_){
1103     bigint_t u, t;
1104     bigint_length_t i, s = MAX(a->length_W, m->length_W);
1105     bigint_word_t u_w[s + 2], t_w[s + 2];
1106
1107     t.wordv = t_w;
1108     u.wordv = u_w;
1109     if (a->length_W == 0) {
1110         bigint_set_zero(dest);
1111         return;
1112     }
1113     bigint_copy(&u, a);
1114     for (i = 0; i < m->length_W; ++i) {
1115         bigint_copy(&t, m_);
1116         if (u.length_W != 0) {
1117             bigint_mul_word_u(&t, u.wordv[0]);
1118             bigint_add_u(&u, &u, &t);
1119         }
1120         bigint_shiftright(&u, BIGINT_WORD_SIZE);
1121     }
1122     bigint_reduce(&u, m);
1123     bigint_copy(dest, &u);
1124 }
1125
1126 /******************************************************************************/
1127 /*
1128  * m_ = m * (- m0^-1 (mod 2^W))
1129  */
1130 void bigint_mont_gen_m_(bigint_t* dest, const bigint_t* m){
1131     bigint_word_t x_w[2], m_w_0[1];
1132     bigint_t x, m_0;
1133 #if 0
1134     printf("\nm = ");
1135     bigint_print_hex(m);
1136     uart0_flush();
1137 #endif
1138     if (m->length_W == 0) {
1139         bigint_set_zero(dest);
1140         return;
1141     }
1142     if ((m->wordv[0] & 1) == 0) {
1143         printf_P(PSTR("ERROR: m must not be even, m = "));
1144         bigint_print_hex(m);
1145         putchar('\n');
1146         uart0_flush();
1147         return;
1148     }
1149     x.wordv = x_w;
1150     x.info = 0;
1151     x.length_W = 2;
1152     x_w[0] = 0;
1153     x_w[1] = 1;
1154     m_0.wordv = m_w_0;
1155     m_0.info = 0;
1156     m_0.length_W = 1;
1157     m_0.wordv[0] = m->wordv[0];
1158     bigint_adjust(&x);
1159     bigint_adjust(&m_0);
1160 #if 0
1161     printf("\nm0 = ");
1162     bigint_print_hex(&m_0);
1163     printf("\nx = ");
1164     bigint_print_hex(&x);
1165     uart0_flush();
1166 #endif
1167     bigint_inverse(dest, &m_0, &x);
1168 #if 0
1169     printf("\nm0^-1 = ");
1170     bigint_print_hex(dest);
1171     uart0_flush();
1172 #endif
1173     bigint_sub_s(&x, &x, dest);
1174 #if 0
1175     printf("\n-m0^-1 = ");
1176     bigint_print_hex(&x);
1177     uart0_flush();
1178 #endif
1179     bigint_copy(dest, m);
1180     bigint_mul_word_u(dest, x.wordv[0]);
1181 }
1182
1183 /******************************************************************************/
1184
1185 /*
1186  * dest = a * R mod m
1187  */
1188 void bigint_mont_trans(bigint_t *dest, const bigint_t *a, const bigint_t *m){
1189     bigint_t t;
1190     bigint_word_t t_w[a->length_W + m->length_W];
1191
1192     t.wordv = t_w;
1193     memset(t_w, 0, m->length_W * sizeof(bigint_word_t));
1194     memcpy(&t_w[m->length_W], a->wordv, a->length_W * sizeof(bigint_word_t));
1195     t.info = a->info;
1196     t.length_W = a->length_W + m->length_W;
1197     bigint_reduce(&t, m);
1198     bigint_copy(dest, &t);
1199 }
1200
1201 /******************************************************************************/
1202
1203 /* calculate dest = a**exp % r */
1204 /* using square&multiply */
1205 void bigint_expmod_u_mont_accel(bigint_t *dest, const bigint_t *a, const bigint_t *exp, const bigint_t *r, const bigint_t *m_){
1206     if(r->length_W == 0) {
1207         return;
1208     }
1209
1210     bigint_length_t s = r->length_W;
1211     bigint_t res, ax;
1212     bigint_word_t t, res_w[r->length_W * 2], ax_w[MAX(s, a->length_W)];
1213     bigint_length_t i;
1214     uint8_t j;
1215
1216     res.wordv = res_w;
1217     ax.wordv = ax_w;
1218
1219     res.wordv[0] = 1;
1220     res.length_W = 1;
1221     res.info = 0;
1222     bigint_adjust(&res);
1223     if (exp->length_W == 0) {
1224         bigint_copy(dest, &res);
1225         return;
1226     }
1227     bigint_copy(&ax, a);
1228     bigint_reduce(&ax, r);
1229     bigint_mont_trans(&ax, &ax, r);
1230     bigint_mont_trans(&res, &res, r);
1231     uint8_t flag = 0;
1232     t = exp->wordv[exp->length_W - 1];
1233     for (i = exp->length_W; i > 0; --i) {
1234         t = exp->wordv[i - 1];
1235         for(j = BIGINT_WORD_SIZE; j > 0; --j){
1236             if (!flag) {
1237                 if(t & (((bigint_wordplus_t)1) << (BIGINT_WORD_SIZE - 1))){
1238                     flag = 1;
1239                 }
1240             }
1241             if (flag) {
1242                 bigint_square(&res, &res);
1243                 bigint_mont_red(&res, &res, r, m_);
1244                 if (t & (((bigint_wordplus_t)1) << (BIGINT_WORD_SIZE - 1))) {
1245                     bigint_mont_mul(&res, &res, &ax, r, m_);
1246                 }
1247             }
1248             t <<= 1;
1249         }
1250     }
1251     SET_POS(&res);
1252     bigint_mont_red(dest, &res, r, m_);
1253 }
1254
1255 /******************************************************************************/
1256
1257 void bigint_expmod_u_mont_sam(bigint_t *dest, const bigint_t *a, const bigint_t *exp, const bigint_t *r){
1258     if(r->length_W == 0) {
1259         return;
1260     }
1261     if(a->length_W == 0) {
1262         bigint_set_zero(dest);
1263         return;
1264     }
1265     bigint_t m_;
1266     bigint_word_t m_w_[r->length_W + 1];
1267     m_.wordv = m_w_;
1268     bigint_mont_gen_m_(&m_, r);
1269     bigint_expmod_u_mont_accel(dest, a, exp, r,&m_);
1270 }
1271
1272 /******************************************************************************/
1273
1274 #endif
1275
1276 void bigint_expmod_u(bigint_t *dest, const bigint_t *a, const bigint_t *exp, const bigint_t *r){
1277     if (r->wordv[0] & 1) {
1278         bigint_expmod_u_mont_sam(dest, a, exp, r);
1279     } else {
1280         bigint_expmod_u_sam(dest, a, exp, r);
1281     }
1282 }
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298