Skip to content

Commit

Permalink
Merge pull request #489 from Living-with-machines/labels_map_fix
Browse files Browse the repository at this point in the history
Adds labels map argument when loading annotations
  • Loading branch information
rwood-97 authored Aug 12, 2024
2 parents 51e8526 + 7de87ef commit 6eec831
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ For example, if you have set up your directory as recommended in our :doc:`Input
annotated_images = AnnotationsLoader()
annotated_images.load("./annotations/railspace_#rosie#.csv")
.. admonition:: Advanced usage
:class: dropdown

Expand All @@ -43,7 +44,9 @@ To view the data loaded in from your ``csv`` as a dataframe, use:
annotated_images.annotations
You will note a ``label_index`` column has been added to your dataframe.
This column contains a numerical reference number for each label, which is needed when training your model.

This column contains a numerical reference number for each label.
This is needed when training your model so that labels can be treated as numerical values instead of strings.

To see how your labels map to their label indices, call the ``annotated_images.labels_map`` attribute:

Expand All @@ -53,6 +56,27 @@ To see how your labels map to their label indices, call the ``annotated_images.l
.. note:: This ``labels_map`` will be needed later.

By default, this ``labels_map`` is automatically generated when loading your annotations by finding unique labels in your annotations and assigning each a numerical index.
The `0` index will be assigned to the label that appears first in the annotations, `1` to the second label and so on.

.. note:: If you use the `scramble_frame` argument when loading your annotations from a file, the order of your labels will be shuffled and so the indices assigned to each label will be different each time you load your annotations.

If instead, you would like to explicitly define your labels map, you can do so by passing a dictionary to the ``labels_map`` argument when loading your annotations.

.. code-block:: python
#EXAMPLE
labels_map = {0: "no", 1: "railspace", 2: "building", 3: "railspace and building"}
annotated_images.load(
annotations="./path/to/annotations.csv",
labels_map=labels_map
)
Now, calling the ``annotated_images.labels_map`` attribute should return the dictionary you passed in.

.. note:: Using the ``labels_map`` argument is important if you are doing a second round of annotations and want to ensure that the labels are consistent between the two rounds!


To view a sample of your annotated images use the ``show_sample()`` method.
The ``label_to_show`` argument specifies which label you would like to show.

Expand Down
25 changes: 23 additions & 2 deletions mapreader/classify/load_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class AnnotationsLoader:

def __init__(self):
self.annotations = pd.DataFrame()
self.labels_map = {}
self.reviewed = pd.DataFrame()
self.patch_paths_col = None
self.label_col = None
Expand All @@ -33,6 +34,7 @@ def __init__(self):
def load(
self,
annotations: str | pd.DataFrame,
labels_map: dict | None = None,
delimiter: str | None = ",",
images_dir: str | None = None,
remove_broken: bool | None = True,
Expand All @@ -52,6 +54,8 @@ def load(
annotations : Union[str, pd.DataFrame]
The annotations.
Can either be the path to a csv file or a pandas.DataFrame.
labels_map : Optional[dict], optional
A dictionary mapping labels to indices. If not provided, labels will be mapped to indices based on the order in which they appear in the annotations dataframe. By default None.
delimiter : Optional[str], optional
The delimiter to use when loading the csv file as a dataframe, by default ",".
images_dir : Optional[str], optional
Expand Down Expand Up @@ -132,8 +136,25 @@ def load(

self.unique_labels = self.annotations[self.label_col].unique().tolist()

labels_map = {i: label for i, label in enumerate(self.unique_labels)}
self.labels_map = labels_map
# if labels_map is explicitly provided
if labels_map:
self.labels_map = dict(
sorted(labels_map.items())
) # sort labels_map by keys
if not set(self.unique_labels).issubset(set(labels_map.values())):
raise ValueError(
"[ERROR] There are label(s) in the annotations that are not in the labels map. Please check the labels_map."
)
# if inferring labels_map
else:
if append:
for label in self.unique_labels:
if label not in self.labels_map.values():
self.labels_map[len(self.labels_map)] = label
else:
# reset labels map
labels_map = {i: label for i, label in enumerate(self.unique_labels)}
self.labels_map = labels_map

self.annotations["label_index"] = self.annotations[self.label_col].apply(
self._get_label_index
Expand Down
47 changes: 47 additions & 0 deletions tests/test_classify/test_annotations_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,29 @@ def test_load_csv(load_annots, sample_dir):
assert annots.labels_map == {0: "no", 1: "railspace", 2: "building"}


def test_labels_map(sample_dir):
annots = AnnotationsLoader()
annots.load(
f"{sample_dir}/test_annots.csv",
reset_index=True,
remove_broken=False,
ignore_broken=True,
labels_map={1: "no", 0: "railspace"},
)
assert len(annots.annotations) == 81
assert annots.labels_map == {0: "railspace", 1: "no"}
# test append
annots.load(
f"{sample_dir}/test_annots_append.csv",
append=True,
remove_broken=False,
ignore_broken=True,
)
assert len(annots.annotations) == 83
assert annots.unique_labels == ["no", "railspace", "building"]
assert annots.labels_map == {0: "railspace", 1: "no", 2: "building"}


@pytest.mark.dependency(name="load_annots_df", scope="session")
def test_load_df(sample_dir):
annots = AnnotationsLoader()
Expand Down Expand Up @@ -162,6 +185,30 @@ def test_create_dataloaders_no_sampler(load_annots):
# errors


def test_labels_map_errors(sample_dir):
# csv
annots = AnnotationsLoader()
with pytest.raises(ValueError, match="not in the labels map"):
annots.load(
f"{sample_dir}/test_annots.csv",
reset_index=True,
remove_broken=False,
ignore_broken=True,
labels_map={0: "no"},
)
# dataframe
annots = AnnotationsLoader()
df = pd.read_csv(f"{sample_dir}/test_annots.csv", sep=",", index_col=0)
with pytest.raises(ValueError, match="not in the labels map"):
annots.load(
df,
reset_index=True,
remove_broken=False,
ignore_broken=True,
labels_map={0: "no"},
)


def test_load_fake_csv_errors():
annots = AnnotationsLoader()
with pytest.raises(ValueError, match="cannot be found"):
Expand Down

0 comments on commit 6eec831

Please sign in to comment.