diff --git a/src/eva/core/data/dataloaders/dataloader.py b/src/eva/core/data/dataloaders/dataloader.py index 70040c618..65f0d6566 100644 --- a/src/eva/core/data/dataloaders/dataloader.py +++ b/src/eva/core/data/dataloaders/dataloader.py @@ -59,17 +59,20 @@ class DataLoader: prefetch_factor: int | None = 2 """Number of batches loaded in advance by each worker.""" - def __call__(self, dataset: datasets.TorchDataset) -> dataloader.DataLoader: + def __call__( + self, dataset: datasets.TorchDataset, sampler: samplers.Sampler | None = None + ) -> dataloader.DataLoader: """Returns the dataloader on the provided dataset. Args: dataset: dataset from which to load the data. + sampler: defines the strategy to draw samples from the dataset. """ return dataloader.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=self.shuffle, - sampler=self.sampler, + sampler=sampler or self.sampler, batch_sampler=self.batch_sampler, num_workers=self.num_workers or multiprocessing.cpu_count(), collate_fn=self.collate_fn, diff --git a/src/eva/core/data/datamodules/datamodule.py b/src/eva/core/data/datamodules/datamodule.py index 1f050ec75..c9522c227 100644 --- a/src/eva/core/data/datamodules/datamodule.py +++ b/src/eva/core/data/datamodules/datamodule.py @@ -8,6 +8,7 @@ from eva.core.data import dataloaders as dataloaders_lib from eva.core.data import datasets as datasets_lib +from eva.core.data import samplers as samplers_lib from eva.core.data.datamodules import call, schemas @@ -24,17 +25,20 @@ def __init__( self, datasets: schemas.DatasetsSchema | None = None, dataloaders: schemas.DataloadersSchema | None = None, + samplers: schemas.SamplersSchema | None = None, ) -> None: """Initializes the datamodule. Args: datasets: The desired datasets. dataloaders: The desired dataloaders. + samplers: The desired samplers for the dataloaders. """ super().__init__() self.datasets = datasets or self.default_datasets self.dataloaders = dataloaders or self.default_dataloaders + self.samplers = samplers or self.default_samplers @property def default_datasets(self) -> schemas.DatasetsSchema: @@ -46,6 +50,11 @@ def default_dataloaders(self) -> schemas.DataloadersSchema: """Returns the default dataloader schema.""" return schemas.DataloadersSchema() + @property + def default_samplers(self) -> schemas.SamplersSchema: + """Returns the default samplers schema.""" + return schemas.SamplersSchema() + @override def prepare_data(self) -> None: call.call_method_if_exists(self.datasets.tolist(), "prepare_data") @@ -64,7 +73,12 @@ def train_dataloader(self) -> TRAIN_DATALOADERS: raise ValueError( "Train dataloader can not be initialized as `self.datasets.train` is `None`." ) - return self.dataloaders.train(self.datasets.train) + if isinstance(self.datasets.train, list) and len(self.datasets.train) > 1: + raise ValueError("Train dataloader can not be initialized with multiple datasets.") + + return self._initialize_dataloaders( + self.dataloaders.train, self.datasets.train, self.samplers.train + )[0] @override def val_dataloader(self) -> EVAL_DATALOADERS: @@ -72,7 +86,9 @@ def val_dataloader(self) -> EVAL_DATALOADERS: raise ValueError( "Validation dataloader can not be initialized as `self.datasets.val` is `None`." ) - return self._initialize_dataloaders(self.dataloaders.val, self.datasets.val) + return self._initialize_dataloaders( + self.dataloaders.val, self.datasets.val, self.samplers.val + ) @override def test_dataloader(self) -> EVAL_DATALOADERS: @@ -80,7 +96,9 @@ def test_dataloader(self) -> EVAL_DATALOADERS: raise ValueError( "Test dataloader can not be initialized as `self.datasets.test` is `None`." ) - return self._initialize_dataloaders(self.dataloaders.test, self.datasets.test) + return self._initialize_dataloaders( + self.dataloaders.test, self.datasets.test, self.samplers.test + ) @override def predict_dataloader(self) -> EVAL_DATALOADERS: @@ -88,21 +106,40 @@ def predict_dataloader(self) -> EVAL_DATALOADERS: raise ValueError( "Predict dataloader can not be initialized as `self.datasets.predict` is `None`." ) - return self._initialize_dataloaders(self.dataloaders.predict, self.datasets.predict) + if isinstance(self.datasets.predict, list) and len(self.datasets.predict) > 1: + # Only apply sampler to the first predict dataset (should correspond to train split) + train_dataloader = self._initialize_dataloaders( + self.dataloaders.predict, self.datasets.predict[0], self.samplers.predict + ) + return train_dataloader + self._initialize_dataloaders( + self.dataloaders.predict, self.datasets.predict[1:] + ) + + return self._initialize_dataloaders( + self.dataloaders.predict, self.datasets.predict, self.samplers.predict + ) def _initialize_dataloaders( self, dataloader: dataloaders_lib.DataLoader, datasets: datasets_lib.TorchDataset | List[datasets_lib.TorchDataset], + sampler: samplers_lib.Sampler | None = None, ) -> EVAL_DATALOADERS: """Initializes dataloaders from a given set of dataset. Args: dataloader: The dataloader to apply to the provided datasets. datasets: The desired dataset(s) to allocate dataloader(s). + sampler: The sampler to use for the dataloader. Returns: A list with the dataloaders of the provided dataset(s). """ datasets = datasets if isinstance(datasets, list) else [datasets] - return list(map(dataloader, datasets)) + + dataloaders = [] + for dataset in datasets: + if sampler is not None and isinstance(sampler, samplers_lib.SamplerWithDataSource): + sampler.set_dataset(dataset) # type: ignore + dataloaders.append(dataloader(dataset, sampler=sampler)) + return dataloaders diff --git a/src/eva/core/data/datamodules/schemas.py b/src/eva/core/data/datamodules/schemas.py index 8780ac61d..d19b342e3 100644 --- a/src/eva/core/data/datamodules/schemas.py +++ b/src/eva/core/data/datamodules/schemas.py @@ -3,7 +3,7 @@ import dataclasses from typing import List -from eva.core.data import dataloaders, datasets +from eva.core.data import dataloaders, datasets, samplers TRAIN_DATASET = datasets.TorchDataset | None """Train dataset.""" @@ -60,3 +60,20 @@ class DataloadersSchema: predict: dataloaders.DataLoader = dataclasses.field(default_factory=dataloaders.DataLoader) """Predict dataloader.""" + + +@dataclasses.dataclass(frozen=True) +class SamplersSchema: + """Samplers schema used in DataModule.""" + + train: samplers.Sampler | None = None + """Train sampler.""" + + val: samplers.Sampler | None = None + """Validation sampler.""" + + test: samplers.Sampler | None = None + """Test sampler.""" + + predict: samplers.Sampler | None = None + """Predict sampler.""" diff --git a/src/eva/core/data/datasets/__init__.py b/src/eva/core/data/datasets/__init__.py index ba4da0cff..c5e366827 100644 --- a/src/eva/core/data/datasets/__init__.py +++ b/src/eva/core/data/datasets/__init__.py @@ -1,15 +1,18 @@ """Datasets API.""" -from eva.core.data.datasets.base import Dataset +from eva.core.data.datasets.base import Dataset, MapDataset from eva.core.data.datasets.classification import ( EmbeddingsClassificationDataset, MultiEmbeddingsClassificationDataset, ) from eva.core.data.datasets.dataset import TorchDataset +from eva.core.data.datasets.typings import DataSample __all__ = [ "Dataset", + "MapDataset", "EmbeddingsClassificationDataset", "MultiEmbeddingsClassificationDataset", "TorchDataset", + "DataSample", ] diff --git a/src/eva/core/data/datasets/base.py b/src/eva/core/data/datasets/base.py index d83fa74ed..a03eaf736 100644 --- a/src/eva/core/data/datasets/base.py +++ b/src/eva/core/data/datasets/base.py @@ -1,5 +1,7 @@ """Base dataset class.""" +import abc + from eva.core.data.datasets import dataset @@ -51,3 +53,24 @@ def teardown(self) -> None: of fit (train + validate), validate, test, or predict and it will be called from every process (i.e. GPU) across all the nodes in DDP. """ + + +class MapDataset(Dataset): + """Abstract base class for all map-style datasets.""" + + @abc.abstractmethod + def __getitem__(self, index: int): + """Retrieves the item at the given index. + + Args: + index: Index of the item to retrieve. + + Returns: + The data at the given index. + """ + raise NotImplementedError + + @abc.abstractmethod + def __len__(self) -> int: + """Returns the length of the dataset.""" + raise NotImplementedError diff --git a/src/eva/core/data/datasets/typings.py b/src/eva/core/data/datasets/typings.py new file mode 100644 index 000000000..465b23e25 --- /dev/null +++ b/src/eva/core/data/datasets/typings.py @@ -0,0 +1,18 @@ +"""Typing definitions for the datasets module.""" + +from typing import Any, Dict, NamedTuple + +import torch + + +class DataSample(NamedTuple): + """The default input batch data scheme.""" + + data: torch.Tensor + """The data batch.""" + + targets: torch.Tensor | None = None + """The target batch.""" + + metadata: Dict[str, Any] | None = None + """The associated metadata.""" diff --git a/src/eva/core/data/samplers/__init__.py b/src/eva/core/data/samplers/__init__.py index 5cc3a852e..7586d533a 100644 --- a/src/eva/core/data/samplers/__init__.py +++ b/src/eva/core/data/samplers/__init__.py @@ -1,5 +1,7 @@ """Data samplers API.""" -from eva.core.data.samplers.sampler import Sampler +from eva.core.data.samplers.classification.balanced import BalancedSampler +from eva.core.data.samplers.random import RandomSampler +from eva.core.data.samplers.sampler import Sampler, SamplerWithDataSource -__all__ = ["Sampler"] +__all__ = ["Sampler", "SamplerWithDataSource", "RandomSampler", "BalancedSampler"] diff --git a/src/eva/core/data/samplers/classification/__init__.py b/src/eva/core/data/samplers/classification/__init__.py new file mode 100644 index 000000000..c68235bcc --- /dev/null +++ b/src/eva/core/data/samplers/classification/__init__.py @@ -0,0 +1,5 @@ +"""Classification data samplers API.""" + +from eva.core.data.samplers.classification.balanced import BalancedSampler + +__all__ = ["BalancedSampler"] diff --git a/src/eva/core/data/samplers/classification/balanced.py b/src/eva/core/data/samplers/classification/balanced.py new file mode 100644 index 000000000..ed3a19d39 --- /dev/null +++ b/src/eva/core/data/samplers/classification/balanced.py @@ -0,0 +1,96 @@ +"""Random class sampler for data loading.""" + +from collections import defaultdict +from typing import Dict, Iterator, List + +import numpy as np +from typing_extensions import override + +from eva.core.data import datasets +from eva.core.data.datasets.typings import DataSample +from eva.core.data.samplers.sampler import SamplerWithDataSource +from eva.core.utils.progress_bar import tqdm + + +class BalancedSampler(SamplerWithDataSource[int]): + """Balanced class sampler for data loading. + + The sampler ensures that: + 1. Each class has the same number of samples + 2. Samples within each class are randomly selected + 3. Samples of different classes appear in random order + """ + + def __init__(self, num_samples: int, replacement: bool = False, seed: int | None = 42): + """Initializes the balanced sampler. + + Args: + num_samples: The number of samples to draw per class. + replacement: samples are drawn on-demand with replacement if ``True``, default=``False`` + seed: Random seed for reproducibility. + """ + self._num_samples = num_samples + self._replacement = replacement + self._class_indices: Dict[int, List[int]] = defaultdict(list) + self._random_generator = np.random.default_rng(seed) + + def __len__(self) -> int: + """Returns the total number of samples.""" + return self._num_samples * len(self._class_indices) + + def __iter__(self) -> Iterator[int]: + """Creates an iterator that yields indices in a class balanced way. + + Returns: + Iterator yielding dataset indices. + """ + indices = [] + + for class_idx in self._class_indices: + class_indices = self._class_indices[class_idx] + sampled_indices = self._random_generator.choice( + class_indices, size=self._num_samples, replace=self._replacement + ).tolist() + indices.extend(sampled_indices) + + self._random_generator.shuffle(indices) + + return iter(indices) + + @override + def set_dataset(self, data_source: datasets.MapDataset): + """Sets the dataset and builds class indices. + + Args: + data_source: The dataset to sample from. + + Raises: + ValueError: If the dataset doesn't have targets or if any class has + fewer samples than `num_samples` and `replacement` is `False`. + """ + super().set_dataset(data_source) + self._make_indices() + + def _make_indices(self): + """Builds indices for each class in the dataset.""" + self._class_indices.clear() + + for idx in tqdm( + range(len(self.data_source)), desc="Fetching class indices for balanced sampler" + ): + _, target, _ = DataSample(*self.data_source[idx]) + if target is None: + raise ValueError("The dataset must return non-empty targets.") + if target.numel() != 1: + raise ValueError("The dataset must return a single & scalar target.") + + class_idx = int(target.item()) + self._class_indices[class_idx].append(idx) + + if not self._replacement: + for class_idx, indices in self._class_indices.items(): + if len(indices) < self._num_samples: + raise ValueError( + f"Class {class_idx} has only {len(indices)} samples, " + f"which is less than the required {self._num_samples} samples." + ) diff --git a/src/eva/core/data/samplers/random.py b/src/eva/core/data/samplers/random.py new file mode 100644 index 000000000..415b8ae3e --- /dev/null +++ b/src/eva/core/data/samplers/random.py @@ -0,0 +1,39 @@ +"""Random sampler for data loading.""" + +from typing import Optional + +from torch.utils import data +from typing_extensions import override + +from eva.core.data import datasets +from eva.core.data.samplers.sampler import SamplerWithDataSource + + +class RandomSampler(data.RandomSampler, SamplerWithDataSource[int]): + """Samples elements randomly.""" + + data_source: datasets.MapDataset # type: ignore + + def __init__( + self, replacement: bool = False, num_samples: Optional[int] = None, generator=None + ) -> None: + """Initializes the random sampler. + + Args: + data_source: dataset to sample from + replacement: samples are drawn on-demand with replacement if ``True``, default=``False`` + num_samples: number of samples to draw, default=`len(dataset)`. + generator: Generator used in sampling. + """ + self.replacement = replacement + self._num_samples = num_samples + self.generator = generator + + @override + def set_dataset(self, data_source: datasets.MapDataset) -> None: + super().__init__( + data_source, + replacement=self.replacement, + num_samples=self.num_samples, + generator=self.generator, + ) diff --git a/src/eva/core/data/samplers/sampler.py b/src/eva/core/data/samplers/sampler.py index 98b3124b2..ff878fa36 100644 --- a/src/eva/core/data/samplers/sampler.py +++ b/src/eva/core/data/samplers/sampler.py @@ -1,6 +1,33 @@ """Core data sampler.""" +from typing import Generic, TypeVar + from torch.utils import data +from eva.core.data import datasets + Sampler = data.Sampler """Core abstract data sampler class.""" + +T_co = TypeVar("T_co", covariant=True) + + +class SamplerWithDataSource(Sampler, Generic[T_co]): + """A sampler base class that enables to specify the data source after initialization. + + The `set_dataset` can also be overwritten to expand the functionality of the derived + sampler classes. + """ + + data_source: datasets.MapDataset + + def set_dataset(self, data_source: datasets.MapDataset) -> None: + """Sets the dataset to sample from. + + This is not done in the constructor because the dataset might not be + available at that time. + + Args: + data_source: The dataset to sample from. + """ + self.data_source = data_source diff --git a/src/eva/vision/data/datasets/vision.py b/src/eva/vision/data/datasets/vision.py index 81b08f57d..ca3387651 100644 --- a/src/eva/vision/data/datasets/vision.py +++ b/src/eva/vision/data/datasets/vision.py @@ -9,7 +9,7 @@ """The data sample type.""" -class VisionDataset(base.Dataset, abc.ABC, Generic[DataSample]): +class VisionDataset(base.MapDataset, abc.ABC, Generic[DataSample]): """Base dataset class for vision tasks.""" @abc.abstractmethod @@ -24,20 +24,3 @@ def filename(self, index: int) -> str: Returns: The filename of the `index`'th data sample. """ - - @abc.abstractmethod - def __getitem__(self, index: int) -> DataSample: - """Returns the `index`'th data sample. - - Args: - index: The index of the data-sample to select. - - Returns: - A data sample and its target. - """ - raise NotImplementedError - - @abc.abstractmethod - def __len__(self) -> int: - """Returns the total length of the data.""" - raise NotImplementedError diff --git a/tests/eva/core/data/samplers/__init__.py b/tests/eva/core/data/samplers/__init__.py new file mode 100644 index 000000000..39e9f73a3 --- /dev/null +++ b/tests/eva/core/data/samplers/__init__.py @@ -0,0 +1 @@ +"""Tests for data loader samplers.""" diff --git a/tests/eva/core/data/samplers/_utils.py b/tests/eva/core/data/samplers/_utils.py new file mode 100644 index 000000000..7e09996a0 --- /dev/null +++ b/tests/eva/core/data/samplers/_utils.py @@ -0,0 +1,30 @@ +"""Test utilities for dataloader sampler tests.""" + +from typing import List, Tuple + +import torch +from typing_extensions import override + +from eva.core.data import datasets + + +class MockDataset(datasets.MapDataset): + """Mock map-style dataset class for unit testing.""" + + def __init__(self, samples: List[Tuple[None, torch.Tensor, None]]): + self.samples = samples + + @override + def __getitem__(self, idx): + return self.samples[idx] + + @override + def __len__(self): + return len(self.samples) + + +def multiclass_dataset(num_samples: int, num_classes: int) -> datasets.MapDataset: + samples = ( + [(None, torch.tensor([i]), None)] * (num_samples // num_classes) for i in range(num_classes) + ) + return MockDataset([item for sublist in samples for item in sublist]) diff --git a/tests/eva/core/data/samplers/classification/__init__.py b/tests/eva/core/data/samplers/classification/__init__.py new file mode 100644 index 000000000..ae2610a37 --- /dev/null +++ b/tests/eva/core/data/samplers/classification/__init__.py @@ -0,0 +1 @@ +"""Tests for classification data loader samplers.""" diff --git a/tests/eva/core/data/samplers/classification/test_balanced.py b/tests/eva/core/data/samplers/classification/test_balanced.py new file mode 100644 index 000000000..ea30a08a0 --- /dev/null +++ b/tests/eva/core/data/samplers/classification/test_balanced.py @@ -0,0 +1,75 @@ +"""Tests for the balanced sampler.""" + +from collections import Counter + +import pytest +import torch + +from eva.core.data.datasets.typings import DataSample +from eva.core.data.samplers.classification import BalancedSampler +from tests.eva.core.data.samplers import _utils + + +@pytest.mark.parametrize( + "num_class_samples, replacement, num_dataset_samples, num_classes", + [ + (3, False, 15, 2), + (20, True, 15, 2), + (3, False, 33, 5), + ], +) +def test_balanced_sampling( + num_class_samples: int, replacement: bool, num_dataset_samples: int, num_classes: int +): + """Tests if the returned indices are balanced.""" + dataset = _utils.multiclass_dataset(num_dataset_samples, num_classes) + sampler = BalancedSampler(num_samples=num_class_samples, replacement=replacement) + sampler.set_dataset(dataset) + + indices = list(sampler) + class_counts = Counter(DataSample(*dataset[i]).targets.item() for i in indices) # type: ignore + + assert len(sampler) == num_class_samples * num_classes + assert len(class_counts.keys()) == num_classes + for count in class_counts.values(): + assert count == num_class_samples + + +def test_insufficient_samples_without_replacement(): + """Tests if the sampler raises an error when there are insufficient samples.""" + num_dataset_samples, num_classes = 15, 3 + dataset = _utils.multiclass_dataset(num_dataset_samples, num_classes) + sampler = BalancedSampler(num_samples=7, replacement=False) + + with pytest.raises(ValueError, match=f"has only {num_dataset_samples // num_classes} samples"): + sampler.set_dataset(dataset) + + +def test_random_seed(): + """Tests if the sampler is reproducible with the same seed.""" + num_dataset_samples, num_classes = 101, 3 + dataset = _utils.multiclass_dataset(num_dataset_samples, num_classes) + sampler1 = BalancedSampler(num_samples=10, seed=1) + sampler1_duplicate = BalancedSampler(num_samples=10, seed=1) + sampler2 = BalancedSampler(num_samples=10, seed=2) + sampler1.set_dataset(dataset) + sampler1_duplicate.set_dataset(dataset) + sampler2.set_dataset(dataset) + + assert list(sampler1) == list(sampler1_duplicate) + assert list(sampler1) != list(sampler2) + + +def test_invalid_targets(): + """Tests if the sampler raises an error unsupported target formats.""" + sampler = BalancedSampler(num_samples=10) + + # test multi-dimensional target + dataset = _utils.MockDataset([(None, torch.tensor([0, 1]), None)]) + with pytest.raises(ValueError, match="single & scalar target"): + sampler.set_dataset(dataset) + + # test empty target + dataset = _utils.MockDataset([(None, None, None)]) # type: ignore + with pytest.raises(ValueError, match="non-empty targets"): + sampler.set_dataset(dataset)