Skip to content

Commit

Permalink
WIP: Fix some type checking errors
Browse files Browse the repository at this point in the history
  • Loading branch information
ghisvail committed Sep 13, 2024
1 parent 62bb242 commit 079a6e9
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 26 deletions.
32 changes: 17 additions & 15 deletions medkit/io/_brat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from smart_open import open

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from pathlib import Path

GROUPING_ENTITIES = frozenset(["And-Group", "Or-Group"])
Expand All @@ -22,7 +23,7 @@ class BratEntity:

uid: str
type: str
span: list[tuple[int, int]]
span: Sequence[tuple[int, int]]
text: str

@property
Expand Down Expand Up @@ -58,7 +59,7 @@ class BratAttribute:
uid: str
type: str
target: str
value: str = None # Only one value is possible
value: str | None = None # Only one value is possible

def to_str(self) -> str:
value = ensure_attr_value(self.value)
Expand All @@ -80,7 +81,7 @@ def to_str(self) -> str:


def ensure_attr_value(attr_value: Any) -> str:
"""Ensure that the attribue value is a string."""
"""Ensure that the attribute value is a string."""
if isinstance(attr_value, str):
return attr_value
if attr_value is None or isinstance(attr_value, bool):
Expand All @@ -98,7 +99,7 @@ class Grouping:

uid: str
type: str
items: list[BratEntity]
items: Sequence[BratEntity]

@property
def text(self):
Expand All @@ -111,11 +112,11 @@ class BratAugmentedEntity:

uid: str
type: str
span: tuple[tuple[int, int], ...]
span: Sequence[tuple[int, int]]
text: str
relations_from_me: tuple[BratRelation, ...]
relations_to_me: tuple[BratRelation, ...]
attributes: tuple[BratAttribute, ...]
relations_from_me: Sequence[BratRelation]
relations_to_me: Sequence[BratRelation]
attributes: Sequence[BratAttribute]

@property
def start(self) -> int:
Expand All @@ -128,11 +129,11 @@ def end(self) -> int:

@dataclass
class BratDocument:
entities: dict[str, BratEntity]
relations: dict[str, BratRelation]
attributes: dict[str, BratAttribute]
notes: dict[str, BratNote]
groups: dict[str, Grouping] = None
entities: Mapping[str, BratEntity]
relations: Mapping[str, BratRelation]
attributes: Mapping[str, BratAttribute]
notes: Mapping[str, BratNote]
groups: Mapping[str, Grouping] | None = None

def get_augmented_entities(self) -> dict[str, BratAugmentedEntity]:
augmented_entities = {}
Expand Down Expand Up @@ -374,9 +375,8 @@ def parse_string(ann_string: str, detect_groups: bool = False) -> BratDocument:
logger.warning("Ignore annotation %s at line %s", ann_id, line_number)

# Process groups
groups = None
if detect_groups:
groups: dict[str, Grouping] = {}
groups = {}
grouping_relations = {r.uid: r for r in relations.values() if r.type in GROUPING_RELATIONS}

for entity in entities.values():
Expand All @@ -385,6 +385,8 @@ def parse_string(ann_string: str, detect_groups: bool = False) -> BratDocument:
entities[relation.obj] for relation in grouping_relations.values() if relation.subj == entity.uid
]
groups[entity.uid] = Grouping(entity.uid, entity.type, items)
else:
groups = None

return BratDocument(entities, relations, attributes, notes, groups)

Expand Down
12 changes: 5 additions & 7 deletions medkit/text/ner/umls_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,10 @@ def load_umls_entries(
if lui in luis_seen:
continue

if semtypes_by_cui is not None and cui in semtypes_by_cui:
semtypes = semtypes_by_cui[cui]
semgroups = [semgroups_by_semtype[semtype] for semtype in semtypes]
else:
semtypes = None
semgroups = None
semtypes = semtypes_by_cui.get(cui) if semtypes_by_cui else None
semgroups = (
[semgroups_by_semtype[semtype] for semtype in semtypes] if semgroups_by_semtype and semtypes else None
)

luis_seen.add(lui)
yield UMLSEntry(cui, term, semtypes, semgroups)
Expand Down Expand Up @@ -198,7 +196,7 @@ def load_semtypes_by_cui(mrsty_file: str | Path) -> dict[str, list[str]]:
# Source: UMLS project
# https://lhncbc.nlm.nih.gov/semanticnetwork/download/sg_archive/SemGroups-v04.txt
_UMLS_SEMGROUPS_FILE = Path(__file__).parent / "umls_semgroups_v04.txt"
_SEMGROUPS_BY_SEMTYPE = None
_SEMGROUPS_BY_SEMTYPE: dict[str, str] | None = None


def load_semgroups_by_semtype() -> dict[str, str]:
Expand Down
6 changes: 4 additions & 2 deletions medkit/training/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from tqdm import tqdm

if TYPE_CHECKING:
from collections.abc import Mapping

from medkit.training.trainer_config import TrainerConfig


Expand All @@ -23,7 +25,7 @@ def on_train_end(self):
def on_epoch_begin(self, epoch: int):
"""Event called at the beginning of an epoch."""

def on_epoch_end(self, metrics: dict[str, float], epoch: int, epoch_time: float):
def on_epoch_end(self, metrics: Mapping[str, Mapping[str, float]], epoch: int, epoch_duration: float):
"""Event called at the end of an epoch."""

def on_step_begin(self, step_idx: int, nb_batches: int, phase: str):
Expand Down Expand Up @@ -66,7 +68,7 @@ def on_train_begin(self, config):
)
self.logger.info(message)

def on_epoch_end(self, metrics, epoch, epoch_duration):
def on_epoch_end(self, metrics: Mapping[str, Mapping[str, float]], epoch: int, epoch_duration: float):
message = f"Epoch {epoch} ended (duration: {epoch_duration:.2f}s)\n"

train_metrics = metrics.get("train", None)
Expand Down
4 changes: 2 additions & 2 deletions medkit/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, runtime_checkable

import torch
from typing_extensions import Protocol, Self
from typing_extensions import Protocol


class BatchData(dict):
Expand All @@ -17,7 +17,7 @@ def __getitem__(self, index: int) -> dict[str, list[Any] | torch.Tensor]:
return inner_dict[index]
return {key: values[index] for key, values in self.items()}

def to_device(self, device: torch.device) -> Self:
def to_device(self, device: torch.device) -> BatchData:
"""Ensure that Tensors in the BatchData object are on the specified `device`.
Parameters
Expand Down

0 comments on commit 079a6e9

Please sign in to comment.