Skip to content

Commit

Permalink
solved mypy issue
Browse files Browse the repository at this point in the history
  • Loading branch information
yfukai committed Nov 6, 2022
1 parent 3ac3723 commit 720815b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
6 changes: 4 additions & 2 deletions src/laptrack/metric_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utilities for metric calculation."""
from typing import List
from typing import Tuple
from typing import Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -32,12 +34,12 @@ def _union_bbox(self, r1, r2):
bbox.append((y0, y1))
return bbox

def __init__(self, label_images: IntArray):
def __init__(self, label_images: Union[IntArray, List[IntArray]]):
"""Summarise the segmentation properties and initialize the object.
Parameters
----------
label_images : IntArray
label_images : Union[IntArray,List[IntArray]]
The labeled images. The first dimension is interpreted as the frame dimension.
"""
if not isinstance(label_images, np.ndarray):
Expand Down
30 changes: 17 additions & 13 deletions tests/test_metric_utils.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
from itertools import product
from typing import List
from typing import Union

import numpy as np

from laptrack._typing_utils import IntArray
from laptrack.metric_utils import LabelOverlap


def test_label_overlap() -> None:
labels = np.array(
[
[[[0, 1, 1, 1, 0], [0, 1, 2, 2, 2]], [[0, 1, 2, 2, 2], [3, 3, 3, 1, 0]]],
[[[0, 1, 1, 1, 2], [0, 4, 1, 2, 2]], [[0, 4, 4, 4, 4], [0, 4, 4, 4, 4]]],
[[[0, 1, 1, 1, 0], [5, 5, 5, 5, 5]], [[0, 1, 1, 1, 0], [0, 1, 1, 1, 0]]],
]
)
labelss = [labels, list(labels)]
labels = [
[[[0, 1, 1, 1, 0], [0, 1, 2, 2, 2]], [[0, 1, 2, 2, 2], [3, 3, 3, 1, 0]]],
[[[0, 1, 1, 1, 2], [0, 4, 1, 2, 2]], [[0, 4, 4, 4, 4], [0, 4, 4, 4, 4]]],
[[[0, 1, 1, 1, 0], [5, 5, 5, 5, 5]], [[0, 1, 1, 1, 0], [0, 1, 1, 1, 0]]],
]
labelss: List[Union[IntArray, List[IntArray]]] = [
np.array(labels),
[np.array(label).astype(np.int64) for label in labels],
]

for labels in labelss:
lo = LabelOverlap(labels)
frame_labels = [np.unique(label) for label in labels]
for _labels in labelss:
lo = LabelOverlap(_labels)
frame_labels = [np.unique(label) for label in _labels]
frame_labels = [x[x > 0] for x in frame_labels]
for f1, f2 in [(0, 0), (0, 1), (1, 2)]:
for l1, l2 in product(frame_labels[f1], frame_labels[f2]):
b1 = labels[f1] == l1
b2 = labels[f2] == l2
b1 = _labels[f1] == l1
b2 = _labels[f2] == l2

intersect = np.sum(b1 & b2)
union = np.sum(b1 | b2)
Expand Down

0 comments on commit 720815b

Please sign in to comment.