# Unraveling scalar mults and countermeasures

In [None]:
import pickle
import itertools
import glob
import random
import math

from collections import Counter

import numpy as np
import pandas as pd
from scipy.stats import binom, entropy
from scipy.spatial import distance
from tqdm.auto import tqdm, trange
from anytree import PreOrderIter, Walker
from matplotlib import pyplot as plt

from pyecsca.ec.mult import *
from pyecsca.misc.utils import TaskExecutor, silent
from pyecsca.sca.re.tree import Map, Tree

from common import *

%matplotlib ipympl

## Prepare
Select *divisor name* to restrict the features. Select *kind* to pick the probmap source.

In [None]:
divisor_name = "all"
kind = "all"
allfeats = list(filter(lambda feat: feat not in (1,2,3,4,5), divisor_map[divisor_name]))

In [None]:
# Load
try:
    with open(f"{divisor_name}_{kind}_distrs.pickle", "rb") as f:
        distributions_mults = pickle.load(f)
except FileNotFoundError:
    with open(f"all_{kind}_distrs.pickle", "rb") as f:
        distributions_mults = pickle.load(f)
    for probmap in distributions_mults.values():
        probmap.narrow(allfeats)

In [None]:
allmults = list(distributions_mults.keys())
nmults = len(allmults)
nallfeats = len(allfeats)

## Build dmap and tree

Select the n for building the tree.

In [None]:
nbuild = 10000
alpha = 0.05

In [None]:
# Now go over all divisors, cluster based on overlapping CI for given n?
io_map = {mult:{} for mult in allmults}
for divisor in allfeats:
    prev_ci_low = None
    prev_ci_high = None
    groups = {}
    pvals = {}
    group = 0
    for mult, probmap in sorted(distributions_mults.items(), key=lambda item: -item[1][divisor]):
        # We are going from high to low p.
        pval = probmap[divisor]
        pvals[mult] = pval
        ci_low, ci_high = conf_interval(pval, nbuild, alpha)
        ci_low = max(ci_low, 0.0)
        ci_high = min(ci_high, 1.0)
        if (prev_ci_low is None and prev_ci_high is None) or prev_ci_low >= ci_high:
            g = groups.setdefault(f"arbitrary{group}", set())
            g.add(mult)
            group += 1
        else:
            g = groups.setdefault(f"arbitrary{group}", set())
            g.add(mult)
        prev_ci_low = ci_low
        prev_ci_high = ci_high
    
    #print(f"Divisor: {divisor}, num groups: {group}", end="\n\t")
    #for g in groups.values():
    #    print(len(g), end=", ")
    #print()
    for group, mults in groups.items():
        mult_pvals = [pvals[mult] for mult in mults]
        group_pval_avg = np.mean(mult_pvals)
        group_pval_var = np.var(mult_pvals)
        group_pval_min = np.min(mult_pvals)
        group_pval_max = np.max(mult_pvals)
        for mult in mults:
            io_map[mult][divisor] = (group,  group_pval_avg, group_pval_var, group_pval_min, group_pval_max)

# then build dmap
dmap = Map.from_io_maps(set(distributions_mults.keys()), io_map)

In [None]:
print(dmap.describe())

In [None]:
# deduplicate dmap
dmap.deduplicate()

In [None]:
print(dmap.describe())

In [None]:
# build a tree
with silent():
    tree = Tree.build(set(allmults), dmap)

In [None]:
print(tree.describe())

In [None]:
print(tree.render_basic())

## Simulate distinguishing using a tree
We can now simulate distinguishing using the tree and how it behaves with increasing the number of samples per divisor collected.

In [None]:
simulations = 1000

for nattack in trange(100, 10000, 100):
    successes = 0
    pathiness = 0
    for i in range(simulations):
        true_mult = random.choice(allmults)
        probmap = distributions_mults[true_mult]
        node = tree.root
        while True:
            if node.is_leaf:
                break
            divisor = node.dmap_input
            prob = probmap[divisor]
            sampled_prob = binom(nattack, prob).rvs() / nattack
            best_child = None
            true_child = None
            best_group_distance = None
            #print(f"Divisor: {divisor}, prob: {prob}, sampled: {sampled_prob}")
            for child in node.children:
                if true_mult in child.cfgs:
                    true_child = child
                group, group_pval_avg, group_pval_var, group_pval_min, group_pval_max = child.response
                group_distance = min(abs(sampled_prob - group_pval_min), abs(sampled_prob - group_pval_max))
                #print(f"Child {group}, {group_pval_avg}")
                if best_child is None or \
                    (group_distance < best_group_distance):
                    best_child = child
                    best_group_distance = group_distance
                if sampled_prob > group_pval_min and sampled_prob < group_pval_max:
                    best_child = child
                    break
            #print(f"Best {best_child.response}")
            if true_child is not None and true_child != best_child:
                pass
                #print(f"Mistake! {prob}, {sampled_prob} true:{true_child.response}, chosen:{best_child.response}")
            node = best_child
            if true_mult in node.cfgs:
                pathiness += 1
        #print(f"Arrived: {true_mult in node.cfgs}")
        if true_mult in node.cfgs:
            successes += 1
    print(f"{nattack}: success rate {successes/simulations}, pathiness {pathiness/simulations}")

## Simulate distinguishing using a distance metric

We need to first select some features (divisors) from the set of all divisors that we will query
the target with. This set should be the smallest (to not do a lot of queries) yet allow us to distinguish as
much as possible.

### Feature selection using trees + classification error

We can reuse the clustering + tree building approach above and just take the inputs that the greedy tree building choses as the features. However, we can also use more conventional feature selection approaches.

In [None]:
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import validate_data, check_is_fitted
from sklearn.utils.multiclass import unique_labels
from scipy.special import logsumexp
from sklearn.metrics import euclidean_distances, top_k_accuracy_score, make_scorer, accuracy_score


class EuclidClassifier(ClassifierMixin, BaseEstimator):
    def __init__(self, *, nattack=100):
        self.nattack = nattack

    def fit(self, X, y):
        X, y = validate_data(self, X, y)
        if not np.logical_and(X >= 0, X <= 1).all():
            raise TypeError("Expects valid probabilities in X.")
        self.classes_ = unique_labels(y)
        if len(self.classes_) != len(y):
            raise ValueError("Expects only one sample per class containing the binomial probabilities.")
        self.X_ = X
        self.y_ = y
        return self

    def decision_function(self, X):
        check_is_fitted(self)
        X = validate_data(self, X, reset=False)
        distances = euclidean_distances(X / self.nattack, self.X_)
        return -distances

    def predict(self, X):
        check_is_fitted(self)
        X = validate_data(self, X, reset=False)
        distances = euclidean_distances(X / self.nattack, self.X_)
        closest = np.argmin(distances, axis=1)
        return self.classes_[closest]


class BayesClassifier(ClassifierMixin, BaseEstimator):
    def __init__(self, *, nattack=100):
        self.nattack = nattack

    def fit(self, X, y):
        # X has (nmults = nsamples, nfeats)
        X, y = validate_data(self, X, y)
        if not np.logical_and(X >= 0, X <= 1).all():
            raise TypeError("Expects valid probabilities in X.")
        self.classes_ = unique_labels(y)
        if len(self.classes_) != len(y):
            raise ValueError("Expects only one sample per class containing the binomial probabilities.")
        self.X_ = X
        self.y_ = y
        return self

    def decision_function(self, X):
        check_is_fitted(self)
        X = validate_data(self, X, reset=False)
        # We have a uniform prior, so we can ignore it.
        probas = np.zeros((len(X), len(self.classes_)))
        for i, row in enumerate(X):
            p = binom(self.nattack, self.X_).logpmf(row)
            s = np.sum(p, axis=1)
            log_prob_x = logsumexp(s)
            res = np.exp(s - log_prob_x)
            probas[i, ] = res
        return probas

    def predict_proba(self, X):
        return self.decision_function(X)

    def predict(self, X):
        check_is_fitted(self)
        X = validate_data(self, X, reset=False)
        # We have a uniform prior, so we can ignore it.
        results = np.empty(len(X), dtype=self.classes_.dtype)
        for i, row in enumerate(X):
            p = binom(self.nattack, self.X_).logpmf(row)
            s = np.sum(p, axis=1)
            most_likely = np.argmax(s)
            results[i] = self.classes_[most_likely]
        return results


def to_sklearn(mults_map: dict[MultIdent, ProbMap], feats: list[int]):
    nfeats = len(feats)
    nmults = len(mults_map)
    classes = np.arange(nmults, dtype=np.uint32)
    probs = np.zeros((nmults, nfeats), dtype=np.float64)
    mults = sorted(list(mults_map.keys()))
    for i, divisor in enumerate(feats):
        for j, mult in enumerate(mults):
            probmap = mults_map[mult]
            probs[j, i] = probmap[divisor]
    return probs, classes


def make_instance(nattack: int,
                  simulations: int,
                  X,
                  y,
                  progress=False):
    nmults, nfeats = X.shape
    X_samp = np.zeros((simulations, nfeats), dtype=np.uint32)
    y_samp = np.zeros(simulations, dtype=np.uint32)

    r = trange(simulations) if progress else range(simulations)
    for i in r:
        if i < nmults and simulations >= nmults:
            j = i
        else:
            j = random.randrange(nmults)
        X_samp[i] = binom(nattack, X[j]).rvs()
        y_samp[i] = j
    return X_samp, y_samp


def evaluate_classifier(nattack: int,
                        simulations: int,
                        X,
                        y,
                        classifier,
                        scorer):
    classifier.set_params(nattack=nattack)
    classifier.fit(X, y)
    X_samp, y_samp = make_instance(nattack, simulations, X, y)
    return scorer(classifier, X_samp, y_samp)


def average_rank_score(y_true, y_pred, labels=None):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    
    n_samples, n_classes = y_pred.shape
    if labels is not None:
        labels = np.asarray(labels)
        if len(labels) != n_classes:
            raise ValueError()
        label_indexes = np.searchsorted(labels, y_true)
        indexes = np.where(labels[label_indexes] == y_true, label_indexes, -1)
    else:
        indexes = y_true
    true_scores = y_pred[np.arange(n_samples), indexes]
    
    counts_higher = np.sum(y_pred > true_scores[:, None], axis=1)
    
    ranks = counts_higher + 1
    
    return ranks.mean()

X = np.array([[0.7, 0.7, 0.1], [0.3, 0.7, 0.1], [0.2, 0.5, 0.1], [0.1, 0.1, 0.4]])
y = np.array([1, 2, 3, 4])

euc = EuclidClassifier(nattack=100).fit(X, y)
label = euc.predict(np.array([20, 50, 10]).reshape(1, -1))
dec = euc.decision_function(np.array([[20, 50, 10], [70, 60, 20]]))
print(label)
print(dec)

clf = BayesClassifier(nattack=100).fit(X, y)
label = clf.predict(np.array([20, 50, 10]).reshape(1, -1))
ps = clf.predict_proba(np.array([20, 50, 10]).reshape(1, -1))
print(label)
print(ps)


acc = top_k_accuracy_score(np.array([3, 1]),
                     euc.decision_function(np.array([[20, 50, 10], [70, 60, 20]])),
                     labels = [1, 2, 3, 4],
                     k=1)
print(acc)
acc = top_k_accuracy_score(np.array([3, 1]),
                     clf.predict_proba(np.array([[20, 50, 10], [70, 60, 20]])),
                     labels = [1, 2, 3, 4],
                     k=1)
print(acc)

avg = average_rank_score(np.array([2, 0]),
                     euc.decision_function(np.array([[20, 50, 10], [70, 60, 20]])))
print(avg)
avg = average_rank_score(np.array([3, 1]),
                         euc.decision_function(np.array([[20, 50, 10], [70, 60, 20]])),
                         labels = [1, 2, 3, 4])
print(avg)

accuracy_scorer = make_scorer(
    top_k_accuracy_score,
    greater_is_better=True,
    response_method=("decision_function", "predict_proba"),
)

#accuracy_scorer.__str__ = lambda self: "Accuracy"

top_5_scorer = make_scorer(
    top_k_accuracy_score,
    greater_is_better=True,
    response_method=("decision_function", "predict_proba"),
    k=5
)

#top_5_scorer.__str__ = lambda self: "Top-5 accuracy"

top_10_scorer = make_scorer(
    top_k_accuracy_score,
    greater_is_better=True,
    response_method=("decision_function", "predict_proba"),
    k=10
)

#top_10_scorer.__str__ = lambda self: "Top-10 accuracy"

avg_rank_scorer = make_scorer(
    average_rank_score,
    greater_is_better=False,
    response_method=("decision_function", "predict_proba"),
)

#avg_rank_scorer.__str__ = lambda self: "Average rank"

In [None]:
class FeatureSelector:
    def __init__(self,
                 allfeats: list[int],
                 mults: dict[MultIdent, ProbMap],
                 num_workers: int):
        self.allfeats = allfeats
        self.mults = mults
        self.num_workers = num_workers

    def prepare(self, nattack: int):
        self.nattack = nattack

    def select(self, nfeats: int, startwith: list[int] = None) -> list[int]:
        pass

class FeaturesByClassification(FeatureSelector):
    def __init__(self,
             allfeats: list[int],
             mults: dict[MultIdent, ProbMap],
             num_workers: int,
             simulations: int,
             classifier,
             scorer):
        super().__init__(allfeats, mults, num_workers)
        self.simulations = simulations
        self.classifier = classifier
        self.scorer = scorer

class RandomFeatures(FeaturesByClassification):

    def __init__(self,
             allfeats: list[int],
             mults: dict[MultIdent, ProbMap],
             num_workers: int,
             simulations: int,
             classifier,
             scorer,
             retries: int):
        super().__init__(allfeats, mults, num_workers, simulations, classifier, scorer)
        self.retries = retries
    
    def _select_random(self, nfeats: int, startwith: list[int] = None) -> list[int]:
        if startwith is None:
            startwith = []
        toselect = nfeats - len(startwith)
        if toselect > 0:
            available_feats = list(filter(lambda feat: feat not in startwith, self.allfeats))
            selected = random.sample(available_feats, toselect)
            return startwith + selected
        elif toselect < 0:
            return random.sample(startwith, nfeats)
        else:
            return startwith

    def select(self, nfeats: int, startwith: list[int] = None) -> tuple[list[int], float]:
        with TaskExecutor(max_workers=self.num_workers) as pool:
            feat_map = []
            for i in range(self.retries):
                feats = self._select_random(nfeats, startwith)
                X, y = to_sklearn(self.mults, feats)
                feat_map.append(feats)
                pool.submit_task(i,
                                 evaluate_classifier,
                                 self.nattack, self.simulations,
                                 X, y, self.classifier, self.scorer)
            best_score = None
            best_feats = None
            for i, future in tqdm(pool.as_completed(), total=len(pool.tasks), desc="retries", leave=False):
                score = future.result()
                #print(i, feat_map[i], score)
                if best_score is None or score > best_score:
                    best_score = score
                    best_feats = feat_map[i]
            return best_feats, best_score


class GreedyFeatures(FeaturesByClassification):

    def select(self, nfeats: int, startwith: list[int] = None) -> tuple[list[int], float]:
        if startwith is None:
            startwith = []
        toselect = nfeats - len(startwith)
        if toselect < 0:
            raise ValueError("No features to select.")
        available_feats = list(filter(lambda feat: feat not in startwith, self.allfeats))
        current = list(startwith)
        with TaskExecutor(max_workers=self.num_workers) as pool:
            while toselect > 0:
                for feat in available_feats:
                    feats = current + [feat]
                    X, y = to_sklearn(self.mults, feats)
                    pool.submit_task(feat,
                                     evaluate_classifier,
                                     self.nattack, self.simulations,
                                     X, y, self.classifier, self.scorer)
                best_score = None
                best_feat = None
                for feat, future in tqdm(pool.as_completed(), total=len(pool.tasks), leave=False):
                    score = future.result()
                    if best_score is None or score > best_score:
                        best_score = score
                        best_feat = feat
                current.append(best_feat)
                toselect -= 1
            return current, best_score


def feature_search(feat_range, nattack_range, selector, restarts=False):
    if isinstance(feat_range, int):
        feat_range = [feat_range]
    if isinstance(nattack_range, int):
        nattack_range = [nattack_range]
    results = {}
    for nattack in tqdm(nattack_range, desc="nattack", smoothing=0):
        selector.prepare(nattack)
        feats = []
        for nfeats in tqdm(feat_range, desc="nfeats", leave=False):
            feats, score = selector.select(nfeats, [] if restarts else feats)
            results[(nattack, nfeats)] = feats
            print(f"{nattack},{nfeats}: {feats}, {score}")
    return results

In [None]:
def plot_performance(classifier, scorer, simulations, feature_map, mults, num_workers = 30):
    scores = {}
    with TaskExecutor(max_workers=num_workers) as pool:
        for (nattack, nfeats), feats in feature_map.items():
            X, y = to_sklearn(mults, feats)
            pool.submit_task((nattack, nfeats),
                             evaluate_classifier,
                             nattack, simulations, X, y, classifier, scorer)
        for (nattack, nfeats), future in tqdm(pool.as_completed(), desc="Evaluating", leave=False, total=len(pool.tasks)):
            score = future.result()
            scores[(nattack, nfeats)] = score

    x_coords = [k[0] for k in scores.keys()]
    y_coords = [k[1] for k in scores.keys()]
    
    x_unique = sorted(set(x_coords))
    y_unique = sorted(set(y_coords))

    heatmap_data = np.zeros((len(y_unique), len(x_unique)))
    
    for (x, y), score in scores.items():
        x_index = x_unique.index(x)
        y_index = y_unique.index(y)
        heatmap_data[y_index, x_index] = score

    x_mesh, y_mesh = np.meshgrid(x_unique, y_unique)
    
    plt.pcolormesh(x_mesh, y_mesh, heatmap_data, cmap='viridis', shading='auto')
    plt.colorbar(label='Score')

    for i in range(len(y_unique)):
        for j in range(len(x_unique)):
            plt.text(x_unique[j], y_unique[i], f'{heatmap_data[i, j]:.2f}', ha='center', va='center', color='white' if heatmap_data[i, j] < 0.4 else "black")
    
    x_contour, y_contour = np.meshgrid(np.linspace(min(x_unique), max(x_unique), 100), 
                                   np.linspace(min(y_unique), max(y_unique), 100))
    z_contour = x_contour * y_contour
    
    contour = plt.contour(x_contour, y_contour, z_contour, levels=[100, 200, 300, 400, 500, 600, 700, 800, 900, 1000], colors='white', zorder=4)
    plt.clabel(contour, inline=True, fontsize=10)
    
    plt.xticks(ticks=x_unique, labels=x_unique)
    plt.yticks(ticks=y_unique, labels=y_unique)
    plt.xlabel('nattack')
    plt.ylabel('nfeats')
    plt.title(f'{scorer._score_func.__name__}{scorer._kwargs} ({classifier.__class__.__name__})')
    plt.show()

In [None]:
feats_in_tree = Counter()
for node in PreOrderIter(tree.root):
    if node.is_leaf:
        continue
    feats_in_tree[node.dmap_input] += 1
feats_in_tree = set(feats_in_tree.keys())

In [None]:
simulations = 500
retries = 500
nattack = range(50, 350, 50)
nfeats = range(1, 11)
num_workers = 30

euclid_classifier = EuclidClassifier()
tree_random_subsets = RandomFeatures(sorted(feats_in_tree), distributions_mults, num_workers,
                                     simulations, euclid_classifier, top_5_scorer, retries)

selected_euclid_fromtree = feature_search(nfeats, nattack, tree_random_subsets, restarts=True)

In [None]:
plot_performance(euclid_classifier, top_5_scorer, 500, selected_euclid_fromtree, distributions_mults)

## Simulate distinguishing using a Bayes classifier

We need to first select some features (divisors) from the set of all divisors that we will query
the target with. This set should be the smallest (to not do a lot of queries) yet allow us to distinguish as
much as possible.

Then, we can build a true Bayes classifier. Since our features are conditionally independent (when conditioned on the class label) in our case naive Bayes == non-naive Bayes. We examine four feature selection algorithms:
 - Feature selection by pre-selection using tree-building and final selection by random subsets + classification error.
 - Feature selection via greedy classification error.
 - Feature selection via mRMR (maximal relevance, minimal redundancy) using mutual information.
 - Feature selection via JMI (Joint Mutual Information).

### Feature selection using trees + classification error

We can reuse the clustering + tree building approach above and just take the inputs that the greedy tree building choses as the features. However, we can also use more conventional feature selection approaches.

In [None]:
feats_in_tree = Counter()
for node in PreOrderIter(tree.root):
    if node.is_leaf:
        continue
    feats_in_tree[node.dmap_input] += 1
feats_in_tree = set(feats_in_tree.keys())

In [None]:
simulations = 500
retries = 500
nattack = range(50, 350, 50)
nfeats = range(1, 11)
num_workers = 30

bayes_classifier = BayesClassifier()
tree_random_subsets = RandomFeatures(sorted(feats_in_tree), distributions_mults, num_workers,
                                     simulations, bayes_classifier, top_5_scorer, retries)

selected_bayes_fromtree = feature_search(nfeats, nattack, tree_random_subsets, restarts=True)

In [None]:
plot_performance(bayes_classifier, top_5_scorer, 500, bay, distributions_mults)

### Feature selection via greedy classification
We can also use the classifier itself for feature selection. We iterate over all the divisors to pick the first feature with the best classifier results in simulation. Then we iteratively add features to it.

In [None]:
simulations = 500
nattack = range(50, 350, 50)
nfeats = range(1, 11)
num_workers = 30

bayes_classifier = BayesClassifier()
greedy = GreedyFeatures(sorted(feats_in_tree), distributions_mults, num_workers,
                        simulations, bayes_classifier, top_5_scorer)

selected_bayes_greedy_fromtree = feature_search(nfeats, nattack, greedy, restarts=False)

In [None]:
plot_performance(bayes_classifier, top_5_scorer, 500, selected_bayes_greedy_fromtree, distributions_mults)

In [None]:
simulations = 500
nattack = range(10, 210, 10)
nfeats = range(1, 16)
num_workers = 30

bayes_classifier = BayesClassifier()
greedy = GreedyFeatures(allfeats, distributions_mults, num_workers,
                        simulations, bayes_classifier, top_5_scorer)

selected_bayes_greedy_fromall = feature_search(nfeats, nattack, greedy, restarts=False)

In [None]:
plot_performance(bayes_classifier, top_5_scorer, 500, gre, distributions_mults)

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

nattack = 50
nfeats = 5
simulations = 20000
bayes = BayesClassifier(nattack=nattack)
X, y = to_sklearn(distributions_mults, selected_bayes_greedy_fromall[(nattack, nfeats)])
bayes.fit(X, y)

X_samp, y_samp = make_instance(nattack, simulations, X, y, progress=True)
fig, ax = plt.subplots(figsize=(12,8))
disp = ConfusionMatrixDisplay.from_predictions(y_samp, bayes.predict(X_samp), ax=ax, normalize="true", include_values=False, xticks_rotation="vertical")

ticks = []
labs = []
kls = None
for i, mult in enumerate(sorted(list(distributions_mults.keys()))):
    if kls is None or kls != mult.klass:
        ticks.append(i)
        labs.append(mult.klass.__name__)
        kls = mult.klass
ax.set_xticks(ticks, labs)
ax.set_yticks(ticks, labs)
ax.set_xticks([], minor=True)
ax.set_yticks([], minor=True);

### Feature selection via mRMR using mutual information

In [None]:
def mutual_information(class_priors, p_ci_list, n):
    """
    Compute mutual information I(X; Y) for a binomial feature with given class parameters.
    
    Args:
        class_priors (np.array): P(Y=c), shape (num_classes,)
        p_ci_list (np.array): Binomial parameters [p_{c,i}] for each class c, shape (num_classes,)
        n (int): Number of trials in binomial distribution
    
    Returns:
        float: Mutual information I(X; Y)
    """
    num_classes = len(class_priors)
    
    # Precompute all PMFs across x and classes
    x_values = np.arange(0, n + 1)[:, None]  # (n+1, 1)
    pmfs = binom.pmf(x_values, n, p_ci_list[None, :])  # Shape: (n+1, num_classes)
    
    # Compute joint probabilities P(Y=c) * P(X=x | Y=c)
    # Multiply class_priors (shape C) with each row of pmfs (each x has shape (C,))
    # class_priors[None, :] becomes (1, C), so broadcasting works.
    joint_probs = pmfs * class_priors[None, :]
    
    # Compute P(X=x) for all x
    px = np.sum(joint_probs, axis=1)

    # Compute H(Y|X):
    h_ygx = 0.0

    for x_idx in range(n + 1):
        current_px = px[x_idx]
        
        if current_px < 1e-9:  # Skip negligible probabilities
            continue
        
        cond_probs = joint_probs[x_idx] / current_px  # P(Y=c | X=x)
        
        # Compute entropy H(Y|X=x) using scipy's entropy function
        h_x = entropy(cond_probs, base=2)
        
        h_ygx += current_px * h_x
    
    # Prior entropy H(Y)
    h_y = entropy(class_priors, base=2)

    return h_y - h_ygx


def mutual_information_between_features(class_priors, p_ci_i, p_ci_j, n):
    """
    Compute mutual information between two features X_i and X_j.
    
    Parameters:
        class_priors (array): Prior probabilities of each class. Shape: (num_classes,)
        p_ci_i (array): Binomial parameters for feature i across classes. Shape: (num_classes,)
        p_ci_j (array): Binomial parameters for feature j across classes. Shape: (num_classes,)
        n (int): Number of trials for the binomial distribution.
    
    Returns:
        float: Mutual information I(X_i; X_j)
    """
    num_classes = len(class_priors)
    x_vals = np.arange(0, n + 1)  # Possible values of features
    
    ### Compute marginal distributions P(Xi=x), P(Xj=y) ###
    # PMF for feature i across all classes
    pmf_i_per_class = binom.pmf(x_vals[:, None], n, p_ci_i[None, :])
    px_i = np.sum(pmf_i_per_class * class_priors[None, :], axis=1)
    entropy_xi = entropy(px_i, base=2) if not np.allclose(px_i, 0.0) else 0.0
    
    # PMF for feature j across all classes
    pmf_j_per_class = binom.pmf(x_vals[:, None], n, p_ci_j[None, :])
    px_j = np.sum(pmf_j_per_class * class_priors[None, :], axis=1)
    entropy_xj = entropy(px_j, base=2) if not np.allclose(px_j, 0.0) else 0.0
    
    ### Compute joint distribution P(Xi=x, Xj=y) ###
    joint_xy = np.zeros((n + 1, n + 1))
    
    for c in range(num_classes):
        pmf_i_c = binom.pmf(x_vals, n, p_ci_i[c])
        pmf_j_c = binom.pmf(x_vals, n, p_ci_j[c])
        
        # Outer product gives joint PMF for class c
        outer = np.outer(pmf_i_c, pmf_j_c)
        joint_xy += class_priors[c] * outer
    
    # Compute entropy of the joint distribution
    epsilon = 1e-10  # To avoid log(0) issues
    non_zero = (joint_xy > epsilon)
    entropy_joint = -np.sum(joint_xy[non_zero] * np.log2(joint_xy[non_zero]))
    
    ### Mutual Information ###
    mi = entropy_xi + entropy_xj - entropy_joint
    
    return mi


def conditional_mutual_info(class_priors, XJ_params, XK_params, n):
    """
    Compute I(XK; Y | XJ) using vectorization with broadcasting.
    
    Args:
        XJ_params (array): p_{c,J} for all classes c.
        XK_params (array): p_{c,K} for all classes c.
        class_priors (array): P(Y=c) for all classes c.
        n (int): Number of trials in the binomial distribution.

    Returns:
        float: Conditional mutual information I(XK; Y | XJ).
    """
    K = len(class_priors)
    x_values = np.arange(n + 1)

    # Precompute PMFs for each class
    P_XJ_giv_Y = binom.pmf(x_values[:, None], n, XJ_params)  
    P_XK_giv_Y = binom.pmf(x_values[:, None], n, XK_params)  

    P_XJ_T = P_XJ_giv_Y.T  # Shape: (K, n+1)
    P_XK_T = P_XK_giv_Y.T

    ######################################################################
    ### Compute H(Y | XJ) ###############################################
    ######################################################################

    # Calculate P(XJ=xj) for all xj
    P_XJ_total = np.dot(class_priors, P_XJ_T)

    # Numerators of posterior probabilities P(Y=c | XJ=xj)
    numerators_YgXJ = class_priors[:, None] * P_XJ_T  

    valid_mask = P_XJ_total > 1e-9
    posterior_YgXJ = np.zeros_like(numerators_YgXJ, dtype=float)
    posterior_YgXJ[:, valid_mask] = (
        numerators_YgXJ[:, valid_mask] / 
        P_XJ_total[valid_mask]
    )

    log_p = np.log2(posterior_YgXJ + 1e-9)  
    entropy_terms_HYgXJ = -np.sum(
        posterior_YgXJ * log_p, 
        axis=0,
        where=(posterior_YgXJ > 1e-9)
    )
    
    H_Y_given_XJ = np.dot(P_XJ_total, entropy_terms_HYgXJ)

    ######################################################################
    ### Compute H(Y | XJ, XK) ###########################################
    ######################################################################

    # Broadcast to compute joint PMF P(XJ=xj, XK=xk | Y=c)
    P_XJ_giv_Y_T = P_XJ_T[..., None]  # Shape: (K, n+1, 1)
    P_XK_giv_Y_T = P_XK_T[:, None, :]  # Shape: (K, 1, n+1)

    joint_pmf_conditional = (
        P_XJ_giv_Y_T * 
        P_XK_giv_Y_T
    )  # Shape: (K, n+1, n+1)

    numerators = class_priors[:, None, None] * joint_pmf_conditional  

    denominators = np.sum(numerators, axis=0)  # Shape: (n+1, n+1)

    valid_mask_3d = (denominators > 1e-9)[None, ...]  # Expand for class dimension

    # Compute posterior probabilities using broadcasting and where
    posterior_YgXJXK = numerators / denominators[None, ...]
    posterior_YgXJXK = np.where(valid_mask_3d, posterior_YgXJXK, 0.0)

    log_p_joint = np.log2(posterior_YgXJXK + 1e-9)  
    entropy_terms_HYgXJXK = -np.sum(
        posterior_YgXJXK * log_p_joint,
        axis=0,  # Sum over classes (axis 0 is K)
        where=(posterior_YgXJXK > 1e-9),
    )

    H_Y_given_XJXK = np.sum(denominators * entropy_terms_HYgXJXK)

    ######################################################################
    ### Compute CMI #####################################################
    ######################################################################

    cmi = H_Y_given_XJ - H_Y_given_XJXK

    return max(cmi, 0.0)  

#### Relevance and redundancy
First, lets pre-compute the relevance and redundancy metrics for mRMR (also used in JMI). We assume a uniform class prior.

In [None]:
def compute_priors(nmults: int):
    return np.full(nmults, 1/nmults, dtype=np.float64)

def compute_probs(feats: list[int], mults_map: dict[MultIdent, ProbMap]):
    probs, _ = to_sklearn(mults_map, feats)
    return probs.T

def compute_relevance(feats: list[int], priors, probs, nattack: int):
    relevance = np.zeros(nallfeats, dtype=np.float64)
    for i, divisor in enumerate(tqdm(feats)):
        mi = mutual_information(priors, probs[i, ], nattack)
        relevance[i] = mi
    return relevance

def compute_redundancy(feats: list[int], priors, probs, nattack: int, num_workers: int = 30):
    nallfeats = len(feats)
    redundancy = np.zeros((nallfeats, nallfeats), dtype=np.float64)
    with TaskExecutor(max_workers=num_workers) as pool:
        for i in trange(nallfeats):
            for j in range(nallfeats):
                if i < j:
                    continue
                pool.submit_task((i, j),
                                 mutual_information_between_features,
                                 priors, probs[i, ], probs[j, ], nattack)
            for (i, j), future in pool.as_completed():
                mi = future.result()
                redundancy[i][j] = mi
                redundancy[j][i] = mi
    return redundancy

In [None]:
nattack = 100

priors = compute_priors(nmults)
probs = compute_probs(allfeats, distributions_mults)
relevance = compute_relevance(allfeats, priors, probs, nattack)

In [None]:
redundancy = compute_redundancy(allfeats, priors, probs, nattack)

Store the relevance and redundancy arrays.

In [None]:
with open("relevance.pickle", "wb") as f:
    pickle.dump(relevance, f)
with open("redundancy.pickle", "wb") as f:
    pickle.dump(redundancy, f)

In [None]:
class MRMRFeatures(FeatureSelector):
    def __init__(self,
                 allfeats: list[int],
                 mults: dict[MultIdent, ProbMap],
                 num_workers: int):
        self.allfeats = allfeats
        self.mults = mults
        self.num_workers = num_workers

    def prepare(self, nattack: int):
        self.nattack = nattack

    def select(self, nfeats: int, startwith: list[int] = None) -> list[int]:
        pass



def mrmr_selection(relevance, redundancy, nfeats):
    """
    Select top features using mRMR.
    
    Returns:
        indices of selected features.
    """
    selected_indices = []
    remaining_indices = list(range(nallfeats))
    
    # Initialize by selecting the most relevant feature
    first_feature_idx = np.argmax(relevance)
    selected_indices.append(first_feature_idx)
    remaining_indices.remove(first_feature_idx)
    
    while len(selected_indices) < nfeats:
        candidates_scores = []
        
        for candidate in remaining_indices:
            # Compute mRMR score: relevance - average redundancy with selected features
            current_relevance = relevance[candidate]
            
            avg_red = 0.0
            if len(selected_indices) > 0:
                sum_red = np.sum(redundancy[candidate][selected_indices])
                avg_red = sum_red / len(selected_indices)
            
            score = current_relevance - avg_red
            candidates_scores.append(score)
        
        # Select the candidate with highest score
        best_candidate_idx = remaining_indices[np.argmax(candidates_scores)]
        selected_indices.append(best_candidate_idx)
        remaining_indices.remove(best_candidate_idx)
    
    return selected_indices

In [None]:
selected_mrmr = [allfeats[i] for i in mrmr_selection(relevance, redundancy, nfeats=15)]

In [None]:
mrmrs = {}
for nfeats in trange(1, 41, 1):
    f = [allfeats[i] for i in mrmr_selection(relevance, redundancy, nfeats=nfeats)]
    for nattack in range(5, 105, 5):
        mrmrs[(nattack, nfeats)] = f

bayes_classifier = BayesClassifier()
plot_performance(bayes_classifier, avg_rank_scorer, 2000, mrmrs, distributions_mults)

In [None]:
plt.close()

### Feature selection via JMI

In [None]:
def jmi_selection(features_params_list, class_priors, n_trials, relevance, nfeats):
    """
    Select top features using JMI.
    
    Returns:
        indices of selected features.
    """
    selected_indices = []
    remaining_indices = list(range(nallfeats))
    
    # Initialize by selecting the most relevant feature
    first_feature_idx = np.argmax(relevance)
    selected_indices.append(first_feature_idx)
    remaining_indices.remove(first_feature_idx)
    
    while len(selected_indices) < nfeats:
        candidates_scores = []
        
        for candidate in tqdm(remaining_indices):
            # Compute mRMR score: relevance - average redundancy with selected features
            current_relevance = relevance[candidate]
            
            sum_cmi = 0.0
            for selected in selected_indices:
                XJ_params = features_params_list[selected]
                XK_params = features_params_list[candidate]
                
                cmi_val = conditional_mutual_info(
                    class_priors=class_priors,
                    XJ_params=XJ_params,
                    XK_params=XK_params,
                    n=n_trials
                )
                sum_cmi += cmi_val
            avg_cmi = sum_cmi / len(selected_indices)
            score = current_relevance + avg_cmi
            candidates_scores.append(score)
        
        # Select the candidate with highest score
        best_candidate_idx = remaining_indices[np.argmax(candidates_scores)]
        selected_indices.append(best_candidate_idx)
        remaining_indices.remove(best_candidate_idx)
    
    return selected_indices

In [None]:
nattack = 100
selected_jmi = [allfeats[i] for i in jmi_selection(probs, priors, nattack, relevance, nfeats=40)]

In [None]:
selected_jmi

In [None]:
mrmrs[(100,5)]

In [None]:
jmis = {}
for nfeats in range(1, 6, 1):
    f = selected_jmi[:nfeats]
    for nattack in range(5, 105, 5):
        jmis[(nattack, nfeats)] = f

bayes_classifier = BayesClassifier()
plot_performance(bayes_classifier, avg_rank_scorer, 2000, jmis, distributions_mults)

In [None]:
plt.close()