{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Pre-process data for reference classification\n", "\n", "This script's pipeline is as follows:\n", "\n", "1. Recover text segments that surround certificate ID for all references in CC dataset\n", "2. Create a DataFrame `(dgst, cert_id, label, text_segments)` out of the objects\n", "3. Clean and dump into csv\n", "4. Check for label noise" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from __future__ import annotations\n", "\n", "from dataclasses import dataclass\n", "from sec_certs.dataset import CCDataset\n", "from sec_certs.sample import CCCertificate\n", "import spacy\n", "from sec_certs.utils.parallel_processing import process_parallel\n", "import pandas as pd\n", "import json\n", "\n", "nlp = spacy.load(\"en_core_web_sm\")\n", "from pathlib import Path\n", "\n", "REPO_ROOT = Path(\"../../../\").resolve()\n", "\n", "@dataclass\n", "class ReferenceRecord:\n", " \"\"\"\n", " Intermediate object to hold references for a given certificate together with sensible attributes to be extracted\n", " for labeling.\n", " \"\"\"\n", " certificate: CCCertificate | None\n", " dgst: str\n", " cert_id: str\n", " location: str\n", " label: str | None = None\n", " sentences: set[str] | None = None\n", "\n", " @staticmethod\n", " def get_reference_sentences(doc, cert_id: str) -> set[str]:\n", " \"\"\"\n", " Return a set of sentences corresponding to the given cert_id for the record\n", " \"\"\"\n", " return {sent.text for sent in doc.sents if cert_id in sent.text}\n", "\n", " @staticmethod\n", " def get_cert_references_with_sentences(record: ReferenceRecord) -> set[tuple[str, str, str]]:\n", " pth_to_read = (\n", " record.certificate.state.st_txt_path\n", " if record.location == \"target\"\n", " else record.certificate.state.report_txt_path\n", " )\n", "\n", " with pth_to_read.open(\"r\") as handle:\n", " data = handle.read()\n", "\n", " result = ReferenceRecord.get_reference_sentences(nlp(data), record.cert_id)\n", " record.sentences = result if result else None\n", "\n", " return record\n", "\n", " def to_pandas_tuple(self) -> tuple[str, str, str, str, set[str] | None]:\n", " return self.dgst, self.cert_id, self.location, self.label, self.sentences\n", "\n", "def get_df_from_records(records: list[ReferenceRecord]):\n", " \"\"\"\n", " Builds dataframe with [dgst,cert_id,location,reason,sentences] with references from list of ReferenceRecords.\n", " Reason set to None if not defined. \n", " \"\"\"\n", " results = process_parallel(ReferenceRecord.get_cert_references_with_sentences, records, use_threading=False, progress_bar=True)\n", " return pd.DataFrame.from_records([x.to_pandas_tuple() for x in results], columns=[\"dgst\", \"cert_id\", \"location\", \"label\", \"sentences\"])" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Extract sentences from text files and populate dataframes" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 58/58 [00:07<00:00, 8.27it/s]\n", "100%|██████████| 944/944 [01:06<00:00, 14.12it/s]\n", "100%|██████████| 2259/2259 [00:32<00:00, 69.22it/s]\n" ] } ], "source": [ "# Load annotated references from CSV\n", "annotations_df = pd.read_csv(REPO_ROOT / \"data/cert_id_eval/random_references.csv\")\n", "annotations_df = annotations_df.rename(columns={\"id\": \"dgst\", \"reason\": \"label\"})\n", "annotations_df = annotations_df.loc[annotations_df.label != \"self\"]\n", "annotations_df.label = annotations_df.label.map(lambda x: x.upper().replace(\" \", \"_\"))\n", "\n", "# Load dataset\n", "# dset = CCDataset.from_web_latest()\n", "dset = CCDataset.from_json(REPO_ROOT / \"datasets/cc/cc_dataset.json\")\n", "\n", "annotated_records = [ReferenceRecord(dset[x.dgst], x.dgst, x.cert_id, x.location, x.label) for x in annotations_df.itertuples(index=False)]\n", "\n", "# Reference records without annotations\n", "target_certs = [x for x in dset if x.heuristics.st_references.directly_referencing and x.state.st_txt_path]\n", "report_certs = [x for x in dset if x.heuristics.report_references.directly_referencing and x.state.report_txt_path]\n", "target_records = [ReferenceRecord(x, x.dgst, y, \"target\", None, None) for x in target_certs for y in x.heuristics.st_references.directly_referencing]\n", "report_records = [ReferenceRecord(x, x.dgst, y, \"report\", None, None) for x in report_certs for y in x.heuristics.report_references.directly_referencing]\n", "\n", "# Filter annotated_records from report_records to avoid duplicities\n", "annotated_keys = {(x.dgst, x.cert_id) for x in annotated_records}\n", "report_records = [x for x in report_records if (x.dgst, x.cert_id) not in annotated_keys]\n", "\n", "df_labeled = get_df_from_records(annotated_records)\n", "df_targets = get_df_from_records(target_records)\n", "df_reports = get_df_from_records(report_records)\n", "df = pd.concat([df_labeled, df_targets, df_reports])" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Process Dataframes and dump two versions into csv\n", "\n", "1. Version with `dgst, cert_id, location, single_sentence` as `*_exploded.csv`\n", "2. Version where all sentences tied to `(dgst, cert_id)` key are merged into `sentences`. Saved as `*_grouped.csv`\n", "\n", "*Note*: So far don't work with test dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Load split labels\n", "with (REPO_ROOT / \"data/reference_annotations_split/train.json\").open(\"r\") as handle:\n", " train_digests = json.load(handle)\n", "\n", "with (REPO_ROOT / \"data/reference_annotations_split/valid.json\").open(\"r\") as handle:\n", " valid_digests = json.load(handle)\n", "\n", "split_dct = {**dict.fromkeys(train_digests, \"train\"), **dict.fromkeys(valid_digests, \"valid\")}\n", "\n", "# Apply filtering\n", "# TODO: We should investigate the cases when we match no sentence\n", "df = df.loc[df.sentences.notnull()] \n", "df[\"split\"] = df.dgst.map(split_dct)\n", "df = df.loc[df.split.notnull()] # Discard test samples\n", "\n", "# TODO: Add language detection\n", "\n", "# Aggregate sentences from different sources (target, report) into one row\n", "df = df.groupby([\"dgst\", \"cert_id\", \"label\", \"split\"], as_index=False)[\"sentences\"].agg({\"sentences\": lambda x: set.union(*x)})\n", "df.to_csv(REPO_ROOT / \"datasets/reference_classification_dataset_merrged.csv\", sep=';', index=False)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Check for label noise, i.e., search for instances that have different label of a reference.\n", "duplicates_df = df[df.duplicated(subset=[\"dgst\", \"cert_id\"], keep=False)]\n", "if not duplicates_df.empty:\n", " print(\"Warning, label noise in dataset. I.e. tuples (dgst, cert_id) with inconsistent reason. See `duplicates_df` frame.\")" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "a2ed43df31f510d0b358bd0625493376557b0c4d37aa99c09b398809f951b6a5" } } }, "nbformat": 4, "nbformat_minor": 2 }