Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implements #409: added strict monotonicity flag for hierarchical segmentation metrics #414

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ Thumbs.db
# Vim
*.swp

# pycharm
# IDEs
.idea/*
.vscode/*

# docs
docs/_build/*
Expand Down
92 changes: 75 additions & 17 deletions mir_eval/hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _align_intervals(int_hier, lab_hier, t_min=0.0, t_max=None):
]


def _lca(intervals_hier, frame_size):
def _lca(intervals_hier, frame_size, strict_mono=False):
"""Compute the (sparse) least-common-ancestor (LCA) matrix for a
hierarchical segmentation.

Expand All @@ -147,6 +147,10 @@ def _lca(intervals_hier, frame_size):
The list is assumed to be ordered by increasing specificity (depth).
frame_size : number
The length of the sample frames (in seconds)
strict_mono : bool, optional
If True, enforce monotonic updates for the LCA matrix. Only positions that were set to
the previous level (i.e., equal to level - 1) will be updated to the current level.
If False, the current level is applied unconditionally. Default is False.

Returns
-------
Expand All @@ -170,18 +174,25 @@ def _lca(intervals_hier, frame_size):
int
):
idx = slice(ival[0], ival[1])
lca_matrix[idx, idx] = level
if level == 1 or not strict_mono:
lca_matrix[idx, idx] = level
else:
# Check if the segments' parents have matching labeling
current_meet = lca_matrix[idx, idx].toarray()
matching_parents_mask = current_meet == level - 1
# Update only at positions where the previous level also matches
current_meet[matching_parents_mask] = level
lca_matrix[idx, idx] = current_meet

return lca_matrix.tocsr()


def _meet(intervals_hier, labels_hier, frame_size):
"""Compute the (sparse) least-common-ancestor (LCA) matrix for a
def _meet(intervals_hier, labels_hier, frame_size, strict_mono=False):
"""Compute the (sparse) annotation meet matrix for a
hierarchical segmentation.

For any pair of frames ``(s, t)``, the LCA is the deepest level in
the hierarchy such that ``(s, t)`` are contained within a single
segment at that level.
For any pair of frames ``(s, t)``, the annotation meet matrix is the deepest level
in the hierarchy such that ``(s, t)`` receive the same segment label, i.e. they meet.

Parameters
----------
Expand All @@ -193,6 +204,10 @@ def _meet(intervals_hier, labels_hier, frame_size):
``i``th layer of the annotations
frame_size : number
The length of the sample frames (in seconds)
strict_mono : bool, optional
If True, enforce monotonic updates for the LCA matrix. Only positions that were set to
the previous level (i.e., equal to level - 1) will be updated to the current level.
If False, the current level is applied unconditionally. Default is False.

Returns
-------
Expand Down Expand Up @@ -225,9 +240,21 @@ def _meet(intervals_hier, labels_hier, frame_size):
for seg_i, seg_j in zip(*np.where(int_agree)):
idx_i = slice(*list(int_frames[seg_i]))
idx_j = slice(*list(int_frames[seg_j]))
meet_matrix[idx_i, idx_j] = level
if seg_i != seg_j:
meet_matrix[idx_j, idx_i] = level

if level == 1 or not strict_mono:
meet_matrix[idx_i, idx_j] = level
if seg_i != seg_j:
meet_matrix[idx_j, idx_i] = level

else:
# Extract current submatrix and update elementwise
current_meet = meet_matrix[idx_i, idx_j].toarray()
mask = current_meet == (level - 1)
current_meet[mask] = level
meet_matrix[idx_i, idx_j] = current_meet

if seg_i != seg_j:
meet_matrix[idx_j, idx_i] = current_meet.T

return scipy.sparse.csr_matrix(meet_matrix)

Expand Down Expand Up @@ -446,21 +473,40 @@ def validate_hier_intervals(intervals_hier):
# Synthesize a label array for the top layer.
label_top = util.generate_labels(intervals_hier[0])

boundaries = set(util.intervals_to_boundaries(intervals_hier[0]))

for level, intervals in enumerate(intervals_hier[1:], 1):
for intervals in intervals_hier[1:]:
# Make sure this level is consistent with the root
label_current = util.generate_labels(intervals)
validate_structure(intervals_hier[0], label_top, intervals, label_current)

check_monotonic_boundaries(intervals_hier)


def check_monotonic_boundaries(intervals_hier):
"""Check that a hierarchical annotation has monotnoic boundaries.

Parameters
----------
intervals_hier : ordered list of segmentations

Returns
-------
bool
True if the annotation has monotnoic boundaries, False otherwise
"""
result = True
boundaries = set(util.intervals_to_boundaries(intervals_hier[0]))

for level, intervals in enumerate(intervals_hier[1:], 1):
# Make sure all previous boundaries are accounted for
new_bounds = set(util.intervals_to_boundaries(intervals))

if boundaries - new_bounds:
warnings.warn(
"Segment hierarchy is inconsistent " "at level {:d}".format(level)
)
result = False
boundaries |= new_bounds
return result


def tmeasure(
Expand All @@ -470,6 +516,7 @@ def tmeasure(
window=15.0,
frame_size=0.1,
beta=1.0,
strict_mono=False,
):
"""Compute the tree measures for hierarchical segment annotations.

Expand Down Expand Up @@ -533,8 +580,8 @@ def tmeasure(
validate_hier_intervals(estimated_intervals_hier)

# Build the least common ancestor matrices
ref_lca = _lca(reference_intervals_hier, frame_size)
est_lca = _lca(estimated_intervals_hier, frame_size)
ref_lca = _lca(reference_intervals_hier, frame_size, strict_mono=strict_mono)
est_lca = _lca(estimated_intervals_hier, frame_size, strict_mono=strict_mono)

# Compute precision and recall
t_recall = _gauc(ref_lca, est_lca, transitive, window_frames)
Expand All @@ -552,6 +599,7 @@ def lmeasure(
estimated_labels_hier,
frame_size=0.1,
beta=1.0,
strict_mono=False,
):
"""Compute the tree measures for hierarchical segment annotations.

Expand Down Expand Up @@ -604,8 +652,18 @@ def lmeasure(
validate_hier_intervals(estimated_intervals_hier)

# Build the least common ancestor matrices
ref_meet = _meet(reference_intervals_hier, reference_labels_hier, frame_size)
est_meet = _meet(estimated_intervals_hier, estimated_labels_hier, frame_size)
ref_meet = _meet(
reference_intervals_hier,
reference_labels_hier,
frame_size,
strict_mono=strict_mono,
)
est_meet = _meet(
estimated_intervals_hier,
estimated_labels_hier,
frame_size,
strict_mono=strict_mono,
)

# Compute precision and recall
l_recall = _gauc(ref_meet, est_meet, True, None)
Expand Down
113 changes: 107 additions & 6 deletions tests/test_hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

@pytest.mark.parametrize("window", [5, 10, 15, 30, 90, None])
@pytest.mark.parametrize("frame_size", [0.1, 0.5, 1.0])
def test_tmeasure_pass(window, frame_size):
@pytest.mark.parametrize("strict_mono", [True, False])
def test_tmeasure_pass(window, frame_size, strict_mono):
# The estimate here gets none of the structure correct.
ref = [[[0, 30]], [[0, 15], [15, 30]]]
# convert to arrays
Expand All @@ -27,13 +28,17 @@ def test_tmeasure_pass(window, frame_size):
est = ref[:1]

# The estimate should get 0 score here
scores = mir_eval.hierarchy.tmeasure(ref, est, window=window, frame_size=frame_size)
scores = mir_eval.hierarchy.tmeasure(
ref, est, window=window, frame_size=frame_size, strict_mono=strict_mono
)

for k in scores:
assert k == 0.0

# The reference should get a perfect score here
scores = mir_eval.hierarchy.tmeasure(ref, ref, window=window, frame_size=frame_size)
scores = mir_eval.hierarchy.tmeasure(
ref, ref, window=window, frame_size=frame_size, strict_mono=strict_mono
)

for k in scores:
assert k == 1.0
Expand Down Expand Up @@ -91,7 +96,8 @@ def test_tmeasure_fail_frame_size(window, frame_size):


@pytest.mark.parametrize("frame_size", [0.1, 0.5, 1.0])
def test_lmeasure_pass(frame_size):
@pytest.mark.parametrize("strict_mono", [True, False])
def test_lmeasure_pass(frame_size, strict_mono):
# The estimate here gets none of the structure correct.
ref = [[[0, 30]], [[0, 15], [15, 30]]]
ref_lab = [["A"], ["a", "b"]]
Expand All @@ -104,15 +110,15 @@ def test_lmeasure_pass(frame_size):

# The estimate should get 0 score here
scores = mir_eval.hierarchy.lmeasure(
ref, ref_lab, est, est_lab, frame_size=frame_size
ref, ref_lab, est, est_lab, frame_size=frame_size, strict_mono=strict_mono
)

for k in scores:
assert k == 0.0

# The reference should get a perfect score here
scores = mir_eval.hierarchy.lmeasure(
ref, ref_lab, ref, ref_lab, frame_size=frame_size
ref, ref_lab, ref, ref_lab, frame_size=frame_size, strict_mono=strict_mono
)

for k in scores:
Expand Down Expand Up @@ -286,6 +292,101 @@ def test_meet():
assert np.all(meet == meet_truth)


def test_strict_mono():
frame_size = 1
int_hier = [
np.array([[0, 10]]),
np.array([[0, 6], [6, 10]]),
np.array([[0, 2], [2, 4], [4, 8], [8, 10]]),
]

lab_hier = [["X"], ["A", "B"], ["a", "b", "c", "b"]]

# Target output
meet_truth = np.asarray(
[
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
[2, 2, 3, 3, 2, 2, 1, 1, 3, 3], # (XAb)
[2, 2, 3, 3, 2, 2, 1, 1, 3, 3], # (XAb)
[2, 2, 2, 2, 3, 3, 3, 3, 1, 1], # (XAc)
[2, 2, 2, 2, 3, 3, 3, 3, 1, 1], # (XAc)
[1, 1, 1, 1, 3, 3, 3, 3, 2, 2], # (XBc)
[1, 1, 1, 1, 3, 3, 3, 3, 2, 2], # (XBc)
[1, 1, 3, 3, 1, 1, 2, 2, 3, 3], # (XBb)
[1, 1, 3, 3, 1, 1, 2, 2, 3, 3], # (XBb)
]
)
meet_truth_strict_mono = np.asarray(
[
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
[2, 2, 3, 3, 2, 2, 1, 1, 1, 1], # (XAb)
[2, 2, 3, 3, 2, 2, 1, 1, 1, 1], # (XAb)
[2, 2, 2, 2, 3, 3, 1, 1, 1, 1], # (XAc)
[2, 2, 2, 2, 3, 3, 1, 1, 1, 1], # (XAc)
[1, 1, 1, 1, 1, 1, 3, 3, 2, 2], # (XBc)
[1, 1, 1, 1, 1, 1, 3, 3, 2, 2], # (XBc)
[1, 1, 1, 1, 1, 1, 2, 2, 3, 3], # (XBb)
[1, 1, 1, 1, 1, 1, 2, 2, 3, 3], # (XBb)
]
)
lca_truth = np.asarray(
[
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
[2, 2, 3, 3, 2, 2, 1, 1, 1, 1], # (XAb)
[2, 2, 3, 3, 2, 2, 1, 1, 1, 1], # (XAb)
[2, 2, 2, 2, 3, 3, 3, 3, 1, 1], # (XAc)
[2, 2, 2, 2, 3, 3, 3, 3, 1, 1], # (XAc)
[1, 1, 1, 1, 3, 3, 3, 3, 2, 2], # (XBc)
[1, 1, 1, 1, 3, 3, 3, 3, 2, 2], # (XBc)
[1, 1, 1, 1, 1, 1, 2, 2, 3, 3], # (XBd)
[1, 1, 1, 1, 1, 1, 2, 2, 3, 3], # (XBd)
]
)
lca_truth_strict_mono = np.asarray(
[
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
[2, 2, 3, 3, 2, 2, 1, 1, 1, 1], # (XAb)
[2, 2, 3, 3, 2, 2, 1, 1, 1, 1], # (XAb)
[2, 2, 2, 2, 3, 3, 1, 1, 1, 1], # (XAc)
[2, 2, 2, 2, 3, 3, 1, 1, 1, 1], # (XAc)
[1, 1, 1, 1, 1, 1, 3, 3, 2, 2], # (XBc)
[1, 1, 1, 1, 1, 1, 3, 3, 2, 2], # (XBc)
[1, 1, 1, 1, 1, 1, 2, 2, 3, 3], # (XBd)
[1, 1, 1, 1, 1, 1, 2, 2, 3, 3], # (XBd)
]
)

meet = mir_eval.hierarchy._meet(int_hier, lab_hier, frame_size, strict_mono=False)
meet_strict_mono = mir_eval.hierarchy._meet(
int_hier, lab_hier, frame_size, strict_mono=True
)
lca = mir_eval.hierarchy._lca(int_hier, frame_size, strict_mono=False)
lca_strict_mono = mir_eval.hierarchy._lca(int_hier, frame_size, strict_mono=True)
# Is it the right type?
assert isinstance(meet, scipy.sparse.csr_matrix)
meet = meet.toarray()
meet_strict_mono = meet_strict_mono.toarray()
assert isinstance(lca_strict_mono, scipy.sparse.csr_matrix)
lca = lca.toarray()
lca_strict_mono = lca_strict_mono.toarray()

# Does it have the right shape?
assert meet.shape == (10, 10)
assert meet_strict_mono.shape == (10, 10)
assert lca.shape == (10, 10)
assert lca_strict_mono.shape == (10, 10)

# Does it have the right value?
assert np.all(meet == meet_truth)
assert np.all(meet_strict_mono == meet_truth_strict_mono)
assert np.all(lca == lca_truth)
assert np.all(lca_strict_mono == lca_truth_strict_mono)


def test_compare_frame_rankings():
# number of pairs (i, j)
# where ref[i] < ref[j] and est[i] >= est[j]
Expand Down