From f95f55d226ee87072ea164880a154e3bd5ac2448 Mon Sep 17 00:00:00 2001 From: J08nY Date: Thu, 6 Nov 2025 15:22:18 +0100 Subject: Progress bar for tree building. --- pyecsca/sca/re/tree.py | 29 +++++++++++++++++++++++------ pyproject.toml | 3 ++- test/sca/test_tree.py | 2 +- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/pyecsca/sca/re/tree.py b/pyecsca/sca/re/tree.py index 43f1901..a45faaa 100644 --- a/pyecsca/sca/re/tree.py +++ b/pyecsca/sca/re/tree.py @@ -45,11 +45,12 @@ Here we grow the trees. from math import ceil, log2 from copy import deepcopy -from typing import Mapping, Any, Set, List, Tuple, Optional, Dict, Union +from typing import Mapping, Any, Callable, Set, List, Tuple, Optional, Dict, Union import numpy as np import pandas as pd from public import public +from tqdm.auto import tqdm from anytree import RenderTree, NodeMixin, AbstractStyle, PreOrderIter from pyecsca.misc.utils import log @@ -556,16 +557,26 @@ class Tree: return tree @classmethod - def build(cls, cfgs: Set[Any], *maps: Map, split: Union[str, SplitCriterion] = "largest") -> "Tree": + def build(cls, cfgs: Set[Any], *maps: Map, split: Union[str, SplitCriterion] = "largest", progress: bool = False) -> "Tree": """ Build a tree. :param cfgs: The set of configs to build the tree for. :param maps: The distinguishing maps to use. :param split: The split criterion to use. Can be one of "degree", "largest", "average". + :param progress: Whether to report progress via a progress bar. :return: The tree. """ - return cls(_build_tree(cfgs, dict(enumerate(maps)), split=split), *maps) + leaf_callback: Optional[Callable[[int], None]] + if progress: + bar = tqdm(total=len(cfgs), desc="Building tree.") + def leaf_callback(n_cfgs): return bar.update(n_cfgs) + else: + leaf_callback = None + res = cls(_build_tree(cfgs, dict(enumerate(maps)), split=split, leaf_callback=leaf_callback), *maps) + if progress: + bar.close() + return res def _build_tree( @@ -576,17 +587,19 @@ def _build_tree( index: Optional[int] = None, breadth: Optional[int] = None, split: Union[str, SplitCriterion] = "largest", + leaf_callback: Optional[Callable[[int], None]] = None ) -> Node: - pad = " " * depth + pad = " " * depth + (f"[{depth}, {index}/{breadth}] " if index is not None and breadth is not None else "") # If there is only one remaining cfg, we do not need to continue and just return (base case 1). # Note that n_cfgs will never be 0 here, as the base case 2 returns if the cfgs cannot # be split into two sets (one would be empty). n_cfgs = len(cfgs) - ident = f"[{index}/{breadth}] " if index is not None and breadth is not None else "" - log(pad + f"{ident}Splitting {n_cfgs}.") + log(pad + f"Splitting {n_cfgs}.") cfgset = set(cfgs) if n_cfgs == 1: log(pad + "Trivial.") + if leaf_callback is not None: + leaf_callback(n_cfgs) return Node(cfgset, response=response) # Choose the split criterion @@ -627,6 +640,8 @@ def _build_tree( # We found nothing distinguishing the configs, so return them all (base case 2). if best_column is None or best_dmap is None: log(pad + "Nothing could split.") + if leaf_callback is not None: + leaf_callback(n_cfgs) return Node(cfgset, response=response) best_distinguishing_element = best_dmap.domain[best_column] @@ -636,6 +651,8 @@ def _build_tree( # We found nothing distinguishing the configs, so return them all (base case 2). if groups.ngroups == 1: log(pad + "Trivial split.") + if leaf_callback is not None: + leaf_callback(n_cfgs) return Node(cfgset, response=response) # Create our node diff --git a/pyproject.toml b/pyproject.toml index 83674c7..da0a9f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,8 @@ "networkx", "importlib-resources", "anytree", - "loky" + "loky", + "tqdm" ] [project.urls] diff --git a/test/sca/test_tree.py b/test/sca/test_tree.py index a3afaf9..715c3e0 100644 --- a/test/sca/test_tree.py +++ b/test/sca/test_tree.py @@ -95,7 +95,7 @@ def test_build_tree(split): codomain2 = {0, 1, 2, 3} mapping2 = pd.DataFrame([(1, 0, 0), (2, 0, 0), (3, 0, 0)]) dmap2 = Map(mapping2, cfg_map, inputs2, codomain2) - tree = Tree.build(set(cfgs), dmap1, dmap2, split=split) + tree = Tree.build(set(cfgs), dmap1, dmap2, split=split, progress=True) tree.render() tree.render_basic() tree.describe() -- cgit v1.2.3-70-g09d2