aboutsummaryrefslogtreecommitdiff
path: root/analysis/countermeasures/utils.py
blob: a937c882d3096f4c2a4d0a9759b91d29d055da1d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from pyecsca.ec.mult import DoubleAndAddMultiplier
from pyecsca.ec.signature import ECDSA_SHA1, SignatureResult
from pyecsca.ec.model import ShortWeierstrassModel
from pyecsca.ec.mod import Mod, mod
from pyecsca.ec.error import NonInvertibleError
import os
from hashlib import sha1
from pyasn1.codec.der.decoder import decode
from pyasn1.type.univ import Sequence
from pyecsca.ec.point import Point

def get_point_bytes(path):
    with open(path, "r") as f:
        line = f.read()
        sx, sy = line.split(",")
        bx = bytes.fromhex(sx[2:])
        by = bytes.fromhex(sy[2:])
        point = bytes([0x04]) + bx + by
        return point

def tuple_to_point(tuple_int, params, coords):
    x,y = tuple_int
    return Point(X=mod(x, params.curve.prime),Y=mod(y, params.curve.prime),Z=mod(1, params.curve.prime), model=coords)


def csv_to_point(path, params, coords):
    with open(path,"r") as f:
        line = f.read()
        x, y = line.split(",")
    xy = int(x,16),int(y,16)
    return tuple_to_point(xy, params,coords)

def parse_04point(string):
    assert string.startswith("04")
    x,y = map(lambda x: int(x,16),[string[2:][:len(string)//2-1],string[2:][len(string)//2-1:]])
    return x,y

def read_curve_params(path):
    with open(path) as f:
        return f.read().strip()


def serialize_ecdh_response(ecdhresponse, curve, point, key):
    error = str(int(ecdhresponse.error))
    params = ",".join(map(lambda x: x.hex(), ecdhresponse.params))
    apdu = ecdhresponse.resp.data.hex()
    secret = ecdhresponse.secret.hex()
    success = str(int(ecdhresponse.success))
    sws = ",".join(map(str, ecdhresponse.sws))
    point = point.hex()
    key = hex(key)
    return ";".join([success, error, secret, key, point, curve, params, apdu, sws])


def recover_nonce(params, data, key, point_bytes, signature_result):
    point = params.curve.decode_point(point_bytes)
    model = ShortWeierstrassModel().coordinates["projective"]
    sig = ECDSA_SHA1(
        DoubleAndAddMultiplier(
            model.formulas["add-2007-bl"], model.formulas["dbl-2007-bl"]
        ),
        params.to_coords(model),
        pubkey=point.to_model(model, params.curve.to_coords(model)),
        privkey=key,
    )
    digest = sig.hash_algo(data).digest()
    z = int.from_bytes(digest, byteorder="big")
    if len(digest) * 8 > sig.params.order.bit_length():
        z >>= len(digest) * 8 - sig.params.order.bit_length()
    r, s = signature_result.r, signature_result.s
    s = mod(int(s), sig.params.order)
    r = mod(int(r), sig.params.order)
    try:
        nonce = s.inverse() * (mod(z, sig.params.order) + r * sig.privkey)
    except NonInvertibleError:
        return 0
    sig.mult.init(sig.params, sig.params.generator)
    point = sig.mult.multiply(int(nonce))
    affine_point = point.to_affine()
    # assert r == mod(int(affine_point.x), sig.params.order)
    return nonce


def serialize_ecdsa_response(
    response, data, domainparams, key, curve_csv, point_bytes, valid=None
):
    error = str(int(response.error))
    params = ",".join(map(lambda x: x.hex(), response.params))
    apdu = response.resp.data.hex()
    signature = response.signature.hex()
    success = str(int(response.success))
    sws = ",".join(map(str, response.sws))
    point_bytes_hex = point_bytes.hex()
    key_hex = hex(key)
    data_hex = data.hex()
    nonce = recover_nonce(
        domainparams,
        data,
        key,
        point_bytes,
        SignatureResult.from_DER(response.signature),
    )
    nonce_hex = hex(int(nonce))
    valid_str = "" if valid is None else str(valid)
    return ";".join(
        [
            success,
            error,
            signature,
            valid_str,
            data_hex,
            nonce_hex,
            key_hex,
            point_bytes_hex,
            curve_csv,
            params,
            apdu,
            sws,
        ]
    )


def serialize_keygen_response(response, key, curve_csv, point_bytes):
    error = str(int(response.error))
    params = ",".join(map(lambda x: x.hex(), response.params))
    apdu = response.resp.data.hex()
    success = str(int(response.success))
    sws = ",".join(map(str, response.sws))
    point_bytes_hex = point_bytes.hex()
    key_hex = hex(key)
    return ";".join(
        [success, error, key_hex, point_bytes_hex, curve_csv, params, apdu, sws]
    )


def safe_save(header, result_lines, filename):
    if os.path.isfile(filename):
        print(f"Measurement already exists ({filename})")
        return
    tmp = filename + ".tmp"
    with open(tmp, "w") as f:
        f.write(f"{header}\n")
        for line in result_lines:
            f.write(f"{line}\n")
    os.rename(tmp, filename)


def save_ecdh(result_lines, filename):
    header = "success;error;secret[SHA1];priv;pub;curve;params;apdu;sws"
    safe_save(header, result_lines, filename)


def save_ecdsa(result_lines, filename):
    header = "success;error;signature;valid;data;nonce;priv;pub;curve;params;apdu;sws"
    safe_save(header, result_lines, filename)


def save_keygen(result_lines, filename):
    header = "success;error;priv;pub;curve;params;apdu;sws"
    safe_save(header, result_lines, filename)


def parse_ecdsa_signature(signature_der):
    decoded_signature, _ = decode(signature_der, asn1Spec=Sequence())
    r = int(decoded_signature[0])
    s = int(decoded_signature[1])
    return r, s

def sha(value):
    h = hex(int(value))[2:]
    if len(h)%2!=0:
        h = "0"+h
    bh = bytes.fromhex(h)
    return int(sha1(bh).digest().hex(),16)