3 This file is part of the AVR-Crypto-Lib.
4 Copyright (C) 2008 Daniel Otte (daniel.otte@rub.de)
6 This program is free software: you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation, either version 3 of the License, or
9 (at your option) any later version.
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with this program. If not, see <http://www.gnu.org/licenses/>.
24 * \license GPLv3 or later
32 #include "bigint_io.h"
36 #define MAX(a,b) (((a)>(b))?(a):(b))
40 #define MIN(a,b) (((a)<(b))?(a):(b))
43 #define SET_FBS(a, v) do{(a)->info &=0xF8; (a)->info |= (v);}while(0)
44 #define GET_FBS(a) ((a)->info&BIGINT_FBS_MASK)
45 #define SET_NEG(a) (a)->info |= BIGINT_NEG_MASK
46 #define SET_POS(a) (a)->info &= ~BIGINT_NEG_MASK
47 #define XCHG(a,b) do{(a)^=(b); (b)^=(a); (a)^=(b);}while(0)
48 #define XCHG_PTR(a,b) do{ a = (void*)(((uint16_t)(a)) ^ ((uint16_t)(b))); \
49 b = (void*)(((uint16_t)(a)) ^ ((uint16_t)(b))); \
50 a = (void*)(((uint16_t)(a)) ^ ((uint16_t)(b)));}while(0)
52 #define GET_SIGN(a) ((a)->info&BIGINT_NEG_MASK)
54 /******************************************************************************/
55 void bigint_adjust(bigint_t* a){
56 while(a->length_B!=0 && a->wordv[a->length_B-1]==0){
65 t = a->wordv[a->length_B-1];
66 while((t&0x80)==0 && i){
73 /******************************************************************************/
75 void bigint_copy(bigint_t* dest, const bigint_t* src){
76 memcpy(dest->wordv, src->wordv, src->length_B);
77 dest->length_B = src->length_B;
78 dest->info = src->info;
81 /******************************************************************************/
83 /* this should be implemented in assembly */
85 void bigint_add_u(bigint_t* dest, const bigint_t* a, const bigint_t* b){
87 if(a->length_B < b->length_B){
90 for(i=0; i<b->length_B; ++i){
91 t = a->wordv[i] + b->wordv[i] + t;
92 dest->wordv[i] = (uint8_t)t;
95 for(; i<a->length_B; ++i){
97 dest->wordv[i] = (uint8_t)t;
100 dest->wordv[i++] = t;
105 /******************************************************************************/
107 /* this should be implemented in assembly */
108 void bigint_add_scale_u(bigint_t* dest, const bigint_t* a, uint16_t scale){
111 if(scale>dest->length_B)
112 memset(dest->wordv+dest->length_B, 0, scale-dest->length_B);
113 for(i=scale; i<a->length_B+scale; ++i,++j){
115 if(dest->length_B>i){
118 dest->wordv[i] = (uint8_t)t;
122 if(dest->length_B>i){
123 t = dest->wordv[i] + t;
125 dest->wordv[i] = (uint8_t)t;
129 if(dest->length_B < i){
135 /******************************************************************************/
137 /* this should be implemented in assembly */
138 void bigint_sub_u(bigint_t* dest, const bigint_t* a, const bigint_t* b){
142 uint16_t i, min, max;
143 min = MIN(a->length_B, b->length_B);
144 max = MAX(a->length_B, b->length_B);
145 r = bigint_cmp_u(a,b);
153 dest->length_B = a->length_B;
154 memcpy(dest->wordv, a->wordv, a->length_B);
155 dest->info = a->info;
160 dest->length_B = b->length_B;
161 memcpy(dest->wordv, b->wordv, b->length_B);
162 dest->info = b->info;
167 for(i=0; i<min; ++i){
168 t = a->wordv[i] - b->wordv[i] - borrow;
171 dest->wordv[i]=(uint8_t)(-t);
174 dest->wordv[i]=(uint8_t)(-t);
178 t = b->wordv[i] + borrow;
181 dest->wordv[i]=(uint8_t)t;
184 dest->wordv[i]=(uint8_t)t;
191 for(i=0; i<min; ++i){
192 t = a->wordv[i] - b->wordv[i] - borrow;
195 dest->wordv[i]=(uint8_t)t;
198 dest->wordv[i]=(uint8_t)t;
202 t = a->wordv[i] - borrow;
205 dest->wordv[i]=(uint8_t)t;
208 dest->wordv[i]=(uint8_t)t;
218 /******************************************************************************/
220 int8_t bigint_cmp_u(const bigint_t* a, const bigint_t* b){
221 if(a->length_B > b->length_B){
224 if(a->length_B < b->length_B){
227 if(GET_FBS(a) > GET_FBS(b)){
230 if(GET_FBS(a) < GET_FBS(b)){
239 if(a->wordv[i]!=b->wordv[i]){
240 if(a->wordv[i]>b->wordv[i]){
250 /******************************************************************************/
252 void bigint_add_s(bigint_t* dest, const bigint_t* a, const bigint_t* b){
255 s |= GET_SIGN(b)?1:0;
257 case 0: /* both positive */
258 bigint_add_u(dest, a,b);
261 case 1: /* a positive, b negative */
262 bigint_sub_u(dest, a, b);
264 case 2: /* a negative, b positive */
265 bigint_sub_u(dest, b, a);
267 case 3: /* both negative */
268 bigint_add_u(dest, a, b);
271 default: /* how can this happen?*/
276 /******************************************************************************/
278 void bigint_sub_s(bigint_t* dest, const bigint_t* a, const bigint_t* b){
281 s |= GET_SIGN(b)?1:0;
283 case 0: /* both positive */
284 bigint_sub_u(dest, a,b);
286 case 1: /* a positive, b negative */
287 bigint_add_u(dest, a, b);
290 case 2: /* a negative, b positive */
291 bigint_add_u(dest, a, b);
294 case 3: /* both negative */
295 bigint_sub_u(dest, b, a);
297 default: /* how can this happen?*/
303 /******************************************************************************/
305 int8_t bigint_cmp_s(const bigint_t* a, const bigint_t* b){
308 s |= GET_SIGN(b)?1:0;
310 case 0: /* both positive */
311 return bigint_cmp_u(a, b);
313 case 1: /* a positive, b negative */
316 case 2: /* a negative, b positive */
319 case 3: /* both negative */
320 return bigint_cmp_u(b, a);
322 default: /* how can this happen?*/
325 return 0; /* just to satisfy the compiler */
328 /******************************************************************************/
330 void bigint_shiftleft(bigint_t* a, uint16_t shift){
335 byteshift = (shift+3)/8;
337 memmove(a->wordv+byteshift, a->wordv, a->length_B);
338 memset(a->wordv, 0, byteshift);
340 if(bitshift<=4){ /* shift to the left */
341 for(i=byteshift; i<a->length_B+byteshift; ++i){
342 t |= (a->wordv[i])<<bitshift;
343 a->wordv[i] = (uint8_t)t;
346 a->wordv[i] = (uint8_t)t;
348 }else{ /* shift to the right */
349 for(i=a->length_B+byteshift-1; i>byteshift-1; --i){
350 t |= (a->wordv[i])<<(bitshift);
351 a->wordv[i] = (uint8_t)(t>>8);
354 t |= (a->wordv[i])<<(bitshift);
355 a->wordv[i] = (uint8_t)(t>>8);
358 a->length_B += byteshift;
362 /******************************************************************************/
364 void bigint_shiftright(bigint_t* a, uint16_t shift){
371 if(byteshift >= a->length_B){ /* we would shift out more than we have */
377 if(byteshift == a->length_B-1 && bitshift>GET_FBS(a)){
384 memmove(a->wordv, a->wordv+byteshift, a->length_B-byteshift);
385 memset(a->wordv+a->length_B-byteshift, 0, byteshift);
388 /* shift to the right */
389 for(i=a->length_B-byteshift-1; i>0; --i){
390 t |= (a->wordv[i])<<(8-bitshift);
391 a->wordv[i] = (uint8_t)(t>>8);
394 t |= (a->wordv[0])<<(8-bitshift);
395 a->wordv[0] = (uint8_t)(t>>8);
397 a->length_B -= byteshift;
401 /******************************************************************************/
403 void bigint_xor(bigint_t* dest, const bigint_t* a){
405 for(i=0; i<a->length_B; ++i){
406 dest->wordv[i] ^= a->wordv[i];
411 /******************************************************************************/
413 void bigint_set_zero(bigint_t* a){
417 /******************************************************************************/
419 /* using the Karatsuba-Algorithm */
420 /* x*y = (xh*yh)*b**2n + ((xh+xl)*(yh+yl) - xh*yh - xl*yl)*b**n + yh*yl */
421 void bigint_mul_u(bigint_t* dest, const bigint_t* a, const bigint_t* b){
422 if(a->length_B==0 || b->length_B==0){
423 bigint_set_zero(dest);
426 if(a->length_B==1 || b->length_B==1){
431 uint8_t x = a->wordv[0];
432 for(i=0; i<b->length_B; ++i){
434 dest->wordv[i] = (uint8_t)t;
437 dest->wordv[i] = (uint8_t)t;
442 if(a->length_B<=4 && b->length_B<=4){
445 memcpy(&p, a->wordv, a->length_B);
446 memcpy(&q, b->wordv, b->length_B);
447 r = (uint64_t)p*(uint64_t)q;
448 memcpy(dest->wordv, &r, a->length_B+b->length_B);
449 dest->length_B = a->length_B+b->length_B;
453 bigint_set_zero(dest);
454 /* split a in xh & xl; split b in yh & yl */
456 n=(MAX(a->length_B, b->length_B)+1)/2;
457 bigint_t xl, xh, yl, yh;
463 xl.length_B = a->length_B;
469 xh.wordv = a->wordv+n;
470 xh.length_B = a->length_B-n;
476 yl.length_B = b->length_B;
482 yh.wordv = b->wordv+n;
483 yh.length_B = b->length_B-n;
486 /* now we have split up a and b */
487 uint8_t tmp_b[2*n+2], m_b[2*(n+1)];
488 bigint_t tmp, tmp2, m;
490 tmp2.wordv = tmp_b+n+1;
493 bigint_mul_u(dest, &xl, &yl); /* dest <= xl*yl */
494 bigint_add_u(&tmp2, &xh, &xl); /* tmp2 <= xh+xl */
495 bigint_add_u(&tmp, &yh, &yl); /* tmp <= yh+yl */
496 bigint_mul_u(&m, &tmp2, &tmp); /* m <= tmp*tmp */
497 bigint_mul_u(&tmp, &xh, &yh); /* h <= xh*yh */
498 bigint_sub_u(&m, &m, dest); /* m <= m-dest */
499 bigint_sub_u(&m, &m, &tmp); /* m <= m-h */
500 bigint_add_scale_u(dest, &m, n);
501 bigint_add_scale_u(dest, &tmp, 2*n);
504 /******************************************************************************/
506 void bigint_mul_s(bigint_t* dest, const bigint_t* a, const bigint_t* b){
509 s |= GET_SIGN(b)?1:0;
511 case 0: /* both positive */
512 bigint_mul_u(dest, a,b);
515 case 1: /* a positive, b negative */
516 bigint_mul_u(dest, a,b);
519 case 2: /* a negative, b positive */
520 bigint_mul_u(dest, a,b);
523 case 3: /* both negative */
524 bigint_mul_u(dest, a,b);
527 default: /* how can this happen?*/
532 /******************************************************************************/
535 /* (xh*b^n+xl)^2 = xh^2*b^2n + 2*xh*xl*b^n + xl^2 */
536 void bigint_square(bigint_t* dest, const bigint_t* a){
539 memcpy(&r, a->wordv, a->length_B);
541 memcpy(dest->wordv, &r, 2*a->length_B);
543 dest->length_B=2*a->length_B;
549 bigint_t xh, xl, tmp; /* x-high, x-low, temp */
550 uint8_t buffer[2*n+1];
553 xh.wordv = a->wordv+n;
554 xh.length_B = a->length_B-n;
556 bigint_square(dest, &xl);
557 bigint_square(&tmp, &xh);
558 bigint_add_scale_u(dest, &tmp, 2*n);
559 bigint_mul_u(&tmp, &xl, &xh);
560 bigint_shiftleft(&tmp, 1);
561 bigint_add_scale_u(dest, &tmp, n);
564 /******************************************************************************/
566 void bigint_sub_u_bitscale(bigint_t* a, const bigint_t* b, uint16_t bitscale){
568 uint8_t tmp_b[b->length_B+1];
569 uint16_t i,j,byteshift=bitscale/8;
573 if(a->length_B < b->length_B+byteshift){
574 cli_putstr_P(PSTR("\r\nERROR: bigint_sub_u_bitscale result negative"));
580 bigint_copy(&tmp, b);
581 bigint_shiftleft(&tmp, bitscale&7);
583 for(j=0,i=byteshift; i<tmp.length_B+byteshift; ++i, ++j){
584 t = a->wordv[i] - tmp.wordv[j] - borrow;
585 a->wordv[i] = (uint8_t)t;
593 if(i+1 > a->length_B){
594 cli_putstr_P(PSTR("\r\nERROR: bigint_sub_u_bitscale result negative (2) shift="));
595 cli_hexdump_rev(&bitscale, 2);
599 a->wordv[i] -= borrow;
600 if(a->wordv[i]!=0xff){
608 /******************************************************************************/
610 void bigint_reduce(bigint_t* a, const bigint_t* r){
611 uint8_t rfbs = GET_FBS(r);
615 while(a->length_B > r->length_B){
616 bigint_sub_u_bitscale(a, r, (a->length_B-r->length_B)*8+GET_FBS(a)-rfbs-1);
618 while(GET_FBS(a) > rfbs+1){
619 bigint_sub_u_bitscale(a, r, GET_FBS(a)-rfbs-1);
621 while(bigint_cmp_u(a,r)>=0){