Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LiTSBalanced dataset #648

Merged
merged 3 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions configs/vision/radiology/offline/segmentation/lits_balanced.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
---
trainer:
class_path: eva.Trainer
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000}
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: eva.vision.callbacks.SemanticSegmentationLogger
init_args:
log_every_n_epochs: 1
log_images: false
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: best
save_last: true
save_top_k: 1
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/GeneralizedDiceScore}
mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
min_delta: 0
patience: 100
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
- class_path: eva.callbacks.SegmentationEmbeddingsWriter
init_args:
output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits
dataloader_idx_map:
0: train
1: val
2: test
metadata_keys: ["slice_index"]
overwrite: false
backbone:
class_path: eva.vision.models.ModelFromRegistry
init_args:
model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino}
model_kwargs:
out_indices: ${oc.env:OUT_INDICES, 1}
model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null}
logger:
- class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: *OUTPUT_ROOT
name: ""
model:
class_path: eva.vision.models.modules.SemanticSegmentationModule
init_args:
decoder:
class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS
init_args:
in_features: ${oc.env:IN_FEATURES, 384}
num_classes: &NUM_CLASSES 3
criterion:
class_path: eva.vision.losses.DiceLoss
init_args:
softmax: true
batch: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: ${oc.env:LR_VALUE, 0.002}
lr_scheduler:
class_path: torch.optim.lr_scheduler.PolynomialLR
init_args:
total_iters: *MAX_STEPS
power: 0.9
postprocess:
predictions_transforms:
- class_path: torch.argmax
init_args:
dim: 1
metrics:
common:
- class_path: eva.metrics.AverageLoss
evaluation:
- class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics
init_args:
num_classes: *NUM_CLASSES
- class_path: torchmetrics.ClasswiseWrapper
init_args:
metric:
class_path: torchmetrics.segmentation.GeneralizedDiceScore
init_args:
num_classes: *NUM_CLASSES
weight_type: linear
per_class: true
data:
class_path: eva.DataModule
init_args:
datasets:
train:
class_path: eva.vision.datasets.EmbeddingsSegmentationDataset
init_args: &DATASET_ARGS
root: *DATASET_EMBEDDINGS_ROOT
manifest_file: manifest.csv
split: train
val:
class_path: eva.vision.datasets.EmbeddingsSegmentationDataset
init_args:
<<: *DATASET_ARGS
split: val
test:
class_path: eva.vision.datasets.EmbeddingsSegmentationDataset
init_args:
<<: *DATASET_ARGS
split: test
predict:
- class_path: eva.vision.datasets.LiTSBalanced
init_args: &PREDICT_DATASET_ARGS
root: ${oc.env:DATA_ROOT, ./data/lits}
split: train
transforms:
class_path: eva.vision.data.transforms.common.ResizeAndClamp
init_args:
size: ${oc.env:RESIZE_DIM, 224}
mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
- class_path: eva.vision.datasets.LiTSBalanced
init_args:
<<: *PREDICT_DATASET_ARGS
split: val
- class_path: eva.vision.datasets.LiTSBalanced
init_args:
<<: *PREDICT_DATASET_ARGS
split: test
dataloaders:
train:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 64}
shuffle: true
val:
batch_size: *BATCH_SIZE
test:
batch_size: *BATCH_SIZE
predict:
batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64}
116 changes: 116 additions & 0 deletions configs/vision/radiology/online/segmentation/lits_balanced.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
---
trainer:
class_path: eva.Trainer
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000}
log_every_n_steps: 6
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: eva.vision.callbacks.SemanticSegmentationLogger
init_args:
log_every_n_epochs: 1
mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: best
save_last: true
save_top_k: 1
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/GeneralizedDiceScore}
mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
min_delta: 0
patience: 100
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
logger:
- class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: *OUTPUT_ROOT
name: ""
model:
class_path: eva.vision.models.modules.SemanticSegmentationModule
init_args:
encoder:
class_path: eva.vision.models.ModelFromRegistry
init_args:
model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino}
model_kwargs:
out_indices: ${oc.env:OUT_INDICES, 1}
decoder:
class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS
init_args:
in_features: ${oc.env:IN_FEATURES, 384}
num_classes: &NUM_CLASSES 3
criterion:
class_path: eva.vision.losses.DiceLoss
init_args:
softmax: true
batch: true
lr_multiplier_encoder: 0.0
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: ${oc.env:LR_VALUE, 0.002}
lr_scheduler:
class_path: torch.optim.lr_scheduler.PolynomialLR
init_args:
total_iters: *MAX_STEPS
power: 0.9
postprocess:
predictions_transforms:
- class_path: torch.argmax
init_args:
dim: 1
metrics:
common:
- class_path: eva.metrics.AverageLoss
evaluation:
- class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics
init_args:
num_classes: *NUM_CLASSES
- class_path: torchmetrics.ClasswiseWrapper
init_args:
metric:
class_path: torchmetrics.segmentation.GeneralizedDiceScore
init_args:
num_classes: *NUM_CLASSES
weight_type: linear
per_class: true
data:
class_path: eva.DataModule
init_args:
datasets:
train:
class_path: eva.vision.datasets.LiTSBalanced
init_args: &DATASET_ARGS
root: ${oc.env:DATA_ROOT, ./data/lits}
split: train
transforms:
class_path: eva.vision.data.transforms.common.ResizeAndClamp
init_args:
size: ${oc.env:RESIZE_DIM, 224}
mean: *NORMALIZE_MEAN
std: *NORMALIZE_STD
val:
class_path: eva.vision.datasets.LiTSBalanced
init_args:
<<: *DATASET_ARGS
split: val
test:
class_path: eva.vision.datasets.LiTSBalanced
init_args:
<<: *DATASET_ARGS
split: test
dataloaders:
train:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 64}
shuffle: true
val:
batch_size: *BATCH_SIZE
shuffle: true
test:
batch_size: *BATCH_SIZE
2 changes: 2 additions & 0 deletions src/eva/vision/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
EmbeddingsSegmentationDataset,
ImageSegmentation,
LiTS,
LiTSBalanced,
MoNuSAC,
TotalSegmentator2D,
)
Expand All @@ -34,6 +35,7 @@
"EmbeddingsSegmentationDataset",
"ImageSegmentation",
"LiTS",
"LiTSBalanced",
"MoNuSAC",
"TotalSegmentator2D",
"VisionDataset",
Expand Down
2 changes: 2 additions & 0 deletions src/eva/vision/data/datasets/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from eva.vision.data.datasets.segmentation.consep import CoNSeP
from eva.vision.data.datasets.segmentation.embeddings import EmbeddingsSegmentationDataset
from eva.vision.data.datasets.segmentation.lits import LiTS
from eva.vision.data.datasets.segmentation.lits_balanced import LiTSBalanced
from eva.vision.data.datasets.segmentation.monusac import MoNuSAC
from eva.vision.data.datasets.segmentation.total_segmentator_2d import TotalSegmentator2D

Expand All @@ -14,6 +15,7 @@
"CoNSeP",
"EmbeddingsSegmentationDataset",
"LiTS",
"LiTSBalanced",
"MoNuSAC",
"TotalSegmentator2D",
]
91 changes: 91 additions & 0 deletions src/eva/vision/data/datasets/segmentation/lits_balanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Balanced LiTS dataset."""

from typing import Callable, Dict, List, Literal, Tuple

import numpy as np
from typing_extensions import override

from eva.vision.data.datasets.segmentation import lits
from eva.vision.utils import io


class LiTSBalanced(lits.LiTS):
"""Balanced version of the LiTS - Liver Tumor Segmentation Challenge dataset.

For each volume in the dataset, we sample the same number of slices where
only the liver and where both liver and tumor are present.

Webpage: https://competitions.codalab.org/competitions/17094

For the splits we follow: https://arxiv.org/pdf/2010.01663v2
"""

_expected_dataset_lengths: Dict[str | None, int] = {
"train": 6090,
"val": 1236,
"test": 1050,
None: 8376,
}
"""Dataset version and split to the expected size."""

def __init__(
self,
root: str,
split: Literal["train", "val", "test"] | None = None,
transforms: Callable | None = None,
) -> None:
"""Initialize dataset.

Args:
root: Path to the root directory of the dataset. The dataset will
be downloaded and extracted here, if it does not already exist.
split: Dataset split to use.
transforms: A function/transforms that takes in an image and a target
mask and returns the transformed versions of both.
"""
super().__init__(root=root, split=split, transforms=transforms)

@override
def _create_indices(self) -> List[Tuple[int, int]]:
"""Builds the dataset indices for the specified split.

Returns:
A list of tuples, where the first value indicates the
sample index which the second its corresponding slice
index.
"""
split_indices = set(self._get_split_indices())

indices: List[Tuple[int, int]] = []

for sample_idx, path in enumerate(self._segmentation_files):
if sample_idx not in split_indices:
continue

segmentation = io.read_nifti(path)
tumor_filter = segmentation == 2
tumor_slice_filter = tumor_filter.sum(axis=(0, 1)) > 0

if tumor_filter.sum() == 0:
continue

liver_filter = segmentation == 1
liver_slice_filter = liver_filter.sum(axis=(0, 1)) > 0

liver_and_tumor_filter = liver_slice_filter & tumor_slice_filter
liver_only_filter = liver_slice_filter & ~tumor_slice_filter

n_slice_samples = min(liver_and_tumor_filter.sum(), liver_only_filter.sum())
tumor_indices = list(np.where(liver_and_tumor_filter)[0])
tumor_indices = list(
np.random.choice(tumor_indices, size=n_slice_samples, replace=False)
)

liver_indices = list(np.where(liver_only_filter)[0])
liver_indices = list(
np.random.choice(liver_indices, size=n_slice_samples, replace=False)
)

indices.extend([(sample_idx, slice_idx) for slice_idx in tumor_indices + liver_indices])

return indices
Git LFS file not shown
Git LFS file not shown
Loading