diff --git a/tests/track_errors/test_divisions.py b/tests/track_errors/test_divisions.py index c704d24..b271bfc 100644 --- a/tests/track_errors/test_divisions.py +++ b/tests/track_errors/test_divisions.py @@ -20,49 +20,44 @@ class TestStandardsDivisions: Tests are written for sparse annotations """ - @pytest.mark.parametrize("t_div", [0, 1, 2]) - def test_good_div(self, t_div): + @pytest.mark.parametrize("t_div,div_node", [(0, (1, 6)), (1, (2, 6)), (2, (3, 8))]) + def test_good_div(self, t_div, div_node): matched = ex_graphs.good_div(t_div) _classify_divisions(matched) - div_node = {0: (1, 6), 1: (2, 6), 2: (3, 8)} - assert matched.gt_graph.nodes[div_node[t_div][0]].get(NodeFlag.TP_DIV) is True - assert matched.pred_graph.nodes[div_node[t_div][1]].get(NodeFlag.TP_DIV) is True + assert matched.gt_graph.nodes[div_node[0]].get(NodeFlag.TP_DIV) is True + assert matched.pred_graph.nodes[div_node[1]].get(NodeFlag.TP_DIV) is True - @pytest.mark.parametrize("t_div", [0, 1]) - def test_fp_div(self, t_div): + @pytest.mark.parametrize("t_div,pred_div_node", [(0, 6), (1, 6)]) + def test_fp_div(self, t_div, pred_div_node): matched = ex_graphs.fp_div(t_div) _classify_divisions(matched) - pred_div_node = [6, 6] - pred_node_attr = matched.pred_graph.nodes[pred_div_node[t_div]] + pred_node_attr = matched.pred_graph.nodes[pred_div_node] assert pred_node_attr.get(NodeFlag.FP_DIV) is True - @pytest.mark.parametrize("t_div", [0, 1]) - def test_one_child(self, t_div): + @pytest.mark.parametrize("t_div,gt_div_node", [(0, 1), (1, 2)]) + def test_one_child(self, t_div, gt_div_node): matched = ex_graphs.one_child(t_div) _classify_divisions(matched) - gt_div_node = [1, 2] - gt_node_attr = matched.gt_graph.nodes[gt_div_node[t_div]] + gt_node_attr = matched.gt_graph.nodes[gt_div_node] assert gt_node_attr.get(NodeFlag.FN_DIV) is True - @pytest.mark.parametrize("t_div", [0, 1]) - def test_no_children(self, t_div): + @pytest.mark.parametrize("t_div,gt_div_node", [(0, 1), (1, 2)]) + def test_no_children(self, t_div, gt_div_node): matched = ex_graphs.no_children(t_div) _classify_divisions(matched) - gt_div_node = [1, 2] - gt_node_attr = matched.gt_graph.nodes[gt_div_node[t_div]] + gt_node_attr = matched.gt_graph.nodes[gt_div_node] assert gt_node_attr.get(NodeFlag.FN_DIV) is True - @pytest.mark.parametrize("t_div", [0, 1]) - def test_wrong_child(self, t_div): + @pytest.mark.parametrize("t_div,gt_div_node", [(0, 1), (1, 2)]) + def test_wrong_child(self, t_div, gt_div_node): matched = ex_graphs.wrong_child(t_div) _classify_divisions(matched) - gt_div_node = [1, 2] - gt_node_attr = matched.gt_graph.nodes[gt_div_node[t_div]] + gt_node_attr = matched.gt_graph.nodes[gt_div_node] assert gt_node_attr.get(NodeFlag.FN_DIV) is True @@ -112,22 +107,14 @@ class TestStandardShifted: @pytest.mark.parametrize("n_frames", [1, 2]) @pytest.mark.parametrize( - "get_data", - [ex_graphs.div_1early_end, ex_graphs.div_1early_mid], + "matched, gt_node, pred_node", + [(ex_graphs.div_1early_end(), 2, 9), (ex_graphs.div_1early_mid(), 3, 9)], ids=["div_1early_end", "div_1early_mid"], ) - def test_div_1early(self, n_frames, get_data): - matched = get_data() + def test_div_1early(self, n_frames, matched, gt_node, pred_node): _classify_divisions(matched) shifted = _correct_shifted_divisions(matched, n_frames=n_frames) - if get_data.__name__ == "div_1early_end": - gt_node = 2 - pred_node = 9 - elif get_data.__name__ == "div_1early_mid": - gt_node = 3 - pred_node = 9 - attrs = shifted.gt_graph.nodes[gt_node] assert attrs.get(NodeFlag.TP_DIV) is True assert attrs.get(NodeFlag.FN_DIV) is False @@ -138,22 +125,14 @@ def test_div_1early(self, n_frames, get_data): @pytest.mark.parametrize("n_frames", [1, 3]) @pytest.mark.parametrize( - "get_data", - [ex_graphs.div_2early_end, ex_graphs.div_2early_mid], + "matched, gt_node, pred_node", + [(ex_graphs.div_2early_end(), 3, 8), (ex_graphs.div_2early_mid(), 4, 8)], ids=["div_2early_end", "div_2early_mid"], ) - def test_div_2early(self, n_frames, get_data): - matched = get_data() + def test_div_2early(self, n_frames, matched, gt_node, pred_node): _classify_divisions(matched) shifted = _correct_shifted_divisions(matched, n_frames=n_frames) - if get_data.__name__ == "div_2early_end": - gt_node = 3 - pred_node = 8 - elif get_data.__name__ == "div_2early_mid": - gt_node = 4 - pred_node = 8 - if n_frames == 1: # Not corrected attrs = shifted.gt_graph.nodes[gt_node] assert attrs.get(NodeFlag.FN_DIV) is True @@ -171,22 +150,14 @@ def test_div_2early(self, n_frames, get_data): @pytest.mark.parametrize("n_frames", [1, 2]) @pytest.mark.parametrize( - "get_data", - [ex_graphs.div_1late_end, ex_graphs.div_1late_mid], + "matched, gt_node, pred_node", + [(ex_graphs.div_1late_end(), 1, 11), (ex_graphs.div_1late_mid(), 2, 11)], ids=["div_1late_end", "div_1late_mid"], ) - def test_div_1late(self, n_frames, get_data): - matched = get_data() + def test_div_1late(self, n_frames, matched, gt_node, pred_node): _classify_divisions(matched) shifted = _correct_shifted_divisions(matched, n_frames=n_frames) - if get_data.__name__ == "div_1late_end": - gt_node = 1 - pred_node = 11 - elif get_data.__name__ == "div_1late_mid": - gt_node = 2 - pred_node = 11 - attrs = shifted.gt_graph.nodes[gt_node] assert attrs.get(NodeFlag.TP_DIV) is True assert attrs.get(NodeFlag.FN_DIV) is False @@ -197,22 +168,14 @@ def test_div_1late(self, n_frames, get_data): @pytest.mark.parametrize("n_frames", [1, 3]) @pytest.mark.parametrize( - "get_data", - [ex_graphs.div_2late_end, ex_graphs.div_2late_mid], + "matched, gt_node, pred_node", + [(ex_graphs.div_2late_end(), 1, 12), (ex_graphs.div_2late_mid(), 2, 12)], ids=["div_2late_end", "div_2late_mid"], ) - def test_div_2late(self, n_frames, get_data): - matched = get_data() + def test_div_2late(self, n_frames, matched, gt_node, pred_node): _classify_divisions(matched) shifted = _correct_shifted_divisions(matched, n_frames=n_frames) - if get_data.__name__ == "div_2late_end": - gt_node = 1 - pred_node = 12 - elif get_data.__name__ == "div_2late_mid": - gt_node = 2 - pred_node = 12 - if n_frames == 1: # Not corrected attrs = shifted.gt_graph.nodes[gt_node] assert attrs.get(NodeFlag.FN_DIV) is True