diff options
Diffstat (limited to 'test')
| -rw-r--r-- | test/.gitignore | 2 | ||||
| -rw-r--r-- | test/Makefile | 14 | ||||
| -rw-r--r-- | test/conftest.py | 95 | ||||
| -rw-r--r-- | test/gdb_script.py | 89 | ||||
| -rw-r--r-- | test/test_bn.c | 427 | ||||
| -rw-r--r-- | test/test_equivalence.py | 185 | ||||
| -rw-r--r-- | test/test_impl.py | 220 |
7 files changed, 768 insertions, 264 deletions
diff --git a/test/.gitignore b/test/.gitignore new file mode 100644 index 0000000..05cabee --- /dev/null +++ b/test/.gitignore @@ -0,0 +1,2 @@ +test_bn_val +test_bn_san diff --git a/test/Makefile b/test/Makefile index d13b487..d0068a2 100644 --- a/test/Makefile +++ b/test/Makefile @@ -1,3 +1,13 @@ -test_bn: test_bn.c ../pyecsca/codegen/bn/bn.c - gcc -o $@ $^ -fsanitize=address -fsanitize=undefined -I ../pyecsca/codegen/ -I ../pyecsca/codegen/tommath/ -L ../pyecsca/codegen/tommath/ -l:libtommath-HOST.a
\ No newline at end of file +test_bn_san: test_bn.c ../pyecsca/codegen/bn/bn.c + gcc -g -o $@ $^ -fsanitize=address -fsanitize=undefined -I ../pyecsca/codegen/ -I ../pyecsca/codegen/tommath/ -L ../pyecsca/codegen/tommath/ -l:libtommath-HOST.a + +test_bn_val: test_bn.c ../pyecsca/codegen/bn/bn.c + gcc -g -o $@ $^ -I ../pyecsca/codegen/ -I ../pyecsca/codegen/tommath/ -L ../pyecsca/codegen/tommath/ -l:libtommath-HOST.a + + +test: test_bn_san test_bn_val + ./test_bn_san + valgrind -q ./test_bn_val + +.PHONY: test
\ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py index 1c1449a..80faf71 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,5 +1,6 @@ import pytest +from pyecsca.ec.mult import * from pyecsca.ec.params import get_params, DomainParameters @@ -11,3 +12,97 @@ def secp128r1() -> DomainParameters: @pytest.fixture(scope="session") def curve25519() -> DomainParameters: return get_params("other", "Curve25519", "xz") + + +# fmt: off +window_mults = [ + (SlidingWindowMultiplier, dict(width=2, recoding_direction=ProcessingDirection.LTR)), + (SlidingWindowMultiplier, dict(width=3, recoding_direction=ProcessingDirection.LTR)), + (SlidingWindowMultiplier, dict(width=4, recoding_direction=ProcessingDirection.LTR)), + (SlidingWindowMultiplier, dict(width=5, recoding_direction=ProcessingDirection.LTR)), + (SlidingWindowMultiplier, dict(width=6, recoding_direction=ProcessingDirection.LTR)), + (SlidingWindowMultiplier, dict(width=2, recoding_direction=ProcessingDirection.RTL)), + (SlidingWindowMultiplier, dict(width=3, recoding_direction=ProcessingDirection.RTL)), + (SlidingWindowMultiplier, dict(width=4, recoding_direction=ProcessingDirection.RTL)), + (SlidingWindowMultiplier, dict(width=5, recoding_direction=ProcessingDirection.RTL)), + (SlidingWindowMultiplier, dict(width=6, recoding_direction=ProcessingDirection.RTL)), + (FixedWindowLTRMultiplier, dict(m=2**1)), + (FixedWindowLTRMultiplier, dict(m=2**2)), + (FixedWindowLTRMultiplier, dict(m=2**3)), + (FixedWindowLTRMultiplier, dict(m=2**4)), + (FixedWindowLTRMultiplier, dict(m=2**5)), + (FixedWindowLTRMultiplier, dict(m=2**6)), + (WindowBoothMultiplier, dict(width=2)), + (WindowBoothMultiplier, dict(width=3)), + (WindowBoothMultiplier, dict(width=4)), + (WindowBoothMultiplier, dict(width=5)), + (WindowBoothMultiplier, dict(width=6)) +] +naf_mults = [ + (WindowNAFMultiplier, dict(width=2)), + (WindowNAFMultiplier, dict(width=3)), + (WindowNAFMultiplier, dict(width=4)), + (WindowNAFMultiplier, dict(width=5)), + (WindowNAFMultiplier, dict(width=6)), + (BinaryNAFMultiplier, dict(always=False, direction=ProcessingDirection.LTR)), + (BinaryNAFMultiplier, dict(always=False, direction=ProcessingDirection.RTL)), + (BinaryNAFMultiplier, dict(always=True, direction=ProcessingDirection.LTR)), + (BinaryNAFMultiplier, dict(always=True, direction=ProcessingDirection.RTL)), + (BinaryNAFMultiplier, dict(complete=False, always=False, direction=ProcessingDirection.LTR)), + (BinaryNAFMultiplier, dict(complete=False, always=False, direction=ProcessingDirection.RTL)), + (BinaryNAFMultiplier, dict(complete=False, always=True, direction=ProcessingDirection.LTR)), + (BinaryNAFMultiplier, dict(complete=False, always=True, direction=ProcessingDirection.RTL)) +] +comb_mults = [ + (CombMultiplier, dict(width=2, always=True)), + (CombMultiplier, dict(width=3, always=True)), + (CombMultiplier, dict(width=4, always=True)), + (CombMultiplier, dict(width=5, always=True)), + (CombMultiplier, dict(width=6, always=True)), + (CombMultiplier, dict(width=2, always=False)), + (CombMultiplier, dict(width=3, always=False)), + (CombMultiplier, dict(width=4, always=False)), + (CombMultiplier, dict(width=5, always=False)), + (CombMultiplier, dict(width=6, always=False)), + (BGMWMultiplier, dict(width=2, direction=ProcessingDirection.LTR)), + (BGMWMultiplier, dict(width=3, direction=ProcessingDirection.LTR)), + (BGMWMultiplier, dict(width=4, direction=ProcessingDirection.LTR)), + (BGMWMultiplier, dict(width=5, direction=ProcessingDirection.LTR)), + (BGMWMultiplier, dict(width=6, direction=ProcessingDirection.LTR)), + (BGMWMultiplier, dict(width=2, direction=ProcessingDirection.RTL)), + (BGMWMultiplier, dict(width=3, direction=ProcessingDirection.RTL)), + (BGMWMultiplier, dict(width=4, direction=ProcessingDirection.RTL)), + (BGMWMultiplier, dict(width=5, direction=ProcessingDirection.RTL)), + (BGMWMultiplier, dict(width=6, direction=ProcessingDirection.RTL)) +] +binary_mults = [ + (LTRMultiplier, dict(always=False, complete=True)), + (LTRMultiplier, dict(always=True, complete=True)), + (LTRMultiplier, dict(always=False, complete=False)), + (LTRMultiplier, dict(always=True, complete=False)), + (RTLMultiplier, dict(always=False, complete=True)), + (RTLMultiplier, dict(always=True, complete=True)), + (RTLMultiplier, dict(always=False, complete=False)), + (RTLMultiplier, dict(always=True, complete=False)), + (CoronMultiplier, dict()) +] +other_mults = [ + (FullPrecompMultiplier, dict(always=False, complete=True)), + (FullPrecompMultiplier, dict(always=True, complete=True)), + (FullPrecompMultiplier, dict(always=False, complete=False)), + (FullPrecompMultiplier, dict(always=True, complete=False)), + (SimpleLadderMultiplier, dict(complete=True)), + (SimpleLadderMultiplier, dict(complete=False)) +] +# fmt: on + + +@pytest.fixture( + scope="session", + params=window_mults + naf_mults + comb_mults + binary_mults + other_mults, + ids=lambda p: "{}-{}".format( + p[0].__name__, ":".join(f"{k}={v}" for k, v in p[1].items()) + ), +) +def simple_multiplier(request): + return request.param diff --git a/test/gdb_script.py b/test/gdb_script.py new file mode 100644 index 0000000..93d4472 --- /dev/null +++ b/test/gdb_script.py @@ -0,0 +1,89 @@ +import os +import json + +import gdb + +trace_file = open(os.environ["TRACE_FILE"], "w") + + +def extract_bn(bn): + data_ptr = bn["dp"] + used = int(bn["used"]) + bs = int(gdb.lookup_global_symbol("bn_digit_bits").value()) + result = 0 + for i in range(used): + limb = int((data_ptr + i).dereference()) + result += limb << (i * bs) + return result + + +def extract_point(point): + result = {} + for field in point.type.fields(): + field_name = field.name + if len(field_name) != 1: + continue + field_value = point[field_name] + result[field_name] = extract_bn(field_value) + return result + + +class TraceFunction(gdb.Breakpoint): + def stop(self): + try: + set_bp.enabled = True + frame = gdb.newest_frame() + block = frame.block() + print(frame.name(), flush=True, file=trace_file) + out = [] + for sym in block: + if sym.is_argument: + name = sym.name + try: + value = frame.read_var(name) + except Exception as e: + value = f"<unavailable: {e}>" + deref = value.dereference() + if deref.type.name == "point_t": + if "out" in name: + out.append(deref) + else: + pt = extract_point(deref) + print(f"{name}: {json.dumps(pt)}", flush=True, file=trace_file) + bp = TraceExit(frame) + bp.silent = True + bp.target = out + except RuntimeError: + pass + return False # Continue execution + + +class TraceExit(gdb.FinishBreakpoint): + def stop(self): + set_bp.enabled = False + for i, point in enumerate(self.target): + print(f"out_{i}: {json.dumps(extract_point(point))}", flush=True, file=trace_file) + return False # Continue execution + + +def register_bp(name): + if gdb.lookup_global_symbol(name) is not None: + bp = TraceFunction(name) + bp.silent = True + return bp + return None + + +register_bp("point_add") +register_bp("point_dadd") +register_bp("point_dadd") +register_bp("point_ladd") +register_bp("point_dbl") +register_bp("point_neg") +register_bp("point_scl") +register_bp("point_tpl") +set_bp = register_bp("point_set") +set_bp.enabled = False + +gdb.execute("run") +trace_file.close() diff --git a/test/test_bn.c b/test/test_bn.c index 66de22f..b04c936 100644 --- a/test/test_bn.c +++ b/test/test_bn.c @@ -4,88 +4,403 @@ int test_wsliding_ltr() { printf("test_wsliding_ltr: "); + int failed = 0; + struct { + const char *value; + int w; + int expected_len; + uint8_t expected[100]; // max expected length + } cases[] = { + // sliding_window_ltr begin + {"181", 3, 6, {5, 0, 0, 5, 0, 1}}, + {"1", 3, 1, {1}}, + {"1234", 2, 11, {1, 0, 0, 0, 3, 0, 1, 0, 0, 1, 0}}, + {"170", 4, 6, {5, 0, 0, 0, 5, 0}}, + {"554", 5, 6, {17, 0, 0, 0, 5, 0}}, + {"123456789123456789123456789", 5, 83, {25, 1, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 23, 0, 0, 0, 0, 25, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 29, 0, 0, 0, 0, 17, 0, 0, 0, 0, 19, 0, 0, 0, 0, 29, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, 21}} + // sliding_window_ltr end + }; + int num_cases = sizeof(cases) / sizeof(cases[0]); + for (int t = 0; t < num_cases; t++) { + bn_t bn; + bn_init(&bn); + bn_from_dec(cases[t].value, &bn); + wsliding_t *ws = bn_wsliding_ltr(&bn, cases[t].w); + if (ws == NULL) { + printf("Case %d: NULL\n", t); + failed++; + bn_clear(&bn); + continue; + } + if (ws->length != cases[t].expected_len) { + printf("Case %d: Bad length (%li instead of %i)\n", t, ws->length, cases[t].expected_len); + failed++; + } + for (int i = 0; i < cases[t].expected_len; i++) { + if (ws->data[i] != cases[t].expected[i]) { + printf("Case %d: Bad data at %d (%i instead of %i)\n", t, i, ws->data[i], cases[t].expected[i]); + failed++; + break; + } + } + bn_clear(&bn); + free(ws->data); + free(ws); + } + if (failed == 0) { + printf("OK\n"); + } else { + printf("FAILED (%d cases)\n", failed); + } + return failed; +} + +int test_wsliding_rtl() { + printf("test_wsliding_rtl: "); + int failed = 0; + struct { + const char *value; + int w; + int expected_len; + uint8_t expected[100]; // max expected length + } cases[] = { + // sliding_window_rtl begin + {"181", 3, 8, {1, 0, 0, 3, 0, 0, 0, 5}}, + {"1", 3, 1, {1}}, + {"1234", 2, 11, {1, 0, 0, 0, 3, 0, 1, 0, 0, 1, 0}}, + {"170", 4, 6, {5, 0, 0, 0, 5, 0}}, + {"554", 5, 10, {1, 0, 0, 0, 0, 0, 0, 0, 21, 0}}, + {"123456789123456789123456789",5, 87, {1, 0, 0, 0, 0, 19, 0, 0, 0, 0, 1, 0, 0, 0, 0, 29, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 17, 0, 0, 0, 0, 27, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, 21}} + // sliding_window_rtl end + }; + int num_cases = sizeof(cases) / sizeof(cases[0]); + for (int t = 0; t < num_cases; t++) { + bn_t bn; + bn_init(&bn); + bn_from_dec(cases[t].value, &bn); + wsliding_t *ws = bn_wsliding_rtl(&bn, cases[t].w); + if (ws == NULL) { + printf("Case %d: NULL\n", t); + failed++; + bn_clear(&bn); + continue; + } + if (ws->length != cases[t].expected_len) { + printf("Case %d: Bad length (%li instead of %i)\n", t, ws->length, cases[t].expected_len); + failed++; + } + for (int i = 0; i < cases[t].expected_len; i++) { + if (ws->data[i] != cases[t].expected[i]) { + printf("Case %d: Bad data at %d (%i instead of %i)\n", t, i, ws->data[i], cases[t].expected[i]); + failed++; + break; + } + } + bn_clear(&bn); + free(ws->data); + free(ws); + } + if (failed == 0) { + printf("OK\n"); + } else { + printf("FAILED (%d cases)\n", failed); + } + return failed; +} + +int test_convert_base_small() { + printf("test_convert_base_small: "); + int failed = 0; + struct { + const char *value; + int base; + int expected_len; + uint8_t expected[100]; // max expected length + } cases[] = { + // convert_base_small begin + {"11", 2, 4, {1, 1, 0, 1}}, + {"255", 2, 8, {1, 1, 1, 1, 1, 1, 1, 1}}, + {"1234", 10, 4, {4, 3, 2, 1}}, + {"0", 2, 1, {0}}, + {"1", 2, 1, {1}}, + {"123456789123456789123456789", 16, 22, {5, 1, 15, 5, 4, 0, 12, 7, 15, 9, 1, 11, 3, 14, 2, 15, 13, 15, 14, 1, 6, 6}} + // convert_base_small end + }; + int num_cases = sizeof(cases) / sizeof(cases[0]); + for (int t = 0; t < num_cases; t++) { + bn_t bn; + bn_init(&bn); + bn_from_dec(cases[t].value, &bn); + small_base_t *bs = bn_convert_base_small(&bn, cases[t].base); + if (bs == NULL) { + printf("Case %d: NULL\n", t); + failed++; + bn_clear(&bn); + continue; + } + if (bs->length != cases[t].expected_len) { + printf("Case %d: Bad length (%li instead of %i)\n", t, bs->length, cases[t].expected_len); + failed++; + } + for (int i = 0; i < cases[t].expected_len; i++) { + if (bs->data[i] != cases[t].expected[i]) { + printf("Case %d: Bad data at %d (%i instead of %i)\n", t, i, bs->data[i], cases[t].expected[i]); + failed++; + break; + } + } + bn_clear(&bn); + free(bs->data); + free(bs); + } + if (failed == 0) { + printf("OK\n"); + } else { + printf("FAILED (%d cases)\n", failed); + } + return failed; +} + +int test_convert_base_large() { + printf("test_convert_base_large: "); + int failed = 0; + struct { + const char *value; + const char *base; + int expected_len; + const char *expected[100]; // max expected length + } cases[] = { + // convert_base_large begin + {"123456789123456", "2", 47, {"0", "0", "0", "0", "0", "0", "0", "1", "1", "0", "0", "0", "1", "0", "0", "1", "1", "1", "1", "1", "0", "0", "0", "0", "0", "1", "1", "0", "0", "0", "0", "1", "0", "0", "0", "1", "0", "0", "1", "0", "0", "0", "0", "0", "1", "1", "1"}}, + {"123456789123456789123456789", "123456", 6, {"104661", "75537", "83120", "74172", "37630", "4"}}, + {"352099265818416392997042486274568094251", "18446744073709551616", 3, {"12367597952119210539", "640595372834356666", "1"}} + // convert_base_large end + }; + int num_cases = sizeof(cases) / sizeof(cases[0]); + for (int t = 0; t < num_cases; t++) { + bn_t bn, base; + bn_init(&bn); + bn_init(&base); + bn_from_dec(cases[t].value, &bn); + bn_from_dec(cases[t].base, &base); + large_base_t *bs = bn_convert_base_large(&bn, &base); + if (bs == NULL) { + printf("Case %d: NULL\n", t); + failed++; + bn_clear(&bn); + bn_clear(&base); + continue; + } + if (bs->length != cases[t].expected_len) { + printf("Case %d: Bad length (%li instead of %i)\n", t, bs->length, cases[t].expected_len); + failed++; + } + for (int i = 0; i < cases[t].expected_len; i++) { + bn_t exp; + bn_init(&exp); + bn_from_dec(cases[t].expected[i], &exp); + if (!bn_eq(&bs->data[i], &exp)) { + printf("Case %d: Bad data at %d\n", t, i); + failed++; + bn_clear(&exp); + break; + } + bn_clear(&exp); + } + for (int i = 0; i < bs->length; i++) { + bn_clear(&bs->data[i]); + } + bn_clear(&bs->m); + bn_clear(&bn); + bn_clear(&base); + free(bs->data); + free(bs); + } + if (failed == 0) { + printf("OK\n"); + } else { + printf("FAILED (%d cases)\n", failed); + } + return failed; +} + +int test_bn_wnaf() { + printf("test_bn_wnaf: "); + int failed = 0; + struct { + const char *value; + int w; + int expected_len; + int8_t expected[100]; // max expected length + } cases[] = { + // wnaf begin + {"19", 2, 5, {1, 0, 1, 0, -1}}, + {"45", 3, 5, {3, 0, 0, 0, -3}}, + {"0", 3, 0, {}}, + {"1", 2, 1, {1}}, + {"21", 4, 5, {1, 0, 0, 0, 5}}, + {"123456789", 3, 28, {1, 0, 0, -1, 0, 0, 3, 0, 0, -1, 0, 0, 0, 0, 0, -3, 0, 0, 0, -3, 0, 0, 0, 0, 3, 0, 0, -3}}, + {"123456789123456789123456789", 5, 84, {13, 0, 0, 0, 0, 0, -15, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, -13, 0, 0, 0, 0, 0, -7, 0, 0, 0, 0, 0, -5, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, -7, 0, 0, 0, 0, -11}} + // wnaf end + }; + int num_cases = sizeof(cases) / sizeof(cases[0]); + for (int t = 0; t < num_cases; t++) { + bn_t bn; + bn_init(&bn); + bn_from_dec(cases[t].value, &bn); + wnaf_t *naf = bn_wnaf(&bn, cases[t].w); + if (naf == NULL) { + printf("Case %d: NULL\n", t); + failed++; + bn_clear(&bn); + continue; + } + if (naf->length != cases[t].expected_len) { + printf("Case %d: Bad length (%li instead of %i)\n", t, naf->length, cases[t].expected_len); + failed++; + } + for (int i = 0; i < cases[t].expected_len; i++) { + if (naf->data[i] != cases[t].expected[i]) { + printf("Case %d: Bad data at %d (%i instead of %i)\n", t, i, naf->data[i], cases[t].expected[i]); + failed++; + break; + } + } + bn_clear(&bn); + free(naf->data); + free(naf); + } + if (failed == 0) { + printf("OK\n"); + } else { + printf("FAILED (%d cases)\n", failed); + } + return failed; +} + +int test_bn_wnaf_manipulation() { + printf("test_bn_wnaf_manipulation: "); bn_t bn; bn_init(&bn); - bn_from_int(181, &bn); - wsliding_t *ws = bn_wsliding_ltr(&bn, 3); - if (ws == NULL) { - printf("NULL\n"); + bn_from_dec("123456789", &bn); + wnaf_t *naf = bn_wnaf(&bn, 3); + bn_clear(&bn); + if (naf->length != 28) { + printf("FAILED (bad length %li instead of 28)\n", naf->length); return 1; } - if (ws->length != 6) { - printf("Bad length (%li instead of 6)\n", ws->length); + bn_naf_pad_left(naf, 0, 5); + if (naf->length != 33) { + printf("FAILED (bad length after pad left %li instead of 33)\n", naf->length); return 1; } - uint8_t expected[6] = {5, 0, 0, 5, 0, 1}; - for (int i = 0; i < 6; i++) { - if (ws->data[i] != expected[i]) { - printf("Bad data (%i instead of %i)\n", ws->data[i], expected[i]); + for (int i = 0; i < 5; i++) { + if (naf->data[i] != 0) { + printf("FAILED (bad data after pad left at %d (%i instead of 0))\n", i, naf->data[i]); return 1; } } - printf("OK\n"); - bn_clear(&bn); - free(ws->data); - free(ws); - return 0; -} - -int test_wsliding_rtl() { - printf("test_wsliding_rtl: "); - bn_t bn; - bn_init(&bn); - bn_from_int(181, &bn); - wsliding_t *ws = bn_wsliding_rtl(&bn, 3); - if (ws == NULL) { - printf("NULL\n"); + bn_naf_strip_left(naf, 0); + if (naf->length != 28) { + printf("FAILED (bad length after strip left %li instead of 28)\n", naf->length); return 1; } - if (ws->length != 8) { - printf("Bad length (%li instead of 8)\n", ws->length); + bn_naf_pad_right(naf, 0, 3); + if (naf->length != 31) { + printf("FAILED (bad length after pad right %li instead of 31)\n", naf->length); return 1; } - uint8_t expected[8] = {1, 0, 0, 3, 0, 0, 0, 5}; - for (int i = 0; i < 8; i++) { - if (ws->data[i] != expected[i]) { - printf("Bad data (%i instead of %i)\n", ws->data[i], expected[i]); + for (int i = 28; i < 31; i++) { + if (naf->data[i] != 0) { + printf("FAILED (bad data after pad right at %d (%i instead of 0))\n", i, naf->data[i]); return 1; } } + bn_naf_strip_right(naf, 0); + if (naf->length != 28) { + printf("FAILED (bad length after strip right %li instead of 28)\n", naf->length); + return 1; + } + int8_t rev[28] = {-3, 0, 0, 3, 0, 0, 0, 0, -3, 0, 0, 0, -3, 0, 0, 0, 0, 0, -1, 0, 0, 3, 0, 0, -1, 0, 0, 1}; + bn_naf_reverse(naf); + for (int i = 0; i < 28; i++) { + if (naf->data[i] != rev[i]) { + printf("FAILED (bad data after reverse at %d (%i instead of %i))\n", i, naf->data[i], rev[i]); + return 1; + } + } + + free(naf->data); + free(naf); printf("OK\n"); - bn_clear(&bn); - free(ws->data); - free(ws); return 0; } -int test_convert_base() { - printf("test_convert_base: "); - bn_t bn; - bn_init(&bn); - bn_from_int(11, &bn); - small_base_t *bs = bn_convert_base_small(&bn, 2); - if (bs == NULL) { - printf("NULL\n"); - return 1; - } - if (bs->length != 4) { - printf("Bad length (%li instead of 4)\n", bs->length); - return 1; +int test_booth() { + printf("test_booth: "); + for (int i = 0; i < (1 << 6); i++) { + int32_t bw = bn_booth_word(i, 5); + if (i <= 31) { + if (bw != (i + 1) / 2) { + printf("FAILED (bad booth for %d: %d instead of %d)\n", i, bw, (i + 1) / 2); + return 1; + } + } else { + if (bw != -((64 - i) / 2)) { + printf("FAILED (bad booth for %d: %d instead of %d)\n", i, bw, -((64 - i) / 2)); + return 1; + } + } } - uint8_t expected[4] = {1, 1, 0, 1}; - for (int i = 0; i < 4; i++) { - if (bs->data[i] != expected[i]) { - printf("Bad data (%i insead of %i)\n", bs->data[i], expected[i]); - return 1; + int failed = 0; + struct { + const char *value; + int w; + size_t bits; + int expected_len; + int32_t expected[256]; // max expected length + } cases[] = { + // booth begin + {"12345678123456781234567812345678123456781234567812345678", 1, 224, 225, {0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 1, -1, 0, 0, 1, 0, -1, 1, -1, 0, 0, 1, -1, 1, -1, 1, 0, -1, 0, 1, 0, 0, 0, -1, 0, 0, 0}}, + {"12345678123456781234567812345678123456781234567812345678", 2, 224, 113, {0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0, 0, 1, 1, -2, 1, -1, 1, 0, 1, 1, 2, -2, 2, 0, -2, 0}}, + {"12345678123456781234567812345678123456781234567812345678", 5, 224, 45, {1, 4, 13, 3, -10, 15, 0, 9, 3, 9, -10, -12, -8, 2, 9, -6, 5, 13, -2, 1, -14, 7, -15, 11, 8, -16, 5, -14, -12, 11, -6, -4, 1, 4, 13, 3, -10, 15, 0, 9, 3, 9, -10, -12, -8}}, + {"12345678123456781234567812345678123456781234567812345678", 16, 224, 15, {0, 4660, 22136, 4660, 22136, 4660, 22136, 4660, 22136, 4660, 22136, 4660, 22136, 4660, 22136}}, + {"12345678123456781234567812345678123456781234567812345678", 24, 224, 10, {18, 3430008, 1193046, 7868980, 5666834, 3430008, 1193046, 7868980, 5666834, 3430008}}, + {"12345678123456781234567812345678123456781234567812345678", 30, 224, 0, {}} + // booth end + }; + int num_cases = sizeof(cases) / sizeof(cases[0]); + for (int t = 0; t < num_cases; t++) { + bn_t bn; + bn_init(&bn); + bn_from_hex(cases[t].value, &bn); + booth_t *booth = bn_booth(&bn, cases[t].w, cases[t].bits); + if (booth == NULL && cases[t].expected_len != 0) { + printf("Case %d: NULL\n", t); + failed++; + bn_clear(&bn); + continue; } + if (cases[t].expected_len != 0) { + if (booth->length != cases[t].expected_len) { + printf("Case %d: Bad length (%li instead of %i)\n", t, booth->length, cases[t].expected_len); + failed++; + } + for (int i = 0; i < cases[t].expected_len; i++) { + if (booth->data[i] != cases[t].expected[i]) { + printf("FAILED (bad booth data at %d: %d instead of %d)\n", i, booth->data[i], cases[t].expected[i]); + failed++; + break; + } + } + } + bn_clear(&bn); + bn_booth_clear(booth); } printf("OK\n"); - bn_clear(&bn); - free(bs->data); - free(bs); return 0; } int main(void) { - return test_wsliding_ltr() + test_wsliding_rtl() + test_convert_base(); + return test_wsliding_ltr() + test_wsliding_rtl() + test_convert_base_small() + test_convert_base_large() + test_bn_wnaf() + test_bn_wnaf_manipulation() + test_booth(); } diff --git a/test/test_equivalence.py b/test/test_equivalence.py new file mode 100644 index 0000000..6973ceb --- /dev/null +++ b/test/test_equivalence.py @@ -0,0 +1,185 @@ +import subprocess +import json +from tempfile import NamedTemporaryFile +from typing import Generator, Any +from click.testing import CliRunner + +import pytest +from importlib import resources +from os.path import join + +from pyecsca.codegen.builder import build_impl +from pyecsca.ec.formula import FormulaAction, NegationFormula +from pyecsca.ec.model import CurveModel +from pyecsca.ec.coordinates import CoordinateModel +from pyecsca.sca.target.binary import BinaryTarget +from pyecsca.codegen.client import ImplTarget +from pyecsca.ec.context import DefaultContext, local, Node + +from pyecsca.ec.mult import WindowBoothMultiplier + + +class GDBTarget(ImplTarget, BinaryTarget): + def __init__(self, model: CurveModel, coords: CoordinateModel, **kwargs): + super().__init__(model, coords, **kwargs) + self.trace_file = None + + def connect(self): + self.trace_file = NamedTemporaryFile("r+") + with resources.path("test", "gdb_script.py") as gdb_script: + self.process = subprocess.Popen( + [ + "gdb", + "-q", + "-batch-silent", + "-x", + gdb_script, + "--args", + *self.binary, + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + env={ + "TRACE_FILE": self.trace_file.name, + }, + text=True, + bufsize=1, + ) + + def disconnect(self): + super().disconnect() + if self.trace_file is not None: + self.trace_file.close() + self.trace_file = None + + +@pytest.fixture(scope="module") +def target(simple_multiplier, secp128r1) -> Generator[GDBTarget, Any, None]: + mult_class, mult_kwargs = simple_multiplier + mult_name = mult_class.__name__ + formulas = ["add-1998-cmo", "dbl-1998-cmo"] + if NegationFormula in mult_class.requires: + formulas.append("neg") + runner = CliRunner() + with runner.isolated_filesystem() as tmpdir: + res = runner.invoke( + build_impl, + [ + "--platform", + "HOST", + "--ecdsa", + "--ecdh", + secp128r1.curve.model.shortname, + secp128r1.curve.coordinate_model.name, + *formulas, + f"{mult_name}({','.join(f'{key}={value}' for key, value in mult_kwargs.items())})", + ".", + ], + env={"DEBUG": "1", "CFLAGS": "-g -O0"}, + ) + assert res.exit_code == 0 + target = GDBTarget( + secp128r1.curve.model, + secp128r1.curve.coordinate_model, + binary=join(tmpdir, "pyecsca-codegen-HOST.elf"), + ) + formula_instances = [ + secp128r1.curve.coordinate_model.formulas[formula] for formula in formulas + ] + mult = mult_class(*formula_instances, **mult_kwargs) + target.mult = mult + yield target + + +def parse_trace(captured: str): + current_function = None + args = [] + rets = [] + result = [] + for line in captured.split("\n"): + if ":" not in line: + func = line.strip() + if func.startswith("point_"): + func = func[len("point_") :] + if func == "set": + # The sets that happen inside another formula (like add) are a sign of short-circuiting. + # The Python simulation does not record the short-circuits, so we ignore them here. + current_function = None + else: + if current_function is not None: + result.append((current_function, args, rets)) + current_function = func + args = [] + rets = [] + else: + name, data = line.split(":", 1) + name = name.strip() + value = json.loads(data) + if "out" in name: + rets.append(value) + else: + args.append(value) + return result + + +def parse_ctx(scalarmult: Node): + result = [] + for node in scalarmult.children: + action: FormulaAction = node.action + formula = action.formula + name = formula.shortname + args = [] + for point in action.input_points: + point_value = {k: int(v) for k, v in point.coords.items()} + args.append(point_value) + rets = [] + for point in action.output_points: + point_value = {k: int(v) for k, v in point.coords.items()} + rets.append(point_value) + result.append((name, args, rets)) + return result + + +def make_hashable(trace): + result = [] + for entry in trace: + name, args, rets = entry + args_t = tuple(tuple(arg.items()) for arg in args) + rets_t = tuple(tuple(ret.items()) for ret in rets) + result.append((name, args_t, rets_t)) + return tuple(result) + + +def test_equivalence(target, secp128r1): + mult = target.mult + for i in range(10): + target.connect() + target.init_prng(i.to_bytes(4, "big")) + target.set_params(secp128r1) + + priv, pub = target.generate() + + with local(DefaultContext()) as ctx: + mult.init(secp128r1, secp128r1.generator) + expected = mult.multiply(priv).to_affine() + + assert secp128r1.curve.is_on_curve(pub) + assert pub == expected + err = target.trace_file.read() + from_codegen = parse_trace(err) + from_sim = parse_ctx(ctx.actions[0]) + parse_ctx(ctx.actions[1]) + codegen_set = set(make_hashable(from_codegen)) + sim_set = set(make_hashable(from_sim)) + if codegen_set != sim_set: + print(len(from_codegen), len(from_sim)) + print("In codegen but not in sim:") + for entry in codegen_set - sim_set: + print(entry) + print("In sim but not in codegen:") + for entry in sim_set - codegen_set: + print(entry) + assert from_codegen == from_sim + + target.quit() + target.disconnect() diff --git a/test/test_impl.py b/test/test_impl.py index 34f427b..a61aabd 100644 --- a/test/test_impl.py +++ b/test/test_impl.py @@ -1,29 +1,18 @@ from copy import copy from os.path import join +from typing import Any, Generator + import pytest from click.testing import CliRunner +from pyecsca.ec.formula import NegationFormula from pyecsca.ec.key_agreement import ECDH_SHA1 from pyecsca.ec.mod import mod -from pyecsca.ec.mult import ( - LTRMultiplier, - RTLMultiplier, - CoronMultiplier, - BinaryNAFMultiplier, - WindowNAFMultiplier, - SlidingWindowMultiplier, - AccumulationOrder, - ProcessingDirection, - ScalarMultiplier, - FixedWindowLTRMultiplier, - FullPrecompMultiplier, - BGMWMultiplier, - CombMultiplier, -) +from pyecsca.ec.mult import ScalarMultiplier, WindowBoothMultiplier from pyecsca.ec.signature import ECDSA_SHA1, SignatureResult from pyecsca.codegen.builder import build_impl -from pyecsca.codegen.client import HostTarget, ImplTarget +from pyecsca.codegen.client import HostTarget @pytest.fixture( @@ -39,195 +28,15 @@ def additional(request): return request.param -@pytest.fixture( - scope="module", - params=[ - pytest.param( - ( - LTRMultiplier, - "ltr", - ["add-1998-cmo", "dbl-1998-cmo"], - {"complete": False}, - ), - id="LTR1", - ), - pytest.param( - ( - LTRMultiplier, - "ltr", - ["add-1998-cmo", "dbl-1998-cmo"], - {"complete": True}, - ), - id="LTR2", - ), - pytest.param( - ( - LTRMultiplier, - "ltr", - ["add-1998-cmo", "dbl-1998-cmo"], - {"complete": False, "always": True}, - ), - id="LTR3", - ), - pytest.param( - ( - LTRMultiplier, - "ltr", - ["add-1998-cmo", "dbl-1998-cmo"], - {"complete": True, "always": True}, - ), - id="LTR4", - ), - pytest.param( - ( - LTRMultiplier, - "ltr", - ["add-1998-cmo", "dbl-1998-cmo"], - {"complete": False, "accumulation_order": AccumulationOrder.PeqRP}, - ), - id="LTR5", - ), - pytest.param( - (RTLMultiplier, "rtl", ["add-1998-cmo", "dbl-1998-cmo"], {"always": False}), - id="RTL1", - ), - pytest.param( - (RTLMultiplier, "rtl", ["add-1998-cmo", "dbl-1998-cmo"], {"always": True}), - id="RTL2", - ), - pytest.param( - (CoronMultiplier, "coron", ["add-1998-cmo", "dbl-1998-cmo"], {}), id="Coron" - ), - pytest.param( - ( - BinaryNAFMultiplier, - "bnaf", - ["add-1998-cmo", "dbl-1998-cmo", "neg"], - {"direction": ProcessingDirection.LTR}, - ), - id="BNAF1", - ), - pytest.param( - ( - BinaryNAFMultiplier, - "bnaf", - ["add-1998-cmo", "dbl-1998-cmo", "neg"], - {"direction": ProcessingDirection.RTL}, - ), - id="BNAF2", - ), - pytest.param( - ( - WindowNAFMultiplier, - "wnaf", - ["add-1998-cmo", "dbl-1998-cmo", "neg"], - {"width": 3}, - ), - id="WNAF1", - ), - pytest.param( - ( - WindowNAFMultiplier, - "wnaf", - ["add-1998-cmo", "dbl-1998-cmo", "neg"], - {"width": 3, "precompute_negation": True}, - ), - id="WNAF2", - ), - pytest.param( - ( - SlidingWindowMultiplier, - "sliding", - ["add-1998-cmo", "dbl-1998-cmo"], - {"width": 3}, - ), - id="SLI1", - ), - pytest.param( - ( - SlidingWindowMultiplier, - "sliding", - ["add-1998-cmo", "dbl-1998-cmo"], - {"width": 3, "recoding_direction": ProcessingDirection.RTL}, - ), - id="SLI2", - ), - pytest.param( - ( - FixedWindowLTRMultiplier, - "fixed", - ["add-1998-cmo", "dbl-1998-cmo"], - {"m": 4}, - ), - id="FIX1", - ), - pytest.param( - ( - FixedWindowLTRMultiplier, - "fixed", - ["add-1998-cmo", "dbl-1998-cmo"], - {"m": 5}, - ), - id="FIX2", - ), - pytest.param( - ( - FullPrecompMultiplier, - "precomp", - ["add-1998-cmo", "dbl-1998-cmo"], - {"direction": ProcessingDirection.LTR}, - ), - id="PRE1", - ), - pytest.param( - ( - FullPrecompMultiplier, - "precomp", - ["add-1998-cmo", "dbl-1998-cmo"], - {"direction": ProcessingDirection.RTL}, - ), - id="PRE2", - ), - pytest.param( - ( - BGMWMultiplier, - "bgmw", - ["add-1998-cmo", "dbl-1998-cmo"], - {"width": 3, "direction": ProcessingDirection.LTR}, - ), - id="BGMW1", - ), - pytest.param( - ( - BGMWMultiplier, - "bgmw", - ["add-1998-cmo", "dbl-1998-cmo"], - {"width": 5, "direction": ProcessingDirection.RTL}, - ), - id="BGMW2", - ), - pytest.param( - ( - CombMultiplier, - "comb", - ["add-1998-cmo", "dbl-1998-cmo"], - {"width": 3}, - ), - id="Comb1", - ), - pytest.param( - ( - CombMultiplier, - "comb", - ["add-1998-cmo", "dbl-1998-cmo"], - {"width": 5}, - ), - id="Comb2", - ), - ], -) -def target(request, additional, secp128r1) -> ImplTarget: - mult_class, mult_name, formulas, mult_kwargs = request.param +@pytest.fixture(scope="module") +def target( + simple_multiplier, additional, secp128r1 +) -> Generator[HostTarget, Any, None]: + mult_class, mult_kwargs = simple_multiplier + mult_name = mult_class.__name__ + formulas = ["add-1998-cmo", "dbl-1998-cmo"] + if NegationFormula in mult_class.requires: + formulas.append("neg") runner = CliRunner() with runner.isolated_filesystem() as tmpdir: res = runner.invoke( @@ -235,7 +44,6 @@ def target(request, additional, secp128r1) -> ImplTarget: [ "--platform", "HOST", - *additional, "--ecdsa", "--ecdh", secp128r1.curve.model.shortname, |
