aboutsummaryrefslogtreecommitdiffhomepage
path: root/pyecsca/codegen/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'pyecsca/codegen/client.py')
-rw-r--r--pyecsca/codegen/client.py158
1 files changed, 99 insertions, 59 deletions
diff --git a/pyecsca/codegen/client.py b/pyecsca/codegen/client.py
index 8893a5c..81dca40 100644
--- a/pyecsca/codegen/client.py
+++ b/pyecsca/codegen/client.py
@@ -1,35 +1,30 @@
#!/usr/bin/env python3
+import subprocess
from binascii import hexlify
-from typing import Mapping, Union, Optional, Any
from os import path
+from subprocess import Popen
+from typing import Mapping, Union, Optional
import click
-from pyecsca.ec.params import DomainParameters
+from pyecsca.ec.curves import get_params
from pyecsca.ec.mod import Mod
-from pyecsca.ec.point import Point
-
-from .common import EnumDefine, wrap_enum
-
-try:
- import chipwhisperer as cw
-except ImportError:
- cw = None
-
+from pyecsca.ec.params import DomainParameters
+from pyecsca.ec.point import Point, InfinityPoint
+from pyecsca.sca import SerialTarget
-class Target(EnumDefine):
- HOST = "HOST"
- SIMPLESERIAL = "CWLITE_SIMPLESERIAL"
- JCARD = "CWLITE_JCARD"
+from .common import wrap_enum, Platform, get_model, get_coords
def encode_scalar(val: Union[int, Mod]) -> bytes:
if isinstance(val, int):
return val.to_bytes((val.bit_length() + 7) // 8, "big")
elif isinstance(val, Mod):
- return encode_scalar(val.x)
+ return encode_scalar(int(val))
def encode_point(point: Point) -> Mapping:
+ if isinstance(point, InfinityPoint):
+ return {"n": bytes([1])}
return {var: encode_scalar(value) for var, value in point.coords.items()}
@@ -53,8 +48,8 @@ def decode_data(data: bytes) -> Mapping:
length = data[parsed + 1]
if name & 0x80:
sub = decode_data(data[parsed + 2: parsed + 2 + length])
- result[chr(name | 0x7f)] = sub
- parsed += len(sub) + 2
+ result[chr(name & 0x7f)] = sub
+ parsed += length + 2
else:
result[chr(name)] = data[parsed + 2: parsed + 2 + length]
parsed += length + 2
@@ -73,7 +68,7 @@ def cmd_set_curve(group: DomainParameters) -> str:
}
for param, value in group.curve.parameters.items():
data[param] = encode_scalar(value)
- data["g"] = encode_point(group.generator)
+ data["g"] = encode_point(group.generator.to_affine())
data["i"] = encode_point(group.neutral)
return "c" + hexlify(encode_data(None, data)).decode()
@@ -87,7 +82,7 @@ def cmd_set_privkey(privkey: int) -> str:
def cmd_set_pubkey(pubkey: Point) -> str:
- return "w" + hexlify(encode_data(None, {"w": encode_point(pubkey)})).decode()
+ return "w" + hexlify(encode_data(None, {"w": encode_point(pubkey.to_affine())})).decode()
def cmd_scalar_mult(scalar: int) -> str:
@@ -95,7 +90,7 @@ def cmd_scalar_mult(scalar: int) -> str:
def cmd_ecdh(pubkey: Point) -> str:
- return "e" + hexlify(encode_data(None, {"w": encode_point(pubkey)})).decode()
+ return "e" + hexlify(encode_data(None, {"w": encode_point(pubkey.to_affine())})).decode()
def cmd_ecdsa_sign(data: bytes) -> str:
@@ -106,49 +101,64 @@ def cmd_ecdsa_verify(data: bytes, sig: bytes) -> str:
return "v" + hexlify(encode_data(None, {"d": data, "s": sig})).decode()
-def setup_scope(target_type: Target):
- if target_type in (Target.SIMPLESERIAL, Target.JCARD):
- scope = cw.scope()
- scope.default_setup()
- return scope
- return None
+def cmd_debug() -> str:
+ return "d"
-def connect(target_type: Target, scope, **kwargs) -> Any:
- if target_type == Target.HOST:
- return kwargs["binary"]
- elif target_type == Target.SIMPLESERIAL:
- return cw.target(scope, cw.targets.SimpleSerial)
- elif target_type == Target.JCARD:
- return None # TODO: this
+class BinaryTarget(SerialTarget):
+ binary: str
+ process: Optional[Popen]
+ def __init__(self, binary: str):
+ self.binary = binary
-def send_command(target_type: Target, target: Any, command: str) -> str:
- if target_type == Target.SIMPLESERIAL:
- target.write(command + "\n")
- result = target.read()
- target.simpleserial_wait_ack()
- return result
- return None # TODO: this
+ def connect(self):
+ self.process = Popen([self.binary], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
+ text=True, bufsize=1)
+
+ def write(self, data: bytes):
+ self.process.stdin.write(data.decode() + "\n")
+ self.process.stdin.flush()
+
+ def read(self, timeout: int) -> bytes:
+ return self.process.stdout.readline()
+
+ def disconnect(self):
+ if self.process.poll() is not None:
+ self.process.terminate()
@click.group(context_settings={"help_option_names": ["-h", "--help"]})
-@click.option("--target", envvar="TARGET", required=True,
- type=click.Choice(Target.names()),
- callback=wrap_enum(Target),
- help="The target to use.")
+@click.option("--platform", envvar="PLATFORM", required=True,
+ type=click.Choice(Platform.names()),
+ callback=wrap_enum(Platform),
+ help="The target platform to use.")
@click.option("--binary", help="For HOST target only. The binary to run.")
@click.version_option()
-def main(target, binary):
+@click.pass_context
+def main(ctx, platform, binary):
"""
A tool for communicating with built and flashed ECC implementations.
"""
- if target in (Target.SIMPLESERIAL, Target.JCARD) and cw is None:
- click.secho("ChipWhisperer not installed, SIMPLESERIAL and JCARD targets require it.", fg="red", err=True)
- raise click.Abort
- if target == Target.HOST and (binary is None or not path.isfile(binary)):
- click.secho("Binary is required if the target is the host.", fg="red", err=True)
- raise click.Abort
+ ctx.ensure_object(dict)
+ if platform != Platform.HOST:
+ from pyecsca.sca.target import has_chipwhisperer
+ if not has_chipwhisperer:
+ click.secho("ChipWhisperer not installed, targets require it.", fg="red", err=True)
+ raise click.Abort
+ from pyecsca.sca.target import SimpleSerialTarget
+ import chipwhisperer as cw
+ from chipwhisperer.capture.targets.simpleserial_readers.cwlite import \
+ SimpleSerial_ChipWhispererLite
+ ser = SimpleSerial_ChipWhispererLite()
+ scope = cw.scope()
+ scope.default_setup()
+ ctx.obj["target"] = SimpleSerialTarget(ser, scope)
+ else:
+ if binary is None or not path.isfile(binary):
+ click.secho("Binary is required if the target is the host.", fg="red", err=True)
+ raise click.Abort
+ ctx.obj["target"] = BinaryTarget(binary)
# model = ShortWeierstrassModel()
# coords = model.coordinates["projective"]
# p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff
@@ -184,20 +194,50 @@ def main(target, binary):
# print(ot == res)
+def get_curve(ctx: click.Context, param, value: Optional[str]) -> DomainParameters:
+ if value is None:
+ return None
+ ctx.ensure_object(dict)
+ category, name = value.split("/")
+ curve = get_params(category, name, ctx.obj["coords"].name)
+ return curve
+
+
@main.command("gen")
-def generate():
- pass
+@click.argument("model", required=True,
+ type=click.Choice(["shortw", "montgom", "edwards", "twisted"]),
+ callback=get_model)
+@click.argument("coords", required=True,
+ callback=get_coords)
+@click.argument("curve", required=True, callback=get_curve)
+@click.pass_context
+def generate(ctx: click.Context, model, coords, curve):
+ ctx.ensure_object(dict)
+ set_curve = cmd_set_curve(curve)
+ generate = cmd_generate()
+ target: SerialTarget = ctx.obj["target"]
+ target.connect()
+ click.echo(set_curve)
+ target.write(set_curve.encode())
+ click.echo(target.read(1))
+ click.echo(generate)
+ target.write(generate.encode())
+ click.echo(target.read(1))
+ click.echo(target.read(1))
+ target.disconnect()
@main.command("ecdh")
-def ecdh():
- pass
+@click.pass_context
+def ecdh(ctx: click.Context):
+ ctx.ensure_object(dict)
@main.command("ecdsa")
-def ecdsa():
- pass
+@click.pass_context
+def ecdsa(ctx: click.Context):
+ ctx.ensure_object(dict)
if __name__ == "__main__":
- main()
+ main(obj={})