Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test CTCMatcher using standard test cases #174

Merged
merged 4 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 18 additions & 38 deletions src/traccuracy/matchers/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import TYPE_CHECKING

import numpy as np
from tqdm import tqdm

if TYPE_CHECKING:
Expand Down Expand Up @@ -80,46 +79,27 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph):
if pred_label_key in G_pred.graph.nodes[node]
}

(
overlapping_gt_labels,
overlapping_pred_labels,
intersection,
) = get_labels_with_overlap(gt_frame, pred_frame, overlap="iogt")

for i in range(len(overlapping_gt_labels)):
gt_label = overlapping_gt_labels[i]
pred_label = overlapping_pred_labels[i]
# CTC metrics only match comp IDs to a single GT ID if there is majority overlap
if intersection[i] > 0.5:
mapping.append(
(gt_label_to_id[gt_label], pred_label_to_id[pred_label])
)
frame_map = match_frame_majority(gt_frame, pred_frame)
# Switch from segmentation ids to node ids
for gt_label, pred_label in frame_map:
mapping.append((gt_label_to_id[gt_label], pred_label_to_id[pred_label]))

return mapping


def detection_test(gt_blob: np.ndarray, comp_blob: np.ndarray) -> int:
"""Check if computed marker overlaps majority of the reference marker.
def match_frame_majority(gt_frame, pred_frame):
mapping = []
(
overlapping_gt_labels,
overlapping_pred_labels,
intersection,
) = get_labels_with_overlap(gt_frame, pred_frame, overlap="iogt")

Given a reference marker and computer marker in original coordinates,
return True if the computed marker overlaps strictly more than half
of the reference marker's pixels, otherwise False.
for gt_label, pred_label, iogt in zip(
overlapping_gt_labels, overlapping_pred_labels, intersection
):
# CTC metrics only match comp IDs to a single GT ID if there is majority overlap
if iogt > 0.5:
mapping.append((gt_label, pred_label))

Parameters
----------
gt_blob : np.ndarray
2D or 3D boolean mask representing the pixels of the ground truth
marker
comp_blob : np.ndarray
2D or 3D boolean mask representing the pixels of the computed
marker

Returns
-------
bool
True if computed marker majority overlaps reference marker, else False.
"""
n_gt_pixels = np.sum(gt_blob)
intersection = np.logical_and(gt_blob, comp_blob)
comp_blob_matches_gt_blob = int(np.sum(intersection) > 0.5 * n_gt_pixels)
return comp_blob_matches_gt_blob
return mapping
2 changes: 1 addition & 1 deletion tests/examples/segs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def make_split_cell_3d(
mask = sphere(center, radius, shape=arr_shape)
im[mask] = labels[0]
# get indices where y value greater than center
mask[:, 0 : center[1]] = 0
mask[:, 0 : center[1] + 1] = 0
im[mask] = labels[1]
return im

Expand Down
85 changes: 83 additions & 2 deletions tests/matchers/test_ctc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from collections import Counter

import networkx as nx
import numpy as np
import pytest

import tests.examples.segs as ex_segs
from tests.test_utils import get_annotated_movie
from traccuracy._tracking_graph import TrackingGraph
from traccuracy.matchers._ctc import CTCMatcher
from traccuracy.matchers._ctc import CTCMatcher, match_frame_majority


def test_match_ctc():
def test_CTCMatcher():
matcher = CTCMatcher()

# shapes don't match
Expand Down Expand Up @@ -46,3 +49,81 @@ def test_match_ctc():
# gt and pred node should be the same
for pair in matched.mapping:
assert pair[0] == pair[1]


class Test_match_frame_majority:
@pytest.mark.parametrize(
"data",
[ex_segs.good_segmentation_2d(), ex_segs.good_segmentation_3d()],
ids=["2D", "3D"],
)
def test_good_seg(self, data):
ex_match = [(1, 2)]
comp_match = match_frame_majority(*data)
assert Counter(ex_match) == Counter(comp_match)

@pytest.mark.parametrize(
"data",
[
ex_segs.false_positive_segmentation_2d(),
ex_segs.false_positive_segmentation_3d(),
],
ids=["2D", "3D"],
)
def test_false_pos_seg(self, data):
ex_match = []
comp_match = match_frame_majority(*data)
assert Counter(ex_match) == Counter(comp_match)

@pytest.mark.parametrize(
"data",
[
ex_segs.false_negative_segmentation_2d(),
ex_segs.false_negative_segmentation_3d(),
],
ids=["2D", "3D"],
)
def test_false_neg_seg(self, data):
ex_match = []
comp_match = match_frame_majority(*data)
assert Counter(ex_match) == Counter(comp_match)

@pytest.mark.parametrize(
"data",
[ex_segs.oversegmentation_2d(), ex_segs.oversegmentation_3d()],
ids=["2D", "3D"],
)
def test_split(self, data):
ex_match = [(1, 2)]
comp_match = match_frame_majority(*data)
assert Counter(ex_match) == Counter(comp_match)

@pytest.mark.parametrize(
"data",
[ex_segs.undersegmentation_2d(), ex_segs.undersegmentation_3d()],
ids=["2D", "3D"],
)
def test_merge(self, data):
ex_match = [(1, 3), (2, 3)]
comp_match = match_frame_majority(*data)
assert Counter(ex_match) == Counter(comp_match)

@pytest.mark.parametrize(
"data",
[ex_segs.no_overlap_2d(), ex_segs.no_overlap_3d()],
ids=["2D", "3D"],
)
def test_no_overlap(self, data):
ex_match = []
comp_match = match_frame_majority(*data)
assert Counter(ex_match) == Counter(comp_match)

@pytest.mark.parametrize(
"data",
[ex_segs.multicell_2d(), ex_segs.multicell_3d()],
ids=["2D", "3D"],
)
def test_multicell(self, data):
ex_match = [(1, 3)]
comp_match = match_frame_majority(*data)
assert Counter(ex_match) == Counter(comp_match)
Loading