diff options
| -rw-r--r-- | pyecsca/codegen/templates/mult_bnaf.c | 53 | ||||
| -rw-r--r-- | pyecsca/codegen/templates/mult_wnaf.c | 2 | ||||
| -rw-r--r-- | test/conftest.py | 7 | ||||
| -rw-r--r-- | test/test_equivalence.py | 43 |
4 files changed, 64 insertions, 41 deletions
diff --git a/pyecsca/codegen/templates/mult_bnaf.c b/pyecsca/codegen/templates/mult_bnaf.c index d0cafdf..090807c 100644 --- a/pyecsca/codegen/templates/mult_bnaf.c +++ b/pyecsca/codegen/templates/mult_bnaf.c @@ -1,15 +1,30 @@ #include "mult.h" #include "point.h" -point_t *scalar_mult_ltr(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf) { +point_t *scalar_mult_ltr(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf, size_t bits) { + point_t *q; + long i; + {% if scalarmult.complete %} + bn_naf_pad_left(naf, 0, (bits + 1) - naf->length); + q = point_copy(curve->neutral); + i = 0; + {% else %} + bn_naf_strip_left(naf, 0); + int8_t val = naf->data[0]; + if (val == 1) { + q = point_copy(point); + } else if (val == -1) { + q = point_copy(neg); + } + i = 1; + {% endif %} + {% if scalarmult.always %} point_t *q_copy = point_new(); - point_t *dummy = point_new(); {% endif %} - - point_t *q = point_copy(curve->neutral); - for (long i = naf->length - 1; i >= 0; i--) { + for (; i < naf->length; i++) { point_dbl(q, curve, q); + {% if scalarmult.always %} point_set(q, q_copy); {% endif %} @@ -17,28 +32,32 @@ point_t *scalar_mult_ltr(point_t *point, point_t *neg, curve_t *curve, wnaf_t *n if (naf->data[i] == 1) { point_accumulate(q, point, curve, q); {% if scalarmult.always %} - point_accumulate(q_copy, neg, curve, dummy); + point_accumulate(q_copy, neg, curve, q_copy); {% endif %} } else if (naf->data[i] == -1) { point_accumulate(q, neg, curve, q); {% if scalarmult.always %} - point_accumulate(q_copy, point, curve, dummy); + point_accumulate(q_copy, point, curve, q_copy); {% endif %} } } {% if scalarmult.always %} point_free(q_copy); - point_free(dummy); {% endif %} return q; } -point_t* scalar_mult_rtl(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf) { +point_t* scalar_mult_rtl(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf, size_t bits) { {% if scalarmult.always %} point_t *r_copy = point_new(); - point_t *dummy = point_new(); {% endif %} + {% if scalarmult.complete %} + bn_naf_pad_left(naf, 0, (bits + 1) - naf->length); + {% endif %} + + bn_naf_reverse(naf); + point_t *q = point_copy(point); point_t *r = point_copy(curve->neutral); point_t *q_neg = point_new(); @@ -46,17 +65,18 @@ point_t* scalar_mult_rtl(point_t *point, point_t *neg, curve_t *curve, wnaf_t *n {% if scalarmult.always %} point_set(r, r_copy); {% endif %} + if (naf->data[i] == 1) { point_accumulate(r, q, curve, r); {% if scalarmult.always %} point_neg(q, curve, q_neg); - point_accumulate(r_copy, q_neg, curve, dummy); + point_accumulate(r_copy, q_neg, curve, r_copy); {% endif %} } else if (naf->data[i] == -1) { point_neg(q, curve, q_neg); point_accumulate(r, q_neg, curve, r); {% if scalarmult.always %} - point_accumulate(r_copy, q, curve, dummy); + point_accumulate(r_copy, q, curve, r_copy); {% endif %} } point_dbl(q, curve, q); @@ -66,7 +86,6 @@ point_t* scalar_mult_rtl(point_t *point, point_t *neg, curve_t *curve, wnaf_t *n {% if scalarmult.always %} point_free(r_copy); - point_free(dummy); {% endif %} return r; } @@ -74,14 +93,14 @@ point_t* scalar_mult_rtl(point_t *point, point_t *neg, curve_t *curve, wnaf_t *n static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, point_t *out) { point_t *neg = point_new(); point_neg(point, curve, neg); - wnaf_t *naf = bn_bnaf(scalar); - {# TODO: Handle the ".complete" option #} + wnaf_t *naf = bn_bnaf(scalar); + size_t bits = bn_bit_length(&curve->n); {% if scalarmult.direction == ProcessingDirection.LTR %} - point_t *q = scalar_mult_ltr(point, neg, curve, naf); + point_t *q = scalar_mult_ltr(point, neg, curve, naf, bits); {% elif scalarmult.direction == ProcessingDirection.RTL %} - point_t *q = scalar_mult_rtl(point, neg, curve, naf); + point_t *q = scalar_mult_rtl(point, neg, curve, naf, bits); {% endif %} free(naf->data); diff --git a/pyecsca/codegen/templates/mult_wnaf.c b/pyecsca/codegen/templates/mult_wnaf.c index 3c5f2b2..c9228fb 100644 --- a/pyecsca/codegen/templates/mult_wnaf.c +++ b/pyecsca/codegen/templates/mult_wnaf.c @@ -26,7 +26,7 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin wnaf_t *naf = bn_wnaf(scalar, {{ scalarmult.width }}); - for (long i = naf->length - 1; i >= 0; i--) { + for (long i = 0; i < naf->length; i++) { point_dbl(q, curve, q); int8_t val = naf->data[i]; if (val > 0) { diff --git a/test/conftest.py b/test/conftest.py index ca32485..80faf71 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -38,7 +38,6 @@ window_mults = [ (WindowBoothMultiplier, dict(width=5)), (WindowBoothMultiplier, dict(width=6)) ] - naf_mults = [ (WindowNAFMultiplier, dict(width=2)), (WindowNAFMultiplier, dict(width=3)), @@ -48,7 +47,11 @@ naf_mults = [ (BinaryNAFMultiplier, dict(always=False, direction=ProcessingDirection.LTR)), (BinaryNAFMultiplier, dict(always=False, direction=ProcessingDirection.RTL)), (BinaryNAFMultiplier, dict(always=True, direction=ProcessingDirection.LTR)), - (BinaryNAFMultiplier, dict(always=True, direction=ProcessingDirection.RTL)) + (BinaryNAFMultiplier, dict(always=True, direction=ProcessingDirection.RTL)), + (BinaryNAFMultiplier, dict(complete=False, always=False, direction=ProcessingDirection.LTR)), + (BinaryNAFMultiplier, dict(complete=False, always=False, direction=ProcessingDirection.RTL)), + (BinaryNAFMultiplier, dict(complete=False, always=True, direction=ProcessingDirection.LTR)), + (BinaryNAFMultiplier, dict(complete=False, always=True, direction=ProcessingDirection.RTL)) ] comb_mults = [ (CombMultiplier, dict(width=2, always=True)), diff --git a/test/test_equivalence.py b/test/test_equivalence.py index e3bf710..f808726 100644 --- a/test/test_equivalence.py +++ b/test/test_equivalence.py @@ -136,27 +136,28 @@ def test_equivalence(target, secp128r1, capfd): mult = target.mult target.connect() target.set_params(secp128r1) - priv, pub = target.generate() - with local(DefaultContext()) as ctx: - mult.init(secp128r1, secp128r1.generator) - expected = mult.multiply(priv).to_affine() - captured = capfd.readouterr() - with capfd.disabled(): - assert secp128r1.curve.is_on_curve(pub) - #assert pub == expected - from_codegen = parse_trace(captured.err) - from_sim = parse_ctx(ctx.actions[0]) + parse_ctx(ctx.actions[1]) - codegen_set = set(make_hashable(from_codegen)) - sim_set = set(make_hashable(from_sim)) - if codegen_set != sim_set: - print(len(from_codegen), len(from_sim)) - print("In codegen but not in sim:") - for entry in codegen_set - sim_set: - print(entry) - print("In sim but not in codegen:") - for entry in sim_set - codegen_set: - print(entry) - assert from_codegen == from_sim + for _ in range(3): + priv, pub = target.generate() + with local(DefaultContext()) as ctx: + mult.init(secp128r1, secp128r1.generator) + expected = mult.multiply(priv).to_affine() + captured = capfd.readouterr() + with capfd.disabled(): + assert secp128r1.curve.is_on_curve(pub) + assert pub == expected + from_codegen = parse_trace(captured.err) + from_sim = parse_ctx(ctx.actions[0]) + parse_ctx(ctx.actions[1]) + codegen_set = set(make_hashable(from_codegen)) + sim_set = set(make_hashable(from_sim)) + if codegen_set != sim_set: + print(len(from_codegen), len(from_sim)) + print("In codegen but not in sim:") + for entry in codegen_set - sim_set: + print(entry) + print("In sim but not in codegen:") + for entry in sim_set - codegen_set: + print(entry) + assert from_codegen == from_sim target.quit() target.disconnect() |
