diff options
30 files changed, 1206 insertions, 342 deletions
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f2e480d..f9325ca 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,11 +4,11 @@ on: [push, pull_request] env: LLVM_CONFIG: /usr/bin/llvm-config-14 - OTHER_PACKAGES: swig libpcsclite-dev llvm-14 libllvm14 llvm-14-dev valgrind gcc gcc-arm-none-eabi binutils-arm-none-eabi libnewlib-arm-none-eabi + OTHER_PACKAGES: swig libpcsclite-dev llvm-14 libllvm14 llvm-14-dev valgrind gdb gcc gcc-arm-none-eabi binutils-arm-none-eabi libnewlib-arm-none-eabi jobs: test: - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12"] @@ -50,7 +50,7 @@ jobs: make test - name: Test (C) run: | - cd test && make test_bn && ./test_bn && cd .. + cd test && make test && cd .. - name: Code coverage uses: codecov/codecov-action@v3 if: ${{ matrix.python-version == 3.9 }} @@ -4,7 +4,6 @@ htmlcov/ /build/ __pycache__ -/test/test_bn pyecsca-codegen-*.elf pyecsca-codegen-*.hex
\ No newline at end of file @@ -1,5 +1,3 @@ -TESTS = test_builder test_client test_render test_impl test_simulator - test: pytest -m "not slow" --cov=pyecsca.codegen diff --git a/ext/libtommath b/ext/libtommath -Subproject e9b2847b8ddf800e44f486966cfa71b82339248 +Subproject c71669aa4e8d9a9ab3076a8398484e3d3d0c04f diff --git a/pyecsca/codegen/bn/bn.c b/pyecsca/codegen/bn/bn.c index 148403c..d7a44b6 100644 --- a/pyecsca/codegen/bn/bn.c +++ b/pyecsca/codegen/bn/bn.c @@ -27,6 +27,8 @@ void math_init(void) { #endif //TODO: COMBA } +const int bn_digit_bits __attribute__((used)) = MP_DIGIT_BIT; + bn_err bn_init(bn_t *bn) { return mp_init(bn); } @@ -47,6 +49,10 @@ bn_err bn_from_hex(const char *data, bn_t *out) { return mp_read_radix(out, data, 16); } +bn_err bn_from_dec(const char *data, bn_t *out) { + return mp_read_radix(out, data, 10); +} + bn_err bn_from_int(unsigned int value, bn_t *out) { if (sizeof(unsigned int) == 8) { mp_set_u64(out, value); @@ -394,6 +400,9 @@ wnaf_t *bn_wnaf(const bn_t *bn, int w) { } wnaf_t *result = NULL; + size_t bits = bn_bit_length(bn) + 1; + int8_t arr[bits]; + bn_t half_width; if (mp_init(&half_width) != BN_OKAY) { return NULL; @@ -418,38 +427,38 @@ wnaf_t *bn_wnaf(const bn_t *bn, int w) { goto exit_val_mod; } - result = malloc(sizeof(wnaf_t)); - result->w = w; - result->length = bn_bit_length(bn) + 1; - result->data = calloc(result->length, sizeof(int8_t)); - size_t i = 0; - while (!bn_is_0(&k) && !(bn_get_sign(&k) == BN_NEG)) { + while (mp_cmp_d(&k, 0) == MP_GT) { if (bn_get_bit(&k, 0) == 1) { bn_mod(&k, &full_width, &val_mod); if (mp_cmp(&val_mod, &half_width) == MP_GT) { if (mp_sub(&val_mod, &full_width, &val_mod) != BN_OKAY) { - free(result->data); - free(result); - result = NULL; - break; + goto exit_result; } } int8_t val = (int8_t) mp_get_i32(&val_mod); - result->data[i++] = val; + arr[i++] = val; if (mp_sub(&k, &val_mod, &k) != BN_OKAY) { - free(result->data); - free(result); - result = NULL; - break; + goto exit_result; } } else { - result->data[i++] = 0; + arr[i++] = 0; } bn_rsh(&k, 1, &k); } - bn_clear(&val_mod); + result = malloc(sizeof(wnaf_t)); + result->w = w; + result->length = i; + result->data = calloc(result->length, sizeof(int8_t)); + + // Revert + for (size_t j = 0; j < i; j++) { + result->data[j] = arr[i - j - 1]; + } + +exit_result: + bn_clear(&val_mod); exit_val_mod: bn_clear(&k); exit_k: @@ -463,6 +472,90 @@ wnaf_t *bn_bnaf(const bn_t *bn) { return bn_wnaf(bn, 2); } +void bn_naf_pad_left(wnaf_t *naf, int8_t value, size_t amount) { + if (amount == 0) { + return; + } + int8_t *new_data = calloc(naf->length + amount, sizeof(int8_t)); + for (size_t i = 0; i < naf->length; i++) { + new_data[i + amount] = naf->data[i]; + } + for (size_t i = 0; i < amount; i++) { + new_data[i] = value; + } + free(naf->data); + naf->data = new_data; + naf->length += amount; +} + +void bn_naf_pad_right(wnaf_t *naf, int8_t value, size_t amount) { + if (amount == 0) { + return; + } + naf->data = realloc(naf->data, (naf->length + amount) * sizeof(int8_t)); + for (size_t i = naf->length; i < naf->length + amount; i++) { + naf->data[i] = value; + } + naf->length += amount; +} + +void bn_naf_strip_left(wnaf_t *naf, int8_t value) { + size_t i = 0; + while (i < naf->length && naf->data[i] == value) { + i++; + } + if (i == 0) { + return; + } + if (i == naf->length) { + free(naf->data); + naf->data = NULL; + naf->length = 0; + return; + } + int8_t *new_data = calloc(naf->length - i, sizeof(int8_t)); + for (size_t j = 0; j < naf->length - i; j++) { + new_data[j] = naf->data[j + i]; + } + free(naf->data); + naf->data = new_data; + naf->length -= i; +} + +void bn_naf_strip_right(wnaf_t *naf, int8_t value) { + size_t i = naf->length; + while (i > 0 && naf->data[i - 1] == value) { + i--; + } + if (i == naf->length) { + return; + } + if (i == 0) { + free(naf->data); + naf->data = NULL; + naf->length = 0; + return; + } + naf->data = realloc(naf->data, i * sizeof(int8_t)); + naf->length = i; +} + +void bn_naf_reverse(wnaf_t *naf) { + for (size_t i = 0; i < naf->length / 2; i++) { + int8_t temp = naf->data[i]; + naf->data[i] = naf->data[naf->length - i - 1]; + naf->data[naf->length - i - 1] = temp; + } +} + +void bn_naf_clear(wnaf_t *naf) { + if (naf == NULL) { + return; + } + free(naf->data); + free(naf); +} + wsliding_t *bn_wsliding_ltr(const bn_t *bn, int w) { if (w > 8 || w < 2) { return NULL; @@ -540,8 +633,8 @@ wsliding_t *bn_wsliding_rtl(const bn_t *bn, int w) { wsliding_t *result = NULL; int blen = bn_bit_length(bn); - uint8_t arr[blen + 2]; - memset(arr, 0, (blen + 2) * sizeof(uint8_t)); + uint8_t arr[blen + w]; + memset(arr, 0, (blen + w) * sizeof(uint8_t)); bn_t k; if (mp_init(&k) != BN_OKAY) { @@ -555,7 +648,7 @@ wsliding_t *bn_wsliding_rtl(const bn_t *bn, int w) { } int i = 0; - while (!bn_is_0(&k) && !(bn_get_sign(&k) == BN_NEG)) { + while (mp_cmp_d(&k, 0) == MP_GT) { if (!bn_get_bit(&k, 0)) { arr[i++] = 0; bn_rsh(&k, 1, &k); @@ -594,6 +687,14 @@ exit_k: return result; } +void bn_wsliding_clear(wsliding_t *wsliding) { + if (wsliding == NULL) { + return; + } + free(wsliding->data); + free(wsliding); +} + small_base_t *bn_convert_base_small(const bn_t *bn, int m) { small_base_t *result = NULL; @@ -604,7 +705,9 @@ small_base_t *bn_convert_base_small(const bn_t *bn, int m) { bn_copy(bn, &k); int len = 0; - if (mp_log_n(&k, m, &len) != BN_OKAY) { + if (mp_cmp_d(&k, 0) == MP_EQ) { + len = 0; + } else if (mp_log_n(&k, m, &len) != BN_OKAY) { goto exit_len; } @@ -630,6 +733,14 @@ exit_k: return result; } +void bn_small_base_clear(small_base_t *sb) { + if (sb == NULL) { + return; + } + free(sb->data); + free(sb); +} + large_base_t *bn_convert_base_large(const bn_t *bn, const bn_t *m) { large_base_t *result = NULL; @@ -640,7 +751,9 @@ large_base_t *bn_convert_base_large(const bn_t *bn, const bn_t *m) { bn_copy(bn, &k); int len = 0; - if (mp_log(&k, m, &len) != BN_OKAY) { + if (mp_cmp_d(&k, 0) == MP_EQ) { + len = 0; + } else if (mp_log(&k, m, &len) != BN_OKAY) { goto exit_len; } @@ -666,4 +779,72 @@ exit_len: bn_clear(&k); exit_k: return result; +} + +void bn_large_base_clear(large_base_t *lb) { + if (lb == NULL) { + return; + } + for (int i = 0; i < lb->length; i++) { + bn_clear(&lb->data[i]); + } + free(lb->data); + bn_clear(&lb->m); + free(lb); +} + +int32_t bn_booth_word(int32_t digit, int32_t w) { + int32_t s = ~((digit >> w) - 1); //s = ~((digit >> w) - 1) + int32_t d = (1 << (w + 1)) - digit - 1; //d = (1 << (w + 1)) - digit - 1 + d = (d & s) | (digit & ~s); // d = (d & s) | (digit & ~s) + d = (d >> 1) + (d & 1); //d = (d >> 1) + (d & 1) + + if (s) { //return -d if s else d + return -d; + } else { + return d; + } +} + +booth_t *bn_booth(const bn_t *bn, int32_t w, size_t bits) { + if (w >= 30) { + return NULL; + } + int32_t mask = (1 << (w + 1)) - 1; + bn_t d, m; + bn_init(&d); + bn_init(&m); + bn_from_int(mask, &m); + + size_t len = (bits / w) + 1; + booth_t *result = malloc(sizeof(booth_t)); + result->length = len; + result->w = w; + result->data = calloc(len, sizeof(int32_t)); + + long l = 0; + for (long i = bits + (w - (bits % w) - 1); i > 0; i -= w) { + int32_t digit; + bn_copy(bn, &d); + if (i >= w) { + bn_rsh(&d, i - w, &d); + } else { + bn_lsh(&d, w - i, &d); + } + bn_and(&d, &m, &d); + digit = bn_to_int(&d); + int32_t val = bn_booth_word(digit, w); + result->data[l++] = val; + } + bn_clear(&d); + bn_clear(&m); + return result; +} + +void bn_booth_clear(booth_t *booth) { + if (booth == NULL) { + return; + } + free(booth->data); + free(booth); }
\ No newline at end of file diff --git a/pyecsca/codegen/bn/bn.h b/pyecsca/codegen/bn/bn.h index 7c25c22..03526b0 100644 --- a/pyecsca/codegen/bn/bn.h +++ b/pyecsca/codegen/bn/bn.h @@ -76,8 +76,16 @@ typedef struct { bn_t m; } large_base_t; +typedef struct { + int32_t *data; + size_t length; + int w; +} booth_t; + void math_init(void); +extern const int bn_digit_bits; + bn_err bn_init(bn_t *bn); #define bn_init_multi mp_init_multi bn_err bn_copy(const bn_t *from, bn_t *to); @@ -86,6 +94,7 @@ void bn_clear(bn_t *bn); bn_err bn_from_bin(const uint8_t *data, size_t size, bn_t *out); bn_err bn_from_hex(const char *data, bn_t *out); +bn_err bn_from_dec(const char *data, bn_t *out); bn_err bn_from_int(unsigned int value, bn_t *out); bn_err bn_to_binpad(const bn_t *one, uint8_t *data, size_t size); @@ -135,11 +144,24 @@ int bn_bit_length(const bn_t *bn); wnaf_t *bn_wnaf(const bn_t *bn, int w); wnaf_t *bn_bnaf(const bn_t *bn); +void bn_naf_pad_left(wnaf_t *naf, int8_t value, size_t amount); +void bn_naf_pad_right(wnaf_t *naf, int8_t value, size_t amount); +void bn_naf_strip_left(wnaf_t *naf, int8_t value); +void bn_naf_strip_right(wnaf_t *naf, int8_t value); +void bn_naf_reverse(wnaf_t *naf); +void bn_naf_clear(wnaf_t *naf); wsliding_t *bn_wsliding_ltr(const bn_t *bn, int w); wsliding_t *bn_wsliding_rtl(const bn_t *bn, int w); +void bn_wsliding_clear(wsliding_t *wsliding); small_base_t *bn_convert_base_small(const bn_t *bn, int m); +void bn_small_base_clear(small_base_t *sb); large_base_t *bn_convert_base_large(const bn_t *bn, const bn_t *m); +void bn_large_base_clear(large_base_t *lb); + +int32_t bn_booth_word(int32_t digit, int32_t w); +booth_t *bn_booth(const bn_t *bn, int32_t w, size_t bits); +void bn_booth_clear(booth_t *booth); #endif //BN_H_
\ No newline at end of file diff --git a/pyecsca/codegen/builder.py b/pyecsca/codegen/builder.py index 63e7801..b7b5e22 100644 --- a/pyecsca/codegen/builder.py +++ b/pyecsca/codegen/builder.py @@ -81,7 +81,7 @@ def get_multiplier(ctx: click.Context, param, value: Optional[str]) -> Optional[ if value is None: return None res = re.match( - "(?P<name>[a-zA-Z\-]+)\((?P<args>([a-zA-Z_]+ *= *[a-zA-Z0-9.]+, ?)*?([a-zA-Z_]+ *= *[a-zA-Z0-9.]+)*)\)", + r"(?P<name>[a-zA-Z\-]+)\((?P<args>([a-zA-Z_]+ *= *[a-zA-Z0-9.]+, ?)*?([a-zA-Z_]+ *= *[a-zA-Z0-9.]+)*)\)", value) if not res: raise click.BadParameter("Couldn't parse multiplier spec: {}.".format(value)) diff --git a/pyecsca/codegen/render.py b/pyecsca/codegen/render.py index 692deab..b1c1477 100644 --- a/pyecsca/codegen/render.py +++ b/pyecsca/codegen/render.py @@ -30,7 +30,7 @@ from pyecsca.ec.mult import ( BGMWMultiplier, CombMultiplier, AccumulationOrder, - ProcessingDirection + ProcessingDirection, WindowBoothMultiplier ) from pyecsca.ec.op import OpType, CodeOp @@ -227,6 +227,7 @@ def render_scalarmult_impl(scalarmult: ScalarMultiplier) -> str: DifferentialLadderMultiplier=DifferentialLadderMultiplier, BinaryNAFMultiplier=BinaryNAFMultiplier, WindowNAFMultiplier=WindowNAFMultiplier, + WindowBoothMultiplier=WindowBoothMultiplier, SlidingWindowMultiplier=SlidingWindowMultiplier, FixedWindowLTRMultiplier=FixedWindowLTRMultiplier, FullPrecompMultiplier=FullPrecompMultiplier, diff --git a/pyecsca/codegen/templates/formula_add.c b/pyecsca/codegen/templates/formula_add.c index 6026601..48bab07 100644 --- a/pyecsca/codegen/templates/formula_add.c +++ b/pyecsca/codegen/templates/formula_add.c @@ -16,16 +16,17 @@ __attribute__((noinline)) void point_add(const point_t *one, const point_t *othe {%- if short_circuit %} if (point_equals(one, curve->neutral)) { point_set(other, out_one); - return; + goto end; } if (point_equals(other, curve->neutral)) { point_set(one, out_one); - return; + goto end; } {%- endif %} {{ ops.render_initializations(initializations) }} {{ ops.render_ops(operations) }} {{ ops.render_returns(returns) }} //NOP_128(); +end: {{ end_action("add") }} }
\ No newline at end of file diff --git a/pyecsca/codegen/templates/formula_dbl.c b/pyecsca/codegen/templates/formula_dbl.c index 451b0ee..e1cfa15 100644 --- a/pyecsca/codegen/templates/formula_dbl.c +++ b/pyecsca/codegen/templates/formula_dbl.c @@ -16,12 +16,13 @@ __attribute__((noinline)) void point_dbl(const point_t *one, const curve_t *curv {%- if short_circuit %} if (point_equals(one, curve->neutral)) { point_set(one, out_one); - return; + goto end; } {%- endif %} {{ ops.render_initializations(initializations) }} {{ ops.render_ops(operations) }} {{ ops.render_returns(returns) }} //NOP_128(); +end: {{ end_action("dbl") }} }
\ No newline at end of file diff --git a/pyecsca/codegen/templates/formula_neg.c b/pyecsca/codegen/templates/formula_neg.c index 93fbe20..fa96c63 100644 --- a/pyecsca/codegen/templates/formula_neg.c +++ b/pyecsca/codegen/templates/formula_neg.c @@ -16,12 +16,13 @@ __attribute__((noinline)) void point_neg(const point_t *one, const curve_t *curv {%- if short_circuit %} if (point_equals(one, curve->neutral)) { point_set(one, out_one); - return; + goto end; } {%- endif %} {{ ops.render_initializations(initializations) }} {{ ops.render_ops(operations) }} {{ ops.render_returns(returns) }} //NOP_128(); +end: {{ end_action("neg") }} }
\ No newline at end of file diff --git a/pyecsca/codegen/templates/formula_scl.c b/pyecsca/codegen/templates/formula_scl.c index 48ac52e..f1471a2 100644 --- a/pyecsca/codegen/templates/formula_scl.c +++ b/pyecsca/codegen/templates/formula_scl.c @@ -16,12 +16,13 @@ __attribute__((noinline)) void point_scl(const point_t *one, const curve_t *curv {%- if short_circuit %} if (point_equals(one, curve->neutral)) { point_set(one, out_one); - return; + goto end; } {%- endif %} {{ ops.render_initializations(initializations) }} {{ ops.render_ops(operations) }} {{ ops.render_returns(returns) }} //NOP_128(); +end: {{ end_action("scl") }} }
\ No newline at end of file diff --git a/pyecsca/codegen/templates/formula_tpl.c b/pyecsca/codegen/templates/formula_tpl.c index d280bad..0b4cd64 100644 --- a/pyecsca/codegen/templates/formula_tpl.c +++ b/pyecsca/codegen/templates/formula_tpl.c @@ -16,12 +16,13 @@ __attribute__((noinline)) void point_tpl(const point_t *one, const curve_t *curv {%- if short_circuit %} if (point_equals(one, curve->neutral)) { point_set(one, out_one); - return; + goto end; } {%- endif %} {{ ops.render_initializations(initializations) }} {{ ops.render_ops(operations) }} {{ ops.render_returns(returns) }} //NOP_128(); +end: {{ end_action("tpl") }} }
\ No newline at end of file diff --git a/pyecsca/codegen/templates/mult.c b/pyecsca/codegen/templates/mult.c index 0144e36..4070952 100644 --- a/pyecsca/codegen/templates/mult.c +++ b/pyecsca/codegen/templates/mult.c @@ -31,6 +31,10 @@ {% include "mult_wnaf.c" %} +{%- elif isinstance(scalarmult, WindowBoothMultiplier) -%} + + {% include "mult_booth.c" %} + {%- elif isinstance(scalarmult, SlidingWindowMultiplier) -%} {% include "mult_sliding_w.c" %} diff --git a/pyecsca/codegen/templates/mult_bgmw.c b/pyecsca/codegen/templates/mult_bgmw.c index 5298fb1..e2e8c72 100644 --- a/pyecsca/codegen/templates/mult_bgmw.c +++ b/pyecsca/codegen/templates/mult_bgmw.c @@ -48,8 +48,7 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin {%- endif %} point_accumulate(a, b, curve, a); } - free(bs->data); - free(bs); + bn_small_base_clear(bs); {%- if "scl" in scalarmult.formulas %} point_scl(a, curve, a); diff --git a/pyecsca/codegen/templates/mult_bnaf.c b/pyecsca/codegen/templates/mult_bnaf.c index 68d1569..9c760af 100644 --- a/pyecsca/codegen/templates/mult_bnaf.c +++ b/pyecsca/codegen/templates/mult_bnaf.c @@ -1,51 +1,109 @@ #include "mult.h" #include "point.h" -point_t *scalar_mult_ltr(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf) { - point_t *q = point_copy(curve->neutral); - for (long i = naf->length - 1; i >= 0; i--) { +point_t *scalar_mult_ltr(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf, size_t bits) { + point_t *q; + long i; + {% if scalarmult.complete %} + bn_naf_pad_left(naf, 0, (bits + 1) - naf->length); + q = point_copy(curve->neutral); + i = 0; + {% else %} + bn_naf_strip_left(naf, 0); + int8_t val = naf->data[0]; + if (val == 1) { + q = point_copy(point); + } else if (val == -1) { + q = point_copy(neg); + } + i = 1; + {% endif %} + + {% if scalarmult.always %} + point_t *q_copy = point_new(); + {% endif %} + for (; i < naf->length; i++) { point_dbl(q, curve, q); + + {% if scalarmult.always %} + point_set(q, q_copy); + {% endif %} + if (naf->data[i] == 1) { point_accumulate(q, point, curve, q); + {% if scalarmult.always %} + point_accumulate(q_copy, neg, curve, q_copy); + {% endif %} } else if (naf->data[i] == -1) { point_accumulate(q, neg, curve, q); + {% if scalarmult.always %} + point_accumulate(q_copy, point, curve, q_copy); + {% endif %} } } + {% if scalarmult.always %} + point_free(q_copy); + {% endif %} return q; } -point_t* scalar_mult_rtl(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf) { - point_t *r = point_copy(point); - point_t *q = point_copy(curve->neutral); - point_t *r_neg = point_new(); +point_t* scalar_mult_rtl(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf, size_t bits) { + {% if scalarmult.always %} + point_t *r_copy = point_new(); + {% endif %} + + {% if scalarmult.complete %} + bn_naf_pad_left(naf, 0, (bits + 1) - naf->length); + {% endif %} + + bn_naf_reverse(naf); + + point_t *q = point_copy(point); + point_t *r = point_copy(curve->neutral); + point_t *q_neg = point_new(); for (long i = 0; i < naf->length; i++) { + {% if scalarmult.always %} + point_set(r, r_copy); + {% endif %} + if (naf->data[i] == 1) { - point_accumulate(q, r, curve, q); + point_accumulate(r, q, curve, r); + {% if scalarmult.always %} + point_neg(q, curve, q_neg); + point_accumulate(r_copy, q_neg, curve, r_copy); + {% endif %} } else if (naf->data[i] == -1) { - point_neg(r, curve, r_neg); - point_accumulate(q, r_neg, curve, q); + point_neg(q, curve, q_neg); + point_accumulate(r, q_neg, curve, r); + {% if scalarmult.always %} + point_accumulate(r_copy, q, curve, r_copy); + {% endif %} } - point_dbl(r, curve, r); + point_dbl(q, curve, q); } - point_free(r_neg); - point_free(r); + point_free(q_neg); + point_free(q); - return q; + {% if scalarmult.always %} + point_free(r_copy); + {% endif %} + return r; } static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, point_t *out) { point_t *neg = point_new(); point_neg(point, curve, neg); + wnaf_t *naf = bn_bnaf(scalar); + size_t bits = bn_bit_length(&curve->n); {% if scalarmult.direction == ProcessingDirection.LTR %} - point_t *q = scalar_mult_ltr(point, neg, curve, naf); + point_t *q = scalar_mult_ltr(point, neg, curve, naf, bits); {% elif scalarmult.direction == ProcessingDirection.RTL %} - point_t *q = scalar_mult_rtl(point, neg, curve, naf); + point_t *q = scalar_mult_rtl(point, neg, curve, naf, bits); {% endif %} - free(naf->data); - free(naf); + bn_naf_clear(naf); {%- if "scl" in scalarmult.formulas %} point_scl(q, curve, q); diff --git a/pyecsca/codegen/templates/mult_booth.c b/pyecsca/codegen/templates/mult_booth.c new file mode 100644 index 0000000..4c1ba40 --- /dev/null +++ b/pyecsca/codegen/templates/mult_booth.c @@ -0,0 +1,78 @@ +#include "mult.h" +#include "point.h" + + + +static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, point_t *out) { + point_t *points[{{ 2 ** (scalarmult.width - 1) }}]; + {% if scalarmult.precompute_negation %} + point_t *points_neg[{{ 2 ** (scalarmult.width - 1) }}]; + {% endif %} + + point_t *current = point_copy(point); + point_t *dbl = point_new(); + point_dbl(current, curve, dbl); + points[0] = point_copy(current); + {% if scalarmult.precompute_negation %} + points_neg[0] = point_new(); + point_neg(points[0], curve, points_neg[0]); + {% endif %} + {% if scalarmult.width > 1 %} + points[1] = point_copy(dbl); + {% if scalarmult.precompute_negation %} + points_neg[1] = point_new(); + point_neg(points[1], curve, points_neg[1]); + {% endif %} + {% endif %} + + point_set(dbl, current); + {% if scalarmult.width > 2 %} + for (long i = 2; i < {{ 2 ** (scalarmult.width - 1) }}; i++) { + point_add(current, point, curve, current); + points[i] = point_copy(current); + {% if scalarmult.precompute_negation %} + points_neg[i] = point_new(); + point_neg(points[i], curve, points_neg[i]); + {% endif %} + } + {% endif %} + point_free(current); + point_free(dbl); + + size_t bits = bn_bit_length(&curve->n); + + booth_t *bs = bn_booth(scalar, {{ scalarmult.width }}, bits); + + point_t *q = point_copy(curve->neutral); + point_t *neg = point_new(); + for (long i = 0; i < bs->length; i++) { + for (long j = 0; j < {{ scalarmult.width }}; j++) { + point_dbl(q, curve, q); + } + int32_t val = bs->data[i]; + if (val > 0) { + point_accumulate(q, points[val - 1], curve, q); + } else if (val < 0) { + {% if scalarmult.precompute_negation %} + point_accumulate(q, points_neg[-val - 1], curve, q); + {% else %} + point_neg(points[-val - 1], curve, neg); + point_accumulate(q, neg, curve, q); + {% endif %} + } + } + bn_booth_clear(bs); + point_free(neg); + + {%- if "scl" in scalarmult.formulas %} + point_scl(q, curve, q); + {%- endif %} + point_set(q, out); + for (long i = 0; i < {{ 2 ** (scalarmult.width - 1) }}; i++) { + point_free(points[i]); + {% if scalarmult.precompute_negation %} + point_free(points_neg[i]); + {% endif %} + } + point_free(q); +}
\ No newline at end of file diff --git a/pyecsca/codegen/templates/mult_comb.c b/pyecsca/codegen/templates/mult_comb.c index 9df9796..1fbb5a3 100644 --- a/pyecsca/codegen/templates/mult_comb.c +++ b/pyecsca/codegen/templates/mult_comb.c @@ -39,6 +39,10 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin bn_from_int(1, &base); bn_lsh(&base, d, &base); + {% if scalarmult.always %} + point_t *dummy = point_new(); + {% endif %} + large_base_t *bs = bn_convert_base_large(scalar, &base); for (int i = d - 1; i >= 0; i--) { point_dbl(q, curve, q); @@ -50,14 +54,18 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin } if (word) { point_accumulate(q, points[word], curve, q); + } else { + {% if scalarmult.always %} + int j = i % {{ 2**scalarmult.width }}; + if (j == 0) { + point_accumulate(q, point, curve, dummy); + } else { + point_accumulate(q, points[j], curve, dummy); + } + {% endif %} } } - for (int i = 0; i < bs->length; i++) { - bn_clear(&bs->data[i]); - } - free(bs->data); - bn_clear(&bs->m); - free(bs); + bn_large_base_clear(bs); bn_clear(&base); diff --git a/pyecsca/codegen/templates/mult_fixed_w.c b/pyecsca/codegen/templates/mult_fixed_w.c index b0a4bb0..6a079b3 100644 --- a/pyecsca/codegen/templates/mult_fixed_w.c +++ b/pyecsca/codegen/templates/mult_fixed_w.c @@ -20,18 +20,22 @@ void scalar_mult_by_m_base(point_t *point, curve_t *curve) { static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, point_t *out) { point_t *q = point_copy(curve->neutral); - point_t *points[{{ scalarmult.m }}]; + point_t *points[{{ scalarmult.m - 1 }}]; point_t *current = point_copy(point); point_t *dbl = point_new(); point_dbl(current, curve, dbl); points[0] = point_copy(current); - points[1] = point_copy(dbl); + {% if scalarmult.m > 2 %} + points[1] = point_copy(dbl); + {% endif %} point_set(dbl, current); - for (long i = 2; i < {{ scalarmult.m }}; i++) { - point_add(current, point, curve, current); - points[i] = point_copy(current); - } + {% if scalarmult.m > 3 %} + for (long i = 2; i < {{ scalarmult.m - 1 }}; i++) { + point_add(current, point, curve, current); + points[i] = point_copy(current); + } + {% endif %} point_free(current); point_free(dbl); @@ -49,14 +53,13 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin point_accumulate(q, points[val-1], curve, q); } } - free(bs->data); - free(bs); + bn_small_base_clear(bs); {%- if "scl" in scalarmult.formulas %} point_scl(q, curve, q); {%- endif %} point_set(q, out); - for (long i = 0; i < {{ scalarmult.m }}; i++) { + for (long i = 0; i < {{ scalarmult.m - 1 }}; i++) { point_free(points[i]); } point_free(q); diff --git a/pyecsca/codegen/templates/mult_rtl.c b/pyecsca/codegen/templates/mult_rtl.c index 71949b4..119ee7e 100644 --- a/pyecsca/codegen/templates/mult_rtl.c +++ b/pyecsca/codegen/templates/mult_rtl.c @@ -5,6 +5,12 @@ void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, point_t *ou point_t *q = point_copy(point); point_t *r = point_copy(curve->neutral); + {% if scalarmult.complete %} + size_t bits = bn_bit_length(&curve->n); + {% else %} + size_t bits = bn_bit_length(scalar); + {% endif %} + {%- if scalarmult.always %} point_t *dummy = point_new(); {%- endif %} @@ -12,7 +18,7 @@ void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, point_t *ou bn_init(©); bn_copy(scalar, ©); - while (!bn_is_0(©)) { + for (int i = 0; i < bits; i++) { if (bn_get_bit(©, 0) == 1) { point_accumulate(r, q, curve, r); } else { diff --git a/pyecsca/codegen/templates/mult_simple_ldr.c b/pyecsca/codegen/templates/mult_simple_ldr.c index ceb257a..33bfcd9 100644 --- a/pyecsca/codegen/templates/mult_simple_ldr.c +++ b/pyecsca/codegen/templates/mult_simple_ldr.c @@ -11,7 +11,7 @@ void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, point_t *ou {%- endif %} for (int i = nbits; i >= 0; i--) { - if (bn_get_bit(scalar, i) == 1) { + if (bn_get_bit(scalar, i) == 0) { point_add(p0, p1, curve, p1); point_dbl(p0, curve, p0); } else { diff --git a/pyecsca/codegen/templates/mult_sliding_w.c b/pyecsca/codegen/templates/mult_sliding_w.c index 1e80a84..347c313 100644 --- a/pyecsca/codegen/templates/mult_sliding_w.c +++ b/pyecsca/codegen/templates/mult_sliding_w.c @@ -34,8 +34,7 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin {%- endif %} point_set(q, out); - free(ws->data); - free(ws); + bn_wsliding_clear(ws); for (long i = 0; i < {{ 2 ** (scalarmult.width - 1) }}; i++) { point_free(points[i]); } diff --git a/pyecsca/codegen/templates/mult_wnaf.c b/pyecsca/codegen/templates/mult_wnaf.c index 3c5f2b2..569e78b 100644 --- a/pyecsca/codegen/templates/mult_wnaf.c +++ b/pyecsca/codegen/templates/mult_wnaf.c @@ -26,7 +26,7 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin wnaf_t *naf = bn_wnaf(scalar, {{ scalarmult.width }}); - for (long i = naf->length - 1; i >= 0; i--) { + for (long i = 0; i < naf->length; i++) { point_dbl(q, curve, q); int8_t val = naf->data[i]; if (val > 0) { @@ -40,8 +40,7 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin {%- endif %} } } - free(naf->data); - free(naf); + bn_naf_clear(naf); {%- if "scl" in scalarmult.formulas %} point_scl(q, curve, q); diff --git a/test/.gitignore b/test/.gitignore new file mode 100644 index 0000000..05cabee --- /dev/null +++ b/test/.gitignore @@ -0,0 +1,2 @@ +test_bn_val +test_bn_san diff --git a/test/Makefile b/test/Makefile index d13b487..d0068a2 100644 --- a/test/Makefile +++ b/test/Makefile @@ -1,3 +1,13 @@ -test_bn: test_bn.c ../pyecsca/codegen/bn/bn.c - gcc -o $@ $^ -fsanitize=address -fsanitize=undefined -I ../pyecsca/codegen/ -I ../pyecsca/codegen/tommath/ -L ../pyecsca/codegen/tommath/ -l:libtommath-HOST.a
\ No newline at end of file +test_bn_san: test_bn.c ../pyecsca/codegen/bn/bn.c + gcc -g -o $@ $^ -fsanitize=address -fsanitize=undefined -I ../pyecsca/codegen/ -I ../pyecsca/codegen/tommath/ -L ../pyecsca/codegen/tommath/ -l:libtommath-HOST.a + +test_bn_val: test_bn.c ../pyecsca/codegen/bn/bn.c + gcc -g -o $@ $^ -I ../pyecsca/codegen/ -I ../pyecsca/codegen/tommath/ -L ../pyecsca/codegen/tommath/ -l:libtommath-HOST.a + + +test: test_bn_san test_bn_val + ./test_bn_san + valgrind -q ./test_bn_val + +.PHONY: test
\ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py index 1c1449a..80faf71 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,5 +1,6 @@ import pytest +from pyecsca.ec.mult import * from pyecsca.ec.params import get_params, DomainParameters @@ -11,3 +12,97 @@ def secp128r1() -> DomainParameters: @pytest.fixture(scope="session") def curve25519() -> DomainParameters: return get_params("other", "Curve25519", "xz") + + +# fmt: off +window_mults = [ + (SlidingWindowMultiplier, dict(width=2, recoding_direction=ProcessingDirection.LTR)), + (SlidingWindowMultiplier, dict(width=3, recoding_direction=ProcessingDirection.LTR)), + (SlidingWindowMultiplier, dict(width=4, recoding_direction=ProcessingDirection.LTR)), + (SlidingWindowMultiplier, dict(width=5, recoding_direction=ProcessingDirection.LTR)), + (SlidingWindowMultiplier, dict(width=6, recoding_direction=ProcessingDirection.LTR)), + (SlidingWindowMultiplier, dict(width=2, recoding_direction=ProcessingDirection.RTL)), + (SlidingWindowMultiplier, dict(width=3, recoding_direction=ProcessingDirection.RTL)), + (SlidingWindowMultiplier, dict(width=4, recoding_direction=ProcessingDirection.RTL)), + (SlidingWindowMultiplier, dict(width=5, recoding_direction=ProcessingDirection.RTL)), + (SlidingWindowMultiplier, dict(width=6, recoding_direction=ProcessingDirection.RTL)), + (FixedWindowLTRMultiplier, dict(m=2**1)), + (FixedWindowLTRMultiplier, dict(m=2**2)), + (FixedWindowLTRMultiplier, dict(m=2**3)), + (FixedWindowLTRMultiplier, dict(m=2**4)), + (FixedWindowLTRMultiplier, dict(m=2**5)), + (FixedWindowLTRMultiplier, dict(m=2**6)), + (WindowBoothMultiplier, dict(width=2)), + (WindowBoothMultiplier, dict(width=3)), + (WindowBoothMultiplier, dict(width=4)), + (WindowBoothMultiplier, dict(width=5)), + (WindowBoothMultiplier, dict(width=6)) +] +naf_mults = [ + (WindowNAFMultiplier, dict(width=2)), + (WindowNAFMultiplier, dict(width=3)), + (WindowNAFMultiplier, dict(width=4)), + (WindowNAFMultiplier, dict(width=5)), + (WindowNAFMultiplier, dict(width=6)), + (BinaryNAFMultiplier, dict(always=False, direction=ProcessingDirection.LTR)), + (BinaryNAFMultiplier, dict(always=False, direction=ProcessingDirection.RTL)), + (BinaryNAFMultiplier, dict(always=True, direction=ProcessingDirection.LTR)), + (BinaryNAFMultiplier, dict(always=True, direction=ProcessingDirection.RTL)), + (BinaryNAFMultiplier, dict(complete=False, always=False, direction=ProcessingDirection.LTR)), + (BinaryNAFMultiplier, dict(complete=False, always=False, direction=ProcessingDirection.RTL)), + (BinaryNAFMultiplier, dict(complete=False, always=True, direction=ProcessingDirection.LTR)), + (BinaryNAFMultiplier, dict(complete=False, always=True, direction=ProcessingDirection.RTL)) +] +comb_mults = [ + (CombMultiplier, dict(width=2, always=True)), + (CombMultiplier, dict(width=3, always=True)), + (CombMultiplier, dict(width=4, always=True)), + (CombMultiplier, dict(width=5, always=True)), + (CombMultiplier, dict(width=6, always=True)), + (CombMultiplier, dict(width=2, always=False)), + (CombMultiplier, dict(width=3, always=False)), + (CombMultiplier, dict(width=4, always=False)), + (CombMultiplier, dict(width=5, always=False)), + (CombMultiplier, dict(width=6, always=False)), + (BGMWMultiplier, dict(width=2, direction=ProcessingDirection.LTR)), + (BGMWMultiplier, dict(width=3, direction=ProcessingDirection.LTR)), + (BGMWMultiplier, dict(width=4, direction=ProcessingDirection.LTR)), + (BGMWMultiplier, dict(width=5, direction=ProcessingDirection.LTR)), + (BGMWMultiplier, dict(width=6, direction=ProcessingDirection.LTR)), + (BGMWMultiplier, dict(width=2, direction=ProcessingDirection.RTL)), + (BGMWMultiplier, dict(width=3, direction=ProcessingDirection.RTL)), + (BGMWMultiplier, dict(width=4, direction=ProcessingDirection.RTL)), + (BGMWMultiplier, dict(width=5, direction=ProcessingDirection.RTL)), + (BGMWMultiplier, dict(width=6, direction=ProcessingDirection.RTL)) +] +binary_mults = [ + (LTRMultiplier, dict(always=False, complete=True)), + (LTRMultiplier, dict(always=True, complete=True)), + (LTRMultiplier, dict(always=False, complete=False)), + (LTRMultiplier, dict(always=True, complete=False)), + (RTLMultiplier, dict(always=False, complete=True)), + (RTLMultiplier, dict(always=True, complete=True)), + (RTLMultiplier, dict(always=False, complete=False)), + (RTLMultiplier, dict(always=True, complete=False)), + (CoronMultiplier, dict()) +] +other_mults = [ + (FullPrecompMultiplier, dict(always=False, complete=True)), + (FullPrecompMultiplier, dict(always=True, complete=True)), + (FullPrecompMultiplier, dict(always=False, complete=False)), + (FullPrecompMultiplier, dict(always=True, complete=False)), + (SimpleLadderMultiplier, dict(complete=True)), + (SimpleLadderMultiplier, dict(complete=False)) +] +# fmt: on + + +@pytest.fixture( + scope="session", + params=window_mults + naf_mults + comb_mults + binary_mults + other_mults, + ids=lambda p: "{}-{}".format( + p[0].__name__, ":".join(f"{k}={v}" for k, v in p[1].items()) + ), +) +def simple_multiplier(request): + return request.param diff --git a/test/gdb_script.py b/test/gdb_script.py new file mode 100644 index 0000000..93d4472 --- /dev/null +++ b/test/gdb_script.py @@ -0,0 +1,89 @@ +import os +import json + +import gdb + +trace_file = open(os.environ["TRACE_FILE"], "w") + + +def extract_bn(bn): + data_ptr = bn["dp"] + used = int(bn["used"]) + bs = int(gdb.lookup_global_symbol("bn_digit_bits").value()) + result = 0 + for i in range(used): + limb = int((data_ptr + i).dereference()) + result += limb << (i * bs) + return result + + +def extract_point(point): + result = {} + for field in point.type.fields(): + field_name = field.name + if len(field_name) != 1: + continue + field_value = point[field_name] + result[field_name] = extract_bn(field_value) + return result + + +class TraceFunction(gdb.Breakpoint): + def stop(self): + try: + set_bp.enabled = True + frame = gdb.newest_frame() + block = frame.block() + print(frame.name(), flush=True, file=trace_file) + out = [] + for sym in block: + if sym.is_argument: + name = sym.name + try: + value = frame.read_var(name) + except Exception as e: + value = f"<unavailable: {e}>" + deref = value.dereference() + if deref.type.name == "point_t": + if "out" in name: + out.append(deref) + else: + pt = extract_point(deref) + print(f"{name}: {json.dumps(pt)}", flush=True, file=trace_file) + bp = TraceExit(frame) + bp.silent = True + bp.target = out + except RuntimeError: + pass + return False # Continue execution + + +class TraceExit(gdb.FinishBreakpoint): + def stop(self): + set_bp.enabled = False + for i, point in enumerate(self.target): + print(f"out_{i}: {json.dumps(extract_point(point))}", flush=True, file=trace_file) + return False # Continue execution + + +def register_bp(name): + if gdb.lookup_global_symbol(name) is not None: + bp = TraceFunction(name) + bp.silent = True + return bp + return None + + +register_bp("point_add") +register_bp("point_dadd") +register_bp("point_dadd") +register_bp("point_ladd") +register_bp("point_dbl") +register_bp("point_neg") +register_bp("point_scl") +register_bp("point_tpl") +set_bp = register_bp("point_set") +set_bp.enabled = False + +gdb.execute("run") +trace_file.close() diff --git a/test/test_bn.c b/test/test_bn.c index 66de22f..b04c936 100644 --- a/test/test_bn.c +++ b/test/test_bn.c @@ -4,88 +4,403 @@ int test_wsliding_ltr() { printf("test_wsliding_ltr: "); + int failed = 0; + struct { + const char *value; + int w; + int expected_len; + uint8_t expected[100]; // max expected length + } cases[] = { + // sliding_window_ltr begin + {"181", 3, 6, {5, 0, 0, 5, 0, 1}}, + {"1", 3, 1, {1}}, + {"1234", 2, 11, {1, 0, 0, 0, 3, 0, 1, 0, 0, 1, 0}}, + {"170", 4, 6, {5, 0, 0, 0, 5, 0}}, + {"554", 5, 6, {17, 0, 0, 0, 5, 0}}, + {"123456789123456789123456789", 5, 83, {25, 1, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 23, 0, 0, 0, 0, 25, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 29, 0, 0, 0, 0, 17, 0, 0, 0, 0, 19, 0, 0, 0, 0, 29, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, 21}} + // sliding_window_ltr end + }; + int num_cases = sizeof(cases) / sizeof(cases[0]); + for (int t = 0; t < num_cases; t++) { + bn_t bn; + bn_init(&bn); + bn_from_dec(cases[t].value, &bn); + wsliding_t *ws = bn_wsliding_ltr(&bn, cases[t].w); + if (ws == NULL) { + printf("Case %d: NULL\n", t); + failed++; + bn_clear(&bn); + continue; + } + if (ws->length != cases[t].expected_len) { + printf("Case %d: Bad length (%li instead of %i)\n", t, ws->length, cases[t].expected_len); + failed++; + } + for (int i = 0; i < cases[t].expected_len; i++) { + if (ws->data[i] != cases[t].expected[i]) { + printf("Case %d: Bad data at %d (%i instead of %i)\n", t, i, ws->data[i], cases[t].expected[i]); + failed++; + break; + } + } + bn_clear(&bn); + free(ws->data); + free(ws); + } + if (failed == 0) { + printf("OK\n"); + } else { + printf("FAILED (%d cases)\n", failed); + } + return failed; +} + +int test_wsliding_rtl() { + printf("test_wsliding_rtl: "); + int failed = 0; + struct { + const char *value; + int w; + int expected_len; + uint8_t expected[100]; // max expected length + } cases[] = { + // sliding_window_rtl begin + {"181", 3, 8, {1, 0, 0, 3, 0, 0, 0, 5}}, + {"1", 3, 1, {1}}, + {"1234", 2, 11, {1, 0, 0, 0, 3, 0, 1, 0, 0, 1, 0}}, + {"170", 4, 6, {5, 0, 0, 0, 5, 0}}, + {"554", 5, 10, {1, 0, 0, 0, 0, 0, 0, 0, 21, 0}}, + {"123456789123456789123456789",5, 87, {1, 0, 0, 0, 0, 19, 0, 0, 0, 0, 1, 0, 0, 0, 0, 29, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 17, 0, 0, 0, 0, 27, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, 21}} + // sliding_window_rtl end + }; + int num_cases = sizeof(cases) / sizeof(cases[0]); + for (int t = 0; t < num_cases; t++) { + bn_t bn; + bn_init(&bn); + bn_from_dec(cases[t].value, &bn); + wsliding_t *ws = bn_wsliding_rtl(&bn, cases[t].w); + if (ws == NULL) { + printf("Case %d: NULL\n", t); + failed++; + bn_clear(&bn); + continue; + } + if (ws->length != cases[t].expected_len) { + printf("Case %d: Bad length (%li instead of %i)\n", t, ws->length, cases[t].expected_len); + failed++; + } + for (int i = 0; i < cases[t].expected_len; i++) { + if (ws->data[i] != cases[t].expected[i]) { + printf("Case %d: Bad data at %d (%i instead of %i)\n", t, i, ws->data[i], cases[t].expected[i]); + failed++; + break; + } + } + bn_clear(&bn); + free(ws->data); + free(ws); + } + if (failed == 0) { + printf("OK\n"); + } else { + printf("FAILED (%d cases)\n", failed); + } + return failed; +} + +int test_convert_base_small() { + printf("test_convert_base_small: "); + int failed = 0; + struct { + const char *value; + int base; + int expected_len; + uint8_t expected[100]; // max expected length + } cases[] = { + // convert_base_small begin + {"11", 2, 4, {1, 1, 0, 1}}, + {"255", 2, 8, {1, 1, 1, 1, 1, 1, 1, 1}}, + {"1234", 10, 4, {4, 3, 2, 1}}, + {"0", 2, 1, {0}}, + {"1", 2, 1, {1}}, + {"123456789123456789123456789", 16, 22, {5, 1, 15, 5, 4, 0, 12, 7, 15, 9, 1, 11, 3, 14, 2, 15, 13, 15, 14, 1, 6, 6}} + // convert_base_small end + }; + int num_cases = sizeof(cases) / sizeof(cases[0]); + for (int t = 0; t < num_cases; t++) { + bn_t bn; + bn_init(&bn); + bn_from_dec(cases[t].value, &bn); + small_base_t *bs = bn_convert_base_small(&bn, cases[t].base); + if (bs == NULL) { + printf("Case %d: NULL\n", t); + failed++; + bn_clear(&bn); + continue; + } + if (bs->length != cases[t].expected_len) { + printf("Case %d: Bad length (%li instead of %i)\n", t, bs->length, cases[t].expected_len); + failed++; + } + for (int i = 0; i < cases[t].expected_len; i++) { + if (bs->data[i] != cases[t].expected[i]) { + printf("Case %d: Bad data at %d (%i instead of %i)\n", t, i, bs->data[i], cases[t].expected[i]); + failed++; + break; + } + } + bn_clear(&bn); + free(bs->data); + free(bs); + } + if (failed == 0) { + printf("OK\n"); + } else { + printf("FAILED (%d cases)\n", failed); + } + return failed; +} + +int test_convert_base_large() { + printf("test_convert_base_large: "); + int failed = 0; + struct { + const char *value; + const char *base; + int expected_len; + const char *expected[100]; // max expected length + } cases[] = { + // convert_base_large begin + {"123456789123456", "2", 47, {"0", "0", "0", "0", "0", "0", "0", "1", "1", "0", "0", "0", "1", "0", "0", "1", "1", "1", "1", "1", "0", "0", "0", "0", "0", "1", "1", "0", "0", "0", "0", "1", "0", "0", "0", "1", "0", "0", "1", "0", "0", "0", "0", "0", "1", "1", "1"}}, + {"123456789123456789123456789", "123456", 6, {"104661", "75537", "83120", "74172", "37630", "4"}}, + {"352099265818416392997042486274568094251", "18446744073709551616", 3, {"12367597952119210539", "640595372834356666", "1"}} + // convert_base_large end + }; + int num_cases = sizeof(cases) / sizeof(cases[0]); + for (int t = 0; t < num_cases; t++) { + bn_t bn, base; + bn_init(&bn); + bn_init(&base); + bn_from_dec(cases[t].value, &bn); + bn_from_dec(cases[t].base, &base); + large_base_t *bs = bn_convert_base_large(&bn, &base); + if (bs == NULL) { + printf("Case %d: NULL\n", t); + failed++; + bn_clear(&bn); + bn_clear(&base); + continue; + } + if (bs->length != cases[t].expected_len) { + printf("Case %d: Bad length (%li instead of %i)\n", t, bs->length, cases[t].expected_len); + failed++; + } + for (int i = 0; i < cases[t].expected_len; i++) { + bn_t exp; + bn_init(&exp); + bn_from_dec(cases[t].expected[i], &exp); + if (!bn_eq(&bs->data[i], &exp)) { + printf("Case %d: Bad data at %d\n", t, i); + failed++; + bn_clear(&exp); + break; + } + bn_clear(&exp); + } + for (int i = 0; i < bs->length; i++) { + bn_clear(&bs->data[i]); + } + bn_clear(&bs->m); + bn_clear(&bn); + bn_clear(&base); + free(bs->data); + free(bs); + } + if (failed == 0) { + printf("OK\n"); + } else { + printf("FAILED (%d cases)\n", failed); + } + return failed; +} + +int test_bn_wnaf() { + printf("test_bn_wnaf: "); + int failed = 0; + struct { + const char *value; + int w; + int expected_len; + int8_t expected[100]; // max expected length + } cases[] = { + // wnaf begin + {"19", 2, 5, {1, 0, 1, 0, -1}}, + {"45", 3, 5, {3, 0, 0, 0, -3}}, + {"0", 3, 0, {}}, + {"1", 2, 1, {1}}, + {"21", 4, 5, {1, 0, 0, 0, 5}}, + {"123456789", 3, 28, {1, 0, 0, -1, 0, 0, 3, 0, 0, -1, 0, 0, 0, 0, 0, -3, 0, 0, 0, -3, 0, 0, 0, 0, 3, 0, 0, -3}}, + {"123456789123456789123456789", 5, 84, {13, 0, 0, 0, 0, 0, -15, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, -13, 0, 0, 0, 0, 0, -7, 0, 0, 0, 0, 0, -5, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, -7, 0, 0, 0, 0, -11}} + // wnaf end + }; + int num_cases = sizeof(cases) / sizeof(cases[0]); + for (int t = 0; t < num_cases; t++) { + bn_t bn; + bn_init(&bn); + bn_from_dec(cases[t].value, &bn); + wnaf_t *naf = bn_wnaf(&bn, cases[t].w); + if (naf == NULL) { + printf("Case %d: NULL\n", t); + failed++; + bn_clear(&bn); + continue; + } + if (naf->length != cases[t].expected_len) { + printf("Case %d: Bad length (%li instead of %i)\n", t, naf->length, cases[t].expected_len); + failed++; + } + for (int i = 0; i < cases[t].expected_len; i++) { + if (naf->data[i] != cases[t].expected[i]) { + printf("Case %d: Bad data at %d (%i instead of %i)\n", t, i, naf->data[i], cases[t].expected[i]); + failed++; + break; + } + } + bn_clear(&bn); + free(naf->data); + free(naf); + } + if (failed == 0) { + printf("OK\n"); + } else { + printf("FAILED (%d cases)\n", failed); + } + return failed; +} + +int test_bn_wnaf_manipulation() { + printf("test_bn_wnaf_manipulation: "); bn_t bn; bn_init(&bn); - bn_from_int(181, &bn); - wsliding_t *ws = bn_wsliding_ltr(&bn, 3); - if (ws == NULL) { - printf("NULL\n"); + bn_from_dec("123456789", &bn); + wnaf_t *naf = bn_wnaf(&bn, 3); + bn_clear(&bn); + if (naf->length != 28) { + printf("FAILED (bad length %li instead of 28)\n", naf->length); return 1; } - if (ws->length != 6) { - printf("Bad length (%li instead of 6)\n", ws->length); + bn_naf_pad_left(naf, 0, 5); + if (naf->length != 33) { + printf("FAILED (bad length after pad left %li instead of 33)\n", naf->length); return 1; } - uint8_t expected[6] = {5, 0, 0, 5, 0, 1}; - for (int i = 0; i < 6; i++) { - if (ws->data[i] != expected[i]) { - printf("Bad data (%i instead of %i)\n", ws->data[i], expected[i]); + for (int i = 0; i < 5; i++) { + if (naf->data[i] != 0) { + printf("FAILED (bad data after pad left at %d (%i instead of 0))\n", i, naf->data[i]); return 1; } } - printf("OK\n"); - bn_clear(&bn); - free(ws->data); - free(ws); - return 0; -} - -int test_wsliding_rtl() { - printf("test_wsliding_rtl: "); - bn_t bn; - bn_init(&bn); - bn_from_int(181, &bn); - wsliding_t *ws = bn_wsliding_rtl(&bn, 3); - if (ws == NULL) { - printf("NULL\n"); + bn_naf_strip_left(naf, 0); + if (naf->length != 28) { + printf("FAILED (bad length after strip left %li instead of 28)\n", naf->length); return 1; } - if (ws->length != 8) { - printf("Bad length (%li instead of 8)\n", ws->length); + bn_naf_pad_right(naf, 0, 3); + if (naf->length != 31) { + printf("FAILED (bad length after pad right %li instead of 31)\n", naf->length); return 1; } - uint8_t expected[8] = {1, 0, 0, 3, 0, 0, 0, 5}; - for (int i = 0; i < 8; i++) { - if (ws->data[i] != expected[i]) { - printf("Bad data (%i instead of %i)\n", ws->data[i], expected[i]); + for (int i = 28; i < 31; i++) { + if (naf->data[i] != 0) { + printf("FAILED (bad data after pad right at %d (%i instead of 0))\n", i, naf->data[i]); return 1; } } + bn_naf_strip_right(naf, 0); + if (naf->length != 28) { + printf("FAILED (bad length after strip right %li instead of 28)\n", naf->length); + return 1; + } + int8_t rev[28] = {-3, 0, 0, 3, 0, 0, 0, 0, -3, 0, 0, 0, -3, 0, 0, 0, 0, 0, -1, 0, 0, 3, 0, 0, -1, 0, 0, 1}; + bn_naf_reverse(naf); + for (int i = 0; i < 28; i++) { + if (naf->data[i] != rev[i]) { + printf("FAILED (bad data after reverse at %d (%i instead of %i))\n", i, naf->data[i], rev[i]); + return 1; + } + } + + free(naf->data); + free(naf); printf("OK\n"); - bn_clear(&bn); - free(ws->data); - free(ws); return 0; } -int test_convert_base() { - printf("test_convert_base: "); - bn_t bn; - bn_init(&bn); - bn_from_int(11, &bn); - small_base_t *bs = bn_convert_base_small(&bn, 2); - if (bs == NULL) { - printf("NULL\n"); - return 1; - } - if (bs->length != 4) { - printf("Bad length (%li instead of 4)\n", bs->length); - return 1; +int test_booth() { + printf("test_booth: "); + for (int i = 0; i < (1 << 6); i++) { + int32_t bw = bn_booth_word(i, 5); + if (i <= 31) { + if (bw != (i + 1) / 2) { + printf("FAILED (bad booth for %d: %d instead of %d)\n", i, bw, (i + 1) / 2); + return 1; + } + } else { + if (bw != -((64 - i) / 2)) { + printf("FAILED (bad booth for %d: %d instead of %d)\n", i, bw, -((64 - i) / 2)); + return 1; + } + } } - uint8_t expected[4] = {1, 1, 0, 1}; - for (int i = 0; i < 4; i++) { - if (bs->data[i] != expected[i]) { - printf("Bad data (%i insead of %i)\n", bs->data[i], expected[i]); - return 1; + int failed = 0; + struct { + const char *value; + int w; + size_t bits; + int expected_len; + int32_t expected[256]; // max expected length + } cases[] = { + // booth begin + {"12345678123456781234567812345678123456781234567812345678", 1, 224, 225, {0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0}}, + {"12345678123456781234567812345678123456781234567812345678", 2, 224, 113, {0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0}}, + {"12345678123456781234567812345678123456781234567812345678", 5, 224, 45, {1, 4, 13, 3, -10, 15, 0, 9, 3, 9, -10, -12, -8, 2, 9, -6, 5, 13, -2, 1, -14, 7, -15, 11, 8, -16, 5, -14, -12, 11, -6, -4, 1, 4, 13, 3, -10, 15, 0, 9, 3, 9, -10, -12, -8}}, + {"12345678123456781234567812345678123456781234567812345678", 16, 224, 15, {0, 4660, 22136, 4660, 22136, 4660, 22136, 4660, 22136, 4660, 22136, 4660, 22136, 4660, 22136}}, + {"12345678123456781234567812345678123456781234567812345678", 24, 224, 10, {18, 3430008, 1193046, 7868980, 5666834, 3430008, 1193046, 7868980, 5666834, 3430008}}, + {"12345678123456781234567812345678123456781234567812345678", 30, 224, 0, {}} + // booth end + }; + int num_cases = sizeof(cases) / sizeof(cases[0]); + for (int t = 0; t < num_cases; t++) { + bn_t bn; + bn_init(&bn); + bn_from_hex(cases[t].value, &bn); + booth_t *booth = bn_booth(&bn, cases[t].w, cases[t].bits); + if (booth == NULL && cases[t].expected_len != 0) { + printf("Case %d: NULL\n", t); + failed++; + bn_clear(&bn); + continue; } + if (cases[t].expected_len != 0) { + if (booth->length != cases[t].expected_len) { + printf("Case %d: Bad length (%li instead of %i)\n", t, booth->length, cases[t].expected_len); + failed++; + } + for (int i = 0; i < cases[t].expected_len; i++) { + if (booth->data[i] != cases[t].expected[i]) { + printf("FAILED (bad booth data at %d: %d instead of %d)\n", i, booth->data[i], cases[t].expected[i]); + failed++; + break; + } + } + } + bn_clear(&bn); + bn_booth_clear(booth); } printf("OK\n"); - bn_clear(&bn); - free(bs->data); - free(bs); return 0; } int main(void) { - return test_wsliding_ltr() + test_wsliding_rtl() + test_convert_base(); + return test_wsliding_ltr() + test_wsliding_rtl() + test_convert_base_small() + test_convert_base_large() + test_bn_wnaf() + test_bn_wnaf_manipulation() + test_booth(); } diff --git a/test/test_equivalence.py b/test/test_equivalence.py new file mode 100644 index 0000000..6973ceb --- /dev/null +++ b/test/test_equivalence.py @@ -0,0 +1,185 @@ +import subprocess +import json +from tempfile import NamedTemporaryFile +from typing import Generator, Any +from click.testing import CliRunner + +import pytest +from importlib import resources +from os.path import join + +from pyecsca.codegen.builder import build_impl +from pyecsca.ec.formula import FormulaAction, NegationFormula +from pyecsca.ec.model import CurveModel +from pyecsca.ec.coordinates import CoordinateModel +from pyecsca.sca.target.binary import BinaryTarget +from pyecsca.codegen.client import ImplTarget +from pyecsca.ec.context import DefaultContext, local, Node + +from pyecsca.ec.mult import WindowBoothMultiplier + + +class GDBTarget(ImplTarget, BinaryTarget): + def __init__(self, model: CurveModel, coords: CoordinateModel, **kwargs): + super().__init__(model, coords, **kwargs) + self.trace_file = None + + def connect(self): + self.trace_file = NamedTemporaryFile("r+") + with resources.path("test", "gdb_script.py") as gdb_script: + self.process = subprocess.Popen( + [ + "gdb", + "-q", + "-batch-silent", + "-x", + gdb_script, + "--args", + *self.binary, + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + env={ + "TRACE_FILE": self.trace_file.name, + }, + text=True, + bufsize=1, + ) + + def disconnect(self): + super().disconnect() + if self.trace_file is not None: + self.trace_file.close() + self.trace_file = None + + +@pytest.fixture(scope="module") +def target(simple_multiplier, secp128r1) -> Generator[GDBTarget, Any, None]: + mult_class, mult_kwargs = simple_multiplier + mult_name = mult_class.__name__ + formulas = ["add-1998-cmo", "dbl-1998-cmo"] + if NegationFormula in mult_class.requires: + formulas.append("neg") + runner = CliRunner() + with runner.isolated_filesystem() as tmpdir: + res = runner.invoke( + build_impl, + [ + "--platform", + "HOST", + "--ecdsa", + "--ecdh", + secp128r1.curve.model.shortname, + secp128r1.curve.coordinate_model.name, + *formulas, + f"{mult_name}({','.join(f'{key}={value}' for key, value in mult_kwargs.items())})", + ".", + ], + env={"DEBUG": "1", "CFLAGS": "-g -O0"}, + ) + assert res.exit_code == 0 + target = GDBTarget( + secp128r1.curve.model, + secp128r1.curve.coordinate_model, + binary=join(tmpdir, "pyecsca-codegen-HOST.elf"), + ) + formula_instances = [ + secp128r1.curve.coordinate_model.formulas[formula] for formula in formulas + ] + mult = mult_class(*formula_instances, **mult_kwargs) + target.mult = mult + yield target + + +def parse_trace(captured: str): + current_function = None + args = [] + rets = [] + result = [] + for line in captured.split("\n"): + if ":" not in line: + func = line.strip() + if func.startswith("point_"): + func = func[len("point_") :] + if func == "set": + # The sets that happen inside another formula (like add) are a sign of short-circuiting. + # The Python simulation does not record the short-circuits, so we ignore them here. + current_function = None + else: + if current_function is not None: + result.append((current_function, args, rets)) + current_function = func + args = [] + rets = [] + else: + name, data = line.split(":", 1) + name = name.strip() + value = json.loads(data) + if "out" in name: + rets.append(value) + else: + args.append(value) + return result + + +def parse_ctx(scalarmult: Node): + result = [] + for node in scalarmult.children: + action: FormulaAction = node.action + formula = action.formula + name = formula.shortname + args = [] + for point in action.input_points: + point_value = {k: int(v) for k, v in point.coords.items()} + args.append(point_value) + rets = [] + for point in action.output_points: + point_value = {k: int(v) for k, v in point.coords.items()} + rets.append(point_value) + result.append((name, args, rets)) + return result + + +def make_hashable(trace): + result = [] + for entry in trace: + name, args, rets = entry + args_t = tuple(tuple(arg.items()) for arg in args) + rets_t = tuple(tuple(ret.items()) for ret in rets) + result.append((name, args_t, rets_t)) + return tuple(result) + + +def test_equivalence(target, secp128r1): + mult = target.mult + for i in range(10): + target.connect() + target.init_prng(i.to_bytes(4, "big")) + target.set_params(secp128r1) + + priv, pub = target.generate() + + with local(DefaultContext()) as ctx: + mult.init(secp128r1, secp128r1.generator) + expected = mult.multiply(priv).to_affine() + + assert secp128r1.curve.is_on_curve(pub) + assert pub == expected + err = target.trace_file.read() + from_codegen = parse_trace(err) + from_sim = parse_ctx(ctx.actions[0]) + parse_ctx(ctx.actions[1]) + codegen_set = set(make_hashable(from_codegen)) + sim_set = set(make_hashable(from_sim)) + if codegen_set != sim_set: + print(len(from_codegen), len(from_sim)) + print("In codegen but not in sim:") + for entry in codegen_set - sim_set: + print(entry) + print("In sim but not in codegen:") + for entry in sim_set - codegen_set: + print(entry) + assert from_codegen == from_sim + + target.quit() + target.disconnect() diff --git a/test/test_impl.py b/test/test_impl.py index 34f427b..a61aabd 100644 --- a/test/test_impl.py +++ b/test/test_impl.py @@ -1,29 +1,18 @@ from copy import copy from os.path import join +from typing import Any, Generator + import pytest from click.testing import CliRunner +from pyecsca.ec.formula import NegationFormula from pyecsca.ec.key_agreement import ECDH_SHA1 from pyecsca.ec.mod import mod -from pyecsca.ec.mult import ( - LTRMultiplier, - RTLMultiplier, - CoronMultiplier, - BinaryNAFMultiplier, - WindowNAFMultiplier, - SlidingWindowMultiplier, - AccumulationOrder, - ProcessingDirection, - ScalarMultiplier, - FixedWindowLTRMultiplier, - FullPrecompMultiplier, - BGMWMultiplier, - CombMultiplier, -) +from pyecsca.ec.mult import ScalarMultiplier, WindowBoothMultiplier from pyecsca.ec.signature import ECDSA_SHA1, SignatureResult from pyecsca.codegen.builder import build_impl -from pyecsca.codegen.client import HostTarget, ImplTarget +from pyecsca.codegen.client import HostTarget @pytest.fixture( @@ -39,195 +28,15 @@ def additional(request): return request.param -@pytest.fixture( - scope="module", - params=[ - pytest.param( - ( - LTRMultiplier, - "ltr", - ["add-1998-cmo", "dbl-1998-cmo"], - {"complete": False}, - ), - id="LTR1", - ), - pytest.param( - ( - LTRMultiplier, - "ltr", - ["add-1998-cmo", "dbl-1998-cmo"], - {"complete": True}, - ), - id="LTR2", - ), - pytest.param( - ( - LTRMultiplier, - "ltr", - ["add-1998-cmo", "dbl-1998-cmo"], - {"complete": False, "always": True}, - ), - id="LTR3", - ), - pytest.param( - ( - LTRMultiplier, - "ltr", - ["add-1998-cmo", "dbl-1998-cmo"], - {"complete": True, "always": True}, - ), - id="LTR4", - ), - pytest.param( - ( - LTRMultiplier, - "ltr", - ["add-1998-cmo", "dbl-1998-cmo"], - {"complete": False, "accumulation_order": AccumulationOrder.PeqRP}, - ), - id="LTR5", - ), - pytest.param( - (RTLMultiplier, "rtl", ["add-1998-cmo", "dbl-1998-cmo"], {"always": False}), - id="RTL1", - ), - pytest.param( - (RTLMultiplier, "rtl", ["add-1998-cmo", "dbl-1998-cmo"], {"always": True}), - id="RTL2", - ), - pytest.param( - (CoronMultiplier, "coron", ["add-1998-cmo", "dbl-1998-cmo"], {}), id="Coron" - ), - pytest.param( - ( - BinaryNAFMultiplier, - "bnaf", - ["add-1998-cmo", "dbl-1998-cmo", "neg"], - {"direction": ProcessingDirection.LTR}, - ), - id="BNAF1", - ), - pytest.param( - ( - BinaryNAFMultiplier, - "bnaf", - ["add-1998-cmo", "dbl-1998-cmo", "neg"], - {"direction": ProcessingDirection.RTL}, - ), - id="BNAF2", - ), - pytest.param( - ( - WindowNAFMultiplier, - "wnaf", - ["add-1998-cmo", "dbl-1998-cmo", "neg"], - {"width": 3}, - ), - id="WNAF1", - ), - pytest.param( - ( - WindowNAFMultiplier, - "wnaf", - ["add-1998-cmo", "dbl-1998-cmo", "neg"], - {"width": 3, "precompute_negation": True}, - ), - id="WNAF2", - ), - pytest.param( - ( - SlidingWindowMultiplier, - "sliding", - ["add-1998-cmo", "dbl-1998-cmo"], - {"width": 3}, - ), - id="SLI1", - ), - pytest.param( - ( - SlidingWindowMultiplier, - "sliding", - ["add-1998-cmo", "dbl-1998-cmo"], - {"width": 3, "recoding_direction": ProcessingDirection.RTL}, - ), - id="SLI2", - ), - pytest.param( - ( - FixedWindowLTRMultiplier, - "fixed", - ["add-1998-cmo", "dbl-1998-cmo"], - {"m": 4}, - ), - id="FIX1", - ), - pytest.param( - ( - FixedWindowLTRMultiplier, - "fixed", - ["add-1998-cmo", "dbl-1998-cmo"], - {"m": 5}, - ), - id="FIX2", - ), - pytest.param( - ( - FullPrecompMultiplier, - "precomp", - ["add-1998-cmo", "dbl-1998-cmo"], - {"direction": ProcessingDirection.LTR}, - ), - id="PRE1", - ), - pytest.param( - ( - FullPrecompMultiplier, - "precomp", - ["add-1998-cmo", "dbl-1998-cmo"], - {"direction": ProcessingDirection.RTL}, - ), - id="PRE2", - ), - pytest.param( - ( - BGMWMultiplier, - "bgmw", - ["add-1998-cmo", "dbl-1998-cmo"], - {"width": 3, "direction": ProcessingDirection.LTR}, - ), - id="BGMW1", - ), - pytest.param( - ( - BGMWMultiplier, - "bgmw", - ["add-1998-cmo", "dbl-1998-cmo"], - {"width": 5, "direction": ProcessingDirection.RTL}, - ), - id="BGMW2", - ), - pytest.param( - ( - CombMultiplier, - "comb", - ["add-1998-cmo", "dbl-1998-cmo"], - {"width": 3}, - ), - id="Comb1", - ), - pytest.param( - ( - CombMultiplier, - "comb", - ["add-1998-cmo", "dbl-1998-cmo"], - {"width": 5}, - ), - id="Comb2", - ), - ], -) -def target(request, additional, secp128r1) -> ImplTarget: - mult_class, mult_name, formulas, mult_kwargs = request.param +@pytest.fixture(scope="module") +def target( + simple_multiplier, additional, secp128r1 +) -> Generator[HostTarget, Any, None]: + mult_class, mult_kwargs = simple_multiplier + mult_name = mult_class.__name__ + formulas = ["add-1998-cmo", "dbl-1998-cmo"] + if NegationFormula in mult_class.requires: + formulas.append("neg") runner = CliRunner() with runner.isolated_filesystem() as tmpdir: res = runner.invoke( @@ -235,7 +44,6 @@ def target(request, additional, secp128r1) -> ImplTarget: [ "--platform", "HOST", - *additional, "--ecdsa", "--ecdh", secp128r1.curve.model.shortname, |
