From 10951fe0bb7dae5229dff9408d8157490005590c Mon Sep 17 00:00:00 2001 From: bg Date: Mon, 17 Feb 2014 00:27:02 +0100 Subject: [PATCH] bug fixing and support for malloc instead of stack memory (some functions) --- bigint/bigint.c | 413 +++++++++++++--------------- bigint/bigint.h | 5 +- rsa/rsa_basic.c | 12 +- rsa/rsaes_pkcs1v15.c | 2 +- test_src/main-rsaes_pkcs1v15-test.c | 4 +- 5 files changed, 207 insertions(+), 229 deletions(-) diff --git a/bigint/bigint.c b/bigint/bigint.c index 8524daa..d868a49 100644 --- a/bigint/bigint.c +++ b/bigint/bigint.c @@ -33,6 +33,21 @@ #include "bigint.h" #include +#define PREFERE_HEAP_SPACE 1 + +#if PREFERE_HEAP_SPACE +#include + +#define ALLOC_BIGINT_WORDS(var,words) bigint_word_t *(var) = malloc((words) * sizeof(bigint_word_t)) +#define FREE(x) free(x) + +#else + +#define ALLOC_BIGINT_WORDS(var,words) bigint_word_t var[words] +#define FREE(x) + +#endif + #define DEBUG 1 #if DEBUG || 1 @@ -367,12 +382,26 @@ int8_t bigint_cmp_s(const bigint_t *a, const bigint_t *b){ /******************************************************************************/ +void bigint_shiftleft_bits(bigint_t *a, uint8_t shift) { + bigint_length_t i; + bigint_wordplus_t t = 0; + + for (i = 0; i < a->length_W; ++i) { + t |= ((bigint_wordplus_t)a->wordv[i]) << shift; + a->wordv[i] = (bigint_word_t)t; + t >>= BIGINT_WORD_SIZE; + } + if (t) { + a->wordv[a->length_W++] = (bigint_word_t)t; + } + bigint_adjust(a); +} + +/******************************************************************************/ + void bigint_shiftleft(bigint_t *a, bigint_length_t shift){ - bigint_length_t byteshift, words_to_shift; - int16_t i; + bigint_length_t byteshift; uint8_t bitshift; - bigint_word_t *p; - bigint_wordplus_t t = 0; if (a->length_W == 0 || shift == 0) { return; @@ -385,36 +414,63 @@ void bigint_shiftleft(bigint_t *a, bigint_length_t shift){ if (byteshift) { memmove(((uint8_t*)a->wordv) + byteshift, a->wordv, a->length_W * sizeof(bigint_word_t)); memset(a->wordv, 0, byteshift); + a->length_W += (byteshift + sizeof(bigint_word_t) - 1) / sizeof(bigint_word_t); } + if (bitshift == 0) { - a->length_W += (byteshift + sizeof(bigint_word_t) - 1) / sizeof(bigint_word_t); bigint_adjust(a); - return; + } else { + bigint_shiftleft_bits(a, bitshift); } - p = &a->wordv[byteshift / sizeof(bigint_word_t)]; - words_to_shift = a->length_W + (byteshift % sizeof(bigint_word_t) ? 1 : 0); - for (i = 0; i < words_to_shift; ++i) { - t |= ((bigint_wordplus_t)p[i]) << bitshift; - p[i] = (bigint_word_t)t; - t >>= BIGINT_WORD_SIZE; +} + +/******************************************************************************/ + +void bigint_shiftright_bits(bigint_t *a, uint8_t shift){ + bigint_length_t i; + bigint_wordplus_t t = 0; + + i = a->length_W; + while (i--) { + t |= ((bigint_wordplus_t)(a->wordv[i])) << (BIGINT_WORD_SIZE - shift); + a->wordv[i] = (bigint_word_t)(t >> BIGINT_WORD_SIZE); + t <<= BIGINT_WORD_SIZE; } - if (t) { - p[i] = (bigint_word_t)t; - a->length_W += 1; + + bigint_adjust(a); +} + +/******************************************************************************/ + +void bigint_shiftright_1bit(bigint_t *a){ + bigint_length_t i; + bigint_word_t t1 = 0, t2; + + i = a->length_W; + while (i--) { + t2 = a->wordv[i] & 1; + a->wordv[i] >>= 1; + a->wordv[i] |= t1; + t1 = t2 << (BIGINT_WORD_SIZE - 1); } - a->length_W += (byteshift + sizeof(bigint_word_t) - 1) / sizeof(bigint_word_t); bigint_adjust(a); } /******************************************************************************/ +void bigint_shiftright_1word(bigint_t *a){ + if (a->length_W == 0) { + return; + } + memmove(a->wordv, &a->wordv[1], (a->length_W - 1) * sizeof(bigint_word_t)); + a->length_W--; +} + +/******************************************************************************/ + void bigint_shiftright(bigint_t *a, bigint_length_t shift){ - bigint_length_t byteshift; - bigint_length_t i; - uint8_t bitshift; - bigint_wordplus_t t = 0; - byteshift = shift / 8; - bitshift = shift & 7; + bigint_length_t byteshift = shift / 8; + uint8_t bitshift = shift & 7; if (a->length_W == 0) { return; @@ -429,19 +485,12 @@ void bigint_shiftright(bigint_t *a, bigint_length_t shift){ memmove(a->wordv, (uint8_t*)a->wordv + byteshift, a->length_W * sizeof(bigint_word_t) - byteshift); memset((uint8_t*)&a->wordv[a->length_W] - byteshift, 0, byteshift); a->length_W -= byteshift / sizeof(bigint_word_t); + bigint_adjust(a); } - - if(bitshift != 0 && a->length_W){ - /* shift to the right */ - i = a->length_W - 1; - do{ - t |= ((bigint_wordplus_t)(a->wordv[i])) << (BIGINT_WORD_SIZE - bitshift); - a->wordv[i] = (bigint_word_t)(t >> BIGINT_WORD_SIZE); - t <<= BIGINT_WORD_SIZE; - }while(i--); + if(bitshift != 0){ + bigint_shiftright_bits(a, bitshift); } - bigint_adjust(a); } /******************************************************************************/ @@ -457,7 +506,7 @@ void bigint_xor(bigint_t *dest, const bigint_t *a){ /******************************************************************************/ void bigint_set_zero(bigint_t *a){ - a->length_W=0; + a->length_W = 0; } /******************************************************************************/ @@ -465,11 +514,11 @@ void bigint_set_zero(bigint_t *a){ /* using the Karatsuba-Algorithm */ /* x*y = (xh*yh)*b**2n + ((xh+xl)*(yh+yl) - xh*yh - xl*yl)*b**n + yh*yl */ void bigint_mul_u(bigint_t *dest, const bigint_t *a, const bigint_t *b){ - if(a->length_W == 0 || b->length_W == 0){ + if (a->length_W == 0 || b->length_W == 0) { bigint_set_zero(dest); return; } - if(dest == a || dest == b){ + if (dest == a || dest == b) { bigint_t d; bigint_word_t d_b[a->length_W + b->length_W]; d.wordv = d_b; @@ -477,20 +526,20 @@ void bigint_mul_u(bigint_t *dest, const bigint_t *a, const bigint_t *b){ bigint_copy(dest, &d); return; } - if(a->length_W == 1 || b->length_W == 1){ - if(a->length_W != 1){ + if (a->length_W == 1 || b->length_W == 1) { + if (a->length_W != 1) { XCHG_PTR(a,b); } bigint_wordplus_t t = 0; bigint_length_t i; bigint_word_t x = a->wordv[0]; - for(i=0; i < b->length_W; ++i){ + for (i = 0; i < b->length_W; ++i) { t += ((bigint_wordplus_t)b->wordv[i]) * ((bigint_wordplus_t)x); dest->wordv[i] = (bigint_word_t)t; t >>= BIGINT_WORD_SIZE; } dest->length_W = i; - if(t){ + if (t) { dest->wordv[i] = (bigint_word_t)t; dest->length_W += 1; } @@ -498,39 +547,39 @@ void bigint_mul_u(bigint_t *dest, const bigint_t *a, const bigint_t *b){ bigint_adjust(dest); return; } - if(a->length_W * sizeof(bigint_word_t) <= 4 && b->length_W * sizeof(bigint_word_t) <= 4){ - uint32_t p=0, q=0; + if (a->length_W * sizeof(bigint_word_t) <= 4 && b->length_W * sizeof(bigint_word_t) <= 4) { + uint32_t p = 0, q = 0; uint64_t r; - memcpy(&p, a->wordv, a->length_W*sizeof(bigint_word_t)); - memcpy(&q, b->wordv, b->length_W*sizeof(bigint_word_t)); + memcpy(&p, a->wordv, a->length_W * sizeof(bigint_word_t)); + memcpy(&q, b->wordv, b->length_W * sizeof(bigint_word_t)); r = (uint64_t)p * (uint64_t)q; - memcpy(dest->wordv, &r, (dest->length_W = a->length_W + b->length_W)*sizeof(bigint_word_t)); + memcpy(dest->wordv, &r, (dest->length_W = a->length_W + b->length_W) * sizeof(bigint_word_t)); bigint_adjust(dest); return; } /* split a in xh & xl; split b in yh & yl */ - const bigint_length_t n = (MAX(a->length_W, b->length_W)+1)/2; + const bigint_length_t n = (MAX(a->length_W, b->length_W) + 1) / 2; bigint_t xl, xh, yl, yh; xl.wordv = a->wordv; yl.wordv = b->wordv; - if(a->length_W<=n){ + if (a->length_W <= n) { bigint_set_zero(&xh); xl.length_W = a->length_W; xl.info = a->info; - }else{ - xl.length_W=n; + } else { + xl.length_W = n; xl.info = 0; bigint_adjust(&xl); xh.wordv = &(a->wordv[n]); xh.length_W = a->length_W-n; xh.info = a->info; } - if(b->length_W<=n){ + if (b->length_W <= n) { bigint_set_zero(&yh); yl.length_W = b->length_W; yl.info = b->info; - }else{ - yl.length_W=n; + } else { + yl.length_W = n; yl.info = 0; bigint_adjust(&yl); yh.wordv = &(b->wordv[n]); @@ -546,11 +595,12 @@ void bigint_mul_u(bigint_t *dest, const bigint_t *a, const bigint_t *b){ * x*y = (xh * yh) * b**2n + ((xh+xl)*(yh+yl) - xh*yh - xl*yl)*b**n + xl*yl * 5 9 2 4 3 7 5 6 1 8 1 */ - bigint_word_t tmp_b[2 * n + 2], m_b[2 * (n + 1)]; + ALLOC_BIGINT_WORDS(tmp_w, 2 * n + 2); + ALLOC_BIGINT_WORDS(m_w, 2 * n + 2); bigint_t tmp, tmp2, m; - tmp.wordv = tmp_b; - tmp2.wordv = &(tmp_b[n + 1]); - m.wordv = m_b; + tmp.wordv = tmp_w; + tmp2.wordv = &(tmp_w[n + 1]); + m.wordv = m_w; bigint_mul_u(dest, &xl, &yl); /* 1: dest <= xl*yl */ bigint_add_u(&tmp2, &xh, &xl); /* 2: tmp2 <= xh+xl */ @@ -561,6 +611,8 @@ void bigint_mul_u(bigint_t *dest, const bigint_t *a, const bigint_t *b){ bigint_sub_u(&m, &m, &tmp); /* 7: m <= m-tmp */ bigint_add_scale_u(dest, &m, n * sizeof(bigint_word_t)); /* 8: dest <= dest+m**n*/ bigint_add_scale_u(dest, &tmp, 2 * n * sizeof(bigint_word_t)); /* 9: dest <= dest+tmp**(2*n) */ + FREE(m_w); + FREE(tmp_w); } /******************************************************************************/ @@ -591,102 +643,6 @@ void bigint_mul_s(bigint_t *dest, const bigint_t *a, const bigint_t *b){ } } -/******************************************************************************/ -#if OLD_SQUARE - -#if DEBUG_SQUARE -unsigned square_depth = 0; -#endif - -/* square */ -/* (xh*b^n+xl)^2 = xh^2*b^2n + 2*xh*xl*b^n + xl^2 */ -void bigint_square(bigint_t *dest, const bigint_t *a){ - if(a->length_W * sizeof(bigint_word_t) <= 4){ - uint64_t r = 0; - memcpy(&r, a->wordv, a->length_W * sizeof(bigint_word_t)); - r = r * r; - memcpy(dest->wordv, &r, 2 * a->length_W * sizeof(bigint_word_t)); - SET_POS(dest); - dest->length_W = 2 * a->length_W; - bigint_adjust(dest); - return; - } - - if(dest->wordv == a->wordv){ - bigint_t d; - bigint_word_t d_b[a->length_W*2]; - d.wordv = d_b; - bigint_square(&d, a); - bigint_copy(dest, &d); - return; - } - bigint_fast_square(dest, a); - return; -#if DEBUG_SQUARE - square_depth += 1; -#endif - - bigint_length_t n; - n=(a->length_W+1)/2; - bigint_t xh, xl, tmp; /* x-high, x-low, temp */ - bigint_word_t buffer[2*n+1]; - xl.wordv = a->wordv; - xl.length_W = n; - xl.info = 0; - xh.wordv = &(a->wordv[n]); - xh.length_W = a->length_W-n; - xh.info = a->info; - bigint_adjust(&xl); - tmp.wordv = buffer; -/* (xh * b**n + xl)**2 = xh**2 * b**2n + 2 * xh * xl * b**n + xl**2 */ -#if DEBUG_SQUARE - if(square_depth == 1){ - cli_putstr("\r\nDBG (a): xl: "); bigint_print_hex(&xl); - cli_putstr("\r\nDBG (b): xh: "); bigint_print_hex(&xh); - } -#endif - bigint_square(dest, &xl); -#if DEBUG_SQUARE - if(square_depth == 1){ - cli_putstr("\r\nDBG (1): xl**2: "); bigint_print_hex(dest); - } -#endif - bigint_square(&tmp, &xh); -#if DEBUG_SQUARE - if(square_depth == 1){ - cli_putstr("\r\nDBG (2): xh**2: "); bigint_print_hex(&tmp); - } -#endif - bigint_add_scale_u(dest, &tmp, 2 * n * sizeof(bigint_word_t)); -#if DEBUG_SQUARE - if(square_depth == 1){ - cli_putstr("\r\nDBG (3): xl**2 + xh**2*n**2: "); bigint_print_hex(dest); - } -#endif - bigint_mul_u(&tmp, &xl, &xh); -#if DEBUG_SQUARE - if(square_depth == 1){ - cli_putstr("\r\nDBG (4): xl*xh: "); bigint_print_hex(&tmp); - } -#endif - bigint_shiftleft(&tmp, 1); -#if DEBUG_SQUARE - if(square_depth == 1){ - cli_putstr("\r\nDBG (5): xl*xh*2: "); bigint_print_hex(&tmp); - } -#endif - bigint_add_scale_u(dest, &tmp, n * sizeof(bigint_word_t)); -#if DEBUG_SQUARE - if(square_depth == 1){ - cli_putstr("\r\nDBG (6): x**2: "); bigint_print_hex(dest); - cli_putstr("\r\n"); - } - square_depth -= 1; -#endif -} - -#endif - /******************************************************************************/ void bigint_square(bigint_t *dest, const bigint_t *a) { @@ -697,7 +653,7 @@ void bigint_square(bigint_t *dest, const bigint_t *a) { bigint_word_t q, c1, c2; bigint_length_t i, j; - if(a->length_W * sizeof(bigint_word_t) <= 4){ + if (a->length_W * sizeof(bigint_word_t) <= 4) { uint64_t r = 0; memcpy(&r, a->wordv, a->length_W * sizeof(bigint_word_t)); r = r * r; @@ -708,12 +664,13 @@ void bigint_square(bigint_t *dest, const bigint_t *a) { return; } - if(dest->wordv == a->wordv){ + if (dest->wordv == a->wordv) { bigint_t d; - bigint_word_t d_b[a->length_W * 2]; - d.wordv = d_b; + ALLOC_BIGINT_WORDS(d_w, a->length_W * 2); + d.wordv = d_w; bigint_square(&d, a); bigint_copy(dest, &d); + FREE(d_w); return; } @@ -750,7 +707,6 @@ void bigint_square(bigint_t *dest, const bigint_t *a) { void bigint_sub_u_bitscale(bigint_t *a, const bigint_t *b, bigint_length_t bitscale){ bigint_t tmp, x; - bigint_word_t tmp_b[b->length_W + 1]; const bigint_length_t word_shift = bitscale / BIGINT_WORD_SIZE; if (a->length_W < b->length_W + word_shift) { @@ -760,15 +716,17 @@ void bigint_sub_u_bitscale(bigint_t *a, const bigint_t *b, bigint_length_t bitsc bigint_set_zero(a); return; } - tmp.wordv = tmp_b; + ALLOC_BIGINT_WORDS(tmp_w, b->length_W + 1); + tmp.wordv = tmp_w; bigint_copy(&tmp, b); - bigint_shiftleft(&tmp, bitscale % BIGINT_WORD_SIZE); + bigint_shiftleft_bits(&tmp, bitscale % BIGINT_WORD_SIZE); x.info = a->info; x.wordv = &(a->wordv[word_shift]); x.length_W = a->length_W - word_shift; bigint_sub_u(&x, &x, &tmp); + FREE(tmp_w); bigint_adjust(a); return; } @@ -813,15 +771,26 @@ void bigint_reduce(bigint_t *a, const bigint_t *r){ /* calculate dest = a**exp % r */ /* using square&multiply */ void bigint_expmod_u_sam(bigint_t *dest, const bigint_t *a, const bigint_t *exp, const bigint_t *r){ - if(a->length_W == 0){ + if (a->length_W == 0) { bigint_set_zero(dest); return; } - bigint_t res, base; - bigint_word_t t, base_w[MAX(a->length_W, r->length_W)], res_w[r->length_W * 2]; + if(exp->length_W == 0){ + dest->info = 0; + dest->length_W = 1; + dest->wordv[0] = 1; + return; + } + + bigint_t res, base; + bigint_word_t t; bigint_length_t i; uint8_t j; + + ALLOC_BIGINT_WORDS(base_w, MAX(a->length_W, r->length_W)); + ALLOC_BIGINT_WORDS(res_w, r->length_W * 2); + res.wordv = res_w; base.wordv = base_w; bigint_copy(&base, a); @@ -830,10 +799,6 @@ void bigint_expmod_u_sam(bigint_t *dest, const bigint_t *a, const bigint_t *exp, res.length_W = 1; res.info = 0; bigint_adjust(&res); - if(exp->length_W == 0){ - bigint_copy(dest, &res); - return; - } bigint_copy(&res, &base); uint8_t flag = 0; t = exp->wordv[exp->length_W - 1]; @@ -858,6 +823,8 @@ void bigint_expmod_u_sam(bigint_t *dest, const bigint_t *a, const bigint_t *exp, SET_POS(&res); bigint_copy(dest, &res); + FREE(res_w); + FREE(base_w); } /******************************************************************************/ @@ -932,16 +899,17 @@ void bigint_gcdext(bigint_t *gcd, bigint_t *a, bigint_t *b, const bigint_t *x, c y_.length_W = y->length_W - i; memcpy(x_.wordv, x->wordv + i, x_.length_W * sizeof(bigint_word_t)); memcpy(y_.wordv, y->wordv + i, y_.length_W * sizeof(bigint_word_t)); - for(i = 0; (x_.wordv[0] & (1 << i)) == 0 && (y_.wordv[0] & (1 << i)) == 0; ++i){ - } + + for(i = 0; (x_.wordv[0] & ((bigint_word_t)1 << i)) == 0 && (y_.wordv[0] & ((bigint_word_t)1 << i)) == 0; ++i) + ; bigint_adjust(&x_); bigint_adjust(&y_); if(i){ - bigint_shiftleft(&g, i); - bigint_shiftright(&x_, i); - bigint_shiftright(&y_, i); + bigint_shiftleft_bits(&g, i); + bigint_shiftright_bits(&x_, i); + bigint_shiftright_bits(&y_, i); } u.wordv = u_b; @@ -963,22 +931,22 @@ void bigint_gcdext(bigint_t *gcd, bigint_t *a, bigint_t *b, const bigint_t *x, c bigint_set_zero(&c_); do { while ((u.wordv[0] & 1) == 0) { - bigint_shiftright(&u, 1); + bigint_shiftright_1bit(&u); if((a_.wordv[0] & 1) || (b_.wordv[0] & 1)){ bigint_add_s(&a_, &a_, &y_); bigint_sub_s(&b_, &b_, &x_); } - bigint_shiftright(&a_, 1); - bigint_shiftright(&b_, 1); + bigint_shiftright_1bit(&a_); + bigint_shiftright_1bit(&b_); } while ((v.wordv[0] & 1) == 0) { - bigint_shiftright(&v, 1); + bigint_shiftright_1bit(&v); if((c_.wordv[0] & 1) || (d_.wordv[0] & 1)){ bigint_add_s(&c_, &c_, &y_); bigint_sub_s(&d_, &d_, &x_); } - bigint_shiftright(&c_, 1); - bigint_shiftright(&d_, 1); + bigint_shiftright_1bit(&c_); + bigint_shiftright_1bit(&d_); } if(bigint_cmp_u(&u, &v) >= 0){ bigint_sub_u(&u, &u, &v); @@ -1014,9 +982,9 @@ void bigint_inverse(bigint_t *dest, const bigint_t *a, const bigint_t *m){ void bigint_changeendianess(bigint_t *a){ uint8_t t, *p, *q; - p = (uint8_t*)(a->wordv); + p = (uint8_t*)a->wordv; q = p + a->length_W * sizeof(bigint_word_t) - 1; - while(plength_W, b->length_W), m->length_W); bigint_t u, t; - bigint_word_t u_w[s + 2], t_w[s + 2]; + bigint_length_t i; if (a->length_W == 0 || b->length_W == 0) { bigint_set_zero(dest); return; } + ALLOC_BIGINT_WORDS(u_w, s + 2); + ALLOC_BIGINT_WORDS(t_w, s + 2); u.wordv = u_w; u.info = 0; u.length_W = 0; @@ -1087,7 +1057,7 @@ void bigint_mont_mul(bigint_t *dest, const bigint_t *a, const bigint_t *b, const bigint_mul_word_u(&t, u.wordv[0]); bigint_add_u(&u, &u, &t); } - bigint_shiftright(&u, BIGINT_WORD_SIZE); + bigint_shiftright_1word(&u); } for (; i < s; ++i) { bigint_copy(&t, m_); @@ -1095,25 +1065,29 @@ void bigint_mont_mul(bigint_t *dest, const bigint_t *a, const bigint_t *b, const bigint_mul_word_u(&t, u.wordv[0]); bigint_add_u(&u, &u, &t); } - bigint_shiftright(&u, BIGINT_WORD_SIZE); + bigint_shiftright_1word(&u); } bigint_reduce(&u, m); bigint_copy(dest, &u); + FREE(t_w); + FREE(u_w); } /******************************************************************************/ void bigint_mont_red(bigint_t *dest, const bigint_t *a, const bigint_t *m, const bigint_t *m_){ bigint_t u, t; - bigint_length_t i, s = MAX(a->length_W, m->length_W); - bigint_word_t u_w[s + 2], t_w[s + 2]; + bigint_length_t i, s = MAX(a->length_W, MAX(m->length_W, m_->length_W)); - t.wordv = t_w; - u.wordv = u_w; if (a->length_W == 0) { bigint_set_zero(dest); return; } + + ALLOC_BIGINT_WORDS(u_w, s + 2); + ALLOC_BIGINT_WORDS(t_w, s + 2); + t.wordv = t_w; + u.wordv = u_w; bigint_copy(&u, a); for (i = 0; i < m->length_W; ++i) { bigint_copy(&t, m_); @@ -1121,10 +1095,12 @@ void bigint_mont_red(bigint_t *dest, const bigint_t *a, const bigint_t *m, const bigint_mul_word_u(&t, u.wordv[0]); bigint_add_u(&u, &u, &t); } - bigint_shiftright(&u, BIGINT_WORD_SIZE); + bigint_shiftright_1word(&u); } bigint_reduce(&u, m); bigint_copy(dest, &u); + FREE(t_w); + FREE(u_w); } /******************************************************************************/ @@ -1134,20 +1110,17 @@ void bigint_mont_red(bigint_t *dest, const bigint_t *a, const bigint_t *m, const void bigint_mont_gen_m_(bigint_t* dest, const bigint_t* m){ bigint_word_t x_w[2], m_w_0[1]; bigint_t x, m_0; -#if 0 - printf("\nm = "); - bigint_print_hex(m); - uart0_flush(); -#endif if (m->length_W == 0) { bigint_set_zero(dest); return; } if ((m->wordv[0] & 1) == 0) { +#if DEBUG printf_P(PSTR("ERROR: m must not be even, m = ")); bigint_print_hex(m); putchar('\n'); uart0_flush(); +#endif return; } x.wordv = x_w; @@ -1161,27 +1134,11 @@ void bigint_mont_gen_m_(bigint_t* dest, const bigint_t* m){ m_0.wordv[0] = m->wordv[0]; bigint_adjust(&x); bigint_adjust(&m_0); -#if 0 - printf("\nm0 = "); - bigint_print_hex(&m_0); - printf("\nx = "); - bigint_print_hex(&x); - uart0_flush(); -#endif bigint_inverse(dest, &m_0, &x); -#if 0 - printf("\nm0^-1 = "); - bigint_print_hex(dest); - uart0_flush(); -#endif bigint_sub_s(&x, &x, dest); -#if 0 - printf("\n-m0^-1 = "); - bigint_print_hex(&x); - uart0_flush(); -#endif bigint_copy(dest, m); bigint_mul_word_u(dest, x.wordv[0]); + } /******************************************************************************/ @@ -1191,8 +1148,8 @@ void bigint_mont_gen_m_(bigint_t* dest, const bigint_t* m){ */ void bigint_mont_trans(bigint_t *dest, const bigint_t *a, const bigint_t *m){ bigint_t t; - bigint_word_t t_w[a->length_W + m->length_W]; + ALLOC_BIGINT_WORDS(t_w, a->length_W + m->length_W); t.wordv = t_w; memset(t_w, 0, m->length_W * sizeof(bigint_word_t)); memcpy(&t_w[m->length_W], a->wordv, a->length_W * sizeof(bigint_word_t)); @@ -1200,6 +1157,7 @@ void bigint_mont_trans(bigint_t *dest, const bigint_t *a, const bigint_t *m){ t.length_W = a->length_W + m->length_W; bigint_reduce(&t, m); bigint_copy(dest, &t); + FREE(t_w); } /******************************************************************************/ @@ -1211,41 +1169,48 @@ void bigint_expmod_u_mont_accel(bigint_t *dest, const bigint_t *a, const bigint_ return; } - bigint_length_t s = r->length_W; bigint_t res, ax; - bigint_word_t t, res_w[r->length_W * 2], ax_w[MAX(s, a->length_W)]; + bigint_word_t t; bigint_length_t i; uint8_t j; + if (exp->length_W == 0) { + dest->length_W = 1; + dest->info = 0; + dest->wordv[0] = 1; + return; + } + + ALLOC_BIGINT_WORDS(res_w, r->length_W * 2); + ALLOC_BIGINT_WORDS(ax_w, MAX(r->length_W, a->length_W)); + res.wordv = res_w; ax.wordv = ax_w; res.wordv[0] = 1; res.length_W = 1; res.info = 0; - bigint_adjust(&res); - if (exp->length_W == 0) { - bigint_copy(dest, &res); - return; - } + bigint_copy(&ax, a); bigint_reduce(&ax, r); + bigint_mont_trans(&ax, &ax, r); bigint_mont_trans(&res, &res, r); + uint8_t flag = 0; t = exp->wordv[exp->length_W - 1]; for (i = exp->length_W; i > 0; --i) { t = exp->wordv[i - 1]; for(j = BIGINT_WORD_SIZE; j > 0; --j){ if (!flag) { - if(t & (((bigint_wordplus_t)1) << (BIGINT_WORD_SIZE - 1))){ + if(t & (((bigint_word_t)1) << (BIGINT_WORD_SIZE - 1))){ flag = 1; } } if (flag) { bigint_square(&res, &res); bigint_mont_red(&res, &res, r, m_); - if (t & (((bigint_wordplus_t)1) << (BIGINT_WORD_SIZE - 1))) { + if (t & (((bigint_word_t)1) << (BIGINT_WORD_SIZE - 1))) { bigint_mont_mul(&res, &res, &ax, r, m_); } } @@ -1254,6 +1219,8 @@ void bigint_expmod_u_mont_accel(bigint_t *dest, const bigint_t *a, const bigint_ } SET_POS(&res); bigint_mont_red(dest, &res, r, m_); + FREE(ax_w); + FREE(res_w); } /******************************************************************************/ @@ -1278,7 +1245,15 @@ void bigint_expmod_u_mont_sam(bigint_t *dest, const bigint_t *a, const bigint_t #endif void bigint_expmod_u(bigint_t *dest, const bigint_t *a, const bigint_t *exp, const bigint_t *r){ - if (r->wordv[0] & 1) { +#if 0 + printf("\nDBG: expmod_u (a ** e %% m) <%s %s %d>\n\ta: ", __FILE__, __func__, __LINE__); + bigint_print_hex(a); + printf("\n\te: "); + bigint_print_hex(exp); + printf("\n\tm: "); + bigint_print_hex(r); +#endif + if (0 && r->wordv[0] & 1) { bigint_expmod_u_mont_sam(dest, a, exp, r); } else { bigint_expmod_u_sam(dest, a, exp, r); diff --git a/bigint/bigint.h b/bigint/bigint.h index 82e10f8..d07713d 100644 --- a/bigint/bigint.h +++ b/bigint/bigint.h @@ -32,7 +32,7 @@ #include #include -#define BIGINT_WORD_SIZE 8 +#define BIGINT_WORD_SIZE 32 #if BIGINT_WORD_SIZE == 8 typedef uint8_t bigint_word_t; @@ -80,7 +80,10 @@ int8_t bigint_cmp_u(const bigint_t * a, const bigint_t * b); void bigint_add_s(bigint_t *dest, const bigint_t *a, const bigint_t *b); void bigint_sub_s(bigint_t *dest, const bigint_t *a, const bigint_t *b); int8_t bigint_cmp_s(const bigint_t *a, const bigint_t *b); +void bigint_shiftleft_bits(bigint_t *a, uint8_t shift); void bigint_shiftleft(bigint_t *a, bigint_length_t shift); +void bigint_shiftright_1bit(bigint_t *a); +void bigint_shiftright_1word(bigint_t *a); void bigint_shiftright(bigint_t *a, bigint_length_t shift); void bigint_xor(bigint_t *dest, const bigint_t *a); void bigint_set_zero(bigint_t *a); diff --git a/rsa/rsa_basic.c b/rsa/rsa_basic.c index a7c3219..345668e 100644 --- a/rsa/rsa_basic.c +++ b/rsa/rsa_basic.c @@ -189,13 +189,12 @@ void rsa_os2ip(bigint_t *dest, const void *data, uint32_t length_B){ uint8_t off; off = (sizeof(bigint_word_t) - length_B % sizeof(bigint_word_t)) % sizeof(bigint_word_t); #if DEBUG - cli_putstr_P(PSTR("\r\nDBG: off = 0x")); - cli_hexdump_byte(off); + printf("\r\nDBG: off = 0x%02x", off); #endif - if(!data){ - if(off){ + if (!data) { + if (off) { dest->wordv = realloc(dest->wordv, length_B + sizeof(bigint_word_t) - off); - memmove((uint8_t*)dest->wordv+off, dest->wordv, length_B); + memmove((uint8_t*)dest->wordv + off, dest->wordv, length_B); memset(dest->wordv, 0, off); } }else{ @@ -206,8 +205,7 @@ void rsa_os2ip(bigint_t *dest, const void *data, uint32_t length_B){ } dest->length_W = (length_B + off) / sizeof(bigint_word_t); #if DEBUG - cli_putstr_P(PSTR("\r\nDBG: dest->length_W = 0x")); - cli_hexdump_rev(&(dest->length_W), 2); + printf("\r\nDBG: dest->length_W = %u", dest->length_W); #endif #endif dest->info = 0; diff --git a/rsa/rsaes_pkcs1v15.c b/rsa/rsaes_pkcs1v15.c index 4f41546..2825b81 100644 --- a/rsa/rsaes_pkcs1v15.c +++ b/rsa/rsaes_pkcs1v15.c @@ -79,7 +79,7 @@ uint8_t rsa_encrypt_pkcs1v15(void *dest, uint16_t *out_length, const void *src, cli_hexdump_block(x.wordv, x.length_W * sizeof(bigint_word_t), 4, 16); #endif bigint_adjust(&x); - rsa_os2ip(&x, NULL, length_B+pad_length+3); + rsa_os2ip(&x, NULL, length_B + pad_length + 3); rsa_enc(&x, key); rsa_i2osp(NULL, &x, out_length); return 0; diff --git a/test_src/main-rsaes_pkcs1v15-test.c b/test_src/main-rsaes_pkcs1v15-test.c index 8933b02..1411b93 100644 --- a/test_src/main-rsaes_pkcs1v15-test.c +++ b/test_src/main-rsaes_pkcs1v15-test.c @@ -588,11 +588,13 @@ void quick_test(void){ cli_putstr_P(PSTR("\r\n\r\nciphertext:")); cli_hexdump_block(ciphertext, clen, 4, 16); - if(clen!=sizeof(ENCRYPTED)){ + if(clen != sizeof(ENCRYPTED)){ cli_putstr_P(PSTR("\r\n>>FAIL (no size match)<<")); + return; }else{ if(memcmp_P(ciphertext, ENCRYPTED, clen)){ cli_putstr_P(PSTR("\r\n>>FAIL (no content match)<<")); + return; }else{ cli_putstr_P(PSTR("\r\n>>OK<<")); } -- 2.39.2