Skip to content

Commit

Permalink
Merge branch 'main' into olmo7-ablations
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr authored Feb 23, 2024
2 parents 50a7704 + 922db6a commit ae538ce
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 15 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ jobs:
spec: |
version: v2
description: GPU Tests
budget: ai2/oe-training
tasks:
- name: tests
image:
Expand Down
2 changes: 2 additions & 0 deletions hf_olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def forward(
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand All @@ -70,6 +71,7 @@ def forward(
input_ids=input_ids,
input_embeddings=inputs_embeds,
attention_mask=attention_mask,
attention_bias=attention_bias,
past_key_values=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
Expand Down
43 changes: 29 additions & 14 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,25 +1339,37 @@ def forward(
def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None):
if wrap_strategy is None:
return None

# The 'recurse' mode for the wrap function does not behave like you'd expect.
# Even if we return False, it may still recurse because PyTorch does what it wants,
# not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer)
# but not other linear layers within a block.
# So we have to explicitly tell PyTorch which linear layers to wrap, and we also just
# return True in 'recurse' mode for simplicity.
size_based_module_to_wrap = {self.transformer.wte}
if hasattr(self.transformer, "ff_out"):
size_based_module_to_wrap.add(self.transformer.ff_out)

if wrap_strategy == FSDPWrapStrategy.by_block:

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
wrap = isinstance(module, OlmoBlock)
if recurse:
return True # always recurse for simplicity
return isinstance(module, OlmoBlock)
return True
else:
return wrap

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.by_block_and_size:

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
wrap = isinstance(module, (OlmoBlock,)) or module in size_based_module_to_wrap
if recurse:
# Determine if we should recurse.
return not isinstance(module, OlmoBlock)
return True
else:
# Determine if we should wrap.
return isinstance(module, (OlmoBlock, nn.Linear, nn.Embedding))
return wrap

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.by_block_group:
Expand All @@ -1368,9 +1380,11 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
wrap = isinstance(module, OlmoBlockGroup)
if recurse:
return True # always recurse for simplicity
return isinstance(module, OlmoBlockGroup)
return True
else:
return wrap

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size:
Expand All @@ -1381,12 +1395,11 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
wrap = isinstance(module, (OlmoBlockGroup,)) or module in size_based_module_to_wrap
if recurse:
# Determine if we should recurse.
return not isinstance(module, OlmoBlockGroup)
return True
else:
# Determine if we should wrap.
return isinstance(module, (OlmoBlockGroup, nn.Linear, nn.Embedding))
return wrap

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.size_based:
Expand All @@ -1408,9 +1421,11 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
wrap = isinstance(module, OlmoBlock) and module.layer_id % c == 0
if recurse:
return True # always recurse for simplicity
return isinstance(module, OlmoBlock) and module.layer_id % c == 0
return True
else:
return wrap

return fsdp_wrap_fn
else:
Expand Down
2 changes: 1 addition & 1 deletion olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def fit(self):
if self.cfg.torch_profiling and get_global_rank() == 0:
from torch.profiler import schedule

profiling_schedule = schedule(wait=1, warmup=5, active=3)
profiling_schedule = schedule(wait=1, warmup=5, active=3, repeat=1)

def on_trace_ready(p):
profiler_output_dir = Path(self.cfg.save_folder) / "profiler"
Expand Down

0 comments on commit ae538ce

Please sign in to comment.