aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/test.yml6
-rw-r--r--.gitignore1
-rw-r--r--Makefile2
m---------ext/libtommath0
-rw-r--r--pyecsca/codegen/bn/bn.c225
-rw-r--r--pyecsca/codegen/bn/bn.h22
-rw-r--r--pyecsca/codegen/builder.py2
-rw-r--r--pyecsca/codegen/render.py3
-rw-r--r--pyecsca/codegen/templates/formula_add.c5
-rw-r--r--pyecsca/codegen/templates/formula_dbl.c3
-rw-r--r--pyecsca/codegen/templates/formula_neg.c3
-rw-r--r--pyecsca/codegen/templates/formula_scl.c3
-rw-r--r--pyecsca/codegen/templates/formula_tpl.c3
-rw-r--r--pyecsca/codegen/templates/mult.c4
-rw-r--r--pyecsca/codegen/templates/mult_bgmw.c3
-rw-r--r--pyecsca/codegen/templates/mult_bnaf.c94
-rw-r--r--pyecsca/codegen/templates/mult_booth.c78
-rw-r--r--pyecsca/codegen/templates/mult_comb.c20
-rw-r--r--pyecsca/codegen/templates/mult_fixed_w.c21
-rw-r--r--pyecsca/codegen/templates/mult_rtl.c8
-rw-r--r--pyecsca/codegen/templates/mult_simple_ldr.c2
-rw-r--r--pyecsca/codegen/templates/mult_sliding_w.c3
-rw-r--r--pyecsca/codegen/templates/mult_wnaf.c5
-rw-r--r--test/.gitignore2
-rw-r--r--test/Makefile14
-rw-r--r--test/conftest.py95
-rw-r--r--test/gdb_script.py89
-rw-r--r--test/test_bn.c427
-rw-r--r--test/test_equivalence.py185
-rw-r--r--test/test_impl.py220
30 files changed, 1206 insertions, 342 deletions
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index f2e480d..f9325ca 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -4,11 +4,11 @@ on: [push, pull_request]
env:
LLVM_CONFIG: /usr/bin/llvm-config-14
- OTHER_PACKAGES: swig libpcsclite-dev llvm-14 libllvm14 llvm-14-dev valgrind gcc gcc-arm-none-eabi binutils-arm-none-eabi libnewlib-arm-none-eabi
+ OTHER_PACKAGES: swig libpcsclite-dev llvm-14 libllvm14 llvm-14-dev valgrind gdb gcc gcc-arm-none-eabi binutils-arm-none-eabi libnewlib-arm-none-eabi
jobs:
test:
- runs-on: ubuntu-22.04
+ runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
@@ -50,7 +50,7 @@ jobs:
make test
- name: Test (C)
run: |
- cd test && make test_bn && ./test_bn && cd ..
+ cd test && make test && cd ..
- name: Code coverage
uses: codecov/codecov-action@v3
if: ${{ matrix.python-version == 3.9 }}
diff --git a/.gitignore b/.gitignore
index 9e55abb..58caf65 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,7 +4,6 @@
htmlcov/
/build/
__pycache__
-/test/test_bn
pyecsca-codegen-*.elf
pyecsca-codegen-*.hex \ No newline at end of file
diff --git a/Makefile b/Makefile
index faeaacb..3d19252 100644
--- a/Makefile
+++ b/Makefile
@@ -1,5 +1,3 @@
-TESTS = test_builder test_client test_render test_impl test_simulator
-
test:
pytest -m "not slow" --cov=pyecsca.codegen
diff --git a/ext/libtommath b/ext/libtommath
-Subproject e9b2847b8ddf800e44f486966cfa71b82339248
+Subproject c71669aa4e8d9a9ab3076a8398484e3d3d0c04f
diff --git a/pyecsca/codegen/bn/bn.c b/pyecsca/codegen/bn/bn.c
index 148403c..d7a44b6 100644
--- a/pyecsca/codegen/bn/bn.c
+++ b/pyecsca/codegen/bn/bn.c
@@ -27,6 +27,8 @@ void math_init(void) {
#endif //TODO: COMBA
}
+const int bn_digit_bits __attribute__((used)) = MP_DIGIT_BIT;
+
bn_err bn_init(bn_t *bn) {
return mp_init(bn);
}
@@ -47,6 +49,10 @@ bn_err bn_from_hex(const char *data, bn_t *out) {
return mp_read_radix(out, data, 16);
}
+bn_err bn_from_dec(const char *data, bn_t *out) {
+ return mp_read_radix(out, data, 10);
+}
+
bn_err bn_from_int(unsigned int value, bn_t *out) {
if (sizeof(unsigned int) == 8) {
mp_set_u64(out, value);
@@ -394,6 +400,9 @@ wnaf_t *bn_wnaf(const bn_t *bn, int w) {
}
wnaf_t *result = NULL;
+ size_t bits = bn_bit_length(bn) + 1;
+ int8_t arr[bits];
+
bn_t half_width;
if (mp_init(&half_width) != BN_OKAY) {
return NULL;
@@ -418,38 +427,38 @@ wnaf_t *bn_wnaf(const bn_t *bn, int w) {
goto exit_val_mod;
}
- result = malloc(sizeof(wnaf_t));
- result->w = w;
- result->length = bn_bit_length(bn) + 1;
- result->data = calloc(result->length, sizeof(int8_t));
-
size_t i = 0;
- while (!bn_is_0(&k) && !(bn_get_sign(&k) == BN_NEG)) {
+ while (mp_cmp_d(&k, 0) == MP_GT) {
if (bn_get_bit(&k, 0) == 1) {
bn_mod(&k, &full_width, &val_mod);
if (mp_cmp(&val_mod, &half_width) == MP_GT) {
if (mp_sub(&val_mod, &full_width, &val_mod) != BN_OKAY) {
- free(result->data);
- free(result);
- result = NULL;
- break;
+ goto exit_result;
}
}
int8_t val = (int8_t) mp_get_i32(&val_mod);
- result->data[i++] = val;
+ arr[i++] = val;
if (mp_sub(&k, &val_mod, &k) != BN_OKAY) {
- free(result->data);
- free(result);
- result = NULL;
- break;
+ goto exit_result;
}
} else {
- result->data[i++] = 0;
+ arr[i++] = 0;
}
bn_rsh(&k, 1, &k);
}
- bn_clear(&val_mod);
+ result = malloc(sizeof(wnaf_t));
+ result->w = w;
+ result->length = i;
+ result->data = calloc(result->length, sizeof(int8_t));
+
+ // Revert
+ for (size_t j = 0; j < i; j++) {
+ result->data[j] = arr[i - j - 1];
+ }
+
+exit_result:
+ bn_clear(&val_mod);
exit_val_mod:
bn_clear(&k);
exit_k:
@@ -463,6 +472,90 @@ wnaf_t *bn_bnaf(const bn_t *bn) {
return bn_wnaf(bn, 2);
}
+void bn_naf_pad_left(wnaf_t *naf, int8_t value, size_t amount) {
+ if (amount == 0) {
+ return;
+ }
+ int8_t *new_data = calloc(naf->length + amount, sizeof(int8_t));
+ for (size_t i = 0; i < naf->length; i++) {
+ new_data[i + amount] = naf->data[i];
+ }
+ for (size_t i = 0; i < amount; i++) {
+ new_data[i] = value;
+ }
+ free(naf->data);
+ naf->data = new_data;
+ naf->length += amount;
+}
+
+void bn_naf_pad_right(wnaf_t *naf, int8_t value, size_t amount) {
+ if (amount == 0) {
+ return;
+ }
+ naf->data = realloc(naf->data, (naf->length + amount) * sizeof(int8_t));
+ for (size_t i = naf->length; i < naf->length + amount; i++) {
+ naf->data[i] = value;
+ }
+ naf->length += amount;
+}
+
+void bn_naf_strip_left(wnaf_t *naf, int8_t value) {
+ size_t i = 0;
+ while (i < naf->length && naf->data[i] == value) {
+ i++;
+ }
+ if (i == 0) {
+ return;
+ }
+ if (i == naf->length) {
+ free(naf->data);
+ naf->data = NULL;
+ naf->length = 0;
+ return;
+ }
+ int8_t *new_data = calloc(naf->length - i, sizeof(int8_t));
+ for (size_t j = 0; j < naf->length - i; j++) {
+ new_data[j] = naf->data[j + i];
+ }
+ free(naf->data);
+ naf->data = new_data;
+ naf->length -= i;
+}
+
+void bn_naf_strip_right(wnaf_t *naf, int8_t value) {
+ size_t i = naf->length;
+ while (i > 0 && naf->data[i - 1] == value) {
+ i--;
+ }
+ if (i == naf->length) {
+ return;
+ }
+ if (i == 0) {
+ free(naf->data);
+ naf->data = NULL;
+ naf->length = 0;
+ return;
+ }
+ naf->data = realloc(naf->data, i * sizeof(int8_t));
+ naf->length = i;
+}
+
+void bn_naf_reverse(wnaf_t *naf) {
+ for (size_t i = 0; i < naf->length / 2; i++) {
+ int8_t temp = naf->data[i];
+ naf->data[i] = naf->data[naf->length - i - 1];
+ naf->data[naf->length - i - 1] = temp;
+ }
+}
+
+void bn_naf_clear(wnaf_t *naf) {
+ if (naf == NULL) {
+ return;
+ }
+ free(naf->data);
+ free(naf);
+}
+
wsliding_t *bn_wsliding_ltr(const bn_t *bn, int w) {
if (w > 8 || w < 2) {
return NULL;
@@ -540,8 +633,8 @@ wsliding_t *bn_wsliding_rtl(const bn_t *bn, int w) {
wsliding_t *result = NULL;
int blen = bn_bit_length(bn);
- uint8_t arr[blen + 2];
- memset(arr, 0, (blen + 2) * sizeof(uint8_t));
+ uint8_t arr[blen + w];
+ memset(arr, 0, (blen + w) * sizeof(uint8_t));
bn_t k;
if (mp_init(&k) != BN_OKAY) {
@@ -555,7 +648,7 @@ wsliding_t *bn_wsliding_rtl(const bn_t *bn, int w) {
}
int i = 0;
- while (!bn_is_0(&k) && !(bn_get_sign(&k) == BN_NEG)) {
+ while (mp_cmp_d(&k, 0) == MP_GT) {
if (!bn_get_bit(&k, 0)) {
arr[i++] = 0;
bn_rsh(&k, 1, &k);
@@ -594,6 +687,14 @@ exit_k:
return result;
}
+void bn_wsliding_clear(wsliding_t *wsliding) {
+ if (wsliding == NULL) {
+ return;
+ }
+ free(wsliding->data);
+ free(wsliding);
+}
+
small_base_t *bn_convert_base_small(const bn_t *bn, int m) {
small_base_t *result = NULL;
@@ -604,7 +705,9 @@ small_base_t *bn_convert_base_small(const bn_t *bn, int m) {
bn_copy(bn, &k);
int len = 0;
- if (mp_log_n(&k, m, &len) != BN_OKAY) {
+ if (mp_cmp_d(&k, 0) == MP_EQ) {
+ len = 0;
+ } else if (mp_log_n(&k, m, &len) != BN_OKAY) {
goto exit_len;
}
@@ -630,6 +733,14 @@ exit_k:
return result;
}
+void bn_small_base_clear(small_base_t *sb) {
+ if (sb == NULL) {
+ return;
+ }
+ free(sb->data);
+ free(sb);
+}
+
large_base_t *bn_convert_base_large(const bn_t *bn, const bn_t *m) {
large_base_t *result = NULL;
@@ -640,7 +751,9 @@ large_base_t *bn_convert_base_large(const bn_t *bn, const bn_t *m) {
bn_copy(bn, &k);
int len = 0;
- if (mp_log(&k, m, &len) != BN_OKAY) {
+ if (mp_cmp_d(&k, 0) == MP_EQ) {
+ len = 0;
+ } else if (mp_log(&k, m, &len) != BN_OKAY) {
goto exit_len;
}
@@ -666,4 +779,72 @@ exit_len:
bn_clear(&k);
exit_k:
return result;
+}
+
+void bn_large_base_clear(large_base_t *lb) {
+ if (lb == NULL) {
+ return;
+ }
+ for (int i = 0; i < lb->length; i++) {
+ bn_clear(&lb->data[i]);
+ }
+ free(lb->data);
+ bn_clear(&lb->m);
+ free(lb);
+}
+
+int32_t bn_booth_word(int32_t digit, int32_t w) {
+ int32_t s = ~((digit >> w) - 1); //s = ~((digit >> w) - 1)
+ int32_t d = (1 << (w + 1)) - digit - 1; //d = (1 << (w + 1)) - digit - 1
+ d = (d & s) | (digit & ~s); // d = (d & s) | (digit & ~s)
+ d = (d >> 1) + (d & 1); //d = (d >> 1) + (d & 1)
+
+ if (s) { //return -d if s else d
+ return -d;
+ } else {
+ return d;
+ }
+}
+
+booth_t *bn_booth(const bn_t *bn, int32_t w, size_t bits) {
+ if (w >= 30) {
+ return NULL;
+ }
+ int32_t mask = (1 << (w + 1)) - 1;
+ bn_t d, m;
+ bn_init(&d);
+ bn_init(&m);
+ bn_from_int(mask, &m);
+
+ size_t len = (bits / w) + 1;
+ booth_t *result = malloc(sizeof(booth_t));
+ result->length = len;
+ result->w = w;
+ result->data = calloc(len, sizeof(int32_t));
+
+ long l = 0;
+ for (long i = bits + (w - (bits % w) - 1); i > 0; i -= w) {
+ int32_t digit;
+ bn_copy(bn, &d);
+ if (i >= w) {
+ bn_rsh(&d, i - w, &d);
+ } else {
+ bn_lsh(&d, w - i, &d);
+ }
+ bn_and(&d, &m, &d);
+ digit = bn_to_int(&d);
+ int32_t val = bn_booth_word(digit, w);
+ result->data[l++] = val;
+ }
+ bn_clear(&d);
+ bn_clear(&m);
+ return result;
+}
+
+void bn_booth_clear(booth_t *booth) {
+ if (booth == NULL) {
+ return;
+ }
+ free(booth->data);
+ free(booth);
} \ No newline at end of file
diff --git a/pyecsca/codegen/bn/bn.h b/pyecsca/codegen/bn/bn.h
index 7c25c22..03526b0 100644
--- a/pyecsca/codegen/bn/bn.h
+++ b/pyecsca/codegen/bn/bn.h
@@ -76,8 +76,16 @@ typedef struct {
bn_t m;
} large_base_t;
+typedef struct {
+ int32_t *data;
+ size_t length;
+ int w;
+} booth_t;
+
void math_init(void);
+extern const int bn_digit_bits;
+
bn_err bn_init(bn_t *bn);
#define bn_init_multi mp_init_multi
bn_err bn_copy(const bn_t *from, bn_t *to);
@@ -86,6 +94,7 @@ void bn_clear(bn_t *bn);
bn_err bn_from_bin(const uint8_t *data, size_t size, bn_t *out);
bn_err bn_from_hex(const char *data, bn_t *out);
+bn_err bn_from_dec(const char *data, bn_t *out);
bn_err bn_from_int(unsigned int value, bn_t *out);
bn_err bn_to_binpad(const bn_t *one, uint8_t *data, size_t size);
@@ -135,11 +144,24 @@ int bn_bit_length(const bn_t *bn);
wnaf_t *bn_wnaf(const bn_t *bn, int w);
wnaf_t *bn_bnaf(const bn_t *bn);
+void bn_naf_pad_left(wnaf_t *naf, int8_t value, size_t amount);
+void bn_naf_pad_right(wnaf_t *naf, int8_t value, size_t amount);
+void bn_naf_strip_left(wnaf_t *naf, int8_t value);
+void bn_naf_strip_right(wnaf_t *naf, int8_t value);
+void bn_naf_reverse(wnaf_t *naf);
+void bn_naf_clear(wnaf_t *naf);
wsliding_t *bn_wsliding_ltr(const bn_t *bn, int w);
wsliding_t *bn_wsliding_rtl(const bn_t *bn, int w);
+void bn_wsliding_clear(wsliding_t *wsliding);
small_base_t *bn_convert_base_small(const bn_t *bn, int m);
+void bn_small_base_clear(small_base_t *sb);
large_base_t *bn_convert_base_large(const bn_t *bn, const bn_t *m);
+void bn_large_base_clear(large_base_t *lb);
+
+int32_t bn_booth_word(int32_t digit, int32_t w);
+booth_t *bn_booth(const bn_t *bn, int32_t w, size_t bits);
+void bn_booth_clear(booth_t *booth);
#endif //BN_H_ \ No newline at end of file
diff --git a/pyecsca/codegen/builder.py b/pyecsca/codegen/builder.py
index 63e7801..b7b5e22 100644
--- a/pyecsca/codegen/builder.py
+++ b/pyecsca/codegen/builder.py
@@ -81,7 +81,7 @@ def get_multiplier(ctx: click.Context, param, value: Optional[str]) -> Optional[
if value is None:
return None
res = re.match(
- "(?P<name>[a-zA-Z\-]+)\((?P<args>([a-zA-Z_]+ *= *[a-zA-Z0-9.]+, ?)*?([a-zA-Z_]+ *= *[a-zA-Z0-9.]+)*)\)",
+ r"(?P<name>[a-zA-Z\-]+)\((?P<args>([a-zA-Z_]+ *= *[a-zA-Z0-9.]+, ?)*?([a-zA-Z_]+ *= *[a-zA-Z0-9.]+)*)\)",
value)
if not res:
raise click.BadParameter("Couldn't parse multiplier spec: {}.".format(value))
diff --git a/pyecsca/codegen/render.py b/pyecsca/codegen/render.py
index 692deab..b1c1477 100644
--- a/pyecsca/codegen/render.py
+++ b/pyecsca/codegen/render.py
@@ -30,7 +30,7 @@ from pyecsca.ec.mult import (
BGMWMultiplier,
CombMultiplier,
AccumulationOrder,
- ProcessingDirection
+ ProcessingDirection, WindowBoothMultiplier
)
from pyecsca.ec.op import OpType, CodeOp
@@ -227,6 +227,7 @@ def render_scalarmult_impl(scalarmult: ScalarMultiplier) -> str:
DifferentialLadderMultiplier=DifferentialLadderMultiplier,
BinaryNAFMultiplier=BinaryNAFMultiplier,
WindowNAFMultiplier=WindowNAFMultiplier,
+ WindowBoothMultiplier=WindowBoothMultiplier,
SlidingWindowMultiplier=SlidingWindowMultiplier,
FixedWindowLTRMultiplier=FixedWindowLTRMultiplier,
FullPrecompMultiplier=FullPrecompMultiplier,
diff --git a/pyecsca/codegen/templates/formula_add.c b/pyecsca/codegen/templates/formula_add.c
index 6026601..48bab07 100644
--- a/pyecsca/codegen/templates/formula_add.c
+++ b/pyecsca/codegen/templates/formula_add.c
@@ -16,16 +16,17 @@ __attribute__((noinline)) void point_add(const point_t *one, const point_t *othe
{%- if short_circuit %}
if (point_equals(one, curve->neutral)) {
point_set(other, out_one);
- return;
+ goto end;
}
if (point_equals(other, curve->neutral)) {
point_set(one, out_one);
- return;
+ goto end;
}
{%- endif %}
{{ ops.render_initializations(initializations) }}
{{ ops.render_ops(operations) }}
{{ ops.render_returns(returns) }}
//NOP_128();
+end:
{{ end_action("add") }}
} \ No newline at end of file
diff --git a/pyecsca/codegen/templates/formula_dbl.c b/pyecsca/codegen/templates/formula_dbl.c
index 451b0ee..e1cfa15 100644
--- a/pyecsca/codegen/templates/formula_dbl.c
+++ b/pyecsca/codegen/templates/formula_dbl.c
@@ -16,12 +16,13 @@ __attribute__((noinline)) void point_dbl(const point_t *one, const curve_t *curv
{%- if short_circuit %}
if (point_equals(one, curve->neutral)) {
point_set(one, out_one);
- return;
+ goto end;
}
{%- endif %}
{{ ops.render_initializations(initializations) }}
{{ ops.render_ops(operations) }}
{{ ops.render_returns(returns) }}
//NOP_128();
+end:
{{ end_action("dbl") }}
} \ No newline at end of file
diff --git a/pyecsca/codegen/templates/formula_neg.c b/pyecsca/codegen/templates/formula_neg.c
index 93fbe20..fa96c63 100644
--- a/pyecsca/codegen/templates/formula_neg.c
+++ b/pyecsca/codegen/templates/formula_neg.c
@@ -16,12 +16,13 @@ __attribute__((noinline)) void point_neg(const point_t *one, const curve_t *curv
{%- if short_circuit %}
if (point_equals(one, curve->neutral)) {
point_set(one, out_one);
- return;
+ goto end;
}
{%- endif %}
{{ ops.render_initializations(initializations) }}
{{ ops.render_ops(operations) }}
{{ ops.render_returns(returns) }}
//NOP_128();
+end:
{{ end_action("neg") }}
} \ No newline at end of file
diff --git a/pyecsca/codegen/templates/formula_scl.c b/pyecsca/codegen/templates/formula_scl.c
index 48ac52e..f1471a2 100644
--- a/pyecsca/codegen/templates/formula_scl.c
+++ b/pyecsca/codegen/templates/formula_scl.c
@@ -16,12 +16,13 @@ __attribute__((noinline)) void point_scl(const point_t *one, const curve_t *curv
{%- if short_circuit %}
if (point_equals(one, curve->neutral)) {
point_set(one, out_one);
- return;
+ goto end;
}
{%- endif %}
{{ ops.render_initializations(initializations) }}
{{ ops.render_ops(operations) }}
{{ ops.render_returns(returns) }}
//NOP_128();
+end:
{{ end_action("scl") }}
} \ No newline at end of file
diff --git a/pyecsca/codegen/templates/formula_tpl.c b/pyecsca/codegen/templates/formula_tpl.c
index d280bad..0b4cd64 100644
--- a/pyecsca/codegen/templates/formula_tpl.c
+++ b/pyecsca/codegen/templates/formula_tpl.c
@@ -16,12 +16,13 @@ __attribute__((noinline)) void point_tpl(const point_t *one, const curve_t *curv
{%- if short_circuit %}
if (point_equals(one, curve->neutral)) {
point_set(one, out_one);
- return;
+ goto end;
}
{%- endif %}
{{ ops.render_initializations(initializations) }}
{{ ops.render_ops(operations) }}
{{ ops.render_returns(returns) }}
//NOP_128();
+end:
{{ end_action("tpl") }}
} \ No newline at end of file
diff --git a/pyecsca/codegen/templates/mult.c b/pyecsca/codegen/templates/mult.c
index 0144e36..4070952 100644
--- a/pyecsca/codegen/templates/mult.c
+++ b/pyecsca/codegen/templates/mult.c
@@ -31,6 +31,10 @@
{% include "mult_wnaf.c" %}
+{%- elif isinstance(scalarmult, WindowBoothMultiplier) -%}
+
+ {% include "mult_booth.c" %}
+
{%- elif isinstance(scalarmult, SlidingWindowMultiplier) -%}
{% include "mult_sliding_w.c" %}
diff --git a/pyecsca/codegen/templates/mult_bgmw.c b/pyecsca/codegen/templates/mult_bgmw.c
index 5298fb1..e2e8c72 100644
--- a/pyecsca/codegen/templates/mult_bgmw.c
+++ b/pyecsca/codegen/templates/mult_bgmw.c
@@ -48,8 +48,7 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin
{%- endif %}
point_accumulate(a, b, curve, a);
}
- free(bs->data);
- free(bs);
+ bn_small_base_clear(bs);
{%- if "scl" in scalarmult.formulas %}
point_scl(a, curve, a);
diff --git a/pyecsca/codegen/templates/mult_bnaf.c b/pyecsca/codegen/templates/mult_bnaf.c
index 68d1569..9c760af 100644
--- a/pyecsca/codegen/templates/mult_bnaf.c
+++ b/pyecsca/codegen/templates/mult_bnaf.c
@@ -1,51 +1,109 @@
#include "mult.h"
#include "point.h"
-point_t *scalar_mult_ltr(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf) {
- point_t *q = point_copy(curve->neutral);
- for (long i = naf->length - 1; i >= 0; i--) {
+point_t *scalar_mult_ltr(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf, size_t bits) {
+ point_t *q;
+ long i;
+ {% if scalarmult.complete %}
+ bn_naf_pad_left(naf, 0, (bits + 1) - naf->length);
+ q = point_copy(curve->neutral);
+ i = 0;
+ {% else %}
+ bn_naf_strip_left(naf, 0);
+ int8_t val = naf->data[0];
+ if (val == 1) {
+ q = point_copy(point);
+ } else if (val == -1) {
+ q = point_copy(neg);
+ }
+ i = 1;
+ {% endif %}
+
+ {% if scalarmult.always %}
+ point_t *q_copy = point_new();
+ {% endif %}
+ for (; i < naf->length; i++) {
point_dbl(q, curve, q);
+
+ {% if scalarmult.always %}
+ point_set(q, q_copy);
+ {% endif %}
+
if (naf->data[i] == 1) {
point_accumulate(q, point, curve, q);
+ {% if scalarmult.always %}
+ point_accumulate(q_copy, neg, curve, q_copy);
+ {% endif %}
} else if (naf->data[i] == -1) {
point_accumulate(q, neg, curve, q);
+ {% if scalarmult.always %}
+ point_accumulate(q_copy, point, curve, q_copy);
+ {% endif %}
}
}
+ {% if scalarmult.always %}
+ point_free(q_copy);
+ {% endif %}
return q;
}
-point_t* scalar_mult_rtl(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf) {
- point_t *r = point_copy(point);
- point_t *q = point_copy(curve->neutral);
- point_t *r_neg = point_new();
+point_t* scalar_mult_rtl(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf, size_t bits) {
+ {% if scalarmult.always %}
+ point_t *r_copy = point_new();
+ {% endif %}
+
+ {% if scalarmult.complete %}
+ bn_naf_pad_left(naf, 0, (bits + 1) - naf->length);
+ {% endif %}
+
+ bn_naf_reverse(naf);
+
+ point_t *q = point_copy(point);
+ point_t *r = point_copy(curve->neutral);
+ point_t *q_neg = point_new();
for (long i = 0; i < naf->length; i++) {
+ {% if scalarmult.always %}
+ point_set(r, r_copy);
+ {% endif %}
+
if (naf->data[i] == 1) {
- point_accumulate(q, r, curve, q);
+ point_accumulate(r, q, curve, r);
+ {% if scalarmult.always %}
+ point_neg(q, curve, q_neg);
+ point_accumulate(r_copy, q_neg, curve, r_copy);
+ {% endif %}
} else if (naf->data[i] == -1) {
- point_neg(r, curve, r_neg);
- point_accumulate(q, r_neg, curve, q);
+ point_neg(q, curve, q_neg);
+ point_accumulate(r, q_neg, curve, r);
+ {% if scalarmult.always %}
+ point_accumulate(r_copy, q, curve, r_copy);
+ {% endif %}
}
- point_dbl(r, curve, r);
+ point_dbl(q, curve, q);
}
- point_free(r_neg);
- point_free(r);
+ point_free(q_neg);
+ point_free(q);
- return q;
+ {% if scalarmult.always %}
+ point_free(r_copy);
+ {% endif %}
+ return r;
}
static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, point_t *out) {
point_t *neg = point_new();
point_neg(point, curve, neg);
+
wnaf_t *naf = bn_bnaf(scalar);
+ size_t bits = bn_bit_length(&curve->n);
{% if scalarmult.direction == ProcessingDirection.LTR %}
- point_t *q = scalar_mult_ltr(point, neg, curve, naf);
+ point_t *q = scalar_mult_ltr(point, neg, curve, naf, bits);
{% elif scalarmult.direction == ProcessingDirection.RTL %}
- point_t *q = scalar_mult_rtl(point, neg, curve, naf);
+ point_t *q = scalar_mult_rtl(point, neg, curve, naf, bits);
{% endif %}
- free(naf->data);
- free(naf);
+ bn_naf_clear(naf);
{%- if "scl" in scalarmult.formulas %}
point_scl(q, curve, q);
diff --git a/pyecsca/codegen/templates/mult_booth.c b/pyecsca/codegen/templates/mult_booth.c
new file mode 100644
index 0000000..4c1ba40
--- /dev/null
+++ b/pyecsca/codegen/templates/mult_booth.c
@@ -0,0 +1,78 @@
+#include "mult.h"
+#include "point.h"
+
+
+
+static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, point_t *out) {
+ point_t *points[{{ 2 ** (scalarmult.width - 1) }}];
+ {% if scalarmult.precompute_negation %}
+ point_t *points_neg[{{ 2 ** (scalarmult.width - 1) }}];
+ {% endif %}
+
+ point_t *current = point_copy(point);
+ point_t *dbl = point_new();
+ point_dbl(current, curve, dbl);
+ points[0] = point_copy(current);
+ {% if scalarmult.precompute_negation %}
+ points_neg[0] = point_new();
+ point_neg(points[0], curve, points_neg[0]);
+ {% endif %}
+ {% if scalarmult.width > 1 %}
+ points[1] = point_copy(dbl);
+ {% if scalarmult.precompute_negation %}
+ points_neg[1] = point_new();
+ point_neg(points[1], curve, points_neg[1]);
+ {% endif %}
+ {% endif %}
+
+ point_set(dbl, current);
+ {% if scalarmult.width > 2 %}
+ for (long i = 2; i < {{ 2 ** (scalarmult.width - 1) }}; i++) {
+ point_add(current, point, curve, current);
+ points[i] = point_copy(current);
+ {% if scalarmult.precompute_negation %}
+ points_neg[i] = point_new();
+ point_neg(points[i], curve, points_neg[i]);
+ {% endif %}
+ }
+ {% endif %}
+ point_free(current);
+ point_free(dbl);
+
+ size_t bits = bn_bit_length(&curve->n);
+
+ booth_t *bs = bn_booth(scalar, {{ scalarmult.width }}, bits);
+
+ point_t *q = point_copy(curve->neutral);
+ point_t *neg = point_new();
+ for (long i = 0; i < bs->length; i++) {
+ for (long j = 0; j < {{ scalarmult.width }}; j++) {
+ point_dbl(q, curve, q);
+ }
+ int32_t val = bs->data[i];
+ if (val > 0) {
+ point_accumulate(q, points[val - 1], curve, q);
+ } else if (val < 0) {
+ {% if scalarmult.precompute_negation %}
+ point_accumulate(q, points_neg[-val - 1], curve, q);
+ {% else %}
+ point_neg(points[-val - 1], curve, neg);
+ point_accumulate(q, neg, curve, q);
+ {% endif %}
+ }
+ }
+ bn_booth_clear(bs);
+ point_free(neg);
+
+ {%- if "scl" in scalarmult.formulas %}
+ point_scl(q, curve, q);
+ {%- endif %}
+ point_set(q, out);
+ for (long i = 0; i < {{ 2 ** (scalarmult.width - 1) }}; i++) {
+ point_free(points[i]);
+ {% if scalarmult.precompute_negation %}
+ point_free(points_neg[i]);
+ {% endif %}
+ }
+ point_free(q);
+} \ No newline at end of file
diff --git a/pyecsca/codegen/templates/mult_comb.c b/pyecsca/codegen/templates/mult_comb.c
index 9df9796..1fbb5a3 100644
--- a/pyecsca/codegen/templates/mult_comb.c
+++ b/pyecsca/codegen/templates/mult_comb.c
@@ -39,6 +39,10 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin
bn_from_int(1, &base);
bn_lsh(&base, d, &base);
+ {% if scalarmult.always %}
+ point_t *dummy = point_new();
+ {% endif %}
+
large_base_t *bs = bn_convert_base_large(scalar, &base);
for (int i = d - 1; i >= 0; i--) {
point_dbl(q, curve, q);
@@ -50,14 +54,18 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin
}
if (word) {
point_accumulate(q, points[word], curve, q);
+ } else {
+ {% if scalarmult.always %}
+ int j = i % {{ 2**scalarmult.width }};
+ if (j == 0) {
+ point_accumulate(q, point, curve, dummy);
+ } else {
+ point_accumulate(q, points[j], curve, dummy);
+ }
+ {% endif %}
}
}
- for (int i = 0; i < bs->length; i++) {
- bn_clear(&bs->data[i]);
- }
- free(bs->data);
- bn_clear(&bs->m);
- free(bs);
+ bn_large_base_clear(bs);
bn_clear(&base);
diff --git a/pyecsca/codegen/templates/mult_fixed_w.c b/pyecsca/codegen/templates/mult_fixed_w.c
index b0a4bb0..6a079b3 100644
--- a/pyecsca/codegen/templates/mult_fixed_w.c
+++ b/pyecsca/codegen/templates/mult_fixed_w.c
@@ -20,18 +20,22 @@ void scalar_mult_by_m_base(point_t *point, curve_t *curve) {
static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, point_t *out) {
point_t *q = point_copy(curve->neutral);
- point_t *points[{{ scalarmult.m }}];
+ point_t *points[{{ scalarmult.m - 1 }}];
point_t *current = point_copy(point);
point_t *dbl = point_new();
point_dbl(current, curve, dbl);
points[0] = point_copy(current);
- points[1] = point_copy(dbl);
+ {% if scalarmult.m > 2 %}
+ points[1] = point_copy(dbl);
+ {% endif %}
point_set(dbl, current);
- for (long i = 2; i < {{ scalarmult.m }}; i++) {
- point_add(current, point, curve, current);
- points[i] = point_copy(current);
- }
+ {% if scalarmult.m > 3 %}
+ for (long i = 2; i < {{ scalarmult.m - 1 }}; i++) {
+ point_add(current, point, curve, current);
+ points[i] = point_copy(current);
+ }
+ {% endif %}
point_free(current);
point_free(dbl);
@@ -49,14 +53,13 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin
point_accumulate(q, points[val-1], curve, q);
}
}
- free(bs->data);
- free(bs);
+ bn_small_base_clear(bs);
{%- if "scl" in scalarmult.formulas %}
point_scl(q, curve, q);
{%- endif %}
point_set(q, out);
- for (long i = 0; i < {{ scalarmult.m }}; i++) {
+ for (long i = 0; i < {{ scalarmult.m - 1 }}; i++) {
point_free(points[i]);
}
point_free(q);
diff --git a/pyecsca/codegen/templates/mult_rtl.c b/pyecsca/codegen/templates/mult_rtl.c
index 71949b4..119ee7e 100644
--- a/pyecsca/codegen/templates/mult_rtl.c
+++ b/pyecsca/codegen/templates/mult_rtl.c
@@ -5,6 +5,12 @@ void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, point_t *ou
point_t *q = point_copy(point);
point_t *r = point_copy(curve->neutral);
+ {% if scalarmult.complete %}
+ size_t bits = bn_bit_length(&curve->n);
+ {% else %}
+ size_t bits = bn_bit_length(scalar);
+ {% endif %}
+
{%- if scalarmult.always %}
point_t *dummy = point_new();
{%- endif %}
@@ -12,7 +18,7 @@ void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, point_t *ou
bn_init(&copy);
bn_copy(scalar, &copy);
- while (!bn_is_0(&copy)) {
+ for (int i = 0; i < bits; i++) {
if (bn_get_bit(&copy, 0) == 1) {
point_accumulate(r, q, curve, r);
} else {
diff --git a/pyecsca/codegen/templates/mult_simple_ldr.c b/pyecsca/codegen/templates/mult_simple_ldr.c
index ceb257a..33bfcd9 100644
--- a/pyecsca/codegen/templates/mult_simple_ldr.c
+++ b/pyecsca/codegen/templates/mult_simple_ldr.c
@@ -11,7 +11,7 @@ void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, point_t *ou
{%- endif %}
for (int i = nbits; i >= 0; i--) {
- if (bn_get_bit(scalar, i) == 1) {
+ if (bn_get_bit(scalar, i) == 0) {
point_add(p0, p1, curve, p1);
point_dbl(p0, curve, p0);
} else {
diff --git a/pyecsca/codegen/templates/mult_sliding_w.c b/pyecsca/codegen/templates/mult_sliding_w.c
index 1e80a84..347c313 100644
--- a/pyecsca/codegen/templates/mult_sliding_w.c
+++ b/pyecsca/codegen/templates/mult_sliding_w.c
@@ -34,8 +34,7 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin
{%- endif %}
point_set(q, out);
- free(ws->data);
- free(ws);
+ bn_wsliding_clear(ws);
for (long i = 0; i < {{ 2 ** (scalarmult.width - 1) }}; i++) {
point_free(points[i]);
}
diff --git a/pyecsca/codegen/templates/mult_wnaf.c b/pyecsca/codegen/templates/mult_wnaf.c
index 3c5f2b2..569e78b 100644
--- a/pyecsca/codegen/templates/mult_wnaf.c
+++ b/pyecsca/codegen/templates/mult_wnaf.c
@@ -26,7 +26,7 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin
wnaf_t *naf = bn_wnaf(scalar, {{ scalarmult.width }});
- for (long i = naf->length - 1; i >= 0; i--) {
+ for (long i = 0; i < naf->length; i++) {
point_dbl(q, curve, q);
int8_t val = naf->data[i];
if (val > 0) {
@@ -40,8 +40,7 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin
{%- endif %}
}
}
- free(naf->data);
- free(naf);
+ bn_naf_clear(naf);
{%- if "scl" in scalarmult.formulas %}
point_scl(q, curve, q);
diff --git a/test/.gitignore b/test/.gitignore
new file mode 100644
index 0000000..05cabee
--- /dev/null
+++ b/test/.gitignore
@@ -0,0 +1,2 @@
+test_bn_val
+test_bn_san
diff --git a/test/Makefile b/test/Makefile
index d13b487..d0068a2 100644
--- a/test/Makefile
+++ b/test/Makefile
@@ -1,3 +1,13 @@
-test_bn: test_bn.c ../pyecsca/codegen/bn/bn.c
- gcc -o $@ $^ -fsanitize=address -fsanitize=undefined -I ../pyecsca/codegen/ -I ../pyecsca/codegen/tommath/ -L ../pyecsca/codegen/tommath/ -l:libtommath-HOST.a \ No newline at end of file
+test_bn_san: test_bn.c ../pyecsca/codegen/bn/bn.c
+ gcc -g -o $@ $^ -fsanitize=address -fsanitize=undefined -I ../pyecsca/codegen/ -I ../pyecsca/codegen/tommath/ -L ../pyecsca/codegen/tommath/ -l:libtommath-HOST.a
+
+test_bn_val: test_bn.c ../pyecsca/codegen/bn/bn.c
+ gcc -g -o $@ $^ -I ../pyecsca/codegen/ -I ../pyecsca/codegen/tommath/ -L ../pyecsca/codegen/tommath/ -l:libtommath-HOST.a
+
+
+test: test_bn_san test_bn_val
+ ./test_bn_san
+ valgrind -q ./test_bn_val
+
+.PHONY: test \ No newline at end of file
diff --git a/test/conftest.py b/test/conftest.py
index 1c1449a..80faf71 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -1,5 +1,6 @@
import pytest
+from pyecsca.ec.mult import *
from pyecsca.ec.params import get_params, DomainParameters
@@ -11,3 +12,97 @@ def secp128r1() -> DomainParameters:
@pytest.fixture(scope="session")
def curve25519() -> DomainParameters:
return get_params("other", "Curve25519", "xz")
+
+
+# fmt: off
+window_mults = [
+ (SlidingWindowMultiplier, dict(width=2, recoding_direction=ProcessingDirection.LTR)),
+ (SlidingWindowMultiplier, dict(width=3, recoding_direction=ProcessingDirection.LTR)),
+ (SlidingWindowMultiplier, dict(width=4, recoding_direction=ProcessingDirection.LTR)),
+ (SlidingWindowMultiplier, dict(width=5, recoding_direction=ProcessingDirection.LTR)),
+ (SlidingWindowMultiplier, dict(width=6, recoding_direction=ProcessingDirection.LTR)),
+ (SlidingWindowMultiplier, dict(width=2, recoding_direction=ProcessingDirection.RTL)),
+ (SlidingWindowMultiplier, dict(width=3, recoding_direction=ProcessingDirection.RTL)),
+ (SlidingWindowMultiplier, dict(width=4, recoding_direction=ProcessingDirection.RTL)),
+ (SlidingWindowMultiplier, dict(width=5, recoding_direction=ProcessingDirection.RTL)),
+ (SlidingWindowMultiplier, dict(width=6, recoding_direction=ProcessingDirection.RTL)),
+ (FixedWindowLTRMultiplier, dict(m=2**1)),
+ (FixedWindowLTRMultiplier, dict(m=2**2)),
+ (FixedWindowLTRMultiplier, dict(m=2**3)),
+ (FixedWindowLTRMultiplier, dict(m=2**4)),
+ (FixedWindowLTRMultiplier, dict(m=2**5)),
+ (FixedWindowLTRMultiplier, dict(m=2**6)),
+ (WindowBoothMultiplier, dict(width=2)),
+ (WindowBoothMultiplier, dict(width=3)),
+ (WindowBoothMultiplier, dict(width=4)),
+ (WindowBoothMultiplier, dict(width=5)),
+ (WindowBoothMultiplier, dict(width=6))
+]
+naf_mults = [
+ (WindowNAFMultiplier, dict(width=2)),
+ (WindowNAFMultiplier, dict(width=3)),
+ (WindowNAFMultiplier, dict(width=4)),
+ (WindowNAFMultiplier, dict(width=5)),
+ (WindowNAFMultiplier, dict(width=6)),
+ (BinaryNAFMultiplier, dict(always=False, direction=ProcessingDirection.LTR)),
+ (BinaryNAFMultiplier, dict(always=False, direction=ProcessingDirection.RTL)),
+ (BinaryNAFMultiplier, dict(always=True, direction=ProcessingDirection.LTR)),
+ (BinaryNAFMultiplier, dict(always=True, direction=ProcessingDirection.RTL)),
+ (BinaryNAFMultiplier, dict(complete=False, always=False, direction=ProcessingDirection.LTR)),
+ (BinaryNAFMultiplier, dict(complete=False, always=False, direction=ProcessingDirection.RTL)),
+ (BinaryNAFMultiplier, dict(complete=False, always=True, direction=ProcessingDirection.LTR)),
+ (BinaryNAFMultiplier, dict(complete=False, always=True, direction=ProcessingDirection.RTL))
+]
+comb_mults = [
+ (CombMultiplier, dict(width=2, always=True)),
+ (CombMultiplier, dict(width=3, always=True)),
+ (CombMultiplier, dict(width=4, always=True)),
+ (CombMultiplier, dict(width=5, always=True)),
+ (CombMultiplier, dict(width=6, always=True)),
+ (CombMultiplier, dict(width=2, always=False)),
+ (CombMultiplier, dict(width=3, always=False)),
+ (CombMultiplier, dict(width=4, always=False)),
+ (CombMultiplier, dict(width=5, always=False)),
+ (CombMultiplier, dict(width=6, always=False)),
+ (BGMWMultiplier, dict(width=2, direction=ProcessingDirection.LTR)),
+ (BGMWMultiplier, dict(width=3, direction=ProcessingDirection.LTR)),
+ (BGMWMultiplier, dict(width=4, direction=ProcessingDirection.LTR)),
+ (BGMWMultiplier, dict(width=5, direction=ProcessingDirection.LTR)),
+ (BGMWMultiplier, dict(width=6, direction=ProcessingDirection.LTR)),
+ (BGMWMultiplier, dict(width=2, direction=ProcessingDirection.RTL)),
+ (BGMWMultiplier, dict(width=3, direction=ProcessingDirection.RTL)),
+ (BGMWMultiplier, dict(width=4, direction=ProcessingDirection.RTL)),
+ (BGMWMultiplier, dict(width=5, direction=ProcessingDirection.RTL)),
+ (BGMWMultiplier, dict(width=6, direction=ProcessingDirection.RTL))
+]
+binary_mults = [
+ (LTRMultiplier, dict(always=False, complete=True)),
+ (LTRMultiplier, dict(always=True, complete=True)),
+ (LTRMultiplier, dict(always=False, complete=False)),
+ (LTRMultiplier, dict(always=True, complete=False)),
+ (RTLMultiplier, dict(always=False, complete=True)),
+ (RTLMultiplier, dict(always=True, complete=True)),
+ (RTLMultiplier, dict(always=False, complete=False)),
+ (RTLMultiplier, dict(always=True, complete=False)),
+ (CoronMultiplier, dict())
+]
+other_mults = [
+ (FullPrecompMultiplier, dict(always=False, complete=True)),
+ (FullPrecompMultiplier, dict(always=True, complete=True)),
+ (FullPrecompMultiplier, dict(always=False, complete=False)),
+ (FullPrecompMultiplier, dict(always=True, complete=False)),
+ (SimpleLadderMultiplier, dict(complete=True)),
+ (SimpleLadderMultiplier, dict(complete=False))
+]
+# fmt: on
+
+
+@pytest.fixture(
+ scope="session",
+ params=window_mults + naf_mults + comb_mults + binary_mults + other_mults,
+ ids=lambda p: "{}-{}".format(
+ p[0].__name__, ":".join(f"{k}={v}" for k, v in p[1].items())
+ ),
+)
+def simple_multiplier(request):
+ return request.param
diff --git a/test/gdb_script.py b/test/gdb_script.py
new file mode 100644
index 0000000..93d4472
--- /dev/null
+++ b/test/gdb_script.py
@@ -0,0 +1,89 @@
+import os
+import json
+
+import gdb
+
+trace_file = open(os.environ["TRACE_FILE"], "w")
+
+
+def extract_bn(bn):
+ data_ptr = bn["dp"]
+ used = int(bn["used"])
+ bs = int(gdb.lookup_global_symbol("bn_digit_bits").value())
+ result = 0
+ for i in range(used):
+ limb = int((data_ptr + i).dereference())
+ result += limb << (i * bs)
+ return result
+
+
+def extract_point(point):
+ result = {}
+ for field in point.type.fields():
+ field_name = field.name
+ if len(field_name) != 1:
+ continue
+ field_value = point[field_name]
+ result[field_name] = extract_bn(field_value)
+ return result
+
+
+class TraceFunction(gdb.Breakpoint):
+ def stop(self):
+ try:
+ set_bp.enabled = True
+ frame = gdb.newest_frame()
+ block = frame.block()
+ print(frame.name(), flush=True, file=trace_file)
+ out = []
+ for sym in block:
+ if sym.is_argument:
+ name = sym.name
+ try:
+ value = frame.read_var(name)
+ except Exception as e:
+ value = f"<unavailable: {e}>"
+ deref = value.dereference()
+ if deref.type.name == "point_t":
+ if "out" in name:
+ out.append(deref)
+ else:
+ pt = extract_point(deref)
+ print(f"{name}: {json.dumps(pt)}", flush=True, file=trace_file)
+ bp = TraceExit(frame)
+ bp.silent = True
+ bp.target = out
+ except RuntimeError:
+ pass
+ return False # Continue execution
+
+
+class TraceExit(gdb.FinishBreakpoint):
+ def stop(self):
+ set_bp.enabled = False
+ for i, point in enumerate(self.target):
+ print(f"out_{i}: {json.dumps(extract_point(point))}", flush=True, file=trace_file)
+ return False # Continue execution
+
+
+def register_bp(name):
+ if gdb.lookup_global_symbol(name) is not None:
+ bp = TraceFunction(name)
+ bp.silent = True
+ return bp
+ return None
+
+
+register_bp("point_add")
+register_bp("point_dadd")
+register_bp("point_dadd")
+register_bp("point_ladd")
+register_bp("point_dbl")
+register_bp("point_neg")
+register_bp("point_scl")
+register_bp("point_tpl")
+set_bp = register_bp("point_set")
+set_bp.enabled = False
+
+gdb.execute("run")
+trace_file.close()
diff --git a/test/test_bn.c b/test/test_bn.c
index 66de22f..b04c936 100644
--- a/test/test_bn.c
+++ b/test/test_bn.c
@@ -4,88 +4,403 @@
int test_wsliding_ltr() {
printf("test_wsliding_ltr: ");
+ int failed = 0;
+ struct {
+ const char *value;
+ int w;
+ int expected_len;
+ uint8_t expected[100]; // max expected length
+ } cases[] = {
+ // sliding_window_ltr begin
+ {"181", 3, 6, {5, 0, 0, 5, 0, 1}},
+ {"1", 3, 1, {1}},
+ {"1234", 2, 11, {1, 0, 0, 0, 3, 0, 1, 0, 0, 1, 0}},
+ {"170", 4, 6, {5, 0, 0, 0, 5, 0}},
+ {"554", 5, 6, {17, 0, 0, 0, 5, 0}},
+ {"123456789123456789123456789", 5, 83, {25, 1, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 23, 0, 0, 0, 0, 25, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 29, 0, 0, 0, 0, 17, 0, 0, 0, 0, 19, 0, 0, 0, 0, 29, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, 21}}
+ // sliding_window_ltr end
+ };
+ int num_cases = sizeof(cases) / sizeof(cases[0]);
+ for (int t = 0; t < num_cases; t++) {
+ bn_t bn;
+ bn_init(&bn);
+ bn_from_dec(cases[t].value, &bn);
+ wsliding_t *ws = bn_wsliding_ltr(&bn, cases[t].w);
+ if (ws == NULL) {
+ printf("Case %d: NULL\n", t);
+ failed++;
+ bn_clear(&bn);
+ continue;
+ }
+ if (ws->length != cases[t].expected_len) {
+ printf("Case %d: Bad length (%li instead of %i)\n", t, ws->length, cases[t].expected_len);
+ failed++;
+ }
+ for (int i = 0; i < cases[t].expected_len; i++) {
+ if (ws->data[i] != cases[t].expected[i]) {
+ printf("Case %d: Bad data at %d (%i instead of %i)\n", t, i, ws->data[i], cases[t].expected[i]);
+ failed++;
+ break;
+ }
+ }
+ bn_clear(&bn);
+ free(ws->data);
+ free(ws);
+ }
+ if (failed == 0) {
+ printf("OK\n");
+ } else {
+ printf("FAILED (%d cases)\n", failed);
+ }
+ return failed;
+}
+
+int test_wsliding_rtl() {
+ printf("test_wsliding_rtl: ");
+ int failed = 0;
+ struct {
+ const char *value;
+ int w;
+ int expected_len;
+ uint8_t expected[100]; // max expected length
+ } cases[] = {
+ // sliding_window_rtl begin
+ {"181", 3, 8, {1, 0, 0, 3, 0, 0, 0, 5}},
+ {"1", 3, 1, {1}},
+ {"1234", 2, 11, {1, 0, 0, 0, 3, 0, 1, 0, 0, 1, 0}},
+ {"170", 4, 6, {5, 0, 0, 0, 5, 0}},
+ {"554", 5, 10, {1, 0, 0, 0, 0, 0, 0, 0, 21, 0}},
+ {"123456789123456789123456789",5, 87, {1, 0, 0, 0, 0, 19, 0, 0, 0, 0, 1, 0, 0, 0, 0, 29, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 17, 0, 0, 0, 0, 27, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, 21}}
+ // sliding_window_rtl end
+ };
+ int num_cases = sizeof(cases) / sizeof(cases[0]);
+ for (int t = 0; t < num_cases; t++) {
+ bn_t bn;
+ bn_init(&bn);
+ bn_from_dec(cases[t].value, &bn);
+ wsliding_t *ws = bn_wsliding_rtl(&bn, cases[t].w);
+ if (ws == NULL) {
+ printf("Case %d: NULL\n", t);
+ failed++;
+ bn_clear(&bn);
+ continue;
+ }
+ if (ws->length != cases[t].expected_len) {
+ printf("Case %d: Bad length (%li instead of %i)\n", t, ws->length, cases[t].expected_len);
+ failed++;
+ }
+ for (int i = 0; i < cases[t].expected_len; i++) {
+ if (ws->data[i] != cases[t].expected[i]) {
+ printf("Case %d: Bad data at %d (%i instead of %i)\n", t, i, ws->data[i], cases[t].expected[i]);
+ failed++;
+ break;
+ }
+ }
+ bn_clear(&bn);
+ free(ws->data);
+ free(ws);
+ }
+ if (failed == 0) {
+ printf("OK\n");
+ } else {
+ printf("FAILED (%d cases)\n", failed);
+ }
+ return failed;
+}
+
+int test_convert_base_small() {
+ printf("test_convert_base_small: ");
+ int failed = 0;
+ struct {
+ const char *value;
+ int base;
+ int expected_len;
+ uint8_t expected[100]; // max expected length
+ } cases[] = {
+ // convert_base_small begin
+ {"11", 2, 4, {1, 1, 0, 1}},
+ {"255", 2, 8, {1, 1, 1, 1, 1, 1, 1, 1}},
+ {"1234", 10, 4, {4, 3, 2, 1}},
+ {"0", 2, 1, {0}},
+ {"1", 2, 1, {1}},
+ {"123456789123456789123456789", 16, 22, {5, 1, 15, 5, 4, 0, 12, 7, 15, 9, 1, 11, 3, 14, 2, 15, 13, 15, 14, 1, 6, 6}}
+ // convert_base_small end
+ };
+ int num_cases = sizeof(cases) / sizeof(cases[0]);
+ for (int t = 0; t < num_cases; t++) {
+ bn_t bn;
+ bn_init(&bn);
+ bn_from_dec(cases[t].value, &bn);
+ small_base_t *bs = bn_convert_base_small(&bn, cases[t].base);
+ if (bs == NULL) {
+ printf("Case %d: NULL\n", t);
+ failed++;
+ bn_clear(&bn);
+ continue;
+ }
+ if (bs->length != cases[t].expected_len) {
+ printf("Case %d: Bad length (%li instead of %i)\n", t, bs->length, cases[t].expected_len);
+ failed++;
+ }
+ for (int i = 0; i < cases[t].expected_len; i++) {
+ if (bs->data[i] != cases[t].expected[i]) {
+ printf("Case %d: Bad data at %d (%i instead of %i)\n", t, i, bs->data[i], cases[t].expected[i]);
+ failed++;
+ break;
+ }
+ }
+ bn_clear(&bn);
+ free(bs->data);
+ free(bs);
+ }
+ if (failed == 0) {
+ printf("OK\n");
+ } else {
+ printf("FAILED (%d cases)\n", failed);
+ }
+ return failed;
+}
+
+int test_convert_base_large() {
+ printf("test_convert_base_large: ");
+ int failed = 0;
+ struct {
+ const char *value;
+ const char *base;
+ int expected_len;
+ const char *expected[100]; // max expected length
+ } cases[] = {
+ // convert_base_large begin
+ {"123456789123456", "2", 47, {"0", "0", "0", "0", "0", "0", "0", "1", "1", "0", "0", "0", "1", "0", "0", "1", "1", "1", "1", "1", "0", "0", "0", "0", "0", "1", "1", "0", "0", "0", "0", "1", "0", "0", "0", "1", "0", "0", "1", "0", "0", "0", "0", "0", "1", "1", "1"}},
+ {"123456789123456789123456789", "123456", 6, {"104661", "75537", "83120", "74172", "37630", "4"}},
+ {"352099265818416392997042486274568094251", "18446744073709551616", 3, {"12367597952119210539", "640595372834356666", "1"}}
+ // convert_base_large end
+ };
+ int num_cases = sizeof(cases) / sizeof(cases[0]);
+ for (int t = 0; t < num_cases; t++) {
+ bn_t bn, base;
+ bn_init(&bn);
+ bn_init(&base);
+ bn_from_dec(cases[t].value, &bn);
+ bn_from_dec(cases[t].base, &base);
+ large_base_t *bs = bn_convert_base_large(&bn, &base);
+ if (bs == NULL) {
+ printf("Case %d: NULL\n", t);
+ failed++;
+ bn_clear(&bn);
+ bn_clear(&base);
+ continue;
+ }
+ if (bs->length != cases[t].expected_len) {
+ printf("Case %d: Bad length (%li instead of %i)\n", t, bs->length, cases[t].expected_len);
+ failed++;
+ }
+ for (int i = 0; i < cases[t].expected_len; i++) {
+ bn_t exp;
+ bn_init(&exp);
+ bn_from_dec(cases[t].expected[i], &exp);
+ if (!bn_eq(&bs->data[i], &exp)) {
+ printf("Case %d: Bad data at %d\n", t, i);
+ failed++;
+ bn_clear(&exp);
+ break;
+ }
+ bn_clear(&exp);
+ }
+ for (int i = 0; i < bs->length; i++) {
+ bn_clear(&bs->data[i]);
+ }
+ bn_clear(&bs->m);
+ bn_clear(&bn);
+ bn_clear(&base);
+ free(bs->data);
+ free(bs);
+ }
+ if (failed == 0) {
+ printf("OK\n");
+ } else {
+ printf("FAILED (%d cases)\n", failed);
+ }
+ return failed;
+}
+
+int test_bn_wnaf() {
+ printf("test_bn_wnaf: ");
+ int failed = 0;
+ struct {
+ const char *value;
+ int w;
+ int expected_len;
+ int8_t expected[100]; // max expected length
+ } cases[] = {
+ // wnaf begin
+ {"19", 2, 5, {1, 0, 1, 0, -1}},
+ {"45", 3, 5, {3, 0, 0, 0, -3}},
+ {"0", 3, 0, {}},
+ {"1", 2, 1, {1}},
+ {"21", 4, 5, {1, 0, 0, 0, 5}},
+ {"123456789", 3, 28, {1, 0, 0, -1, 0, 0, 3, 0, 0, -1, 0, 0, 0, 0, 0, -3, 0, 0, 0, -3, 0, 0, 0, 0, 3, 0, 0, -3}},
+ {"123456789123456789123456789", 5, 84, {13, 0, 0, 0, 0, 0, -15, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, -13, 0, 0, 0, 0, 0, -7, 0, 0, 0, 0, 0, -5, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, -7, 0, 0, 0, 0, -11}}
+ // wnaf end
+ };
+ int num_cases = sizeof(cases) / sizeof(cases[0]);
+ for (int t = 0; t < num_cases; t++) {
+ bn_t bn;
+ bn_init(&bn);
+ bn_from_dec(cases[t].value, &bn);
+ wnaf_t *naf = bn_wnaf(&bn, cases[t].w);
+ if (naf == NULL) {
+ printf("Case %d: NULL\n", t);
+ failed++;
+ bn_clear(&bn);
+ continue;
+ }
+ if (naf->length != cases[t].expected_len) {
+ printf("Case %d: Bad length (%li instead of %i)\n", t, naf->length, cases[t].expected_len);
+ failed++;
+ }
+ for (int i = 0; i < cases[t].expected_len; i++) {
+ if (naf->data[i] != cases[t].expected[i]) {
+ printf("Case %d: Bad data at %d (%i instead of %i)\n", t, i, naf->data[i], cases[t].expected[i]);
+ failed++;
+ break;
+ }
+ }
+ bn_clear(&bn);
+ free(naf->data);
+ free(naf);
+ }
+ if (failed == 0) {
+ printf("OK\n");
+ } else {
+ printf("FAILED (%d cases)\n", failed);
+ }
+ return failed;
+}
+
+int test_bn_wnaf_manipulation() {
+ printf("test_bn_wnaf_manipulation: ");
bn_t bn;
bn_init(&bn);
- bn_from_int(181, &bn);
- wsliding_t *ws = bn_wsliding_ltr(&bn, 3);
- if (ws == NULL) {
- printf("NULL\n");
+ bn_from_dec("123456789", &bn);
+ wnaf_t *naf = bn_wnaf(&bn, 3);
+ bn_clear(&bn);
+ if (naf->length != 28) {
+ printf("FAILED (bad length %li instead of 28)\n", naf->length);
return 1;
}
- if (ws->length != 6) {
- printf("Bad length (%li instead of 6)\n", ws->length);
+ bn_naf_pad_left(naf, 0, 5);
+ if (naf->length != 33) {
+ printf("FAILED (bad length after pad left %li instead of 33)\n", naf->length);
return 1;
}
- uint8_t expected[6] = {5, 0, 0, 5, 0, 1};
- for (int i = 0; i < 6; i++) {
- if (ws->data[i] != expected[i]) {
- printf("Bad data (%i instead of %i)\n", ws->data[i], expected[i]);
+ for (int i = 0; i < 5; i++) {
+ if (naf->data[i] != 0) {
+ printf("FAILED (bad data after pad left at %d (%i instead of 0))\n", i, naf->data[i]);
return 1;
}
}
- printf("OK\n");
- bn_clear(&bn);
- free(ws->data);
- free(ws);
- return 0;
-}
-
-int test_wsliding_rtl() {
- printf("test_wsliding_rtl: ");
- bn_t bn;
- bn_init(&bn);
- bn_from_int(181, &bn);
- wsliding_t *ws = bn_wsliding_rtl(&bn, 3);
- if (ws == NULL) {
- printf("NULL\n");
+ bn_naf_strip_left(naf, 0);
+ if (naf->length != 28) {
+ printf("FAILED (bad length after strip left %li instead of 28)\n", naf->length);
return 1;
}
- if (ws->length != 8) {
- printf("Bad length (%li instead of 8)\n", ws->length);
+ bn_naf_pad_right(naf, 0, 3);
+ if (naf->length != 31) {
+ printf("FAILED (bad length after pad right %li instead of 31)\n", naf->length);
return 1;
}
- uint8_t expected[8] = {1, 0, 0, 3, 0, 0, 0, 5};
- for (int i = 0; i < 8; i++) {
- if (ws->data[i] != expected[i]) {
- printf("Bad data (%i instead of %i)\n", ws->data[i], expected[i]);
+ for (int i = 28; i < 31; i++) {
+ if (naf->data[i] != 0) {
+ printf("FAILED (bad data after pad right at %d (%i instead of 0))\n", i, naf->data[i]);
return 1;
}
}
+ bn_naf_strip_right(naf, 0);
+ if (naf->length != 28) {
+ printf("FAILED (bad length after strip right %li instead of 28)\n", naf->length);
+ return 1;
+ }
+ int8_t rev[28] = {-3, 0, 0, 3, 0, 0, 0, 0, -3, 0, 0, 0, -3, 0, 0, 0, 0, 0, -1, 0, 0, 3, 0, 0, -1, 0, 0, 1};
+ bn_naf_reverse(naf);
+ for (int i = 0; i < 28; i++) {
+ if (naf->data[i] != rev[i]) {
+ printf("FAILED (bad data after reverse at %d (%i instead of %i))\n", i, naf->data[i], rev[i]);
+ return 1;
+ }
+ }
+
+ free(naf->data);
+ free(naf);
printf("OK\n");
- bn_clear(&bn);
- free(ws->data);
- free(ws);
return 0;
}
-int test_convert_base() {
- printf("test_convert_base: ");
- bn_t bn;
- bn_init(&bn);
- bn_from_int(11, &bn);
- small_base_t *bs = bn_convert_base_small(&bn, 2);
- if (bs == NULL) {
- printf("NULL\n");
- return 1;
- }
- if (bs->length != 4) {
- printf("Bad length (%li instead of 4)\n", bs->length);
- return 1;
+int test_booth() {
+ printf("test_booth: ");
+ for (int i = 0; i < (1 << 6); i++) {
+ int32_t bw = bn_booth_word(i, 5);
+ if (i <= 31) {
+ if (bw != (i + 1) / 2) {
+ printf("FAILED (bad booth for %d: %d instead of %d)\n", i, bw, (i + 1) / 2);
+ return 1;
+ }
+ } else {
+ if (bw != -((64 - i) / 2)) {
+ printf("FAILED (bad booth for %d: %d instead of %d)\n", i, bw, -((64 - i) / 2));
+ return 1;
+ }
+ }
}
- uint8_t expected[4] = {1, 1, 0, 1};
- for (int i = 0; i < 4; i++) {
- if (bs->data[i] != expected[i]) {
- printf("Bad data (%i insead of %i)\n", bs->data[i], expected[i]);
- return 1;
+ int failed = 0;
+ struct {
+ const char *value;
+ int w;
+ size_t bits;
+ int expected_len;
+ int32_t expected[256]; // max expected length
+ } cases[] = {
+ // booth begin
+ {"12345678123456781234567812345678123456781234567812345678", 1, 224, 225, {0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0}},
+ {"12345678123456781234567812345678123456781234567812345678", 2, 224, 113, {0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0}},
+ {"12345678123456781234567812345678123456781234567812345678", 5, 224, 45, {1, 4, 13, 3, -10, 15, 0, 9, 3, 9, -10, -12, -8, 2, 9, -6, 5, 13, -2, 1, -14, 7, -15, 11, 8, -16, 5, -14, -12, 11, -6, -4, 1, 4, 13, 3, -10, 15, 0, 9, 3, 9, -10, -12, -8}},
+ {"12345678123456781234567812345678123456781234567812345678", 16, 224, 15, {0, 4660, 22136, 4660, 22136, 4660, 22136, 4660, 22136, 4660, 22136, 4660, 22136, 4660, 22136}},
+ {"12345678123456781234567812345678123456781234567812345678", 24, 224, 10, {18, 3430008, 1193046, 7868980, 5666834, 3430008, 1193046, 7868980, 5666834, 3430008}},
+ {"12345678123456781234567812345678123456781234567812345678", 30, 224, 0, {}}
+ // booth end
+ };
+ int num_cases = sizeof(cases) / sizeof(cases[0]);
+ for (int t = 0; t < num_cases; t++) {
+ bn_t bn;
+ bn_init(&bn);
+ bn_from_hex(cases[t].value, &bn);
+ booth_t *booth = bn_booth(&bn, cases[t].w, cases[t].bits);
+ if (booth == NULL && cases[t].expected_len != 0) {
+ printf("Case %d: NULL\n", t);
+ failed++;
+ bn_clear(&bn);
+ continue;
}
+ if (cases[t].expected_len != 0) {
+ if (booth->length != cases[t].expected_len) {
+ printf("Case %d: Bad length (%li instead of %i)\n", t, booth->length, cases[t].expected_len);
+ failed++;
+ }
+ for (int i = 0; i < cases[t].expected_len; i++) {
+ if (booth->data[i] != cases[t].expected[i]) {
+ printf("FAILED (bad booth data at %d: %d instead of %d)\n", i, booth->data[i], cases[t].expected[i]);
+ failed++;
+ break;
+ }
+ }
+ }
+ bn_clear(&bn);
+ bn_booth_clear(booth);
}
printf("OK\n");
- bn_clear(&bn);
- free(bs->data);
- free(bs);
return 0;
}
int main(void) {
- return test_wsliding_ltr() + test_wsliding_rtl() + test_convert_base();
+ return test_wsliding_ltr() + test_wsliding_rtl() + test_convert_base_small() + test_convert_base_large() + test_bn_wnaf() + test_bn_wnaf_manipulation() + test_booth();
}
diff --git a/test/test_equivalence.py b/test/test_equivalence.py
new file mode 100644
index 0000000..6973ceb
--- /dev/null
+++ b/test/test_equivalence.py
@@ -0,0 +1,185 @@
+import subprocess
+import json
+from tempfile import NamedTemporaryFile
+from typing import Generator, Any
+from click.testing import CliRunner
+
+import pytest
+from importlib import resources
+from os.path import join
+
+from pyecsca.codegen.builder import build_impl
+from pyecsca.ec.formula import FormulaAction, NegationFormula
+from pyecsca.ec.model import CurveModel
+from pyecsca.ec.coordinates import CoordinateModel
+from pyecsca.sca.target.binary import BinaryTarget
+from pyecsca.codegen.client import ImplTarget
+from pyecsca.ec.context import DefaultContext, local, Node
+
+from pyecsca.ec.mult import WindowBoothMultiplier
+
+
+class GDBTarget(ImplTarget, BinaryTarget):
+ def __init__(self, model: CurveModel, coords: CoordinateModel, **kwargs):
+ super().__init__(model, coords, **kwargs)
+ self.trace_file = None
+
+ def connect(self):
+ self.trace_file = NamedTemporaryFile("r+")
+ with resources.path("test", "gdb_script.py") as gdb_script:
+ self.process = subprocess.Popen(
+ [
+ "gdb",
+ "-q",
+ "-batch-silent",
+ "-x",
+ gdb_script,
+ "--args",
+ *self.binary,
+ ],
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.DEVNULL,
+ env={
+ "TRACE_FILE": self.trace_file.name,
+ },
+ text=True,
+ bufsize=1,
+ )
+
+ def disconnect(self):
+ super().disconnect()
+ if self.trace_file is not None:
+ self.trace_file.close()
+ self.trace_file = None
+
+
+@pytest.fixture(scope="module")
+def target(simple_multiplier, secp128r1) -> Generator[GDBTarget, Any, None]:
+ mult_class, mult_kwargs = simple_multiplier
+ mult_name = mult_class.__name__
+ formulas = ["add-1998-cmo", "dbl-1998-cmo"]
+ if NegationFormula in mult_class.requires:
+ formulas.append("neg")
+ runner = CliRunner()
+ with runner.isolated_filesystem() as tmpdir:
+ res = runner.invoke(
+ build_impl,
+ [
+ "--platform",
+ "HOST",
+ "--ecdsa",
+ "--ecdh",
+ secp128r1.curve.model.shortname,
+ secp128r1.curve.coordinate_model.name,
+ *formulas,
+ f"{mult_name}({','.join(f'{key}={value}' for key, value in mult_kwargs.items())})",
+ ".",
+ ],
+ env={"DEBUG": "1", "CFLAGS": "-g -O0"},
+ )
+ assert res.exit_code == 0
+ target = GDBTarget(
+ secp128r1.curve.model,
+ secp128r1.curve.coordinate_model,
+ binary=join(tmpdir, "pyecsca-codegen-HOST.elf"),
+ )
+ formula_instances = [
+ secp128r1.curve.coordinate_model.formulas[formula] for formula in formulas
+ ]
+ mult = mult_class(*formula_instances, **mult_kwargs)
+ target.mult = mult
+ yield target
+
+
+def parse_trace(captured: str):
+ current_function = None
+ args = []
+ rets = []
+ result = []
+ for line in captured.split("\n"):
+ if ":" not in line:
+ func = line.strip()
+ if func.startswith("point_"):
+ func = func[len("point_") :]
+ if func == "set":
+ # The sets that happen inside another formula (like add) are a sign of short-circuiting.
+ # The Python simulation does not record the short-circuits, so we ignore them here.
+ current_function = None
+ else:
+ if current_function is not None:
+ result.append((current_function, args, rets))
+ current_function = func
+ args = []
+ rets = []
+ else:
+ name, data = line.split(":", 1)
+ name = name.strip()
+ value = json.loads(data)
+ if "out" in name:
+ rets.append(value)
+ else:
+ args.append(value)
+ return result
+
+
+def parse_ctx(scalarmult: Node):
+ result = []
+ for node in scalarmult.children:
+ action: FormulaAction = node.action
+ formula = action.formula
+ name = formula.shortname
+ args = []
+ for point in action.input_points:
+ point_value = {k: int(v) for k, v in point.coords.items()}
+ args.append(point_value)
+ rets = []
+ for point in action.output_points:
+ point_value = {k: int(v) for k, v in point.coords.items()}
+ rets.append(point_value)
+ result.append((name, args, rets))
+ return result
+
+
+def make_hashable(trace):
+ result = []
+ for entry in trace:
+ name, args, rets = entry
+ args_t = tuple(tuple(arg.items()) for arg in args)
+ rets_t = tuple(tuple(ret.items()) for ret in rets)
+ result.append((name, args_t, rets_t))
+ return tuple(result)
+
+
+def test_equivalence(target, secp128r1):
+ mult = target.mult
+ for i in range(10):
+ target.connect()
+ target.init_prng(i.to_bytes(4, "big"))
+ target.set_params(secp128r1)
+
+ priv, pub = target.generate()
+
+ with local(DefaultContext()) as ctx:
+ mult.init(secp128r1, secp128r1.generator)
+ expected = mult.multiply(priv).to_affine()
+
+ assert secp128r1.curve.is_on_curve(pub)
+ assert pub == expected
+ err = target.trace_file.read()
+ from_codegen = parse_trace(err)
+ from_sim = parse_ctx(ctx.actions[0]) + parse_ctx(ctx.actions[1])
+ codegen_set = set(make_hashable(from_codegen))
+ sim_set = set(make_hashable(from_sim))
+ if codegen_set != sim_set:
+ print(len(from_codegen), len(from_sim))
+ print("In codegen but not in sim:")
+ for entry in codegen_set - sim_set:
+ print(entry)
+ print("In sim but not in codegen:")
+ for entry in sim_set - codegen_set:
+ print(entry)
+ assert from_codegen == from_sim
+
+ target.quit()
+ target.disconnect()
diff --git a/test/test_impl.py b/test/test_impl.py
index 34f427b..a61aabd 100644
--- a/test/test_impl.py
+++ b/test/test_impl.py
@@ -1,29 +1,18 @@
from copy import copy
from os.path import join
+from typing import Any, Generator
+
import pytest
from click.testing import CliRunner
+from pyecsca.ec.formula import NegationFormula
from pyecsca.ec.key_agreement import ECDH_SHA1
from pyecsca.ec.mod import mod
-from pyecsca.ec.mult import (
- LTRMultiplier,
- RTLMultiplier,
- CoronMultiplier,
- BinaryNAFMultiplier,
- WindowNAFMultiplier,
- SlidingWindowMultiplier,
- AccumulationOrder,
- ProcessingDirection,
- ScalarMultiplier,
- FixedWindowLTRMultiplier,
- FullPrecompMultiplier,
- BGMWMultiplier,
- CombMultiplier,
-)
+from pyecsca.ec.mult import ScalarMultiplier, WindowBoothMultiplier
from pyecsca.ec.signature import ECDSA_SHA1, SignatureResult
from pyecsca.codegen.builder import build_impl
-from pyecsca.codegen.client import HostTarget, ImplTarget
+from pyecsca.codegen.client import HostTarget
@pytest.fixture(
@@ -39,195 +28,15 @@ def additional(request):
return request.param
-@pytest.fixture(
- scope="module",
- params=[
- pytest.param(
- (
- LTRMultiplier,
- "ltr",
- ["add-1998-cmo", "dbl-1998-cmo"],
- {"complete": False},
- ),
- id="LTR1",
- ),
- pytest.param(
- (
- LTRMultiplier,
- "ltr",
- ["add-1998-cmo", "dbl-1998-cmo"],
- {"complete": True},
- ),
- id="LTR2",
- ),
- pytest.param(
- (
- LTRMultiplier,
- "ltr",
- ["add-1998-cmo", "dbl-1998-cmo"],
- {"complete": False, "always": True},
- ),
- id="LTR3",
- ),
- pytest.param(
- (
- LTRMultiplier,
- "ltr",
- ["add-1998-cmo", "dbl-1998-cmo"],
- {"complete": True, "always": True},
- ),
- id="LTR4",
- ),
- pytest.param(
- (
- LTRMultiplier,
- "ltr",
- ["add-1998-cmo", "dbl-1998-cmo"],
- {"complete": False, "accumulation_order": AccumulationOrder.PeqRP},
- ),
- id="LTR5",
- ),
- pytest.param(
- (RTLMultiplier, "rtl", ["add-1998-cmo", "dbl-1998-cmo"], {"always": False}),
- id="RTL1",
- ),
- pytest.param(
- (RTLMultiplier, "rtl", ["add-1998-cmo", "dbl-1998-cmo"], {"always": True}),
- id="RTL2",
- ),
- pytest.param(
- (CoronMultiplier, "coron", ["add-1998-cmo", "dbl-1998-cmo"], {}), id="Coron"
- ),
- pytest.param(
- (
- BinaryNAFMultiplier,
- "bnaf",
- ["add-1998-cmo", "dbl-1998-cmo", "neg"],
- {"direction": ProcessingDirection.LTR},
- ),
- id="BNAF1",
- ),
- pytest.param(
- (
- BinaryNAFMultiplier,
- "bnaf",
- ["add-1998-cmo", "dbl-1998-cmo", "neg"],
- {"direction": ProcessingDirection.RTL},
- ),
- id="BNAF2",
- ),
- pytest.param(
- (
- WindowNAFMultiplier,
- "wnaf",
- ["add-1998-cmo", "dbl-1998-cmo", "neg"],
- {"width": 3},
- ),
- id="WNAF1",
- ),
- pytest.param(
- (
- WindowNAFMultiplier,
- "wnaf",
- ["add-1998-cmo", "dbl-1998-cmo", "neg"],
- {"width": 3, "precompute_negation": True},
- ),
- id="WNAF2",
- ),
- pytest.param(
- (
- SlidingWindowMultiplier,
- "sliding",
- ["add-1998-cmo", "dbl-1998-cmo"],
- {"width": 3},
- ),
- id="SLI1",
- ),
- pytest.param(
- (
- SlidingWindowMultiplier,
- "sliding",
- ["add-1998-cmo", "dbl-1998-cmo"],
- {"width": 3, "recoding_direction": ProcessingDirection.RTL},
- ),
- id="SLI2",
- ),
- pytest.param(
- (
- FixedWindowLTRMultiplier,
- "fixed",
- ["add-1998-cmo", "dbl-1998-cmo"],
- {"m": 4},
- ),
- id="FIX1",
- ),
- pytest.param(
- (
- FixedWindowLTRMultiplier,
- "fixed",
- ["add-1998-cmo", "dbl-1998-cmo"],
- {"m": 5},
- ),
- id="FIX2",
- ),
- pytest.param(
- (
- FullPrecompMultiplier,
- "precomp",
- ["add-1998-cmo", "dbl-1998-cmo"],
- {"direction": ProcessingDirection.LTR},
- ),
- id="PRE1",
- ),
- pytest.param(
- (
- FullPrecompMultiplier,
- "precomp",
- ["add-1998-cmo", "dbl-1998-cmo"],
- {"direction": ProcessingDirection.RTL},
- ),
- id="PRE2",
- ),
- pytest.param(
- (
- BGMWMultiplier,
- "bgmw",
- ["add-1998-cmo", "dbl-1998-cmo"],
- {"width": 3, "direction": ProcessingDirection.LTR},
- ),
- id="BGMW1",
- ),
- pytest.param(
- (
- BGMWMultiplier,
- "bgmw",
- ["add-1998-cmo", "dbl-1998-cmo"],
- {"width": 5, "direction": ProcessingDirection.RTL},
- ),
- id="BGMW2",
- ),
- pytest.param(
- (
- CombMultiplier,
- "comb",
- ["add-1998-cmo", "dbl-1998-cmo"],
- {"width": 3},
- ),
- id="Comb1",
- ),
- pytest.param(
- (
- CombMultiplier,
- "comb",
- ["add-1998-cmo", "dbl-1998-cmo"],
- {"width": 5},
- ),
- id="Comb2",
- ),
- ],
-)
-def target(request, additional, secp128r1) -> ImplTarget:
- mult_class, mult_name, formulas, mult_kwargs = request.param
+@pytest.fixture(scope="module")
+def target(
+ simple_multiplier, additional, secp128r1
+) -> Generator[HostTarget, Any, None]:
+ mult_class, mult_kwargs = simple_multiplier
+ mult_name = mult_class.__name__
+ formulas = ["add-1998-cmo", "dbl-1998-cmo"]
+ if NegationFormula in mult_class.requires:
+ formulas.append("neg")
runner = CliRunner()
with runner.isolated_filesystem() as tmpdir:
res = runner.invoke(
@@ -235,7 +44,6 @@ def target(request, additional, secp128r1) -> ImplTarget:
[
"--platform",
"HOST",
- *additional,
"--ecdsa",
"--ecdh",
secp128r1.curve.model.shortname,