-
Notifications
You must be signed in to change notification settings - Fork 451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add option to FSDP wrap by groups of blocks #340
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
ActivationType, | ||
BlockType, | ||
CheckpointType, | ||
FSDPWrapStrategy, | ||
LayerNormType, | ||
ModelConfig, | ||
) | ||
|
@@ -532,6 +533,8 @@ def forward( | |
self, | ||
x: torch.Tensor, | ||
attention_bias: Optional[torch.FloatTensor] = None, | ||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | ||
use_cache: bool = False, | ||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: | ||
raise NotImplementedError | ||
|
||
|
@@ -694,6 +697,28 @@ class OlmoGenerateOutput(NamedTuple): | |
""" | ||
|
||
|
||
class OlmoBlockGroup(nn.ModuleList): | ||
def forward( | ||
self, | ||
x: torch.Tensor, | ||
attention_bias: Optional[torch.FloatTensor] = None, | ||
layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, | ||
use_cache: bool = False, | ||
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: | ||
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None | ||
for block_idx, block in enumerate(self): | ||
layer_past = None if layers_past is None else layers_past[block_idx] | ||
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) | ||
if attn_key_values is not None: | ||
assert cache is not None | ||
attn_key_values.append(cache) | ||
return x, attn_key_values | ||
|
||
def reset_parameters(self): | ||
for block in self: | ||
block.reset_parameters() | ||
|
||
|
||
def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor: | ||
att_bias = torch.triu( | ||
torch.ones(seq_len, seq_len, device=device, dtype=torch.float), | ||
|
@@ -741,6 +766,12 @@ def __init__(self, config: ModelConfig, init_params: bool = True): | |
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning | ||
) | ||
|
||
if not ( | ||
0 < self.config.block_group_size <= self.config.n_layers | ||
and self.config.n_layers % self.config.block_group_size == 0 | ||
): | ||
raise OlmoConfigurationError("n layers must be divisible by block group size") | ||
|
||
torch.backends.cuda.enable_flash_sdp(self.config.flash_attention) | ||
torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it | ||
|
||
|
@@ -750,10 +781,20 @@ def __init__(self, config: ModelConfig, init_params: bool = True): | |
config.embedding_size or config.vocab_size, config.d_model, device=config.init_device | ||
), | ||
emb_drop=Dropout(config.embedding_dropout), | ||
blocks=nn.ModuleList([OlmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)]), | ||
ln_f=LayerNorm.build(config), | ||
) | ||
) | ||
|
||
blocks = [OlmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)] | ||
if self.config.block_group_size > 1: | ||
block_groups = [ | ||
OlmoBlockGroup(blocks[i : i + config.block_group_size]) | ||
for i in range(0, config.n_layers, config.block_group_size) | ||
] | ||
self.transformer.update({"block_groups": nn.ModuleList(block_groups)}) | ||
else: | ||
self.transformer.update({"blocks": nn.ModuleList(blocks)}) | ||
|
||
if not (self.config.alibi or self.config.rope): | ||
self.transformer.update( | ||
{"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)} | ||
|
@@ -806,8 +847,12 @@ def reset_parameters(self): | |
init_weights(self.config, self.transformer.ff_out) # type: ignore | ||
|
||
# Let the blocks handle themselves. | ||
for block in self.transformer.blocks: # type: ignore | ||
block.reset_parameters() # type: ignore | ||
if self.config.block_group_size == 1: | ||
for block in self.transformer.blocks: | ||
block.reset_parameters() | ||
else: | ||
for block_group in self.transformer.block_groups: | ||
block_group.reset_parameters() | ||
|
||
def get_causal_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: | ||
if (causal_bias := self.__cache.get("causal_attention_bias")) is not None and causal_bias.shape[ | ||
|
@@ -944,15 +989,29 @@ def forward( | |
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None | ||
|
||
# Apply blocks one-by-one. | ||
for block, layer_past in zip( | ||
self.transformer.blocks, # type: ignore | ||
past_key_values or [None] * self.config.n_layers, # type: ignore | ||
): | ||
# shape: (batch_size, seq_len, d_model) | ||
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) | ||
if attn_key_values is not None: | ||
assert cache is not None | ||
attn_key_values.append(cache) | ||
if self.config.block_group_size == 1: | ||
for block_idx, block in enumerate(self.transformer.blocks): | ||
layer_past = None if past_key_values is None else past_key_values[block_idx] | ||
# shape: (batch_size, seq_len, d_model) | ||
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) | ||
if attn_key_values is not None: | ||
assert cache is not None | ||
attn_key_values.append(cache) | ||
else: | ||
for group_idx, block_group in enumerate(self.transformer.block_groups): | ||
layers_past = ( | ||
None | ||
if past_key_values is None | ||
else past_key_values[ | ||
group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size | ||
] | ||
) | ||
x, cache = block_group( | ||
x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache | ||
) | ||
if attn_key_values is not None: | ||
assert cache is not None | ||
attn_key_values.extend(cache) | ||
|
||
if last_logits_only: | ||
# shape: (batch_size, 1, d_model) | ||
|
@@ -973,11 +1032,37 @@ def forward( | |
|
||
return OlmoOutput(logits=logits, attn_key_values=attn_key_values) # type: ignore[arg-type] | ||
|
||
def fsdp_wrap_fn(self, module, recurse: bool = True, nonwrapped_numel: int = 0): | ||
del nonwrapped_numel | ||
if recurse: | ||
return True # always recurse | ||
return isinstance(module, OlmoBlock) | ||
def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None): | ||
if wrap_strategy is None: | ||
return None | ||
if wrap_strategy == FSDPWrapStrategy.by_block: | ||
|
||
def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): | ||
del nonwrapped_numel | ||
if recurse: | ||
return True # always recurse for simplicity | ||
return isinstance(module, OlmoBlock) | ||
|
||
return fsdp_wrap_fn | ||
elif wrap_strategy == FSDPWrapStrategy.by_block_group: | ||
if self.config.block_group_size <= 1: | ||
raise OlmoConfigurationError( | ||
"'by_block_group' FSDP wrapping strategy requires block group size greater than 1" | ||
) | ||
|
||
def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): | ||
del nonwrapped_numel | ||
if recurse: | ||
return True # always recurse for simplicity | ||
Comment on lines
+1055
to
+1056
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand why this works. When does this ever get called with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From what I understand this function basically gets called twice on every module. Once with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the meaning of the return value changes depending on the value of |
||
return isinstance(module, OlmoBlockGroup) | ||
|
||
return fsdp_wrap_fn | ||
elif wrap_strategy == FSDPWrapStrategy.size_based: | ||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy | ||
|
||
return size_based_auto_wrap_policy | ||
else: | ||
raise NotImplementedError(wrap_strategy) | ||
|
||
def activation_checkpointing_fn(self, module): | ||
return isinstance(module, OlmoBlock) | ||
|
@@ -1185,20 +1270,61 @@ def from_checkpoint( | |
return model.eval() | ||
|
||
def _make_state_dict_compatible(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||
import re | ||
from fnmatch import fnmatch | ||
|
||
# Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is | ||
# not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work | ||
# fine without the prefixes. This also simplifies the other steps below. | ||
for key in list(state_dict.keys()): | ||
state_dict[key.replace("_fsdp_wrapped_module.", "")] = state_dict.pop(key) | ||
|
||
# For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222 | ||
prefix = "" | ||
if next(iter(state_dict.keys())).startswith((fsdp_prefix := "_fsdp_wrapped_module.")): | ||
prefix = fsdp_prefix | ||
if self.config.block_type == BlockType.sequential: | ||
for block_idx in range(self.config.n_layers): | ||
norm_w_key = f"{prefix}transformer.blocks.{block_idx}.norm.weight" | ||
norm_b_key = f"{prefix}transformer.blocks.{block_idx}.norm.bias" | ||
if norm_w_key in state_dict: | ||
norm_w = state_dict.pop(norm_w_key) | ||
state_dict[f"{prefix}transformer.blocks.{block_idx}.attn_norm.weight"] = norm_w | ||
state_dict[f"{prefix}transformer.blocks.{block_idx}.ff_norm.weight"] = norm_w.clone() | ||
if norm_b_key in state_dict: | ||
norm_b = state_dict.pop(norm_b_key) | ||
state_dict[f"{prefix}transformer.blocks.{block_idx}.attn_norm.bias"] = norm_b | ||
state_dict[f"{prefix}transformer.blocks.{block_idx}.ff_norm.bias"] = norm_b.clone() | ||
for key in list(state_dict.keys()): | ||
if fnmatch(key, "transformer.*.norm.weight"): | ||
tensor = state_dict.pop(key) | ||
state_dict[key.replace("norm.weight", "attn_norm.weight")] = tensor | ||
state_dict[key.replace("norm.weight", "ff_norm.weight")] = tensor.clone() | ||
elif fnmatch(key, "transformer.*.norm.bias"): | ||
tensor = state_dict.pop(key) | ||
state_dict[key.replace("norm.bias", "attn_norm.bias")] = tensor | ||
state_dict[key.replace("norm.bias", "ff_norm.bias")] = tensor.clone() | ||
|
||
# For loading a state dict that was saved with a different `block_group_size`. | ||
if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys(): | ||
state_dict_block_group_size = len( | ||
[k for k in state_dict.keys() if fnmatch(k, "transformer.block_groups.0.*.attn_out.weight")] | ||
) | ||
else: | ||
state_dict_block_group_size = 1 | ||
if self.config.block_group_size != state_dict_block_group_size: | ||
log.info( | ||
f"Regrouping state dict blocks from group size {state_dict_block_group_size} to " | ||
f"group size {self.config.block_group_size}" | ||
) | ||
# For simplicity we're first going to flatten out the block groups in the state dict (if necessary) | ||
# and then (re-)group them into the right block sizes. | ||
if state_dict_block_group_size > 1: | ||
for key in list(state_dict.keys()): | ||
if (m := re.match(r"transformer.block_groups\.(\d+)\.(\d+)\..*", key)) is not None: | ||
group_idx, group_block_idx = int(m.group(1)), int(m.group(2)) | ||
block_idx = (group_idx * state_dict_block_group_size) + group_block_idx | ||
state_dict[ | ||
key.replace(f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}.") | ||
] = state_dict.pop(key) | ||
|
||
if self.config.block_group_size > 1: | ||
# Group the state dict blocks into the right block size. | ||
for key in list(state_dict.keys()): | ||
if (m := re.match(r"transformer.blocks\.(\d+)\..*", key)) is not None: | ||
block_idx = int(m.group(1)) | ||
group_idx, group_block_idx = ( | ||
block_idx // self.config.block_group_size, | ||
block_idx % self.config.block_group_size, | ||
) | ||
state_dict[ | ||
key.replace(f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}.") | ||
] = state_dict.pop(key) | ||
|
||
return state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this strategy, the input and output embeddings are never wrapped. I think that's fine at this point in time, but we should experiment with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. I didn't want to change too many things.