diff --git a/olmo/config.py b/olmo/config.py index 8e04957e4..af3327dac 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -821,11 +821,6 @@ class TrainConfig(BaseConfig): Settings for compiling the model with ``torch.compile()``. """ - activation_checkpointing: bool = False - """ - Use activation checkpointing on transformer blocks. - """ - fsdp: FSDPConfig = field(default_factory=FSDPConfig) """ Fully sharded data parallel settings. @@ -866,6 +861,11 @@ class TrainConfig(BaseConfig): Stop at a specific step. """ + activation_checkpointing: bool = False + """ + Use activation checkpointing on transformer blocks. + """ + @property def autocast_precision(self) -> torch.dtype: if self.precision == "amp_bf16": diff --git a/olmo/model.py b/olmo/model.py index 85135ddce..be80cefd8 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -10,7 +10,18 @@ import math from abc import abstractmethod from collections.abc import MutableMapping -from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, cast +from functools import partial +from typing import ( + Callable, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + cast, +) import torch import torch.backends.cuda @@ -30,7 +41,7 @@ ) from .exceptions import OlmoConfigurationError from .initialization import init_weights -from .util import ensure_finite_ +from .util import ensure_finite_, pass_through_fn __all__ = [ "LayerNormBase", @@ -698,6 +709,10 @@ class OlmoGenerateOutput(NamedTuple): class OlmoBlockGroup(nn.ModuleList): + def __init__(self, modules: Optional[Iterable[nn.Module]] = None): + super().__init__(modules) + self.__activation_checkpoint_fn: Callable = pass_through_fn + def forward( self, x: torch.Tensor, @@ -708,7 +723,9 @@ def forward( attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None for block_idx, block in enumerate(self): layer_past = None if layers_past is None else layers_past[block_idx] - x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) + x, cache = self.__activation_checkpoint_fn( + block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache + ) if attn_key_values is not None: assert cache is not None attn_key_values.append(cache) @@ -766,6 +783,8 @@ def __init__(self, config: ModelConfig, init_params: bool = True): "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning ) + self.__activation_checkpoint_fn: Callable = pass_through_fn + if not ( 0 < self.config.block_group_size <= self.config.n_layers and self.config.n_layers % self.config.block_group_size == 0 @@ -820,6 +839,28 @@ def __init__(self, config: ModelConfig, init_params: bool = True): self.get_causal_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) + def enable_activation_checkpointing(self, enable: bool = True): + if enable: + preserve_rng_state = ( + (self.config.attention_dropout == 0.0) + and (self.config.embedding_dropout == 0.0) + and (self.config.residual_dropout == 0.0) + ) + from torch.utils.checkpoint import checkpoint + + self.__activation_checkpoint_fn = partial( + checkpoint, + preserve_rng_state=preserve_rng_state, + use_reentrant=False, + ) + else: + self.__activation_checkpoint_fn = pass_through_fn + + # Set up the blocks to use the same function. + if self.config.block_group_size != 1: + for block_group in self.transformer.block_groups: + block_group.__activation_checkpoint_fn = self.__activation_checkpoint_fn + @property def device(self) -> torch.device: device: torch.device = self.transformer.wte.weight.device # type: ignore @@ -993,7 +1034,9 @@ def forward( for block_idx, block in enumerate(self.transformer.blocks): layer_past = None if past_key_values is None else past_key_values[block_idx] # shape: (batch_size, seq_len, d_model) - x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) + x, cache = self.__activation_checkpoint_fn( + block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache + ) if attn_key_values is not None: assert cache is not None attn_key_values.append(cache) diff --git a/olmo/util.py b/olmo/util.py index c185fa9e3..2b0649021 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -623,3 +623,7 @@ def is_weight_decay_module(module: nn.Module) -> bool: def default_thread_count() -> int: return int(os.environ.get("OLMO_NUM_THREADS") or min(32, (os.cpu_count() or 1) + 4)) + + +def pass_through_fn(fn, *args, **kwargs): + return fn(*args, **kwargs) diff --git a/scripts/train.py b/scripts/train.py index cd3c59a05..c119431c1 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -4,7 +4,6 @@ import logging import os import sys -from functools import partial from pathlib import Path from typing import Optional, TextIO @@ -113,6 +112,9 @@ def main(cfg: TrainConfig) -> None: log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embedding=False):,d}") log.info(f"Peak GPU Memory (MB) before FSDP: {int(peak_gpu_memory() or 0)}") + if cfg.activation_checkpointing: + olmo_model.enable_activation_checkpointing() + # Wrap the model in FSDP. log.info("Wrapping model with FDSP...") wrap_policy = olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy) @@ -139,26 +141,6 @@ def dummy_init_fn(module: torch.nn.Module) -> None: olmo_model.reset_parameters() log.info(f"Peak GPU Memory (MB) after FSDP: {int(peak_gpu_memory() or 0)}") - - if cfg.activation_checkpointing: - # verify we have FSDP activation support ready by importing: - from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointImpl, - apply_activation_checkpointing, - checkpoint_wrapper, - ) - - non_reentrant_wrapper = partial( - checkpoint_wrapper, - offload_to_cpu=False, - checkpoint_impl=CheckpointImpl.NO_REENTRANT, - ) - apply_activation_checkpointing( - fsdp_model, - checkpoint_wrapper_fn=non_reentrant_wrapper, # type: ignore - check_fn=olmo_model.activation_checkpointing_fn, # type: ignore - ) - log.info("Model:") log.info(fsdp_model)