From 12feaaf26e5da986ef479d0db2ebf5e2299378d6 Mon Sep 17 00:00:00 2001 From: Morgan Schwartz Date: Wed, 11 Dec 2024 15:37:10 -0500 Subject: [PATCH 1/3] Delete unused function --- src/traccuracy/matchers/_ctc.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/src/traccuracy/matchers/_ctc.py b/src/traccuracy/matchers/_ctc.py index 91aa7093..7285da43 100644 --- a/src/traccuracy/matchers/_ctc.py +++ b/src/traccuracy/matchers/_ctc.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING -import numpy as np from tqdm import tqdm if TYPE_CHECKING: @@ -96,30 +95,3 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): ) return mapping - - -def detection_test(gt_blob: np.ndarray, comp_blob: np.ndarray) -> int: - """Check if computed marker overlaps majority of the reference marker. - - Given a reference marker and computer marker in original coordinates, - return True if the computed marker overlaps strictly more than half - of the reference marker's pixels, otherwise False. - - Parameters - ---------- - gt_blob : np.ndarray - 2D or 3D boolean mask representing the pixels of the ground truth - marker - comp_blob : np.ndarray - 2D or 3D boolean mask representing the pixels of the computed - marker - - Returns - ------- - bool - True if computed marker majority overlaps reference marker, else False. - """ - n_gt_pixels = np.sum(gt_blob) - intersection = np.logical_and(gt_blob, comp_blob) - comp_blob_matches_gt_blob = int(np.sum(intersection) > 0.5 * n_gt_pixels) - return comp_blob_matches_gt_blob From 31f7828644cabe2959e5186f9e74a0f39bfd6ecf Mon Sep 17 00:00:00 2001 From: Morgan Schwartz Date: Thu, 12 Dec 2024 14:19:48 -0500 Subject: [PATCH 2/3] Pull the core matching functionality for the ctc matcher into a stand alone function --- src/traccuracy/matchers/_ctc.py | 36 ++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/traccuracy/matchers/_ctc.py b/src/traccuracy/matchers/_ctc.py index 7285da43..007d8484 100644 --- a/src/traccuracy/matchers/_ctc.py +++ b/src/traccuracy/matchers/_ctc.py @@ -79,19 +79,27 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): if pred_label_key in G_pred.graph.nodes[node] } - ( - overlapping_gt_labels, - overlapping_pred_labels, - intersection, - ) = get_labels_with_overlap(gt_frame, pred_frame, overlap="iogt") - - for i in range(len(overlapping_gt_labels)): - gt_label = overlapping_gt_labels[i] - pred_label = overlapping_pred_labels[i] - # CTC metrics only match comp IDs to a single GT ID if there is majority overlap - if intersection[i] > 0.5: - mapping.append( - (gt_label_to_id[gt_label], pred_label_to_id[pred_label]) - ) + frame_map = match_frame_majority(gt_frame, pred_frame) + # Switch from segmentation ids to node ids + for gt_label, pred_label in frame_map: + mapping.append((gt_label_to_id[gt_label], pred_label_to_id[pred_label])) return mapping + + +def match_frame_majority(gt_frame, pred_frame): + mapping = [] + ( + overlapping_gt_labels, + overlapping_pred_labels, + intersection, + ) = get_labels_with_overlap(gt_frame, pred_frame, overlap="iogt") + + for gt_label, pred_label, iogt in zip( + overlapping_gt_labels, overlapping_pred_labels, intersection + ): + # CTC metrics only match comp IDs to a single GT ID if there is majority overlap + if iogt > 0.5: + mapping.append((gt_label, pred_label)) + + return mapping From daaf76831d7a53d4fc74719df265067901419f59 Mon Sep 17 00:00:00 2001 From: Morgan Schwartz Date: Thu, 12 Dec 2024 14:20:35 -0500 Subject: [PATCH 3/3] Write tests for ctc matcher using standard test cases for segmentation --- tests/examples/segs.py | 2 +- tests/matchers/test_ctc.py | 85 +++++++++++++++++++++++++++++++++++++- 2 files changed, 84 insertions(+), 3 deletions(-) diff --git a/tests/examples/segs.py b/tests/examples/segs.py index 95773412..d529c3c6 100644 --- a/tests/examples/segs.py +++ b/tests/examples/segs.py @@ -122,7 +122,7 @@ def make_split_cell_3d( mask = sphere(center, radius, shape=arr_shape) im[mask] = labels[0] # get indices where y value greater than center - mask[:, 0 : center[1]] = 0 + mask[:, 0 : center[1] + 1] = 0 im[mask] = labels[1] return im diff --git a/tests/matchers/test_ctc.py b/tests/matchers/test_ctc.py index cb1495e7..246a299f 100644 --- a/tests/matchers/test_ctc.py +++ b/tests/matchers/test_ctc.py @@ -1,13 +1,16 @@ +from collections import Counter + import networkx as nx import numpy as np import pytest +import tests.examples.segs as ex_segs from tests.test_utils import get_annotated_movie from traccuracy._tracking_graph import TrackingGraph -from traccuracy.matchers._ctc import CTCMatcher +from traccuracy.matchers._ctc import CTCMatcher, match_frame_majority -def test_match_ctc(): +def test_CTCMatcher(): matcher = CTCMatcher() # shapes don't match @@ -46,3 +49,81 @@ def test_match_ctc(): # gt and pred node should be the same for pair in matched.mapping: assert pair[0] == pair[1] + + +class Test_match_frame_majority: + @pytest.mark.parametrize( + "data", + [ex_segs.good_segmentation_2d(), ex_segs.good_segmentation_3d()], + ids=["2D", "3D"], + ) + def test_good_seg(self, data): + ex_match = [(1, 2)] + comp_match = match_frame_majority(*data) + assert Counter(ex_match) == Counter(comp_match) + + @pytest.mark.parametrize( + "data", + [ + ex_segs.false_positive_segmentation_2d(), + ex_segs.false_positive_segmentation_3d(), + ], + ids=["2D", "3D"], + ) + def test_false_pos_seg(self, data): + ex_match = [] + comp_match = match_frame_majority(*data) + assert Counter(ex_match) == Counter(comp_match) + + @pytest.mark.parametrize( + "data", + [ + ex_segs.false_negative_segmentation_2d(), + ex_segs.false_negative_segmentation_3d(), + ], + ids=["2D", "3D"], + ) + def test_false_neg_seg(self, data): + ex_match = [] + comp_match = match_frame_majority(*data) + assert Counter(ex_match) == Counter(comp_match) + + @pytest.mark.parametrize( + "data", + [ex_segs.oversegmentation_2d(), ex_segs.oversegmentation_3d()], + ids=["2D", "3D"], + ) + def test_split(self, data): + ex_match = [(1, 2)] + comp_match = match_frame_majority(*data) + assert Counter(ex_match) == Counter(comp_match) + + @pytest.mark.parametrize( + "data", + [ex_segs.undersegmentation_2d(), ex_segs.undersegmentation_3d()], + ids=["2D", "3D"], + ) + def test_merge(self, data): + ex_match = [(1, 3), (2, 3)] + comp_match = match_frame_majority(*data) + assert Counter(ex_match) == Counter(comp_match) + + @pytest.mark.parametrize( + "data", + [ex_segs.no_overlap_2d(), ex_segs.no_overlap_3d()], + ids=["2D", "3D"], + ) + def test_no_overlap(self, data): + ex_match = [] + comp_match = match_frame_majority(*data) + assert Counter(ex_match) == Counter(comp_match) + + @pytest.mark.parametrize( + "data", + [ex_segs.multicell_2d(), ex_segs.multicell_3d()], + ids=["2D", "3D"], + ) + def test_multicell(self, data): + ex_match = [(1, 3)] + comp_match = match_frame_majority(*data) + assert Counter(ex_match) == Counter(comp_match)