From 545a4276613a55043b0c03d7c1cd52becc84dc6e Mon Sep 17 00:00:00 2001 From: J08nY Date: Wed, 26 Mar 2025 23:39:45 +0100 Subject: Add better distinguishers and feature selectors. --- epare/distinguish.ipynb | 5667 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 5490 insertions(+), 177 deletions(-) diff --git a/epare/distinguish.ipynb b/epare/distinguish.ipynb index 9d06f36..e55e828 100644 --- a/epare/distinguish.ipynb +++ b/epare/distinguish.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "bc1528b8-61cd-4219-993f-e3f1ac79e801", "metadata": {}, "outputs": [], @@ -28,8 +28,8 @@ "from scipy.stats import binom, entropy\n", "from scipy.spatial import distance\n", "from tqdm.auto import tqdm, trange\n", - "from statsmodels.stats.proportion import proportion_confint\n", "from anytree import PreOrderIter, Walker\n", + "from matplotlib import pyplot as plt\n", "\n", "from pyecsca.ec.mult import *\n", "from pyecsca.misc.utils import TaskExecutor, silent\n", @@ -38,57 +38,6 @@ "from common import *" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "b3814d6d-af42-4b9a-bbf2-6dbbdddd92d4", - "metadata": {}, - "outputs": [], - "source": [ - "def conf_interval(p: float, samples: int, alpha: float = 0.05) -> tuple[float, float]:\n", - " return proportion_confint(round(p*samples), samples, alpha, method=\"wilson\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f0eac736-be48-4925-accf-1ca8ff6aa065", - "metadata": {}, - "outputs": [], - "source": [ - "def powers_of(k, max_power=20):\n", - " return [k**i for i in range(1, max_power)]\n", - "\n", - "def prod_combine(one, other):\n", - " return [a * b for a, b in itertools.product(one, other)]\n", - "\n", - "small_primes = [3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199]\n", - "medium_primes = [211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397]\n", - "large_primes = [401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997]\n", - "all_integers = list(range(1, 400))\n", - "all_even = list(range(2, 400, 2))\n", - "all_odd = list(range(1, 400, 2))\n", - "all_primes = small_primes + medium_primes + large_primes\n", - "\n", - "divisor_map = {\n", - " \"small_primes\": small_primes,\n", - " \"medium_primes\": medium_primes,\n", - " \"large_primes\": large_primes,\n", - " \"all_primes\": all_primes,\n", - " \"all_integers\": all_integers,\n", - " \"all_even\": all_even,\n", - " \"all_odd\": all_odd,\n", - " \"powers_of_2\": powers_of(2),\n", - " \"powers_of_2_large\": powers_of(2, 256),\n", - " \"powers_of_2_large_3\": [i * 3 for i in powers_of(2, 256)],\n", - " \"powers_of_2_large_p1\": [i + 1 for i in powers_of(2, 256)],\n", - " \"powers_of_2_large_m1\": [i - 1 for i in powers_of(2, 256)],\n", - " \"powers_of_2_large_pmautobus\": sorted(set([i + j for i in powers_of(2, 256) for j in range(-5,5) if i+j > 0])),\n", - " \"powers_of_3\": powers_of(3),\n", - "}\n", - "divisor_map[\"all\"] = list(sorted(set().union(*[v for v in divisor_map.values()])))" - ] - }, { "cell_type": "markdown", "id": "4868c083-8073-453d-b508-704fcb6d6f2a", @@ -100,19 +49,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "ccb00342-3c48-49c9-bedf-2341e5eae3a2", "metadata": {}, "outputs": [], "source": [ "divisor_name = \"all\"\n", "kind = \"all\"\n", - "allfeats = divisor_map[divisor_name]" + "allfeats = list(filter(lambda feat: feat not in (1,2,3,4,5), divisor_map[divisor_name]))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "3dbac9be-d098-479a-8ca2-f531f6668f7c", "metadata": {}, "outputs": [], @@ -130,12 +79,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "38c81e38-a37c-4e58-ac9e-927d14dad458", "metadata": {}, "outputs": [], "source": [ - "nmults = len(distributions_mults.keys())\n", + "allmults = list(distributions_mults.keys())\n", + "nmults = len(allmults)\n", "nallfeats = len(allfeats)" ] }, @@ -157,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "2307bf7c-4fac-489d-8527-7ddbf536a148", "metadata": {}, "outputs": [], @@ -168,13 +118,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "0b85fad7-392f-4701-9329-d75d39736bbb", "metadata": {}, "outputs": [], "source": [ "# Now go over all divisors, cluster based on overlapping CI for given n?\n", - "io_map = {mult:{} for mult in distributions_mults.keys()}\n", + "io_map = {mult:{} for mult in allmults}\n", "for divisor in allfeats:\n", " prev_ci_low = None\n", " prev_ci_high = None\n", @@ -217,17 +167,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "06104104-b612-40e9-bc1d-646356a13381", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total configs: 378, (6_048 bytes)\n", + "Rows: 378, (9_722_292 bytes)\n", + "Inputs: 3215\n", + "Codomain: 11998\n", + "None in codomain: False\n" + ] + } + ], "source": [ "print(dmap.describe())" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "eb8672ca-b76b-411d-b514-2387b555f184", "metadata": {}, "outputs": [], @@ -238,44 +200,617 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "ccba09b0-31c3-450b-af30-efaa64329743", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total configs: 378, (14_344 bytes)\n", + "Rows: 347, (8_924_972 bytes)\n", + "Inputs: 3215\n", + "Codomain: 11998\n", + "None in codomain: False\n" + ] + } + ], "source": [ "print(dmap.describe())" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "id": "5735e7d4-149c-4184-96f7-dcfd6017fbad", "metadata": {}, "outputs": [], "source": [ "# build a tree\n", "with silent():\n", - " tree = Tree.build(set(distributions_mults.keys()), dmap)" + " tree = Tree.build(set(allmults), dmap)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "d41093df-32c4-450d-922d-5ad042539397", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dmaps: 1\n", + "Total cfgs: 378\n", + "Height: 7\n", + "Size: 533\n", + "Leaves: 347\n", + "Precise: False\n", + "Leaf sizes: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 4, 5]\n", + "Leaf depths: [1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7]\n", + "Average leaf depth: 4.118\n", + "Average leaf size: 1.089\n", + "Random walk leaf depth: 2.642\n", + "Random walk leaf size: 1.041\n", + "Mean result depth: 4.153\n", + "Mean result size: 1.270\n" + ] + } + ], "source": [ "print(tree.describe())" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "de577429-d87c-4967-be17-75cbb378860c", "metadata": { "scrolled": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "378\n", + "├── 1\n", + "├── 10\n", + "│ ├── 1\n", + "│ ├── 1\n", + "│ ├── 1\n", + "│ ├── 1\n", + "│ ├── 4\n", + "│ │ ├── 1\n", + "│ │ └── 3\n", + "│ │ ├── 1\n", + "│ │ └── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 1\n", + "│ └── 1\n", + "├── 10\n", + "│ ├── 1\n", + "│ ├── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 4\n", + "│ │ ├── 1\n", + "│ │ └── 3\n", + "│ │ ├── 1\n", + "│ │ └── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ └── 3\n", + "│ ├── 1\n", + "│ ├── 1\n", + "│ └── 1\n", + "├── 49\n", + "│ ├── 1\n", + "│ ├── 13\n", + "│ │ ├── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 1\n", + "│ │ ├── 3\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 10\n", + "│ │ ├── 1\n", + "│ │ ├── 4\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 1\n", + "│ │ └── 4\n", + "│ │ ├── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 1\n", + "│ ├── 11\n", + "│ │ ├── 5\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ └── 3\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 5\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 4\n", + "│ │ └── 1\n", + "│ ├── 4\n", + "│ │ ├── 1\n", + "│ │ └── 3\n", + "│ ├── 8\n", + "│ │ ├── 1\n", + "│ │ ├── 1\n", + "│ │ ├── 3\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 3\n", + "│ │ ├── 1\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 1\n", + "│ └── 1\n", + "├── 20\n", + "│ ├── 4\n", + "│ │ ├── 1\n", + "│ │ ├── 1\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 1\n", + "│ ├── 1\n", + "│ ├── 1\n", + "│ ├── 3\n", + "│ │ ├── 1\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 1\n", + "│ ├── 1\n", + "│ ├── 5\n", + "│ │ ├── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 1\n", + "│ ├── 1\n", + "│ └── 1\n", + "├── 53\n", + "│ ├── 20\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 4\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 1\n", + "│ │ ├── 5\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 3\n", + "│ │ ├── 1\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 19\n", + "│ │ ├── 6\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ └── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 7\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 3\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ └── 3\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 3\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 3\n", + "│ │ ├── 1\n", + "│ │ └── 2\n", + "│ ├── 5\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 1\n", + "│ │ └── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 7\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 1\n", + "│ │ └── 2\n", + "│ └── 2\n", + "│ ├── 1\n", + "│ └── 1\n", + "├── 134\n", + "│ ├── 23\n", + "│ │ ├── 5\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 6\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ └── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 1\n", + "│ │ ├── 1\n", + "│ │ ├── 1\n", + "│ │ ├── 1\n", + "│ │ ├── 1\n", + "│ │ ├── 1\n", + "│ │ └── 6\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 6\n", + "│ │ ├── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 1\n", + "│ │ └── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 7\n", + "│ │ ├── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 1\n", + "│ ├── 24\n", + "│ │ ├── 3\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 7\n", + "│ │ │ ├── 2\n", + "│ │ │ └── 5\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 4\n", + "│ │ ├── 2\n", + "│ │ ├── 7\n", + "│ │ │ ├── 4\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ └── 3\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 5\n", + "│ │ ├── 3\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 2\n", + "│ │ └── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 46\n", + "│ │ ├── 11\n", + "│ │ │ ├── 3\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ ├── 3\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ ├── 4\n", + "│ │ │ │ ├── 2\n", + "│ │ │ │ │ ├── 1\n", + "│ │ │ │ │ └── 1\n", + "│ │ │ │ └── 2\n", + "│ │ │ └── 1\n", + "│ │ ├── 8\n", + "│ │ │ ├── 5\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 4\n", + "│ │ │ └── 3\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 2\n", + "│ │ ├── 1\n", + "│ │ ├── 10\n", + "│ │ │ ├── 3\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 2\n", + "│ │ │ └── 7\n", + "│ │ │ ├── 5\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 9\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 4\n", + "│ │ │ │ ├── 2\n", + "│ │ │ │ │ ├── 1\n", + "│ │ │ │ │ └── 1\n", + "│ │ │ │ └── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 3\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 2\n", + "│ │ └── 7\n", + "│ │ ├── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 1\n", + "│ ├── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 6\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 1\n", + "│ ├── 12\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 1\n", + "│ ├── 1\n", + "│ ├── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 1\n", + "│ └── 2\n", + "│ ├── 1\n", + "│ └── 1\n", + "├── 97\n", + "│ ├── 27\n", + "│ │ ├── 7\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 3\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 1\n", + "│ │ ├── 7\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 4\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 1\n", + "│ ├── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 11\n", + "│ │ ├── 3\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 4\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ └── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 4\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 6\n", + "│ │ ├── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 3\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 1\n", + "│ ├── 13\n", + "│ │ ├── 8\n", + "│ │ │ ├── 4\n", + "│ │ │ │ ├── 2\n", + "│ │ │ │ └── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ └── 4\n", + "│ │ │ ├── 3\n", + "│ │ │ │ ├── 2\n", + "│ │ │ │ └── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 4\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ └── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 1\n", + "│ ├── 27\n", + "│ │ ├── 11\n", + "│ │ │ ├── 4\n", + "│ │ │ │ ├── 2\n", + "│ │ │ │ │ ├── 1\n", + "│ │ │ │ │ └── 1\n", + "│ │ │ │ └── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ ├── 5\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 4\n", + "│ │ │ └── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 9\n", + "│ │ │ ├── 4\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 3\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ └── 5\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ ├── 2\n", + "│ │ │ │ ├── 1\n", + "│ │ │ │ └── 1\n", + "│ │ │ └── 1\n", + "│ │ ├── 6\n", + "│ │ │ ├── 3\n", + "│ │ │ └── 3\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 2\n", + "│ │ │ ├── 1\n", + "│ │ │ └── 1\n", + "│ │ └── 1\n", + "│ ├── 1\n", + "│ ├── 1\n", + "│ ├── 1\n", + "│ ├── 3\n", + "│ │ ├── 1\n", + "│ │ └── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ └── 5\n", + "│ ├── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ ├── 2\n", + "│ │ ├── 1\n", + "│ │ └── 1\n", + "│ └── 1\n", + "└── 4\n", + " ├── 1\n", + " ├── 1\n", + " ├── 1\n", + " └── 1\n" + ] + } + ], "source": [ "print(tree.render_basic())" ] @@ -302,7 +837,7 @@ " successes = 0\n", " pathiness = 0\n", " for i in range(simulations):\n", - " true_mult = random.choice(list(distributions_mults.keys()))\n", + " true_mult = random.choice(allmults)\n", " probmap = distributions_mults[true_mult]\n", " node = tree.root\n", " while True:\n", @@ -363,6 +898,424 @@ "We can reuse the clustering + tree building approach above and just take the inputs that the greedy tree building choses as the features. However, we can also use more conventional feature selection approaches." ] }, + { + "cell_type": "code", + "execution_count": 86, + "id": "cc1a9956-bc8c-47cf-b6ec-093c6cf85c7d", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[3]\n", + "[[-0.53851648 -0.2236068 -0. -0.50990195]\n", + " [-0.14142136 -0.42426407 -0.51961524 -0.80622577]]\n", + "[3]\n", + "[[1.04188836e-27 1.24861430e-05 9.99987514e-01 1.14579810e-34]]\n", + "1.0\n", + "1.0\n", + "1.0\n", + "1.0\n" + ] + } + ], + "source": [ + "from sklearn.base import BaseEstimator, ClassifierMixin\n", + "from sklearn.utils.validation import validate_data, check_is_fitted\n", + "from sklearn.utils.multiclass import unique_labels\n", + "from scipy.special import logsumexp\n", + "from sklearn.metrics import euclidean_distances, top_k_accuracy_score, make_scorer, accuracy_score\n", + "\n", + "\n", + "class EuclidClassifier(ClassifierMixin, BaseEstimator):\n", + " def __init__(self, *, nattack=100):\n", + " self.nattack = nattack\n", + "\n", + " def fit(self, X, y):\n", + " X, y = validate_data(self, X, y)\n", + " if not np.logical_and(X >= 0, X <= 1).all():\n", + " raise TypeError(\"Expects valid probabilities in X.\")\n", + " self.classes_ = unique_labels(y)\n", + " if len(self.classes_) != len(y):\n", + " raise ValueError(\"Expects only one sample per class containing the binomial probabilities.\")\n", + " self.X_ = X\n", + " self.y_ = y\n", + " return self\n", + "\n", + " def decision_function(self, X):\n", + " check_is_fitted(self)\n", + " X = validate_data(self, X, reset=False)\n", + " distances = euclidean_distances(X / self.nattack, self.X_)\n", + " return -distances\n", + "\n", + " def predict(self, X):\n", + " check_is_fitted(self)\n", + " X = validate_data(self, X, reset=False)\n", + " distances = euclidean_distances(X / self.nattack, self.X_)\n", + " closest = np.argmin(distances, axis=1)\n", + " return self.classes_[closest]\n", + "\n", + "\n", + "class BayesClassifier(ClassifierMixin, BaseEstimator):\n", + " def __init__(self, *, nattack=100):\n", + " self.nattack = nattack\n", + "\n", + " def fit(self, X, y):\n", + " # X has (nmults = nsamples, nfeats)\n", + " X, y = validate_data(self, X, y)\n", + " if not np.logical_and(X >= 0, X <= 1).all():\n", + " raise TypeError(\"Expects valid probabilities in X.\")\n", + " self.classes_ = unique_labels(y)\n", + " if len(self.classes_) != len(y):\n", + " raise ValueError(\"Expects only one sample per class containing the binomial probabilities.\")\n", + " self.X_ = X\n", + " self.y_ = y\n", + " return self\n", + "\n", + " def decision_function(self, X):\n", + " check_is_fitted(self)\n", + " X = validate_data(self, X, reset=False)\n", + " # We have a uniform prior, so we can ignore it.\n", + " probas = np.zeros((len(X), len(self.classes_)))\n", + " for i, row in enumerate(X):\n", + " p = binom(self.nattack, self.X_).logpmf(row)\n", + " s = np.sum(p, axis=1)\n", + " log_prob_x = logsumexp(s)\n", + " res = np.exp(s - log_prob_x)\n", + " probas[i, ] = res\n", + " return probas\n", + "\n", + " def predict_proba(self, X):\n", + " return self.decision_function(X)\n", + "\n", + " def predict(self, X):\n", + " check_is_fitted(self)\n", + " X = validate_data(self, X, reset=False)\n", + " # We have a uniform prior, so we can ignore it.\n", + " results = np.empty(len(X), dtype=self.classes_.dtype)\n", + " for i, row in enumerate(X):\n", + " p = binom(self.nattack, self.X_).logpmf(row)\n", + " s = np.sum(p, axis=1)\n", + " most_likely = np.argmax(s)\n", + " results[i] = self.classes_[most_likely]\n", + " return results\n", + "\n", + "\n", + "def to_sklearn(mults_map: dict[MultIdent, ProbMap], feats: list[int]):\n", + " nfeats = len(feats)\n", + " nmults = len(mults_map)\n", + " classes = np.arange(nmults, dtype=np.uint32)\n", + " probs = np.zeros((nmults, nfeats), dtype=np.float64)\n", + " for i, divisor in enumerate(feats):\n", + " for j, probmap in enumerate(mults_map.values()):\n", + " probs[j, i] = probmap[divisor]\n", + " return probs, classes\n", + "\n", + "\n", + "def evaluate_classifier(nattack: int,\n", + " simulations: int,\n", + " X,\n", + " y,\n", + " classifier,\n", + " scorer):\n", + " #X, y = to_sklearn(mults, feats)\n", + " nmults, nfeats = X.shape\n", + " classifier.set_params(nattack=nattack)\n", + " classifier.fit(X, y)\n", + "\n", + " X_samp = np.zeros((simulations, nfeats), dtype=np.uint32)\n", + " y_samp = np.zeros(simulations, dtype=np.uint32)\n", + "\n", + " for i in range(simulations):\n", + " if i < nmults and simulations >= nmults:\n", + " j = i\n", + " else:\n", + " j = random.randrange(nmults)\n", + " X_samp[i] = binom(nattack, X[j]).rvs()\n", + " y_samp[i] = j\n", + "\n", + " return scorer(classifier, X_samp, y_samp)\n", + "\n", + "\n", + "def average_rank_score(y_true, y_pred, labels=None):\n", + " y_true = np.asarray(y_true)\n", + " y_pred = np.asarray(y_pred)\n", + " \n", + " n_samples, n_classes = y_pred.shape\n", + " if labels is not None:\n", + " labels = np.asarray(labels)\n", + " if len(labels) != n_classes:\n", + " raise ValueError()\n", + " label_indexes = np.searchsorted(labels, y_true)\n", + " indexes = np.where(labels[label_indexes] == y_true, label_indexes, -1)\n", + " else:\n", + " indexes = y_true\n", + " true_scores = y_pred[np.arange(n_samples), indexes]\n", + " \n", + " counts_higher = np.sum(y_pred > true_scores[:, None], axis=1)\n", + " \n", + " ranks = counts_higher + 1\n", + " \n", + " return ranks.mean()\n", + "\n", + "X = np.array([[0.7, 0.7, 0.1], [0.3, 0.7, 0.1], [0.2, 0.5, 0.1], [0.1, 0.1, 0.4]])\n", + "y = np.array([1, 2, 3, 4])\n", + "\n", + "euc = EuclidClassifier(nattack=100).fit(X, y)\n", + "label = euc.predict(np.array([20, 50, 10]).reshape(1, -1))\n", + "dec = euc.decision_function(np.array([[20, 50, 10], [70, 60, 20]]))\n", + "print(label)\n", + "print(dec)\n", + "\n", + "clf = BayesClassifier(nattack=100).fit(X, y)\n", + "label = clf.predict(np.array([20, 50, 10]).reshape(1, -1))\n", + "ps = clf.predict_proba(np.array([20, 50, 10]).reshape(1, -1))\n", + "print(label)\n", + "print(ps)\n", + "\n", + "\n", + "acc = top_k_accuracy_score(np.array([3, 1]),\n", + " euc.decision_function(np.array([[20, 50, 10], [70, 60, 20]])),\n", + " labels = [1, 2, 3, 4],\n", + " k=1)\n", + "print(acc)\n", + "acc = top_k_accuracy_score(np.array([3, 1]),\n", + " clf.predict_proba(np.array([[20, 50, 10], [70, 60, 20]])),\n", + " labels = [1, 2, 3, 4],\n", + " k=1)\n", + "print(acc)\n", + "\n", + "avg = average_rank_score(np.array([2, 0]),\n", + " euc.decision_function(np.array([[20, 50, 10], [70, 60, 20]])))\n", + "print(avg)\n", + "avg = average_rank_score(np.array([3, 1]),\n", + " euc.decision_function(np.array([[20, 50, 10], [70, 60, 20]])),\n", + " labels = [1, 2, 3, 4])\n", + "print(avg)\n", + "\n", + "accuracy_scorer = make_scorer(\n", + " top_k_accuracy_score,\n", + " greater_is_better=True,\n", + " response_method=(\"decision_function\", \"predict_proba\"),\n", + ")\n", + "\n", + "#accuracy_scorer.__str__ = lambda self: \"Accuracy\"\n", + "\n", + "top_5_scorer = make_scorer(\n", + " top_k_accuracy_score,\n", + " greater_is_better=True,\n", + " response_method=(\"decision_function\", \"predict_proba\"),\n", + " k=5\n", + ")\n", + "\n", + "#top_5_scorer.__str__ = lambda self: \"Top-5 accuracy\"\n", + "\n", + "top_10_scorer = make_scorer(\n", + " top_k_accuracy_score,\n", + " greater_is_better=True,\n", + " response_method=(\"decision_function\", \"predict_proba\"),\n", + " k=10\n", + ")\n", + "\n", + "#top_10_scorer.__str__ = lambda self: \"Top-10 accuracy\"\n", + "\n", + "avg_rank_scorer = make_scorer(\n", + " average_rank_score,\n", + " greater_is_better=False,\n", + " response_method=(\"decision_function\", \"predict_proba\"),\n", + ")\n", + "\n", + "#avg_rank_scorer.__str__ = lambda self: \"Average rank\"" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "a9fae775-797f-4efe-ac28-d83a8c905372", + "metadata": {}, + "outputs": [], + "source": [ + "class FeatureSelector:\n", + " def __init__(self,\n", + " allfeats: list[int],\n", + " mults: dict[MultIdent, ProbMap],\n", + " num_workers: int):\n", + " self.allfeats = allfeats\n", + " self.mults = mults\n", + " self.num_workers = num_workers\n", + "\n", + " def prepare(self, nattack: int):\n", + " self.nattack = nattack\n", + "\n", + " def select(self, nfeats: int, startwith: list[int] = None) -> list[int]:\n", + " pass\n", + "\n", + "class FeaturesByClassification(FeatureSelector):\n", + " def __init__(self,\n", + " allfeats: list[int],\n", + " mults: dict[MultIdent, ProbMap],\n", + " num_workers: int,\n", + " simulations: int,\n", + " classifier,\n", + " scorer):\n", + " super().__init__(allfeats, mults, num_workers)\n", + " self.simulations = simulations\n", + " self.classifier = classifier\n", + " self.scorer = scorer\n", + "\n", + "class RandomFeatures(FeaturesByClassification):\n", + "\n", + " def __init__(self,\n", + " allfeats: list[int],\n", + " mults: dict[MultIdent, ProbMap],\n", + " num_workers: int,\n", + " simulations: int,\n", + " classifier,\n", + " scorer,\n", + " retries: int):\n", + " super().__init__(allfeats, mults, num_workers, simulations, classifier, scorer)\n", + " self.retries = retries\n", + " \n", + " def _select_random(self, nfeats: int, startwith: list[int] = None) -> list[int]:\n", + " if startwith is None:\n", + " startwith = []\n", + " toselect = nfeats - len(startwith)\n", + " if toselect > 0:\n", + " available_feats = list(filter(lambda feat: feat not in startwith, self.allfeats))\n", + " selected = random.sample(available_feats, toselect)\n", + " return startwith + selected\n", + " elif toselect < 0:\n", + " return random.sample(startwith, nfeats)\n", + " else:\n", + " return startwith\n", + "\n", + " def select(self, nfeats: int, startwith: list[int] = None) -> tuple[list[int], float]:\n", + " with TaskExecutor(max_workers=self.num_workers) as pool:\n", + " feat_map = []\n", + " for i in range(self.retries):\n", + " feats = self._select_random(nfeats, startwith)\n", + " X, y = to_sklearn(self.mults, feats)\n", + " feat_map.append(feats)\n", + " pool.submit_task(i,\n", + " evaluate_classifier,\n", + " self.nattack, self.simulations,\n", + " X, y, self.classifier, self.scorer)\n", + " best_score = None\n", + " best_feats = None\n", + " for i, future in tqdm(pool.as_completed(), total=len(pool.tasks), desc=\"retries\", leave=False):\n", + " score = future.result()\n", + " #print(i, feat_map[i], score)\n", + " if best_score is None or score > best_score:\n", + " best_score = score\n", + " best_feats = feat_map[i]\n", + " return best_feats, best_score\n", + "\n", + "\n", + "class GreedyFeatures(FeaturesByClassification):\n", + "\n", + " def select(self, nfeats: int, startwith: list[int] = None) -> tuple[list[int], float]:\n", + " if startwith is None:\n", + " startwith = []\n", + " toselect = nfeats - len(startwith)\n", + " if toselect < 0:\n", + " raise ValueError(\"No features to select.\")\n", + " available_feats = list(filter(lambda feat: feat not in startwith, self.allfeats))\n", + " current = list(startwith)\n", + " with TaskExecutor(max_workers=self.num_workers) as pool:\n", + " while toselect > 0:\n", + " for feat in available_feats:\n", + " feats = current + [feat]\n", + " X, y = to_sklearn(self.mults, feats)\n", + " pool.submit_task(feat,\n", + " evaluate_classifier,\n", + " self.nattack, self.simulations,\n", + " X, y, self.classifier, self.scorer)\n", + " best_score = None\n", + " best_feat = None\n", + " for feat, future in tqdm(pool.as_completed(), total=len(pool.tasks), leave=False):\n", + " score = future.result()\n", + " if best_score is None or score > best_score:\n", + " best_score = score\n", + " best_feat = feat\n", + " current.append(best_feat)\n", + " toselect -= 1\n", + " return current, best_score\n", + "\n", + "\n", + "def feature_search(feat_range, nattack_range, selector, restarts=False):\n", + " if isinstance(feat_range, int):\n", + " feat_range = [feat_range]\n", + " if isinstance(nattack_range, int):\n", + " nattack_range = [nattack_range]\n", + " results = {}\n", + " for nattack in tqdm(nattack_range, desc=\"nattack\", smoothing=0):\n", + " selector.prepare(nattack)\n", + " feats = []\n", + " for nfeats in tqdm(feat_range, desc=\"nfeats\", leave=False):\n", + " feats, score = selector.select(nfeats, [] if restarts else feats)\n", + " results[(nattack, nfeats)] = feats\n", + " print(f\"{nattack},{nfeats}: {feats}, {score}\")\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "f1c0bebe-c519-4241-a163-63613b929db2", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_performance(classifier, scorer, simulations, feature_map, mults):\n", + " scores = {}\n", + " for (nattack, nfeats), feats in tqdm(feature_map.items(), desc=\"Evaluating\", leave=False):\n", + " X, y = to_sklearn(mults, feats)\n", + " score = evaluate_classifier(nattack, simulations, X, y, classifier, scorer)\n", + " scores[(nattack, nfeats)] = score\n", + "\n", + " x_coords = [k[0] for k in scores.keys()]\n", + " y_coords = [k[1] for k in scores.keys()]\n", + " \n", + " x_unique = sorted(set(x_coords))\n", + " y_unique = sorted(set(y_coords))\n", + "\n", + " heatmap_data = np.zeros((len(y_unique), len(x_unique)))\n", + " \n", + " for (x, y), score in scores.items():\n", + " x_index = x_unique.index(x)\n", + " y_index = y_unique.index(y)\n", + " heatmap_data[y_index, x_index] = score\n", + "\n", + " x_mesh, y_mesh = np.meshgrid(x_unique, y_unique)\n", + " \n", + " plt.pcolormesh(x_mesh, y_mesh, heatmap_data, cmap='viridis', shading='auto')\n", + " plt.colorbar(label='Score')\n", + "\n", + " for i in range(len(y_unique)):\n", + " for j in range(len(x_unique)):\n", + " plt.text(x_unique[j], y_unique[i], f'{heatmap_data[i, j]:.2f}', ha='center', va='center', color='white' if heatmap_data[i, j] < 0.4 else \"black\")\n", + " \n", + " x_contour, y_contour = np.meshgrid(np.linspace(min(x_unique), max(x_unique), 100), \n", + " np.linspace(min(y_unique), max(y_unique), 100))\n", + " z_contour = x_contour * y_contour\n", + " \n", + " contour = plt.contour(x_contour, y_contour, z_contour, levels=[100, 200, 300, 400, 500], colors='white', zorder=4)\n", + " plt.clabel(contour, inline=True, fontsize=8)\n", + " \n", + " plt.xticks(ticks=x_unique, labels=x_unique)\n", + " plt.yticks(ticks=y_unique, labels=y_unique)\n", + " plt.xlabel('nattack')\n", + " plt.ylabel('nfeats')\n", + " plt.title(f'{scorer._score_func.__name__}{scorer._kwargs} ({classifier.__class__.__name__})')\n", + " plt.show()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -382,104 +1335,1416 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "fa49b67c-0a52-443e-904c-98f0a3d7febf", + "execution_count": 72, + "id": "6e3260c9-c0fa-4828-a749-4d34499abacf", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2726629e5a434a43a05103635da433eb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "nattack: 0%| | 0/6 [00:00 mean_pos:\n", - " best_feats = feats\n", - " best_feats_mean_pos = mean_pos\n", - " best_successes = successes\n", - " \n", - " print(f\"Best results for {nfeats} feats at {nattack} samples out of {retries} random feat subsets.\")\n", - " print(f\"Features: {best_feats}\")\n", - " print(f\"mean_pos: {best_feats_mean_pos:.2f}\")\n", - " print(f\"top1: {best_successes[1]:.2f}, top2: {best_successes[2]:.2f}, top5: {best_successes[5]:.2f}, top10: {best_successes[10]:.2f}\")" + "simulations = 500\n", + "retries = 500\n", + "nattack = range(50, 350, 50)\n", + "nfeats = range(1, 11)\n", + "num_workers = 30\n", + "\n", + "euclid_classifier = EuclidClassifier()\n", + "tree_random_subsets = RandomFeatures(sorted(feats_in_tree), distributions_mults, num_workers,\n", + " simulations, euclid_classifier, top_5_scorer, retries)\n", + "\n", + "tre = feature_search(nfeats, nattack, tree_random_subsets, restarts=True)" ] }, { "cell_type": "code", - "execution_count": null, - "id": "6e3260c9-c0fa-4828-a749-4d34499abacf", + "execution_count": 77, + "id": "0b6a1a5b-82dd-44d4-82dc-83b16ac5bc82", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a0ec40907ca6452d8f7419a143d06466", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Evaluating: 0%| | 0/60 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "simulations = 500\n", - "retries = 200\n", - "nfeats = trange(1, 11, leave=False, desc=\"nfeats\")\n", - "nattack = trange(50, 350, 50, leave=False, desc=\"nattack\")\n", - "num_workers = 30\n", - "\n", - "selected_random_euclid = find_features_random(feats_in_tree, nfeats, nattack, num_workers, retries, simulations, euclid)" + "plot_performance(euclid_classifier, top_5_scorer, 500, tre, distributions_mults)" ] }, { @@ -512,7 +2777,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 60, "id": "1f24b323-3604-4e34-a880-9dfd611fb245", "metadata": { "scrolled": true @@ -529,18 +2794,1416 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 61, "id": "f1052222-ad32-4e25-97ca-851cc42bf546", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "dd0c41b2ba854130a3e11c46e17ec1b0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "nattack: 0%| | 0/6 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(bayes_classifier, top_5_scorer, 500, bay, distributions_mults)" ] }, { @@ -556,7 +4219,11 @@ "cell_type": "code", "execution_count": null, "id": "93c778a4-0855-4248-91a9-750fdd76ffa6", - "metadata": {}, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, "outputs": [], "source": [ "def find_features_greedy(nfeats, nattack, num_workers, simulations, scorer, start_features=None):\n", @@ -593,18 +4260,1664 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 79, "id": "6e4c2313-83b0-43f8-80d6-14c39be0d9ec", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8ca4c6298c494e538470d7cc19357ad5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "nattack: 0%| | 0/6 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(bayes_classifier, top_5_scorer, 500, gre, distributions_mults)" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "id": "69ce91fa-7475-41f1-a3ed-bc4dd97d44d6", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "30adc7f9575145209097d858ff84a17f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "nattack: 0%| | 0/6 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_performance(bayes_classifier, top_5_scorer, 500, gre, distributions_mults)" ] }, { -- cgit v1.2.3-70-g09d2