aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pyecsca/codegen/templates/mult_bnaf.c53
-rw-r--r--pyecsca/codegen/templates/mult_wnaf.c2
-rw-r--r--test/conftest.py7
-rw-r--r--test/test_equivalence.py43
4 files changed, 64 insertions, 41 deletions
diff --git a/pyecsca/codegen/templates/mult_bnaf.c b/pyecsca/codegen/templates/mult_bnaf.c
index d0cafdf..090807c 100644
--- a/pyecsca/codegen/templates/mult_bnaf.c
+++ b/pyecsca/codegen/templates/mult_bnaf.c
@@ -1,15 +1,30 @@
#include "mult.h"
#include "point.h"
-point_t *scalar_mult_ltr(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf) {
+point_t *scalar_mult_ltr(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf, size_t bits) {
+ point_t *q;
+ long i;
+ {% if scalarmult.complete %}
+ bn_naf_pad_left(naf, 0, (bits + 1) - naf->length);
+ q = point_copy(curve->neutral);
+ i = 0;
+ {% else %}
+ bn_naf_strip_left(naf, 0);
+ int8_t val = naf->data[0];
+ if (val == 1) {
+ q = point_copy(point);
+ } else if (val == -1) {
+ q = point_copy(neg);
+ }
+ i = 1;
+ {% endif %}
+
{% if scalarmult.always %}
point_t *q_copy = point_new();
- point_t *dummy = point_new();
{% endif %}
-
- point_t *q = point_copy(curve->neutral);
- for (long i = naf->length - 1; i >= 0; i--) {
+ for (; i < naf->length; i++) {
point_dbl(q, curve, q);
+
{% if scalarmult.always %}
point_set(q, q_copy);
{% endif %}
@@ -17,28 +32,32 @@ point_t *scalar_mult_ltr(point_t *point, point_t *neg, curve_t *curve, wnaf_t *n
if (naf->data[i] == 1) {
point_accumulate(q, point, curve, q);
{% if scalarmult.always %}
- point_accumulate(q_copy, neg, curve, dummy);
+ point_accumulate(q_copy, neg, curve, q_copy);
{% endif %}
} else if (naf->data[i] == -1) {
point_accumulate(q, neg, curve, q);
{% if scalarmult.always %}
- point_accumulate(q_copy, point, curve, dummy);
+ point_accumulate(q_copy, point, curve, q_copy);
{% endif %}
}
}
{% if scalarmult.always %}
point_free(q_copy);
- point_free(dummy);
{% endif %}
return q;
}
-point_t* scalar_mult_rtl(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf) {
+point_t* scalar_mult_rtl(point_t *point, point_t *neg, curve_t *curve, wnaf_t *naf, size_t bits) {
{% if scalarmult.always %}
point_t *r_copy = point_new();
- point_t *dummy = point_new();
{% endif %}
+ {% if scalarmult.complete %}
+ bn_naf_pad_left(naf, 0, (bits + 1) - naf->length);
+ {% endif %}
+
+ bn_naf_reverse(naf);
+
point_t *q = point_copy(point);
point_t *r = point_copy(curve->neutral);
point_t *q_neg = point_new();
@@ -46,17 +65,18 @@ point_t* scalar_mult_rtl(point_t *point, point_t *neg, curve_t *curve, wnaf_t *n
{% if scalarmult.always %}
point_set(r, r_copy);
{% endif %}
+
if (naf->data[i] == 1) {
point_accumulate(r, q, curve, r);
{% if scalarmult.always %}
point_neg(q, curve, q_neg);
- point_accumulate(r_copy, q_neg, curve, dummy);
+ point_accumulate(r_copy, q_neg, curve, r_copy);
{% endif %}
} else if (naf->data[i] == -1) {
point_neg(q, curve, q_neg);
point_accumulate(r, q_neg, curve, r);
{% if scalarmult.always %}
- point_accumulate(r_copy, q, curve, dummy);
+ point_accumulate(r_copy, q, curve, r_copy);
{% endif %}
}
point_dbl(q, curve, q);
@@ -66,7 +86,6 @@ point_t* scalar_mult_rtl(point_t *point, point_t *neg, curve_t *curve, wnaf_t *n
{% if scalarmult.always %}
point_free(r_copy);
- point_free(dummy);
{% endif %}
return r;
}
@@ -74,14 +93,14 @@ point_t* scalar_mult_rtl(point_t *point, point_t *neg, curve_t *curve, wnaf_t *n
static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, point_t *out) {
point_t *neg = point_new();
point_neg(point, curve, neg);
- wnaf_t *naf = bn_bnaf(scalar);
- {# TODO: Handle the ".complete" option #}
+ wnaf_t *naf = bn_bnaf(scalar);
+ size_t bits = bn_bit_length(&curve->n);
{% if scalarmult.direction == ProcessingDirection.LTR %}
- point_t *q = scalar_mult_ltr(point, neg, curve, naf);
+ point_t *q = scalar_mult_ltr(point, neg, curve, naf, bits);
{% elif scalarmult.direction == ProcessingDirection.RTL %}
- point_t *q = scalar_mult_rtl(point, neg, curve, naf);
+ point_t *q = scalar_mult_rtl(point, neg, curve, naf, bits);
{% endif %}
free(naf->data);
diff --git a/pyecsca/codegen/templates/mult_wnaf.c b/pyecsca/codegen/templates/mult_wnaf.c
index 3c5f2b2..c9228fb 100644
--- a/pyecsca/codegen/templates/mult_wnaf.c
+++ b/pyecsca/codegen/templates/mult_wnaf.c
@@ -26,7 +26,7 @@ static void scalar_mult_inner(bn_t *scalar, point_t *point, curve_t *curve, poin
wnaf_t *naf = bn_wnaf(scalar, {{ scalarmult.width }});
- for (long i = naf->length - 1; i >= 0; i--) {
+ for (long i = 0; i < naf->length; i++) {
point_dbl(q, curve, q);
int8_t val = naf->data[i];
if (val > 0) {
diff --git a/test/conftest.py b/test/conftest.py
index ca32485..80faf71 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -38,7 +38,6 @@ window_mults = [
(WindowBoothMultiplier, dict(width=5)),
(WindowBoothMultiplier, dict(width=6))
]
-
naf_mults = [
(WindowNAFMultiplier, dict(width=2)),
(WindowNAFMultiplier, dict(width=3)),
@@ -48,7 +47,11 @@ naf_mults = [
(BinaryNAFMultiplier, dict(always=False, direction=ProcessingDirection.LTR)),
(BinaryNAFMultiplier, dict(always=False, direction=ProcessingDirection.RTL)),
(BinaryNAFMultiplier, dict(always=True, direction=ProcessingDirection.LTR)),
- (BinaryNAFMultiplier, dict(always=True, direction=ProcessingDirection.RTL))
+ (BinaryNAFMultiplier, dict(always=True, direction=ProcessingDirection.RTL)),
+ (BinaryNAFMultiplier, dict(complete=False, always=False, direction=ProcessingDirection.LTR)),
+ (BinaryNAFMultiplier, dict(complete=False, always=False, direction=ProcessingDirection.RTL)),
+ (BinaryNAFMultiplier, dict(complete=False, always=True, direction=ProcessingDirection.LTR)),
+ (BinaryNAFMultiplier, dict(complete=False, always=True, direction=ProcessingDirection.RTL))
]
comb_mults = [
(CombMultiplier, dict(width=2, always=True)),
diff --git a/test/test_equivalence.py b/test/test_equivalence.py
index e3bf710..f808726 100644
--- a/test/test_equivalence.py
+++ b/test/test_equivalence.py
@@ -136,27 +136,28 @@ def test_equivalence(target, secp128r1, capfd):
mult = target.mult
target.connect()
target.set_params(secp128r1)
- priv, pub = target.generate()
- with local(DefaultContext()) as ctx:
- mult.init(secp128r1, secp128r1.generator)
- expected = mult.multiply(priv).to_affine()
- captured = capfd.readouterr()
- with capfd.disabled():
- assert secp128r1.curve.is_on_curve(pub)
- #assert pub == expected
- from_codegen = parse_trace(captured.err)
- from_sim = parse_ctx(ctx.actions[0]) + parse_ctx(ctx.actions[1])
- codegen_set = set(make_hashable(from_codegen))
- sim_set = set(make_hashable(from_sim))
- if codegen_set != sim_set:
- print(len(from_codegen), len(from_sim))
- print("In codegen but not in sim:")
- for entry in codegen_set - sim_set:
- print(entry)
- print("In sim but not in codegen:")
- for entry in sim_set - codegen_set:
- print(entry)
- assert from_codegen == from_sim
+ for _ in range(3):
+ priv, pub = target.generate()
+ with local(DefaultContext()) as ctx:
+ mult.init(secp128r1, secp128r1.generator)
+ expected = mult.multiply(priv).to_affine()
+ captured = capfd.readouterr()
+ with capfd.disabled():
+ assert secp128r1.curve.is_on_curve(pub)
+ assert pub == expected
+ from_codegen = parse_trace(captured.err)
+ from_sim = parse_ctx(ctx.actions[0]) + parse_ctx(ctx.actions[1])
+ codegen_set = set(make_hashable(from_codegen))
+ sim_set = set(make_hashable(from_sim))
+ if codegen_set != sim_set:
+ print(len(from_codegen), len(from_sim))
+ print("In codegen but not in sim:")
+ for entry in codegen_set - sim_set:
+ print(entry)
+ print("In sim but not in codegen:")
+ for entry in sim_set - codegen_set:
+ print(entry)
+ assert from_codegen == from_sim
target.quit()
target.disconnect()