diff --git a/src/laptrack/metric_utils.py b/src/laptrack/metric_utils.py index e0222375..946b3674 100644 --- a/src/laptrack/metric_utils.py +++ b/src/laptrack/metric_utils.py @@ -1,5 +1,7 @@ """Utilities for metric calculation.""" +from typing import List from typing import Tuple +from typing import Union import numpy as np import pandas as pd @@ -32,12 +34,12 @@ def _union_bbox(self, r1, r2): bbox.append((y0, y1)) return bbox - def __init__(self, label_images: IntArray): + def __init__(self, label_images: Union[IntArray, List[IntArray]]): """Summarise the segmentation properties and initialize the object. Parameters ---------- - label_images : IntArray + label_images : Union[IntArray,List[IntArray]] The labeled images. The first dimension is interpreted as the frame dimension. """ if not isinstance(label_images, np.ndarray): diff --git a/tests/test_metric_utils.py b/tests/test_metric_utils.py index df157ca8..5fece351 100644 --- a/tests/test_metric_utils.py +++ b/tests/test_metric_utils.py @@ -1,28 +1,32 @@ from itertools import product +from typing import List +from typing import Union import numpy as np +from laptrack._typing_utils import IntArray from laptrack.metric_utils import LabelOverlap def test_label_overlap() -> None: - labels = np.array( - [ - [[[0, 1, 1, 1, 0], [0, 1, 2, 2, 2]], [[0, 1, 2, 2, 2], [3, 3, 3, 1, 0]]], - [[[0, 1, 1, 1, 2], [0, 4, 1, 2, 2]], [[0, 4, 4, 4, 4], [0, 4, 4, 4, 4]]], - [[[0, 1, 1, 1, 0], [5, 5, 5, 5, 5]], [[0, 1, 1, 1, 0], [0, 1, 1, 1, 0]]], - ] - ) - labelss = [labels, list(labels)] + labels = [ + [[[0, 1, 1, 1, 0], [0, 1, 2, 2, 2]], [[0, 1, 2, 2, 2], [3, 3, 3, 1, 0]]], + [[[0, 1, 1, 1, 2], [0, 4, 1, 2, 2]], [[0, 4, 4, 4, 4], [0, 4, 4, 4, 4]]], + [[[0, 1, 1, 1, 0], [5, 5, 5, 5, 5]], [[0, 1, 1, 1, 0], [0, 1, 1, 1, 0]]], + ] + labelss: List[Union[IntArray, List[IntArray]]] = [ + np.array(labels), + [np.array(label).astype(np.int64) for label in labels], + ] - for labels in labelss: - lo = LabelOverlap(labels) - frame_labels = [np.unique(label) for label in labels] + for _labels in labelss: + lo = LabelOverlap(_labels) + frame_labels = [np.unique(label) for label in _labels] frame_labels = [x[x > 0] for x in frame_labels] for f1, f2 in [(0, 0), (0, 1), (1, 2)]: for l1, l2 in product(frame_labels[f1], frame_labels[f2]): - b1 = labels[f1] == l1 - b2 = labels[f2] == l2 + b1 = _labels[f1] == l1 + b2 = _labels[f2] == l2 intersect = np.sum(b1 & b2) union = np.sum(b1 | b2)