diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 240dae6296c1f..07cfaf291f9fb 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -244,12 +244,27 @@ def _get_distributed_sampler( return DistributedSamplerWrapper(dataloader.sampler, **kwargs) +def _is_simple_sampler_replaceable(sampler: Sampler) -> bool: + """Check if a sampler can be safely replaced with SequentialSampler for overfit batches.""" + simple_sampler_types = ( + RandomSampler, + SequentialSampler, + DistributedSampler, + DistributedSamplerWrapper, + UnrepeatedDistributedSamplerWrapper, + ) + return isinstance(sampler, simple_sampler_types) + + def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None: """Resolve overfit batches by disabling shuffling. When overfit_batches > 0, this function ensures that sequential sampling is used without shuffling for consistent batches across epochs. Training and validation use different sets of data. + For simple samplers (RandomSampler, SequentialSampler, etc.), they are replaced with SequentialSampler. For custom + samplers that may use complex indexing, they are preserved but a warning is issued. + """ all_have_sequential_sampler = all( isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler") @@ -257,6 +272,20 @@ def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage if all_have_sequential_sampler: return + # Check if any dataloaders have custom samplers that shouldn't be replaced + has_custom_samplers = any( + hasattr(dl, "sampler") and not _is_simple_sampler_replaceable(dl.sampler) for dl in combined_loader.flattened + ) + + if has_custom_samplers: + rank_zero_warn( + f"You requested to overfit but some {mode.dataloader_prefix} dataloaders use custom samplers. " + f"Custom samplers are preserved, but please ensure they provide deterministic, non-shuffled output " + f"for consistent overfitting behavior.", + category=PossibleUserWarning, + ) + return + rank_zero_warn( f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling." f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you." diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 367c2340ce542..9580ed278a3e5 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -19,7 +19,7 @@ import pytest from lightning_utilities.test.warning import no_warning_call from torch import Tensor -from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler +from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler import lightning.fabric from lightning.fabric.utilities.distributed import DistributedSamplerWrapper @@ -30,6 +30,8 @@ _check_dataloader_iterable, _DataHookSelector, _DataLoaderSource, + _is_simple_sampler_replaceable, + _resolve_overfit_batches, _worker_check, warning_cache, ) @@ -696,3 +698,49 @@ def test_iterable_check_on_known_iterators(): dataloader.__iter__ = Mock() _check_dataloader_iterable(dataloader, Mock(), Mock()) dataloader.__iter__.assert_not_called() + + +def test_is_simple_sampler_replaceable(): + """Test that _is_simple_sampler_replaceable correctly identifies simple vs custom samplers.""" + dataset = RandomDataset(32, 64) + + assert _is_simple_sampler_replaceable(SequentialSampler(dataset)) is True + assert _is_simple_sampler_replaceable(RandomSampler(dataset)) is True + + class CustomSampler(Sampler): + def __init__(self, dataset): + self.dataset = dataset + + def __iter__(self): + return iter([{"index": i, "param": 0.5} for i in range(len(self.dataset))]) + + def __len__(self): + return len(self.dataset) + + assert _is_simple_sampler_replaceable(CustomSampler(dataset)) is False + + +def test_resolve_overfit_batches_preserves_custom_sampler(): + """Test that _resolve_overfit_batches does not alter custom samplers.""" + dataset = RandomDataset(32, 64) + + class CustomDictSampler(Sampler): + def __init__(self, dataset): + self.dataset = dataset + + def __iter__(self): + return iter([{"index": i, "param": 0.5} for i in range(len(self.dataset))]) + + def __len__(self): + return len(self.dataset) + + custom_sampler = CustomDictSampler(dataset) + dataloader = DataLoader(dataset, sampler=custom_sampler, batch_size=2) + combined_loader = CombinedLoader([dataloader]) + original_sampler = dataloader.sampler + + _resolve_overfit_batches(combined_loader, RunningStage.TRAINING) + + assert combined_loader.flattened[0].sampler is original_sampler + assert combined_loader.flattened[0].sampler is custom_sampler + assert isinstance(combined_loader.flattened[0].sampler, CustomDictSampler)