From 1c5634f09f91b9501a07efce3db3dd8fd23a1a02 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 10 Aug 2023 15:14:49 -0700 Subject: [PATCH] fix --- olmo/data/iterable_dataset.py | 6 ++++-- olmo/train.py | 3 +++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/olmo/data/iterable_dataset.py b/olmo/data/iterable_dataset.py index f127c764d..1ca94f686 100644 --- a/olmo/data/iterable_dataset.py +++ b/olmo/data/iterable_dataset.py @@ -8,7 +8,7 @@ import torch.utils.data from ..aliases import PathOrStr -from ..util import barrier, get_global_rank, get_world_size +from ..util import barrier, get_fs_local_rank, get_global_rank, get_world_size __all__ = ["IterableDataset"] @@ -35,6 +35,7 @@ def __init__( drop_last: bool = False, world_size: Optional[int] = None, rank: Optional[int] = None, + fs_local_rank: Optional[int] = None, work_dir: Optional[PathOrStr] = None, ): self.dataset = dataset @@ -44,6 +45,7 @@ def __init__( self.shuffle = shuffle self.drop_last = drop_last self.rank = rank if rank is not None else get_global_rank() + self.fs_local_rank = fs_local_rank if fs_local_rank is not None else get_fs_local_rank() self.world_size = world_size if world_size is not None else get_world_size() # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. @@ -59,7 +61,7 @@ def __init__( self.global_indices_file: Optional[Path] = None if work_dir is not None: self.global_indices_file = Path(work_dir) / "global_indices.npy" - if self.rank == 0: + if self.fs_local_rank == 0: log.info("Saving global data order indices...") self.global_indices_file.parent.mkdir(parents=True, exist_ok=True) global_indices = self._build_global_indices() diff --git a/olmo/train.py b/olmo/train.py index a32dc6be0..a3ac3268e 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -213,6 +213,9 @@ def restore_rng_state(self, rng_state: Dict[str, Any]) -> None: torch.cuda.set_rng_state(rng_state["cuda"]) def save_sharded_checkpoint(self) -> Path: + # Zero-gradients to avoid gathering them. + self.optim.zero_grad(set_to_none=True) + checkpoint_dir = Path(self.cfg.save_folder) / f"step{self.global_step}" checkpoint_dir_tmp = Path(self.cfg.save_folder) / f"step{self.global_step}-tmp"