-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for dataloader samplers (#713)
- Loading branch information
Showing
16 changed files
with
389 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,18 @@ | ||
"""Datasets API.""" | ||
|
||
from eva.core.data.datasets.base import Dataset | ||
from eva.core.data.datasets.base import Dataset, MapDataset | ||
from eva.core.data.datasets.classification import ( | ||
EmbeddingsClassificationDataset, | ||
MultiEmbeddingsClassificationDataset, | ||
) | ||
from eva.core.data.datasets.dataset import TorchDataset | ||
from eva.core.data.datasets.typings import DataSample | ||
|
||
__all__ = [ | ||
"Dataset", | ||
"MapDataset", | ||
"EmbeddingsClassificationDataset", | ||
"MultiEmbeddingsClassificationDataset", | ||
"TorchDataset", | ||
"DataSample", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
"""Typing definitions for the datasets module.""" | ||
|
||
from typing import Any, Dict, NamedTuple | ||
|
||
import torch | ||
|
||
|
||
class DataSample(NamedTuple): | ||
"""The default input batch data scheme.""" | ||
|
||
data: torch.Tensor | ||
"""The data batch.""" | ||
|
||
targets: torch.Tensor | None = None | ||
"""The target batch.""" | ||
|
||
metadata: Dict[str, Any] | None = None | ||
"""The associated metadata.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
"""Data samplers API.""" | ||
|
||
from eva.core.data.samplers.sampler import Sampler | ||
from eva.core.data.samplers.classification.balanced import BalancedSampler | ||
from eva.core.data.samplers.random import RandomSampler | ||
from eva.core.data.samplers.sampler import Sampler, SamplerWithDataSource | ||
|
||
__all__ = ["Sampler"] | ||
__all__ = ["Sampler", "SamplerWithDataSource", "RandomSampler", "BalancedSampler"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Classification data samplers API.""" | ||
|
||
from eva.core.data.samplers.classification.balanced import BalancedSampler | ||
|
||
__all__ = ["BalancedSampler"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
"""Random class sampler for data loading.""" | ||
|
||
from collections import defaultdict | ||
from typing import Dict, Iterator, List | ||
|
||
import numpy as np | ||
from typing_extensions import override | ||
|
||
from eva.core.data import datasets | ||
from eva.core.data.datasets.typings import DataSample | ||
from eva.core.data.samplers.sampler import SamplerWithDataSource | ||
from eva.core.utils.progress_bar import tqdm | ||
|
||
|
||
class BalancedSampler(SamplerWithDataSource[int]): | ||
"""Balanced class sampler for data loading. | ||
The sampler ensures that: | ||
1. Each class has the same number of samples | ||
2. Samples within each class are randomly selected | ||
3. Samples of different classes appear in random order | ||
""" | ||
|
||
def __init__(self, num_samples: int, replacement: bool = False, seed: int | None = 42): | ||
"""Initializes the balanced sampler. | ||
Args: | ||
num_samples: The number of samples to draw per class. | ||
replacement: samples are drawn on-demand with replacement if ``True``, default=``False`` | ||
seed: Random seed for reproducibility. | ||
""" | ||
self._num_samples = num_samples | ||
self._replacement = replacement | ||
self._class_indices: Dict[int, List[int]] = defaultdict(list) | ||
self._random_generator = np.random.default_rng(seed) | ||
|
||
def __len__(self) -> int: | ||
"""Returns the total number of samples.""" | ||
return self._num_samples * len(self._class_indices) | ||
|
||
def __iter__(self) -> Iterator[int]: | ||
"""Creates an iterator that yields indices in a class balanced way. | ||
Returns: | ||
Iterator yielding dataset indices. | ||
""" | ||
indices = [] | ||
|
||
for class_idx in self._class_indices: | ||
class_indices = self._class_indices[class_idx] | ||
sampled_indices = self._random_generator.choice( | ||
class_indices, size=self._num_samples, replace=self._replacement | ||
).tolist() | ||
indices.extend(sampled_indices) | ||
|
||
self._random_generator.shuffle(indices) | ||
|
||
return iter(indices) | ||
|
||
@override | ||
def set_dataset(self, data_source: datasets.MapDataset): | ||
"""Sets the dataset and builds class indices. | ||
Args: | ||
data_source: The dataset to sample from. | ||
Raises: | ||
ValueError: If the dataset doesn't have targets or if any class has | ||
fewer samples than `num_samples` and `replacement` is `False`. | ||
""" | ||
super().set_dataset(data_source) | ||
self._make_indices() | ||
|
||
def _make_indices(self): | ||
"""Builds indices for each class in the dataset.""" | ||
self._class_indices.clear() | ||
|
||
for idx in tqdm( | ||
range(len(self.data_source)), desc="Fetching class indices for balanced sampler" | ||
): | ||
_, target, _ = DataSample(*self.data_source[idx]) | ||
if target is None: | ||
raise ValueError("The dataset must return non-empty targets.") | ||
if target.numel() != 1: | ||
raise ValueError("The dataset must return a single & scalar target.") | ||
|
||
class_idx = int(target.item()) | ||
self._class_indices[class_idx].append(idx) | ||
|
||
if not self._replacement: | ||
for class_idx, indices in self._class_indices.items(): | ||
if len(indices) < self._num_samples: | ||
raise ValueError( | ||
f"Class {class_idx} has only {len(indices)} samples, " | ||
f"which is less than the required {self._num_samples} samples." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
"""Random sampler for data loading.""" | ||
|
||
from typing import Optional | ||
|
||
from torch.utils import data | ||
from typing_extensions import override | ||
|
||
from eva.core.data import datasets | ||
from eva.core.data.samplers.sampler import SamplerWithDataSource | ||
|
||
|
||
class RandomSampler(data.RandomSampler, SamplerWithDataSource[int]): | ||
"""Samples elements randomly.""" | ||
|
||
data_source: datasets.MapDataset # type: ignore | ||
|
||
def __init__( | ||
self, replacement: bool = False, num_samples: Optional[int] = None, generator=None | ||
) -> None: | ||
"""Initializes the random sampler. | ||
Args: | ||
data_source: dataset to sample from | ||
replacement: samples are drawn on-demand with replacement if ``True``, default=``False`` | ||
num_samples: number of samples to draw, default=`len(dataset)`. | ||
generator: Generator used in sampling. | ||
""" | ||
self.replacement = replacement | ||
self._num_samples = num_samples | ||
self.generator = generator | ||
|
||
@override | ||
def set_dataset(self, data_source: datasets.MapDataset) -> None: | ||
super().__init__( | ||
data_source, | ||
replacement=self.replacement, | ||
num_samples=self.num_samples, | ||
generator=self.generator, | ||
) |
Oops, something went wrong.