{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Train / validation / test split\n", "\n", "This is a notebook that was used to split the CC dataset into train/valid/test samples for the reference annotation NLP task." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from sec_certs.dataset import CCDataset\n", "from sec_certs.sample import CCCertificate\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "import json" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "dset = CCDataset.from_web()\n", "df = dset.to_pandas()\n", "reference_rich_certs = {x.dgst for x in dset if (x.heuristics.st_references.directly_referencing and x.state.st_txt_path) or (x.heuristics.report_references.directly_referencing and x.state.report_txt_path)}\n", "df = df.loc[df.index.isin(reference_rich_certs)]\n", "\n", "# The following certs go straight to the test set as they represent super rare categories that we cannot split\n", "certs_from_rare_categories = df.loc[df.category.isin({\"Multi-Function Devices\", \"Mobility\", \"Data Protection\"})].index.tolist()\n", "df = df.loc[~df.index.isin(certs_from_rare_categories)]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# This splits 30/20/50 (train, valid, test)\n", "x_train, x_test = train_test_split(df.index, test_size=0.5, shuffle=True, stratify=df.category)\n", "x_train, x_valid = train_test_split(x_train, test_size=0.4, shuffle=True, stratify=df.loc[df.index.isin(x_train)].category)\n", "x_test = list(x_test) + list(certs_from_rare_categories)\n", "\n", "with open(\"../../../data/reference_annotations_split/train.json\", \"w\") as handle:\n", " json.dump(x_train.tolist(), handle, indent=4)\n", "\n", "with open(\"../../../data/reference_annotations_split/valid.json\", \"w\") as handle:\n", " json.dump(x_valid.tolist(), handle, indent=4)\n", "\n", "with open(\"../../../data/reference_annotations_split/test.json\", \"w\") as handle:\n", " json.dump(x_test, handle, indent=4)" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "a5b8c5b127d2cfe5bc3a1c933e197485eb9eba25154c3661362401503b4ef9d4" } } }, "nbformat": 4, "nbformat_minor": 2 }