Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,19 +244,48 @@ def _get_distributed_sampler(
return DistributedSamplerWrapper(dataloader.sampler, **kwargs)


def _is_simple_sampler_replaceable(sampler: Sampler) -> bool:
"""Check if a sampler can be safely replaced with SequentialSampler for overfit batches."""
simple_sampler_types = (
RandomSampler,
SequentialSampler,
DistributedSampler,
DistributedSamplerWrapper,
UnrepeatedDistributedSamplerWrapper,
)
return isinstance(sampler, simple_sampler_types)


def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None:
"""Resolve overfit batches by disabling shuffling.

When overfit_batches > 0, this function ensures that sequential sampling is used without shuffling for consistent
batches across epochs. Training and validation use different sets of data.

For simple samplers (RandomSampler, SequentialSampler, etc.), they are replaced with SequentialSampler. For custom
samplers that may use complex indexing, they are preserved but a warning is issued.

"""
all_have_sequential_sampler = all(
isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler")
)
if all_have_sequential_sampler:
return

# Check if any dataloaders have custom samplers that shouldn't be replaced
has_custom_samplers = any(
hasattr(dl, "sampler") and not _is_simple_sampler_replaceable(dl.sampler) for dl in combined_loader.flattened
)

if has_custom_samplers:
rank_zero_warn(
f"You requested to overfit but some {mode.dataloader_prefix} dataloaders use custom samplers. "
f"Custom samplers are preserved, but please ensure they provide deterministic, non-shuffled output "
f"for consistent overfitting behavior.",
category=PossibleUserWarning,
)
return

rank_zero_warn(
f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling."
f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you."
Expand Down
50 changes: 49 additions & 1 deletion tests/tests_pytorch/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pytest
from lightning_utilities.test.warning import no_warning_call
from torch import Tensor
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler

import lightning.fabric
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
Expand All @@ -30,6 +30,8 @@
_check_dataloader_iterable,
_DataHookSelector,
_DataLoaderSource,
_is_simple_sampler_replaceable,
_resolve_overfit_batches,
_worker_check,
warning_cache,
)
Expand Down Expand Up @@ -696,3 +698,49 @@ def test_iterable_check_on_known_iterators():
dataloader.__iter__ = Mock()
_check_dataloader_iterable(dataloader, Mock(), Mock())
dataloader.__iter__.assert_not_called()


def test_is_simple_sampler_replaceable():
"""Test that _is_simple_sampler_replaceable correctly identifies simple vs custom samplers."""
dataset = RandomDataset(32, 64)

assert _is_simple_sampler_replaceable(SequentialSampler(dataset)) is True
assert _is_simple_sampler_replaceable(RandomSampler(dataset)) is True

class CustomSampler(Sampler):
def __init__(self, dataset):
self.dataset = dataset

def __iter__(self):
return iter([{"index": i, "param": 0.5} for i in range(len(self.dataset))])

def __len__(self):
return len(self.dataset)

assert _is_simple_sampler_replaceable(CustomSampler(dataset)) is False


def test_resolve_overfit_batches_preserves_custom_sampler():
"""Test that _resolve_overfit_batches does not alter custom samplers."""
dataset = RandomDataset(32, 64)

class CustomDictSampler(Sampler):
def __init__(self, dataset):
self.dataset = dataset

def __iter__(self):
return iter([{"index": i, "param": 0.5} for i in range(len(self.dataset))])

def __len__(self):
return len(self.dataset)

custom_sampler = CustomDictSampler(dataset)
dataloader = DataLoader(dataset, sampler=custom_sampler, batch_size=2)
combined_loader = CombinedLoader([dataloader])
original_sampler = dataloader.sampler

_resolve_overfit_batches(combined_loader, RunningStage.TRAINING)

assert combined_loader.flattened[0].sampler is original_sampler
assert combined_loader.flattened[0].sampler is custom_sampler
assert isinstance(combined_loader.flattened[0].sampler, CustomDictSampler)
Loading