Skip to content

Commit

Permalink
Improve parametrization of more complex div tests
Browse files Browse the repository at this point in the history
  • Loading branch information
msschwartz21 committed Dec 19, 2024
1 parent b705b82 commit 0b6ea61
Showing 1 changed file with 28 additions and 65 deletions.
93 changes: 28 additions & 65 deletions tests/track_errors/test_divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,49 +20,44 @@ class TestStandardsDivisions:
Tests are written for sparse annotations
"""

@pytest.mark.parametrize("t_div", [0, 1, 2])
def test_good_div(self, t_div):
@pytest.mark.parametrize("t_div,div_node", [(0, (1, 6)), (1, (2, 6)), (2, (3, 8))])
def test_good_div(self, t_div, div_node):
matched = ex_graphs.good_div(t_div)
_classify_divisions(matched)

div_node = {0: (1, 6), 1: (2, 6), 2: (3, 8)}
assert matched.gt_graph.nodes[div_node[t_div][0]].get(NodeFlag.TP_DIV) is True
assert matched.pred_graph.nodes[div_node[t_div][1]].get(NodeFlag.TP_DIV) is True
assert matched.gt_graph.nodes[div_node[0]].get(NodeFlag.TP_DIV) is True
assert matched.pred_graph.nodes[div_node[1]].get(NodeFlag.TP_DIV) is True

@pytest.mark.parametrize("t_div", [0, 1])
def test_fp_div(self, t_div):
@pytest.mark.parametrize("t_div,pred_div_node", [(0, 6), (1, 6)])
def test_fp_div(self, t_div, pred_div_node):
matched = ex_graphs.fp_div(t_div)
_classify_divisions(matched)

pred_div_node = [6, 6]
pred_node_attr = matched.pred_graph.nodes[pred_div_node[t_div]]
pred_node_attr = matched.pred_graph.nodes[pred_div_node]
assert pred_node_attr.get(NodeFlag.FP_DIV) is True

@pytest.mark.parametrize("t_div", [0, 1])
def test_one_child(self, t_div):
@pytest.mark.parametrize("t_div,gt_div_node", [(0, 1), (1, 2)])
def test_one_child(self, t_div, gt_div_node):
matched = ex_graphs.one_child(t_div)
_classify_divisions(matched)

gt_div_node = [1, 2]
gt_node_attr = matched.gt_graph.nodes[gt_div_node[t_div]]
gt_node_attr = matched.gt_graph.nodes[gt_div_node]
assert gt_node_attr.get(NodeFlag.FN_DIV) is True

@pytest.mark.parametrize("t_div", [0, 1])
def test_no_children(self, t_div):
@pytest.mark.parametrize("t_div,gt_div_node", [(0, 1), (1, 2)])
def test_no_children(self, t_div, gt_div_node):
matched = ex_graphs.no_children(t_div)
_classify_divisions(matched)

gt_div_node = [1, 2]
gt_node_attr = matched.gt_graph.nodes[gt_div_node[t_div]]
gt_node_attr = matched.gt_graph.nodes[gt_div_node]
assert gt_node_attr.get(NodeFlag.FN_DIV) is True

@pytest.mark.parametrize("t_div", [0, 1])
def test_wrong_child(self, t_div):
@pytest.mark.parametrize("t_div,gt_div_node", [(0, 1), (1, 2)])
def test_wrong_child(self, t_div, gt_div_node):
matched = ex_graphs.wrong_child(t_div)
_classify_divisions(matched)

gt_div_node = [1, 2]
gt_node_attr = matched.gt_graph.nodes[gt_div_node[t_div]]
gt_node_attr = matched.gt_graph.nodes[gt_div_node]
assert gt_node_attr.get(NodeFlag.FN_DIV) is True


Expand Down Expand Up @@ -112,22 +107,14 @@ class TestStandardShifted:

@pytest.mark.parametrize("n_frames", [1, 2])
@pytest.mark.parametrize(
"get_data",
[ex_graphs.div_1early_end, ex_graphs.div_1early_mid],
"matched, gt_node, pred_node",
[(ex_graphs.div_1early_end(), 2, 9), (ex_graphs.div_1early_mid(), 3, 9)],
ids=["div_1early_end", "div_1early_mid"],
)
def test_div_1early(self, n_frames, get_data):
matched = get_data()
def test_div_1early(self, n_frames, matched, gt_node, pred_node):
_classify_divisions(matched)
shifted = _correct_shifted_divisions(matched, n_frames=n_frames)

if get_data.__name__ == "div_1early_end":
gt_node = 2
pred_node = 9
elif get_data.__name__ == "div_1early_mid":
gt_node = 3
pred_node = 9

attrs = shifted.gt_graph.nodes[gt_node]
assert attrs.get(NodeFlag.TP_DIV) is True
assert attrs.get(NodeFlag.FN_DIV) is False
Expand All @@ -138,22 +125,14 @@ def test_div_1early(self, n_frames, get_data):

@pytest.mark.parametrize("n_frames", [1, 3])
@pytest.mark.parametrize(
"get_data",
[ex_graphs.div_2early_end, ex_graphs.div_2early_mid],
"matched, gt_node, pred_node",
[(ex_graphs.div_2early_end(), 3, 8), (ex_graphs.div_2early_mid(), 4, 8)],
ids=["div_2early_end", "div_2early_mid"],
)
def test_div_2early(self, n_frames, get_data):
matched = get_data()
def test_div_2early(self, n_frames, matched, gt_node, pred_node):
_classify_divisions(matched)
shifted = _correct_shifted_divisions(matched, n_frames=n_frames)

if get_data.__name__ == "div_2early_end":
gt_node = 3
pred_node = 8
elif get_data.__name__ == "div_2early_mid":
gt_node = 4
pred_node = 8

if n_frames == 1: # Not corrected
attrs = shifted.gt_graph.nodes[gt_node]
assert attrs.get(NodeFlag.FN_DIV) is True
Expand All @@ -171,22 +150,14 @@ def test_div_2early(self, n_frames, get_data):

@pytest.mark.parametrize("n_frames", [1, 2])
@pytest.mark.parametrize(
"get_data",
[ex_graphs.div_1late_end, ex_graphs.div_1late_mid],
"matched, gt_node, pred_node",
[(ex_graphs.div_1late_end(), 1, 11), (ex_graphs.div_1late_mid(), 2, 11)],
ids=["div_1late_end", "div_1late_mid"],
)
def test_div_1late(self, n_frames, get_data):
matched = get_data()
def test_div_1late(self, n_frames, matched, gt_node, pred_node):
_classify_divisions(matched)
shifted = _correct_shifted_divisions(matched, n_frames=n_frames)

if get_data.__name__ == "div_1late_end":
gt_node = 1
pred_node = 11
elif get_data.__name__ == "div_1late_mid":
gt_node = 2
pred_node = 11

attrs = shifted.gt_graph.nodes[gt_node]
assert attrs.get(NodeFlag.TP_DIV) is True
assert attrs.get(NodeFlag.FN_DIV) is False
Expand All @@ -197,22 +168,14 @@ def test_div_1late(self, n_frames, get_data):

@pytest.mark.parametrize("n_frames", [1, 3])
@pytest.mark.parametrize(
"get_data",
[ex_graphs.div_2late_end, ex_graphs.div_2late_mid],
"matched, gt_node, pred_node",
[(ex_graphs.div_2late_end(), 1, 12), (ex_graphs.div_2late_mid(), 2, 12)],
ids=["div_2late_end", "div_2late_mid"],
)
def test_div_2late(self, n_frames, get_data):
matched = get_data()
def test_div_2late(self, n_frames, matched, gt_node, pred_node):
_classify_divisions(matched)
shifted = _correct_shifted_divisions(matched, n_frames=n_frames)

if get_data.__name__ == "div_2late_end":
gt_node = 1
pred_node = 12
elif get_data.__name__ == "div_2late_mid":
gt_node = 2
pred_node = 12

if n_frames == 1: # Not corrected
attrs = shifted.gt_graph.nodes[gt_node]
assert attrs.get(NodeFlag.FN_DIV) is True
Expand Down

1 comment on commit 0b6ea61

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Mean (s) BASE 68d63ad Mean (s) HEAD 0b6ea61 Percent Change
test_load_gt_ctc_data[2d] 5.57766 5.56953 -0.15
test_load_gt_ctc_data[3d] 19.4027 18.9305 -2.43
test_load_pred_ctc_data[2d] 1.13963 1.16123 1.9
test_ctc_checks[2d] 0.75891 0.76482 0.78
test_ctc_checks[3d] 9.77632 9.67958 -0.99
test_ctc_matcher[2d] 1.53613 1.51051 -1.67
test_ctc_matcher[3d] 17.0421 16.8817 -0.94
test_ctc_metrics[2d] 0.28009 0.26928 -3.86
test_ctc_metrics[3d] 4.16459 4.1813 0.4
test_iou_matcher[2d] 1.63143 1.60634 -1.54
test_iou_matcher[3d] 18.2327 17.9711 -1.43
test_iou_div_metrics[2d] 0.07561 0.07252 -4.09
test_iou_div_metrics[3d] 0.73452 0.69714 -5.09

Please sign in to comment.