From 0262e89bedb079c9f8510f85953ff945b0faf3be Mon Sep 17 00:00:00 2001 From: ioangatop Date: Tue, 15 Oct 2024 16:31:44 +0200 Subject: [PATCH 01/11] add kits23 --- .../radiology/online/segmentation/kits23.yaml | 119 ++++++++++ src/eva/vision/data/datasets/__init__.py | 2 + .../data/datasets/segmentation/__init__.py | 2 + .../data/datasets/segmentation/kits23.py | 212 ++++++++++++++++++ 4 files changed, 335 insertions(+) create mode 100644 configs/vision/radiology/online/segmentation/kits23.yaml create mode 100644 src/eva/vision/data/datasets/segmentation/kits23.py diff --git a/configs/vision/radiology/online/segmentation/kits23.yaml b/configs/vision/radiology/online/segmentation/kits23.yaml new file mode 100644 index 00000000..8bb24821 --- /dev/null +++ b/configs/vision/radiology/online/segmentation/kits23.yaml @@ -0,0 +1,119 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} + log_every_n_steps: 6 + callbacks: + - class_path: eva.callbacks.ConfigurationLogger + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} + - class_path: eva.vision.callbacks.SemanticSegmentationLogger + init_args: + log_every_n_epochs: 1 + mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} + std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: true + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/GeneralizedDiceScore} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: 100 + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.vision.models.modules.SemanticSegmentationModule + init_args: + encoder: + class_path: eva.vision.models.ModelFromRegistry + init_args: + model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} + model_kwargs: + out_indices: ${oc.env:OUT_INDICES, 1} + decoder: + class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS + init_args: + in_features: ${oc.env:IN_FEATURES, 384} + num_classes: &NUM_CLASSES 4 + criterion: + class_path: eva.vision.losses.DiceLoss + init_args: + softmax: true + batch: true + lr_multiplier_encoder: 0.0 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: ${oc.env:LR_VALUE, 0.002} + lr_scheduler: + class_path: torch.optim.lr_scheduler.PolynomialLR + init_args: + total_iters: *MAX_STEPS + power: 0.9 + postprocess: + predictions_transforms: + - class_path: torch.argmax + init_args: + dim: 1 + metrics: + common: + - class_path: eva.metrics.AverageLoss + evaluation: + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics + init_args: + num_classes: *NUM_CLASSES + - class_path: torchmetrics.ClasswiseWrapper + init_args: + metric: + class_path: eva.vision.metrics.GeneralizedDiceScore + init_args: + num_classes: *NUM_CLASSES + weight_type: linear + per_class: true +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.vision.datasets.KiTS23 + init_args: &DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data/kits23} + split: train + transforms: + class_path: eva.vision.data.transforms.common.ResizeAndClamp + init_args: + size: ${oc.env:RESIZE_DIM, 224} + mean: *NORMALIZE_MEAN + std: *NORMALIZE_STD + val: + class_path: eva.vision.datasets.KiTS23 + init_args: + <<: *DATASET_ARGS + split: train + test: + class_path: eva.vision.datasets.KiTS23 + init_args: + <<: *DATASET_ARGS + split: train + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 64} + shuffle: true + val: + batch_size: *BATCH_SIZE + shuffle: true + test: + batch_size: *BATCH_SIZE diff --git a/src/eva/vision/data/datasets/__init__.py b/src/eva/vision/data/datasets/__init__.py index bec918af..4fac6499 100644 --- a/src/eva/vision/data/datasets/__init__.py +++ b/src/eva/vision/data/datasets/__init__.py @@ -14,6 +14,7 @@ CoNSeP, EmbeddingsSegmentationDataset, ImageSegmentation, + KiTS23, LiTS, LiTSBalanced, MoNuSAC, @@ -34,6 +35,7 @@ "CoNSeP", "EmbeddingsSegmentationDataset", "ImageSegmentation", + "KiTS23", "LiTS", "LiTSBalanced", "MoNuSAC", diff --git a/src/eva/vision/data/datasets/segmentation/__init__.py b/src/eva/vision/data/datasets/segmentation/__init__.py index b954fa39..edabcf8a 100644 --- a/src/eva/vision/data/datasets/segmentation/__init__.py +++ b/src/eva/vision/data/datasets/segmentation/__init__.py @@ -4,6 +4,7 @@ from eva.vision.data.datasets.segmentation.bcss import BCSS from eva.vision.data.datasets.segmentation.consep import CoNSeP from eva.vision.data.datasets.segmentation.embeddings import EmbeddingsSegmentationDataset +from eva.vision.data.datasets.segmentation.kits23 import KiTS23 from eva.vision.data.datasets.segmentation.lits import LiTS from eva.vision.data.datasets.segmentation.lits_balanced import LiTSBalanced from eva.vision.data.datasets.segmentation.monusac import MoNuSAC @@ -14,6 +15,7 @@ "BCSS", "CoNSeP", "EmbeddingsSegmentationDataset", + "KiTS23", "LiTS", "LiTSBalanced", "MoNuSAC", diff --git a/src/eva/vision/data/datasets/segmentation/kits23.py b/src/eva/vision/data/datasets/segmentation/kits23.py new file mode 100644 index 00000000..153e992b --- /dev/null +++ b/src/eva/vision/data/datasets/segmentation/kits23.py @@ -0,0 +1,212 @@ +"""KiTS23 dataset.""" + +import functools +import os +import time +from typing import Any, Callable, Dict, List, Literal, Tuple +from urllib import request + +import torch +from torchvision import tv_tensors +from typing_extensions import override + +from eva.core.utils.progress_bar import tqdm +from eva.vision.data.datasets import _utils, _validators +from eva.vision.data.datasets.segmentation import base +from eva.vision.utils import io + + +class KiTS23(base.ImageSegmentation): + """KiTS23 - The 2023 Kidney and Kidney Tumor Segmentation challenge. + + Webpage: https://kits-challenge.org/kits23/ + """ + + _train_index_ranges: List[Tuple[int, int]] = [(0, 300), (400, 589)] + """Train range indices.""" + + _expected_dataset_lengths: Dict[str | None, int] = { + "train": 250911, + } + """Dataset version and split to the expected size.""" + + _sample_every_n_slices: int | None = None + """The amount of slices to sub-sample per 3D CT scan image.""" + + _license: str = "CC BY-NC-SA 4.0" + """Dataset license.""" + + def __init__( + self, + root: str, + split: Literal["train"], + download: bool = False, + transforms: Callable | None = None, + ) -> None: + """Initialize dataset. + + Args: + root: Path to the root directory of the dataset. The dataset will + be downloaded and extracted here, if it does not already exist. + split: Dataset split to use. + download: Whether to download the data for the specified split. + Note that the download will be executed only by additionally + calling the :meth:`prepare_data` method and if the data does + not yet exist on disk. + transforms: A function/transforms that takes in an image and a target + mask and returns the transformed versions of both. + """ + super().__init__(transforms=transforms) + + self._root = root + self._split = split + self._download = download + + self._indices: List[Tuple[int, int]] = [] + + @property + @override + def classes(self) -> List[str]: + return ["kidney", "tumor", "cyst"] + + @functools.cached_property + @override + def class_to_idx(self) -> Dict[str, int]: + return {label: index for index, label in enumerate(self.classes)} + + @override + def filename(self, index: int) -> str: + sample_index, _ = self._indices[index] + return self._volume_filename(sample_index) + + @override + def prepare_data(self) -> None: + if self._download: + self._download_dataset() + + @override + def configure(self) -> None: + self._indices = self._create_indices() + + @override + def validate(self) -> None: + _validators.check_dataset_integrity( + self, + length=self._expected_dataset_lengths.get(self._split, 0), + n_classes=3, + first_and_last_labels=("kidney", "cyst"), + ) + + @override + def load_image(self, index: int) -> tv_tensors.Image: + sample_index, slice_index = self._indices[index] + volume_path = self._volume_path(sample_index) + image_array = io.read_nifti(volume_path, slice_index) + return tv_tensors.Image(image_array.transpose(2, 0, 1), dtype=torch.float32) + + @override + def load_mask(self, index: int) -> tv_tensors.Mask: + sample_index, slice_index = self._indices[index] + segmentation_path = self._segmentation_path(sample_index) + semantic_labels = io.read_nifti(segmentation_path, slice_index) + return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue] + + @override + def load_metadata(self, index: int) -> Dict[str, Any]: + _, slice_index = self._indices[index] + return {"slice_index": slice_index} + + @override + def __len__(self) -> int: + return len(self._indices) + + def _create_indices(self) -> List[Tuple[int, int]]: + """Builds the dataset indices for the specified split. + + Returns: + A list of tuples, where the first value indicates the + sample index which the second its corresponding slice + index. + """ + indices = [ + (sample_idx, slide_idx) + for sample_idx in self._get_split_indices() + for slide_idx in range(self._get_number_of_slices_per_volume(sample_idx)) + if slide_idx % (self._sample_every_n_slices or 1) == 0 + ] + return indices + + def _get_split_indices(self) -> List[int]: + """Builds the dataset indices for the specified split.""" + split_index_ranges = { + "train": self._train_index_ranges, + } + index_ranges = split_index_ranges.get(self._split) + if index_ranges is None: + raise ValueError("Invalid data split. Use 'train'.") + + return _utils.ranges_to_indices(index_ranges) + + def _get_number_of_slices_per_volume(self, sample_index: int) -> int: + """Returns the total amount of slices of a volume.""" + volume_shape = io.fetch_nifti_shape(self._volume_path(sample_index)) + return volume_shape[-1] + + def _volume_filename(self, sample_index: int) -> str: + return os.path.join(f"case_{sample_index}", "imaging.nii.gz") + + def _segmentation_filename(self, sample_index: int) -> str: + return os.path.join(f"case_{sample_index}", "segmentation.nii.gz") + + def _volume_path(self, sample_index: int) -> str: + return os.path.join(self._root, self._volume_filename(sample_index)) + + def _segmentation_path(self, sample_index: int) -> str: + return os.path.join(self._root, self._segmentation_filename(sample_index)) + + def _download_dataset(self) -> None: + """Downloads the dataset.""" + self._print_license() + for case_id in tqdm( + self._get_split_indices(), + desc=">> Downloading dataset", + leave=False, + ): + image_path, segmentation_path = self._volume_path(case_id), self._segmentation_path( + case_id + ) + if os.path.isfile(image_path) and os.path.isfile(segmentation_path): + continue + + _download_case_with_retry(case_id, image_path, segmentation_path) + + def _print_license(self) -> None: + """Prints the dataset license.""" + print(f"Dataset license: {self._license}") + + +def _download_case_with_retry( + case_id: int, + image_path: str, + segmentation_path: str, + *, + retries: int = 2, +) -> None: + for attempt in range(retries): + try: + os.makedirs(os.path.dirname(image_path), exist_ok=True) + request.urlretrieve( + url=f"https://kits19.sfo2.digitaloceanspaces.com/master_{case_id:05d}.nii.gz", # nosec + filename=image_path, + ) + request.urlretrieve( + url=f"https://github.com/neheller/kits23/raw/refs/heads/main/dataset/case_{case_id:05d}/segmentation.nii.gz", # nosec + filename=segmentation_path, + ) + return + + except Exception as e: + if attempt < retries - 1: + time.sleep(5) + else: + raise e From 6cd96044566796a4d6ecaaf5dbf86f68ced8bf7f Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Fri, 6 Dec 2024 10:09:16 +0100 Subject: [PATCH 02/11] fixed paths --- src/eva/vision/data/datasets/segmentation/kits23.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/eva/vision/data/datasets/segmentation/kits23.py b/src/eva/vision/data/datasets/segmentation/kits23.py index 153e992b..e5b850c2 100644 --- a/src/eva/vision/data/datasets/segmentation/kits23.py +++ b/src/eva/vision/data/datasets/segmentation/kits23.py @@ -153,10 +153,10 @@ def _get_number_of_slices_per_volume(self, sample_index: int) -> int: return volume_shape[-1] def _volume_filename(self, sample_index: int) -> str: - return os.path.join(f"case_{sample_index}", "imaging.nii.gz") + return os.path.join(f"case_{sample_index:05d}", f"master_{sample_index:05d}.nii.gz") def _segmentation_filename(self, sample_index: int) -> str: - return os.path.join(f"case_{sample_index}", "segmentation.nii.gz") + return os.path.join(f"case_{sample_index:05d}", "segmentation.nii.gz") def _volume_path(self, sample_index: int) -> str: return os.path.join(self._root, self._volume_filename(sample_index)) @@ -200,7 +200,7 @@ def _download_case_with_retry( filename=image_path, ) request.urlretrieve( - url=f"https://github.com/neheller/kits23/raw/refs/heads/main/dataset/case_{case_id:05d}/segmentation.nii.gz", # nosec + url=f"https://raw.githubusercontent.com/neheller/kits23/e282208/dataset/case_{case_id:05d}/segmentation.nii.gz", # nosec filename=segmentation_path, ) return From 8d5f34c4005370a5e7dc6a0e884cc0ff64151723 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Fri, 6 Dec 2024 10:47:19 +0100 Subject: [PATCH 03/11] add background class --- src/eva/vision/data/datasets/segmentation/kits23.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/eva/vision/data/datasets/segmentation/kits23.py b/src/eva/vision/data/datasets/segmentation/kits23.py index e5b850c2..21a07819 100644 --- a/src/eva/vision/data/datasets/segmentation/kits23.py +++ b/src/eva/vision/data/datasets/segmentation/kits23.py @@ -67,7 +67,7 @@ def __init__( @property @override def classes(self) -> List[str]: - return ["kidney", "tumor", "cyst"] + return ["background", "kidney", "tumor", "cyst"] @functools.cached_property @override @@ -93,8 +93,8 @@ def validate(self) -> None: _validators.check_dataset_integrity( self, length=self._expected_dataset_lengths.get(self._split, 0), - n_classes=3, - first_and_last_labels=("kidney", "cyst"), + n_classes=4, + first_and_last_labels=("background", "cyst"), ) @override From 2bf71f3d5ebb006ee167ab921cdc86ecf4ecdb21 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Fri, 6 Dec 2024 11:18:55 +0100 Subject: [PATCH 04/11] add decompression logic --- .../data/datasets/segmentation/kits23.py | 48 +++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/src/eva/vision/data/datasets/segmentation/kits23.py b/src/eva/vision/data/datasets/segmentation/kits23.py index 21a07819..08b9bac4 100644 --- a/src/eva/vision/data/datasets/segmentation/kits23.py +++ b/src/eva/vision/data/datasets/segmentation/kits23.py @@ -3,6 +3,9 @@ import functools import os import time +from pathlib import Path +import numpy as np +import numpy.typing as npt from typing import Any, Callable, Dict, List, Literal, Tuple from urllib import request @@ -10,6 +13,8 @@ from torchvision import tv_tensors from typing_extensions import override +from eva.core.utils import io as core_io +from eva.core.utils import multiprocessing from eva.core.utils.progress_bar import tqdm from eva.vision.data.datasets import _utils, _validators from eva.vision.data.datasets.segmentation import base @@ -30,6 +35,9 @@ class KiTS23(base.ImageSegmentation): } """Dataset version and split to the expected size.""" + _fix_orientation: bool = True + """Whether to fix the orientation of the images to match the default for radiologists.""" + _sample_every_n_slices: int | None = None """The amount of slices to sub-sample per 3D CT scan image.""" @@ -41,6 +49,8 @@ def __init__( root: str, split: Literal["train"], download: bool = False, + decompress: bool = True, + num_workers: int = 10, transforms: Callable | None = None, ) -> None: """Initialize dataset. @@ -53,6 +63,10 @@ def __init__( Note that the download will be executed only by additionally calling the :meth:`prepare_data` method and if the data does not yet exist on disk. + decompress: Whether to decompress the .nii.gz files when preparing the data. + Without decompression, data loading will be very slow. + num_workers: The number of workers to use for optimizing the masks & + decompressing the .gz files. transforms: A function/transforms that takes in an image and a target mask and returns the transformed versions of both. """ @@ -60,6 +74,8 @@ def __init__( self._root = root self._split = split + self._decompress = decompress + self._num_workers = num_workers self._download = download self._indices: List[Tuple[int, int]] = [] @@ -74,6 +90,10 @@ def classes(self) -> List[str]: def class_to_idx(self) -> Dict[str, int]: return {label: index for index, label in enumerate(self.classes)} + @property + def _file_suffix(self) -> str: + return "nii" if self._decompress else "nii.gz" + @override def filename(self, index: int) -> str: sample_index, _ = self._indices[index] @@ -83,6 +103,8 @@ def filename(self, index: int) -> str: def prepare_data(self) -> None: if self._download: self._download_dataset() + if self._decompress: + self._decompress_files() @override def configure(self) -> None: @@ -102,7 +124,9 @@ def load_image(self, index: int) -> tv_tensors.Image: sample_index, slice_index = self._indices[index] volume_path = self._volume_path(sample_index) image_array = io.read_nifti(volume_path, slice_index) - return tv_tensors.Image(image_array.transpose(2, 0, 1), dtype=torch.float32) + if self._fix_orientation: + image_array = self._orientation(image_array, sample_index) + return tv_tensors.Image(image_array.transpose(2, 0, 1)) @override def load_mask(self, index: int) -> tv_tensors.Mask: @@ -116,6 +140,14 @@ def load_metadata(self, index: int) -> Dict[str, Any]: _, slice_index = self._indices[index] return {"slice_index": slice_index} + def _orientation(self, array: npt.NDArray, sample_index: int) -> npt.NDArray: + # orientation = io.fetch_nifti_axis_direction_code(self._volume_path(sample_index)) + # array = np.rot90(array, axes=(0, 1)) + # if orientation == "LPS": + # array = np.flip(array, axis=0) + # TODO: Implement orientation correction + return array.copy() + @override def __len__(self) -> int: return len(self._indices) @@ -153,10 +185,10 @@ def _get_number_of_slices_per_volume(self, sample_index: int) -> int: return volume_shape[-1] def _volume_filename(self, sample_index: int) -> str: - return os.path.join(f"case_{sample_index:05d}", f"master_{sample_index:05d}.nii.gz") + return f"case_{sample_index:05d}/master_{sample_index:05d}.{self._file_suffix}" def _segmentation_filename(self, sample_index: int) -> str: - return os.path.join(f"case_{sample_index:05d}", "segmentation.nii.gz") + return f"case_{sample_index:05d}/segmentation.{self._file_suffix}" def _volume_path(self, sample_index: int) -> str: return os.path.join(self._root, self._volume_filename(sample_index)) @@ -180,6 +212,16 @@ def _download_dataset(self) -> None: _download_case_with_retry(case_id, image_path, segmentation_path) + def _decompress_files(self) -> None: + compressed_paths = Path(self._root).rglob("*.nii.gz") + multiprocessing.run_with_threads( + functools.partial(core_io.gunzip_file, keep=False), + [(str(path),) for path in compressed_paths], + num_workers=self._num_workers, + progress_desc=">> Decompressing .gz files", + return_results=False, + ) + def _print_license(self) -> None: """Prints the dataset license.""" print(f"Dataset license: {self._license}") From 71804264779293d5301f7b184ec505879448f362 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Fri, 6 Dec 2024 11:31:52 +0100 Subject: [PATCH 05/11] addded val & test splits --- .../data/datasets/segmentation/kits23.py | 61 ++++++++++++------- .../vision/data/datasets/segmentation/lits.py | 2 +- 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/src/eva/vision/data/datasets/segmentation/kits23.py b/src/eva/vision/data/datasets/segmentation/kits23.py index 08b9bac4..e279922f 100644 --- a/src/eva/vision/data/datasets/segmentation/kits23.py +++ b/src/eva/vision/data/datasets/segmentation/kits23.py @@ -4,15 +4,15 @@ import os import time from pathlib import Path -import numpy as np -import numpy.typing as npt from typing import Any, Callable, Dict, List, Literal, Tuple from urllib import request +import numpy.typing as npt import torch from torchvision import tv_tensors from typing_extensions import override +from eva.core.data import splitting from eva.core.utils import io as core_io from eva.core.utils import multiprocessing from eva.core.utils.progress_bar import tqdm @@ -27,11 +27,19 @@ class KiTS23(base.ImageSegmentation): Webpage: https://kits-challenge.org/kits23/ """ - _train_index_ranges: List[Tuple[int, int]] = [(0, 300), (400, 589)] - """Train range indices.""" + _index_ranges: List[Tuple[int, int]] = [(0, 300), (400, 589)] + """Dataset index ranges.""" + + _train_ratio: float = 0.7 + _val_ratio: float = 0.15 + _test_ratio: float = 0.15 + """Ratios for dataset splits.""" _expected_dataset_lengths: Dict[str | None, int] = { - "train": 250911, + "train": 175527, + "val": 37376, + "test": 38008, + None: 250911, } """Dataset version and split to the expected size.""" @@ -47,18 +55,19 @@ class KiTS23(base.ImageSegmentation): def __init__( self, root: str, - split: Literal["train"], + split: Literal["train", "val", "test"] | None = None, download: bool = False, decompress: bool = True, num_workers: int = 10, transforms: Callable | None = None, + seed: int = 8, ) -> None: """Initialize dataset. Args: root: Path to the root directory of the dataset. The dataset will be downloaded and extracted here, if it does not already exist. - split: Dataset split to use. + split: Dataset split to use. If `None`, the entire dataset will be used. download: Whether to download the data for the specified split. Note that the download will be executed only by additionally calling the :meth:`prepare_data` method and if the data does @@ -69,6 +78,7 @@ def __init__( decompressing the .gz files. transforms: A function/transforms that takes in an image and a target mask and returns the transformed versions of both. + seed: Seed used for generating the dataset splits. """ super().__init__(transforms=transforms) @@ -77,6 +87,7 @@ def __init__( self._decompress = decompress self._num_workers = num_workers self._download = download + self._seed = seed self._indices: List[Tuple[int, int]] = [] @@ -170,14 +181,21 @@ def _create_indices(self) -> List[Tuple[int, int]]: def _get_split_indices(self) -> List[int]: """Builds the dataset indices for the specified split.""" - split_index_ranges = { - "train": self._train_index_ranges, + indices = _utils.ranges_to_indices(self._index_ranges) + + train_indices, val_indices, test_indices = splitting.random_split( + indices, self._train_ratio, self._val_ratio, self._test_ratio, seed=self._seed + ) + split_indices_dict = { + "train": [indices[i] for i in train_indices], + "val": [indices[i] for i in val_indices], + "test": [indices[i] for i in test_indices], # type: ignore + None: indices, } - index_ranges = split_index_ranges.get(self._split) - if index_ranges is None: - raise ValueError("Invalid data split. Use 'train'.") + if self._split not in split_indices_dict: + raise ValueError("Invalid data split. Use 'train', 'val', 'test' or `None`.") - return _utils.ranges_to_indices(index_ranges) + return list(split_indices_dict[self._split]) def _get_number_of_slices_per_volume(self, sample_index: int) -> int: """Returns the total amount of slices of a volume.""" @@ -213,14 +231,15 @@ def _download_dataset(self) -> None: _download_case_with_retry(case_id, image_path, segmentation_path) def _decompress_files(self) -> None: - compressed_paths = Path(self._root).rglob("*.nii.gz") - multiprocessing.run_with_threads( - functools.partial(core_io.gunzip_file, keep=False), - [(str(path),) for path in compressed_paths], - num_workers=self._num_workers, - progress_desc=">> Decompressing .gz files", - return_results=False, - ) + compressed_paths = list(Path(self._root).rglob("*.nii.gz")) + if len(compressed_paths) > 0: + multiprocessing.run_with_threads( + functools.partial(core_io.gunzip_file, keep=False), + [(str(path),) for path in compressed_paths], + num_workers=self._num_workers, + progress_desc=">> Decompressing .gz files", + return_results=False, + ) def _print_license(self) -> None: """Prints the dataset license.""" diff --git a/src/eva/vision/data/datasets/segmentation/lits.py b/src/eva/vision/data/datasets/segmentation/lits.py index fb354155..e9794fc5 100644 --- a/src/eva/vision/data/datasets/segmentation/lits.py +++ b/src/eva/vision/data/datasets/segmentation/lits.py @@ -27,7 +27,7 @@ class LiTS(base.ImageSegmentation): _train_ratio: float = 0.7 _val_ratio: float = 0.15 _test_ratio: float = 0.15 - """Index ranges per split.""" + """Ratios for dataset splits.""" _fix_orientation: bool = True """Whether to fix the orientation of the images to match the default for radiologists.""" From 4a0d96e7f80dfc6cf6dfcfaea9d546b545a03717 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Fri, 6 Dec 2024 12:50:33 +0100 Subject: [PATCH 06/11] remove orientation & load slices from first array dimension (this is super slow for some reason) --- .../data/datasets/segmentation/kits23.py | 33 ++++++------------- src/eva/vision/utils/io/nifti.py | 14 ++++++-- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/src/eva/vision/data/datasets/segmentation/kits23.py b/src/eva/vision/data/datasets/segmentation/kits23.py index e279922f..54fd2d82 100644 --- a/src/eva/vision/data/datasets/segmentation/kits23.py +++ b/src/eva/vision/data/datasets/segmentation/kits23.py @@ -36,16 +36,13 @@ class KiTS23(base.ImageSegmentation): """Ratios for dataset splits.""" _expected_dataset_lengths: Dict[str | None, int] = { - "train": 175527, - "val": 37376, - "test": 38008, - None: 250911, + "train": 67582, + "val": 13751, + "test": 13888, + None: 95221, } """Dataset version and split to the expected size.""" - _fix_orientation: bool = True - """Whether to fix the orientation of the images to match the default for radiologists.""" - _sample_every_n_slices: int | None = None """The amount of slices to sub-sample per 3D CT scan image.""" @@ -134,30 +131,20 @@ def validate(self) -> None: def load_image(self, index: int) -> tv_tensors.Image: sample_index, slice_index = self._indices[index] volume_path = self._volume_path(sample_index) - image_array = io.read_nifti(volume_path, slice_index) - if self._fix_orientation: - image_array = self._orientation(image_array, sample_index) - return tv_tensors.Image(image_array.transpose(2, 0, 1)) + image_array = io.read_nifti(volume_path, slice_index, slice_dim=0) + return tv_tensors.Image(image_array) @override def load_mask(self, index: int) -> tv_tensors.Mask: sample_index, slice_index = self._indices[index] segmentation_path = self._segmentation_path(sample_index) - semantic_labels = io.read_nifti(segmentation_path, slice_index) + semantic_labels = io.read_nifti(segmentation_path, slice_index, slice_dim=0) return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue] @override def load_metadata(self, index: int) -> Dict[str, Any]: - _, slice_index = self._indices[index] - return {"slice_index": slice_index} - - def _orientation(self, array: npt.NDArray, sample_index: int) -> npt.NDArray: - # orientation = io.fetch_nifti_axis_direction_code(self._volume_path(sample_index)) - # array = np.rot90(array, axes=(0, 1)) - # if orientation == "LPS": - # array = np.flip(array, axis=0) - # TODO: Implement orientation correction - return array.copy() + sample_index, slice_index = self._indices[index] + return {"case_id": f"{sample_index:05d}", "slice_index": slice_index} @override def __len__(self) -> int: @@ -200,7 +187,7 @@ def _get_split_indices(self) -> List[int]: def _get_number_of_slices_per_volume(self, sample_index: int) -> int: """Returns the total amount of slices of a volume.""" volume_shape = io.fetch_nifti_shape(self._volume_path(sample_index)) - return volume_shape[-1] + return volume_shape[0] def _volume_filename(self, sample_index: int) -> str: return f"case_{sample_index:05d}/master_{sample_index:05d}.{self._file_suffix}" diff --git a/src/eva/vision/utils/io/nifti.py b/src/eva/vision/utils/io/nifti.py index 49ca8fda..d6d0dc11 100644 --- a/src/eva/vision/utils/io/nifti.py +++ b/src/eva/vision/utils/io/nifti.py @@ -11,13 +11,18 @@ def read_nifti( - path: str, slice_index: int | None = None, *, use_storage_dtype: bool = True + path: str, + slice_index: int | None = None, + *, + slice_dim: int = -1, + use_storage_dtype: bool = True, ) -> npt.NDArray[Any]: """Reads and loads a NIfTI image from a file path. Args: path: The path to the NIfTI file. slice_index: Whether to read only a slice from the file. + slice_dim: The array dimension to slice the image. Default is -1. use_storage_dtype: Whether to cast the raw image array to the inferred type. @@ -30,8 +35,13 @@ def read_nifti( """ _utils.check_file(path) image_data: nib.Nifti1Image = nib.load(path) # type: ignore + if slice_index is not None: - image_data = image_data.slicer[:, :, slice_index : slice_index + 1] + if slice_dim not in {-1, 0, 1, 2}: + raise ValueError(f"Expected slice_dim to be -1, 0, 1 or 2, but got {slice_dim}.") + array_indices = [slice(None)] * 3 + array_indices[2 if slice_dim == -1 else slice_dim] = slice(slice_index, slice_index + 1) + image_data = image_data.slicer[tuple(array_indices)] image_array = image_data.get_fdata() if use_storage_dtype: From 5e738e2426d7da4bec525691cb8f1c6e9659d892 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Mon, 9 Dec 2024 08:40:40 +0100 Subject: [PATCH 07/11] orientation --- .../data/datasets/segmentation/kits23.py | 6 ++-- src/eva/vision/utils/io/nifti.py | 33 +++++++++++-------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/eva/vision/data/datasets/segmentation/kits23.py b/src/eva/vision/data/datasets/segmentation/kits23.py index 54fd2d82..3f4e8326 100644 --- a/src/eva/vision/data/datasets/segmentation/kits23.py +++ b/src/eva/vision/data/datasets/segmentation/kits23.py @@ -131,14 +131,14 @@ def validate(self) -> None: def load_image(self, index: int) -> tv_tensors.Image: sample_index, slice_index = self._indices[index] volume_path = self._volume_path(sample_index) - image_array = io.read_nifti(volume_path, slice_index, slice_dim=0) - return tv_tensors.Image(image_array) + image_array = io.read_nifti(volume_path, slice_index, target_orientation="LAS", use_storage_dtype=True) + return tv_tensors.Image(image_array, dtype=torch.float32) # type: ignore[reportCallIssue] @override def load_mask(self, index: int) -> tv_tensors.Mask: sample_index, slice_index = self._indices[index] segmentation_path = self._segmentation_path(sample_index) - semantic_labels = io.read_nifti(segmentation_path, slice_index, slice_dim=0) + semantic_labels = io.read_nifti(segmentation_path, slice_index, target_orientation="LAS", use_storage_dtype=True) return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue] @override diff --git a/src/eva/vision/utils/io/nifti.py b/src/eva/vision/utils/io/nifti.py index d6d0dc11..8397bbf4 100644 --- a/src/eva/vision/utils/io/nifti.py +++ b/src/eva/vision/utils/io/nifti.py @@ -11,18 +11,13 @@ def read_nifti( - path: str, - slice_index: int | None = None, - *, - slice_dim: int = -1, - use_storage_dtype: bool = True, + path: str, slice_index: int | None = None, *, use_storage_dtype: bool = True, target_orientation: str | None = None ) -> npt.NDArray[Any]: """Reads and loads a NIfTI image from a file path. Args: path: The path to the NIfTI file. slice_index: Whether to read only a slice from the file. - slice_dim: The array dimension to slice the image. Default is -1. use_storage_dtype: Whether to cast the raw image array to the inferred type. @@ -35,20 +30,30 @@ def read_nifti( """ _utils.check_file(path) image_data: nib.Nifti1Image = nib.load(path) # type: ignore - + if target_orientation is not None: + image_data = reorient(image_data, target_orientation) + if slice_index is not None: - if slice_dim not in {-1, 0, 1, 2}: - raise ValueError(f"Expected slice_dim to be -1, 0, 1 or 2, but got {slice_dim}.") - array_indices = [slice(None)] * 3 - array_indices[2 if slice_dim == -1 else slice_dim] = slice(slice_index, slice_index + 1) - image_data = image_data.slicer[tuple(array_indices)] - - image_array = image_data.get_fdata() + image_array = image_data.dataobj[:, :, slice_index] + else: + image_array = image_data.get_fdata() + if use_storage_dtype: image_array = image_array.astype(image_data.get_data_dtype()) return image_array +def reorient( + nii: nib.Nifti1Image, + orientation: str | tuple[str, str, str] = "RAS", +) -> nib.Nifti1Image: + """Reorients a nifti image to specified orientation. Orientation string or tuple + must consist of "R" or "L", "A" or "P", and "I" or "S" in any order.""" + orig_ornt = nib.io_orientation(nii.affine) + targ_ornt = orientations.axcodes2ornt(orientation) + transform = orientations.ornt_transform(orig_ornt, targ_ornt) + reoriented_nii = nii.as_reoriented(transform) + return reoriented_nii def save_array_as_nifti( array: npt.ArrayLike, From f39c625894bf470f83f907d394b6f39fd03a9d35 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Mon, 9 Dec 2024 13:31:36 +0100 Subject: [PATCH 08/11] added yaml configs --- .../offline/segmentation/kits23.yaml | 158 ++++++++++++++++++ .../radiology/online/segmentation/kits23.yaml | 39 +++-- .../data/datasets/segmentation/kits23.py | 68 ++++---- .../vision/data/datasets/segmentation/lits.py | 2 +- .../segmentation/total_segmentator_2d.py | 2 +- src/eva/vision/utils/io/nifti.py | 27 ++- 6 files changed, 244 insertions(+), 52 deletions(-) create mode 100644 configs/vision/radiology/offline/segmentation/kits23.yaml diff --git a/configs/vision/radiology/offline/segmentation/kits23.yaml b/configs/vision/radiology/offline/segmentation/kits23.yaml new file mode 100644 index 00000000..103c2f03 --- /dev/null +++ b/configs/vision/radiology/offline/segmentation/kits23.yaml @@ -0,0 +1,158 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/kits23} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 40000} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} + callbacks: + - class_path: eva.callbacks.ConfigurationLogger + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} + - class_path: eva.vision.callbacks.SemanticSegmentationLogger + init_args: + log_every_n_epochs: 1 + log_images: false + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: true + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, 'val/MonaiDiceScore'} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: ${oc.env:PATIENCE, 5} + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + - class_path: eva.callbacks.SegmentationEmbeddingsWriter + init_args: + output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/kits23 + dataloader_idx_map: + 0: train + 1: val + 2: test + metadata_keys: ["slice_index"] + overwrite: false + backbone: + class_path: eva.vision.models.ModelFromRegistry + init_args: + model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} + model_kwargs: + out_indices: ${oc.env:OUT_INDICES, 1} + model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.vision.models.modules.SemanticSegmentationModule + init_args: + decoder: + class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS + init_args: + in_features: ${oc.env:IN_FEATURES, 384} + num_classes: &NUM_CLASSES 4 + criterion: + class_path: eva.vision.losses.DiceLoss + init_args: + softmax: true + batch: true + lr_multiplier_encoder: 0.0 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: ${oc.env:LR_VALUE, 0.002} + lr_scheduler: + class_path: torch.optim.lr_scheduler.PolynomialLR + init_args: + total_iters: *MAX_STEPS + power: 0.9 + postprocess: + predictions_transforms: + - class_path: torch.argmax + init_args: + dim: 1 + metrics: + common: + - class_path: eva.metrics.AverageLoss + evaluation: + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics + init_args: + num_classes: *NUM_CLASSES + - class_path: torchmetrics.ClasswiseWrapper + init_args: + metric: + class_path: eva.vision.metrics.MonaiDiceScore + init_args: + include_background: true + num_classes: *NUM_CLASSES + reduction: none + labels: + - background + - kidney + - tumor + - cyst +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.vision.datasets.EmbeddingsSegmentationDataset + init_args: &DATASET_ARGS + root: *DATASET_EMBEDDINGS_ROOT + manifest_file: manifest.csv + split: train + val: + class_path: eva.vision.datasets.EmbeddingsSegmentationDataset + init_args: + <<: *DATASET_ARGS + split: val + test: + class_path: eva.vision.datasets.EmbeddingsSegmentationDataset + init_args: + <<: *DATASET_ARGS + split: test + predict: + - class_path: eva.vision.datasets.KiTS23 + init_args: &PREDICT_DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data/kits23} + split: train + download: ${oc.env:DOWNLOAD_DATA, false} + # Set `download: true` to download the dataset automatically from the official source. + # The KiTS23 dataset is distributed under the following license: + # "Attribution-NonCommercial-ShareAlike 4.0 International" + # (see: https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en) + transforms: + class_path: eva.vision.data.transforms.common.ResizeAndClamp + init_args: + size: ${oc.env:RESIZE_DIM, 224} + mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} + std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: eva.vision.datasets.KiTS23 + init_args: + <<: *PREDICT_DATASET_ARGS + split: val + - class_path: eva.vision.datasets.KiTS23 + init_args: + <<: *PREDICT_DATASET_ARGS + split: test + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 64} + num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} + shuffle: true + val: + batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS + shuffle: true + test: + batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS + predict: + batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} + num_workers: *N_DATA_WORKERS diff --git a/configs/vision/radiology/online/segmentation/kits23.yaml b/configs/vision/radiology/online/segmentation/kits23.yaml index 8bb24821..d804e864 100644 --- a/configs/vision/radiology/online/segmentation/kits23.yaml +++ b/configs/vision/radiology/online/segmentation/kits23.yaml @@ -2,10 +2,10 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} - default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits} - max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} - log_every_n_steps: 6 + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/kits23} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 40000} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -21,12 +21,12 @@ trainer: filename: best save_last: true save_top_k: 1 - monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/GeneralizedDiceScore} + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, 'val/MonaiDiceScore'} mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 100 + patience: ${oc.env:PATIENCE, 5} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: @@ -43,8 +43,9 @@ model: model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} model_kwargs: out_indices: ${oc.env:OUT_INDICES, 1} + model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} decoder: - class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS + class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderWithImage init_args: in_features: ${oc.env:IN_FEATURES, 384} num_classes: &NUM_CLASSES 4 @@ -78,11 +79,16 @@ model: - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: eva.vision.metrics.GeneralizedDiceScore + class_path: eva.vision.metrics.MonaiDiceScore init_args: + include_background: true num_classes: *NUM_CLASSES - weight_type: linear - per_class: true + reduction: none + labels: + - background + - kidney + - tumor + - cyst data: class_path: eva.DataModule init_args: @@ -92,28 +98,35 @@ data: init_args: &DATASET_ARGS root: ${oc.env:DATA_ROOT, ./data/kits23} split: train + download: ${oc.env:DOWNLOAD_DATA, false} + # Set `download: true` to download the dataset automatically from the official source. + # The KiTS23 dataset is distributed under the following license: + # "Attribution-NonCommercial-ShareAlike 4.0 International" + # (see: https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en) transforms: class_path: eva.vision.data.transforms.common.ResizeAndClamp init_args: - size: ${oc.env:RESIZE_DIM, 224} mean: *NORMALIZE_MEAN std: *NORMALIZE_STD val: class_path: eva.vision.datasets.KiTS23 init_args: <<: *DATASET_ARGS - split: train + split: val test: class_path: eva.vision.datasets.KiTS23 init_args: <<: *DATASET_ARGS - split: train + split: test dataloaders: train: batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 64} + num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} shuffle: true val: batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS shuffle: true test: batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS diff --git a/src/eva/vision/data/datasets/segmentation/kits23.py b/src/eva/vision/data/datasets/segmentation/kits23.py index 3f4e8326..6a6b2250 100644 --- a/src/eva/vision/data/datasets/segmentation/kits23.py +++ b/src/eva/vision/data/datasets/segmentation/kits23.py @@ -7,18 +7,18 @@ from typing import Any, Callable, Dict, List, Literal, Tuple from urllib import request -import numpy.typing as npt +import nibabel as nib import torch from torchvision import tv_tensors from typing_extensions import override from eva.core.data import splitting -from eva.core.utils import io as core_io from eva.core.utils import multiprocessing from eva.core.utils.progress_bar import tqdm from eva.vision.data.datasets import _utils, _validators from eva.vision.data.datasets.segmentation import base from eva.vision.utils import io +from eva.vision.utils.io import nifti class KiTS23(base.ImageSegmentation): @@ -46,6 +46,9 @@ class KiTS23(base.ImageSegmentation): _sample_every_n_slices: int | None = None """The amount of slices to sub-sample per 3D CT scan image.""" + _processed_dir: str = "processed" + """Directory where the processed data (reoriented to LPS & uncompressed) is stored.""" + _license: str = "CC BY-NC-SA 4.0" """Dataset license.""" @@ -54,7 +57,6 @@ def __init__( root: str, split: Literal["train", "val", "test"] | None = None, download: bool = False, - decompress: bool = True, num_workers: int = 10, transforms: Callable | None = None, seed: int = 8, @@ -69,10 +71,7 @@ def __init__( Note that the download will be executed only by additionally calling the :meth:`prepare_data` method and if the data does not yet exist on disk. - decompress: Whether to decompress the .nii.gz files when preparing the data. - Without decompression, data loading will be very slow. - num_workers: The number of workers to use for optimizing the masks & - decompressing the .gz files. + num_workers: The number of workers to use for preprocessing the dataset. transforms: A function/transforms that takes in an image and a target mask and returns the transformed versions of both. seed: Seed used for generating the dataset splits. @@ -81,9 +80,8 @@ def __init__( self._root = root self._split = split - self._decompress = decompress - self._num_workers = num_workers self._download = download + self._num_workers = num_workers self._seed = seed self._indices: List[Tuple[int, int]] = [] @@ -99,8 +97,8 @@ def class_to_idx(self) -> Dict[str, int]: return {label: index for index, label in enumerate(self.classes)} @property - def _file_suffix(self) -> str: - return "nii" if self._decompress else "nii.gz" + def _processed_root(self) -> str: + return os.path.join(self._root, self._processed_dir) @override def filename(self, index: int) -> str: @@ -111,8 +109,7 @@ def filename(self, index: int) -> str: def prepare_data(self) -> None: if self._download: self._download_dataset() - if self._decompress: - self._decompress_files() + self._preprocess() @override def configure(self) -> None: @@ -131,14 +128,14 @@ def validate(self) -> None: def load_image(self, index: int) -> tv_tensors.Image: sample_index, slice_index = self._indices[index] volume_path = self._volume_path(sample_index) - image_array = io.read_nifti(volume_path, slice_index, target_orientation="LAS", use_storage_dtype=True) - return tv_tensors.Image(image_array, dtype=torch.float32) # type: ignore[reportCallIssue] + image_array = io.read_nifti(volume_path, slice_index) + return tv_tensors.Image(image_array.transpose(2, 0, 1), dtype=torch.float32) # type: ignore[reportCallIssue] @override def load_mask(self, index: int) -> tv_tensors.Mask: sample_index, slice_index = self._indices[index] segmentation_path = self._segmentation_path(sample_index) - semantic_labels = io.read_nifti(segmentation_path, slice_index, target_orientation="LAS", use_storage_dtype=True) + semantic_labels = io.read_nifti(segmentation_path, slice_index) return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue] @override @@ -187,19 +184,19 @@ def _get_split_indices(self) -> List[int]: def _get_number_of_slices_per_volume(self, sample_index: int) -> int: """Returns the total amount of slices of a volume.""" volume_shape = io.fetch_nifti_shape(self._volume_path(sample_index)) - return volume_shape[0] + return volume_shape[-1] def _volume_filename(self, sample_index: int) -> str: - return f"case_{sample_index:05d}/master_{sample_index:05d}.{self._file_suffix}" + return f"case_{sample_index:05d}/master_{sample_index:05d}.nii" def _segmentation_filename(self, sample_index: int) -> str: - return f"case_{sample_index:05d}/segmentation.{self._file_suffix}" + return f"case_{sample_index:05d}/segmentation.nii" def _volume_path(self, sample_index: int) -> str: - return os.path.join(self._root, self._volume_filename(sample_index)) + return os.path.join(self._processed_root, self._volume_filename(sample_index)) def _segmentation_path(self, sample_index: int) -> str: - return os.path.join(self._root, self._segmentation_filename(sample_index)) + return os.path.join(self._processed_root, self._segmentation_filename(sample_index)) def _download_dataset(self) -> None: """Downloads the dataset.""" @@ -217,16 +214,27 @@ def _download_dataset(self) -> None: _download_case_with_retry(case_id, image_path, segmentation_path) - def _decompress_files(self) -> None: + def _preprocess(self): + def _reorient_and_save(path: Path) -> None: + relative_path = str(path.relative_to(self._root)) + save_path = os.path.join(self._processed_root, relative_path.rstrip(".gz")) + if os.path.isfile(save_path): + return + os.makedirs(os.path.dirname(save_path), exist_ok=True) + nifti.reorient(nib.load(path), "LPS").to_filename(str(save_path)) + compressed_paths = list(Path(self._root).rglob("*.nii.gz")) - if len(compressed_paths) > 0: - multiprocessing.run_with_threads( - functools.partial(core_io.gunzip_file, keep=False), - [(str(path),) for path in compressed_paths], - num_workers=self._num_workers, - progress_desc=">> Decompressing .gz files", - return_results=False, - ) + multiprocessing.run_with_threads( + _reorient_and_save, + [(path,) for path in compressed_paths], + num_workers=self._num_workers, + progress_desc=">> Preprocessing dataset", + return_results=False, + ) + + processed_paths = list(Path(self._processed_root).rglob("*.nii")) + if len(compressed_paths) != len(processed_paths): + raise RuntimeError(f"Preprocessing failed, missing files in {self._processed_root}.") def _print_license(self) -> None: """Prints the dataset license.""" diff --git a/src/eva/vision/data/datasets/segmentation/lits.py b/src/eva/vision/data/datasets/segmentation/lits.py index a9098311..0b88bf4a 100644 --- a/src/eva/vision/data/datasets/segmentation/lits.py +++ b/src/eva/vision/data/datasets/segmentation/lits.py @@ -116,7 +116,7 @@ def load_image(self, index: int) -> tv_tensors.Image: image_array = io.read_nifti(volume_path, slice_index) if self._fix_orientation: image_array = self._orientation(image_array, sample_index) - return tv_tensors.Image(image_array.transpose(2, 0, 1)) + return tv_tensors.Image(image_array.transpose(2, 0, 1), dtype=torch.float32) # type: ignore[reportCallIssue] @override def load_mask(self, index: int) -> tv_tensors.Mask: diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py index 53cc0c5f..ff8d4891 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py @@ -211,7 +211,7 @@ def load_image(self, index: int) -> tv_tensors.Image: image_path = self._get_image_path(sample_index) image_array = io.read_nifti(image_path, slice_index) image_array = self._fix_orientation(image_array) - return tv_tensors.Image(image_array.copy().transpose(2, 0, 1)) + return tv_tensors.Image(image_array.transpose(2, 0, 1), dtype=torch.float32) # type: ignore[reportCallIssue] @override def load_mask(self, index: int) -> tv_tensors.Mask: diff --git a/src/eva/vision/utils/io/nifti.py b/src/eva/vision/utils/io/nifti.py index 8397bbf4..4efab18c 100644 --- a/src/eva/vision/utils/io/nifti.py +++ b/src/eva/vision/utils/io/nifti.py @@ -11,7 +11,11 @@ def read_nifti( - path: str, slice_index: int | None = None, *, use_storage_dtype: bool = True, target_orientation: str | None = None + path: str, + slice_index: int | None = None, + *, + use_storage_dtype: bool = True, + target_orientation: str | None = None, ) -> npt.NDArray[Any]: """Reads and loads a NIfTI image from a file path. @@ -20,6 +24,7 @@ def read_nifti( slice_index: Whether to read only a slice from the file. use_storage_dtype: Whether to cast the raw image array to the inferred type. + target_orientation: The target orientation to reorient the image. E.g. "LPS". Returns: The image as a numpy array (height, width, channels). @@ -32,29 +37,37 @@ def read_nifti( image_data: nib.Nifti1Image = nib.load(path) # type: ignore if target_orientation is not None: image_data = reorient(image_data, target_orientation) - + if slice_index is not None: - image_array = image_data.dataobj[:, :, slice_index] + image_array = np.expand_dims(image_data.dataobj[:, :, slice_index], -1) else: image_array = image_data.get_fdata() - + if use_storage_dtype: image_array = image_array.astype(image_data.get_data_dtype()) return image_array + def reorient( nii: nib.Nifti1Image, - orientation: str | tuple[str, str, str] = "RAS", + orientation: str | Tuple[str, str, str] = "LPS", ) -> nib.Nifti1Image: - """Reorients a nifti image to specified orientation. Orientation string or tuple - must consist of "R" or "L", "A" or "P", and "I" or "S" in any order.""" + """Reorients a nifti image to specified orientation. + + Args: + nii: The input nifti image. + orientation: The target orientation to reorient the image. E.g. "LPS" or ("L", "P", "S"). + """ orig_ornt = nib.io_orientation(nii.affine) targ_ornt = orientations.axcodes2ornt(orientation) + if orig_ornt == targ_ornt: + return nii transform = orientations.ornt_transform(orig_ornt, targ_ornt) reoriented_nii = nii.as_reoriented(transform) return reoriented_nii + def save_array_as_nifti( array: npt.ArrayLike, filename: str, From 644e30c501a595b65990439d4177cdf574f23155 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Mon, 9 Dec 2024 15:13:45 +0100 Subject: [PATCH 09/11] added unit tests --- .gitattributes | 1 + .gitignore | 3 - .../data/datasets/segmentation/kits23.py | 9 +- src/eva/vision/utils/io/nifti.py | 2 +- .../kits23/case_00036/master_00036.nii.gz | 3 + .../kits23/case_00036/segmentation.nii.gz | 3 + .../kits23/case_00240/master_00240.nii.gz | 3 + .../kits23/case_00240/segmentation.nii.gz | 3 + .../data/datasets/segmentation/test_kits23.py | 83 +++++++++++++++++++ 9 files changed, 105 insertions(+), 5 deletions(-) create mode 100644 tests/eva/assets/vision/datasets/kits23/case_00036/master_00036.nii.gz create mode 100644 tests/eva/assets/vision/datasets/kits23/case_00036/segmentation.nii.gz create mode 100644 tests/eva/assets/vision/datasets/kits23/case_00240/master_00240.nii.gz create mode 100644 tests/eva/assets/vision/datasets/kits23/case_00240/segmentation.nii.gz create mode 100644 tests/eva/vision/data/datasets/segmentation/test_kits23.py diff --git a/.gitattributes b/.gitattributes index 29bcbf6f..d4a50a98 100644 --- a/.gitattributes +++ b/.gitattributes @@ -9,3 +9,4 @@ tests/eva/assets/**/*.npy filter=lfs diff=lfs merge=lfs -text tests/eva/assets/**/*.xml filter=lfs diff=lfs merge=lfs -text tests/eva/assets/**/*.mat filter=lfs diff=lfs merge=lfs -text tests/eva/assets/**/*.nii filter=lfs diff=lfs merge=lfs -text +tests/eva/assets/**/*.nii.gz filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index cd585658..a9042f9b 100644 --- a/.gitignore +++ b/.gitignore @@ -179,6 +179,3 @@ cython_debug/ # numpy data *.npy - -# NiFti data -*.nii.gz diff --git a/src/eva/vision/data/datasets/segmentation/kits23.py b/src/eva/vision/data/datasets/segmentation/kits23.py index 6a6b2250..fbf1b8e9 100644 --- a/src/eva/vision/data/datasets/segmentation/kits23.py +++ b/src/eva/vision/data/datasets/segmentation/kits23.py @@ -24,6 +24,11 @@ class KiTS23(base.ImageSegmentation): """KiTS23 - The 2023 Kidney and Kidney Tumor Segmentation challenge. + To optimize data loading, the dataset is preprocessed by reorienting the images + from IPL to LAS and uncompressing them. The reorientation is necessary, because + loading slices from the first dimension is significantly slower than from the last, + due to data not being stored in a contiguous manner on disk accross all dimensions. + Webpage: https://kits-challenge.org/kits23/ """ @@ -215,6 +220,8 @@ def _download_dataset(self) -> None: _download_case_with_retry(case_id, image_path, segmentation_path) def _preprocess(self): + """Reorienting the images to LPS and uncompressing them.""" + def _reorient_and_save(path: Path) -> None: relative_path = str(path.relative_to(self._root)) save_path = os.path.join(self._processed_root, relative_path.rstrip(".gz")) @@ -227,7 +234,7 @@ def _reorient_and_save(path: Path) -> None: multiprocessing.run_with_threads( _reorient_and_save, [(path,) for path in compressed_paths], - num_workers=self._num_workers, + num_workers=1, progress_desc=">> Preprocessing dataset", return_results=False, ) diff --git a/src/eva/vision/utils/io/nifti.py b/src/eva/vision/utils/io/nifti.py index 4efab18c..47d9e236 100644 --- a/src/eva/vision/utils/io/nifti.py +++ b/src/eva/vision/utils/io/nifti.py @@ -61,7 +61,7 @@ def reorient( """ orig_ornt = nib.io_orientation(nii.affine) targ_ornt = orientations.axcodes2ornt(orientation) - if orig_ornt == targ_ornt: + if np.all(orig_ornt == targ_ornt): return nii transform = orientations.ornt_transform(orig_ornt, targ_ornt) reoriented_nii = nii.as_reoriented(transform) diff --git a/tests/eva/assets/vision/datasets/kits23/case_00036/master_00036.nii.gz b/tests/eva/assets/vision/datasets/kits23/case_00036/master_00036.nii.gz new file mode 100644 index 00000000..1cceb422 --- /dev/null +++ b/tests/eva/assets/vision/datasets/kits23/case_00036/master_00036.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7218db62c033db92b9d6ed29e0925a2c8793d5206fa65fb7b47f92a5699e085 +size 1094157 diff --git a/tests/eva/assets/vision/datasets/kits23/case_00036/segmentation.nii.gz b/tests/eva/assets/vision/datasets/kits23/case_00036/segmentation.nii.gz new file mode 100644 index 00000000..90ca9d76 --- /dev/null +++ b/tests/eva/assets/vision/datasets/kits23/case_00036/segmentation.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ece3061b8c97179b5ddf91a26fdbfc124276f789fe6fca942d5678396810fd2c +size 8328 diff --git a/tests/eva/assets/vision/datasets/kits23/case_00240/master_00240.nii.gz b/tests/eva/assets/vision/datasets/kits23/case_00240/master_00240.nii.gz new file mode 100644 index 00000000..f8b162d8 --- /dev/null +++ b/tests/eva/assets/vision/datasets/kits23/case_00240/master_00240.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:092410c27a5c04537b6c359e336cfe97693029a2dbe3c6e5ea0892d52ba3a5d6 +size 1004338 diff --git a/tests/eva/assets/vision/datasets/kits23/case_00240/segmentation.nii.gz b/tests/eva/assets/vision/datasets/kits23/case_00240/segmentation.nii.gz new file mode 100644 index 00000000..90ca9d76 --- /dev/null +++ b/tests/eva/assets/vision/datasets/kits23/case_00240/segmentation.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ece3061b8c97179b5ddf91a26fdbfc124276f789fe6fca942d5678396810fd2c +size 8328 diff --git a/tests/eva/vision/data/datasets/segmentation/test_kits23.py b/tests/eva/vision/data/datasets/segmentation/test_kits23.py new file mode 100644 index 00000000..466f7a41 --- /dev/null +++ b/tests/eva/vision/data/datasets/segmentation/test_kits23.py @@ -0,0 +1,83 @@ +"""KiTS23 dataset tests.""" + +import os +import shutil +from typing import Literal +from unittest.mock import patch + +import pytest +from torchvision import tv_tensors + +from eva.vision.data import datasets + + +@pytest.mark.parametrize( + "split, expected_length", + [(None, 8)], +) +def test_length(kits23_dataset: datasets.KiTS23, expected_length: int) -> None: + """Tests the length of the dataset.""" + assert len(kits23_dataset) == expected_length + + +@pytest.mark.parametrize( + "split, index", + [ + (None, 0), + ], +) +def test_sample(kits23_dataset: datasets.KiTS23, index: int) -> None: + """Tests the format of a dataset sample.""" + # assert data sample is a tuple + sample = kits23_dataset[index] + assert isinstance(sample, tuple) + assert len(sample) == 3 + # assert the format of the `image` and `mask` + image, mask, metadata = sample + assert isinstance(image, tv_tensors.Image) + assert image.shape == (1, 512, 512) + assert isinstance(mask, tv_tensors.Mask) + assert mask.shape == (512, 512) + assert isinstance(metadata, dict) + assert "slice_index" in metadata + + +@pytest.mark.parametrize("split", [None]) +def test_processed_dir_exists(kits23_dataset: datasets.KiTS23) -> None: + """Tests the existence of the processed directory.""" + assert os.path.isdir(kits23_dataset._processed_root) + + for index in ["00036", "00240"]: + assert os.path.isfile( + os.path.join(kits23_dataset._processed_root, f"case_{index}/master_{index}.nii") + ) + assert os.path.isfile( + os.path.join(kits23_dataset._processed_root, f"case_{index}/segmentation.nii") + ) + + +@pytest.fixture(scope="function") +def kits23_dataset(split: Literal["train", "val", "test"] | None, assets_path: str): + """KiTS23 dataset fixture.""" + dataset = datasets.KiTS23( + root=os.path.join( + assets_path, + "vision", + "datasets", + "kits23", + ), + split=split, + ) + dataset.prepare_data() + dataset.configure() + yield dataset + + if os.path.isdir(dataset._processed_root): + shutil.rmtree(dataset._processed_root) + + +@pytest.fixture(autouse=True) +def mock_indices(): + """Mocks the download function to avoid downloading resources when running tests.""" + with patch.object(datasets.KiTS23, "_get_split_indices", return_value=[36, 240]): + yield From 3cf1e2b3306d8ff4a085570e0f53e669b7d2db50 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Mon, 9 Dec 2024 15:49:07 +0100 Subject: [PATCH 10/11] added nifti unit tests --- tests/eva/vision/utils/io/test_nifti.py | 53 +++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 tests/eva/vision/utils/io/test_nifti.py diff --git a/tests/eva/vision/utils/io/test_nifti.py b/tests/eva/vision/utils/io/test_nifti.py new file mode 100644 index 00000000..10c3d226 --- /dev/null +++ b/tests/eva/vision/utils/io/test_nifti.py @@ -0,0 +1,53 @@ +"""Tests for the nifti IO functions.""" + +import os + +import nibabel as nib +import numpy as np +import pytest +from nibabel import orientations + +from eva.vision.utils.io import nifti + + +@pytest.mark.parametrize( + "use_storage_dtype, target_orientation", + [ + [False, None], + [False, "LPS"], + [True, "RAS"], + ], +) +def test_read_nifti(nifti_path: str, use_storage_dtype: bool, target_orientation: str): + """Tests the function to read a nifti file as array (full & slice).""" + image = nifti.read_nifti( + nifti_path, use_storage_dtype=use_storage_dtype, target_orientation=target_orientation + ) + assert image.shape == (512, 512, 4) + + slice_image = nifti.read_nifti( + nifti_path, + slice_index=0, + use_storage_dtype=use_storage_dtype, + target_orientation=target_orientation, + ) + assert slice_image.shape == (512, 512, 1) + + expected_dtype = np.dtype(" Date: Mon, 9 Dec 2024 15:59:02 +0100 Subject: [PATCH 11/11] fixed download paths --- .../vision/data/datasets/segmentation/kits23.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/eva/vision/data/datasets/segmentation/kits23.py b/src/eva/vision/data/datasets/segmentation/kits23.py index fbf1b8e9..0ed351bd 100644 --- a/src/eva/vision/data/datasets/segmentation/kits23.py +++ b/src/eva/vision/data/datasets/segmentation/kits23.py @@ -197,11 +197,13 @@ def _volume_filename(self, sample_index: int) -> str: def _segmentation_filename(self, sample_index: int) -> str: return f"case_{sample_index:05d}/segmentation.nii" - def _volume_path(self, sample_index: int) -> str: - return os.path.join(self._processed_root, self._volume_filename(sample_index)) + def _volume_path(self, sample_index: int, processed: bool = True) -> str: + root = self._processed_root if processed else self._root + return os.path.join(root, self._volume_filename(sample_index)) - def _segmentation_path(self, sample_index: int) -> str: - return os.path.join(self._processed_root, self._segmentation_filename(sample_index)) + def _segmentation_path(self, sample_index: int, processed: bool = True) -> str: + root = self._processed_root if processed else self._root + return os.path.join(root, self._segmentation_filename(sample_index)) def _download_dataset(self) -> None: """Downloads the dataset.""" @@ -211,9 +213,8 @@ def _download_dataset(self) -> None: desc=">> Downloading dataset", leave=False, ): - image_path, segmentation_path = self._volume_path(case_id), self._segmentation_path( - case_id - ) + image_path = self._volume_path(case_id, processed=False) + segmentation_path = self._segmentation_path(case_id, processed=False) if os.path.isfile(image_path) and os.path.isfile(segmentation_path): continue