diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py
index 20fc2a1b..9707cd60 100644
--- a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py
+++ b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py
@@ -245,12 +245,14 @@ def _load_masks_as_semantic_label(
             slice_index: Whether to return only a specific slice.
         """
         masks_dir = self._get_masks_dir(sample_index)
-        classes = self._class_mappings.keys() if self._class_mappings else self.classes
+        classes = self._class_mappings.keys() if self._class_mappings else self.classes[1:]
         mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in classes]
         binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]
 
         if self._class_mappings:
-            mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(self.classes[1:])
+            mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(
+                self.classes[1:]
+            )
             for original_class, mapped_class in self._class_mappings.items():
                 mapped_index = self.class_to_idx[mapped_class] - 1
                 original_index = list(self._class_mappings.keys()).index(original_class)
diff --git a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py
index b30e2c72..4e4e1283 100644
--- a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py
+++ b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py
@@ -59,7 +59,7 @@ def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: i
     assert "slice_index" in metadata
 
     # check the number of classes with v.s. without class mappings
-    n_classes_expected = 2 if total_segmentator_dataset._class_mappings is not None else 3
+    n_classes_expected = 3 if total_segmentator_dataset._class_mappings is not None else 4
     assert len(total_segmentator_dataset.classes) == n_classes_expected