From 40a934e1e12d2cfc87bab1ab795c3033aad1af18 Mon Sep 17 00:00:00 2001 From: Ghislain Vaillant Date: Thu, 12 Sep 2024 16:14:39 +0200 Subject: [PATCH] WIP: Fix some type checking errors --- medkit/io/_brat_utils.py | 32 +++++++++++++++++--------------- medkit/text/ner/umls_utils.py | 12 +++++------- medkit/training/utils.py | 4 ++-- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/medkit/io/_brat_utils.py b/medkit/io/_brat_utils.py index 7a3cdbdd..572001f5 100644 --- a/medkit/io/_brat_utils.py +++ b/medkit/io/_brat_utils.py @@ -8,6 +8,7 @@ from smart_open import open if TYPE_CHECKING: + from collections.abc import Mapping, Sequence from pathlib import Path GROUPING_ENTITIES = frozenset(["And-Group", "Or-Group"]) @@ -22,7 +23,7 @@ class BratEntity: uid: str type: str - span: list[tuple[int, int]] + span: Sequence[tuple[int, int]] text: str @property @@ -58,7 +59,7 @@ class BratAttribute: uid: str type: str target: str - value: str = None # Only one value is possible + value: str | None = None # Only one value is possible def to_str(self) -> str: value = ensure_attr_value(self.value) @@ -80,7 +81,7 @@ def to_str(self) -> str: def ensure_attr_value(attr_value: Any) -> str: - """Ensure that the attribue value is a string.""" + """Ensure that the attribute value is a string.""" if isinstance(attr_value, str): return attr_value if attr_value is None or isinstance(attr_value, bool): @@ -98,7 +99,7 @@ class Grouping: uid: str type: str - items: list[BratEntity] + items: Sequence[BratEntity] @property def text(self): @@ -111,11 +112,11 @@ class BratAugmentedEntity: uid: str type: str - span: tuple[tuple[int, int], ...] + span: Sequence[tuple[int, int]] text: str - relations_from_me: tuple[BratRelation, ...] - relations_to_me: tuple[BratRelation, ...] - attributes: tuple[BratAttribute, ...] + relations_from_me: Sequence[BratRelation] + relations_to_me: Sequence[BratRelation] + attributes: Sequence[BratAttribute] @property def start(self) -> int: @@ -128,11 +129,11 @@ def end(self) -> int: @dataclass class BratDocument: - entities: dict[str, BratEntity] - relations: dict[str, BratRelation] - attributes: dict[str, BratAttribute] - notes: dict[str, BratNote] - groups: dict[str, Grouping] = None + entities: Mapping[str, BratEntity] + relations: Mapping[str, BratRelation] + attributes: Mapping[str, BratAttribute] + notes: Mapping[str, BratNote] + groups: Mapping[str, Grouping] | None = None def get_augmented_entities(self) -> dict[str, BratAugmentedEntity]: augmented_entities = {} @@ -374,9 +375,8 @@ def parse_string(ann_string: str, detect_groups: bool = False) -> BratDocument: logger.warning("Ignore annotation %s at line %s", ann_id, line_number) # Process groups - groups = None if detect_groups: - groups: dict[str, Grouping] = {} + groups = {} grouping_relations = {r.uid: r for r in relations.values() if r.type in GROUPING_RELATIONS} for entity in entities.values(): @@ -385,6 +385,8 @@ def parse_string(ann_string: str, detect_groups: bool = False) -> BratDocument: entities[relation.obj] for relation in grouping_relations.values() if relation.subj == entity.uid ] groups[entity.uid] = Grouping(entity.uid, entity.type, items) + else: + groups = None return BratDocument(entities, relations, attributes, notes, groups) diff --git a/medkit/text/ner/umls_utils.py b/medkit/text/ner/umls_utils.py index 2fdde0bb..a9e447f0 100644 --- a/medkit/text/ner/umls_utils.py +++ b/medkit/text/ner/umls_utils.py @@ -155,12 +155,10 @@ def load_umls_entries( if lui in luis_seen: continue - if semtypes_by_cui is not None and cui in semtypes_by_cui: - semtypes = semtypes_by_cui[cui] - semgroups = [semgroups_by_semtype[semtype] for semtype in semtypes] - else: - semtypes = None - semgroups = None + semtypes = semtypes_by_cui.get(cui) if semtypes_by_cui else None + semgroups = ( + [semgroups_by_semtype[semtype] for semtype in semtypes] if semgroups_by_semtype and semtypes else None + ) luis_seen.add(lui) yield UMLSEntry(cui, term, semtypes, semgroups) @@ -198,7 +196,7 @@ def load_semtypes_by_cui(mrsty_file: str | Path) -> dict[str, list[str]]: # Source: UMLS project # https://lhncbc.nlm.nih.gov/semanticnetwork/download/sg_archive/SemGroups-v04.txt _UMLS_SEMGROUPS_FILE = Path(__file__).parent / "umls_semgroups_v04.txt" -_SEMGROUPS_BY_SEMTYPE = None +_SEMGROUPS_BY_SEMTYPE: dict[str, str] | None = None def load_semgroups_by_semtype() -> dict[str, str]: diff --git a/medkit/training/utils.py b/medkit/training/utils.py index 5e687d20..1ddffb6c 100644 --- a/medkit/training/utils.py +++ b/medkit/training/utils.py @@ -5,7 +5,7 @@ from typing import Any, runtime_checkable import torch -from typing_extensions import Protocol, Self +from typing_extensions import Protocol class BatchData(dict): @@ -17,7 +17,7 @@ def __getitem__(self, index: int) -> dict[str, list[Any] | torch.Tensor]: return inner_dict[index] return {key: values[index] for key, values in self.items()} - def to_device(self, device: torch.device) -> Self: + def to_device(self, device: torch.device) -> BatchData: """Ensure that Tensors in the BatchData object are on the specified `device`. Parameters