From b67e92c7a85a8d68396c683fcf56863350948016 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Tue, 24 Oct 2023 23:15:58 -0700 Subject: [PATCH 1/7] Do activation checkpointing in a different way --- olmo/config.py | 16 +++++++++++----- olmo/model.py | 29 +++++++++++++++++++++++++++-- scripts/train.py | 22 +--------------------- 3 files changed, 39 insertions(+), 28 deletions(-) diff --git a/olmo/config.py b/olmo/config.py index b514aeb8a..81e48b9db 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -404,6 +404,12 @@ class ModelConfig(BaseConfig): See :data:`TrainConfig.precision` instead. """ + activation_checkpointing: bool = False + """ + Use activation checkpointing on transformer blocks. You shouldn't set this directly. + See :data:`TrainConfig.activation_checkpointing` instead. + """ + class OptimizerType(StrEnum): lionw = "lionw" @@ -808,11 +814,6 @@ class TrainConfig(BaseConfig): Settings for compiling the model with ``torch.compile()``. """ - activation_checkpointing: bool = False - """ - Use activation checkpointing on transformer blocks. - """ - fsdp: FSDPConfig = field(default_factory=FSDPConfig) """ Fully sharded data parallel settings. @@ -853,6 +854,11 @@ class TrainConfig(BaseConfig): Stop at a specific step. """ + activation_checkpointing: bool = False + """ + Use activation checkpointing on transformer blocks. + """ + @property def autocast_precision(self) -> torch.dtype: if self.precision == "amp_bf16": diff --git a/olmo/model.py b/olmo/model.py index e873e5aa1..e4ab0eb62 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -10,7 +10,8 @@ import math from abc import abstractmethod from collections.abc import MutableMapping -from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, cast +from functools import partial +from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, cast, Callable import torch import torch.backends.cuda @@ -741,6 +742,27 @@ def __init__(self, config: ModelConfig, init_params: bool = True): "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning ) + self.__activation_checkpoint_fn: Callable + if self.config.activation_checkpointing: + preserve_rng_state = ( + (self.config.attention_dropout == 0.0) + and (self.config.embedding_dropout == 0.0) + and (self.config.residual_dropout == 0.0) + ) + import torch.utils.checkpoint + + self.__activation_checkpoint_fn = partial( + torch.utils.checkpoint.checkpoint, + preserve_rng_state=preserve_rng_state, + use_reentrant=False, + ) + else: + + def pass_through_fn(fn, *args, **kwargs): + return fn(*args, **kwargs) + + self.__activation_checkpoint_fn = pass_through_fn + torch.backends.cuda.enable_flash_sdp(self.config.flash_attention) torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it @@ -949,7 +971,10 @@ def forward( past_key_values or [None] * self.config.n_layers, # type: ignore ): # shape: (batch_size, seq_len, d_model) - x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) + x, cache = self.__activation_checkpoint_fn( + block, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache + ) + if attn_key_values is not None: assert cache is not None attn_key_values.append(cache) diff --git a/scripts/train.py b/scripts/train.py index 9d6349a70..6381fe6ae 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -4,7 +4,6 @@ import logging import os import sys -from functools import partial from pathlib import Path from typing import Optional, TextIO @@ -59,6 +58,7 @@ def main(cfg: TrainConfig) -> None: # Fill some configuration options. cfg.model.precision = cfg.precision + cfg.model.activation_checkpointing = cfg.activation_checkpointing cfg.device_train_batch_size = cfg.global_train_batch_size // get_world_size() assert cfg.device_train_batch_size is not None # for mypy cfg.device_train_grad_accum = cfg.device_train_batch_size // cfg.device_train_microbatch_size @@ -146,26 +146,6 @@ def dummy_init_fn(module: torch.nn.Module) -> None: olmo_model.reset_parameters() log.info(f"Peak GPU Memory (MB) after FSDP: {int(peak_gpu_memory() or 0)}") - - if cfg.activation_checkpointing: - # verify we have FSDP activation support ready by importing: - from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointImpl, - apply_activation_checkpointing, - checkpoint_wrapper, - ) - - non_reentrant_wrapper = partial( - checkpoint_wrapper, - offload_to_cpu=False, - checkpoint_impl=CheckpointImpl.NO_REENTRANT, - ) - apply_activation_checkpointing( - fsdp_model, - checkpoint_wrapper_fn=non_reentrant_wrapper, # type: ignore - check_fn=olmo_model.activation_checkpointing_fn, # type: ignore - ) - log.info("Model:") log.info(fsdp_model) From c8c6c68f4988562ded53107aa479bb9cbe407562 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Tue, 24 Oct 2023 23:19:44 -0700 Subject: [PATCH 2/7] Forgot the X --- olmo/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/model.py b/olmo/model.py index e4ab0eb62..6eae70995 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -972,7 +972,7 @@ def forward( ): # shape: (batch_size, seq_len, d_model) x, cache = self.__activation_checkpoint_fn( - block, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache + block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache ) if attn_key_values is not None: From 1e2c7e02c1c98e7494536f4b0f014d85f370fa82 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Thu, 26 Oct 2023 14:07:47 -0700 Subject: [PATCH 3/7] Python imports are weird --- olmo/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index 6eae70995..04eecfe7a 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -749,10 +749,10 @@ def __init__(self, config: ModelConfig, init_params: bool = True): and (self.config.embedding_dropout == 0.0) and (self.config.residual_dropout == 0.0) ) - import torch.utils.checkpoint + from torch.utils.checkpoint import checkpoint self.__activation_checkpoint_fn = partial( - torch.utils.checkpoint.checkpoint, + checkpoint, preserve_rng_state=preserve_rng_state, use_reentrant=False, ) From c527ff4bbaba36c8d6c5acf4f224e4f8850ce3e2 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Thu, 26 Oct 2023 14:13:02 -0700 Subject: [PATCH 4/7] Productivity through formatting --- olmo/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/model.py b/olmo/model.py index 04eecfe7a..f6d5c9b98 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -11,7 +11,7 @@ from abc import abstractmethod from collections.abc import MutableMapping from functools import partial -from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, cast, Callable +from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, cast import torch import torch.backends.cuda From 3326f93bcafcee697f06e50c1481c7bf1e7ddb20 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Thu, 26 Oct 2023 17:43:44 -0700 Subject: [PATCH 5/7] Make activation checkpointing enablable on the fly --- olmo/config.py | 6 ------ olmo/model.py | 42 ++++++++++++++++++++++-------------------- scripts/train.py | 4 +++- 3 files changed, 25 insertions(+), 27 deletions(-) diff --git a/olmo/config.py b/olmo/config.py index 81e48b9db..893c22e23 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -404,12 +404,6 @@ class ModelConfig(BaseConfig): See :data:`TrainConfig.precision` instead. """ - activation_checkpointing: bool = False - """ - Use activation checkpointing on transformer blocks. You shouldn't set this directly. - See :data:`TrainConfig.activation_checkpointing` instead. - """ - class OptimizerType(StrEnum): lionw = "lionw" diff --git a/olmo/model.py b/olmo/model.py index f6d5c9b98..bccc7a925 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -742,26 +742,7 @@ def __init__(self, config: ModelConfig, init_params: bool = True): "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning ) - self.__activation_checkpoint_fn: Callable - if self.config.activation_checkpointing: - preserve_rng_state = ( - (self.config.attention_dropout == 0.0) - and (self.config.embedding_dropout == 0.0) - and (self.config.residual_dropout == 0.0) - ) - from torch.utils.checkpoint import checkpoint - - self.__activation_checkpoint_fn = partial( - checkpoint, - preserve_rng_state=preserve_rng_state, - use_reentrant=False, - ) - else: - - def pass_through_fn(fn, *args, **kwargs): - return fn(*args, **kwargs) - - self.__activation_checkpoint_fn = pass_through_fn + self.__activation_checkpoint_fn: Callable = Olmo.pass_through_fn torch.backends.cuda.enable_flash_sdp(self.config.flash_attention) torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it @@ -801,6 +782,27 @@ def pass_through_fn(fn, *args, **kwargs): self.get_causal_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) + @staticmethod + def pass_through_fn(fn, *args, **kwargs): + return fn(*args, **kwargs) + + def enable_activation_checkpointing(self, enable: bool = True): + if enable: + preserve_rng_state = ( + (self.config.attention_dropout == 0.0) + and (self.config.embedding_dropout == 0.0) + and (self.config.residual_dropout == 0.0) + ) + from torch.utils.checkpoint import checkpoint + + self.__activation_checkpoint_fn = partial( + checkpoint, + preserve_rng_state=preserve_rng_state, + use_reentrant=False, + ) + else: + self.__activation_checkpoint_fn = Olmo.pass_through_fn + @property def device(self) -> torch.device: device: torch.device = self.transformer.wte.weight.device # type: ignore diff --git a/scripts/train.py b/scripts/train.py index 6381fe6ae..964209aac 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -58,7 +58,6 @@ def main(cfg: TrainConfig) -> None: # Fill some configuration options. cfg.model.precision = cfg.precision - cfg.model.activation_checkpointing = cfg.activation_checkpointing cfg.device_train_batch_size = cfg.global_train_batch_size // get_world_size() assert cfg.device_train_batch_size is not None # for mypy cfg.device_train_grad_accum = cfg.device_train_batch_size // cfg.device_train_microbatch_size @@ -114,6 +113,9 @@ def main(cfg: TrainConfig) -> None: log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embedding=False):,d}") log.info(f"Peak GPU Memory (MB) before FSDP: {int(peak_gpu_memory() or 0)}") + if cfg.activation_checkpointing: + olmo_model.enable_activation_checkpointing() + # Wrap the model in FSDP. log.info("Wrapping model with FDSP...") wrap_policy = None From 5dde7741f68b98ab72c485ca4019c14ee749ed80 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Thu, 26 Oct 2023 18:00:53 -0700 Subject: [PATCH 6/7] Makes checkpointing work with block groups This is untested. --- olmo/model.py | 25 ++++++++++++++++--------- olmo/util.py | 4 ++++ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index 2b8bb5b79..82a7f5351 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -11,7 +11,7 @@ from abc import abstractmethod from collections.abc import MutableMapping from functools import partial -from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, cast +from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, cast, Iterable import torch import torch.backends.cuda @@ -31,7 +31,7 @@ ) from .exceptions import OlmoConfigurationError from .initialization import init_weights -from .util import ensure_finite_ +from .util import ensure_finite_, pass_through_fn __all__ = [ "LayerNormBase", @@ -699,6 +699,10 @@ class OlmoGenerateOutput(NamedTuple): class OlmoBlockGroup(nn.ModuleList): + def __init__(self, modules: Optional[Iterable[nn.Module]] = None): + super().__init__(modules) + self.__activation_checkpoint_fn: Callable = pass_through_fn + def forward( self, x: torch.Tensor, @@ -709,7 +713,9 @@ def forward( attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None for block_idx, block in enumerate(self): layer_past = None if layers_past is None else layers_past[block_idx] - x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) + x, cache = self.__activation_checkpoint_fn( + block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache + ) if attn_key_values is not None: assert cache is not None attn_key_values.append(cache) @@ -767,7 +773,7 @@ def __init__(self, config: ModelConfig, init_params: bool = True): "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning ) - self.__activation_checkpoint_fn: Callable = Olmo.pass_through_fn + self.__activation_checkpoint_fn: Callable = pass_through_fn if not ( 0 < self.config.block_group_size <= self.config.n_layers @@ -823,10 +829,6 @@ def __init__(self, config: ModelConfig, init_params: bool = True): self.get_causal_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) - @staticmethod - def pass_through_fn(fn, *args, **kwargs): - return fn(*args, **kwargs) - def enable_activation_checkpointing(self, enable: bool = True): if enable: preserve_rng_state = ( @@ -842,7 +844,12 @@ def enable_activation_checkpointing(self, enable: bool = True): use_reentrant=False, ) else: - self.__activation_checkpoint_fn = Olmo.pass_through_fn + self.__activation_checkpoint_fn = pass_through_fn + + # Set up the blocks to use the same function. + if self.config.block_group_size != 1: + for block_group in self.transformer.block_groups: + block_group.__activation_checkpoint_fn = self.__activation_checkpoint_fn @property def device(self) -> torch.device: diff --git a/olmo/util.py b/olmo/util.py index c185fa9e3..2b0649021 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -623,3 +623,7 @@ def is_weight_decay_module(module: nn.Module) -> bool: def default_thread_count() -> int: return int(os.environ.get("OLMO_NUM_THREADS") or min(32, (os.cpu_count() or 1) + 4)) + + +def pass_through_fn(fn, *args, **kwargs): + return fn(*args, **kwargs) From c366e42018b64f2be9c6b7b23368a9f94a3c50bb Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Thu, 26 Oct 2023 20:54:33 -0700 Subject: [PATCH 7/7] I'd rather be sailing. --- olmo/model.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/olmo/model.py b/olmo/model.py index 82a7f5351..be80cefd8 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -11,7 +11,17 @@ from abc import abstractmethod from collections.abc import MutableMapping from functools import partial -from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, cast, Iterable +from typing import ( + Callable, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + cast, +) import torch import torch.backends.cuda