Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New activation checkpointing #343

Merged
merged 8 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,12 @@ class ModelConfig(BaseConfig):
See :data:`TrainConfig.precision` instead.
"""

activation_checkpointing: bool = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this gets set to true during training this could cause issues later when loading the model for inference since. Instead maybe we have a method on the model like Olmo.enable_activation_checkpointing()? The trainer calls that when TrainConfig.activation_checkpointing is true, so we don't need to add a configuration option to ModelConfig`.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why it would cause issues for inference, as the resulting checkpoint files should be 100% identical.

But I like this other design anyways.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is that it would enable activation checkpointing when the model is loaded for inference

"""
Use activation checkpointing on transformer blocks. You shouldn't set this directly.
See :data:`TrainConfig.activation_checkpointing` instead.
"""


class OptimizerType(StrEnum):
lionw = "lionw"
Expand Down Expand Up @@ -808,11 +814,6 @@ class TrainConfig(BaseConfig):
Settings for compiling the model with ``torch.compile()``.
"""

activation_checkpointing: bool = False
"""
Use activation checkpointing on transformer blocks.
"""

fsdp: FSDPConfig = field(default_factory=FSDPConfig)
"""
Fully sharded data parallel settings.
Expand Down Expand Up @@ -853,6 +854,11 @@ class TrainConfig(BaseConfig):
Stop at a specific step.
"""

activation_checkpointing: bool = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have this field

"""
Use activation checkpointing on transformer blocks.
"""

@property
def autocast_precision(self) -> torch.dtype:
if self.precision == "amp_bf16":
Expand Down
29 changes: 27 additions & 2 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import math
from abc import abstractmethod
from collections.abc import MutableMapping
from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, cast
from functools import partial
from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, cast

import torch
import torch.backends.cuda
Expand Down Expand Up @@ -741,6 +742,27 @@ def __init__(self, config: ModelConfig, init_params: bool = True):
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
)

self.__activation_checkpoint_fn: Callable
if self.config.activation_checkpointing:
preserve_rng_state = (
(self.config.attention_dropout == 0.0)
and (self.config.embedding_dropout == 0.0)
and (self.config.residual_dropout == 0.0)
)
from torch.utils.checkpoint import checkpoint

self.__activation_checkpoint_fn = partial(
checkpoint,
preserve_rng_state=preserve_rng_state,
use_reentrant=False,
)
else:

def pass_through_fn(fn, *args, **kwargs):
return fn(*args, **kwargs)

self.__activation_checkpoint_fn = pass_through_fn

torch.backends.cuda.enable_flash_sdp(self.config.flash_attention)
torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it

Expand Down Expand Up @@ -949,7 +971,10 @@ def forward(
past_key_values or [None] * self.config.n_layers, # type: ignore
):
# shape: (batch_size, seq_len, d_model)
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
x, cache = self.__activation_checkpoint_fn(
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
)

if attn_key_values is not None:
assert cache is not None
attn_key_values.append(cache)
Expand Down
22 changes: 1 addition & 21 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import os
import sys
from functools import partial
from pathlib import Path
from typing import Optional, TextIO

Expand Down Expand Up @@ -59,6 +58,7 @@ def main(cfg: TrainConfig) -> None:

# Fill some configuration options.
cfg.model.precision = cfg.precision
cfg.model.activation_checkpointing = cfg.activation_checkpointing
cfg.device_train_batch_size = cfg.global_train_batch_size // get_world_size()
assert cfg.device_train_batch_size is not None # for mypy
cfg.device_train_grad_accum = cfg.device_train_batch_size // cfg.device_train_microbatch_size
Expand Down Expand Up @@ -146,26 +146,6 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
olmo_model.reset_parameters()

log.info(f"Peak GPU Memory (MB) after FSDP: {int(peak_gpu_memory() or 0)}")

if cfg.activation_checkpointing:
# verify we have FSDP activation support ready by importing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)

non_reentrant_wrapper = partial(
checkpoint_wrapper,
offload_to_cpu=False,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
fsdp_model,
checkpoint_wrapper_fn=non_reentrant_wrapper, # type: ignore
check_fn=olmo_model.activation_checkpointing_fn, # type: ignore
)

log.info("Model:")
log.info(fsdp_model)

Expand Down
Loading