Skip to content

Commit

Permalink
Add support for dataloader samplers (#713)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig authored Dec 6, 2024
1 parent 9c1c29d commit 936e96d
Show file tree
Hide file tree
Showing 16 changed files with 389 additions and 29 deletions.
7 changes: 5 additions & 2 deletions src/eva/core/data/dataloaders/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,20 @@ class DataLoader:
prefetch_factor: int | None = 2
"""Number of batches loaded in advance by each worker."""

def __call__(self, dataset: datasets.TorchDataset) -> dataloader.DataLoader:
def __call__(
self, dataset: datasets.TorchDataset, sampler: samplers.Sampler | None = None
) -> dataloader.DataLoader:
"""Returns the dataloader on the provided dataset.
Args:
dataset: dataset from which to load the data.
sampler: defines the strategy to draw samples from the dataset.
"""
return dataloader.DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
sampler=self.sampler,
sampler=sampler or self.sampler,
batch_sampler=self.batch_sampler,
num_workers=self.num_workers or multiprocessing.cpu_count(),
collate_fn=self.collate_fn,
Expand Down
47 changes: 42 additions & 5 deletions src/eva/core/data/datamodules/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from eva.core.data import dataloaders as dataloaders_lib
from eva.core.data import datasets as datasets_lib
from eva.core.data import samplers as samplers_lib
from eva.core.data.datamodules import call, schemas


Expand All @@ -24,17 +25,20 @@ def __init__(
self,
datasets: schemas.DatasetsSchema | None = None,
dataloaders: schemas.DataloadersSchema | None = None,
samplers: schemas.SamplersSchema | None = None,
) -> None:
"""Initializes the datamodule.
Args:
datasets: The desired datasets.
dataloaders: The desired dataloaders.
samplers: The desired samplers for the dataloaders.
"""
super().__init__()

self.datasets = datasets or self.default_datasets
self.dataloaders = dataloaders or self.default_dataloaders
self.samplers = samplers or self.default_samplers

@property
def default_datasets(self) -> schemas.DatasetsSchema:
Expand All @@ -46,6 +50,11 @@ def default_dataloaders(self) -> schemas.DataloadersSchema:
"""Returns the default dataloader schema."""
return schemas.DataloadersSchema()

@property
def default_samplers(self) -> schemas.SamplersSchema:
"""Returns the default samplers schema."""
return schemas.SamplersSchema()

@override
def prepare_data(self) -> None:
call.call_method_if_exists(self.datasets.tolist(), "prepare_data")
Expand All @@ -64,45 +73,73 @@ def train_dataloader(self) -> TRAIN_DATALOADERS:
raise ValueError(
"Train dataloader can not be initialized as `self.datasets.train` is `None`."
)
return self.dataloaders.train(self.datasets.train)
if isinstance(self.datasets.train, list) and len(self.datasets.train) > 1:
raise ValueError("Train dataloader can not be initialized with multiple datasets.")

return self._initialize_dataloaders(
self.dataloaders.train, self.datasets.train, self.samplers.train
)[0]

@override
def val_dataloader(self) -> EVAL_DATALOADERS:
if self.datasets.val is None:
raise ValueError(
"Validation dataloader can not be initialized as `self.datasets.val` is `None`."
)
return self._initialize_dataloaders(self.dataloaders.val, self.datasets.val)
return self._initialize_dataloaders(
self.dataloaders.val, self.datasets.val, self.samplers.val
)

@override
def test_dataloader(self) -> EVAL_DATALOADERS:
if self.datasets.test is None:
raise ValueError(
"Test dataloader can not be initialized as `self.datasets.test` is `None`."
)
return self._initialize_dataloaders(self.dataloaders.test, self.datasets.test)
return self._initialize_dataloaders(
self.dataloaders.test, self.datasets.test, self.samplers.test
)

@override
def predict_dataloader(self) -> EVAL_DATALOADERS:
if self.datasets.predict is None:
raise ValueError(
"Predict dataloader can not be initialized as `self.datasets.predict` is `None`."
)
return self._initialize_dataloaders(self.dataloaders.predict, self.datasets.predict)
if isinstance(self.datasets.predict, list) and len(self.datasets.predict) > 1:
# Only apply sampler to the first predict dataset (should correspond to train split)
train_dataloader = self._initialize_dataloaders(
self.dataloaders.predict, self.datasets.predict[0], self.samplers.predict
)
return train_dataloader + self._initialize_dataloaders(
self.dataloaders.predict, self.datasets.predict[1:]
)

return self._initialize_dataloaders(
self.dataloaders.predict, self.datasets.predict, self.samplers.predict
)

def _initialize_dataloaders(
self,
dataloader: dataloaders_lib.DataLoader,
datasets: datasets_lib.TorchDataset | List[datasets_lib.TorchDataset],
sampler: samplers_lib.Sampler | None = None,
) -> EVAL_DATALOADERS:
"""Initializes dataloaders from a given set of dataset.
Args:
dataloader: The dataloader to apply to the provided datasets.
datasets: The desired dataset(s) to allocate dataloader(s).
sampler: The sampler to use for the dataloader.
Returns:
A list with the dataloaders of the provided dataset(s).
"""
datasets = datasets if isinstance(datasets, list) else [datasets]
return list(map(dataloader, datasets))

dataloaders = []
for dataset in datasets:
if sampler is not None and isinstance(sampler, samplers_lib.SamplerWithDataSource):
sampler.set_dataset(dataset) # type: ignore
dataloaders.append(dataloader(dataset, sampler=sampler))
return dataloaders
19 changes: 18 additions & 1 deletion src/eva/core/data/datamodules/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dataclasses
from typing import List

from eva.core.data import dataloaders, datasets
from eva.core.data import dataloaders, datasets, samplers

TRAIN_DATASET = datasets.TorchDataset | None
"""Train dataset."""
Expand Down Expand Up @@ -60,3 +60,20 @@ class DataloadersSchema:

predict: dataloaders.DataLoader = dataclasses.field(default_factory=dataloaders.DataLoader)
"""Predict dataloader."""


@dataclasses.dataclass(frozen=True)
class SamplersSchema:
"""Samplers schema used in DataModule."""

train: samplers.Sampler | None = None
"""Train sampler."""

val: samplers.Sampler | None = None
"""Validation sampler."""

test: samplers.Sampler | None = None
"""Test sampler."""

predict: samplers.Sampler | None = None
"""Predict sampler."""
5 changes: 4 additions & 1 deletion src/eva/core/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""Datasets API."""

from eva.core.data.datasets.base import Dataset
from eva.core.data.datasets.base import Dataset, MapDataset
from eva.core.data.datasets.classification import (
EmbeddingsClassificationDataset,
MultiEmbeddingsClassificationDataset,
)
from eva.core.data.datasets.dataset import TorchDataset
from eva.core.data.datasets.typings import DataSample

__all__ = [
"Dataset",
"MapDataset",
"EmbeddingsClassificationDataset",
"MultiEmbeddingsClassificationDataset",
"TorchDataset",
"DataSample",
]
23 changes: 23 additions & 0 deletions src/eva/core/data/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Base dataset class."""

import abc

from eva.core.data.datasets import dataset


Expand Down Expand Up @@ -51,3 +53,24 @@ def teardown(self) -> None:
of fit (train + validate), validate, test, or predict and it will be
called from every process (i.e. GPU) across all the nodes in DDP.
"""


class MapDataset(Dataset):
"""Abstract base class for all map-style datasets."""

@abc.abstractmethod
def __getitem__(self, index: int):
"""Retrieves the item at the given index.
Args:
index: Index of the item to retrieve.
Returns:
The data at the given index.
"""
raise NotImplementedError

@abc.abstractmethod
def __len__(self) -> int:
"""Returns the length of the dataset."""
raise NotImplementedError
18 changes: 18 additions & 0 deletions src/eva/core/data/datasets/typings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Typing definitions for the datasets module."""

from typing import Any, Dict, NamedTuple

import torch


class DataSample(NamedTuple):
"""The default input batch data scheme."""

data: torch.Tensor
"""The data batch."""

targets: torch.Tensor | None = None
"""The target batch."""

metadata: Dict[str, Any] | None = None
"""The associated metadata."""
6 changes: 4 additions & 2 deletions src/eva/core/data/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Data samplers API."""

from eva.core.data.samplers.sampler import Sampler
from eva.core.data.samplers.classification.balanced import BalancedSampler
from eva.core.data.samplers.random import RandomSampler
from eva.core.data.samplers.sampler import Sampler, SamplerWithDataSource

__all__ = ["Sampler"]
__all__ = ["Sampler", "SamplerWithDataSource", "RandomSampler", "BalancedSampler"]
5 changes: 5 additions & 0 deletions src/eva/core/data/samplers/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Classification data samplers API."""

from eva.core.data.samplers.classification.balanced import BalancedSampler

__all__ = ["BalancedSampler"]
96 changes: 96 additions & 0 deletions src/eva/core/data/samplers/classification/balanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Random class sampler for data loading."""

from collections import defaultdict
from typing import Dict, Iterator, List

import numpy as np
from typing_extensions import override

from eva.core.data import datasets
from eva.core.data.datasets.typings import DataSample
from eva.core.data.samplers.sampler import SamplerWithDataSource
from eva.core.utils.progress_bar import tqdm


class BalancedSampler(SamplerWithDataSource[int]):
"""Balanced class sampler for data loading.
The sampler ensures that:
1. Each class has the same number of samples
2. Samples within each class are randomly selected
3. Samples of different classes appear in random order
"""

def __init__(self, num_samples: int, replacement: bool = False, seed: int | None = 42):
"""Initializes the balanced sampler.
Args:
num_samples: The number of samples to draw per class.
replacement: samples are drawn on-demand with replacement if ``True``, default=``False``
seed: Random seed for reproducibility.
"""
self._num_samples = num_samples
self._replacement = replacement
self._class_indices: Dict[int, List[int]] = defaultdict(list)
self._random_generator = np.random.default_rng(seed)

def __len__(self) -> int:
"""Returns the total number of samples."""
return self._num_samples * len(self._class_indices)

def __iter__(self) -> Iterator[int]:
"""Creates an iterator that yields indices in a class balanced way.
Returns:
Iterator yielding dataset indices.
"""
indices = []

for class_idx in self._class_indices:
class_indices = self._class_indices[class_idx]
sampled_indices = self._random_generator.choice(
class_indices, size=self._num_samples, replace=self._replacement
).tolist()
indices.extend(sampled_indices)

self._random_generator.shuffle(indices)

return iter(indices)

@override
def set_dataset(self, data_source: datasets.MapDataset):
"""Sets the dataset and builds class indices.
Args:
data_source: The dataset to sample from.
Raises:
ValueError: If the dataset doesn't have targets or if any class has
fewer samples than `num_samples` and `replacement` is `False`.
"""
super().set_dataset(data_source)
self._make_indices()

def _make_indices(self):
"""Builds indices for each class in the dataset."""
self._class_indices.clear()

for idx in tqdm(
range(len(self.data_source)), desc="Fetching class indices for balanced sampler"
):
_, target, _ = DataSample(*self.data_source[idx])
if target is None:
raise ValueError("The dataset must return non-empty targets.")
if target.numel() != 1:
raise ValueError("The dataset must return a single & scalar target.")

class_idx = int(target.item())
self._class_indices[class_idx].append(idx)

if not self._replacement:
for class_idx, indices in self._class_indices.items():
if len(indices) < self._num_samples:
raise ValueError(
f"Class {class_idx} has only {len(indices)} samples, "
f"which is less than the required {self._num_samples} samples."
)
39 changes: 39 additions & 0 deletions src/eva/core/data/samplers/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Random sampler for data loading."""

from typing import Optional

from torch.utils import data
from typing_extensions import override

from eva.core.data import datasets
from eva.core.data.samplers.sampler import SamplerWithDataSource


class RandomSampler(data.RandomSampler, SamplerWithDataSource[int]):
"""Samples elements randomly."""

data_source: datasets.MapDataset # type: ignore

def __init__(
self, replacement: bool = False, num_samples: Optional[int] = None, generator=None
) -> None:
"""Initializes the random sampler.
Args:
data_source: dataset to sample from
replacement: samples are drawn on-demand with replacement if ``True``, default=``False``
num_samples: number of samples to draw, default=`len(dataset)`.
generator: Generator used in sampling.
"""
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator

@override
def set_dataset(self, data_source: datasets.MapDataset) -> None:
super().__init__(
data_source,
replacement=self.replacement,
num_samples=self.num_samples,
generator=self.generator,
)
Loading

0 comments on commit 936e96d

Please sign in to comment.