]> git.cryptolib.org Git - avr-crypto-lib.git/blob - bigint/bigint.c
fb2f524cfc2fedc0227a3ee5c2903e1d64b15de9
[avr-crypto-lib.git] / bigint / bigint.c
1 /* bigint.c */
2 /*
3     This file is part of the AVR-Crypto-Lib.
4     Copyright (C) 2008  Daniel Otte (daniel.otte@rub.de)
5
6     This program is free software: you can redistribute it and/or modify
7     it under the terms of the GNU General Public License as published by
8     the Free Software Foundation, either version 3 of the License, or
9     (at your option) any later version.
10
11     This program is distributed in the hope that it will be useful,
12     but WITHOUT ANY WARRANTY; without even the implied warranty of
13     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14     GNU General Public License for more details.
15
16     You should have received a copy of the GNU General Public License
17     along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19 /**
20  * \file                bigint.c
21  * \author              Daniel Otte
22  * \date                2010-02-22
23  * 
24  * \license         GPLv3 or later
25  * 
26  */
27  
28
29 #include "bigint.h"
30 #include <string.h>
31
32 #include "bigint_io.h"
33 #include "cli.h"
34
35 #ifndef MAX
36  #define MAX(a,b) (((a)>(b))?(a):(b))
37 #endif
38
39 #ifndef MIN
40  #define MIN(a,b) (((a)<(b))?(a):(b))
41 #endif
42
43 #define SET_FBS(a, v) do{(a)->info &=0xF8; (a)->info |= (v);}while(0)
44 #define GET_FBS(a)   ((a)->info&BIGINT_FBS_MASK)
45 #define SET_NEG(a)   (a)->info |= BIGINT_NEG_MASK
46 #define SET_POS(a)   (a)->info &= ~BIGINT_NEG_MASK
47 #define XCHG(a,b)    do{(a)^=(b); (b)^=(a); (a)^=(b);}while(0)
48 #define XCHG_PTR(a,b)    do{ a = (void*)(((uint16_t)(a)) ^ ((uint16_t)(b))); \
49                                  b = (void*)(((uint16_t)(a)) ^ ((uint16_t)(b))); \
50                                  a = (void*)(((uint16_t)(a)) ^ ((uint16_t)(b)));}while(0)
51
52 #define GET_SIGN(a) ((a)->info&BIGINT_NEG_MASK)
53
54 /******************************************************************************/
55 void bigint_adjust(bigint_t* a){
56         while(a->length_B!=0 && a->wordv[a->length_B-1]==0){
57                 a->length_B--;
58         }
59         if(a->length_B==0){
60                 a->info=0;
61                 return;
62         }
63         uint8_t t;
64         uint8_t i = 0x07;
65         t = a->wordv[a->length_B-1];
66         while((t&0x80)==0 && i){
67                 t<<=1;
68                 i--;
69         }
70         SET_FBS(a, i);
71 }
72
73 /******************************************************************************/
74
75 void bigint_copy(bigint_t* dest, const bigint_t* src){
76         memcpy(dest->wordv, src->wordv, src->length_B);
77         dest->length_B = src->length_B;
78         dest->info = src->info;
79 }
80
81 /******************************************************************************/
82
83 /* this should be implemented in assembly */
84 /*
85 void bigint_add_u(bigint_t* dest, const bigint_t* a, const bigint_t* b){
86         uint16_t t=0, i;
87         if(a->length_B < b->length_B){
88                 XCHG_PTR(a,b);
89         }
90         for(i=0; i<b->length_B; ++i){
91                 t = a->wordv[i] + b->wordv[i] + t;
92                 dest->wordv[i] = (uint8_t)t;
93                 t>>=8;
94         }
95         for(; i<a->length_B; ++i){
96                 t = a->wordv[i] + t;
97                 dest->wordv[i] = (uint8_t)t;
98                 t>>=8;
99         }
100         dest->wordv[i++] = t;
101         dest->length_B = i;
102         bigint_adjust(dest);
103 }
104 */
105 /******************************************************************************/
106
107 /* this should be implemented in assembly */
108 void bigint_add_scale_u(bigint_t* dest, const bigint_t* a, uint16_t scale){
109         uint16_t i,j=0;
110         uint16_t t=0;
111         if(scale>dest->length_B)
112                 memset(dest->wordv+dest->length_B, 0, scale-dest->length_B);
113         for(i=scale; i<a->length_B+scale; ++i,++j){
114                 t = a->wordv[j] + t;
115                 if(dest->length_B>i){
116                         t += dest->wordv[i];
117                 }
118                 dest->wordv[i] = (uint8_t)t;
119                 t>>=8;
120         }
121         while(t){
122                 if(dest->length_B>i){
123                         t = dest->wordv[i] + t;
124                 }
125                 dest->wordv[i] = (uint8_t)t;
126                 t>>=8;
127                 ++i;
128         }
129         if(dest->length_B < i){
130                 dest->length_B = i;
131         }
132         bigint_adjust(dest);
133 }
134
135 /******************************************************************************/
136
137 /* this should be implemented in assembly */
138 void bigint_sub_u(bigint_t* dest, const bigint_t* a, const bigint_t* b){
139         int8_t borrow=0;
140         int8_t  r;
141         int16_t t;
142         uint16_t i, min, max;
143         min = MIN(a->length_B, b->length_B);
144         max = MAX(a->length_B, b->length_B);
145         r = bigint_cmp_u(a,b);
146         if(r==0){
147                 dest->length_B = 0;
148                 dest->wordv[0] = 0;
149                 bigint_adjust(dest);
150                 return;
151         }
152         if(b->length_B==0){
153                 dest->length_B = a->length_B;
154                 memcpy(dest->wordv, a->wordv, a->length_B);
155                 dest->info = a->info;
156                 SET_POS(dest);
157                 return;
158         }
159         if(a->length_B==0){
160                         dest->length_B = b->length_B;
161                         memcpy(dest->wordv, b->wordv, b->length_B);
162                         dest->info = b->info;
163                         SET_NEG(dest);
164                         return;
165         }
166         if(r<0){
167                 for(i=0; i<min; ++i){
168                         t = a->wordv[i] - b->wordv[i] - borrow;
169                         if(t<1){
170                                 borrow = 0;
171                                 dest->wordv[i]=(uint8_t)(-t);
172                         }else{
173                                 borrow = -1;
174                                 dest->wordv[i]=(uint8_t)(-t);
175                         }
176                 }
177                 for(;i<max; ++i){
178                         t = b->wordv[i] + borrow;
179                         if(t<1){
180                                 borrow = 0;
181                                 dest->wordv[i]=(uint8_t)t;
182                         }else{
183                                 borrow = -1;
184                                 dest->wordv[i]=(uint8_t)t;
185                         }
186                 }
187                 SET_NEG(dest);
188                 dest->length_B = i;
189                 bigint_adjust(dest);
190         }else{
191                 for(i=0; i<min; ++i){
192                         t = a->wordv[i] - b->wordv[i] - borrow;
193                         if(t<0){
194                                 borrow = 1;
195                                 dest->wordv[i]=(uint8_t)t;
196                         }else{
197                                 borrow = 0;
198                                 dest->wordv[i]=(uint8_t)t;
199                         }
200                 }
201                 for(;i<max; ++i){
202                         t = a->wordv[i] - borrow;
203                         if(t<0){
204                                 borrow = 1;
205                                 dest->wordv[i]=(uint8_t)t;
206                         }else{
207                                 borrow = 0;
208                                 dest->wordv[i]=(uint8_t)t;
209                         }
210
211                 }
212                 SET_POS(dest);
213                 dest->length_B = i;
214                 bigint_adjust(dest);
215         }
216 }
217
218 /******************************************************************************/
219
220 int8_t bigint_cmp_u(const bigint_t* a, const bigint_t* b){
221         if(a->length_B > b->length_B){
222                 return 1;
223         }
224         if(a->length_B < b->length_B){
225                 return -1;
226         }
227         if(GET_FBS(a) > GET_FBS(b)){
228                 return 1;
229         }
230         if(GET_FBS(a) < GET_FBS(b)){
231                 return -1;
232         }
233         if(a->length_B==0){
234                 return 0;
235         }
236         uint16_t i;
237         i = a->length_B-1;
238         do{
239                 if(a->wordv[i]!=b->wordv[i]){
240                         if(a->wordv[i]>b->wordv[i]){
241                                 return 1;
242                         }else{
243                                 return -1;
244                         }
245                 }
246         }while(i--);
247         return 0;
248 }
249
250 /******************************************************************************/
251
252 void bigint_add_s(bigint_t* dest, const bigint_t* a, const bigint_t* b){
253         uint8_t s;
254         s  = GET_SIGN(a)?2:0;
255         s |= GET_SIGN(b)?1:0;
256         switch(s){
257                 case 0: /* both positive */
258                         bigint_add_u(dest, a,b);
259                         SET_POS(dest);
260                         break;
261                 case 1: /* a positive, b negative */
262                         bigint_sub_u(dest, a, b);
263                         break;
264                 case 2: /* a negative, b positive */
265                         bigint_sub_u(dest, b, a);
266                         break;
267                 case 3: /* both negative */
268                         bigint_add_u(dest, a, b);
269                         SET_NEG(dest);
270                         break;
271                 default: /* how can this happen?*/
272                         break;
273         }
274 }
275
276 /******************************************************************************/
277
278 void bigint_sub_s(bigint_t* dest, const bigint_t* a, const bigint_t* b){
279         uint8_t s;
280         s  = GET_SIGN(a)?2:0;
281         s |= GET_SIGN(b)?1:0;
282         switch(s){
283                 case 0: /* both positive */
284                         bigint_sub_u(dest, a,b);
285                         break;
286                 case 1: /* a positive, b negative */
287                         bigint_add_u(dest, a, b);
288                         SET_POS(dest);
289                         break;
290                 case 2: /* a negative, b positive */
291                         bigint_add_u(dest, a, b);
292                         SET_NEG(dest);
293                         break;
294                 case 3: /* both negative */
295                         bigint_sub_u(dest, b, a);
296                         break;
297                 default: /* how can this happen?*/
298                                         break;
299         }
300
301 }
302
303 /******************************************************************************/
304
305 int8_t bigint_cmp_s(const bigint_t* a, const bigint_t* b){
306         uint8_t s;
307         s  = GET_SIGN(a)?2:0;
308         s |= GET_SIGN(b)?1:0;
309         switch(s){
310                 case 0: /* both positive */
311                         return bigint_cmp_u(a, b);
312                         break;
313                 case 1: /* a positive, b negative */
314                         return 1;
315                         break;
316                 case 2: /* a negative, b positive */
317                         return -1;
318                         break;
319                 case 3: /* both negative */
320                         return bigint_cmp_u(b, a);
321                         break;
322                 default: /* how can this happen?*/
323                                         break;
324         }
325         return 0; /* just to satisfy the compiler */
326 }
327
328 /******************************************************************************/
329
330 void bigint_shiftleft(bigint_t* a, uint16_t shift){
331         uint16_t byteshift;
332         uint16_t i;
333         uint8_t bitshift;
334         uint16_t t=0;
335         byteshift = (shift+3)/8;
336         bitshift = shift&7;
337         memmove(a->wordv+byteshift, a->wordv, a->length_B);
338         memset(a->wordv, 0, byteshift);
339         if(bitshift!=0){
340                 if(bitshift<=4){ /* shift to the left */
341                         for(i=byteshift; i<a->length_B+byteshift; ++i){
342                                 t |= (a->wordv[i])<<bitshift;
343                                 a->wordv[i] = (uint8_t)t;
344                                 t >>= 8;
345                         }
346                         a->wordv[i] = (uint8_t)t;
347                 }else{ /* shift to the right */
348                         bitshift = 8 - bitshift;
349                         for(i=a->length_B+byteshift-1; i>byteshift; --i){
350                                 t |= (a->wordv[i])<<(8-bitshift);
351                                 a->wordv[i] = (uint8_t)(t>>8);
352                                 t <<= 8;
353                         }
354                         t |= (a->wordv[i])<<(8-bitshift);
355                         a->wordv[i] = (uint8_t)(t>>8);
356                 }
357
358         }
359         a->length_B += byteshift+1;
360         bigint_adjust(a);
361 }
362
363 /******************************************************************************/
364
365 void bigint_shiftright(bigint_t* a, uint16_t shift){
366         uint16_t byteshift;
367         uint16_t i;
368         uint8_t bitshift;
369         uint16_t t=0;
370         byteshift = shift/8;
371         bitshift = shift&7;
372         if(byteshift >= a->length_B){ /* we would shift out more than we have */
373                 a->length_B=0;
374                 a->wordv[0] = 0;
375                 SET_FBS(a, 0);
376                 return;
377         }
378         if(byteshift == a->length_B-1 && bitshift>GET_FBS(a)){
379                 a->length_B=0;
380                 a->wordv[0] = 0;
381                 SET_FBS(a, 0);
382                 return;
383         }
384         if(byteshift){
385                 memmove(a->wordv, a->wordv+byteshift, a->length_B-byteshift);
386                 memset(a->wordv+a->length_B-byteshift, 0,  byteshift);
387         }
388         if(bitshift!=0){
389          /* shift to the right */
390                 for(i=a->length_B-byteshift-1; i>0; --i){
391                         t |= (a->wordv[i])<<(8-bitshift);
392                         a->wordv[i] = (uint8_t)(t>>8);
393                         t <<= 8;
394                 }
395                 t |= (a->wordv[0])<<(8-bitshift);
396                 a->wordv[0] = (uint8_t)(t>>8);
397         }
398         a->length_B -= byteshift;
399         bigint_adjust(a);
400 }
401
402 /******************************************************************************/
403
404 void bigint_xor(bigint_t* dest, const bigint_t* a){
405         uint16_t i;
406         for(i=0; i<a->length_B; ++i){
407                 dest->wordv[i] ^= a->wordv[i];
408         }
409         bigint_adjust(dest);
410 }
411
412 /******************************************************************************/
413
414 void bigint_set_zero(bigint_t* a){
415         a->length_B=0;
416 }
417
418 /******************************************************************************/
419
420 /* using the Karatsuba-Algorithm */
421 /* x*y = (xh*yh)*b**2n + ((xh+xl)*(yh+yl) - xh*yh - xl*yl)*b**n + yh*yl */
422 void bigint_mul_u(bigint_t* dest, const bigint_t* a, const bigint_t* b){
423         if(a->length_B==0 || b->length_B==0){
424                 bigint_set_zero(dest);
425                 depth -= 1;
426                 return;
427         }
428         if(a->length_B==1 || b->length_B==1){
429                 if(a->length_B!=1){
430                         XCHG_PTR(a,b);
431                 }
432                 uint16_t i, t=0;
433                 uint8_t x = a->wordv[0];
434                 for(i=0; i<b->length_B; ++i){
435                         t += b->wordv[i]*x;
436                         dest->wordv[i] = (uint8_t)t;
437                         t>>=8;
438                 }
439                 dest->wordv[i] = (uint8_t)t;
440                 dest->length_B=i+1;
441                 bigint_adjust(dest);
442                 return;
443         }
444         if(a->length_B<=4 && b->length_B<=4){
445                 uint32_t p=0, q=0;
446                 uint64_t r;
447                 memcpy(&p, a->wordv, a->length_B);
448                 memcpy(&q, b->wordv, b->length_B);
449                 r = (uint64_t)p*(uint64_t)q;
450                 memcpy(dest->wordv, &r, a->length_B+b->length_B);
451                 dest->length_B =  a->length_B+b->length_B;
452                 bigint_adjust(dest);
453                 return;
454         }
455         bigint_set_zero(dest);
456         /* split a in xh & xl; split b in yh & yl */
457         uint16_t n;
458         n=(MAX(a->length_B, b->length_B)+1)/2;
459         bigint_t xl, xh, yl, yh;
460         xl.wordv = a->wordv;
461         yl.wordv = b->wordv;
462         if(a->length_B<=n){
463                 xh.info=0;
464                 xh.length_B = 0;
465                 xl.length_B = a->length_B;
466                 xl.info = 0;
467         }else{
468                 xl.length_B=n;
469                 xl.info = 0;
470                 bigint_adjust(&xl);
471                 xh.wordv = a->wordv+n;
472                 xh.length_B = a->length_B-n;
473                 xh.info = 0;
474         }
475         if(b->length_B<=n){
476                 yh.info=0;
477                 yh.length_B = 0;
478                 yl.length_B = b->length_B;
479                 yl.info = b->info;
480         }else{
481                 yl.length_B=n;
482                 yl.info = 0;
483                 bigint_adjust(&yl);
484                 yh.wordv = b->wordv+n;
485                 yh.length_B = b->length_B-n;
486                 yh.info = 0;
487         }
488         /* now we have split up a and b */
489         uint8_t  tmp_b[2*n+2], m_b[2*(n+1)];
490         bigint_t tmp, tmp2, m;
491         /* calculate h=xh*yh; l=xl*yl; sx=xh+xl; sy=yh+yl */
492         tmp.wordv = tmp_b;
493         tmp2.wordv = tmp_b+n+1;
494         m.wordv = m_b;
495
496         bigint_mul_u(dest, &xl, &yl);  /* dest <= xl*yl     */
497         bigint_add_u(&tmp2, &xh, &xl); /* tmp2 <= xh+xl     */
498         bigint_add_u(&tmp, &yh, &yl);  /* tmp  <= yh+yl     */
499         bigint_mul_u(&m, &tmp2, &tmp); /* m    <= tmp*tmp   */
500         bigint_mul_u(&tmp, &xh, &yh);  /* h    <= xh*yh     */
501         bigint_sub_u(&m, &m, dest);    /* m    <= m-dest    */
502     bigint_sub_u(&m, &m, &tmp);    /* m    <= m-h       */
503         bigint_add_scale_u(dest, &m, n);
504         bigint_add_scale_u(dest, &tmp, 2*n);
505 }
506
507 /******************************************************************************/
508
509 void bigint_mul_s(bigint_t* dest, const bigint_t* a, const bigint_t* b){
510         uint8_t s;
511         s  = GET_SIGN(a)?2:0;
512         s |= GET_SIGN(b)?1:0;
513         switch(s){
514                 case 0: /* both positive */
515                         bigint_mul_u(dest, a,b);
516                         SET_POS(dest);
517                         break;
518                 case 1: /* a positive, b negative */
519                         bigint_mul_u(dest, a,b);
520                         SET_NEG(dest);
521                         break;
522                 case 2: /* a negative, b positive */
523                         bigint_mul_u(dest, a,b);
524                         SET_NEG(dest);
525                         break;
526                 case 3: /* both negative */
527                         bigint_mul_u(dest, a,b);
528                         SET_POS(dest);
529                         break;
530                 default: /* how can this happen?*/
531                         break;
532         }
533 }
534
535 /******************************************************************************/
536
537 /* square */
538 /* (xh*b^n+xl)^2 = xh^2*b^2n + 2*xh*xl*b^n + xl^2 */
539 void bigint_square(bigint_t* dest, const bigint_t* a){
540         if(a->length_B<=4){
541                 uint64_t r=0;
542                 memcpy(&r, a->wordv, a->length_B);
543                 r = r*r;
544                 memcpy(dest->wordv, &r, 2*a->length_B);
545                 SET_POS(dest);
546                 dest->length_B=2*a->length_B;
547                 bigint_adjust(dest);
548                 return;
549         }
550         uint16_t n;
551         n=(a->length_B+1)/2;
552         bigint_t xh, xl, tmp; /* x-high, x-low, temp */
553         uint8_t buffer[2*n+1];
554         xl.wordv = a->wordv;
555         xl.length_B = n;
556         xh.wordv = a->wordv+n;
557         xh.length_B = a->length_B-n;
558         tmp.wordv = buffer;
559         bigint_square(dest, &xl);
560         bigint_square(&tmp, &xh);
561         bigint_add_scale_u(dest, &tmp, 2*n);
562         bigint_mul_u(&tmp, &xl, &xh);
563         bigint_shiftleft(&tmp, 1);
564         bigint_add_scale_u(dest, &tmp, n);
565
566 }
567
568
569
570
571
572
573