Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to FSDP wrap by groups of blocks #340

Merged
merged 3 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ class ModelConfig(BaseConfig):
The transformer block implementation.
"""

block_group_size: int = 1
"""
The number of blocks to group together into a single parent block.
This is has no affect on the number of parameters in the model and is only used to wrap groups
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is has -> has

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of blocks together with a single FSDP wrapper during training.
"""

alibi: bool = False
"""
If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
Expand Down Expand Up @@ -532,6 +539,12 @@ class FSDPWrapStrategy(StrEnum):
Wrap each OLMo block with its own FSDP instance.
"""

by_block_group = "by_block_group"
"""
Wrap each block group together into its own FSDP instance.
This requires :attr:`~ModelConfig.block_group_size` to be bigger than 1.
"""

size_based = "size_based"
"""
Used PyTorch's default size-based auto wrap policy.
Expand Down
188 changes: 157 additions & 31 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ActivationType,
BlockType,
CheckpointType,
FSDPWrapStrategy,
LayerNormType,
ModelConfig,
)
Expand Down Expand Up @@ -532,6 +533,8 @@ def forward(
self,
x: torch.Tensor,
attention_bias: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
raise NotImplementedError

Expand Down Expand Up @@ -694,6 +697,28 @@ class OlmoGenerateOutput(NamedTuple):
"""


class OlmoBlockGroup(nn.ModuleList):
def forward(
self,
x: torch.Tensor,
attention_bias: Optional[torch.FloatTensor] = None,
layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
for block_idx, block in enumerate(self):
layer_past = None if layers_past is None else layers_past[block_idx]
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
if attn_key_values is not None:
assert cache is not None
attn_key_values.append(cache)
return x, attn_key_values

def reset_parameters(self):
for block in self:
block.reset_parameters()


def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
att_bias = torch.triu(
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
Expand Down Expand Up @@ -741,6 +766,12 @@ def __init__(self, config: ModelConfig, init_params: bool = True):
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
)

if not (
0 < self.config.block_group_size <= self.config.n_layers
and self.config.n_layers % self.config.block_group_size == 0
):
raise OlmoConfigurationError("n layers must be divisible by block group size")

torch.backends.cuda.enable_flash_sdp(self.config.flash_attention)
torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it

Expand All @@ -750,10 +781,20 @@ def __init__(self, config: ModelConfig, init_params: bool = True):
config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
),
emb_drop=Dropout(config.embedding_dropout),
blocks=nn.ModuleList([OlmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)]),
ln_f=LayerNorm.build(config),
)
)

blocks = [OlmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)]
if self.config.block_group_size > 1:
block_groups = [
OlmoBlockGroup(blocks[i : i + config.block_group_size])
for i in range(0, config.n_layers, config.block_group_size)
]
self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
else:
self.transformer.update({"blocks": nn.ModuleList(blocks)})

if not (self.config.alibi or self.config.rope):
self.transformer.update(
{"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
Expand Down Expand Up @@ -806,8 +847,12 @@ def reset_parameters(self):
init_weights(self.config, self.transformer.ff_out) # type: ignore

# Let the blocks handle themselves.
for block in self.transformer.blocks: # type: ignore
block.reset_parameters() # type: ignore
if self.config.block_group_size == 1:
for block in self.transformer.blocks:
block.reset_parameters()
else:
for block_group in self.transformer.block_groups:
block_group.reset_parameters()

def get_causal_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
if (causal_bias := self.__cache.get("causal_attention_bias")) is not None and causal_bias.shape[
Expand Down Expand Up @@ -944,15 +989,29 @@ def forward(
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None

# Apply blocks one-by-one.
for block, layer_past in zip(
self.transformer.blocks, # type: ignore
past_key_values or [None] * self.config.n_layers, # type: ignore
):
# shape: (batch_size, seq_len, d_model)
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
if attn_key_values is not None:
assert cache is not None
attn_key_values.append(cache)
if self.config.block_group_size == 1:
for block_idx, block in enumerate(self.transformer.blocks):
layer_past = None if past_key_values is None else past_key_values[block_idx]
# shape: (batch_size, seq_len, d_model)
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
if attn_key_values is not None:
assert cache is not None
attn_key_values.append(cache)
else:
for group_idx, block_group in enumerate(self.transformer.block_groups):
layers_past = (
None
if past_key_values is None
else past_key_values[
group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
]
)
x, cache = block_group(
x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
)
if attn_key_values is not None:
assert cache is not None
attn_key_values.extend(cache)

if last_logits_only:
# shape: (batch_size, 1, d_model)
Expand All @@ -973,11 +1032,37 @@ def forward(

return OlmoOutput(logits=logits, attn_key_values=attn_key_values) # type: ignore[arg-type]

def fsdp_wrap_fn(self, module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
if recurse:
return True # always recurse
return isinstance(module, OlmoBlock)
def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None):
if wrap_strategy is None:
return None
if wrap_strategy == FSDPWrapStrategy.by_block:

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
if recurse:
return True # always recurse for simplicity
return isinstance(module, OlmoBlock)

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.by_block_group:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this strategy, the input and output embeddings are never wrapped. I think that's fine at this point in time, but we should experiment with it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I didn't want to change too many things.

if self.config.block_group_size <= 1:
raise OlmoConfigurationError(
"'by_block_group' FSDP wrapping strategy requires block group size greater than 1"
)

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
if recurse:
return True # always recurse for simplicity
Comment on lines +1055 to +1056
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this works. When does this ever get called with recurse == False?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I understand this function basically gets called twice on every module. Once with recurse=False to check if that module itself should be wrapped, and once again with recurse=True to check if it should go deeper into submodules of the current module to potentially wrap any of those.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the meaning of the return value changes depending on the value of recurse, which is confusing... and is the reason we had the wrapping bug in the first place where we thought we were wrapping by block but we were actually just wrapping the whole model.

return isinstance(module, OlmoBlockGroup)

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.size_based:
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

return size_based_auto_wrap_policy
else:
raise NotImplementedError(wrap_strategy)

def activation_checkpointing_fn(self, module):
return isinstance(module, OlmoBlock)
Expand Down Expand Up @@ -1185,20 +1270,61 @@ def from_checkpoint(
return model.eval()

def _make_state_dict_compatible(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
import re
from fnmatch import fnmatch

# Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is
# not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work
# fine without the prefixes. This also simplifies the other steps below.
for key in list(state_dict.keys()):
state_dict[key.replace("_fsdp_wrapped_module.", "")] = state_dict.pop(key)

# For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222
prefix = ""
if next(iter(state_dict.keys())).startswith((fsdp_prefix := "_fsdp_wrapped_module.")):
prefix = fsdp_prefix
if self.config.block_type == BlockType.sequential:
for block_idx in range(self.config.n_layers):
norm_w_key = f"{prefix}transformer.blocks.{block_idx}.norm.weight"
norm_b_key = f"{prefix}transformer.blocks.{block_idx}.norm.bias"
if norm_w_key in state_dict:
norm_w = state_dict.pop(norm_w_key)
state_dict[f"{prefix}transformer.blocks.{block_idx}.attn_norm.weight"] = norm_w
state_dict[f"{prefix}transformer.blocks.{block_idx}.ff_norm.weight"] = norm_w.clone()
if norm_b_key in state_dict:
norm_b = state_dict.pop(norm_b_key)
state_dict[f"{prefix}transformer.blocks.{block_idx}.attn_norm.bias"] = norm_b
state_dict[f"{prefix}transformer.blocks.{block_idx}.ff_norm.bias"] = norm_b.clone()
for key in list(state_dict.keys()):
if fnmatch(key, "transformer.*.norm.weight"):
tensor = state_dict.pop(key)
state_dict[key.replace("norm.weight", "attn_norm.weight")] = tensor
state_dict[key.replace("norm.weight", "ff_norm.weight")] = tensor.clone()
elif fnmatch(key, "transformer.*.norm.bias"):
tensor = state_dict.pop(key)
state_dict[key.replace("norm.bias", "attn_norm.bias")] = tensor
state_dict[key.replace("norm.bias", "ff_norm.bias")] = tensor.clone()

# For loading a state dict that was saved with a different `block_group_size`.
if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys():
state_dict_block_group_size = len(
[k for k in state_dict.keys() if fnmatch(k, "transformer.block_groups.0.*.attn_out.weight")]
)
else:
state_dict_block_group_size = 1
if self.config.block_group_size != state_dict_block_group_size:
log.info(
f"Regrouping state dict blocks from group size {state_dict_block_group_size} to "
f"group size {self.config.block_group_size}"
)
# For simplicity we're first going to flatten out the block groups in the state dict (if necessary)
# and then (re-)group them into the right block sizes.
if state_dict_block_group_size > 1:
for key in list(state_dict.keys()):
if (m := re.match(r"transformer.block_groups\.(\d+)\.(\d+)\..*", key)) is not None:
group_idx, group_block_idx = int(m.group(1)), int(m.group(2))
block_idx = (group_idx * state_dict_block_group_size) + group_block_idx
state_dict[
key.replace(f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}.")
] = state_dict.pop(key)

if self.config.block_group_size > 1:
# Group the state dict blocks into the right block size.
for key in list(state_dict.keys()):
if (m := re.match(r"transformer.blocks\.(\d+)\..*", key)) is not None:
block_idx = int(m.group(1))
group_idx, group_block_idx = (
block_idx // self.config.block_group_size,
block_idx % self.config.block_group_size,
)
state_dict[
key.replace(f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}.")
] = state_dict.pop(key)

return state_dict
11 changes: 2 additions & 9 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
import wandb
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

from olmo.config import CheckpointType, FSDPWrapStrategy, TrainConfig
from olmo.config import CheckpointType, TrainConfig
from olmo.data import build_train_dataloader
from olmo.eval import build_evaluators
from olmo.exceptions import OlmoCliError, OlmoConfigurationError
Expand Down Expand Up @@ -116,12 +115,7 @@ def main(cfg: TrainConfig) -> None:

# Wrap the model in FSDP.
log.info("Wrapping model with FDSP...")
wrap_policy = None
if cfg.fsdp.wrapping_strategy == FSDPWrapStrategy.by_block:
wrap_policy = olmo_model.fsdp_wrap_fn
elif cfg.fsdp.wrapping_strategy == FSDPWrapStrategy.size_based:
wrap_policy = size_based_auto_wrap_policy

wrap_policy = olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy)
if version.parse(torch.__version__) >= version.parse("2.1.0"):
# This prevents any parameters from being initialized twice
def dummy_init_fn(module: torch.nn.Module) -> None:
Expand All @@ -130,7 +124,6 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
param_init_fn = dummy_init_fn
else:
param_init_fn = None

fsdp_model = FSDP(
olmo_model,
sharding_strategy=cfg.fsdp.sharding_strategy,
Expand Down
23 changes: 22 additions & 1 deletion tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.nn import CrossEntropyLoss

from olmo import BlockType, LayerNorm, Olmo, Tokenizer, TrainConfig
from olmo.config import PaddingDirection
from olmo.config import ModelConfig, PaddingDirection
from olmo.data import DataCollator
from olmo.model import AMDLayerNorm

Expand Down Expand Up @@ -432,3 +432,24 @@ def test_layer_norm(train_config: TrainConfig, elementwise_affine: bool, include

y_actual = amd_ln(x)
torch.testing.assert_close(y_actual, y_expected)


def test_block_groups():
model_with_block_groups = Olmo(ModelConfig(d_model=128, n_heads=2, n_layers=9, block_group_size=3)).eval()
model_without_block_groups = Olmo(ModelConfig(d_model=128, n_heads=2, n_layers=9, block_group_size=1)).eval()

# We should be able to load the state dict from one model into the other, and vice-versa.
model_with_block_groups.load_state_dict(
model_with_block_groups._make_state_dict_compatible(model_without_block_groups.state_dict())
)
model_without_block_groups.load_state_dict(
model_without_block_groups._make_state_dict_compatible(model_with_block_groups.state_dict())
)

# Check that output is exactly the same.
input_ids = torch.randint(0, model_with_block_groups.config.vocab_size, (2, 16))
with torch.no_grad():
block_groups_output = model_with_block_groups(input_ids)
no_block_groups_output = model_without_block_groups(input_ids)

torch.testing.assert_close(block_groups_output, no_block_groups_output)
Loading