From b705b82cdb67251be97d873bc35a6265b5c17067 Mon Sep 17 00:00:00 2001 From: Morgan Schwartz Date: Wed, 18 Dec 2024 16:20:37 -0500 Subject: [PATCH 1/2] Finish division standard test cases --- tests/track_errors/test_divisions.py | 452 ++++++++++++++------------- 1 file changed, 243 insertions(+), 209 deletions(-) diff --git a/tests/track_errors/test_divisions.py b/tests/track_errors/test_divisions.py index a939ab1b..c704d245 100644 --- a/tests/track_errors/test_divisions.py +++ b/tests/track_errors/test_divisions.py @@ -1,7 +1,7 @@ -import networkx as nx import numpy as np import pytest +import tests.examples.graphs as ex_graphs from tests.test_utils import get_division_graphs from traccuracy import NodeFlag, TrackingGraph from traccuracy.matchers import Matched @@ -14,219 +14,253 @@ ) -@pytest.fixture -def g(): - """ - 1_0 -- 1_1 -- 1_2 -- 1_3 - 3_3 - 2_0 -- 2_1 -- 2_2 -< - 4_3 - """ - g = nx.DiGraph() - g.add_edge("1_0", "1_1") - g.add_edge("1_1", "1_2") - g.add_edge("1_2", "1_3") - - g.add_edge("2_0", "2_1") - g.add_edge("2_1", "2_2") - - # node 2 divides into 3 and 4 in frame 3 - g.add_edge("2_2", "3_3") - g.add_edge("2_2", "4_3") - - # Set node attributes - attrs = {} - for node in g.nodes: - attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} - nx.set_node_attributes(g, attrs) - - return g - - -def test_classify_divisions_tp(g): - # Define mapper assuming all nodes match - mapper = [(n, n) for n in g.nodes] - matched_data = Matched( - TrackingGraph(g.copy()), - TrackingGraph(g.copy()), - mapper, - {"name": "DummyMatcher"}, - ) - - # Test true positive - _classify_divisions(matched_data) - - assert len(matched_data.gt_graph.get_nodes_with_flag(NodeFlag.FN_DIV)) == 0 - assert len(matched_data.pred_graph.get_nodes_with_flag(NodeFlag.FP_DIV)) == 0 - assert NodeFlag.TP_DIV in matched_data.gt_graph.nodes["2_2"] - assert NodeFlag.TP_DIV in matched_data.pred_graph.nodes["2_2"] - - # Check division flag - assert matched_data.gt_graph.division_annotations - assert matched_data.pred_graph.division_annotations - +class TestStandardsDivisions: + """Test _classify_divisions against standard cases -def test_classify_divisions_fp(g): + Tests are written for sparse annotations """ - 5_3 - 1_0 -- 1_1 -- 1_2 -< - 1_3 - 3_3 - 2_0 -- 2_1 -- 2_2 -< - 4_3 - """ - h = g.copy() - # Add false positive division edge - h.add_edge("1_2", "5_3") - nx.set_node_attributes(h, {"5_3": {"t": 3, "x": 0, "y": 0}}) - mapper = [(n, n) for n in h.nodes] - matched_data = Matched( - TrackingGraph(g), TrackingGraph(h), mapper, {"name": "DummyMatcher"} + @pytest.mark.parametrize("t_div", [0, 1, 2]) + def test_good_div(self, t_div): + matched = ex_graphs.good_div(t_div) + _classify_divisions(matched) + + div_node = {0: (1, 6), 1: (2, 6), 2: (3, 8)} + assert matched.gt_graph.nodes[div_node[t_div][0]].get(NodeFlag.TP_DIV) is True + assert matched.pred_graph.nodes[div_node[t_div][1]].get(NodeFlag.TP_DIV) is True + + @pytest.mark.parametrize("t_div", [0, 1]) + def test_fp_div(self, t_div): + matched = ex_graphs.fp_div(t_div) + _classify_divisions(matched) + + pred_div_node = [6, 6] + pred_node_attr = matched.pred_graph.nodes[pred_div_node[t_div]] + assert pred_node_attr.get(NodeFlag.FP_DIV) is True + + @pytest.mark.parametrize("t_div", [0, 1]) + def test_one_child(self, t_div): + matched = ex_graphs.one_child(t_div) + _classify_divisions(matched) + + gt_div_node = [1, 2] + gt_node_attr = matched.gt_graph.nodes[gt_div_node[t_div]] + assert gt_node_attr.get(NodeFlag.FN_DIV) is True + + @pytest.mark.parametrize("t_div", [0, 1]) + def test_no_children(self, t_div): + matched = ex_graphs.no_children(t_div) + _classify_divisions(matched) + + gt_div_node = [1, 2] + gt_node_attr = matched.gt_graph.nodes[gt_div_node[t_div]] + assert gt_node_attr.get(NodeFlag.FN_DIV) is True + + @pytest.mark.parametrize("t_div", [0, 1]) + def test_wrong_child(self, t_div): + matched = ex_graphs.wrong_child(t_div) + _classify_divisions(matched) + + gt_div_node = [1, 2] + gt_node_attr = matched.gt_graph.nodes[gt_div_node[t_div]] + assert gt_node_attr.get(NodeFlag.FN_DIV) is True + + +class Test_get_pred_by_t: + g = ex_graphs.basic_graph() + + def test_predecessor_available(self): + start_node = 3 + delta = 2 + node = _get_pred_by_t(self.g, start_node, delta) + assert node == 1 + + def test_no_predecessor(self): + start_node = 2 + delta = 2 + node = _get_pred_by_t(self.g, start_node, delta) + assert node is None + + +class Test_get_succ_by_t: + g = ex_graphs.basic_division(2) + + def across_division(self): + # Return none if looking across division + start_node = 3 + delta = 2 + succ = _get_succ_by_t(self.g, start_node, delta) + assert succ is None + + def valid_succ(self): + # Find 2 frames forward with valid node + start_node = 1 + delta = 2 + succ = _get_succ_by_t(self.g, start_node, delta) + assert succ == 3 + + def no_succ(self): + # Forward without valid node returns None + start_node = 4 + delta = 1 + succ = _get_succ_by_t(self.g, start_node, delta) + assert succ is None + + +class TestStandardShifted: + """Test correct_shifted_divisions against standard shifted cases""" + + @pytest.mark.parametrize("n_frames", [1, 2]) + @pytest.mark.parametrize( + "get_data", + [ex_graphs.div_1early_end, ex_graphs.div_1early_mid], + ids=["div_1early_end", "div_1early_mid"], ) - - _classify_divisions(matched_data) - - assert len(matched_data.gt_graph.get_nodes_with_flag(NodeFlag.FN_DIV)) == 0 - assert NodeFlag.FP_DIV in matched_data.pred_graph.nodes["1_2"] - assert NodeFlag.TP_DIV in matched_data.gt_graph.nodes["2_2"] - assert NodeFlag.TP_DIV in matched_data.pred_graph.nodes["2_2"] - - -def test_classify_divisions_fn(g): - """ - 1_0 -- 1_1 -- 1_2 -- 1_3 - 2_0 -- 2_1 -- 2_2 - """ - # Remove daughters to create false negative - h = g.copy() - h.remove_nodes_from(["3_3", "4_3"]) - mapper = [(n, n) for n in h.nodes] - - matched_data = Matched( - TrackingGraph(g), TrackingGraph(h), mapper, {"name": "DummyMatcher"} + def test_div_1early(self, n_frames, get_data): + matched = get_data() + _classify_divisions(matched) + shifted = _correct_shifted_divisions(matched, n_frames=n_frames) + + if get_data.__name__ == "div_1early_end": + gt_node = 2 + pred_node = 9 + elif get_data.__name__ == "div_1early_mid": + gt_node = 3 + pred_node = 9 + + attrs = shifted.gt_graph.nodes[gt_node] + assert attrs.get(NodeFlag.TP_DIV) is True + assert attrs.get(NodeFlag.FN_DIV) is False + + attrs = shifted.pred_graph.nodes[pred_node] + assert attrs.get(NodeFlag.TP_DIV) is True + assert attrs.get(NodeFlag.FP_DIV) is False + + @pytest.mark.parametrize("n_frames", [1, 3]) + @pytest.mark.parametrize( + "get_data", + [ex_graphs.div_2early_end, ex_graphs.div_2early_mid], + ids=["div_2early_end", "div_2early_mid"], ) - - _classify_divisions(matched_data) - - assert len(matched_data.pred_graph.get_nodes_with_flag(NodeFlag.FP_DIV)) == 0 - assert len(matched_data.gt_graph.get_nodes_with_flag(NodeFlag.TP_DIV)) == 0 - assert NodeFlag.FN_DIV in matched_data.gt_graph.nodes["2_2"] - - -@pytest.fixture -def straight_graph(): - g = nx.DiGraph() - for t in range(2, 10): - g.add_edge(f"1_{t}", f"1_{t+1}") - - # Set node attributes - attrs = {} - for node in g.nodes: - attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0} - nx.set_node_attributes(g, attrs) - - return g - - -def test__get_pred_by_t(straight_graph): - # Linear graph with node id 1 from frame 2-10 - g = TrackingGraph(straight_graph) - - # Predecessor available - start_frame = 10 - target_frame = 5 - node = _get_pred_by_t(g, f"1_{start_frame}", start_frame - target_frame) - assert node == f"1_{target_frame}" - - # Predecessor does not exist - start_frame = 10 - target_frame = 1 - node = _get_pred_by_t(g, f"1_{start_frame}", start_frame - target_frame) - assert node is None - - -def test__get_succ_by_t(): - _, g2, _, _ = get_division_graphs() - g2 = TrackingGraph(g2) - - # Find 2 frames forward correctly - start_node = "5_2" - delta_t = 2 - end_node = "5_4" - node = _get_succ_by_t(g2, start_node, delta_t) - assert node == end_node - - # 3 frames forward returns None - start_node = "5_2" - delta_t = 3 - end_node = None - node = _get_succ_by_t(g2, start_node, delta_t) - assert node == end_node - - -class Test_correct_shifted_divisions: - def test_no_change(self): - # Early division in gt - g_pred, g_gt, map_pred, map_gt = get_division_graphs() - mapper = list(zip(map_gt, map_pred)) - g_gt.nodes["4_1"][NodeFlag.FN_DIV] = True - g_pred.nodes["1_3"][NodeFlag.FP_DIV] = True - - matched_data = Matched( - TrackingGraph(g_gt), TrackingGraph(g_pred), mapper, {"name": "DummyMatcher"} - ) - - # buffer of 1, no change - new_matched = _correct_shifted_divisions(matched_data, n_frames=1) - ng_pred = new_matched.pred_graph - ng_gt = new_matched.gt_graph - - assert ng_pred.nodes["1_3"][NodeFlag.FP_DIV] is True - assert ng_gt.nodes["4_1"][NodeFlag.FN_DIV] is True - assert len(ng_gt.get_nodes_with_flag(NodeFlag.TP_DIV)) == 0 - - def test_fn_early(self): - # Early division in gt - g_pred, g_gt, map_pred, map_gt = get_division_graphs() - mapper = list(zip(map_gt, map_pred)) - g_gt.nodes["4_1"][NodeFlag.FN_DIV] = True - g_pred.nodes["1_3"][NodeFlag.FP_DIV] = True - - matched_data = Matched( - TrackingGraph(g_gt), TrackingGraph(g_pred), mapper, {"name": "DummyMatcher"} - ) - - # buffer of 3, corrections - new_matched = _correct_shifted_divisions(matched_data, n_frames=3) - ng_pred = new_matched.pred_graph - ng_gt = new_matched.gt_graph - - assert ng_pred.nodes["1_3"][NodeFlag.FP_DIV] is False - assert ng_gt.nodes["4_1"][NodeFlag.FN_DIV] is False - assert ng_pred.nodes["1_3"][NodeFlag.TP_DIV] is True - assert ng_gt.nodes["4_1"][NodeFlag.TP_DIV] is True - - def test_fp_early(self): - # Early division in pred - g_gt, g_pred, map_gt, map_pred = get_division_graphs() - mapper = list(zip(map_gt, map_pred)) - g_pred.nodes["4_1"][NodeFlag.FP_DIV] = True - g_gt.nodes["1_3"][NodeFlag.FN_DIV] = True - - matched_data = Matched( - TrackingGraph(g_gt), TrackingGraph(g_pred), mapper, {"name": "DummyMatcher"} - ) - - # buffer of 3, corrections - new_matched = _correct_shifted_divisions(matched_data, n_frames=3) - ng_pred = new_matched.pred_graph - ng_gt = new_matched.gt_graph - - assert ng_pred.nodes["4_1"][NodeFlag.FP_DIV] is False - assert ng_gt.nodes["1_3"][NodeFlag.FN_DIV] is False - assert ng_pred.nodes["4_1"][NodeFlag.TP_DIV] is True - assert ng_gt.nodes["1_3"][NodeFlag.TP_DIV] is True + def test_div_2early(self, n_frames, get_data): + matched = get_data() + _classify_divisions(matched) + shifted = _correct_shifted_divisions(matched, n_frames=n_frames) + + if get_data.__name__ == "div_2early_end": + gt_node = 3 + pred_node = 8 + elif get_data.__name__ == "div_2early_mid": + gt_node = 4 + pred_node = 8 + + if n_frames == 1: # Not corrected + attrs = shifted.gt_graph.nodes[gt_node] + assert attrs.get(NodeFlag.FN_DIV) is True + + attrs = shifted.pred_graph.nodes[pred_node] + assert attrs.get(NodeFlag.FP_DIV) is True + elif n_frames == 3: # corrected + attrs = shifted.gt_graph.nodes[gt_node] + assert attrs.get(NodeFlag.TP_DIV) is True + assert attrs.get(NodeFlag.FN_DIV) is False + + attrs = shifted.pred_graph.nodes[pred_node] + assert attrs.get(NodeFlag.TP_DIV) is True + assert attrs.get(NodeFlag.FP_DIV) is False + + @pytest.mark.parametrize("n_frames", [1, 2]) + @pytest.mark.parametrize( + "get_data", + [ex_graphs.div_1late_end, ex_graphs.div_1late_mid], + ids=["div_1late_end", "div_1late_mid"], + ) + def test_div_1late(self, n_frames, get_data): + matched = get_data() + _classify_divisions(matched) + shifted = _correct_shifted_divisions(matched, n_frames=n_frames) + + if get_data.__name__ == "div_1late_end": + gt_node = 1 + pred_node = 11 + elif get_data.__name__ == "div_1late_mid": + gt_node = 2 + pred_node = 11 + + attrs = shifted.gt_graph.nodes[gt_node] + assert attrs.get(NodeFlag.TP_DIV) is True + assert attrs.get(NodeFlag.FN_DIV) is False + + attrs = shifted.pred_graph.nodes[pred_node] + assert attrs.get(NodeFlag.TP_DIV) is True + assert attrs.get(NodeFlag.FP_DIV) is False + + @pytest.mark.parametrize("n_frames", [1, 3]) + @pytest.mark.parametrize( + "get_data", + [ex_graphs.div_2late_end, ex_graphs.div_2late_mid], + ids=["div_2late_end", "div_2late_mid"], + ) + def test_div_2late(self, n_frames, get_data): + matched = get_data() + _classify_divisions(matched) + shifted = _correct_shifted_divisions(matched, n_frames=n_frames) + + if get_data.__name__ == "div_2late_end": + gt_node = 1 + pred_node = 12 + elif get_data.__name__ == "div_2late_mid": + gt_node = 2 + pred_node = 12 + + if n_frames == 1: # Not corrected + attrs = shifted.gt_graph.nodes[gt_node] + assert attrs.get(NodeFlag.FN_DIV) is True + + attrs = shifted.pred_graph.nodes[pred_node] + assert attrs.get(NodeFlag.FP_DIV) is True + elif n_frames == 3: # corrected + attrs = shifted.gt_graph.nodes[gt_node] + assert attrs.get(NodeFlag.TP_DIV) is True + assert attrs.get(NodeFlag.FN_DIV) is False + + attrs = shifted.pred_graph.nodes[pred_node] + assert attrs.get(NodeFlag.TP_DIV) is True + assert attrs.get(NodeFlag.FP_DIV) is False + + def test_minimal_matching(self): + matched = ex_graphs.div_shift_min_match() + _classify_divisions(matched) + shifted = _correct_shifted_divisions(matched, n_frames=1) + + attrs = shifted.gt_graph.nodes[2] + assert attrs.get(NodeFlag.TP_DIV) is True + assert attrs.get(NodeFlag.FN_DIV) is False + + attrs = shifted.pred_graph.nodes[11] + assert attrs.get(NodeFlag.TP_DIV) is True + assert attrs.get(NodeFlag.FP_DIV) is False + + @pytest.mark.parametrize( + "matched", + [ + ex_graphs.div_shift_bad_match_pred(), + ex_graphs.div_shift_bad_match_daughter(), + ], + ids=["pred", "daughters"], + ) + def test_bad_matching(self, matched): + _classify_divisions(matched) + shifted = _correct_shifted_divisions(matched, n_frames=1) + + # No correction of shifted divisions b/c matching criteria not met + attrs = shifted.gt_graph.nodes[2] + assert attrs.get(NodeFlag.TP_DIV) is None + assert attrs.get(NodeFlag.FN_DIV) is True + + attrs = shifted.pred_graph.nodes[11] + assert attrs.get(NodeFlag.TP_DIV) is None + assert attrs.get(NodeFlag.FP_DIV) is True def test_evaluate_division_events(): From 0b6ea611bf29ea07e192f4ab0c1470c1792d86a7 Mon Sep 17 00:00:00 2001 From: Morgan Schwartz Date: Thu, 19 Dec 2024 14:47:49 -0500 Subject: [PATCH 2/2] Improve parametrization of more complex div tests --- tests/track_errors/test_divisions.py | 93 +++++++++------------------- 1 file changed, 28 insertions(+), 65 deletions(-) diff --git a/tests/track_errors/test_divisions.py b/tests/track_errors/test_divisions.py index c704d245..b271bfc8 100644 --- a/tests/track_errors/test_divisions.py +++ b/tests/track_errors/test_divisions.py @@ -20,49 +20,44 @@ class TestStandardsDivisions: Tests are written for sparse annotations """ - @pytest.mark.parametrize("t_div", [0, 1, 2]) - def test_good_div(self, t_div): + @pytest.mark.parametrize("t_div,div_node", [(0, (1, 6)), (1, (2, 6)), (2, (3, 8))]) + def test_good_div(self, t_div, div_node): matched = ex_graphs.good_div(t_div) _classify_divisions(matched) - div_node = {0: (1, 6), 1: (2, 6), 2: (3, 8)} - assert matched.gt_graph.nodes[div_node[t_div][0]].get(NodeFlag.TP_DIV) is True - assert matched.pred_graph.nodes[div_node[t_div][1]].get(NodeFlag.TP_DIV) is True + assert matched.gt_graph.nodes[div_node[0]].get(NodeFlag.TP_DIV) is True + assert matched.pred_graph.nodes[div_node[1]].get(NodeFlag.TP_DIV) is True - @pytest.mark.parametrize("t_div", [0, 1]) - def test_fp_div(self, t_div): + @pytest.mark.parametrize("t_div,pred_div_node", [(0, 6), (1, 6)]) + def test_fp_div(self, t_div, pred_div_node): matched = ex_graphs.fp_div(t_div) _classify_divisions(matched) - pred_div_node = [6, 6] - pred_node_attr = matched.pred_graph.nodes[pred_div_node[t_div]] + pred_node_attr = matched.pred_graph.nodes[pred_div_node] assert pred_node_attr.get(NodeFlag.FP_DIV) is True - @pytest.mark.parametrize("t_div", [0, 1]) - def test_one_child(self, t_div): + @pytest.mark.parametrize("t_div,gt_div_node", [(0, 1), (1, 2)]) + def test_one_child(self, t_div, gt_div_node): matched = ex_graphs.one_child(t_div) _classify_divisions(matched) - gt_div_node = [1, 2] - gt_node_attr = matched.gt_graph.nodes[gt_div_node[t_div]] + gt_node_attr = matched.gt_graph.nodes[gt_div_node] assert gt_node_attr.get(NodeFlag.FN_DIV) is True - @pytest.mark.parametrize("t_div", [0, 1]) - def test_no_children(self, t_div): + @pytest.mark.parametrize("t_div,gt_div_node", [(0, 1), (1, 2)]) + def test_no_children(self, t_div, gt_div_node): matched = ex_graphs.no_children(t_div) _classify_divisions(matched) - gt_div_node = [1, 2] - gt_node_attr = matched.gt_graph.nodes[gt_div_node[t_div]] + gt_node_attr = matched.gt_graph.nodes[gt_div_node] assert gt_node_attr.get(NodeFlag.FN_DIV) is True - @pytest.mark.parametrize("t_div", [0, 1]) - def test_wrong_child(self, t_div): + @pytest.mark.parametrize("t_div,gt_div_node", [(0, 1), (1, 2)]) + def test_wrong_child(self, t_div, gt_div_node): matched = ex_graphs.wrong_child(t_div) _classify_divisions(matched) - gt_div_node = [1, 2] - gt_node_attr = matched.gt_graph.nodes[gt_div_node[t_div]] + gt_node_attr = matched.gt_graph.nodes[gt_div_node] assert gt_node_attr.get(NodeFlag.FN_DIV) is True @@ -112,22 +107,14 @@ class TestStandardShifted: @pytest.mark.parametrize("n_frames", [1, 2]) @pytest.mark.parametrize( - "get_data", - [ex_graphs.div_1early_end, ex_graphs.div_1early_mid], + "matched, gt_node, pred_node", + [(ex_graphs.div_1early_end(), 2, 9), (ex_graphs.div_1early_mid(), 3, 9)], ids=["div_1early_end", "div_1early_mid"], ) - def test_div_1early(self, n_frames, get_data): - matched = get_data() + def test_div_1early(self, n_frames, matched, gt_node, pred_node): _classify_divisions(matched) shifted = _correct_shifted_divisions(matched, n_frames=n_frames) - if get_data.__name__ == "div_1early_end": - gt_node = 2 - pred_node = 9 - elif get_data.__name__ == "div_1early_mid": - gt_node = 3 - pred_node = 9 - attrs = shifted.gt_graph.nodes[gt_node] assert attrs.get(NodeFlag.TP_DIV) is True assert attrs.get(NodeFlag.FN_DIV) is False @@ -138,22 +125,14 @@ def test_div_1early(self, n_frames, get_data): @pytest.mark.parametrize("n_frames", [1, 3]) @pytest.mark.parametrize( - "get_data", - [ex_graphs.div_2early_end, ex_graphs.div_2early_mid], + "matched, gt_node, pred_node", + [(ex_graphs.div_2early_end(), 3, 8), (ex_graphs.div_2early_mid(), 4, 8)], ids=["div_2early_end", "div_2early_mid"], ) - def test_div_2early(self, n_frames, get_data): - matched = get_data() + def test_div_2early(self, n_frames, matched, gt_node, pred_node): _classify_divisions(matched) shifted = _correct_shifted_divisions(matched, n_frames=n_frames) - if get_data.__name__ == "div_2early_end": - gt_node = 3 - pred_node = 8 - elif get_data.__name__ == "div_2early_mid": - gt_node = 4 - pred_node = 8 - if n_frames == 1: # Not corrected attrs = shifted.gt_graph.nodes[gt_node] assert attrs.get(NodeFlag.FN_DIV) is True @@ -171,22 +150,14 @@ def test_div_2early(self, n_frames, get_data): @pytest.mark.parametrize("n_frames", [1, 2]) @pytest.mark.parametrize( - "get_data", - [ex_graphs.div_1late_end, ex_graphs.div_1late_mid], + "matched, gt_node, pred_node", + [(ex_graphs.div_1late_end(), 1, 11), (ex_graphs.div_1late_mid(), 2, 11)], ids=["div_1late_end", "div_1late_mid"], ) - def test_div_1late(self, n_frames, get_data): - matched = get_data() + def test_div_1late(self, n_frames, matched, gt_node, pred_node): _classify_divisions(matched) shifted = _correct_shifted_divisions(matched, n_frames=n_frames) - if get_data.__name__ == "div_1late_end": - gt_node = 1 - pred_node = 11 - elif get_data.__name__ == "div_1late_mid": - gt_node = 2 - pred_node = 11 - attrs = shifted.gt_graph.nodes[gt_node] assert attrs.get(NodeFlag.TP_DIV) is True assert attrs.get(NodeFlag.FN_DIV) is False @@ -197,22 +168,14 @@ def test_div_1late(self, n_frames, get_data): @pytest.mark.parametrize("n_frames", [1, 3]) @pytest.mark.parametrize( - "get_data", - [ex_graphs.div_2late_end, ex_graphs.div_2late_mid], + "matched, gt_node, pred_node", + [(ex_graphs.div_2late_end(), 1, 12), (ex_graphs.div_2late_mid(), 2, 12)], ids=["div_2late_end", "div_2late_mid"], ) - def test_div_2late(self, n_frames, get_data): - matched = get_data() + def test_div_2late(self, n_frames, matched, gt_node, pred_node): _classify_divisions(matched) shifted = _correct_shifted_divisions(matched, n_frames=n_frames) - if get_data.__name__ == "div_2late_end": - gt_node = 1 - pred_node = 12 - elif get_data.__name__ == "div_2late_mid": - gt_node = 2 - pred_node = 12 - if n_frames == 1: # Not corrected attrs = shifted.gt_graph.nodes[gt_node] assert attrs.get(NodeFlag.FN_DIV) is True