From 3177aa199c3650c0e7eb1c31d41653cb203ace6e Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Wed, 25 Sep 2024 16:45:08 +0200 Subject: [PATCH 1/2] add LiTSBalanced dataset class --- .../offline/segmentation/lits_balanced.yaml | 138 ++++++++++++++++++ .../online/segmentation/lits_balanced.yaml | 116 +++++++++++++++ src/eva/vision/data/datasets/__init__.py | 2 + .../data/datasets/segmentation/__init__.py | 2 + .../datasets/segmentation/lits_balanced.py | 91 ++++++++++++ .../lits/Training_Batch2/segmentation-31.nii | 4 +- .../lits/Training_Batch2/segmentation-45.nii | 4 +- .../segmentation/test_lits_balanced.py | 59 ++++++++ 8 files changed, 412 insertions(+), 4 deletions(-) create mode 100644 configs/vision/radiology/offline/segmentation/lits_balanced.yaml create mode 100644 configs/vision/radiology/online/segmentation/lits_balanced.yaml create mode 100644 src/eva/vision/data/datasets/segmentation/lits_balanced.py create mode 100644 tests/eva/vision/data/datasets/segmentation/test_lits_balanced.py diff --git a/configs/vision/radiology/offline/segmentation/lits_balanced.yaml b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml new file mode 100644 index 00000000..b24cf572 --- /dev/null +++ b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml @@ -0,0 +1,138 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} + callbacks: + - class_path: eva.callbacks.ConfigurationLogger + - class_path: eva.vision.callbacks.SemanticSegmentationLogger + init_args: + log_every_n_epochs: 1 + log_images: false + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: true + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/GeneralizedDiceScore} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: 100 + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + - class_path: eva.callbacks.SegmentationEmbeddingsWriter + init_args: + output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits + dataloader_idx_map: + 0: train + 1: val + 2: test + metadata_keys: ["slice_index"] + overwrite: false + backbone: + class_path: eva.vision.models.ModelFromRegistry + init_args: + model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} + model_kwargs: + out_indices: ${oc.env:OUT_INDICES, 1} + model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.vision.models.modules.SemanticSegmentationModule + init_args: + decoder: + class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS + init_args: + in_features: ${oc.env:IN_FEATURES, 384} + num_classes: &NUM_CLASSES 3 + criterion: + class_path: eva.vision.losses.DiceLoss + init_args: + softmax: true + batch: true + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: ${oc.env:LR_VALUE, 0.002} + lr_scheduler: + class_path: torch.optim.lr_scheduler.PolynomialLR + init_args: + total_iters: *MAX_STEPS + power: 0.9 + postprocess: + predictions_transforms: + - class_path: torch.argmax + init_args: + dim: 1 + metrics: + common: + - class_path: eva.metrics.AverageLoss + evaluation: + - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + init_args: + num_classes: *NUM_CLASSES + - class_path: torchmetrics.ClasswiseWrapper + init_args: + metric: + class_path: torchmetrics.segmentation.GeneralizedDiceScore + init_args: + num_classes: *NUM_CLASSES + weight_type: linear + per_class: true +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.vision.datasets.EmbeddingsSegmentationDataset + init_args: &DATASET_ARGS + root: *DATASET_EMBEDDINGS_ROOT + manifest_file: manifest.csv + split: train + val: + class_path: eva.vision.datasets.EmbeddingsSegmentationDataset + init_args: + <<: *DATASET_ARGS + split: val + test: + class_path: eva.vision.datasets.EmbeddingsSegmentationDataset + init_args: + <<: *DATASET_ARGS + split: test + predict: + - class_path: eva.vision.datasets.LiTSBalanced + init_args: &PREDICT_DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data/lits} + split: train + transforms: + class_path: eva.vision.data.transforms.common.ResizeAndClamp + init_args: + size: ${oc.env:RESIZE_DIM, 224} + mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} + std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: eva.vision.datasets.LiTSBalanced + init_args: + <<: *PREDICT_DATASET_ARGS + split: val + - class_path: eva.vision.datasets.LiTSBalanced + init_args: + <<: *PREDICT_DATASET_ARGS + split: test + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 64} + shuffle: true + val: + batch_size: *BATCH_SIZE + test: + batch_size: *BATCH_SIZE + predict: + batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} diff --git a/configs/vision/radiology/online/segmentation/lits_balanced.yaml b/configs/vision/radiology/online/segmentation/lits_balanced.yaml new file mode 100644 index 00000000..0847de90 --- /dev/null +++ b/configs/vision/radiology/online/segmentation/lits_balanced.yaml @@ -0,0 +1,116 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} + log_every_n_steps: 6 + callbacks: + - class_path: eva.callbacks.ConfigurationLogger + - class_path: eva.vision.callbacks.SemanticSegmentationLogger + init_args: + log_every_n_epochs: 1 + mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} + std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: true + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/GeneralizedDiceScore} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: 100 + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.vision.models.modules.SemanticSegmentationModule + init_args: + encoder: + class_path: eva.vision.models.ModelFromRegistry + init_args: + model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} + model_kwargs: + out_indices: ${oc.env:OUT_INDICES, 1} + decoder: + class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS + init_args: + in_features: ${oc.env:IN_FEATURES, 384} + num_classes: &NUM_CLASSES 3 + criterion: + class_path: eva.vision.losses.DiceLoss + init_args: + softmax: true + batch: true + lr_multiplier_encoder: 0.0 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: ${oc.env:LR_VALUE, 0.002} + lr_scheduler: + class_path: torch.optim.lr_scheduler.PolynomialLR + init_args: + total_iters: *MAX_STEPS + power: 0.9 + postprocess: + predictions_transforms: + - class_path: torch.argmax + init_args: + dim: 1 + metrics: + common: + - class_path: eva.metrics.AverageLoss + evaluation: + - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + init_args: + num_classes: *NUM_CLASSES + - class_path: torchmetrics.ClasswiseWrapper + init_args: + metric: + class_path: torchmetrics.segmentation.GeneralizedDiceScore + init_args: + num_classes: *NUM_CLASSES + weight_type: linear + per_class: true +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.vision.datasets.LiTSBalanced + init_args: &DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data/lits} + split: train + transforms: + class_path: eva.vision.data.transforms.common.ResizeAndClamp + init_args: + size: ${oc.env:RESIZE_DIM, 224} + mean: *NORMALIZE_MEAN + std: *NORMALIZE_STD + val: + class_path: eva.vision.datasets.LiTSBalanced + init_args: + <<: *DATASET_ARGS + split: val + test: + class_path: eva.vision.datasets.LiTSBalanced + init_args: + <<: *DATASET_ARGS + split: test + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 64} + shuffle: true + val: + batch_size: *BATCH_SIZE + shuffle: true + test: + batch_size: *BATCH_SIZE diff --git a/src/eva/vision/data/datasets/__init__.py b/src/eva/vision/data/datasets/__init__.py index 0cb0dcbe..bec918af 100644 --- a/src/eva/vision/data/datasets/__init__.py +++ b/src/eva/vision/data/datasets/__init__.py @@ -15,6 +15,7 @@ EmbeddingsSegmentationDataset, ImageSegmentation, LiTS, + LiTSBalanced, MoNuSAC, TotalSegmentator2D, ) @@ -34,6 +35,7 @@ "EmbeddingsSegmentationDataset", "ImageSegmentation", "LiTS", + "LiTSBalanced", "MoNuSAC", "TotalSegmentator2D", "VisionDataset", diff --git a/src/eva/vision/data/datasets/segmentation/__init__.py b/src/eva/vision/data/datasets/segmentation/__init__.py index 3e0c5970..b954fa39 100644 --- a/src/eva/vision/data/datasets/segmentation/__init__.py +++ b/src/eva/vision/data/datasets/segmentation/__init__.py @@ -5,6 +5,7 @@ from eva.vision.data.datasets.segmentation.consep import CoNSeP from eva.vision.data.datasets.segmentation.embeddings import EmbeddingsSegmentationDataset from eva.vision.data.datasets.segmentation.lits import LiTS +from eva.vision.data.datasets.segmentation.lits_balanced import LiTSBalanced from eva.vision.data.datasets.segmentation.monusac import MoNuSAC from eva.vision.data.datasets.segmentation.total_segmentator_2d import TotalSegmentator2D @@ -14,6 +15,7 @@ "CoNSeP", "EmbeddingsSegmentationDataset", "LiTS", + "LiTSBalanced", "MoNuSAC", "TotalSegmentator2D", ] diff --git a/src/eva/vision/data/datasets/segmentation/lits_balanced.py b/src/eva/vision/data/datasets/segmentation/lits_balanced.py new file mode 100644 index 00000000..ad3fe045 --- /dev/null +++ b/src/eva/vision/data/datasets/segmentation/lits_balanced.py @@ -0,0 +1,91 @@ +"""Balanced LiTS dataset.""" + +from typing import Callable, Dict, List, Literal, Tuple + +import numpy as np +from typing_extensions import override + +from eva.vision.data.datasets.segmentation import lits +from eva.vision.utils import io + + +class LiTSBalanced(lits.LiTS): + """Balanced version of the LiTS - Liver Tumor Segmentation Challenge dataset. + + For each volume in the dataset, we sample the same number of slices where + only the liver and where both liver and tumor are present. + + Webpage: https://competitions.codalab.org/competitions/17094 + + For the splits we follow: https://arxiv.org/pdf/2010.01663v2 + """ + + _expected_dataset_lengths: Dict[str | None, int] = { + "train": 6090, + "val": 1236, + "test": 1050, + None: 8376, + } + """Dataset version and split to the expected size.""" + + def __init__( + self, + root: str, + split: Literal["train", "val", "test"] | None = None, + transforms: Callable | None = None, + ) -> None: + """Initialize dataset. + + Args: + root: Path to the root directory of the dataset. The dataset will + be downloaded and extracted here, if it does not already exist. + split: Dataset split to use. + transforms: A function/transforms that takes in an image and a target + mask and returns the transformed versions of both. + """ + super().__init__(root=root, split=split, transforms=transforms) + + @override + def _create_indices(self) -> List[Tuple[int, int]]: + """Builds the dataset indices for the specified split. + + Returns: + A list of tuples, where the first value indicates the + sample index which the second its corresponding slice + index. + """ + split_indices = set(self._get_split_indices()) + + indices: List[Tuple[int, int]] = [] + + for sample_idx, path in enumerate(self._segmentation_files): + if sample_idx not in split_indices: + continue + + segmentation = io.read_nifti(path) + tumor_filter = segmentation == 2 + tumor_slice_filter = tumor_filter.sum(axis=(0, 1)) > 0 + + if tumor_filter.sum() == 0: + continue + + liver_filter = segmentation == 1 + liver_slice_filter = liver_filter.sum(axis=(0, 1)) > 0 + + liver_and_tumor_filter = liver_slice_filter & tumor_slice_filter + liver_only_filter = liver_slice_filter & ~tumor_slice_filter + + n_slice_samples = min(liver_and_tumor_filter.sum(), liver_only_filter.sum()) + tumor_indices = list(np.where(liver_and_tumor_filter)[0]) + tumor_indices = list( + np.random.choice(tumor_indices, size=n_slice_samples, replace=False) + ) + + liver_indices = list(np.where(liver_only_filter)[0]) + liver_indices = list( + np.random.choice(liver_indices, size=n_slice_samples, replace=False) + ) + + indices.extend([(sample_idx, slice_idx) for slice_idx in tumor_indices + liver_indices]) + + return indices diff --git a/tests/eva/assets/vision/datasets/lits/Training_Batch2/segmentation-31.nii b/tests/eva/assets/vision/datasets/lits/Training_Batch2/segmentation-31.nii index ee1409db..8e43f996 100644 --- a/tests/eva/assets/vision/datasets/lits/Training_Batch2/segmentation-31.nii +++ b/tests/eva/assets/vision/datasets/lits/Training_Batch2/segmentation-31.nii @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0623e62b3f4b3bc8fea08c0e2d0613834bc291db637bde434b17c1ca44875d26 -size 1048928 +oid sha256:764be4ca61551d83a885236b05558fda078c9ffb1a16f80ed49e41f76574a5a2 +size 8388960 diff --git a/tests/eva/assets/vision/datasets/lits/Training_Batch2/segmentation-45.nii b/tests/eva/assets/vision/datasets/lits/Training_Batch2/segmentation-45.nii index ee1409db..8e43f996 100644 --- a/tests/eva/assets/vision/datasets/lits/Training_Batch2/segmentation-45.nii +++ b/tests/eva/assets/vision/datasets/lits/Training_Batch2/segmentation-45.nii @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0623e62b3f4b3bc8fea08c0e2d0613834bc291db637bde434b17c1ca44875d26 -size 1048928 +oid sha256:764be4ca61551d83a885236b05558fda078c9ffb1a16f80ed49e41f76574a5a2 +size 8388960 diff --git a/tests/eva/vision/data/datasets/segmentation/test_lits_balanced.py b/tests/eva/vision/data/datasets/segmentation/test_lits_balanced.py new file mode 100644 index 00000000..ee5dd9fe --- /dev/null +++ b/tests/eva/vision/data/datasets/segmentation/test_lits_balanced.py @@ -0,0 +1,59 @@ +"""LiTS dataset tests.""" + +import os +from typing import Literal + +import pytest +from torchvision import tv_tensors + +from eva.vision.data import datasets + + +@pytest.mark.parametrize( + "split, expected_length", + [(None, 4)], +) +def test_length(lits_balanced_dataset: datasets.LiTSBalanced, expected_length: int) -> None: + """Tests the length of the dataset.""" + assert len(lits_balanced_dataset) == expected_length + + +@pytest.mark.parametrize( + "split, index", + [ + (None, 0), + ], +) +def test_sample(lits_balanced_dataset: datasets.LiTSBalanced, index: int) -> None: + """Tests the format of a dataset sample.""" + # assert data sample is a tuple + sample = lits_balanced_dataset[index] + assert isinstance(sample, tuple) + assert len(sample) == 3 + # assert the format of the `image` and `mask` + image, mask, metadata = sample + assert isinstance(image, tv_tensors.Image) + assert image.shape == (1, 512, 512) + assert isinstance(mask, tv_tensors.Mask) + assert mask.shape == (512, 512) + assert isinstance(metadata, dict) + assert "slice_index" in metadata + + +@pytest.fixture(scope="function") +def lits_balanced_dataset( + split: Literal["train", "val", "test"] | None, assets_path: str +) -> datasets.LiTSBalanced: + """LiTS dataset fixture.""" + dataset = datasets.LiTSBalanced( + root=os.path.join( + assets_path, + "vision", + "datasets", + "lits", + ), + split=split, + ) + dataset.prepare_data() + dataset.configure() + return dataset From 581488a305e41ac39c57af431d0ecccd2bf72609 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Mon, 30 Sep 2024 10:44:09 +0200 Subject: [PATCH 2/2] added TQDMProgressBar to yaml configs --- .../vision/radiology/offline/segmentation/lits_balanced.yaml | 3 +++ .../vision/radiology/online/segmentation/lits_balanced.yaml | 3 +++ 2 files changed, 6 insertions(+) diff --git a/configs/vision/radiology/offline/segmentation/lits_balanced.yaml b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml index b24cf572..28e76c59 100644 --- a/configs/vision/radiology/offline/segmentation/lits_balanced.yaml +++ b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml @@ -7,6 +7,9 @@ trainer: max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} callbacks: - class_path: eva.callbacks.ConfigurationLogger + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} - class_path: eva.vision.callbacks.SemanticSegmentationLogger init_args: log_every_n_epochs: 1 diff --git a/configs/vision/radiology/online/segmentation/lits_balanced.yaml b/configs/vision/radiology/online/segmentation/lits_balanced.yaml index 0847de90..e767c097 100644 --- a/configs/vision/radiology/online/segmentation/lits_balanced.yaml +++ b/configs/vision/radiology/online/segmentation/lits_balanced.yaml @@ -8,6 +8,9 @@ trainer: log_every_n_steps: 6 callbacks: - class_path: eva.callbacks.ConfigurationLogger + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} - class_path: eva.vision.callbacks.SemanticSegmentationLogger init_args: log_every_n_epochs: 1