diff --git a/olmo/data/iterable_dataset.py b/olmo/data/iterable_dataset.py index 3cfc8bfcc..50a1aa653 100644 --- a/olmo/data/iterable_dataset.py +++ b/olmo/data/iterable_dataset.py @@ -69,9 +69,11 @@ def __init__( def _build_global_indices(self) -> List[int]: if self.shuffle: # Deterministically shuffle based on epoch and seed - g = torch.Generator() - g.manual_seed(self.seed) - indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + # Torch built-in randomness is not very random, so we use numpy. + rng = np.random.Generator(np.random.PCG64(seed=self.seed)) + indices = np.arange(len(self.dataset)) + rng.shuffle(indices) + indices = list(indices) else: indices = list(range(len(self.dataset))) # type: ignore[arg-type]