Skip to content

Commit

Permalink
add labels_map argument when loading annots
Browse files Browse the repository at this point in the history
  • Loading branch information
rwood-97 committed Aug 12, 2024
1 parent 51e8526 commit daa12fc
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions mapreader/classify/load_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self):
def load(
self,
annotations: str | pd.DataFrame,
labels_map: dict | None = None,
delimiter: str | None = ",",
images_dir: str | None = None,
remove_broken: bool | None = True,
Expand All @@ -52,6 +53,8 @@ def load(
annotations : Union[str, pd.DataFrame]
The annotations.
Can either be the path to a csv file or a pandas.DataFrame.
labels_map : Optional[dict], optional
A dictionary mapping labels to indices. If not provided, labels will be mapped to indices based on the order in which they appear in the annotations dataframe. By default None.
delimiter : Optional[str], optional
The delimiter to use when loading the csv file as a dataframe, by default ",".
images_dir : Optional[str], optional
Expand Down Expand Up @@ -132,8 +135,15 @@ def load(

self.unique_labels = self.annotations[self.label_col].unique().tolist()

labels_map = {i: label for i, label in enumerate(self.unique_labels)}
self.labels_map = labels_map
if labels_map:
self.labels_map = labels_map
if not set(self.unique_labels).issubset(set(labels_map.keys())):
raise ValueError(
"[ERROR] There are label(s) in the annotations that are not in the labels_map. Please check the labels_map."
)
else:
labels_map = {i: label for i, label in enumerate(self.unique_labels)}
self.labels_map = labels_map

self.annotations["label_index"] = self.annotations[self.label_col].apply(
self._get_label_index
Expand Down

0 comments on commit daa12fc

Please sign in to comment.