Skip to content

Commit

Permalink
Merge pull request #343 from allenai/ActivationCheckpointing
Browse files Browse the repository at this point in the history
New activation checkpointing
  • Loading branch information
dirkgr authored Oct 27, 2023
2 parents 558102e + c366e42 commit 5c64338
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 30 deletions.
10 changes: 5 additions & 5 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,11 +821,6 @@ class TrainConfig(BaseConfig):
Settings for compiling the model with ``torch.compile()``.
"""

activation_checkpointing: bool = False
"""
Use activation checkpointing on transformer blocks.
"""

fsdp: FSDPConfig = field(default_factory=FSDPConfig)
"""
Fully sharded data parallel settings.
Expand Down Expand Up @@ -866,6 +861,11 @@ class TrainConfig(BaseConfig):
Stop at a specific step.
"""

activation_checkpointing: bool = False
"""
Use activation checkpointing on transformer blocks.
"""

@property
def autocast_precision(self) -> torch.dtype:
if self.precision == "amp_bf16":
Expand Down
51 changes: 47 additions & 4 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,18 @@
import math
from abc import abstractmethod
from collections.abc import MutableMapping
from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, cast
from functools import partial
from typing import (
Callable,
Dict,
Iterable,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
cast,
)

import torch
import torch.backends.cuda
Expand All @@ -30,7 +41,7 @@
)
from .exceptions import OlmoConfigurationError
from .initialization import init_weights
from .util import ensure_finite_
from .util import ensure_finite_, pass_through_fn

__all__ = [
"LayerNormBase",
Expand Down Expand Up @@ -698,6 +709,10 @@ class OlmoGenerateOutput(NamedTuple):


class OlmoBlockGroup(nn.ModuleList):
def __init__(self, modules: Optional[Iterable[nn.Module]] = None):
super().__init__(modules)
self.__activation_checkpoint_fn: Callable = pass_through_fn

def forward(
self,
x: torch.Tensor,
Expand All @@ -708,7 +723,9 @@ def forward(
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
for block_idx, block in enumerate(self):
layer_past = None if layers_past is None else layers_past[block_idx]
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
x, cache = self.__activation_checkpoint_fn(
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
)
if attn_key_values is not None:
assert cache is not None
attn_key_values.append(cache)
Expand Down Expand Up @@ -766,6 +783,8 @@ def __init__(self, config: ModelConfig, init_params: bool = True):
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
)

self.__activation_checkpoint_fn: Callable = pass_through_fn

if not (
0 < self.config.block_group_size <= self.config.n_layers
and self.config.n_layers % self.config.block_group_size == 0
Expand Down Expand Up @@ -820,6 +839,28 @@ def __init__(self, config: ModelConfig, init_params: bool = True):
self.get_causal_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))

def enable_activation_checkpointing(self, enable: bool = True):
if enable:
preserve_rng_state = (
(self.config.attention_dropout == 0.0)
and (self.config.embedding_dropout == 0.0)
and (self.config.residual_dropout == 0.0)
)
from torch.utils.checkpoint import checkpoint

self.__activation_checkpoint_fn = partial(
checkpoint,
preserve_rng_state=preserve_rng_state,
use_reentrant=False,
)
else:
self.__activation_checkpoint_fn = pass_through_fn

# Set up the blocks to use the same function.
if self.config.block_group_size != 1:
for block_group in self.transformer.block_groups:
block_group.__activation_checkpoint_fn = self.__activation_checkpoint_fn

@property
def device(self) -> torch.device:
device: torch.device = self.transformer.wte.weight.device # type: ignore
Expand Down Expand Up @@ -993,7 +1034,9 @@ def forward(
for block_idx, block in enumerate(self.transformer.blocks):
layer_past = None if past_key_values is None else past_key_values[block_idx]
# shape: (batch_size, seq_len, d_model)
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
x, cache = self.__activation_checkpoint_fn(
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
)
if attn_key_values is not None:
assert cache is not None
attn_key_values.append(cache)
Expand Down
4 changes: 4 additions & 0 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,3 +623,7 @@ def is_weight_decay_module(module: nn.Module) -> bool:

def default_thread_count() -> int:
return int(os.environ.get("OLMO_NUM_THREADS") or min(32, (os.cpu_count() or 1) + 4))


def pass_through_fn(fn, *args, **kwargs):
return fn(*args, **kwargs)
24 changes: 3 additions & 21 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import os
import sys
from functools import partial
from pathlib import Path
from typing import Optional, TextIO

Expand Down Expand Up @@ -113,6 +112,9 @@ def main(cfg: TrainConfig) -> None:
log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embedding=False):,d}")
log.info(f"Peak GPU Memory (MB) before FSDP: {int(peak_gpu_memory() or 0)}")

if cfg.activation_checkpointing:
olmo_model.enable_activation_checkpointing()

# Wrap the model in FSDP.
log.info("Wrapping model with FDSP...")
wrap_policy = olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy)
Expand All @@ -139,26 +141,6 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
olmo_model.reset_parameters()

log.info(f"Peak GPU Memory (MB) after FSDP: {int(peak_gpu_memory() or 0)}")

if cfg.activation_checkpointing:
# verify we have FSDP activation support ready by importing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)

non_reentrant_wrapper = partial(
checkpoint_wrapper,
offload_to_cpu=False,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
fsdp_model,
checkpoint_wrapper_fn=non_reentrant_wrapper, # type: ignore
check_fn=olmo_model.activation_checkpointing_fn, # type: ignore
)

log.info("Model:")
log.info(fsdp_model)

Expand Down

0 comments on commit 5c64338

Please sign in to comment.