diff --git a/olmo/model.py b/olmo/model.py index b635518cb..4fdc3a4b4 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -45,7 +45,6 @@ from .exceptions import OlmoConfigurationError from .initialization import ModuleType, init_weights from .torch_util import ensure_finite_ -from .util import pass_through_fn __all__ = [ "LayerNormBase", @@ -430,7 +429,7 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): self.__cache = cache assert config.d_model % config.n_heads == 0 - self._activation_checkpoint_fn = pass_through_fn + self._activation_checkpoint_fn = None # Dropout. self.dropout = Dropout(config.residual_dropout) @@ -492,7 +491,7 @@ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointin if strategy == ActivationCheckpointingStrategy.fine_grained: self._activation_checkpoint_fn = activation_checkpoint_function(self.config) else: - self._activation_checkpoint_fn = pass_through_fn + self._activation_checkpoint_fn = None @classmethod def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor: @@ -673,12 +672,20 @@ def forward( # - for regular attn q, k, v: (batch_size, seq_len, d_model) # - for multi-query attn q: (batch_size, seq_len, d_model) # k, v: (batch_size, seq_len, d_model // n_heads) - q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(self.fused_dims, dim=-1) + if self._activation_checkpoint_fn is not None: + q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split( + self.fused_dims, dim=-1 + ) + else: + q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1) # Get attention scores. - att, cache = self._activation_checkpoint_fn( - self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache - ) + if self._activation_checkpoint_fn is not None: + att, cache = self._activation_checkpoint_fn( # type: ignore + self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) # Add attention scores. # shape: (B, T, C) @@ -687,9 +694,15 @@ def forward( # Add feed-forward projection. # shape: (batch_size, seq_len, d_model) og_x = x - x = self._activation_checkpoint_fn(self.ff_norm, x) + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) x = self.ff_proj(x) - x = self._activation_checkpoint_fn(self.act, x) + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.act, x) # type: ignore + else: + x = self.act(x) x = self.ff_out(x) x = self.dropout(x) x = og_x + x @@ -753,23 +766,35 @@ def forward( # - for multi-query attn q: (batch_size, seq_len, d_model) # k, v: (batch_size, seq_len, d_model // n_heads) # shape of ff: (batch_size, seq_len, hidden_size) - q, k, v, ff = self.fused_attn_ff_proj(self._activation_checkpoint_fn(self.norm, x)).split( - self.fused_dims, dim=-1 - ) + if self._activation_checkpoint_fn is not None: + q, k, v, ff = self.fused_attn_ff_proj(self._activation_checkpoint_fn(self.norm, x)).split( + self.fused_dims, dim=-1 + ) + else: + q, k, v, ff = self.fused_attn_ff_proj(self.norm(x)).split(self.fused_dims, dim=-1) # Get attention scores. # shape: (B, T, C) - att, cache = self._activation_checkpoint_fn( - self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache - ) + if self._activation_checkpoint_fn is not None: + att, cache = self._activation_checkpoint_fn( # type: ignore + self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) # Apply output projections (and activation function) and sum the results. # We keep these projections separate because we found that we got better throughput this # way compared to fusing them. - return ( - x + self.dropout(self.ff_out(self._activation_checkpoint_fn(self.act, ff))) + self.dropout(att), - cache, - ) + if self._activation_checkpoint_fn is not None: + return ( + x + self.dropout(self.ff_out(self._activation_checkpoint_fn(self.act, ff))) + self.dropout(att), + cache, + ) + else: + return ( + x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att), + cache, + ) class OlmoLlamaBlock(OlmoBlock): @@ -874,9 +899,15 @@ def forward( # Add feed-forward projection. # shape: (batch_size, seq_len, d_model) og_x = x - x = self._activation_checkpoint_fn(self.ff_norm, x) + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) x = self.ff_proj(x) - x = self._activation_checkpoint_fn(self.act, x) + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.act, x) # type: ignore + else: + x = self.act(x) x = self.ff_out(x) x = self.dropout(x) x = og_x + x @@ -945,7 +976,7 @@ def forward( ) ): # shape: (batch_size, seq_len, d_model) - x, cache = self._activation_checkpoint_fn( + x, cache = self._activation_checkpoint_fn( # type: ignore block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache ) else: