aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--test/gdb_script.py87
-rw-r--r--test/test_equivalence.py189
2 files changed, 276 insertions, 0 deletions
diff --git a/test/gdb_script.py b/test/gdb_script.py
new file mode 100644
index 0000000..c231a3c
--- /dev/null
+++ b/test/gdb_script.py
@@ -0,0 +1,87 @@
+import sys
+import json
+
+import gdb
+
+
+def extract_bn(bn):
+ data_ptr = bn["dp"]
+ used = int(bn["used"])
+ bs = int(gdb.lookup_global_symbol("bn_digit_bits").value())
+ result = 0
+ for i in range(used):
+ limb = int((data_ptr + i).dereference())
+ result += limb << (i * bs)
+ return result
+
+
+def extract_point(point):
+ result = {}
+ for field in point.type.fields():
+ field_name = field.name
+ if len(field_name) != 1:
+ continue
+ field_value = point[field_name]
+ result[field_name] = extract_bn(field_value)
+ return result
+
+
+class TraceFunction(gdb.Breakpoint):
+ def stop(self):
+ try:
+ set_bp.enabled = True
+ frame = gdb.newest_frame()
+ block = frame.block()
+ print(frame.name(), file=sys.stderr)
+ out = []
+ for sym in block:
+ if sym.is_argument:
+ name = sym.name
+ try:
+ value = frame.read_var(name)
+ except Exception as e:
+ value = f"<unavailable: {e}>"
+ deref = value.dereference()
+ if deref.type.name == "point_t":
+ if "out" in name:
+ out.append(deref)
+ else:
+ pt = extract_point(deref)
+ print(f"{name}: {json.dumps(pt)}", file=sys.stderr)
+ bp = TraceExit(frame)
+ bp.silent = True
+ bp.target = out
+ except RuntimeError as e:
+ pass
+ return False # Continue execution
+
+
+class TraceExit(gdb.FinishBreakpoint):
+ def stop(self):
+ set_bp.enabled = False
+ for i, point in enumerate(self.target):
+ print(f"out_{i}: {json.dumps(extract_point(point))}", file=sys.stderr)
+ return False # Continue execution
+
+
+def register_bp(name):
+ if gdb.lookup_global_symbol(name) is not None:
+ bp = TraceFunction(name)
+ bp.silent = True
+ return bp
+ return None
+
+
+register_bp("point_add")
+register_bp("point_dadd")
+register_bp("point_dadd")
+register_bp("point_ladd")
+register_bp("point_dbl")
+register_bp("point_neg")
+register_bp("point_scl")
+register_bp("point_tpl")
+set_bp = register_bp("point_set")
+set_bp.enabled = False
+
+gdb.execute("run")
+# print("\x04", file=sys.stderr)
diff --git a/test/test_equivalence.py b/test/test_equivalence.py
new file mode 100644
index 0000000..dceae99
--- /dev/null
+++ b/test/test_equivalence.py
@@ -0,0 +1,189 @@
+import subprocess
+import json
+from typing import Generator, Any
+from click.testing import CliRunner
+
+import pytest
+from importlib import resources
+from os.path import join
+
+from pyecsca.codegen.builder import build_impl
+from pyecsca.ec.formula import FormulaAction, NegationFormula
+from pyecsca.ec.model import CurveModel
+from pyecsca.ec.coordinates import CoordinateModel
+from pyecsca.sca.target.binary import BinaryTarget
+from pyecsca.codegen.client import ImplTarget
+from pyecsca.ec.context import DefaultContext, local, Node
+
+from pyecsca.ec.mult import (
+ LTRMultiplier,
+ RTLMultiplier,
+ CoronMultiplier,
+ BinaryNAFMultiplier,
+ WindowNAFMultiplier,
+ SlidingWindowMultiplier,
+ AccumulationOrder,
+ ProcessingDirection,
+ ScalarMultiplier,
+ FixedWindowLTRMultiplier,
+ FullPrecompMultiplier,
+ BGMWMultiplier,
+ CombMultiplier,
+ ScalarMultiplicationAction,
+)
+
+
+class GDBTarget(ImplTarget, BinaryTarget):
+ def __init__(self, model: CurveModel, coords: CoordinateModel, **kwargs):
+ super().__init__(model, coords, **kwargs)
+
+ def connect(self):
+ with resources.path("test", "gdb_script.py") as gdb_script:
+ self.process = subprocess.Popen(
+ ["gdb", "-batch", "-x", gdb_script, "--args", *self.binary],
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ # stderr=subprocess.PIPE,
+ text=True,
+ bufsize=1,
+ )
+
+ # def disconnect(self):
+ # if self.process is None:
+ # return
+ # if self.process.stdin is not None:
+ # self.process.stdin.close()
+ # if self.process.stdout is not None:
+ # self.process.stdout.close()
+ # if self.process.stderr is not None:
+ # self.process.stderr.close()
+ # self.process.terminate()
+ # self.process.wait()
+
+
+@pytest.fixture(scope="module")
+def target(simple_multiplier, secp128r1) -> Generator[GDBTarget, Any, None]:
+ mult_class, mult_kwargs = simple_multiplier
+ mult_name = mult_class.__name__
+ formulas = ["add-1998-cmo", "dbl-1998-cmo"]
+ if NegationFormula in mult_class.requires:
+ formulas.append("neg")
+ runner = CliRunner()
+ with runner.isolated_filesystem() as tmpdir:
+ res = runner.invoke(
+ build_impl,
+ [
+ "--platform",
+ "HOST",
+ "--ecdsa",
+ "--ecdh",
+ secp128r1.curve.model.shortname,
+ secp128r1.curve.coordinate_model.name,
+ *formulas,
+ f"{mult_name}({','.join(f'{key}={value}' for key, value in mult_kwargs.items())})",
+ ".",
+ ],
+ env={"DEBUG": "1", "CFLAGS": "-g -O0"},
+ )
+ assert res.exit_code == 0
+ target = GDBTarget(
+ secp128r1.curve.model,
+ secp128r1.curve.coordinate_model,
+ binary=join(tmpdir, "pyecsca-codegen-HOST.elf"),
+ )
+ formula_instances = [
+ secp128r1.curve.coordinate_model.formulas[formula] for formula in formulas
+ ]
+ mult = mult_class(*formula_instances, **mult_kwargs)
+ target.mult = mult
+ yield target
+
+
+def parse_trace(captured: str):
+ current_function = None
+ args = []
+ rets = []
+ result = []
+ for line in captured.split("\n"):
+ if ":" not in line:
+ func = line.strip()
+ if func.startswith("point_"):
+ func = func[len("point_") :]
+ if func == "set":
+ # The sets that happen inside another formula (like add) are a sign of short-circuiting.
+ # The Python simulation does not record the short-circuits, so we ignore them here.
+ current_function = None
+ else:
+ if current_function is not None:
+ result.append((current_function, args, rets))
+ current_function = func
+ args = []
+ rets = []
+ else:
+ name, data = line.split(":", 1)
+ name = name.strip()
+ value = json.loads(data)
+ if "out" in name:
+ rets.append(value)
+ else:
+ args.append(value)
+ return result
+
+
+def parse_ctx(scalarmult: Node):
+ result = []
+ for node in scalarmult.children:
+ action: FormulaAction = node.action
+ formula = action.formula
+ name = formula.shortname
+ args = []
+ for point in action.input_points:
+ point_value = {k: int(v) for k, v in point.coords.items()}
+ args.append(point_value)
+ rets = []
+ for point in action.output_points:
+ point_value = {k: int(v) for k, v in point.coords.items()}
+ rets.append(point_value)
+ result.append((name, args, rets))
+ return result
+
+
+def make_hashable(trace):
+ result = []
+ for entry in trace:
+ name, args, rets = entry
+ args_t = tuple(tuple(arg.items()) for arg in args)
+ rets_t = tuple(tuple(ret.items()) for ret in rets)
+ result.append((name, args_t, rets_t))
+ return tuple(result)
+
+
+def test_equivalence(target, secp128r1, capfd):
+ mult = target.mult
+ target.connect()
+ target.set_params(secp128r1)
+ for _ in range(1):
+ priv, pub = target.generate()
+ assert secp128r1.curve.is_on_curve(pub)
+ with local(DefaultContext()) as ctx:
+ mult.init(secp128r1, secp128r1.generator)
+ expected = mult.multiply(priv).to_affine()
+ captured = capfd.readouterr()
+ with capfd.disabled():
+ assert pub == expected
+ from_codegen = parse_trace(captured.err)
+ from_sim = parse_ctx(ctx.actions[0]) + parse_ctx(ctx.actions[1])
+ codegen_set = set(make_hashable(from_codegen))
+ sim_set = set(make_hashable(from_sim))
+ if codegen_set != sim_set:
+ print(len(from_codegen), len(from_sim))
+ print("In codegen but not in sim:")
+ for entry in codegen_set - sim_set:
+ print(entry)
+ print("In sim but not in codegen:")
+ for entry in sim_set - codegen_set:
+ print(entry)
+ assert from_codegen == from_sim
+
+ target.quit()
+ target.disconnect()