diff options
Diffstat (limited to 'pyecsca/codegen/client.py')
| -rw-r--r-- | pyecsca/codegen/client.py | 158 |
1 files changed, 99 insertions, 59 deletions
diff --git a/pyecsca/codegen/client.py b/pyecsca/codegen/client.py index 8893a5c..81dca40 100644 --- a/pyecsca/codegen/client.py +++ b/pyecsca/codegen/client.py @@ -1,35 +1,30 @@ #!/usr/bin/env python3 +import subprocess from binascii import hexlify -from typing import Mapping, Union, Optional, Any from os import path +from subprocess import Popen +from typing import Mapping, Union, Optional import click -from pyecsca.ec.params import DomainParameters +from pyecsca.ec.curves import get_params from pyecsca.ec.mod import Mod -from pyecsca.ec.point import Point - -from .common import EnumDefine, wrap_enum - -try: - import chipwhisperer as cw -except ImportError: - cw = None - +from pyecsca.ec.params import DomainParameters +from pyecsca.ec.point import Point, InfinityPoint +from pyecsca.sca import SerialTarget -class Target(EnumDefine): - HOST = "HOST" - SIMPLESERIAL = "CWLITE_SIMPLESERIAL" - JCARD = "CWLITE_JCARD" +from .common import wrap_enum, Platform, get_model, get_coords def encode_scalar(val: Union[int, Mod]) -> bytes: if isinstance(val, int): return val.to_bytes((val.bit_length() + 7) // 8, "big") elif isinstance(val, Mod): - return encode_scalar(val.x) + return encode_scalar(int(val)) def encode_point(point: Point) -> Mapping: + if isinstance(point, InfinityPoint): + return {"n": bytes([1])} return {var: encode_scalar(value) for var, value in point.coords.items()} @@ -53,8 +48,8 @@ def decode_data(data: bytes) -> Mapping: length = data[parsed + 1] if name & 0x80: sub = decode_data(data[parsed + 2: parsed + 2 + length]) - result[chr(name | 0x7f)] = sub - parsed += len(sub) + 2 + result[chr(name & 0x7f)] = sub + parsed += length + 2 else: result[chr(name)] = data[parsed + 2: parsed + 2 + length] parsed += length + 2 @@ -73,7 +68,7 @@ def cmd_set_curve(group: DomainParameters) -> str: } for param, value in group.curve.parameters.items(): data[param] = encode_scalar(value) - data["g"] = encode_point(group.generator) + data["g"] = encode_point(group.generator.to_affine()) data["i"] = encode_point(group.neutral) return "c" + hexlify(encode_data(None, data)).decode() @@ -87,7 +82,7 @@ def cmd_set_privkey(privkey: int) -> str: def cmd_set_pubkey(pubkey: Point) -> str: - return "w" + hexlify(encode_data(None, {"w": encode_point(pubkey)})).decode() + return "w" + hexlify(encode_data(None, {"w": encode_point(pubkey.to_affine())})).decode() def cmd_scalar_mult(scalar: int) -> str: @@ -95,7 +90,7 @@ def cmd_scalar_mult(scalar: int) -> str: def cmd_ecdh(pubkey: Point) -> str: - return "e" + hexlify(encode_data(None, {"w": encode_point(pubkey)})).decode() + return "e" + hexlify(encode_data(None, {"w": encode_point(pubkey.to_affine())})).decode() def cmd_ecdsa_sign(data: bytes) -> str: @@ -106,49 +101,64 @@ def cmd_ecdsa_verify(data: bytes, sig: bytes) -> str: return "v" + hexlify(encode_data(None, {"d": data, "s": sig})).decode() -def setup_scope(target_type: Target): - if target_type in (Target.SIMPLESERIAL, Target.JCARD): - scope = cw.scope() - scope.default_setup() - return scope - return None +def cmd_debug() -> str: + return "d" -def connect(target_type: Target, scope, **kwargs) -> Any: - if target_type == Target.HOST: - return kwargs["binary"] - elif target_type == Target.SIMPLESERIAL: - return cw.target(scope, cw.targets.SimpleSerial) - elif target_type == Target.JCARD: - return None # TODO: this +class BinaryTarget(SerialTarget): + binary: str + process: Optional[Popen] + def __init__(self, binary: str): + self.binary = binary -def send_command(target_type: Target, target: Any, command: str) -> str: - if target_type == Target.SIMPLESERIAL: - target.write(command + "\n") - result = target.read() - target.simpleserial_wait_ack() - return result - return None # TODO: this + def connect(self): + self.process = Popen([self.binary], stdin=subprocess.PIPE, stdout=subprocess.PIPE, + text=True, bufsize=1) + + def write(self, data: bytes): + self.process.stdin.write(data.decode() + "\n") + self.process.stdin.flush() + + def read(self, timeout: int) -> bytes: + return self.process.stdout.readline() + + def disconnect(self): + if self.process.poll() is not None: + self.process.terminate() @click.group(context_settings={"help_option_names": ["-h", "--help"]}) -@click.option("--target", envvar="TARGET", required=True, - type=click.Choice(Target.names()), - callback=wrap_enum(Target), - help="The target to use.") +@click.option("--platform", envvar="PLATFORM", required=True, + type=click.Choice(Platform.names()), + callback=wrap_enum(Platform), + help="The target platform to use.") @click.option("--binary", help="For HOST target only. The binary to run.") @click.version_option() -def main(target, binary): +@click.pass_context +def main(ctx, platform, binary): """ A tool for communicating with built and flashed ECC implementations. """ - if target in (Target.SIMPLESERIAL, Target.JCARD) and cw is None: - click.secho("ChipWhisperer not installed, SIMPLESERIAL and JCARD targets require it.", fg="red", err=True) - raise click.Abort - if target == Target.HOST and (binary is None or not path.isfile(binary)): - click.secho("Binary is required if the target is the host.", fg="red", err=True) - raise click.Abort + ctx.ensure_object(dict) + if platform != Platform.HOST: + from pyecsca.sca.target import has_chipwhisperer + if not has_chipwhisperer: + click.secho("ChipWhisperer not installed, targets require it.", fg="red", err=True) + raise click.Abort + from pyecsca.sca.target import SimpleSerialTarget + import chipwhisperer as cw + from chipwhisperer.capture.targets.simpleserial_readers.cwlite import \ + SimpleSerial_ChipWhispererLite + ser = SimpleSerial_ChipWhispererLite() + scope = cw.scope() + scope.default_setup() + ctx.obj["target"] = SimpleSerialTarget(ser, scope) + else: + if binary is None or not path.isfile(binary): + click.secho("Binary is required if the target is the host.", fg="red", err=True) + raise click.Abort + ctx.obj["target"] = BinaryTarget(binary) # model = ShortWeierstrassModel() # coords = model.coordinates["projective"] # p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff @@ -184,20 +194,50 @@ def main(target, binary): # print(ot == res) +def get_curve(ctx: click.Context, param, value: Optional[str]) -> DomainParameters: + if value is None: + return None + ctx.ensure_object(dict) + category, name = value.split("/") + curve = get_params(category, name, ctx.obj["coords"].name) + return curve + + @main.command("gen") -def generate(): - pass +@click.argument("model", required=True, + type=click.Choice(["shortw", "montgom", "edwards", "twisted"]), + callback=get_model) +@click.argument("coords", required=True, + callback=get_coords) +@click.argument("curve", required=True, callback=get_curve) +@click.pass_context +def generate(ctx: click.Context, model, coords, curve): + ctx.ensure_object(dict) + set_curve = cmd_set_curve(curve) + generate = cmd_generate() + target: SerialTarget = ctx.obj["target"] + target.connect() + click.echo(set_curve) + target.write(set_curve.encode()) + click.echo(target.read(1)) + click.echo(generate) + target.write(generate.encode()) + click.echo(target.read(1)) + click.echo(target.read(1)) + target.disconnect() @main.command("ecdh") -def ecdh(): - pass +@click.pass_context +def ecdh(ctx: click.Context): + ctx.ensure_object(dict) @main.command("ecdsa") -def ecdsa(): - pass +@click.pass_context +def ecdsa(ctx: click.Context): + ctx.ensure_object(dict) if __name__ == "__main__": - main() + main(obj={}) |
