Skip to content

Commit

Permalink
Merge pull request #490 from Living-with-machines/labels_map_fix
Browse files Browse the repository at this point in the history
Fix for labels_map argument
  • Loading branch information
rwood-97 authored Aug 12, 2024
2 parents 6eec831 + 5c9fc6f commit 6eed2e2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mapreader/classify/load_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,8 @@ def _get_label_index(self, label: str) -> int:
Used to generate the ``label_index`` column.
"""
return self.unique_labels.index(label)
index_map = {v: k for k, v in self.labels_map.items()}
return index_map[label]

def __str__(self):
print(f"[INFO] Number of annotations: {len(self.annotations)}\n")
Expand Down
28 changes: 28 additions & 0 deletions tests/test_classify/test_annotations_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,34 @@ def test_labels_map(sample_dir):
assert annots.labels_map == {0: "railspace", 1: "no", 2: "building"}


def test_get_label_index(sample_dir):
annots = AnnotationsLoader()
annots.load(
f"{sample_dir}/test_annots.csv",
reset_index=True,
remove_broken=False,
ignore_broken=True,
labels_map={
0: "railspace",
1: "no",
}, # different order vs in the csv
)
assert annots.labels_map == {0: "railspace", 1: "no"}
assert annots._get_label_index("railspace") == 0

# test append
annots.load(
f"{sample_dir}/test_annots_append.csv",
append=True,
remove_broken=False,
ignore_broken=True,
)
assert annots.unique_labels == ["no", "railspace", "building"]
assert annots.labels_map == {0: "railspace", 1: "no", 2: "building"}
assert annots._get_label_index("railspace") == 0
assert annots._get_label_index("building") == 2


@pytest.mark.dependency(name="load_annots_df", scope="session")
def test_load_df(sample_dir):
annots = AnnotationsLoader()
Expand Down

0 comments on commit 6eed2e2

Please sign in to comment.