Skip to content

Commit

Permalink
Add option to FSDP wrap by groups of blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 26, 2023
1 parent c1a4519 commit 22187ab
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 41 deletions.
13 changes: 13 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ class ModelConfig(BaseConfig):
The transformer block implementation.
"""

block_group_size: int = 1
"""
The number of blocks to group together into a single parent block.
This is has no affect on the number of parameters in the model and is only used to wrap groups
of blocks together with a single FSDP wrapper during training.
"""

alibi: bool = False
"""
If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
Expand Down Expand Up @@ -532,6 +539,12 @@ class FSDPWrapStrategy(StrEnum):
Wrap each OLMo block with its own FSDP instance.
"""

by_block_group = "by_block_group"
"""
Wrap each block group together into its own FSDP instance.
This requires :attr:`~ModelConfig.block_group_size` to be bigger than 1.
"""

size_based = "size_based"
"""
Used PyTorch's default size-based auto wrap policy.
Expand Down
191 changes: 159 additions & 32 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,22 @@
import math
from abc import abstractmethod
from collections.abc import MutableMapping
from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, cast
from typing import Dict, Generator, List, NamedTuple, Optional, Sequence, Tuple, cast

import torch
import torch.backends.cuda
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from .aliases import PathOrStr
from .beam_search import BeamSearch, Constraint, FinalSequenceScorer, Sampler
from .config import (
ActivationType,
BlockType,
CheckpointType,
FSDPWrapStrategy,
LayerNormType,
ModelConfig,
)
Expand Down Expand Up @@ -532,6 +534,8 @@ def forward(
self,
x: torch.Tensor,
attention_bias: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
raise NotImplementedError

Expand Down Expand Up @@ -694,6 +698,28 @@ class OlmoGenerateOutput(NamedTuple):
"""


class OlmoBlockGroup(nn.ModuleList):
def forward(
self,
x: torch.Tensor,
attention_bias: Optional[torch.FloatTensor] = None,
layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
for block_idx, block in enumerate(self):
layer_past = None if layers_past is None else layers_past[block_idx]
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
if attn_key_values is not None:
assert cache is not None
attn_key_values.append(cache)
return x, attn_key_values

def reset_parameters(self):
for block in self:
block.reset_parameters()


def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
att_bias = torch.triu(
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
Expand Down Expand Up @@ -741,6 +767,12 @@ def __init__(self, config: ModelConfig, init_params: bool = True):
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
)

if not (
0 < self.config.block_group_size < self.config.n_layers
and self.config.n_layers % self.config.block_group_size == 0
):
raise OlmoConfigurationError("n layers must be divisible by block group size")

torch.backends.cuda.enable_flash_sdp(self.config.flash_attention)
torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it

Expand All @@ -750,10 +782,20 @@ def __init__(self, config: ModelConfig, init_params: bool = True):
config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
),
emb_drop=Dropout(config.embedding_dropout),
blocks=nn.ModuleList([OlmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)]),
ln_f=LayerNorm.build(config),
)
)

blocks = [OlmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)]
if self.config.block_group_size > 1:
block_groups = [
OlmoBlockGroup(blocks[i : i + config.block_group_size])
for i in range(0, config.n_layers, config.block_group_size)
]
self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
else:
self.transformer.update({"blocks": nn.ModuleList(blocks)})

if not (self.config.alibi or self.config.rope):
self.transformer.update(
{"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
Expand Down Expand Up @@ -806,8 +848,12 @@ def reset_parameters(self):
init_weights(self.config, self.transformer.ff_out) # type: ignore

# Let the blocks handle themselves.
for block in self.transformer.blocks: # type: ignore
block.reset_parameters() # type: ignore
if self.config.block_group_size == 1:
for block in self.transformer.blocks:
block.reset_parameters()
else:
for block_group in self.transformer.block_groups:
block_group.reset_parameters()

def get_causal_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
if (causal_bias := self.__cache.get("causal_attention_bias")) is not None and causal_bias.shape[
Expand Down Expand Up @@ -944,15 +990,29 @@ def forward(
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None

# Apply blocks one-by-one.
for block, layer_past in zip(
self.transformer.blocks, # type: ignore
past_key_values or [None] * self.config.n_layers, # type: ignore
):
# shape: (batch_size, seq_len, d_model)
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
if attn_key_values is not None:
assert cache is not None
attn_key_values.append(cache)
if self.config.block_group_size == 1:
for block_idx, block in enumerate(self.transformer.blocks):
layer_past = None if past_key_values is None else past_key_values[block_idx]
# shape: (batch_size, seq_len, d_model)
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
if attn_key_values is not None:
assert cache is not None
attn_key_values.append(cache)
else:
for group_idx, block_group in enumerate(self.transformer.block_groups):
layers_past = (
None
if past_key_values is None
else past_key_values[
group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
]
)
x, cache = block_group(
x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
)
if attn_key_values is not None:
assert cache is not None
attn_key_values.extend(cache)

if last_logits_only:
# shape: (batch_size, 1, d_model)
Expand All @@ -973,11 +1033,37 @@ def forward(

return OlmoOutput(logits=logits, attn_key_values=attn_key_values) # type: ignore[arg-type]

def fsdp_wrap_fn(self, module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
if recurse:
return True # always recurse
return isinstance(module, OlmoBlock)
def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None):
if wrap_strategy is None:
return None
if wrap_strategy == FSDPWrapStrategy.by_block:

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
if recurse:
return True # always recurse for simplicity
return isinstance(module, OlmoBlock)

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.by_block_group:
if self.config.block_group_size <= 1:
raise OlmoConfigurationError(
"'by_block_group' FSDP wrapping strategy requires block group size greater than 1"
)

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
if recurse:
return True # always recurse for simplicity
return isinstance(module, OlmoBlockGroup)

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.size_based:
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

return size_based_auto_wrap_policy
else:
raise NotImplementedError(wrap_strategy)

def activation_checkpointing_fn(self, module):
return isinstance(module, OlmoBlock)
Expand Down Expand Up @@ -1185,20 +1271,61 @@ def from_checkpoint(
return model.eval()

def _make_state_dict_compatible(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
import re
from fnmatch import fnmatch

# Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is
# not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work
# fine without the prefixes. This also simplifies the other steps below.
for key in list(state_dict.keys()):
state_dict[key.replace("_fsdp_wrapped_module.", "")] = state_dict.pop(key)

# For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222
prefix = ""
if next(iter(state_dict.keys())).startswith((fsdp_prefix := "_fsdp_wrapped_module.")):
prefix = fsdp_prefix
if self.config.block_type == BlockType.sequential:
for block_idx in range(self.config.n_layers):
norm_w_key = f"{prefix}transformer.blocks.{block_idx}.norm.weight"
norm_b_key = f"{prefix}transformer.blocks.{block_idx}.norm.bias"
if norm_w_key in state_dict:
norm_w = state_dict.pop(norm_w_key)
state_dict[f"{prefix}transformer.blocks.{block_idx}.attn_norm.weight"] = norm_w
state_dict[f"{prefix}transformer.blocks.{block_idx}.ff_norm.weight"] = norm_w.clone()
if norm_b_key in state_dict:
norm_b = state_dict.pop(norm_b_key)
state_dict[f"{prefix}transformer.blocks.{block_idx}.attn_norm.bias"] = norm_b
state_dict[f"{prefix}transformer.blocks.{block_idx}.ff_norm.bias"] = norm_b.clone()
for key in list(state_dict.keys()):
if fnmatch(key, "transformer.*.norm.weight"):
tensor = state_dict.pop(key)
state_dict[key.replace("norm.weight", "attn_norm.weight")] = tensor
state_dict[key.replace("norm.weight", "ff_norm.weight")] = tensor.clone()
elif fnmatch(key, "transformer.*.norm.bias"):
tensor = state_dict.pop(key)
state_dict[key.replace("norm.bias", "attn_norm.bias")] = tensor
state_dict[key.replace("norm.bias", "ff_norm.bias")] = tensor.clone()

# For loading a state dict that was saved with a different `block_group_size`.
if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys():
state_dict_block_group_size = len(
[k for k in state_dict.keys() if fnmatch(k, "transformer.block_groups.0.*.attn_out.weight")]
)
else:
state_dict_block_group_size = 1
if self.config.block_group_size != state_dict_block_group_size:
log.info(
f"Regrouping state dict blocks from group size {state_dict_block_group_size} to "
f"group size {self.config.block_group_size}"
)
# For simplicity we're first going to flatten out the block groups in the state dict (if necessary)
# and then (re-)group them into the right block sizes.
if state_dict_block_group_size > 1:
for key in list(state_dict.keys()):
if (m := re.match(r"transformer.block_groups\.(\d+)\.(\d+)\..*", key)) is not None:
group_idx, group_block_idx = int(m.group(1)), int(m.group(2))
block_idx = (group_idx * state_dict_block_group_size) + group_block_idx
state_dict[
key.replace(f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}.")
] = state_dict.pop(key)

if self.config.block_group_size > 1:
# Group the state dict blocks into the right block size.
for key in list(state_dict.keys()):
if (m := re.match(r"transformer.blocks\.(\d+)\..*", key)) is not None:
block_idx = int(m.group(1))
group_idx, group_block_idx = (
block_idx // self.config.block_group_size,
block_idx % self.config.block_group_size,
)
state_dict[
key.replace(f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}.")
] = state_dict.pop(key)

return state_dict
9 changes: 1 addition & 8 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import wandb
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

from olmo.config import CheckpointType, FSDPWrapStrategy, TrainConfig
from olmo.data import build_train_dataloader
Expand Down Expand Up @@ -116,12 +115,7 @@ def main(cfg: TrainConfig) -> None:

# Wrap the model in FSDP.
log.info("Wrapping model with FDSP...")
wrap_policy = None
if cfg.fsdp.wrapping_strategy == FSDPWrapStrategy.by_block:
wrap_policy = olmo_model.fsdp_wrap_fn
elif cfg.fsdp.wrapping_strategy == FSDPWrapStrategy.size_based:
wrap_policy = size_based_auto_wrap_policy

wrap_policy = olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy)
if version.parse(torch.__version__) >= version.parse("2.1.0"):
# This prevents any parameters from being initialized twice
def dummy_init_fn(module: torch.nn.Module) -> None:
Expand All @@ -130,7 +124,6 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
param_init_fn = dummy_init_fn
else:
param_init_fn = None

fsdp_model = FSDP(
olmo_model,
sharding_strategy=cfg.fsdp.sharding_strategy,
Expand Down
23 changes: 22 additions & 1 deletion tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.nn import CrossEntropyLoss

from olmo import BlockType, LayerNorm, Olmo, Tokenizer, TrainConfig
from olmo.config import PaddingDirection
from olmo.config import ModelConfig, PaddingDirection
from olmo.data import DataCollator
from olmo.model import AMDLayerNorm

Expand Down Expand Up @@ -432,3 +432,24 @@ def test_layer_norm(train_config: TrainConfig, elementwise_affine: bool, include

y_actual = amd_ln(x)
torch.testing.assert_close(y_actual, y_expected)


def test_block_groups():
model_with_block_groups = Olmo(ModelConfig(d_model=128, n_heads=2, n_layers=9, block_group_size=3)).eval()
model_without_block_groups = Olmo(ModelConfig(d_model=128, n_heads=2, n_layers=9, block_group_size=1)).eval()

# We should be able to load the state dict from one model into the other, and vice-versa.
model_with_block_groups.load_state_dict(
model_with_block_groups._make_state_dict_compatible(model_without_block_groups.state_dict())
)
model_without_block_groups.load_state_dict(
model_without_block_groups._make_state_dict_compatible(model_with_block_groups.state_dict())
)

# Check that output is exactly the same.
input_ids = torch.randint(0, model_with_block_groups.config.vocab_size, (2, 16))
with torch.no_grad():
block_groups_output = model_with_block_groups(input_ids)
no_block_groups_output = model_without_block_groups(input_ids)

torch.testing.assert_close(block_groups_output, no_block_groups_output)

0 comments on commit 22187ab

Please sign in to comment.