diff options
Diffstat (limited to 'notebooks')
| -rw-r--r-- | notebooks/cc/reference_annotations/hyperparameter_search.py | 4 | ||||
| -rw-r--r-- | notebooks/fixed_sankey_plot.py | 38 |
2 files changed, 21 insertions, 21 deletions
diff --git a/notebooks/cc/reference_annotations/hyperparameter_search.py b/notebooks/cc/reference_annotations/hyperparameter_search.py index a7286aa0..fc42d148 100644 --- a/notebooks/cc/reference_annotations/hyperparameter_search.py +++ b/notebooks/cc/reference_annotations/hyperparameter_search.py @@ -14,11 +14,11 @@ import optuna import pandas as pd import torch from rapidfuzz import fuzz +from sec_certs.model.references.annotator_trainer import ReferenceAnnotatorTrainer +from sec_certs.model.references.segment_extractor import ReferenceSegmentExtractor from sklearn.metrics import f1_score from sec_certs.dataset import CCDataset -from sec_certs.model.references.annotator_trainer import ReferenceAnnotatorTrainer -from sec_certs.model.references.segment_extractor import ReferenceSegmentExtractor from sec_certs.utils.helpers import compute_heuristics_version from sec_certs.utils.nlp import prec_recall_metric diff --git a/notebooks/fixed_sankey_plot.py b/notebooks/fixed_sankey_plot.py index b8d062f9..8f609227 100644 --- a/notebooks/fixed_sankey_plot.py +++ b/notebooks/fixed_sankey_plot.py @@ -9,7 +9,7 @@ This code should fix the problems and should be used to produce figures in the r import logging import warnings from collections import defaultdict -from typing import Any, Optional, Union +from typing import Any, Union import matplotlib.pyplot as plt import numpy as np @@ -57,18 +57,18 @@ def check_data_matches_labels(labels: Union[list[str], set[str]], data: Series, def sankey( left: Union[list, ndarray, Series], right: Union[ndarray, Series], - leftWeight: Optional[ndarray] = None, - rightWeight: Optional[ndarray] = None, - colorDict: Optional[dict[str, str]] = None, - leftLabels: Optional[list[str]] = None, - rightLabels: Optional[list[str]] = None, + leftWeight: ndarray | None = None, + rightWeight: ndarray | None = None, + colorDict: dict[str, str] | None = None, + leftLabels: list[str] | None = None, + rightLabels: list[str] | None = None, aspect: int = 4, rightColor: bool = False, fontsize: int = 14, - figureName: Optional[str] = None, + figureName: str | None = None, closePlot: bool = False, - figSize: Optional[tuple[int, int]] = None, - ax: Optional[Any] = None, + figSize: tuple[int, int] | None = None, + ax: Any | None = None, ) -> Any: """ Make Sankey Diagram showing flow from left-->right @@ -151,7 +151,7 @@ def sankey( return ax -def save_image(figureName: Optional[str]) -> None: +def save_image(figureName: str | None) -> None: if figureName is not None: file_name = f"{figureName}.png" plt.savefig(file_name, bbox_inches="tight", dpi=150) @@ -173,15 +173,15 @@ def identify_labels(dataFrame: DataFrame, leftLabels: list[str], rightLabels: li def init_values( - ax: Optional[Any], + ax: Any | None, closePlot: bool, - figSize: Optional[tuple[int, int]], - figureName: Optional[str], + figSize: tuple[int, int] | None, + figureName: str | None, left: Union[list, ndarray, Series], - leftLabels: Optional[list[str]], - leftWeight: Optional[ndarray], - rightLabels: Optional[list[str]], - rightWeight: Optional[ndarray], + leftLabels: list[str] | None, + leftWeight: ndarray | None, + rightLabels: list[str] | None, + rightWeight: ndarray | None, ) -> tuple[Any, list[str], ndarray, list[str], ndarray]: deprecation_warnings(closePlot, figSize, figureName) if ax is None: @@ -202,7 +202,7 @@ def init_values( return ax, leftLabels, leftWeight, rightLabels, rightWeight -def deprecation_warnings(closePlot: bool, figSize: Optional[tuple[int, int]], figureName: Optional[str]) -> None: +def deprecation_warnings(closePlot: bool, figSize: tuple[int, int] | None, figureName: str | None) -> None: warn = [] if figureName is not None: msg = "use of figureName in sankey() is deprecated" @@ -286,7 +286,7 @@ def draw_vertical_bars( def create_colors( - allLabels: ndarray, colorDict: Optional[dict[str, str]] + allLabels: ndarray, colorDict: dict[str, str] | None ) -> Union[dict[str, tuple[float, float, float]], dict[str, str]]: # If no colorDict given, make one if colorDict is None: |
