diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index adac82794..1093adb3d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -127,6 +127,7 @@ jobs: spec: | version: v2 description: GPU Tests + budget: ai2/oe-training tasks: - name: tests image: diff --git a/hf_olmo/modeling_olmo.py b/hf_olmo/modeling_olmo.py index 83814c5cf..6a279cb10 100644 --- a/hf_olmo/modeling_olmo.py +++ b/hf_olmo/modeling_olmo.py @@ -50,6 +50,7 @@ def forward( input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -70,6 +71,7 @@ def forward( input_ids=input_ids, input_embeddings=inputs_embeds, attention_mask=attention_mask, + attention_bias=attention_bias, past_key_values=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states, diff --git a/olmo/model.py b/olmo/model.py index 466a37a99..a11eceb71 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -1339,25 +1339,37 @@ def forward( def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None): if wrap_strategy is None: return None + + # The 'recurse' mode for the wrap function does not behave like you'd expect. + # Even if we return False, it may still recurse because PyTorch does what it wants, + # not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer) + # but not other linear layers within a block. + # So we have to explicitly tell PyTorch which linear layers to wrap, and we also just + # return True in 'recurse' mode for simplicity. + size_based_module_to_wrap = {self.transformer.wte} + if hasattr(self.transformer, "ff_out"): + size_based_module_to_wrap.add(self.transformer.ff_out) + if wrap_strategy == FSDPWrapStrategy.by_block: def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): del nonwrapped_numel + wrap = isinstance(module, OlmoBlock) if recurse: - return True # always recurse for simplicity - return isinstance(module, OlmoBlock) + return True + else: + return wrap return fsdp_wrap_fn elif wrap_strategy == FSDPWrapStrategy.by_block_and_size: def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): del nonwrapped_numel + wrap = isinstance(module, (OlmoBlock,)) or module in size_based_module_to_wrap if recurse: - # Determine if we should recurse. - return not isinstance(module, OlmoBlock) + return True else: - # Determine if we should wrap. - return isinstance(module, (OlmoBlock, nn.Linear, nn.Embedding)) + return wrap return fsdp_wrap_fn elif wrap_strategy == FSDPWrapStrategy.by_block_group: @@ -1368,9 +1380,11 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): del nonwrapped_numel + wrap = isinstance(module, OlmoBlockGroup) if recurse: - return True # always recurse for simplicity - return isinstance(module, OlmoBlockGroup) + return True + else: + return wrap return fsdp_wrap_fn elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size: @@ -1381,12 +1395,11 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): del nonwrapped_numel + wrap = isinstance(module, (OlmoBlockGroup,)) or module in size_based_module_to_wrap if recurse: - # Determine if we should recurse. - return not isinstance(module, OlmoBlockGroup) + return True else: - # Determine if we should wrap. - return isinstance(module, (OlmoBlockGroup, nn.Linear, nn.Embedding)) + return wrap return fsdp_wrap_fn elif wrap_strategy == FSDPWrapStrategy.size_based: @@ -1408,9 +1421,11 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): del nonwrapped_numel + wrap = isinstance(module, OlmoBlock) and module.layer_id % c == 0 if recurse: - return True # always recurse for simplicity - return isinstance(module, OlmoBlock) and module.layer_id % c == 0 + return True + else: + return wrap return fsdp_wrap_fn else: diff --git a/olmo/train.py b/olmo/train.py index f459ad88d..79132f0fc 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -877,7 +877,7 @@ def fit(self): if self.cfg.torch_profiling and get_global_rank() == 0: from torch.profiler import schedule - profiling_schedule = schedule(wait=1, warmup=5, active=3) + profiling_schedule = schedule(wait=1, warmup=5, active=3, repeat=1) def on_trace_ready(p): profiler_output_dir = Path(self.cfg.save_folder) / "profiler"