aboutsummaryrefslogtreecommitdiff
path: root/analysis/countermeasures/distinguish.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'analysis/countermeasures/distinguish.ipynb')
-rw-r--r--analysis/countermeasures/distinguish.ipynb1510
1 files changed, 1510 insertions, 0 deletions
diff --git a/analysis/countermeasures/distinguish.ipynb b/analysis/countermeasures/distinguish.ipynb
new file mode 100644
index 0000000..fdb3f6a
--- /dev/null
+++ b/analysis/countermeasures/distinguish.ipynb
@@ -0,0 +1,1510 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "e76983df-053b-450b-976c-295826248978",
+ "metadata": {},
+ "source": [
+ "# Unraveling scalar mults and countermeasures"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bc1528b8-61cd-4219-993f-e3f1ac79e801",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pickle\n",
+ "import itertools\n",
+ "import glob\n",
+ "import random\n",
+ "import math\n",
+ "\n",
+ "from collections import Counter\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from scipy.stats import binom, entropy\n",
+ "from scipy.spatial import distance\n",
+ "from tqdm.auto import tqdm, trange\n",
+ "from anytree import PreOrderIter, Walker\n",
+ "from matplotlib import pyplot as plt\n",
+ "\n",
+ "from pyecsca.ec.mult import *\n",
+ "from pyecsca.misc.utils import TaskExecutor, silent\n",
+ "from pyecsca.sca.re.tree import Map, Tree\n",
+ "\n",
+ "from common import *\n",
+ "\n",
+ "%matplotlib ipympl"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4868c083-8073-453d-b508-704fcb6d6f2a",
+ "metadata": {},
+ "source": [
+ "## Prepare\n",
+ "Select *divisor name* to restrict the features. Select *kind* to pick the probmap source."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ccb00342-3c48-49c9-bedf-2341e5eae3a2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "divisor_name = \"all\"\n",
+ "kind = \"all\"\n",
+ "allfeats = list(filter(lambda feat: feat not in (1,2,3,4,5), divisor_map[divisor_name]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3dbac9be-d098-479a-8ca2-f531f6668f7c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load\n",
+ "try:\n",
+ " with open(f\"{divisor_name}_{kind}_distrs.pickle\", \"rb\") as f:\n",
+ " distributions_mults = pickle.load(f)\n",
+ "except FileNotFoundError:\n",
+ " with open(f\"all_{kind}_distrs.pickle\", \"rb\") as f:\n",
+ " distributions_mults = pickle.load(f)\n",
+ " for probmap in distributions_mults.values():\n",
+ " probmap.narrow(allfeats)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "38c81e38-a37c-4e58-ac9e-927d14dad458",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "allmults = list(distributions_mults.keys())\n",
+ "nmults = len(allmults)\n",
+ "nallfeats = len(allfeats)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1f783baf-bc81-40c1-9282-e2dfdacfd17c",
+ "metadata": {},
+ "source": [
+ "## Build dmap and tree"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d161f3e8-6e39-47c1-a26f-23f910f3fe26",
+ "metadata": {},
+ "source": [
+ "Select the n for building the tree."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2307bf7c-4fac-489d-8527-7ddbf536a148",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nbuild = 10000\n",
+ "alpha = 0.05"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0b85fad7-392f-4701-9329-d75d39736bbb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Now go over all divisors, cluster based on overlapping CI for given n?\n",
+ "io_map = {mult:{} for mult in allmults}\n",
+ "for divisor in allfeats:\n",
+ " prev_ci_low = None\n",
+ " prev_ci_high = None\n",
+ " groups = {}\n",
+ " pvals = {}\n",
+ " group = 0\n",
+ " for mult, probmap in sorted(distributions_mults.items(), key=lambda item: -item[1][divisor]):\n",
+ " # We are going from high to low p.\n",
+ " pval = probmap[divisor]\n",
+ " pvals[mult] = pval\n",
+ " ci_low, ci_high = conf_interval(pval, nbuild, alpha)\n",
+ " ci_low = max(ci_low, 0.0)\n",
+ " ci_high = min(ci_high, 1.0)\n",
+ " if (prev_ci_low is None and prev_ci_high is None) or prev_ci_low >= ci_high:\n",
+ " g = groups.setdefault(f\"arbitrary{group}\", set())\n",
+ " g.add(mult)\n",
+ " group += 1\n",
+ " else:\n",
+ " g = groups.setdefault(f\"arbitrary{group}\", set())\n",
+ " g.add(mult)\n",
+ " prev_ci_low = ci_low\n",
+ " prev_ci_high = ci_high\n",
+ " \n",
+ " #print(f\"Divisor: {divisor}, num groups: {group}\", end=\"\\n\\t\")\n",
+ " #for g in groups.values():\n",
+ " # print(len(g), end=\", \")\n",
+ " #print()\n",
+ " for group, mults in groups.items():\n",
+ " mult_pvals = [pvals[mult] for mult in mults]\n",
+ " group_pval_avg = np.mean(mult_pvals)\n",
+ " group_pval_var = np.var(mult_pvals)\n",
+ " group_pval_min = np.min(mult_pvals)\n",
+ " group_pval_max = np.max(mult_pvals)\n",
+ " for mult in mults:\n",
+ " io_map[mult][divisor] = (group, group_pval_avg, group_pval_var, group_pval_min, group_pval_max)\n",
+ "\n",
+ "# then build dmap\n",
+ "dmap = Map.from_io_maps(set(distributions_mults.keys()), io_map)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "06104104-b612-40e9-bc1d-646356a13381",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(dmap.describe())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "eb8672ca-b76b-411d-b514-2387b555f184",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# deduplicate dmap\n",
+ "dmap.deduplicate()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ccba09b0-31c3-450b-af30-efaa64329743",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(dmap.describe())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5735e7d4-149c-4184-96f7-dcfd6017fbad",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# build a tree\n",
+ "with silent():\n",
+ " tree = Tree.build(set(allmults), dmap)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d41093df-32c4-450d-922d-5ad042539397",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(tree.describe())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "de577429-d87c-4967-be17-75cbb378860c",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "print(tree.render_basic())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "437bcd9c-1da5-428a-a979-0835326777f3",
+ "metadata": {},
+ "source": [
+ "## Simulate distinguishing using a tree\n",
+ "We can now simulate distinguishing using the tree and how it behaves with increasing the number of samples per divisor collected."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "109ed95f-0949-4251-874f-9b87cfe97a00",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "simulations = 1000\n",
+ "\n",
+ "for nattack in trange(100, 10000, 100):\n",
+ " successes = 0\n",
+ " pathiness = 0\n",
+ " for i in range(simulations):\n",
+ " true_mult = random.choice(allmults)\n",
+ " probmap = distributions_mults[true_mult]\n",
+ " node = tree.root\n",
+ " while True:\n",
+ " if node.is_leaf:\n",
+ " break\n",
+ " divisor = node.dmap_input\n",
+ " prob = probmap[divisor]\n",
+ " sampled_prob = binom(nattack, prob).rvs() / nattack\n",
+ " best_child = None\n",
+ " true_child = None\n",
+ " best_group_distance = None\n",
+ " #print(f\"Divisor: {divisor}, prob: {prob}, sampled: {sampled_prob}\")\n",
+ " for child in node.children:\n",
+ " if true_mult in child.cfgs:\n",
+ " true_child = child\n",
+ " group, group_pval_avg, group_pval_var, group_pval_min, group_pval_max = child.response\n",
+ " group_distance = min(abs(sampled_prob - group_pval_min), abs(sampled_prob - group_pval_max))\n",
+ " #print(f\"Child {group}, {group_pval_avg}\")\n",
+ " if best_child is None or \\\n",
+ " (group_distance < best_group_distance):\n",
+ " best_child = child\n",
+ " best_group_distance = group_distance\n",
+ " if sampled_prob > group_pval_min and sampled_prob < group_pval_max:\n",
+ " best_child = child\n",
+ " break\n",
+ " #print(f\"Best {best_child.response}\")\n",
+ " if true_child is not None and true_child != best_child:\n",
+ " pass\n",
+ " #print(f\"Mistake! {prob}, {sampled_prob} true:{true_child.response}, chosen:{best_child.response}\")\n",
+ " node = best_child\n",
+ " if true_mult in node.cfgs:\n",
+ " pathiness += 1\n",
+ " #print(f\"Arrived: {true_mult in node.cfgs}\")\n",
+ " if true_mult in node.cfgs:\n",
+ " successes += 1\n",
+ " print(f\"{nattack}: success rate {successes/simulations}, pathiness {pathiness/simulations}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "308df683-952e-430a-b2bd-f19bcfb98b8e",
+ "metadata": {},
+ "source": [
+ "## Simulate distinguishing using a distance metric\n",
+ "\n",
+ "We need to first select some features (divisors) from the set of all divisors that we will query\n",
+ "the target with. This set should be the smallest (to not do a lot of queries) yet allow us to distinguish as\n",
+ "much as possible."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "62d2f2a2-495e-459d-b0e2-89c9a5973b1e",
+ "metadata": {},
+ "source": [
+ "### Feature selection using trees + classification error\n",
+ "\n",
+ "We can reuse the clustering + tree building approach above and just take the inputs that the greedy tree building choses as the features. However, we can also use more conventional feature selection approaches."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "cc1a9956-bc8c-47cf-b6ec-093c6cf85c7d",
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from sklearn.base import BaseEstimator, ClassifierMixin\n",
+ "from sklearn.utils.validation import validate_data, check_is_fitted\n",
+ "from sklearn.utils.multiclass import unique_labels\n",
+ "from scipy.special import logsumexp\n",
+ "from sklearn.metrics import euclidean_distances, top_k_accuracy_score, make_scorer, accuracy_score\n",
+ "\n",
+ "\n",
+ "class EuclidClassifier(ClassifierMixin, BaseEstimator):\n",
+ " def __init__(self, *, nattack=100):\n",
+ " self.nattack = nattack\n",
+ "\n",
+ " def fit(self, X, y):\n",
+ " X, y = validate_data(self, X, y)\n",
+ " if not np.logical_and(X >= 0, X <= 1).all():\n",
+ " raise TypeError(\"Expects valid probabilities in X.\")\n",
+ " self.classes_ = unique_labels(y)\n",
+ " if len(self.classes_) != len(y):\n",
+ " raise ValueError(\"Expects only one sample per class containing the binomial probabilities.\")\n",
+ " self.X_ = X\n",
+ " self.y_ = y\n",
+ " return self\n",
+ "\n",
+ " def decision_function(self, X):\n",
+ " check_is_fitted(self)\n",
+ " X = validate_data(self, X, reset=False)\n",
+ " distances = euclidean_distances(X / self.nattack, self.X_)\n",
+ " return -distances\n",
+ "\n",
+ " def predict(self, X):\n",
+ " check_is_fitted(self)\n",
+ " X = validate_data(self, X, reset=False)\n",
+ " distances = euclidean_distances(X / self.nattack, self.X_)\n",
+ " closest = np.argmin(distances, axis=1)\n",
+ " return self.classes_[closest]\n",
+ "\n",
+ "\n",
+ "class BayesClassifier(ClassifierMixin, BaseEstimator):\n",
+ " def __init__(self, *, nattack=100):\n",
+ " self.nattack = nattack\n",
+ "\n",
+ " def fit(self, X, y):\n",
+ " # X has (nmults = nsamples, nfeats)\n",
+ " X, y = validate_data(self, X, y)\n",
+ " if not np.logical_and(X >= 0, X <= 1).all():\n",
+ " raise TypeError(\"Expects valid probabilities in X.\")\n",
+ " self.classes_ = unique_labels(y)\n",
+ " if len(self.classes_) != len(y):\n",
+ " raise ValueError(\"Expects only one sample per class containing the binomial probabilities.\")\n",
+ " self.X_ = X\n",
+ " self.y_ = y\n",
+ " return self\n",
+ "\n",
+ " def decision_function(self, X):\n",
+ " check_is_fitted(self)\n",
+ " X = validate_data(self, X, reset=False)\n",
+ " # We have a uniform prior, so we can ignore it.\n",
+ " probas = np.zeros((len(X), len(self.classes_)))\n",
+ " for i, row in enumerate(X):\n",
+ " p = binom(self.nattack, self.X_).logpmf(row)\n",
+ " s = np.sum(p, axis=1)\n",
+ " log_prob_x = logsumexp(s)\n",
+ " res = np.exp(s - log_prob_x)\n",
+ " probas[i, ] = res\n",
+ " return probas\n",
+ "\n",
+ " def predict_proba(self, X):\n",
+ " return self.decision_function(X)\n",
+ "\n",
+ " def predict(self, X):\n",
+ " check_is_fitted(self)\n",
+ " X = validate_data(self, X, reset=False)\n",
+ " # We have a uniform prior, so we can ignore it.\n",
+ " results = np.empty(len(X), dtype=self.classes_.dtype)\n",
+ " for i, row in enumerate(X):\n",
+ " p = binom(self.nattack, self.X_).logpmf(row)\n",
+ " s = np.sum(p, axis=1)\n",
+ " most_likely = np.argmax(s)\n",
+ " results[i] = self.classes_[most_likely]\n",
+ " return results\n",
+ "\n",
+ "\n",
+ "def to_sklearn(mults_map: dict[MultIdent, ProbMap], feats: list[int]):\n",
+ " nfeats = len(feats)\n",
+ " nmults = len(mults_map)\n",
+ " classes = np.arange(nmults, dtype=np.uint32)\n",
+ " probs = np.zeros((nmults, nfeats), dtype=np.float64)\n",
+ " mults = sorted(list(mults_map.keys()))\n",
+ " for i, divisor in enumerate(feats):\n",
+ " for j, mult in enumerate(mults):\n",
+ " probmap = mults_map[mult]\n",
+ " probs[j, i] = probmap[divisor]\n",
+ " return probs, classes\n",
+ "\n",
+ "\n",
+ "def make_instance(nattack: int,\n",
+ " simulations: int,\n",
+ " X,\n",
+ " y,\n",
+ " progress=False):\n",
+ " nmults, nfeats = X.shape\n",
+ " X_samp = np.zeros((simulations, nfeats), dtype=np.uint32)\n",
+ " y_samp = np.zeros(simulations, dtype=np.uint32)\n",
+ "\n",
+ " r = trange(simulations) if progress else range(simulations)\n",
+ " for i in r:\n",
+ " if i < nmults and simulations >= nmults:\n",
+ " j = i\n",
+ " else:\n",
+ " j = random.randrange(nmults)\n",
+ " X_samp[i] = binom(nattack, X[j]).rvs()\n",
+ " y_samp[i] = j\n",
+ " return X_samp, y_samp\n",
+ "\n",
+ "\n",
+ "def evaluate_classifier(nattack: int,\n",
+ " simulations: int,\n",
+ " X,\n",
+ " y,\n",
+ " classifier,\n",
+ " scorer):\n",
+ " classifier.set_params(nattack=nattack)\n",
+ " classifier.fit(X, y)\n",
+ " X_samp, y_samp = make_instance(nattack, simulations, X, y)\n",
+ " return scorer(classifier, X_samp, y_samp)\n",
+ "\n",
+ "\n",
+ "def average_rank_score(y_true, y_pred, labels=None):\n",
+ " y_true = np.asarray(y_true)\n",
+ " y_pred = np.asarray(y_pred)\n",
+ " \n",
+ " n_samples, n_classes = y_pred.shape\n",
+ " if labels is not None:\n",
+ " labels = np.asarray(labels)\n",
+ " if len(labels) != n_classes:\n",
+ " raise ValueError()\n",
+ " label_indexes = np.searchsorted(labels, y_true)\n",
+ " indexes = np.where(labels[label_indexes] == y_true, label_indexes, -1)\n",
+ " else:\n",
+ " indexes = y_true\n",
+ " true_scores = y_pred[np.arange(n_samples), indexes]\n",
+ " \n",
+ " counts_higher = np.sum(y_pred > true_scores[:, None], axis=1)\n",
+ " \n",
+ " ranks = counts_higher + 1\n",
+ " \n",
+ " return ranks.mean()\n",
+ "\n",
+ "X = np.array([[0.7, 0.7, 0.1], [0.3, 0.7, 0.1], [0.2, 0.5, 0.1], [0.1, 0.1, 0.4]])\n",
+ "y = np.array([1, 2, 3, 4])\n",
+ "\n",
+ "euc = EuclidClassifier(nattack=100).fit(X, y)\n",
+ "label = euc.predict(np.array([20, 50, 10]).reshape(1, -1))\n",
+ "dec = euc.decision_function(np.array([[20, 50, 10], [70, 60, 20]]))\n",
+ "print(label)\n",
+ "print(dec)\n",
+ "\n",
+ "clf = BayesClassifier(nattack=100).fit(X, y)\n",
+ "label = clf.predict(np.array([20, 50, 10]).reshape(1, -1))\n",
+ "ps = clf.predict_proba(np.array([20, 50, 10]).reshape(1, -1))\n",
+ "print(label)\n",
+ "print(ps)\n",
+ "\n",
+ "\n",
+ "acc = top_k_accuracy_score(np.array([3, 1]),\n",
+ " euc.decision_function(np.array([[20, 50, 10], [70, 60, 20]])),\n",
+ " labels = [1, 2, 3, 4],\n",
+ " k=1)\n",
+ "print(acc)\n",
+ "acc = top_k_accuracy_score(np.array([3, 1]),\n",
+ " clf.predict_proba(np.array([[20, 50, 10], [70, 60, 20]])),\n",
+ " labels = [1, 2, 3, 4],\n",
+ " k=1)\n",
+ "print(acc)\n",
+ "\n",
+ "avg = average_rank_score(np.array([2, 0]),\n",
+ " euc.decision_function(np.array([[20, 50, 10], [70, 60, 20]])))\n",
+ "print(avg)\n",
+ "avg = average_rank_score(np.array([3, 1]),\n",
+ " euc.decision_function(np.array([[20, 50, 10], [70, 60, 20]])),\n",
+ " labels = [1, 2, 3, 4])\n",
+ "print(avg)\n",
+ "\n",
+ "accuracy_scorer = make_scorer(\n",
+ " top_k_accuracy_score,\n",
+ " greater_is_better=True,\n",
+ " response_method=(\"decision_function\", \"predict_proba\"),\n",
+ ")\n",
+ "\n",
+ "#accuracy_scorer.__str__ = lambda self: \"Accuracy\"\n",
+ "\n",
+ "top_5_scorer = make_scorer(\n",
+ " top_k_accuracy_score,\n",
+ " greater_is_better=True,\n",
+ " response_method=(\"decision_function\", \"predict_proba\"),\n",
+ " k=5\n",
+ ")\n",
+ "\n",
+ "#top_5_scorer.__str__ = lambda self: \"Top-5 accuracy\"\n",
+ "\n",
+ "top_10_scorer = make_scorer(\n",
+ " top_k_accuracy_score,\n",
+ " greater_is_better=True,\n",
+ " response_method=(\"decision_function\", \"predict_proba\"),\n",
+ " k=10\n",
+ ")\n",
+ "\n",
+ "#top_10_scorer.__str__ = lambda self: \"Top-10 accuracy\"\n",
+ "\n",
+ "avg_rank_scorer = make_scorer(\n",
+ " average_rank_score,\n",
+ " greater_is_better=False,\n",
+ " response_method=(\"decision_function\", \"predict_proba\"),\n",
+ ")\n",
+ "\n",
+ "#avg_rank_scorer.__str__ = lambda self: \"Average rank\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a9fae775-797f-4efe-ac28-d83a8c905372",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class FeatureSelector:\n",
+ " def __init__(self,\n",
+ " allfeats: list[int],\n",
+ " mults: dict[MultIdent, ProbMap],\n",
+ " num_workers: int):\n",
+ " self.allfeats = allfeats\n",
+ " self.mults = mults\n",
+ " self.num_workers = num_workers\n",
+ "\n",
+ " def prepare(self, nattack: int):\n",
+ " self.nattack = nattack\n",
+ "\n",
+ " def select(self, nfeats: int, startwith: list[int] = None) -> list[int]:\n",
+ " pass\n",
+ "\n",
+ "class FeaturesByClassification(FeatureSelector):\n",
+ " def __init__(self,\n",
+ " allfeats: list[int],\n",
+ " mults: dict[MultIdent, ProbMap],\n",
+ " num_workers: int,\n",
+ " simulations: int,\n",
+ " classifier,\n",
+ " scorer):\n",
+ " super().__init__(allfeats, mults, num_workers)\n",
+ " self.simulations = simulations\n",
+ " self.classifier = classifier\n",
+ " self.scorer = scorer\n",
+ "\n",
+ "class RandomFeatures(FeaturesByClassification):\n",
+ "\n",
+ " def __init__(self,\n",
+ " allfeats: list[int],\n",
+ " mults: dict[MultIdent, ProbMap],\n",
+ " num_workers: int,\n",
+ " simulations: int,\n",
+ " classifier,\n",
+ " scorer,\n",
+ " retries: int):\n",
+ " super().__init__(allfeats, mults, num_workers, simulations, classifier, scorer)\n",
+ " self.retries = retries\n",
+ " \n",
+ " def _select_random(self, nfeats: int, startwith: list[int] = None) -> list[int]:\n",
+ " if startwith is None:\n",
+ " startwith = []\n",
+ " toselect = nfeats - len(startwith)\n",
+ " if toselect > 0:\n",
+ " available_feats = list(filter(lambda feat: feat not in startwith, self.allfeats))\n",
+ " selected = random.sample(available_feats, toselect)\n",
+ " return startwith + selected\n",
+ " elif toselect < 0:\n",
+ " return random.sample(startwith, nfeats)\n",
+ " else:\n",
+ " return startwith\n",
+ "\n",
+ " def select(self, nfeats: int, startwith: list[int] = None) -> tuple[list[int], float]:\n",
+ " with TaskExecutor(max_workers=self.num_workers) as pool:\n",
+ " feat_map = []\n",
+ " for i in range(self.retries):\n",
+ " feats = self._select_random(nfeats, startwith)\n",
+ " X, y = to_sklearn(self.mults, feats)\n",
+ " feat_map.append(feats)\n",
+ " pool.submit_task(i,\n",
+ " evaluate_classifier,\n",
+ " self.nattack, self.simulations,\n",
+ " X, y, self.classifier, self.scorer)\n",
+ " best_score = None\n",
+ " best_feats = None\n",
+ " for i, future in tqdm(pool.as_completed(), total=len(pool.tasks), desc=\"retries\", leave=False):\n",
+ " score = future.result()\n",
+ " #print(i, feat_map[i], score)\n",
+ " if best_score is None or score > best_score:\n",
+ " best_score = score\n",
+ " best_feats = feat_map[i]\n",
+ " return best_feats, best_score\n",
+ "\n",
+ "\n",
+ "class GreedyFeatures(FeaturesByClassification):\n",
+ "\n",
+ " def select(self, nfeats: int, startwith: list[int] = None) -> tuple[list[int], float]:\n",
+ " if startwith is None:\n",
+ " startwith = []\n",
+ " toselect = nfeats - len(startwith)\n",
+ " if toselect < 0:\n",
+ " raise ValueError(\"No features to select.\")\n",
+ " available_feats = list(filter(lambda feat: feat not in startwith, self.allfeats))\n",
+ " current = list(startwith)\n",
+ " with TaskExecutor(max_workers=self.num_workers) as pool:\n",
+ " while toselect > 0:\n",
+ " for feat in available_feats:\n",
+ " feats = current + [feat]\n",
+ " X, y = to_sklearn(self.mults, feats)\n",
+ " pool.submit_task(feat,\n",
+ " evaluate_classifier,\n",
+ " self.nattack, self.simulations,\n",
+ " X, y, self.classifier, self.scorer)\n",
+ " best_score = None\n",
+ " best_feat = None\n",
+ " for feat, future in tqdm(pool.as_completed(), total=len(pool.tasks), leave=False):\n",
+ " score = future.result()\n",
+ " if best_score is None or score > best_score:\n",
+ " best_score = score\n",
+ " best_feat = feat\n",
+ " current.append(best_feat)\n",
+ " toselect -= 1\n",
+ " return current, best_score\n",
+ "\n",
+ "\n",
+ "def feature_search(feat_range, nattack_range, selector, restarts=False):\n",
+ " if isinstance(feat_range, int):\n",
+ " feat_range = [feat_range]\n",
+ " if isinstance(nattack_range, int):\n",
+ " nattack_range = [nattack_range]\n",
+ " results = {}\n",
+ " for nattack in tqdm(nattack_range, desc=\"nattack\", smoothing=0):\n",
+ " selector.prepare(nattack)\n",
+ " feats = []\n",
+ " for nfeats in tqdm(feat_range, desc=\"nfeats\", leave=False):\n",
+ " feats, score = selector.select(nfeats, [] if restarts else feats)\n",
+ " results[(nattack, nfeats)] = feats\n",
+ " print(f\"{nattack},{nfeats}: {feats}, {score}\")\n",
+ " return results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f1c0bebe-c519-4241-a163-63613b929db2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plot_performance(classifier, scorer, simulations, feature_map, mults, num_workers = 30):\n",
+ " scores = {}\n",
+ " with TaskExecutor(max_workers=num_workers) as pool:\n",
+ " for (nattack, nfeats), feats in feature_map.items():\n",
+ " X, y = to_sklearn(mults, feats)\n",
+ " pool.submit_task((nattack, nfeats),\n",
+ " evaluate_classifier,\n",
+ " nattack, simulations, X, y, classifier, scorer)\n",
+ " for (nattack, nfeats), future in tqdm(pool.as_completed(), desc=\"Evaluating\", leave=False, total=len(pool.tasks)):\n",
+ " score = future.result()\n",
+ " scores[(nattack, nfeats)] = score\n",
+ "\n",
+ " x_coords = [k[0] for k in scores.keys()]\n",
+ " y_coords = [k[1] for k in scores.keys()]\n",
+ " \n",
+ " x_unique = sorted(set(x_coords))\n",
+ " y_unique = sorted(set(y_coords))\n",
+ "\n",
+ " heatmap_data = np.zeros((len(y_unique), len(x_unique)))\n",
+ " \n",
+ " for (x, y), score in scores.items():\n",
+ " x_index = x_unique.index(x)\n",
+ " y_index = y_unique.index(y)\n",
+ " heatmap_data[y_index, x_index] = score\n",
+ "\n",
+ " x_mesh, y_mesh = np.meshgrid(x_unique, y_unique)\n",
+ " \n",
+ " plt.pcolormesh(x_mesh, y_mesh, heatmap_data, cmap='viridis', shading='auto')\n",
+ " plt.colorbar(label='Score')\n",
+ "\n",
+ " for i in range(len(y_unique)):\n",
+ " for j in range(len(x_unique)):\n",
+ " plt.text(x_unique[j], y_unique[i], f'{heatmap_data[i, j]:.2f}', ha='center', va='center', color='white' if heatmap_data[i, j] < 0.4 else \"black\")\n",
+ " \n",
+ " x_contour, y_contour = np.meshgrid(np.linspace(min(x_unique), max(x_unique), 100), \n",
+ " np.linspace(min(y_unique), max(y_unique), 100))\n",
+ " z_contour = x_contour * y_contour\n",
+ " \n",
+ " contour = plt.contour(x_contour, y_contour, z_contour, levels=[100, 200, 300, 400, 500, 600, 700, 800, 900, 1000], colors='white', zorder=4)\n",
+ " plt.clabel(contour, inline=True, fontsize=10)\n",
+ " \n",
+ " plt.xticks(ticks=x_unique, labels=x_unique)\n",
+ " plt.yticks(ticks=y_unique, labels=y_unique)\n",
+ " plt.xlabel('nattack')\n",
+ " plt.ylabel('nfeats')\n",
+ " plt.title(f'{scorer._score_func.__name__}{scorer._kwargs} ({classifier.__class__.__name__})')\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "beb5720a-f793-4ad9-ad27-1bd943bb325b",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "feats_in_tree = Counter()\n",
+ "for node in PreOrderIter(tree.root):\n",
+ " if node.is_leaf:\n",
+ " continue\n",
+ " feats_in_tree[node.dmap_input] += 1\n",
+ "feats_in_tree = set(feats_in_tree.keys())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6e3260c9-c0fa-4828-a749-4d34499abacf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "simulations = 500\n",
+ "retries = 500\n",
+ "nattack = range(50, 350, 50)\n",
+ "nfeats = range(1, 11)\n",
+ "num_workers = 30\n",
+ "\n",
+ "euclid_classifier = EuclidClassifier()\n",
+ "tree_random_subsets = RandomFeatures(sorted(feats_in_tree), distributions_mults, num_workers,\n",
+ " simulations, euclid_classifier, top_5_scorer, retries)\n",
+ "\n",
+ "selected_euclid_fromtree = feature_search(nfeats, nattack, tree_random_subsets, restarts=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0b6a1a5b-82dd-44d4-82dc-83b16ac5bc82",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plot_performance(euclid_classifier, top_5_scorer, 500, selected_euclid_fromtree, distributions_mults)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f16a5868-e92c-4b84-9f19-664627d9848a",
+ "metadata": {},
+ "source": [
+ "## Simulate distinguishing using a Bayes classifier\n",
+ "\n",
+ "We need to first select some features (divisors) from the set of all divisors that we will query\n",
+ "the target with. This set should be the smallest (to not do a lot of queries) yet allow us to distinguish as\n",
+ "much as possible.\n",
+ "\n",
+ "Then, we can build a true Bayes classifier. Since our features are conditionally independent (when conditioned on the class label) in our case naive Bayes == non-naive Bayes. We examine four feature selection algorithms:\n",
+ " - Feature selection by pre-selection using tree-building and final selection by random subsets + classification error.\n",
+ " - Feature selection via greedy classification error.\n",
+ " - Feature selection via mRMR (maximal relevance, minimal redundancy) using mutual information.\n",
+ " - Feature selection via JMI (Joint Mutual Information)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ed81e076-9ccb-445d-ada9-384b73efb2c5",
+ "metadata": {},
+ "source": [
+ "### Feature selection using trees + classification error\n",
+ "\n",
+ "We can reuse the clustering + tree building approach above and just take the inputs that the greedy tree building choses as the features. However, we can also use more conventional feature selection approaches."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1f24b323-3604-4e34-a880-9dfd611fb245",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "feats_in_tree = Counter()\n",
+ "for node in PreOrderIter(tree.root):\n",
+ " if node.is_leaf:\n",
+ " continue\n",
+ " feats_in_tree[node.dmap_input] += 1\n",
+ "feats_in_tree = set(feats_in_tree.keys())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f1052222-ad32-4e25-97ca-851cc42bf546",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "simulations = 500\n",
+ "retries = 500\n",
+ "nattack = range(50, 350, 50)\n",
+ "nfeats = range(1, 11)\n",
+ "num_workers = 30\n",
+ "\n",
+ "bayes_classifier = BayesClassifier()\n",
+ "tree_random_subsets = RandomFeatures(sorted(feats_in_tree), distributions_mults, num_workers,\n",
+ " simulations, bayes_classifier, top_5_scorer, retries)\n",
+ "\n",
+ "selected_bayes_fromtree = feature_search(nfeats, nattack, tree_random_subsets, restarts=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b7d1f703-5dc6-4c00-b739-11b47205ed75",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plot_performance(bayes_classifier, top_5_scorer, 500, bay, distributions_mults)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6df03e89-d517-4243-bbfb-a5f52de24bb1",
+ "metadata": {},
+ "source": [
+ "### Feature selection via greedy classification\n",
+ "We can also use the classifier itself for feature selection. We iterate over all the divisors to pick the first feature with the best classifier results in simulation. Then we iteratively add features to it."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6e4c2313-83b0-43f8-80d6-14c39be0d9ec",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "simulations = 500\n",
+ "nattack = range(50, 350, 50)\n",
+ "nfeats = range(1, 11)\n",
+ "num_workers = 30\n",
+ "\n",
+ "bayes_classifier = BayesClassifier()\n",
+ "greedy = GreedyFeatures(sorted(feats_in_tree), distributions_mults, num_workers,\n",
+ " simulations, bayes_classifier, top_5_scorer)\n",
+ "\n",
+ "selected_bayes_greedy_fromtree = feature_search(nfeats, nattack, greedy, restarts=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "86f6a319-a61c-41f2-9a7a-561691884198",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plot_performance(bayes_classifier, top_5_scorer, 500, selected_bayes_greedy_fromtree, distributions_mults)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "69ce91fa-7475-41f1-a3ed-bc4dd97d44d6",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "simulations = 500\n",
+ "nattack = range(10, 210, 10)\n",
+ "nfeats = range(1, 16)\n",
+ "num_workers = 30\n",
+ "\n",
+ "bayes_classifier = BayesClassifier()\n",
+ "greedy = GreedyFeatures(allfeats, distributions_mults, num_workers,\n",
+ " simulations, bayes_classifier, top_5_scorer)\n",
+ "\n",
+ "selected_bayes_greedy_fromall = feature_search(nfeats, nattack, greedy, restarts=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "54e58342-f2d8-42e9-ae63-0f0349efc8eb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plot_performance(bayes_classifier, top_5_scorer, 500, gre, distributions_mults)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ea3e2f00-9bdf-4014-9c1e-85fa48304ef3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay\n",
+ "\n",
+ "nattack = 50\n",
+ "nfeats = 5\n",
+ "simulations = 20000\n",
+ "bayes = BayesClassifier(nattack=nattack)\n",
+ "X, y = to_sklearn(distributions_mults, selected_bayes_greedy_fromall[(nattack, nfeats)])\n",
+ "bayes.fit(X, y)\n",
+ "\n",
+ "X_samp, y_samp = make_instance(nattack, simulations, X, y, progress=True)\n",
+ "fig, ax = plt.subplots(figsize=(12,8))\n",
+ "disp = ConfusionMatrixDisplay.from_predictions(y_samp, bayes.predict(X_samp), ax=ax, normalize=\"true\", include_values=False, xticks_rotation=\"vertical\")\n",
+ "\n",
+ "ticks = []\n",
+ "labs = []\n",
+ "kls = None\n",
+ "for i, mult in enumerate(sorted(list(distributions_mults.keys()))):\n",
+ " if kls is None or kls != mult.klass:\n",
+ " ticks.append(i)\n",
+ " labs.append(mult.klass.__name__)\n",
+ " kls = mult.klass\n",
+ "ax.set_xticks(ticks, labs)\n",
+ "ax.set_yticks(ticks, labs)\n",
+ "ax.set_xticks([], minor=True)\n",
+ "ax.set_yticks([], minor=True);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7c030a1c-c13d-401a-bcdb-212c064681e4",
+ "metadata": {},
+ "source": [
+ "### Feature selection via mRMR using mutual information"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "cd769175-c188-411c-af36-2973e7a0ffd1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def mutual_information(class_priors, p_ci_list, n):\n",
+ " \"\"\"\n",
+ " Compute mutual information I(X; Y) for a binomial feature with given class parameters.\n",
+ " \n",
+ " Args:\n",
+ " class_priors (np.array): P(Y=c), shape (num_classes,)\n",
+ " p_ci_list (np.array): Binomial parameters [p_{c,i}] for each class c, shape (num_classes,)\n",
+ " n (int): Number of trials in binomial distribution\n",
+ " \n",
+ " Returns:\n",
+ " float: Mutual information I(X; Y)\n",
+ " \"\"\"\n",
+ " num_classes = len(class_priors)\n",
+ " \n",
+ " # Precompute all PMFs across x and classes\n",
+ " x_values = np.arange(0, n + 1)[:, None] # (n+1, 1)\n",
+ " pmfs = binom.pmf(x_values, n, p_ci_list[None, :]) # Shape: (n+1, num_classes)\n",
+ " \n",
+ " # Compute joint probabilities P(Y=c) * P(X=x | Y=c)\n",
+ " # Multiply class_priors (shape C) with each row of pmfs (each x has shape (C,))\n",
+ " # class_priors[None, :] becomes (1, C), so broadcasting works.\n",
+ " joint_probs = pmfs * class_priors[None, :]\n",
+ " \n",
+ " # Compute P(X=x) for all x\n",
+ " px = np.sum(joint_probs, axis=1)\n",
+ "\n",
+ " # Compute H(Y|X):\n",
+ " h_ygx = 0.0\n",
+ "\n",
+ " for x_idx in range(n + 1):\n",
+ " current_px = px[x_idx]\n",
+ " \n",
+ " if current_px < 1e-9: # Skip negligible probabilities\n",
+ " continue\n",
+ " \n",
+ " cond_probs = joint_probs[x_idx] / current_px # P(Y=c | X=x)\n",
+ " \n",
+ " # Compute entropy H(Y|X=x) using scipy's entropy function\n",
+ " h_x = entropy(cond_probs, base=2)\n",
+ " \n",
+ " h_ygx += current_px * h_x\n",
+ " \n",
+ " # Prior entropy H(Y)\n",
+ " h_y = entropy(class_priors, base=2)\n",
+ "\n",
+ " return h_y - h_ygx\n",
+ "\n",
+ "\n",
+ "def mutual_information_between_features(class_priors, p_ci_i, p_ci_j, n):\n",
+ " \"\"\"\n",
+ " Compute mutual information between two features X_i and X_j.\n",
+ " \n",
+ " Parameters:\n",
+ " class_priors (array): Prior probabilities of each class. Shape: (num_classes,)\n",
+ " p_ci_i (array): Binomial parameters for feature i across classes. Shape: (num_classes,)\n",
+ " p_ci_j (array): Binomial parameters for feature j across classes. Shape: (num_classes,)\n",
+ " n (int): Number of trials for the binomial distribution.\n",
+ " \n",
+ " Returns:\n",
+ " float: Mutual information I(X_i; X_j)\n",
+ " \"\"\"\n",
+ " num_classes = len(class_priors)\n",
+ " x_vals = np.arange(0, n + 1) # Possible values of features\n",
+ " \n",
+ " ### Compute marginal distributions P(Xi=x), P(Xj=y) ###\n",
+ " # PMF for feature i across all classes\n",
+ " pmf_i_per_class = binom.pmf(x_vals[:, None], n, p_ci_i[None, :])\n",
+ " px_i = np.sum(pmf_i_per_class * class_priors[None, :], axis=1)\n",
+ " entropy_xi = entropy(px_i, base=2) if not np.allclose(px_i, 0.0) else 0.0\n",
+ " \n",
+ " # PMF for feature j across all classes\n",
+ " pmf_j_per_class = binom.pmf(x_vals[:, None], n, p_ci_j[None, :])\n",
+ " px_j = np.sum(pmf_j_per_class * class_priors[None, :], axis=1)\n",
+ " entropy_xj = entropy(px_j, base=2) if not np.allclose(px_j, 0.0) else 0.0\n",
+ " \n",
+ " ### Compute joint distribution P(Xi=x, Xj=y) ###\n",
+ " joint_xy = np.zeros((n + 1, n + 1))\n",
+ " \n",
+ " for c in range(num_classes):\n",
+ " pmf_i_c = binom.pmf(x_vals, n, p_ci_i[c])\n",
+ " pmf_j_c = binom.pmf(x_vals, n, p_ci_j[c])\n",
+ " \n",
+ " # Outer product gives joint PMF for class c\n",
+ " outer = np.outer(pmf_i_c, pmf_j_c)\n",
+ " joint_xy += class_priors[c] * outer\n",
+ " \n",
+ " # Compute entropy of the joint distribution\n",
+ " epsilon = 1e-10 # To avoid log(0) issues\n",
+ " non_zero = (joint_xy > epsilon)\n",
+ " entropy_joint = -np.sum(joint_xy[non_zero] * np.log2(joint_xy[non_zero]))\n",
+ " \n",
+ " ### Mutual Information ###\n",
+ " mi = entropy_xi + entropy_xj - entropy_joint\n",
+ " \n",
+ " return mi\n",
+ "\n",
+ "\n",
+ "def conditional_mutual_info(class_priors, XJ_params, XK_params, n):\n",
+ " \"\"\"\n",
+ " Compute I(XK; Y | XJ) using vectorization with broadcasting.\n",
+ " \n",
+ " Args:\n",
+ " XJ_params (array): p_{c,J} for all classes c.\n",
+ " XK_params (array): p_{c,K} for all classes c.\n",
+ " class_priors (array): P(Y=c) for all classes c.\n",
+ " n (int): Number of trials in the binomial distribution.\n",
+ "\n",
+ " Returns:\n",
+ " float: Conditional mutual information I(XK; Y | XJ).\n",
+ " \"\"\"\n",
+ " K = len(class_priors)\n",
+ " x_values = np.arange(n + 1)\n",
+ "\n",
+ " # Precompute PMFs for each class\n",
+ " P_XJ_giv_Y = binom.pmf(x_values[:, None], n, XJ_params) \n",
+ " P_XK_giv_Y = binom.pmf(x_values[:, None], n, XK_params) \n",
+ "\n",
+ " P_XJ_T = P_XJ_giv_Y.T # Shape: (K, n+1)\n",
+ " P_XK_T = P_XK_giv_Y.T\n",
+ "\n",
+ " ######################################################################\n",
+ " ### Compute H(Y | XJ) ###############################################\n",
+ " ######################################################################\n",
+ "\n",
+ " # Calculate P(XJ=xj) for all xj\n",
+ " P_XJ_total = np.dot(class_priors, P_XJ_T)\n",
+ "\n",
+ " # Numerators of posterior probabilities P(Y=c | XJ=xj)\n",
+ " numerators_YgXJ = class_priors[:, None] * P_XJ_T \n",
+ "\n",
+ " valid_mask = P_XJ_total > 1e-9\n",
+ " posterior_YgXJ = np.zeros_like(numerators_YgXJ, dtype=float)\n",
+ " posterior_YgXJ[:, valid_mask] = (\n",
+ " numerators_YgXJ[:, valid_mask] / \n",
+ " P_XJ_total[valid_mask]\n",
+ " )\n",
+ "\n",
+ " log_p = np.log2(posterior_YgXJ + 1e-9) \n",
+ " entropy_terms_HYgXJ = -np.sum(\n",
+ " posterior_YgXJ * log_p, \n",
+ " axis=0,\n",
+ " where=(posterior_YgXJ > 1e-9)\n",
+ " )\n",
+ " \n",
+ " H_Y_given_XJ = np.dot(P_XJ_total, entropy_terms_HYgXJ)\n",
+ "\n",
+ " ######################################################################\n",
+ " ### Compute H(Y | XJ, XK) ###########################################\n",
+ " ######################################################################\n",
+ "\n",
+ " # Broadcast to compute joint PMF P(XJ=xj, XK=xk | Y=c)\n",
+ " P_XJ_giv_Y_T = P_XJ_T[..., None] # Shape: (K, n+1, 1)\n",
+ " P_XK_giv_Y_T = P_XK_T[:, None, :] # Shape: (K, 1, n+1)\n",
+ "\n",
+ " joint_pmf_conditional = (\n",
+ " P_XJ_giv_Y_T * \n",
+ " P_XK_giv_Y_T\n",
+ " ) # Shape: (K, n+1, n+1)\n",
+ "\n",
+ " numerators = class_priors[:, None, None] * joint_pmf_conditional \n",
+ "\n",
+ " denominators = np.sum(numerators, axis=0) # Shape: (n+1, n+1)\n",
+ "\n",
+ " valid_mask_3d = (denominators > 1e-9)[None, ...] # Expand for class dimension\n",
+ "\n",
+ " # Compute posterior probabilities using broadcasting and where\n",
+ " posterior_YgXJXK = numerators / denominators[None, ...]\n",
+ " posterior_YgXJXK = np.where(valid_mask_3d, posterior_YgXJXK, 0.0)\n",
+ "\n",
+ " log_p_joint = np.log2(posterior_YgXJXK + 1e-9) \n",
+ " entropy_terms_HYgXJXK = -np.sum(\n",
+ " posterior_YgXJXK * log_p_joint,\n",
+ " axis=0, # Sum over classes (axis 0 is K)\n",
+ " where=(posterior_YgXJXK > 1e-9),\n",
+ " )\n",
+ "\n",
+ " H_Y_given_XJXK = np.sum(denominators * entropy_terms_HYgXJXK)\n",
+ "\n",
+ " ######################################################################\n",
+ " ### Compute CMI #####################################################\n",
+ " ######################################################################\n",
+ "\n",
+ " cmi = H_Y_given_XJ - H_Y_given_XJXK\n",
+ "\n",
+ " return max(cmi, 0.0) "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "03acb79a-040e-4bc7-a235-fd80dd72addb",
+ "metadata": {},
+ "source": [
+ "#### Relevance and redundancy\n",
+ "First, lets pre-compute the relevance and redundancy metrics for mRMR (also used in JMI). We assume a uniform class prior."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6a1ec802-7e8e-4ac1-beb8-72b6e14e5a6c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def compute_priors(nmults: int):\n",
+ " return np.full(nmults, 1/nmults, dtype=np.float64)\n",
+ "\n",
+ "def compute_probs(feats: list[int], mults_map: dict[MultIdent, ProbMap]):\n",
+ " probs, _ = to_sklearn(mults_map, feats)\n",
+ " return probs.T\n",
+ "\n",
+ "def compute_relevance(feats: list[int], priors, probs, nattack: int):\n",
+ " relevance = np.zeros(nallfeats, dtype=np.float64)\n",
+ " for i, divisor in enumerate(tqdm(feats)):\n",
+ " mi = mutual_information(priors, probs[i, ], nattack)\n",
+ " relevance[i] = mi\n",
+ " return relevance\n",
+ "\n",
+ "def compute_redundancy(feats: list[int], priors, probs, nattack: int, num_workers: int = 30):\n",
+ " nallfeats = len(feats)\n",
+ " redundancy = np.zeros((nallfeats, nallfeats), dtype=np.float64)\n",
+ " with TaskExecutor(max_workers=num_workers) as pool:\n",
+ " for i in trange(nallfeats):\n",
+ " for j in range(nallfeats):\n",
+ " if i < j:\n",
+ " continue\n",
+ " pool.submit_task((i, j),\n",
+ " mutual_information_between_features,\n",
+ " priors, probs[i, ], probs[j, ], nattack)\n",
+ " for (i, j), future in pool.as_completed():\n",
+ " mi = future.result()\n",
+ " redundancy[i][j] = mi\n",
+ " redundancy[j][i] = mi\n",
+ " return redundancy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4e14ed9b-56af-4b36-9615-0bcee45e4b40",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nattack = 100\n",
+ "\n",
+ "priors = compute_priors(nmults)\n",
+ "probs = compute_probs(allfeats, distributions_mults)\n",
+ "relevance = compute_relevance(allfeats, priors, probs, nattack)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7a49f7b0-f9cf-4862-8638-0b4ba5dd4f07",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "redundancy = compute_redundancy(allfeats, priors, probs, nattack)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "95e0b366-188f-4c65-b92c-e9b9587a8083",
+ "metadata": {},
+ "source": [
+ "Store the relevance and redundancy arrays."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d64cf0a1-a83f-4837-8113-5de536fb0f09",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(\"relevance.pickle\", \"wb\") as f:\n",
+ " pickle.dump(relevance, f)\n",
+ "with open(\"redundancy.pickle\", \"wb\") as f:\n",
+ " pickle.dump(redundancy, f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d3223edc-f3b2-4137-bc1f-75e311ff075e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class MRMRFeatures(FeatureSelector):\n",
+ " def __init__(self,\n",
+ " allfeats: list[int],\n",
+ " mults: dict[MultIdent, ProbMap],\n",
+ " num_workers: int):\n",
+ " self.allfeats = allfeats\n",
+ " self.mults = mults\n",
+ " self.num_workers = num_workers\n",
+ "\n",
+ " def prepare(self, nattack: int):\n",
+ " self.nattack = nattack\n",
+ "\n",
+ " def select(self, nfeats: int, startwith: list[int] = None) -> list[int]:\n",
+ " pass\n",
+ "\n",
+ "\n",
+ "\n",
+ "def mrmr_selection(relevance, redundancy, nfeats):\n",
+ " \"\"\"\n",
+ " Select top features using mRMR.\n",
+ " \n",
+ " Returns:\n",
+ " indices of selected features.\n",
+ " \"\"\"\n",
+ " selected_indices = []\n",
+ " remaining_indices = list(range(nallfeats))\n",
+ " \n",
+ " # Initialize by selecting the most relevant feature\n",
+ " first_feature_idx = np.argmax(relevance)\n",
+ " selected_indices.append(first_feature_idx)\n",
+ " remaining_indices.remove(first_feature_idx)\n",
+ " \n",
+ " while len(selected_indices) < nfeats:\n",
+ " candidates_scores = []\n",
+ " \n",
+ " for candidate in remaining_indices:\n",
+ " # Compute mRMR score: relevance - average redundancy with selected features\n",
+ " current_relevance = relevance[candidate]\n",
+ " \n",
+ " avg_red = 0.0\n",
+ " if len(selected_indices) > 0:\n",
+ " sum_red = np.sum(redundancy[candidate][selected_indices])\n",
+ " avg_red = sum_red / len(selected_indices)\n",
+ " \n",
+ " score = current_relevance - avg_red\n",
+ " candidates_scores.append(score)\n",
+ " \n",
+ " # Select the candidate with highest score\n",
+ " best_candidate_idx = remaining_indices[np.argmax(candidates_scores)]\n",
+ " selected_indices.append(best_candidate_idx)\n",
+ " remaining_indices.remove(best_candidate_idx)\n",
+ " \n",
+ " return selected_indices"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5604d599-ef63-49fc-a7c8-65fa90c15620",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "selected_mrmr = [allfeats[i] for i in mrmr_selection(relevance, redundancy, nfeats=15)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "37cce925-8479-43e9-ad0c-b1d42be4991c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mrmrs = {}\n",
+ "for nfeats in trange(1, 41, 1):\n",
+ " f = [allfeats[i] for i in mrmr_selection(relevance, redundancy, nfeats=nfeats)]\n",
+ " for nattack in range(5, 105, 5):\n",
+ " mrmrs[(nattack, nfeats)] = f\n",
+ "\n",
+ "bayes_classifier = BayesClassifier()\n",
+ "plot_performance(bayes_classifier, avg_rank_scorer, 2000, mrmrs, distributions_mults)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2a8bef1c-800b-453b-b239-29ea1481213b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.close()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a12b75cd-3c62-4b87-a7df-f0c5f7748386",
+ "metadata": {},
+ "source": [
+ "### Feature selection via JMI"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d8b3f827-baef-49c2-af60-0be74ff0efa2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def jmi_selection(features_params_list, class_priors, n_trials, relevance, nfeats):\n",
+ " \"\"\"\n",
+ " Select top features using JMI.\n",
+ " \n",
+ " Returns:\n",
+ " indices of selected features.\n",
+ " \"\"\"\n",
+ " selected_indices = []\n",
+ " remaining_indices = list(range(nallfeats))\n",
+ " \n",
+ " # Initialize by selecting the most relevant feature\n",
+ " first_feature_idx = np.argmax(relevance)\n",
+ " selected_indices.append(first_feature_idx)\n",
+ " remaining_indices.remove(first_feature_idx)\n",
+ " \n",
+ " while len(selected_indices) < nfeats:\n",
+ " candidates_scores = []\n",
+ " \n",
+ " for candidate in tqdm(remaining_indices):\n",
+ " # Compute mRMR score: relevance - average redundancy with selected features\n",
+ " current_relevance = relevance[candidate]\n",
+ " \n",
+ " sum_cmi = 0.0\n",
+ " for selected in selected_indices:\n",
+ " XJ_params = features_params_list[selected]\n",
+ " XK_params = features_params_list[candidate]\n",
+ " \n",
+ " cmi_val = conditional_mutual_info(\n",
+ " class_priors=class_priors,\n",
+ " XJ_params=XJ_params,\n",
+ " XK_params=XK_params,\n",
+ " n=n_trials\n",
+ " )\n",
+ " sum_cmi += cmi_val\n",
+ " avg_cmi = sum_cmi / len(selected_indices)\n",
+ " score = current_relevance + avg_cmi\n",
+ " candidates_scores.append(score)\n",
+ " \n",
+ " # Select the candidate with highest score\n",
+ " best_candidate_idx = remaining_indices[np.argmax(candidates_scores)]\n",
+ " selected_indices.append(best_candidate_idx)\n",
+ " remaining_indices.remove(best_candidate_idx)\n",
+ " \n",
+ " return selected_indices"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "216a22d8-f27f-4584-8769-bd575ed538b1",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6739192e-879a-4862-b862-6a8fc3939b73",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nattack = 100\n",
+ "selected_jmi = [allfeats[i] for i in jmi_selection(probs, priors, nattack, relevance, nfeats=40)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7c65b785-0131-4d41-9131-b519bf446803",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "selected_jmi"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "899dd2fa-409a-4e6c-aa13-f2d6b0088627",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mrmrs[(100,5)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9a894db6-0e4d-49c6-b259-cc9bd28e8c8f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "jmis = {}\n",
+ "for nfeats in range(1, 6, 1):\n",
+ " f = selected_jmi[:nfeats]\n",
+ " for nattack in range(5, 105, 5):\n",
+ " jmis[(nattack, nfeats)] = f\n",
+ "\n",
+ "bayes_classifier = BayesClassifier()\n",
+ "plot_performance(bayes_classifier, avg_rank_scorer, 2000, jmis, distributions_mults)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4c29cc75-7bee-47d0-b0cb-c0f5a7ec5da0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.close()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c5598117-4c54-4721-9fc5-68432fb8e230",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}