]> git.cryptolib.org Git - avr-crypto-lib.git/blob - bigint/bigint.c
adding modulo-reduction to bigint
[avr-crypto-lib.git] / bigint / bigint.c
1 /* bigint.c */
2 /*
3     This file is part of the AVR-Crypto-Lib.
4     Copyright (C) 2008  Daniel Otte (daniel.otte@rub.de)
5
6     This program is free software: you can redistribute it and/or modify
7     it under the terms of the GNU General Public License as published by
8     the Free Software Foundation, either version 3 of the License, or
9     (at your option) any later version.
10
11     This program is distributed in the hope that it will be useful,
12     but WITHOUT ANY WARRANTY; without even the implied warranty of
13     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14     GNU General Public License for more details.
15
16     You should have received a copy of the GNU General Public License
17     along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19 /**
20  * \file                bigint.c
21  * \author              Daniel Otte
22  * \date                2010-02-22
23  * 
24  * \license         GPLv3 or later
25  * 
26  */
27  
28
29 #include "bigint.h"
30 #include <string.h>
31
32 #include "bigint_io.h"
33 #include "cli.h"
34
35 #ifndef MAX
36  #define MAX(a,b) (((a)>(b))?(a):(b))
37 #endif
38
39 #ifndef MIN
40  #define MIN(a,b) (((a)<(b))?(a):(b))
41 #endif
42
43 #define SET_FBS(a, v) do{(a)->info &=0xF8; (a)->info |= (v);}while(0)
44 #define GET_FBS(a)   ((a)->info&BIGINT_FBS_MASK)
45 #define SET_NEG(a)   (a)->info |= BIGINT_NEG_MASK
46 #define SET_POS(a)   (a)->info &= ~BIGINT_NEG_MASK
47 #define XCHG(a,b)    do{(a)^=(b); (b)^=(a); (a)^=(b);}while(0)
48 #define XCHG_PTR(a,b)    do{ a = (void*)(((uint16_t)(a)) ^ ((uint16_t)(b))); \
49                                  b = (void*)(((uint16_t)(a)) ^ ((uint16_t)(b))); \
50                                  a = (void*)(((uint16_t)(a)) ^ ((uint16_t)(b)));}while(0)
51
52 #define GET_SIGN(a) ((a)->info&BIGINT_NEG_MASK)
53
54 /******************************************************************************/
55 void bigint_adjust(bigint_t* a){
56         while(a->length_B!=0 && a->wordv[a->length_B-1]==0){
57                 a->length_B--;
58         }
59         if(a->length_B==0){
60                 a->info=0;
61                 return;
62         }
63         uint8_t t;
64         uint8_t i = 0x07;
65         t = a->wordv[a->length_B-1];
66         while((t&0x80)==0 && i){
67                 t<<=1;
68                 i--;
69         }
70         SET_FBS(a, i);
71 }
72
73 /******************************************************************************/
74
75 void bigint_copy(bigint_t* dest, const bigint_t* src){
76         memcpy(dest->wordv, src->wordv, src->length_B);
77         dest->length_B = src->length_B;
78         dest->info = src->info;
79 }
80
81 /******************************************************************************/
82
83 /* this should be implemented in assembly */
84 /*
85 void bigint_add_u(bigint_t* dest, const bigint_t* a, const bigint_t* b){
86         uint16_t t=0, i;
87         if(a->length_B < b->length_B){
88                 XCHG_PTR(a,b);
89         }
90         for(i=0; i<b->length_B; ++i){
91                 t = a->wordv[i] + b->wordv[i] + t;
92                 dest->wordv[i] = (uint8_t)t;
93                 t>>=8;
94         }
95         for(; i<a->length_B; ++i){
96                 t = a->wordv[i] + t;
97                 dest->wordv[i] = (uint8_t)t;
98                 t>>=8;
99         }
100         dest->wordv[i++] = t;
101         dest->length_B = i;
102         bigint_adjust(dest);
103 }
104 */
105 /******************************************************************************/
106
107 /* this should be implemented in assembly */
108 void bigint_add_scale_u(bigint_t* dest, const bigint_t* a, uint16_t scale){
109         uint16_t i,j=0;
110         uint16_t t=0;
111         if(scale>dest->length_B)
112                 memset(dest->wordv+dest->length_B, 0, scale-dest->length_B);
113         for(i=scale; i<a->length_B+scale; ++i,++j){
114                 t = a->wordv[j] + t;
115                 if(dest->length_B>i){
116                         t += dest->wordv[i];
117                 }
118                 dest->wordv[i] = (uint8_t)t;
119                 t>>=8;
120         }
121         while(t){
122                 if(dest->length_B>i){
123                         t = dest->wordv[i] + t;
124                 }
125                 dest->wordv[i] = (uint8_t)t;
126                 t>>=8;
127                 ++i;
128         }
129         if(dest->length_B < i){
130                 dest->length_B = i;
131         }
132         bigint_adjust(dest);
133 }
134
135 /******************************************************************************/
136
137 /* this should be implemented in assembly */
138 void bigint_sub_u(bigint_t* dest, const bigint_t* a, const bigint_t* b){
139         int8_t borrow=0;
140         int8_t  r;
141         int16_t t;
142         uint16_t i, min, max;
143         min = MIN(a->length_B, b->length_B);
144         max = MAX(a->length_B, b->length_B);
145         r = bigint_cmp_u(a,b);
146         if(r==0){
147                 dest->length_B = 0;
148                 dest->wordv[0] = 0;
149                 bigint_adjust(dest);
150                 return;
151         }
152         if(b->length_B==0){
153                 dest->length_B = a->length_B;
154                 memcpy(dest->wordv, a->wordv, a->length_B);
155                 dest->info = a->info;
156                 SET_POS(dest);
157                 return;
158         }
159         if(a->length_B==0){
160                         dest->length_B = b->length_B;
161                         memcpy(dest->wordv, b->wordv, b->length_B);
162                         dest->info = b->info;
163                         SET_NEG(dest);
164                         return;
165         }
166         if(r<0){
167                 for(i=0; i<min; ++i){
168                         t = a->wordv[i] - b->wordv[i] - borrow;
169                         if(t<1){
170                                 borrow = 0;
171                                 dest->wordv[i]=(uint8_t)(-t);
172                         }else{
173                                 borrow = -1;
174                                 dest->wordv[i]=(uint8_t)(-t);
175                         }
176                 }
177                 for(;i<max; ++i){
178                         t = b->wordv[i] + borrow;
179                         if(t<1){
180                                 borrow = 0;
181                                 dest->wordv[i]=(uint8_t)t;
182                         }else{
183                                 borrow = -1;
184                                 dest->wordv[i]=(uint8_t)t;
185                         }
186                 }
187                 SET_NEG(dest);
188                 dest->length_B = i;
189                 bigint_adjust(dest);
190         }else{
191                 for(i=0; i<min; ++i){
192                         t = a->wordv[i] - b->wordv[i] - borrow;
193                         if(t<0){
194                                 borrow = 1;
195                                 dest->wordv[i]=(uint8_t)t;
196                         }else{
197                                 borrow = 0;
198                                 dest->wordv[i]=(uint8_t)t;
199                         }
200                 }
201                 for(;i<max; ++i){
202                         t = a->wordv[i] - borrow;
203                         if(t<0){
204                                 borrow = 1;
205                                 dest->wordv[i]=(uint8_t)t;
206                         }else{
207                                 borrow = 0;
208                                 dest->wordv[i]=(uint8_t)t;
209                         }
210
211                 }
212                 SET_POS(dest);
213                 dest->length_B = i;
214                 bigint_adjust(dest);
215         }
216 }
217
218 /******************************************************************************/
219
220 int8_t bigint_cmp_u(const bigint_t* a, const bigint_t* b){
221         if(a->length_B > b->length_B){
222                 return 1;
223         }
224         if(a->length_B < b->length_B){
225                 return -1;
226         }
227         if(GET_FBS(a) > GET_FBS(b)){
228                 return 1;
229         }
230         if(GET_FBS(a) < GET_FBS(b)){
231                 return -1;
232         }
233         if(a->length_B==0){
234                 return 0;
235         }
236         uint16_t i;
237         i = a->length_B-1;
238         do{
239                 if(a->wordv[i]!=b->wordv[i]){
240                         if(a->wordv[i]>b->wordv[i]){
241                                 return 1;
242                         }else{
243                                 return -1;
244                         }
245                 }
246         }while(i--);
247         return 0;
248 }
249
250 /******************************************************************************/
251
252 void bigint_add_s(bigint_t* dest, const bigint_t* a, const bigint_t* b){
253         uint8_t s;
254         s  = GET_SIGN(a)?2:0;
255         s |= GET_SIGN(b)?1:0;
256         switch(s){
257                 case 0: /* both positive */
258                         bigint_add_u(dest, a,b);
259                         SET_POS(dest);
260                         break;
261                 case 1: /* a positive, b negative */
262                         bigint_sub_u(dest, a, b);
263                         break;
264                 case 2: /* a negative, b positive */
265                         bigint_sub_u(dest, b, a);
266                         break;
267                 case 3: /* both negative */
268                         bigint_add_u(dest, a, b);
269                         SET_NEG(dest);
270                         break;
271                 default: /* how can this happen?*/
272                         break;
273         }
274 }
275
276 /******************************************************************************/
277
278 void bigint_sub_s(bigint_t* dest, const bigint_t* a, const bigint_t* b){
279         uint8_t s;
280         s  = GET_SIGN(a)?2:0;
281         s |= GET_SIGN(b)?1:0;
282         switch(s){
283                 case 0: /* both positive */
284                         bigint_sub_u(dest, a,b);
285                         break;
286                 case 1: /* a positive, b negative */
287                         bigint_add_u(dest, a, b);
288                         SET_POS(dest);
289                         break;
290                 case 2: /* a negative, b positive */
291                         bigint_add_u(dest, a, b);
292                         SET_NEG(dest);
293                         break;
294                 case 3: /* both negative */
295                         bigint_sub_u(dest, b, a);
296                         break;
297                 default: /* how can this happen?*/
298                                         break;
299         }
300
301 }
302
303 /******************************************************************************/
304
305 int8_t bigint_cmp_s(const bigint_t* a, const bigint_t* b){
306         uint8_t s;
307         s  = GET_SIGN(a)?2:0;
308         s |= GET_SIGN(b)?1:0;
309         switch(s){
310                 case 0: /* both positive */
311                         return bigint_cmp_u(a, b);
312                         break;
313                 case 1: /* a positive, b negative */
314                         return 1;
315                         break;
316                 case 2: /* a negative, b positive */
317                         return -1;
318                         break;
319                 case 3: /* both negative */
320                         return bigint_cmp_u(b, a);
321                         break;
322                 default: /* how can this happen?*/
323                                         break;
324         }
325         return 0; /* just to satisfy the compiler */
326 }
327
328 /******************************************************************************/
329
330 void bigint_shiftleft(bigint_t* a, uint16_t shift){
331         uint16_t byteshift;
332         uint16_t i;
333         uint8_t bitshift;
334         uint16_t t=0;
335         byteshift = (shift+3)/8;
336         bitshift = shift&7;
337         memmove(a->wordv+byteshift, a->wordv, a->length_B);
338         memset(a->wordv, 0, byteshift);
339         if(bitshift!=0){
340                 if(bitshift<=4){ /* shift to the left */
341                         for(i=byteshift; i<a->length_B+byteshift; ++i){
342                                 t |= (a->wordv[i])<<bitshift;
343                                 a->wordv[i] = (uint8_t)t;
344                                 t >>= 8;
345                         }
346                         a->wordv[i] = (uint8_t)t;
347                         byteshift++;
348                 }else{ /* shift to the right */
349                         for(i=a->length_B+byteshift-1; i>byteshift-1; --i){
350                                 t |= (a->wordv[i])<<(bitshift);
351                                 a->wordv[i] = (uint8_t)(t>>8);
352                                 t <<= 8;
353                         }
354                         t |= (a->wordv[i])<<(bitshift);
355                         a->wordv[i] = (uint8_t)(t>>8);
356                 }
357         }
358         a->length_B += byteshift;
359         bigint_adjust(a);
360 }
361
362 /******************************************************************************/
363
364 void bigint_shiftright(bigint_t* a, uint16_t shift){
365         uint16_t byteshift;
366         uint16_t i;
367         uint8_t bitshift;
368         uint16_t t=0;
369         byteshift = shift/8;
370         bitshift = shift&7;
371         if(byteshift >= a->length_B){ /* we would shift out more than we have */
372                 a->length_B=0;
373                 a->wordv[0] = 0;
374                 SET_FBS(a, 0);
375                 return;
376         }
377         if(byteshift == a->length_B-1 && bitshift>GET_FBS(a)){
378                 a->length_B=0;
379                 a->wordv[0] = 0;
380                 SET_FBS(a, 0);
381                 return;
382         }
383         if(byteshift){
384                 memmove(a->wordv, a->wordv+byteshift, a->length_B-byteshift);
385                 memset(a->wordv+a->length_B-byteshift, 0,  byteshift);
386         }
387         if(bitshift!=0){
388          /* shift to the right */
389                 for(i=a->length_B-byteshift-1; i>0; --i){
390                         t |= (a->wordv[i])<<(8-bitshift);
391                         a->wordv[i] = (uint8_t)(t>>8);
392                         t <<= 8;
393                 }
394                 t |= (a->wordv[0])<<(8-bitshift);
395                 a->wordv[0] = (uint8_t)(t>>8);
396         }
397         a->length_B -= byteshift;
398         bigint_adjust(a);
399 }
400
401 /******************************************************************************/
402
403 void bigint_xor(bigint_t* dest, const bigint_t* a){
404         uint16_t i;
405         for(i=0; i<a->length_B; ++i){
406                 dest->wordv[i] ^= a->wordv[i];
407         }
408         bigint_adjust(dest);
409 }
410
411 /******************************************************************************/
412
413 void bigint_set_zero(bigint_t* a){
414         a->length_B=0;
415 }
416
417 /******************************************************************************/
418
419 /* using the Karatsuba-Algorithm */
420 /* x*y = (xh*yh)*b**2n + ((xh+xl)*(yh+yl) - xh*yh - xl*yl)*b**n + yh*yl */
421 void bigint_mul_u(bigint_t* dest, const bigint_t* a, const bigint_t* b){
422         if(a->length_B==0 || b->length_B==0){
423                 bigint_set_zero(dest);
424                 return;
425         }
426         if(a->length_B==1 || b->length_B==1){
427                 if(a->length_B!=1){
428                         XCHG_PTR(a,b);
429                 }
430                 uint16_t i, t=0;
431                 uint8_t x = a->wordv[0];
432                 for(i=0; i<b->length_B; ++i){
433                         t += b->wordv[i]*x;
434                         dest->wordv[i] = (uint8_t)t;
435                         t>>=8;
436                 }
437                 dest->wordv[i] = (uint8_t)t;
438                 dest->length_B=i+1;
439                 bigint_adjust(dest);
440                 return;
441         }
442         if(a->length_B<=4 && b->length_B<=4){
443                 uint32_t p=0, q=0;
444                 uint64_t r;
445                 memcpy(&p, a->wordv, a->length_B);
446                 memcpy(&q, b->wordv, b->length_B);
447                 r = (uint64_t)p*(uint64_t)q;
448                 memcpy(dest->wordv, &r, a->length_B+b->length_B);
449                 dest->length_B =  a->length_B+b->length_B;
450                 bigint_adjust(dest);
451                 return;
452         }
453         bigint_set_zero(dest);
454         /* split a in xh & xl; split b in yh & yl */
455         uint16_t n;
456         n=(MAX(a->length_B, b->length_B)+1)/2;
457         bigint_t xl, xh, yl, yh;
458         xl.wordv = a->wordv;
459         yl.wordv = b->wordv;
460         if(a->length_B<=n){
461                 xh.info=0;
462                 xh.length_B = 0;
463                 xl.length_B = a->length_B;
464                 xl.info = 0;
465         }else{
466                 xl.length_B=n;
467                 xl.info = 0;
468                 bigint_adjust(&xl);
469                 xh.wordv = a->wordv+n;
470                 xh.length_B = a->length_B-n;
471                 xh.info = 0;
472         }
473         if(b->length_B<=n){
474                 yh.info=0;
475                 yh.length_B = 0;
476                 yl.length_B = b->length_B;
477                 yl.info = b->info;
478         }else{
479                 yl.length_B=n;
480                 yl.info = 0;
481                 bigint_adjust(&yl);
482                 yh.wordv = b->wordv+n;
483                 yh.length_B = b->length_B-n;
484                 yh.info = 0;
485         }
486         /* now we have split up a and b */
487         uint8_t  tmp_b[2*n+2], m_b[2*(n+1)];
488         bigint_t tmp, tmp2, m;
489         tmp.wordv = tmp_b;
490         tmp2.wordv = tmp_b+n+1;
491         m.wordv = m_b;
492
493         bigint_mul_u(dest, &xl, &yl);  /* dest <= xl*yl     */
494         bigint_add_u(&tmp2, &xh, &xl); /* tmp2 <= xh+xl     */
495         bigint_add_u(&tmp, &yh, &yl);  /* tmp  <= yh+yl     */
496         bigint_mul_u(&m, &tmp2, &tmp); /* m    <= tmp*tmp   */
497         bigint_mul_u(&tmp, &xh, &yh);  /* h    <= xh*yh     */
498         bigint_sub_u(&m, &m, dest);    /* m    <= m-dest    */
499     bigint_sub_u(&m, &m, &tmp);    /* m    <= m-h       */
500         bigint_add_scale_u(dest, &m, n);
501         bigint_add_scale_u(dest, &tmp, 2*n);
502 }
503
504 /******************************************************************************/
505
506 void bigint_mul_s(bigint_t* dest, const bigint_t* a, const bigint_t* b){
507         uint8_t s;
508         s  = GET_SIGN(a)?2:0;
509         s |= GET_SIGN(b)?1:0;
510         switch(s){
511                 case 0: /* both positive */
512                         bigint_mul_u(dest, a,b);
513                         SET_POS(dest);
514                         break;
515                 case 1: /* a positive, b negative */
516                         bigint_mul_u(dest, a,b);
517                         SET_NEG(dest);
518                         break;
519                 case 2: /* a negative, b positive */
520                         bigint_mul_u(dest, a,b);
521                         SET_NEG(dest);
522                         break;
523                 case 3: /* both negative */
524                         bigint_mul_u(dest, a,b);
525                         SET_POS(dest);
526                         break;
527                 default: /* how can this happen?*/
528                         break;
529         }
530 }
531
532 /******************************************************************************/
533
534 /* square */
535 /* (xh*b^n+xl)^2 = xh^2*b^2n + 2*xh*xl*b^n + xl^2 */
536 void bigint_square(bigint_t* dest, const bigint_t* a){
537         if(a->length_B<=4){
538                 uint64_t r=0;
539                 memcpy(&r, a->wordv, a->length_B);
540                 r = r*r;
541                 memcpy(dest->wordv, &r, 2*a->length_B);
542                 SET_POS(dest);
543                 dest->length_B=2*a->length_B;
544                 bigint_adjust(dest);
545                 return;
546         }
547         uint16_t n;
548         n=(a->length_B+1)/2;
549         bigint_t xh, xl, tmp; /* x-high, x-low, temp */
550         uint8_t buffer[2*n+1];
551         xl.wordv = a->wordv;
552         xl.length_B = n;
553         xh.wordv = a->wordv+n;
554         xh.length_B = a->length_B-n;
555         tmp.wordv = buffer;
556         bigint_square(dest, &xl);
557         bigint_square(&tmp, &xh);
558         bigint_add_scale_u(dest, &tmp, 2*n);
559         bigint_mul_u(&tmp, &xl, &xh);
560         bigint_shiftleft(&tmp, 1);
561         bigint_add_scale_u(dest, &tmp, n);
562 }
563
564 /******************************************************************************/
565
566 void bigint_sub_u_bitscale(bigint_t* a, const bigint_t* b, uint16_t bitscale){
567         bigint_t tmp;
568         uint8_t tmp_b[b->length_B+1];
569         uint16_t i,j,byteshift=bitscale/8;
570         uint8_t borrow=0;
571         int16_t t;
572
573         if(a->length_B < b->length_B+byteshift){
574                 cli_putstr_P(PSTR("\r\nERROR: bigint_sub_u_bitscale result negative"));
575                 bigint_set_zero(a);
576                 return;
577         }
578
579         tmp.wordv = tmp_b;
580         bigint_copy(&tmp, b);
581         bigint_shiftleft(&tmp, bitscale&7);
582
583         for(j=0,i=byteshift; i<tmp.length_B+byteshift; ++i, ++j){
584                 t = a->wordv[i] - tmp.wordv[j] - borrow;
585                 a->wordv[i] = (uint8_t)t;
586                 if(t<0){
587                         borrow = 1;
588                 }else{
589                         borrow = 0;
590                 }
591         }
592         while(borrow){
593                 if(i+1 > a->length_B){
594                         cli_putstr_P(PSTR("\r\nERROR: bigint_sub_u_bitscale result negative (2) shift="));
595                         cli_hexdump_rev(&bitscale, 2);
596                         bigint_set_zero(a);
597                         return;
598                 }
599                 a->wordv[i] -= borrow;
600                 if(a->wordv[i]!=0xff){
601                         borrow=0;
602                 }
603                 ++i;
604         }
605         bigint_adjust(a);
606 }
607
608 /******************************************************************************/
609
610 void bigint_reduce(bigint_t* a, const bigint_t* r){
611         uint8_t rfbs = GET_FBS(r);
612         if(r->length_B==0){
613                 return;
614         }
615         while(a->length_B > r->length_B){
616                 bigint_sub_u_bitscale(a, r, (a->length_B-r->length_B)*8+GET_FBS(a)-rfbs-1);
617         }
618         while(GET_FBS(a) > rfbs+1){
619                 bigint_sub_u_bitscale(a, r, GET_FBS(a)-rfbs-1);
620         }
621         while(bigint_cmp_u(a,r)>=0){
622                 bigint_sub_u(a,a,r);
623         }
624         bigint_adjust(a);
625 }
626
627
628