Skip to content

Commit

Permalink
Run garbage collection manually in train loop
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Mar 20, 2024
1 parent 74de51d commit 474f1d2
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
12 changes: 9 additions & 3 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@
from .exceptions import OLMoCheckpointError
from .optim import Optimizer, fix_optim_state_dict
from .safetensors_util import safetensors_file_to_state_dict
from .torch_util import barrier, get_fs_local_rank, get_global_rank, get_world_size
from .torch_util import (
barrier,
gc_cuda,
get_fs_local_rank,
get_global_rank,
get_world_size,
)
from .util import (
_get_s3_client,
default_thread_count,
Expand Down Expand Up @@ -191,7 +197,7 @@ def load_fsdp_model_and_optim_state(
),
)
del model_state
torch.cuda.empty_cache()
gc_cuda()
load_fsdp_optim_state(fsdp_model, optim, optim_state["optim"])


Expand All @@ -212,7 +218,7 @@ def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[
v = state[k]
if isinstance(v, torch.Tensor):
state[k] = v.to(device="cpu")
torch.cuda.empty_cache()
gc_cuda()
optim.load_state_dict(fix_optim_state_dict(optim, flattened_osd))


Expand Down
7 changes: 7 additions & 0 deletions olmo/torch_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import os
from typing import Optional, TypeVar

Expand Down Expand Up @@ -130,3 +131,9 @@ def synchronize_value(value: V, device: torch.device) -> V:

def synchronize_flag(flag: bool, device: torch.device) -> bool:
return synchronize_value(flag, device)


def gc_cuda():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
28 changes: 25 additions & 3 deletions olmo/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import cProfile
import gc
import logging
import math
import os
Expand Down Expand Up @@ -38,6 +39,7 @@
from .optim import Optimizer, Scheduler
from .torch_util import (
barrier,
gc_cuda,
get_fs_local_rank,
get_global_rank,
get_world_size,
Expand Down Expand Up @@ -136,6 +138,7 @@ class Trainer:
cur_train_loss: float = float("inf")
indices_file: Optional[TextIO] = None
_start_time: float = 0.0
_gc_init_state: bool = True
loss_fn: Callable[..., torch.Tensor] = field(default_factory=lambda: cross_entropy_loss) # type: ignore
last_sharded_checkpoint_step: Optional[int] = None
last_unsharded_checkpoint_step: Optional[int] = None
Expand Down Expand Up @@ -537,15 +540,19 @@ def restore_unsharded_checkpoint(
def save_checkpoint(
self, checkpoint_type: CheckpointType = CheckpointType.sharded
) -> Tuple[PathOrStr, Optional[PathOrStr]]:
result: Tuple[PathOrStr, Optional[PathOrStr]]
if checkpoint_type == CheckpointType.sharded:
return self.save_sharded_checkpoint()
result = self.save_sharded_checkpoint()
elif checkpoint_type == CheckpointType.unsharded:
return self.save_unsharded_checkpoint()
result = self.save_unsharded_checkpoint()
elif checkpoint_type == CheckpointType.sharded_ephemeral:
return self.save_ephemeral_checkpoint()
result = self.save_ephemeral_checkpoint()
else:
raise NotImplementedError(checkpoint_type)

gc_cuda()
return result

def restore_checkpoint(
self,
load_path: PathOrStr,
Expand Down Expand Up @@ -576,6 +583,8 @@ def restore_checkpoint(
elif checkpoint_type is not None:
raise NotImplementedError(checkpoint_type)

gc_cuda()

def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = CheckpointType.sharded):
if checkpoint_type == CheckpointType.sharded:
self.remove_sharded_checkpoint(idx=idx)
Expand Down Expand Up @@ -936,6 +945,10 @@ def fit(self):
self.cfg.stop_at = min(self.cfg.stop_at, self.global_step + self.cfg.stop_after)

self._start_time = time.time()
self._gc_init_state = gc.isenabled() # cache if garbage collection is enabled, reset on close.

# Disable automatic garbage collection, FSDP doesn't work well with it.
gc.disable()

if self.cfg.load_path is not None and self.global_step > 0 and self.cfg.eval_on_load:
eval_metrics = self.eval()
Expand Down Expand Up @@ -1141,6 +1154,9 @@ def on_trace_ready(p):
if stop_at is not None and self.global_step >= stop_at:
break

# Run generation 1 garbage collection.
gc.collect(1)

# Python Profiler stuff
# We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
if python_profiler is not None:
Expand Down Expand Up @@ -1178,9 +1194,15 @@ def on_trace_ready(p):
log.info(f"Checkpoint saved to {checkpoint_path}")

def close(self, exit_code: int = 0) -> None:
gc_cuda()

if self.indices_file is not None:
self.indices_file.flush()
self.indices_file.close()
if self._gc_init_state:
gc.enable()
else:
gc.disable()
if wandb.run is not None:
wandb.finish(exit_code=exit_code, quiet=True)

Expand Down

0 comments on commit 474f1d2

Please sign in to comment.