From 334a33afabaaaf3a1be83a0573ee3dc2ddbe575d Mon Sep 17 00:00:00 2001 From: J08nY Date: Thu, 20 Mar 2025 16:37:46 +0100 Subject: Add countermeasure and scalarmult distinguishing notebooks. --- epare/countermeasures.ipynb | 1131 +++++++++++++++++++++++++++++++++++++++++++ epare/distinguish.ipynb | 641 ++++++++++++++++++++++++ 2 files changed, 1772 insertions(+) create mode 100644 epare/countermeasures.ipynb create mode 100644 epare/distinguish.ipynb diff --git a/epare/countermeasures.ipynb b/epare/countermeasures.ipynb new file mode 100644 index 0000000..1ee3dea --- /dev/null +++ b/epare/countermeasures.ipynb @@ -0,0 +1,1131 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "bafc2f4e-05a3-4120-bcd6-5d1f5fb91cd9", + "metadata": {}, + "source": [ + "# Distinguishing countermeasures by output" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "33ee6084-2ac3-4f95-9610-0fbc06026538", + "metadata": {}, + "outputs": [], + "source": [ + "import io\n", + "import random\n", + "\n", + "from collections import Counter\n", + "from tqdm.auto import tqdm, trange\n", + "\n", + "from pyecsca.ec.mod import mod\n", + "from pyecsca.ec.point import Point\n", + "from pyecsca.ec.model import ShortWeierstrassModel\n", + "from pyecsca.ec.params import load_params_ectester\n", + "from pyecsca.ec.mult import LTRMultiplier\n", + "from pyecsca.ec.countermeasures import GroupScalarRandomization, AdditiveSplitting, MultiplicativeSplitting, EuclideanSplitting, BrumleyTuveri" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b1b9596c-1eba-4ace-af84-8cb279d84cc2", + "metadata": {}, + "outputs": [], + "source": [ + "model = ShortWeierstrassModel()\n", + "coords = model.coordinates[\"projective\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b0afb195-8390-44c5-931e-75a70ccd4e9e", + "metadata": {}, + "outputs": [], + "source": [ + "add = coords.formulas[\"add-2015-rcb\"]\n", + "dbl = coords.formulas[\"dbl-2015-rcb\"]\n", + "mult = LTRMultiplier(add, dbl, complete=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "52c877e1-5021-4ec2-9daa-dd20bec6bcb2", + "metadata": {}, + "outputs": [], + "source": [ + "gsr = GroupScalarRandomization(mult)\n", + "asplit = AdditiveSplitting(mult)\n", + "msplit = MultiplicativeSplitting(mult)\n", + "esplit = EuclideanSplitting(mult)\n", + "bt = BrumleyTuveri(mult)" + ] + }, + { + "cell_type": "markdown", + "id": "27626337-dcbc-497c-a54e-02d50e2b8f34", + "metadata": {}, + "source": [ + "## 3n test" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c3088419-161b-4193-a1b6-6f623f217fcd", + "metadata": {}, + "outputs": [], + "source": [ + "key3n = 0x20959f2b437de1e522baf6d814911938157390d3ea5118660b852ab0d5387006\n", + "params3n = load_params_ectester(io.BytesIO(b\"0xc381bb0394f34b5ed061c9107b66974f4d0a8ec89b9fe73b98b6d1368c7d974d,0x5ca6c5ee0a10097af291a8f125303fb1a3e35e8100411902245d691e0e5cb497,0x385a5a8bb8af94721f6fd10b562606d9b9df931f7fd966e96859bb9bd7c05836,0x4616af1898b92cac0f902a9daee24bbae63571cead270467c6a7886ced421f5e,0x34e896bdb1337e0ae5960fa3389fb59c2c8d6c7dbfd9aac33a844f8f98e433ef,0x412b3e5686fbc3ca4575edb0292232702ae721a7d4a230cc170a5561aa70e00f,0x01\"), \"projective\")\n", + "bits3n = params3n.full_order.bit_length()\n", + "point3n = Point(X=mod(0x4a48addb2e471767b7cd0f6f1d4c27fe46f4a828fc20f950bd1f72c939b36a84, params3n.curve.prime),\n", + " Y=mod(0x13384d38c353f862832c0f067e46a3e510bb6803c20745dfb31929f4a18d890d, params3n.curve.prime),\n", + " Z=mod(1, params3n.curve.prime), model=coords)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a8dde7e6-cd48-4f99-9677-23a19e4c2e5b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "prime:\t0xc381bb0394f34b5ed061c9107b66974f4d0a8ec89b9fe73b98b6d1368c7d974d\n", + "a:\t0x5ca6c5ee0a10097af291a8f125303fb1a3e35e8100411902245d691e0e5cb497\n", + "b:\t0x385a5a8bb8af94721f6fd10b562606d9b9df931f7fd966e96859bb9bd7c05836\n", + "G:\t[0x4616af1898b92cac0f902a9daee24bbae63571cead270467c6a7886ced421f5e,\n", + "\t 0x34e896bdb1337e0ae5960fa3389fb59c2c8d6c7dbfd9aac33a844f8f98e433ef]\n", + "n:\t0x412b3e5686fbc3ca4575edb0292232702ae721a7d4a230cc170a5561aa70e00f\n", + "3n:\t0xc381bb0394f34b5ed061c9107b66975080b564f77de69264451f0024ff52a02d\n", + "\n", + "P:\t[0x4a48addb2e471767b7cd0f6f1d4c27fe46f4a828fc20f950bd1f72c939b36a84,\n", + "\t 0x13384d38c353f862832c0f067e46a3e510bb6803c20745dfb31929f4a18d890d]\n" + ] + } + ], + "source": [ + "print(f\"prime:\\t0x{params3n.curve.prime:x}\")\n", + "print(f\"a:\\t0x{params3n.curve.parameters['a']:x}\")\n", + "print(f\"b:\\t0x{params3n.curve.parameters['b']:x}\")\n", + "print(f\"G:\\t[0x{params3n.generator.X:x},\\n\\t 0x{params3n.generator.Y:x}]\")\n", + "print(f\"n:\\t0x{params3n.order:x}\")\n", + "print(f\"3n:\\t0x{3 * params3n.order:x}\")\n", + "print(f\"\\nP:\\t[0x{point3n.X:x},\\n\\t 0x{point3n.Y:x}]\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "cd6f8500-7509-45b0-8b23-471ee5014f42", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_scalars_mod3(rem, samples):\n", + " scalars = []\n", + " while True:\n", + " scalar = random.randint(0, params3n.full_order)\n", + " if scalar % 3 == rem:\n", + " scalars.append(scalar)\n", + " if len(scalars) == samples:\n", + " break\n", + " return scalars\n", + "\n", + "def test_3n(countermeasure, scalars):\n", + " ctr = Counter()\n", + " for k in tqdm(scalars, leave=False):\n", + " mult.init(params3n, point3n)\n", + " kP = mult.multiply(k).to_affine()\n", + " mult.init(params3n, point3n)\n", + " knP = mult.multiply(k + params3n.full_order).to_affine()\n", + " mult.init(params3n, point3n)\n", + " k2nP = mult.multiply(k + 2 * params3n.full_order).to_affine()\n", + "\n", + " countermeasure.init(params3n, point3n)\n", + " res = countermeasure.multiply(k)\n", + " aff = res.to_affine()\n", + " if aff.equals(kP):\n", + " ctr[\"k\"] += 1\n", + " elif aff.equals(knP):\n", + " ctr[\"k + 1n\"] += 1\n", + " elif aff.equals(k2nP):\n", + " ctr[\"k + 2n\"] += 1\n", + " else:\n", + " ctr[aff] += 1\n", + " for name, count in sorted(ctr.items()):\n", + " print(f\"{name}:\\t{count}\")\n", + "\n", + "def test_3n_fixed_scalar(countermeasure, samples):\n", + " test_3n(countermeasure, [key3n for _ in range(samples)])\n", + "\n", + "def test_3n_random_scalar(countermeasure, samples):\n", + " test_3n(countermeasure, [random.randint(0, params3n.full_order) for _ in range(samples)])\n", + "\n", + "def test_3n_random_scalar_projected(countermeasure, samples):\n", + " print(\"k = 0 mod 3\")\n", + " test_3n(countermeasure, generate_scalars_mod3(0, samples))\n", + " print()\n", + " print(\"k = 1 mod 3\")\n", + " test_3n(countermeasure, generate_scalars_mod3(1, samples))\n", + " print()\n", + " print(\"k = 2 mod 3\")\n", + " test_3n(countermeasure, generate_scalars_mod3(2, samples))" + ] + }, + { + "cell_type": "markdown", + "id": "46b8f74a-433d-48c9-b5b9-6bb7d2731246", + "metadata": {}, + "source": [ + "### Fixed scalar experiments" + ] + }, + { + "cell_type": "markdown", + "id": "fc82d4b9-91cd-423c-83aa-89721efa1ae9", + "metadata": {}, + "source": [ + "#### Group scalar randomization" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "86532d50-2db7-4370-b449-c545b330a852", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a645f19f86484d3f8154c39c2b851de2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00 tuple[float, float]:\n", + " return proportion_confint(round(p*samples), samples, alpha, method=\"wilson\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0eac736-be48-4925-accf-1ca8ff6aa065", + "metadata": {}, + "outputs": [], + "source": [ + "def powers_of(k, max_power=20):\n", + " return [k**i for i in range(1, max_power)]\n", + "\n", + "def prod_combine(one, other):\n", + " return [a * b for a, b in itertools.product(one, other)]\n", + "\n", + "small_primes = [3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199]\n", + "medium_primes = [211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397]\n", + "large_primes = [401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997]\n", + "all_integers = list(range(1, 400))\n", + "all_even = list(range(2, 400, 2))\n", + "all_odd = list(range(1, 400, 2))\n", + "all_primes = small_primes + medium_primes + large_primes\n", + "\n", + "divisor_map = {\n", + " \"small_primes\": small_primes,\n", + " \"medium_primes\": medium_primes,\n", + " \"large_primes\": large_primes,\n", + " \"all_primes\": all_primes,\n", + " \"all_integers\": all_integers,\n", + " \"all_even\": all_even,\n", + " \"all_odd\": all_odd,\n", + " \"powers_of_2\": powers_of(2),\n", + " \"powers_of_2_large\": powers_of(2, 256),\n", + " \"powers_of_2_large_3\": [i * 3 for i in powers_of(2, 256)],\n", + " \"powers_of_2_large_p1\": [i + 1 for i in powers_of(2, 256)],\n", + " \"powers_of_2_large_m1\": [i - 1 for i in powers_of(2, 256)],\n", + " \"powers_of_2_large_pmautobus\": sorted(set([i + j for i in powers_of(2, 256) for j in range(-5,5) if i+j > 0])),\n", + " \"powers_of_3\": powers_of(3),\n", + "}\n", + "divisor_map[\"all\"] = list(sorted(set().union(*[v for v in divisor_map.values()])))" + ] + }, + { + "cell_type": "markdown", + "id": "4868c083-8073-453d-b508-704fcb6d6f2a", + "metadata": {}, + "source": [ + "## Prepare" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ccb00342-3c48-49c9-bedf-2341e5eae3a2", + "metadata": {}, + "outputs": [], + "source": [ + "selected_mults = all_mults\n", + "divisor_name = \"all\"\n", + "kind = \"precomp+necessary\"\n", + "selected_divisors = divisor_map[divisor_name]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3dbac9be-d098-479a-8ca2-f531f6668f7c", + "metadata": {}, + "outputs": [], + "source": [ + "# Load\n", + "try:\n", + " with open(f\"{divisor_name}_{kind}_distrs.pickle\", \"rb\") as f:\n", + " distributions_mults = pickle.load(f)\n", + "except FileNotFoundError:\n", + " with open(f\"all_{kind}_distrs.pickle\", \"rb\") as f:\n", + " distributions_mults = pickle.load(f)\n", + " for probmap in distributions_mults.values():\n", + " probmap.narrow(selected_divisors)" + ] + }, + { + "cell_type": "markdown", + "id": "1f783baf-bc81-40c1-9282-e2dfdacfd17c", + "metadata": {}, + "source": [ + "## Build dmap and tree" + ] + }, + { + "cell_type": "markdown", + "id": "d161f3e8-6e39-47c1-a26f-23f910f3fe26", + "metadata": {}, + "source": [ + "Select the n for building the tree." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2307bf7c-4fac-489d-8527-7ddbf536a148", + "metadata": {}, + "outputs": [], + "source": [ + "nbuild = 10000\n", + "alpha = 0.05" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b85fad7-392f-4701-9329-d75d39736bbb", + "metadata": {}, + "outputs": [], + "source": [ + "# Now go over all divisors, cluster based on overlapping CI for given n?\n", + "io_map = {mult:{} for mult in distributions_mults.keys()}\n", + "for divisor in selected_divisors:\n", + " prev_ci_low = None\n", + " prev_ci_high = None\n", + " groups = {}\n", + " pvals = {}\n", + " group = 0\n", + " for mult, probmap in sorted(distributions_mults.items(), key=lambda item: -item[1][divisor]):\n", + " # We are going from high to low p.\n", + " pval = probmap[divisor]\n", + " pvals[mult] = pval\n", + " ci_low, ci_high = conf_interval(pval, nbuild, alpha)\n", + " ci_low = max(ci_low, 0.0)\n", + " ci_high = min(ci_high, 1.0)\n", + " if (prev_ci_low is None and prev_ci_high is None) or prev_ci_low >= ci_high:\n", + " g = groups.setdefault(f\"arbitrary{group}\", set())\n", + " g.add(mult)\n", + " group += 1\n", + " else:\n", + " g = groups.setdefault(f\"arbitrary{group}\", set())\n", + " g.add(mult)\n", + " prev_ci_low = ci_low\n", + " prev_ci_high = ci_high\n", + " \n", + " #print(f\"Divisor: {divisor}, num groups: {group}\", end=\"\\n\\t\")\n", + " #for g in groups.values():\n", + " # print(len(g), end=\", \")\n", + " #print()\n", + " for group, mults in groups.items():\n", + " mult_pvals = [pvals[mult] for mult in mults]\n", + " group_pval_avg = np.mean(mult_pvals)\n", + " group_pval_var = np.var(mult_pvals)\n", + " group_pval_min = np.min(mult_pvals)\n", + " group_pval_max = np.max(mult_pvals)\n", + " for mult in mults:\n", + " io_map[mult][divisor] = (group, group_pval_avg, group_pval_var, group_pval_min, group_pval_max)\n", + "\n", + "# then build dmap\n", + "dmap = Map.from_io_maps(set(distributions_mults.keys()), io_map)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06104104-b612-40e9-bc1d-646356a13381", + "metadata": {}, + "outputs": [], + "source": [ + "print(dmap.describe())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb8672ca-b76b-411d-b514-2387b555f184", + "metadata": {}, + "outputs": [], + "source": [ + "# deduplicate dmap\n", + "dmap.deduplicate()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ccba09b0-31c3-450b-af30-efaa64329743", + "metadata": {}, + "outputs": [], + "source": [ + "print(dmap.describe())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5735e7d4-149c-4184-96f7-dcfd6017fbad", + "metadata": {}, + "outputs": [], + "source": [ + "# build a tree\n", + "with silent():\n", + " tree = Tree.build(set(distributions_mults.keys()), dmap)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d41093df-32c4-450d-922d-5ad042539397", + "metadata": {}, + "outputs": [], + "source": [ + "print(tree.describe())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de577429-d87c-4967-be17-75cbb378860c", + "metadata": {}, + "outputs": [], + "source": [ + "print(tree.render_basic())" + ] + }, + { + "cell_type": "markdown", + "id": "437bcd9c-1da5-428a-a979-0835326777f3", + "metadata": {}, + "source": [ + "## Simulate distinguishing using a tree" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "109ed95f-0949-4251-874f-9b87cfe97a00", + "metadata": {}, + "outputs": [], + "source": [ + "simulations = 1000\n", + "\n", + "for nattack in trange(100, 10000, 100):\n", + " successes = 0\n", + " pathiness = 0\n", + " for i in range(simulations):\n", + " true_mult = random.choice(list(distributions_mults.keys()))\n", + " probmap = distributions_mults[true_mult]\n", + " node = tree.root\n", + " while True:\n", + " if node.is_leaf:\n", + " break\n", + " divisor = node.dmap_input\n", + " prob = probmap[divisor]\n", + " sampled_prob = binom(nattack, prob).rvs() / nattack\n", + " best_child = None\n", + " true_child = None\n", + " best_group_distance = None\n", + " #print(f\"Divisor: {divisor}, prob: {prob}, sampled: {sampled_prob}\")\n", + " for child in node.children:\n", + " if true_mult in child.cfgs:\n", + " true_child = child\n", + " group, group_pval_avg, group_pval_var, group_pval_min, group_pval_max = child.response\n", + " group_distance = min(abs(sampled_prob - group_pval_min), abs(sampled_prob - group_pval_max))\n", + " #print(f\"Child {group}, {group_pval_avg}\")\n", + " if best_child is None or \\\n", + " (group_distance < best_group_distance):\n", + " best_child = child\n", + " best_group_distance = group_distance\n", + " if sampled_prob > group_pval_min and sampled_prob < group_pval_max:\n", + " best_child = child\n", + " break\n", + " #print(f\"Best {best_child.response}\")\n", + " if true_child is not None and true_child != best_child:\n", + " pass\n", + " #print(f\"Mistake! {prob}, {sampled_prob} true:{true_child.response}, chosen:{best_child.response}\")\n", + " node = best_child\n", + " if true_mult in node.cfgs:\n", + " pathiness += 1\n", + " #print(f\"Arrived: {true_mult in node.cfgs}\")\n", + " if true_mult in node.cfgs:\n", + " successes += 1\n", + " print(f\"{nattack}: success rate {successes/simulations}, pathiness {pathiness/simulations}\")" + ] + }, + { + "cell_type": "markdown", + "id": "308df683-952e-430a-b2bd-f19bcfb98b8e", + "metadata": {}, + "source": [ + "## Simulate distinguishing using a distance metric\n", + "\n", + "We need to first select some features (divisors) from the set of all divisors that we will query\n", + "the target with. This set should be the smallest (to not do a lot of queries) yet allow us to distinguish as\n", + "much as possible." + ] + }, + { + "cell_type": "markdown", + "id": "62d2f2a2-495e-459d-b0e2-89c9a5973b1e", + "metadata": {}, + "source": [ + "### Feature selection using trees\n", + "\n", + "We can reuse the clustering + tree building approach above and just take the inputs that the greedy tree building choses as the features. However, we can also use more conventional feature selection approaches." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "beb5720a-f793-4ad9-ad27-1bd943bb325b", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "good_inputs = Counter()\n", + "for node in PreOrderIter(tree.root):\n", + " if node.is_leaf:\n", + " continue\n", + " good_inputs[node.dmap_input] += 1\n", + "for good in sorted(good_inputs):\n", + " print(good)\n", + " print(bin(good))\n", + " print(f\"used {good_inputs[good]} times\")\n", + " print(f\"nbits {good.bit_length()}\")\n", + " for div_name, div_group in divisor_map.items():\n", + " if good in div_group and div_name != \"all\":\n", + " print(div_name, end=\", \")\n", + " print(\"\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8af8c2ad-45e1-42ec-a757-5d5966c2f4b7", + "metadata": {}, + "outputs": [], + "source": [ + "simulations = 400\n", + "retries = 1000\n", + "\n", + "for nfeats in (6,): #trange(1, 7)\n", + " for nattack in range(100, 200, 100):\n", + " best_feats = None\n", + " best_feats_mean_pos = None\n", + " best_successes = None\n", + " for _ in trange(retries):\n", + " feats = random.sample(sorted(good_inputs), nfeats)\n", + " successes = {k:0 for k in range(1, 11)}\n", + " mean_pos = 0\n", + " for _ in range(simulations):\n", + " true_mult = random.choice(list(distributions_mults.keys()))\n", + " probmap = distributions_mults[true_mult]\n", + " feat_vector = np.zeros(nfeats)\n", + " for i, divisor in enumerate(feats):\n", + " prob = probmap[divisor]\n", + " sampled_prob = binom(nattack, prob).rvs() / nattack\n", + " feat_vector[i] = sampled_prob\n", + " scoring = []\n", + " for other_mult, other_probmap in distributions_mults.items():\n", + " other_vector = np.zeros(nfeats)\n", + " for i, divisor in enumerate(feats):\n", + " other_vector[i] = other_probmap[divisor]\n", + " similarity = distance.euclidean(feat_vector, other_vector)\n", + " scoring.append((similarity, other_mult))\n", + " scoring.sort(key=lambda item: item[0])\n", + " for i, (sim, other) in enumerate(scoring):\n", + " if other == true_mult:\n", + " mean_pos += i\n", + " for k in range(10):\n", + " if i <= k:\n", + " successes[k+1] +=1\n", + " for i in successes.keys():\n", + " successes[i] /= simulations\n", + " #print(f\"{nattack:<10}: mean position {mean_pos/simulations}\")\n", + " #print(f\" top1: {successes[1]}, top5: {successes[5]}, top10: {successes[10]}\")\n", + " if best_feats is None or best_feats_mean_pos > mean_pos/simulations:\n", + " best_feats = feats\n", + " best_feats_mean_pos = mean_pos/simulations\n", + " best_successes = successes\n", + " print(flush=True)\n", + " print(nattack)\n", + " print(f\"Features: ({nfeats}) {best_feats}\")\n", + " print(f\"mean_pos: {best_feats_mean_pos}\")\n", + " print(f\"top1: {best_successes[1]}, top2: {best_successes[2]}, top5: {best_successes[5]}, top10: {best_successes[10]}\")" + ] + }, + { + "cell_type": "markdown", + "id": "9b94be24-d0ee-4597-ad99-4c55558a38c9", + "metadata": {}, + "source": [ + "### Feature selection using variance" + ] + }, + { + "cell_type": "markdown", + "id": "11ee99e6-a90e-4304-bd7d-6491ce936b57", + "metadata": {}, + "source": [ + "### Feature selection using synthetic data\n", + "The below contains some experiments that mostly do not work. Ignore." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e8440f3-f856-41e0-8d37-56b750e1309d", + "metadata": {}, + "outputs": [], + "source": [ + "# Lets pick n as if we were doing the reversing\n", + "n = 100\n", + "# Lets pick m as the number of repeats\n", + "m = 100\n", + "# then for each mult and each divisor (thus each point) do binom(n, p) m times, save this synthetic data\n", + "nmults = len(distributions_mults)\n", + "ndivs = len(selected_divisors)\n", + "base_X = np.zeros((nmults, ndivs))\n", + "base_y = np.zeros(nmults)\n", + "synthetic_X = np.zeros((nmults * m, ndivs))\n", + "synthetic_y = np.zeros(nmults * m)\n", + "for i, (mult, probmap) in enumerate(distributions_mults.items()):\n", + " for j, divisor in enumerate(selected_divisors):\n", + " p = probmap[divisor]\n", + " r = binom.rvs(n, p, size=m) / n\n", + " synthetic_X[i*m:(i+1)*m, j] = r\n", + " base_X[i, j] = p\n", + " synthetic_y[i*m:(i+1)*m] = i\n", + " base_y[i] = i\n", + "print(synthetic_X)\n", + "# so we have !mults! classes and !mults! * m samples\n", + "# on this synthetic data we can run whatever" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6396296e-9352-4599-8ee9-45f9b4f4ce70", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.feature_selection import SelectKBest, SelectFdr, SelectFpr, SelectFwe, SequentialFeatureSelector\n", + "from sklearn.feature_selection import f_classif, mutual_info_classif, chi2\n", + "from sklearn.neighbors import KNeighborsClassifier\n", + "\n", + "from sklearn.datasets import load_iris" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d7050a1-b1ef-4eed-a885-cc11d8703b24", + "metadata": {}, + "outputs": [], + "source": [ + "selection = SelectKBest(f_classif, k=10).fit(synthetic_X, synthetic_y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "393d9d99-67e1-4d0a-b4ad-a0adcd6491d8", + "metadata": {}, + "outputs": [], + "source": [ + "len(selection.get_feature_names_out())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f324988-04bd-4c87-9af0-45abe1ebb6e9", + "metadata": {}, + "outputs": [], + "source": [ + "for divisor, present in zip(selected_divisors, selection.get_support()):\n", + " if present:\n", + " print(divisor)\n", + " print(bin(divisor))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0a40bd9-753e-4bc4-9bc7-f0eb2f96ce7b", + "metadata": {}, + "outputs": [], + "source": [ + "X_new = selection.transform(synthetic_X)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88809007-7b21-4985-83f9-f4cd9247fccf", + "metadata": {}, + "outputs": [], + "source": [ + "X_new.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cddc8885-37ad-4225-b83f-4798018f80f3", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn import tree" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e11a8fc4-0df9-4cdc-a6d3-2d6297b8e085", + "metadata": {}, + "outputs": [], + "source": [ + "clf = tree.DecisionTreeClassifier()\n", + "clf = clf.fit(synthetic_X, synthetic_y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e9a91df-845c-4eaa-944e-62d07d7cb1c6", + "metadata": {}, + "outputs": [], + "source": [ + "clf" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21516983-06be-4ad9-91f4-7454eacbf121", + "metadata": {}, + "outputs": [], + "source": [ + "from mrmr import mrmr_classif" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c2215a5-c073-4118-a21e-e78afb724eda", + "metadata": {}, + "outputs": [], + "source": [ + "selected_features = mrmr_classif(X=pd.DataFrame(synthetic_X), y=pd.Series(synthetic_y), K=35)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c05c7b33-6a75-4477-97a0-6b70808d0e1e", + "metadata": {}, + "outputs": [], + "source": [ + "for selected in selected_features:\n", + " divisor = selected_divisors[selected]\n", + " print(divisor, bin(divisor))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2787faf-a487-4f28-aa3c-8fdd9562550d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} -- cgit v1.2.3-70-g09d2