aboutsummaryrefslogtreecommitdiffhomepage
path: root/notebooks
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks')
-rw-r--r--notebooks/cc/reference_annotations/hyperparameter_search.py4
-rw-r--r--notebooks/fixed_sankey_plot.py38
2 files changed, 21 insertions, 21 deletions
diff --git a/notebooks/cc/reference_annotations/hyperparameter_search.py b/notebooks/cc/reference_annotations/hyperparameter_search.py
index a7286aa0..fc42d148 100644
--- a/notebooks/cc/reference_annotations/hyperparameter_search.py
+++ b/notebooks/cc/reference_annotations/hyperparameter_search.py
@@ -14,11 +14,11 @@ import optuna
import pandas as pd
import torch
from rapidfuzz import fuzz
+from sec_certs.model.references.annotator_trainer import ReferenceAnnotatorTrainer
+from sec_certs.model.references.segment_extractor import ReferenceSegmentExtractor
from sklearn.metrics import f1_score
from sec_certs.dataset import CCDataset
-from sec_certs.model.references.annotator_trainer import ReferenceAnnotatorTrainer
-from sec_certs.model.references.segment_extractor import ReferenceSegmentExtractor
from sec_certs.utils.helpers import compute_heuristics_version
from sec_certs.utils.nlp import prec_recall_metric
diff --git a/notebooks/fixed_sankey_plot.py b/notebooks/fixed_sankey_plot.py
index b8d062f9..8f609227 100644
--- a/notebooks/fixed_sankey_plot.py
+++ b/notebooks/fixed_sankey_plot.py
@@ -9,7 +9,7 @@ This code should fix the problems and should be used to produce figures in the r
import logging
import warnings
from collections import defaultdict
-from typing import Any, Optional, Union
+from typing import Any, Union
import matplotlib.pyplot as plt
import numpy as np
@@ -57,18 +57,18 @@ def check_data_matches_labels(labels: Union[list[str], set[str]], data: Series,
def sankey(
left: Union[list, ndarray, Series],
right: Union[ndarray, Series],
- leftWeight: Optional[ndarray] = None,
- rightWeight: Optional[ndarray] = None,
- colorDict: Optional[dict[str, str]] = None,
- leftLabels: Optional[list[str]] = None,
- rightLabels: Optional[list[str]] = None,
+ leftWeight: ndarray | None = None,
+ rightWeight: ndarray | None = None,
+ colorDict: dict[str, str] | None = None,
+ leftLabels: list[str] | None = None,
+ rightLabels: list[str] | None = None,
aspect: int = 4,
rightColor: bool = False,
fontsize: int = 14,
- figureName: Optional[str] = None,
+ figureName: str | None = None,
closePlot: bool = False,
- figSize: Optional[tuple[int, int]] = None,
- ax: Optional[Any] = None,
+ figSize: tuple[int, int] | None = None,
+ ax: Any | None = None,
) -> Any:
"""
Make Sankey Diagram showing flow from left-->right
@@ -151,7 +151,7 @@ def sankey(
return ax
-def save_image(figureName: Optional[str]) -> None:
+def save_image(figureName: str | None) -> None:
if figureName is not None:
file_name = f"{figureName}.png"
plt.savefig(file_name, bbox_inches="tight", dpi=150)
@@ -173,15 +173,15 @@ def identify_labels(dataFrame: DataFrame, leftLabels: list[str], rightLabels: li
def init_values(
- ax: Optional[Any],
+ ax: Any | None,
closePlot: bool,
- figSize: Optional[tuple[int, int]],
- figureName: Optional[str],
+ figSize: tuple[int, int] | None,
+ figureName: str | None,
left: Union[list, ndarray, Series],
- leftLabels: Optional[list[str]],
- leftWeight: Optional[ndarray],
- rightLabels: Optional[list[str]],
- rightWeight: Optional[ndarray],
+ leftLabels: list[str] | None,
+ leftWeight: ndarray | None,
+ rightLabels: list[str] | None,
+ rightWeight: ndarray | None,
) -> tuple[Any, list[str], ndarray, list[str], ndarray]:
deprecation_warnings(closePlot, figSize, figureName)
if ax is None:
@@ -202,7 +202,7 @@ def init_values(
return ax, leftLabels, leftWeight, rightLabels, rightWeight
-def deprecation_warnings(closePlot: bool, figSize: Optional[tuple[int, int]], figureName: Optional[str]) -> None:
+def deprecation_warnings(closePlot: bool, figSize: tuple[int, int] | None, figureName: str | None) -> None:
warn = []
if figureName is not None:
msg = "use of figureName in sankey() is deprecated"
@@ -286,7 +286,7 @@ def draw_vertical_bars(
def create_colors(
- allLabels: ndarray, colorDict: Optional[dict[str, str]]
+ allLabels: ndarray, colorDict: dict[str, str] | None
) -> Union[dict[str, tuple[float, float, float]], dict[str, str]]:
# If no colorDict given, make one
if colorDict is None: