Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rwood-97 committed Aug 12, 2024
1 parent daa12fc commit 6185eb8
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 5 deletions.
21 changes: 16 additions & 5 deletions mapreader/classify/load_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class AnnotationsLoader:

def __init__(self):
self.annotations = pd.DataFrame()
self.labels_map = {}
self.reviewed = pd.DataFrame()
self.patch_paths_col = None
self.label_col = None
Expand Down Expand Up @@ -135,15 +136,25 @@ def load(

self.unique_labels = self.annotations[self.label_col].unique().tolist()

# if labels_map is explicitly provided
if labels_map:
self.labels_map = labels_map
if not set(self.unique_labels).issubset(set(labels_map.keys())):
self.labels_map = dict(
sorted(labels_map.items())
) # sort labels_map by keys
if not set(self.unique_labels).issubset(set(labels_map.values())):
raise ValueError(
"[ERROR] There are label(s) in the annotations that are not in the labels_map. Please check the labels_map."
"[ERROR] There are label(s) in the annotations that are not in the labels map. Please check the labels_map."
)
# if inferring labels_map
else:
labels_map = {i: label for i, label in enumerate(self.unique_labels)}
self.labels_map = labels_map
if append:
for label in self.unique_labels:
if label not in self.labels_map.values():
self.labels_map[len(self.labels_map)] = label
else:
# reset labels map
labels_map = {i: label for i, label in enumerate(self.unique_labels)}
self.labels_map = labels_map

self.annotations["label_index"] = self.annotations[self.label_col].apply(
self._get_label_index
Expand Down
47 changes: 47 additions & 0 deletions tests/test_classify/test_annotations_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,29 @@ def test_load_csv(load_annots, sample_dir):
assert annots.labels_map == {0: "no", 1: "railspace", 2: "building"}


def test_labels_map(sample_dir):
annots = AnnotationsLoader()
annots.load(
f"{sample_dir}/test_annots.csv",
reset_index=True,
remove_broken=False,
ignore_broken=True,
labels_map={1: "no", 0: "railspace"},
)
assert len(annots.annotations) == 81
assert annots.labels_map == {0: "railspace", 1: "no"}
# test append
annots.load(
f"{sample_dir}/test_annots_append.csv",
append=True,
remove_broken=False,
ignore_broken=True,
)
assert len(annots.annotations) == 83
assert annots.unique_labels == ["no", "railspace", "building"]
assert annots.labels_map == {0: "railspace", 1: "no", 2: "building"}


@pytest.mark.dependency(name="load_annots_df", scope="session")
def test_load_df(sample_dir):
annots = AnnotationsLoader()
Expand Down Expand Up @@ -162,6 +185,30 @@ def test_create_dataloaders_no_sampler(load_annots):
# errors


def test_labels_map_errors(sample_dir):
# csv
annots = AnnotationsLoader()
with pytest.raises(ValueError, match="not in the labels map"):
annots.load(
f"{sample_dir}/test_annots.csv",
reset_index=True,
remove_broken=False,
ignore_broken=True,
labels_map={0: "no"},
)
# dataframe
annots = AnnotationsLoader()
df = pd.read_csv(f"{sample_dir}/test_annots.csv", sep=",", index_col=0)
with pytest.raises(ValueError, match="not in the labels map"):
annots.load(
df,
reset_index=True,
remove_broken=False,
ignore_broken=True,
labels_map={0: "no"},
)


def test_load_fake_csv_errors():
annots = AnnotationsLoader()
with pytest.raises(ValueError, match="cannot be found"):
Expand Down

0 comments on commit 6185eb8

Please sign in to comment.