diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py index 20fc2a1b..9707cd60 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py @@ -245,12 +245,14 @@ def _load_masks_as_semantic_label( slice_index: Whether to return only a specific slice. """ masks_dir = self._get_masks_dir(sample_index) - classes = self._class_mappings.keys() if self._class_mappings else self.classes + classes = self._class_mappings.keys() if self._class_mappings else self.classes[1:] mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in classes] binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths] if self._class_mappings: - mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(self.classes[1:]) + mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len( + self.classes[1:] + ) for original_class, mapped_class in self._class_mappings.items(): mapped_index = self.class_to_idx[mapped_class] - 1 original_index = list(self._class_mappings.keys()).index(original_class) diff --git a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py index b30e2c72..4e4e1283 100644 --- a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py +++ b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py @@ -59,7 +59,7 @@ def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: i assert "slice_index" in metadata # check the number of classes with v.s. without class mappings - n_classes_expected = 2 if total_segmentator_dataset._class_mappings is not None else 3 + n_classes_expected = 3 if total_segmentator_dataset._class_mappings is not None else 4 assert len(total_segmentator_dataset.classes) == n_classes_expected