Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Aug 10, 2023
1 parent fea1de2 commit 1c5634f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
6 changes: 4 additions & 2 deletions olmo/data/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.utils.data

from ..aliases import PathOrStr
from ..util import barrier, get_global_rank, get_world_size
from ..util import barrier, get_fs_local_rank, get_global_rank, get_world_size

__all__ = ["IterableDataset"]

Expand All @@ -35,6 +35,7 @@ def __init__(
drop_last: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
fs_local_rank: Optional[int] = None,
work_dir: Optional[PathOrStr] = None,
):
self.dataset = dataset
Expand All @@ -44,6 +45,7 @@ def __init__(
self.shuffle = shuffle
self.drop_last = drop_last
self.rank = rank if rank is not None else get_global_rank()
self.fs_local_rank = fs_local_rank if fs_local_rank is not None else get_fs_local_rank()
self.world_size = world_size if world_size is not None else get_world_size()
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
Expand All @@ -59,7 +61,7 @@ def __init__(
self.global_indices_file: Optional[Path] = None
if work_dir is not None:
self.global_indices_file = Path(work_dir) / "global_indices.npy"
if self.rank == 0:
if self.fs_local_rank == 0:
log.info("Saving global data order indices...")
self.global_indices_file.parent.mkdir(parents=True, exist_ok=True)
global_indices = self._build_global_indices()
Expand Down
3 changes: 3 additions & 0 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ def restore_rng_state(self, rng_state: Dict[str, Any]) -> None:
torch.cuda.set_rng_state(rng_state["cuda"])

def save_sharded_checkpoint(self) -> Path:
# Zero-gradients to avoid gathering them.
self.optim.zero_grad(set_to_none=True)

checkpoint_dir = Path(self.cfg.save_folder) / f"step{self.global_step}"
checkpoint_dir_tmp = Path(self.cfg.save_folder) / f"step{self.global_step}-tmp"

Expand Down

0 comments on commit 1c5634f

Please sign in to comment.