aboutsummaryrefslogtreecommitdiffhomepage
path: root/test/gdb_script.py
blob: 93d4472801006a13b8016488b8643b585216568e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import json

import gdb

trace_file = open(os.environ["TRACE_FILE"], "w")


def extract_bn(bn):
    data_ptr = bn["dp"]
    used = int(bn["used"])
    bs = int(gdb.lookup_global_symbol("bn_digit_bits").value())
    result = 0
    for i in range(used):
        limb = int((data_ptr + i).dereference())
        result += limb << (i * bs)
    return result


def extract_point(point):
    result = {}
    for field in point.type.fields():
        field_name = field.name
        if len(field_name) != 1:
            continue
        field_value = point[field_name]
        result[field_name] = extract_bn(field_value)
    return result


class TraceFunction(gdb.Breakpoint):
    def stop(self):
        try:
            set_bp.enabled = True
            frame = gdb.newest_frame()
            block = frame.block()
            print(frame.name(), flush=True, file=trace_file)
            out = []
            for sym in block:
                if sym.is_argument:
                    name = sym.name
                    try:
                        value = frame.read_var(name)
                    except Exception as e:
                        value = f"<unavailable: {e}>"
                    deref = value.dereference()
                    if deref.type.name == "point_t":
                        if "out" in name:
                            out.append(deref)
                        else:
                            pt = extract_point(deref)
                            print(f"{name}: {json.dumps(pt)}", flush=True, file=trace_file)
            bp = TraceExit(frame)
            bp.silent = True
            bp.target = out
        except RuntimeError:
            pass
        return False  # Continue execution


class TraceExit(gdb.FinishBreakpoint):
    def stop(self):
        set_bp.enabled = False
        for i, point in enumerate(self.target):
            print(f"out_{i}: {json.dumps(extract_point(point))}", flush=True, file=trace_file)
        return False  # Continue execution


def register_bp(name):
    if gdb.lookup_global_symbol(name) is not None:
        bp = TraceFunction(name)
        bp.silent = True
        return bp
    return None


register_bp("point_add")
register_bp("point_dadd")
register_bp("point_dadd")
register_bp("point_ladd")
register_bp("point_dbl")
register_bp("point_neg")
register_bp("point_scl")
register_bp("point_tpl")
set_bp = register_bp("point_set")
set_bp.enabled = False

gdb.execute("run")
trace_file.close()