diff options
Diffstat (limited to 'pyecsca/codegen')
| -rw-r--r-- | pyecsca/codegen/builder.py | 288 | ||||
| -rw-r--r-- | pyecsca/codegen/client.py | 158 | ||||
| -rw-r--r-- | pyecsca/codegen/common.py | 111 | ||||
| -rw-r--r-- | pyecsca/codegen/hal/host/host_hal.h | 1 | ||||
| -rw-r--r-- | pyecsca/codegen/hal/host/uart.c | 4 | ||||
| -rw-r--r-- | pyecsca/codegen/hal/host/uart.h | 2 | ||||
| -rw-r--r-- | pyecsca/codegen/hal/stm32f0/stm32f0_hal.h | 1 | ||||
| -rw-r--r-- | pyecsca/codegen/hal/stm32f3/stm32f3_hal.h | 1 | ||||
| -rw-r--r-- | pyecsca/codegen/hal/xmega/xmega_hal.h | 1 | ||||
| -rw-r--r-- | pyecsca/codegen/render.py | 274 | ||||
| -rw-r--r-- | pyecsca/codegen/simpleserial/simpleserial.c | 1 | ||||
| -rw-r--r-- | pyecsca/codegen/templates/defs.h | 2 | ||||
| -rw-r--r-- | pyecsca/codegen/templates/main.c | 135 | ||||
| -rw-r--r-- | pyecsca/codegen/templates/mult_rtl.c | 2 | ||||
| -rw-r--r-- | pyecsca/codegen/templates/point.c | 7 |
15 files changed, 557 insertions, 431 deletions
diff --git a/pyecsca/codegen/builder.py b/pyecsca/codegen/builder.py index d65eb2b..b1a8816 100644 --- a/pyecsca/codegen/builder.py +++ b/pyecsca/codegen/builder.py @@ -1,284 +1,28 @@ #!/usr/bin/env python3 -import os import re import shutil import subprocess -import tempfile -from ast import Pow from copy import copy from os import path -from typing import List, Set, Mapping, Any, Optional, Tuple, MutableMapping +from typing import List, Optional, Tuple import click -from jinja2 import Environment, PackageLoader -from pkg_resources import resource_filename from public import public +from pyecsca.ec.configuration import Multiplication, Squaring, Reduction, HashType, RandomMod from pyecsca.ec.coordinates import CoordinateModel -from pyecsca.ec.formula import (Formula, AdditionFormula, DoublingFormula, TriplingFormula, - NegationFormula, ScalingFormula, DifferentialAdditionFormula, - LadderFormula) -from pyecsca.ec.model import (CurveModel, ShortWeierstrassModel, MontgomeryModel, EdwardsModel, - TwistedEdwardsModel) -from pyecsca.ec.mult import (ScalarMultiplier, LTRMultiplier, RTLMultiplier, CoronMultiplier, - LadderMultiplier, SimpleLadderMultiplier, DifferentialLadderMultiplier, - BinaryNAFMultiplier) -from pyecsca.ec.op import CodeOp, OpType +from pyecsca.ec.formula import (Formula, AdditionFormula) +from pyecsca.ec.model import (CurveModel) +from pyecsca.ec.mult import (ScalarMultiplier) -from .common import (Platform, Multiplication, Squaring, Reduction, HashType, RandomMod, - Configuration, MULTIPLIERS, wrap_enum) - -env = Environment( - loader=PackageLoader("pyecsca.codegen") -) - - -def render_op(op: OpType, result: str, left: str, right: str, mod: str) -> Optional[str]: - if op == OpType.Add: - return "bn_mod_add(&{}, &{}, &{}, &{});".format(left, right, mod, result) - elif op == OpType.Sub: - return "bn_mod_sub(&{}, &{}, &{}, &{});".format(left, right, mod, result) - elif op == OpType.Mult: - return "bn_mod_mul(&{}, &{}, &{}, &{});".format(left, right, mod, result) - elif op == OpType.Div or op == OpType.Inv: - return "bn_mod_div(&{}, &{}, &{}, &{});".format(left, right, mod, result) - elif op == OpType.Sqr: - return "bn_mod_sqr(&{}, &{}, &{});".format(left, mod, result) - elif op == OpType.Pow: - return "bn_mod_pow(&{}, &{}, &{}, &{});".format(left, right, mod, result) - elif op == OpType.Id: - return "bn_copy(&{}, &{});".format(left, result) - else: - print(op, result, left, right, mod) - return None - - -env.globals["render_op"] = render_op -env.globals["isinstance"] = isinstance - - -def render_defs(model: CurveModel, coords: CoordinateModel) -> str: - return env.get_template("defs.h").render(params=model.parameter_names, - variables=coords.variables) - - -def render_curve_impl(model: CurveModel) -> str: - return env.get_template("curve.c").render(params=model.parameter_names) - - -def transform_ops(ops: List[CodeOp], parameters: List[str], outputs: Set[str], - renames: Mapping[str, str] = None) -> MutableMapping[Any, Any]: - def rename(name: str): - if renames is not None and name not in outputs: - return renames.get(name, name) - return name - - allocations: List[str] = [] - initializations = {} - const_mapping = {} - operations = [] - frees = [] - for op in ops: - if op.result not in allocations: - allocations.append(op.result) - frees.append(op.result) - for param in op.parameters: - if param not in allocations and param not in parameters: - raise ValueError("Should be allocated or parameter: {}".format(param)) - for const in op.constants: - name = "c" + str(const) - if name not in allocations: - allocations.append(name) - initializations[name] = const - const_mapping[const] = name - frees.append(name) - operations.append((op.operator, op.result, rename(op.left), rename(op.right))) - mapped = [] - for op in operations: - o2 = op[2] - if o2 in const_mapping: - o2 = const_mapping[o2] - o3 = op[3] - if o3 in const_mapping and not (isinstance(op[0], Pow) and o3 == 2): - o3 = const_mapping[o3] - mapped.append((op[0], op[1], o2, o3)) - returns = {} - if renames: - for r_from, r_to in renames.items(): - if r_from in outputs: - returns[r_from] = r_to - - return dict(allocations=allocations, - initializations=initializations, - const_mapping=const_mapping, operations=mapped, - frees=frees, returns=returns) - - -def render_ops(ops: List[CodeOp], parameters: List[str], outputs: Set[str], - renames: Mapping[str, str] = None) -> str: - namespace = transform_ops(ops, parameters, outputs, renames) - return env.get_template("ops.c").render(namespace) - - -def render_coords_impl(coords: CoordinateModel) -> str: - ops = [] - for s in coords.satisfying: - try: - ops.append(CodeOp(s)) - except Exception: - pass - renames = {"x": "out_x", "y": "out_y"} - for variable in coords.variables: - renames[variable] = "point->{}".format(variable) - for param in coords.curve_model.parameter_names: - renames[param] = "curve->{}".format(param) - namespace = transform_ops(ops, coords.curve_model.parameter_names, - coords.curve_model.coordinate_names, renames) - returns = namespace["returns"] - namespace["returns"] = {} - frees = namespace["frees"] - namespace["frees"] = {} - - return env.get_template("point.c").render(variables=coords.variables, **namespace, - to_affine_rets=returns, to_affine_frees=frees) - - -def render_formula_impl(formula: Formula, short_circuit: bool = False) -> str: - if isinstance(formula, AdditionFormula): - tname = "formula_add.c" - elif isinstance(formula, DoublingFormula): - tname = "formula_dbl.c" - elif isinstance(formula, TriplingFormula): - tname = "formula_tpl.c" - elif isinstance(formula, NegationFormula): - tname = "formula_neg.c" - elif isinstance(formula, ScalingFormula): - tname = "formula_scl.c" - elif isinstance(formula, DifferentialAdditionFormula): - tname = "formula_dadd.c" - elif isinstance(formula, LadderFormula): - tname = "formula_ladd.c" - else: - raise ValueError - template = env.get_template(tname) - inputs = ["one", "other", "diff"] - outputs = ["out_one", "out_other"] - renames = {} - for input in formula.inputs: - var = input[0] - num = int(input[1:]) - formula.input_index - renames[input] = "{}->{}".format(inputs[num], var) - for param in formula.coordinate_model.curve_model.parameter_names: - renames[param] = "curve->{}".format(param) - for output in formula.outputs: - var = output[0] - num = int(output[1:]) - formula.output_index - renames[output] = "{}->{}".format(outputs[num], var) - namespace = transform_ops(formula.code, formula.coordinate_model.curve_model.parameter_names, - formula.outputs, renames) - namespace["short_circuit"] = short_circuit - return template.render(namespace) - - -def render_scalarmult_impl(scalarmult: ScalarMultiplier) -> str: - return env.get_template("mult.c").render(scalarmult=scalarmult, LTRMultiplier=LTRMultiplier, - RTLMultiplier=RTLMultiplier, - CoronMultiplier=CoronMultiplier, - LadderMultiplier=LadderMultiplier, - SimpleLadderMultiplier=SimpleLadderMultiplier, - DifferentialLadderMultiplier=DifferentialLadderMultiplier, - BinaryNAFMultiplier=BinaryNAFMultiplier) - - -def render_main(model: CurveModel, coords: CoordinateModel, keygen: bool, ecdh: bool, ecdsa: bool) -> str: - return env.get_template("main.c").render(curve_variables=coords.variables, - curve_parameters=model.parameter_names, - keygen=keygen, ecdh=ecdh, ecdsa=ecdsa) - - -def render_makefile(platform: Platform, hash_type: HashType, mod_rand: RandomMod) -> str: - return env.get_template("Makefile").render(platform=str(platform), hash_type=str(hash_type), - mod_rand=str(mod_rand)) - - -def save_render(dir: str, fname: str, rendered: str): - with open(path.join(dir, fname), "w") as f: - f.write(rendered) - - -@public -def render(config: Configuration) -> Tuple[str, str, str]: - temp = tempfile.mkdtemp() - symlinks = ["asn1", "bn", "hal", "hash", "mult", "prng", "simpleserial", "tommath", "fat.h", - "point.h", "curve.h", "mult.h", "Makefile.inc"] - for sym in symlinks: - os.symlink(resource_filename("pyecsca.codegen", sym), path.join(temp, sym)) - gen_dir = path.join(temp, "gen") - os.mkdir(gen_dir) - save_render(temp, "Makefile", - render_makefile(config.platform, config.hash_type, config.mod_rand)) - save_render(temp, "main.c", render_main(config.model, config.coords, config.keygen, config.ecdh, config.ecdsa)) - save_render(gen_dir, "defs.h", render_defs(config.model, config.coords)) - point_render = render_coords_impl(config.coords) - for formula in config.formulas: - point_render += "\n" - point_render += render_formula_impl(formula, config.scalarmult.short_circuit) - save_render(gen_dir, "point.c", point_render) - save_render(gen_dir, "curve.c", render_curve_impl(config.model)) - save_render(gen_dir, "mult.c", render_scalarmult_impl(config.scalarmult)) - return temp, "pyecsca-codegen-{}.elf".format( - str(config.platform)), "pyecsca-codegen-{}.hex".format(str(config.platform)) - - -@public -def render_and_build(config, outdir, strip=False, remove=True): - dir, elf_file, hex_file = render(config) - - res = subprocess.run(["make"], cwd=dir, capture_output=True) - if res.returncode != 0: - raise ValueError("Build failed!") - if strip: - subprocess.run(["strip", elf_file], cwd=dir) - full_elf_path = path.join(dir, elf_file) - full_hex_path = path.join(dir, hex_file) - shutil.copy(full_elf_path, outdir) - shutil.copy(full_hex_path, outdir) - if remove: - shutil.rmtree(dir) - - -def get_model(ctx: click.Context, param, value: str) -> CurveModel: - if value is None: - return None - classes = { - "shortw": ShortWeierstrassModel, - "montgom": MontgomeryModel, - "edwards": EdwardsModel, - "twisted": TwistedEdwardsModel - } - if value not in classes: - raise click.BadParameter("Cannot create CurveModel from '{}'.".format(value)) - model = classes[value]() - ctx.meta["model"] = model - return model - - -def get_coords(ctx: click.Context, param, value: Optional[str]) -> Optional[CoordinateModel]: - if value is None: - return None - model = ctx.meta["model"] - if value not in model.coordinates: - raise click.BadParameter( - "Coordinate model '{}' is not a model in '{}'.".format(value, - model.__class__.__name__)) - coords = model.coordinates[value] - ctx.meta["coords"] = coords - return coords +from pyecsca.codegen.render import render +from .common import Platform, DeviceConfiguration, MULTIPLIERS, wrap_enum, get_model, get_coords def get_formula(ctx: click.Context, param, value: Optional[Tuple[str]]) -> List[Formula]: if not value: return [] - coords = ctx.meta["coords"] + ctx.ensure_object(dict) + coords = ctx.obj["coords"] result = [] for formula in value: if formula not in coords.formulas: @@ -287,7 +31,7 @@ def get_formula(ctx: click.Context, param, value: Optional[Tuple[str]]) -> List[ result.append(coords.formulas[formula]) if len(set(formula.__class__ for formula in result)) != len(result): raise click.BadParameter("Duplicate formula types.") - ctx.meta["formulas"] = copy(result) + ctx.obj["formulas"] = copy(result) return result @@ -308,7 +52,8 @@ def get_multiplier(ctx: click.Context, param, value: Optional[str]) -> Optional[ break if mult_class is None: raise click.BadParameter("Unknown multiplier: {}.".format(name)) - formulas = ctx.meta["formulas"] + ctx.ensure_object(dict) + formulas = ctx.obj["formulas"] classes = set(formula.__class__ for formula in formulas) if not all( any(issubclass(cls, required) for cls in classes) for required in mult_class.requires): @@ -331,7 +76,8 @@ def get_multiplier(ctx: click.Context, param, value: Optional[str]) -> Optional[ def get_ecdsa(ctx: click.Context, param, value: bool) -> bool: if not value: return False - formulas = ctx.meta["formulas"] + ctx.ensure_object(dict) + formulas = ctx.obj["formulas"] if not any(isinstance(formula, AdditionFormula) for formula in formulas): raise click.BadParameter("ECDSA needs an addition formula. None was supplied.") @@ -401,8 +147,8 @@ def build_impl(platform, hash, rand, mul, sqr, red, keygen, ecdh, ecdsa, strip, OUTDIR: The output directory for files with the built impl. """ - config = Configuration(platform, hash, rand, mul, sqr, red, model, coords, formulas, scalarmult, - keygen, ecdh, ecdsa) + config = DeviceConfiguration(model, coords, formulas, scalarmult, hash, rand, mul, sqr, red, + platform, keygen, ecdh, ecdsa) dir, elf_file, hex_file = render(config) res = subprocess.run(["make"], cwd=dir, capture_output=True) @@ -520,4 +266,4 @@ def flash(platform, dir): # pragma: no cover if __name__ == "__main__": - main() + main(obj={}) diff --git a/pyecsca/codegen/client.py b/pyecsca/codegen/client.py index 8893a5c..81dca40 100644 --- a/pyecsca/codegen/client.py +++ b/pyecsca/codegen/client.py @@ -1,35 +1,30 @@ #!/usr/bin/env python3 +import subprocess from binascii import hexlify -from typing import Mapping, Union, Optional, Any from os import path +from subprocess import Popen +from typing import Mapping, Union, Optional import click -from pyecsca.ec.params import DomainParameters +from pyecsca.ec.curves import get_params from pyecsca.ec.mod import Mod -from pyecsca.ec.point import Point - -from .common import EnumDefine, wrap_enum - -try: - import chipwhisperer as cw -except ImportError: - cw = None - +from pyecsca.ec.params import DomainParameters +from pyecsca.ec.point import Point, InfinityPoint +from pyecsca.sca import SerialTarget -class Target(EnumDefine): - HOST = "HOST" - SIMPLESERIAL = "CWLITE_SIMPLESERIAL" - JCARD = "CWLITE_JCARD" +from .common import wrap_enum, Platform, get_model, get_coords def encode_scalar(val: Union[int, Mod]) -> bytes: if isinstance(val, int): return val.to_bytes((val.bit_length() + 7) // 8, "big") elif isinstance(val, Mod): - return encode_scalar(val.x) + return encode_scalar(int(val)) def encode_point(point: Point) -> Mapping: + if isinstance(point, InfinityPoint): + return {"n": bytes([1])} return {var: encode_scalar(value) for var, value in point.coords.items()} @@ -53,8 +48,8 @@ def decode_data(data: bytes) -> Mapping: length = data[parsed + 1] if name & 0x80: sub = decode_data(data[parsed + 2: parsed + 2 + length]) - result[chr(name | 0x7f)] = sub - parsed += len(sub) + 2 + result[chr(name & 0x7f)] = sub + parsed += length + 2 else: result[chr(name)] = data[parsed + 2: parsed + 2 + length] parsed += length + 2 @@ -73,7 +68,7 @@ def cmd_set_curve(group: DomainParameters) -> str: } for param, value in group.curve.parameters.items(): data[param] = encode_scalar(value) - data["g"] = encode_point(group.generator) + data["g"] = encode_point(group.generator.to_affine()) data["i"] = encode_point(group.neutral) return "c" + hexlify(encode_data(None, data)).decode() @@ -87,7 +82,7 @@ def cmd_set_privkey(privkey: int) -> str: def cmd_set_pubkey(pubkey: Point) -> str: - return "w" + hexlify(encode_data(None, {"w": encode_point(pubkey)})).decode() + return "w" + hexlify(encode_data(None, {"w": encode_point(pubkey.to_affine())})).decode() def cmd_scalar_mult(scalar: int) -> str: @@ -95,7 +90,7 @@ def cmd_scalar_mult(scalar: int) -> str: def cmd_ecdh(pubkey: Point) -> str: - return "e" + hexlify(encode_data(None, {"w": encode_point(pubkey)})).decode() + return "e" + hexlify(encode_data(None, {"w": encode_point(pubkey.to_affine())})).decode() def cmd_ecdsa_sign(data: bytes) -> str: @@ -106,49 +101,64 @@ def cmd_ecdsa_verify(data: bytes, sig: bytes) -> str: return "v" + hexlify(encode_data(None, {"d": data, "s": sig})).decode() -def setup_scope(target_type: Target): - if target_type in (Target.SIMPLESERIAL, Target.JCARD): - scope = cw.scope() - scope.default_setup() - return scope - return None +def cmd_debug() -> str: + return "d" -def connect(target_type: Target, scope, **kwargs) -> Any: - if target_type == Target.HOST: - return kwargs["binary"] - elif target_type == Target.SIMPLESERIAL: - return cw.target(scope, cw.targets.SimpleSerial) - elif target_type == Target.JCARD: - return None # TODO: this +class BinaryTarget(SerialTarget): + binary: str + process: Optional[Popen] + def __init__(self, binary: str): + self.binary = binary -def send_command(target_type: Target, target: Any, command: str) -> str: - if target_type == Target.SIMPLESERIAL: - target.write(command + "\n") - result = target.read() - target.simpleserial_wait_ack() - return result - return None # TODO: this + def connect(self): + self.process = Popen([self.binary], stdin=subprocess.PIPE, stdout=subprocess.PIPE, + text=True, bufsize=1) + + def write(self, data: bytes): + self.process.stdin.write(data.decode() + "\n") + self.process.stdin.flush() + + def read(self, timeout: int) -> bytes: + return self.process.stdout.readline() + + def disconnect(self): + if self.process.poll() is not None: + self.process.terminate() @click.group(context_settings={"help_option_names": ["-h", "--help"]}) -@click.option("--target", envvar="TARGET", required=True, - type=click.Choice(Target.names()), - callback=wrap_enum(Target), - help="The target to use.") +@click.option("--platform", envvar="PLATFORM", required=True, + type=click.Choice(Platform.names()), + callback=wrap_enum(Platform), + help="The target platform to use.") @click.option("--binary", help="For HOST target only. The binary to run.") @click.version_option() -def main(target, binary): +@click.pass_context +def main(ctx, platform, binary): """ A tool for communicating with built and flashed ECC implementations. """ - if target in (Target.SIMPLESERIAL, Target.JCARD) and cw is None: - click.secho("ChipWhisperer not installed, SIMPLESERIAL and JCARD targets require it.", fg="red", err=True) - raise click.Abort - if target == Target.HOST and (binary is None or not path.isfile(binary)): - click.secho("Binary is required if the target is the host.", fg="red", err=True) - raise click.Abort + ctx.ensure_object(dict) + if platform != Platform.HOST: + from pyecsca.sca.target import has_chipwhisperer + if not has_chipwhisperer: + click.secho("ChipWhisperer not installed, targets require it.", fg="red", err=True) + raise click.Abort + from pyecsca.sca.target import SimpleSerialTarget + import chipwhisperer as cw + from chipwhisperer.capture.targets.simpleserial_readers.cwlite import \ + SimpleSerial_ChipWhispererLite + ser = SimpleSerial_ChipWhispererLite() + scope = cw.scope() + scope.default_setup() + ctx.obj["target"] = SimpleSerialTarget(ser, scope) + else: + if binary is None or not path.isfile(binary): + click.secho("Binary is required if the target is the host.", fg="red", err=True) + raise click.Abort + ctx.obj["target"] = BinaryTarget(binary) # model = ShortWeierstrassModel() # coords = model.coordinates["projective"] # p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff @@ -184,20 +194,50 @@ def main(target, binary): # print(ot == res) +def get_curve(ctx: click.Context, param, value: Optional[str]) -> DomainParameters: + if value is None: + return None + ctx.ensure_object(dict) + category, name = value.split("/") + curve = get_params(category, name, ctx.obj["coords"].name) + return curve + + @main.command("gen") -def generate(): - pass +@click.argument("model", required=True, + type=click.Choice(["shortw", "montgom", "edwards", "twisted"]), + callback=get_model) +@click.argument("coords", required=True, + callback=get_coords) +@click.argument("curve", required=True, callback=get_curve) +@click.pass_context +def generate(ctx: click.Context, model, coords, curve): + ctx.ensure_object(dict) + set_curve = cmd_set_curve(curve) + generate = cmd_generate() + target: SerialTarget = ctx.obj["target"] + target.connect() + click.echo(set_curve) + target.write(set_curve.encode()) + click.echo(target.read(1)) + click.echo(generate) + target.write(generate.encode()) + click.echo(target.read(1)) + click.echo(target.read(1)) + target.disconnect() @main.command("ecdh") -def ecdh(): - pass +@click.pass_context +def ecdh(ctx: click.Context): + ctx.ensure_object(dict) @main.command("ecdsa") -def ecdsa(): - pass +@click.pass_context +def ecdsa(ctx: click.Context): + ctx.ensure_object(dict) if __name__ == "__main__": - main() + main(obj={}) diff --git a/pyecsca/codegen/common.py b/pyecsca/codegen/common.py index e29fa00..bd57ab6 100644 --- a/pyecsca/codegen/common.py +++ b/pyecsca/codegen/common.py @@ -1,31 +1,18 @@ from dataclasses import dataclass -from enum import Enum -from typing import List, Type -import click +from typing import Type, Optional +import click from public import public +from pyecsca.ec.configuration import EnumDefine, Configuration from pyecsca.ec.coordinates import CoordinateModel -from pyecsca.ec.formula import Formula -from pyecsca.ec.model import CurveModel -from pyecsca.ec.mult import (ScalarMultiplier, LTRMultiplier, RTLMultiplier, CoronMultiplier, +from pyecsca.ec.model import (CurveModel, ShortWeierstrassModel, MontgomeryModel, EdwardsModel, + TwistedEdwardsModel) +from pyecsca.ec.mult import (LTRMultiplier, RTLMultiplier, CoronMultiplier, LadderMultiplier, SimpleLadderMultiplier, DifferentialLadderMultiplier, WindowNAFMultiplier, BinaryNAFMultiplier) @public -class EnumDefine(Enum): - def __str__(self): - return self.value - - def __repr__(self): - return self.value - - @classmethod - def names(cls): - return list(e.name for e in cls) - - -@public class Platform(EnumDefine): """Platform to build for.""" HOST = "HOST" @@ -35,62 +22,9 @@ class Platform(EnumDefine): @public -class Multiplication(EnumDefine): - """Base multiplication algorithm to use.""" - TOOM_COOK = "MUL_TOOM_COOK" - KARATSUBA = "MUL_KARATSUBA" - COMBA = "MUL_COMBA" - BASE = "MUL_BASE" - - -@public -class Squaring(EnumDefine): - """Base squaring algorithm to use.""" - TOOM_COOK = "SQR_TOOM_COOK" - KARATSUBA = "SQR_KARATSUBA" - COMBA = "SQR_COMBA" - BASE = "SQR_BASE" - - -@public -class Reduction(EnumDefine): - """Modular reduction method used.""" - BARRETT = "RED_BARRETT" - MONTGOMERY = "RED_MONTGOMERY" - BASE = "RED_BASE" - - -@public -class HashType(EnumDefine): - """Hash algorithm used in ECDH and ECDSA.""" - NONE = "HASH_NONE" - SHA1 = "HASH_SHA1" - SHA224 = "HASH_SHA224" - SHA256 = "HASH_SHA256" - SHA384 = "HASH_SHA384" - SHA512 = "HASH_SHA512" - - -@public -class RandomMod(EnumDefine): - """Method of sampling a uniform integer modulo order.""" - SAMPLE = "MOD_RAND_SAMPLE" - REDUCE = "MOD_RAND_REDUCE" - - -@public -@dataclass -class Configuration(object): +@dataclass(frozen=True) +class DeviceConfiguration(Configuration): platform: Platform - hash_type: HashType - mod_rand: RandomMod - mult: Multiplication # TODO: Use this - sqr: Squaring # TODO: Use this - red: Reduction # TODO: Use this - model: CurveModel - coords: CoordinateModel - formulas: List[Formula] - scalarmult: ScalarMultiplier keygen: bool ecdh: bool ecdsa: bool @@ -143,3 +77,32 @@ def wrap_enum(enum_class: Type[EnumDefine]): "Cannot create {} enum from {}.".format(enum_class.__name__, value)) return callback + + +def get_model(ctx: click.Context, param, value: str) -> CurveModel: + if value is None: + return None + classes = { + "shortw": ShortWeierstrassModel, + "montgom": MontgomeryModel, + "edwards": EdwardsModel, + "twisted": TwistedEdwardsModel + } + model = classes[value]() + ctx.ensure_object(dict) + ctx.obj["model"] = model + return model + + +def get_coords(ctx: click.Context, param, value: Optional[str]) -> Optional[CoordinateModel]: + if value is None: + return None + ctx.ensure_object(dict) + model = ctx.obj["model"] + if value not in model.coordinates: + raise click.BadParameter( + "Coordinate model '{}' is not a model in '{}'.".format(value, + model.__class__.__name__)) + coords = model.coordinates[value] + ctx.obj["coords"] = coords + return coords
\ No newline at end of file diff --git a/pyecsca/codegen/hal/host/host_hal.h b/pyecsca/codegen/hal/host/host_hal.h index b78172c..26d8f4d 100644 --- a/pyecsca/codegen/hal/host/host_hal.h +++ b/pyecsca/codegen/hal/host/host_hal.h @@ -10,5 +10,6 @@ #define init_uart init_uart0 #define putch output_ch_0 #define getch input_ch_0 +#define flush flush_ch_0 #endif //HOST_HAL_H_ diff --git a/pyecsca/codegen/hal/host/uart.c b/pyecsca/codegen/hal/host/uart.c index 0410462..1f8d947 100644 --- a/pyecsca/codegen/hal/host/uart.c +++ b/pyecsca/codegen/hal/host/uart.c @@ -4,4 +4,6 @@ void init_uart0(void) {} int input_ch_0(void) { return getchar(); } -void output_ch_0(char data) { putchar(data); }
\ No newline at end of file +void output_ch_0(char data) { putchar(data); } + +void flush_ch_0(void) { fflush(stdout); }
\ No newline at end of file diff --git a/pyecsca/codegen/hal/host/uart.h b/pyecsca/codegen/hal/host/uart.h index ffbc964..e65bfff 100644 --- a/pyecsca/codegen/hal/host/uart.h +++ b/pyecsca/codegen/hal/host/uart.h @@ -10,4 +10,6 @@ int input_ch_0(void); void output_ch_0(char data); +void flush_ch_0(void); + #endif //UART_H_
\ No newline at end of file diff --git a/pyecsca/codegen/hal/stm32f0/stm32f0_hal.h b/pyecsca/codegen/hal/stm32f0/stm32f0_hal.h index 630a583..1cbdefc 100644 --- a/pyecsca/codegen/hal/stm32f0/stm32f0_hal.h +++ b/pyecsca/codegen/hal/stm32f0/stm32f0_hal.h @@ -4,6 +4,7 @@ void init_uart(void); void putch(char c); char getch(void); +#define flush() void trigger_setup(void); void trigger_low(void); diff --git a/pyecsca/codegen/hal/stm32f3/stm32f3_hal.h b/pyecsca/codegen/hal/stm32f3/stm32f3_hal.h index 016c3be..b5fec45 100644 --- a/pyecsca/codegen/hal/stm32f3/stm32f3_hal.h +++ b/pyecsca/codegen/hal/stm32f3/stm32f3_hal.h @@ -24,6 +24,7 @@ void init_uart(void); void putch(char c); char getch(void); +#define flush() void trigger_setup(void); void trigger_low(void); diff --git a/pyecsca/codegen/hal/xmega/xmega_hal.h b/pyecsca/codegen/hal/xmega/xmega_hal.h index d86f9a6..1647824 100644 --- a/pyecsca/codegen/hal/xmega/xmega_hal.h +++ b/pyecsca/codegen/hal/xmega/xmega_hal.h @@ -30,6 +30,7 @@ #define init_uart init_uart0 #define putch output_ch_0 #define getch input_ch_0 +#define flush() #if PLATFORM == CW303 #define led_error(a) if (a) {PORTA.OUTCLR = PIN6_bm;} else {PORTA.OUTSET = PIN6_bm;} diff --git a/pyecsca/codegen/render.py b/pyecsca/codegen/render.py new file mode 100644 index 0000000..ead1470 --- /dev/null +++ b/pyecsca/codegen/render.py @@ -0,0 +1,274 @@ +import os +import shutil +import subprocess +import tempfile +from _ast import Pow +from os import path +from typing import Optional, List, Set, Mapping, MutableMapping, Any, Tuple + +from jinja2 import Environment, PackageLoader +from pkg_resources import resource_filename +from public import public +from pyecsca.ec.configuration import HashType, RandomMod +from pyecsca.ec.coordinates import CoordinateModel +from pyecsca.ec.formula import (Formula, AdditionFormula, DoublingFormula, TriplingFormula, + NegationFormula, ScalingFormula, DifferentialAdditionFormula, + LadderFormula) +from pyecsca.ec.model import CurveModel +from pyecsca.ec.mult import (ScalarMultiplier, LTRMultiplier, RTLMultiplier, CoronMultiplier, + LadderMultiplier, SimpleLadderMultiplier, DifferentialLadderMultiplier, + BinaryNAFMultiplier) +from pyecsca.ec.op import OpType, CodeOp + +from pyecsca.codegen.common import Platform, DeviceConfiguration + +env = Environment( + loader=PackageLoader("pyecsca.codegen") +) + + +env.globals["isinstance"] = isinstance + +def render_op(op: OpType, result: str, left: str, right: str, mod: str) -> Optional[str]: + if op == OpType.Add: + return "bn_mod_add(&{}, &{}, &{}, &{});".format(left, right, mod, result) + elif op == OpType.Sub: + return "bn_mod_sub(&{}, &{}, &{}, &{});".format(left, right, mod, result) + elif op == OpType.Mult: + return "bn_mod_mul(&{}, &{}, &{}, &{});".format(left, right, mod, result) + elif op == OpType.Div or op == OpType.Inv: + return "bn_mod_div(&{}, &{}, &{}, &{});".format(left, right, mod, result) + elif op == OpType.Sqr: + return "bn_mod_sqr(&{}, &{}, &{});".format(left, mod, result) + elif op == OpType.Pow: + return "bn_mod_pow(&{}, &{}, &{}, &{});".format(left, right, mod, result) + elif op == OpType.Id: + return "bn_copy(&{}, &{});".format(left, result) + else: + print(op, result, left, right, mod) + return None + +env.globals["render_op"] = render_op + +def render_defs(model: CurveModel, coords: CoordinateModel) -> str: + return env.get_template("defs.h").render(params=model.parameter_names, + variables=coords.variables) + + +def render_curve_impl(model: CurveModel) -> str: + return env.get_template("curve.c").render(params=model.parameter_names) + + +def transform_ops(ops: List[CodeOp], parameters: List[str], outputs: Set[str], + renames: Mapping[str, str] = None) -> MutableMapping[Any, Any]: + def rename(name: str): + if renames is not None and name not in outputs: + return renames.get(name, name) + return name + + allocations: List[str] = [] + initializations = {} + const_mapping = {} + operations = [] + frees = [] + for op in ops: + if op.result not in allocations: + allocations.append(op.result) + frees.append(op.result) + for param in op.parameters: + if param not in allocations and param not in parameters: + raise ValueError("Should be allocated or parameter: {}".format(param)) + for const in op.constants: + name = "c" + str(const) + if name not in allocations: + allocations.append(name) + initializations[name] = const + const_mapping[const] = name + frees.append(name) + operations.append((op.operator, op.result, rename(op.left), rename(op.right))) + mapped = [] + for op in operations: + o2 = op[2] + if o2 in const_mapping: + o2 = const_mapping[o2] + o3 = op[3] + if o3 in const_mapping and not (isinstance(op[0], Pow) and o3 == 2): + o3 = const_mapping[o3] + mapped.append((op[0], op[1], o2, o3)) + returns = {} + if renames: + for r_from, r_to in renames.items(): + if r_from in outputs: + returns[r_from] = r_to + + return dict(allocations=allocations, + initializations=initializations, + const_mapping=const_mapping, operations=mapped, + frees=frees, returns=returns) + + +def render_ops(ops: List[CodeOp], parameters: List[str], outputs: Set[str], + renames: Mapping[str, str] = None) -> str: + namespace = transform_ops(ops, parameters, outputs, renames) + return env.get_template("ops.c").render(namespace) + + +def render_coords_impl(coords: CoordinateModel) -> str: + ops = [] + for s in coords.satisfying: + try: + ops.append(CodeOp(s)) + except Exception: + pass + renames = {"x": "out_x", "y": "out_y"} + for variable in coords.variables: + renames[variable] = "point->{}".format(variable) + for param in coords.curve_model.parameter_names: + renames[param] = "curve->{}".format(param) + namespace = transform_ops(ops, coords.curve_model.parameter_names, + coords.curve_model.coordinate_names, renames) + returns = namespace["returns"] + namespace["returns"] = {} + frees = namespace["frees"] + namespace["frees"] = {} + + return env.get_template("point.c").render(variables=coords.variables, **namespace, + to_affine_rets=returns, to_affine_frees=frees) + + +def render_formula_impl(formula: Formula, short_circuit: bool = False) -> str: + if isinstance(formula, AdditionFormula): + tname = "formula_add.c" + elif isinstance(formula, DoublingFormula): + tname = "formula_dbl.c" + elif isinstance(formula, TriplingFormula): + tname = "formula_tpl.c" + elif isinstance(formula, NegationFormula): + tname = "formula_neg.c" + elif isinstance(formula, ScalingFormula): + tname = "formula_scl.c" + elif isinstance(formula, DifferentialAdditionFormula): + tname = "formula_dadd.c" + elif isinstance(formula, LadderFormula): + tname = "formula_ladd.c" + else: + raise ValueError + template = env.get_template(tname) + inputs = ["one", "other", "diff"] + outputs = ["out_one", "out_other"] + renames = {} + for input in formula.inputs: + var = input[0] + num = int(input[1:]) - formula.input_index + renames[input] = "{}->{}".format(inputs[num], var) + for param in formula.coordinate_model.curve_model.parameter_names: + renames[param] = "curve->{}".format(param) + for output in formula.outputs: + var = output[0] + num = int(output[1:]) - formula.output_index + renames[output] = "{}->{}".format(outputs[num], var) + namespace = transform_ops(formula.code, formula.coordinate_model.curve_model.parameter_names, + formula.outputs, renames) + namespace["short_circuit"] = short_circuit + return template.render(namespace) + + +def render_scalarmult_impl(scalarmult: ScalarMultiplier) -> str: + return env.get_template("mult.c").render(scalarmult=scalarmult, LTRMultiplier=LTRMultiplier, + RTLMultiplier=RTLMultiplier, + CoronMultiplier=CoronMultiplier, + LadderMultiplier=LadderMultiplier, + SimpleLadderMultiplier=SimpleLadderMultiplier, + DifferentialLadderMultiplier=DifferentialLadderMultiplier, + BinaryNAFMultiplier=BinaryNAFMultiplier) + + +def render_main(model: CurveModel, coords: CoordinateModel, keygen: bool, ecdh: bool, + ecdsa: bool) -> str: + return env.get_template("main.c").render(model=model, coords=coords, + curve_variables=coords.variables, + curve_parameters=model.parameter_names, + keygen=keygen, ecdh=ecdh, ecdsa=ecdsa) + + +def render_makefile(platform: Platform, hash_type: HashType, mod_rand: RandomMod) -> str: + return env.get_template("Makefile").render(platform=str(platform), hash_type=str(hash_type), + mod_rand=str(mod_rand)) + + +def save_render(dir: str, fname: str, rendered: str): + with open(path.join(dir, fname), "w") as f: + f.write(rendered) + + +@public +def render(config: DeviceConfiguration) -> Tuple[str, str, str]: + """ + + :param config: + :return: + """ + temp = tempfile.mkdtemp() + symlinks = ["asn1", "bn", "hal", "hash", "mult", "prng", "simpleserial", "tommath", "fat.h", + "point.h", "curve.h", "mult.h", "Makefile.inc"] + for sym in symlinks: + os.symlink(resource_filename("pyecsca.codegen", sym), path.join(temp, sym)) + gen_dir = path.join(temp, "gen") + os.mkdir(gen_dir) + save_render(temp, "Makefile", + render_makefile(config.platform, config.hash_type, config.mod_rand)) + save_render(temp, "main.c", + render_main(config.model, config.coords, config.keygen, config.ecdh, config.ecdsa)) + save_render(gen_dir, "defs.h", render_defs(config.model, config.coords)) + point_render = render_coords_impl(config.coords) + for formula in config.formulas: + point_render += "\n" + point_render += render_formula_impl(formula, config.scalarmult.short_circuit) + save_render(gen_dir, "point.c", point_render) + save_render(gen_dir, "curve.c", render_curve_impl(config.model)) + save_render(gen_dir, "mult.c", render_scalarmult_impl(config.scalarmult)) + return temp, "pyecsca-codegen-{}.elf".format( + str(config.platform)), "pyecsca-codegen-{}.hex".format(str(config.platform)) + + +@public +def build(dir: str, elf_file: str, hex_file: str, outdir: str, strip: bool = False, + remove : bool = True) -> subprocess.CompletedProcess: + """ + + :param dir: + :param elf_file: + :param hex_file: + :param outdir: + :param strip: + :param remove: + :return: + """ + res = subprocess.run(["make"], cwd=dir, capture_output=True) + if res.returncode != 0: + raise ValueError("Build failed!") + if strip: + subprocess.run(["strip", elf_file], cwd=dir) + full_elf_path = path.join(dir, elf_file) + full_hex_path = path.join(dir, hex_file) + shutil.copy(full_elf_path, outdir) + shutil.copy(full_hex_path, outdir) + if remove: + shutil.rmtree(dir) + return res + + +@public +def render_and_build(config: DeviceConfiguration, outdir: str, strip: bool = False, + remove: bool = True) -> Tuple[str, str, str, subprocess.CompletedProcess]: + """ + + :param config: + :param outdir: + :param strip: + :param remove: + :return: + """ + dir, elf_file, hex_file = render(config) + res = build(dir, elf_file, hex_file, outdir, strip, remove) + return dir, elf_file, hex_file, res
\ No newline at end of file diff --git a/pyecsca/codegen/simpleserial/simpleserial.c b/pyecsca/codegen/simpleserial/simpleserial.c index 4d73ebb..8ca20c4 100644 --- a/pyecsca/codegen/simpleserial/simpleserial.c +++ b/pyecsca/codegen/simpleserial/simpleserial.c @@ -151,4 +151,5 @@ void simpleserial_put(char c, int size, uint8_t* output) // Write trailing '\n' putch('\n'); + flush(); } diff --git a/pyecsca/codegen/templates/defs.h b/pyecsca/codegen/templates/defs.h index f517ea3..071e8a3 100644 --- a/pyecsca/codegen/templates/defs.h +++ b/pyecsca/codegen/templates/defs.h @@ -1,12 +1,14 @@ #ifndef DEFS_H_ #define DEFS_H_ +#include <stdlib.h> #include "bn.h" typedef struct { {%- for variable in variables %} bn_t {{ variable }}; {%- endfor %} + bool infinity; } point_t; typedef struct { diff --git a/pyecsca/codegen/templates/main.c b/pyecsca/codegen/templates/main.c index e39af07..c2e250f 100644 --- a/pyecsca/codegen/templates/main.c +++ b/pyecsca/codegen/templates/main.c @@ -63,11 +63,26 @@ static void parse_set_curve(const char *path, const uint8_t *data, size_t len, v return; } {%- endfor %} - {%- for variable in curve_variables %} - if (strcmp(path, "g{{ variable }}") == 0) { - bn_from_bin(data, len, &curve->generator->{{ variable }}); + + fat_t *affine = (fat_t *) arg; + if (strcmp(path, "gx") == 0) { + affine[0].len = len; + affine[0].value = malloc(len); + memcpy(affine[0].value, data, len); + return; + } + if (strcmp(path, "gy") == 0) { + affine[1].len = len; + affine[1].value = malloc(len); + memcpy(affine[1].value, data, len); + return; + } + + if (strcmp(path, "in") == 0) { + curve->neutral->infinity = *data; return; } + {%- for variable in curve_variables %} if (strcmp(path, "i{{ variable }}") == 0) { bn_from_bin(data, len, &curve->neutral->{{ variable }}); return; @@ -76,13 +91,25 @@ static void parse_set_curve(const char *path, const uint8_t *data, size_t len, v } static uint8_t cmd_set_curve(uint8_t *data, uint16_t len) { - // need p, [params], n, h, g[variables], i[variables] - parse_data(data, len, "", parse_set_curve, NULL); + // need p, [params], n, h, g[xy], i[variables] + fat_t affine[2] = {fat_empty, fat_empty}; + parse_data(data, len, "", parse_set_curve, (void *) affine); + bn_t x; bn_init(&x); + bn_t y; bn_init(&y); + bn_from_bin(affine[0].value, affine[0].len, &x); + bn_from_bin(affine[1].value, affine[1].len, &y); + + point_from_affine(&x, &y, curve, curve->generator); + bn_clear(&x); + bn_clear(&y); + free(affine[0].value); + free(affine[1].value); return 0; } static uint8_t cmd_generate(uint8_t *data, uint16_t len) { // generate a keypair, export privkey and affine pubkey + trigger_high(); bn_init(&privkey); bn_rand_mod(&privkey, &curve->n); size_t priv_size = bn_to_bin_size(&privkey); @@ -92,12 +119,21 @@ static uint8_t cmd_generate(uint8_t *data, uint16_t len) { uint8_t priv[priv_size]; bn_to_bin(&privkey, priv); + + bn_t x; bn_init(&x); + bn_t y; bn_init(&y); + + point_to_affine(pubkey, curve, &x, &y); + + uint8_t pub[coord_size * 2]; + bn_to_binpad(&x, pub, coord_size); + bn_to_binpad(&y, pub + coord_size, coord_size); + bn_clear(&x); + bn_clear(&y); + trigger_low(); + simpleserial_put('s', priv_size, priv); - uint8_t pub[coord_size * {{ curve_variables | length }}]; - {%- for variable in curve_variables %} - bn_to_binpad(&pubkey->{{ variable }}, pub + coord_size * {{ loop.index0 }}, coord_size); - {%- endfor %} - simpleserial_put('w', coord_size * {{ curve_variables | length }}, pub); + simpleserial_put('w', coord_size * 2, pub); return 0; } @@ -115,17 +151,35 @@ static uint8_t cmd_set_privkey(uint8_t *data, uint16_t len) { } static void parse_set_pubkey(const char *path, const uint8_t *data, size_t len, void *arg) { - {%- for variable in curve_variables %} - if (strcmp(path, "w{{ variable }}") == 0) { - bn_from_bin(data, len, &pubkey->{{ variable }}); + fat_t *affine = (fat_t *) arg; + if (strcmp(path, "wx") == 0) { + affine[0].len = len; + affine[0].value = malloc(len); + memcpy(affine[0].value, data, len); + return; + } + if (strcmp(path, "wy") == 0) { + affine[1].len = len; + affine[1].value = malloc(len); + memcpy(affine[1].value, data, len); return; } - {%- endfor %} } static uint8_t cmd_set_pubkey(uint8_t *data, uint16_t len) { // set the current pubkey - parse_data(data, len, "", parse_set_pubkey, NULL); + fat_t affine[2] = {fat_empty, fat_empty}; + parse_data(data, len, "", parse_set_pubkey, (void *) affine); + bn_t x; bn_init(&x); + bn_t y; bn_init(&y); + bn_from_bin(affine[0].value, affine[0].len, &x); + bn_from_bin(affine[1].value, affine[1].len, &y); + + point_from_affine(&x, &y, curve, pubkey); + bn_clear(&x); + bn_clear(&y); + free(affine[0].value); + free(affine[1].value); return 0; } @@ -139,6 +193,7 @@ static void parse_scalar_mult(const char *path, const uint8_t *data, size_t len, static uint8_t cmd_scalar_mult(uint8_t *data, uint16_t len) { // perform base point scalar mult with supplied scalar. + trigger_high(); bn_t scalar; bn_init(&scalar); parse_data(data, len, "", parse_scalar_mult, (void *) &scalar); size_t coord_size = bn_to_bin_size(&curve->p); @@ -151,26 +206,46 @@ static uint8_t cmd_scalar_mult(uint8_t *data, uint16_t len) { {%- for variable in curve_variables %} bn_to_binpad(&result->{{ variable }}, res + coord_size * {{ loop.index0 }}, coord_size); {%- endfor %} - simpleserial_put('w', coord_size * {{ curve_variables | length }}, res); bn_clear(&scalar); point_free(result); + trigger_low(); + + simpleserial_put('w', coord_size * {{ curve_variables | length }}, res); return 0; } static void parse_ecdh(const char *path, const uint8_t *data, size_t len, void *arg) { - point_t *other = (point_t *) arg; - {%- for variable in curve_variables %} - if (strcmp(path, "w{{ variable }}") == 0) { - bn_from_bin(data, len, &other->{{ variable }}); + fat_t *affine = (fat_t *) arg; + if (strcmp(path, "wx") == 0) { + affine[0].len = len; + affine[0].value = malloc(len); + memcpy(affine[0].value, data, len); + return; + } + if (strcmp(path, "wy") == 0) { + affine[1].len = len; + affine[1].value = malloc(len); + memcpy(affine[1].value, data, len); return; } - {%- endfor %} } static uint8_t cmd_ecdh(uint8_t *data, uint16_t len) { //perform ECDH with provided point (and current privkey), output shared secret + trigger_high(); point_t *other = point_new(); - parse_data(data, len, "", parse_ecdh, (void *) other); + fat_t affine[2] = {fat_empty, fat_empty}; + parse_data(data, len, "", parse_ecdh, (void *) affine); + bn_t ox; bn_init(&ox); + bn_t oy; bn_init(&oy); + bn_from_bin(affine[0].value, affine[0].len, &ox); + bn_from_bin(affine[1].value, affine[1].len, &oy); + + point_from_affine(&ox, &oy, curve, other); + bn_clear(&ox); + bn_clear(&oy); + free(affine[0].value); + free(affine[1].value); point_t *result = point_new(); @@ -192,17 +267,18 @@ static uint8_t cmd_ecdh(uint8_t *data, uint16_t len) { uint8_t h_out[h_size]; hash_final(h_ctx, size, x_raw, h_out); hash_free_ctx(h_ctx); - - simpleserial_put('r', h_size, h_out); bn_clear(&x); bn_clear(&y); point_free(result); point_free(other); + trigger_low(); + + simpleserial_put('r', h_size, h_out); return 0; } static void parse_ecdsa_msg(const char *path, const uint8_t *data, size_t len, void *arg) { - fat_t *dest = (fat_t *)arg; + fat_t *dest = (fat_t *) arg; if (strcmp(path, "d") == 0) { dest->len = len; dest->value = malloc(len); @@ -339,6 +415,14 @@ static uint8_t cmd_ecdsa_verify(uint8_t *data, uint16_t len) { return 0; } +static uint8_t cmd_debug(uint8_t *data, uint16_t len) { + const char *debug_string = "{{ ','.join((model.shortname, coords.name))}}"; + size_t debug_len = strlen(debug_string); + + simpleserial_put('d', debug_len, debug_string); + return 0; +} + int main(void) { platform_init(); prng_init(); @@ -365,6 +449,7 @@ int main(void) { simpleserial_addcmd('a', MAX_SS_LEN, cmd_ecdsa_sign); simpleserial_addcmd('v', MAX_SS_LEN, cmd_ecdsa_verify); {%- endif %} + simpleserial_addcmd('d', MAX_SS_LEN, cmd_debug); while(simpleserial_get()); return 0; }
\ No newline at end of file diff --git a/pyecsca/codegen/templates/mult_rtl.c b/pyecsca/codegen/templates/mult_rtl.c index af437e0..01d42a5 100644 --- a/pyecsca/codegen/templates/mult_rtl.c +++ b/pyecsca/codegen/templates/mult_rtl.c @@ -13,7 +13,7 @@ void scalar_mult(bn_t *scalar, point_t *point, curve_t *curve, point_t *out) { bn_copy(scalar, ©); while (!bn_is_0(©)) { - if (bn_get_bit(©, i) == 1) { + if (bn_get_bit(©, 0) == 1) { point_add(q, r, curve, r); } else { {%- if scalarmult.always %} diff --git a/pyecsca/codegen/templates/point.c b/pyecsca/codegen/templates/point.c index 689d598..c12c50a 100644 --- a/pyecsca/codegen/templates/point.c +++ b/pyecsca/codegen/templates/point.c @@ -30,6 +30,13 @@ void point_free(point_t *point) { } bool point_equals(const point_t *one, const point_t *other) { + if (one->infinity && !other->infinity || other->infinity && !one->infinity) { + return false; + } + if (one->infinity && other->infinity) { + return true; + } + {%- for variable in variables %} if (!bn_eq(&one->{{ variable }}, &other->{{ variable }})) { return false; |
