Skip to content

Commit

Permalink
fix compile without act chkpting
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Dec 15, 2023
1 parent bd81f46 commit 466dba6
Showing 1 changed file with 53 additions and 22 deletions.
75 changes: 53 additions & 22 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from .exceptions import OlmoConfigurationError
from .initialization import ModuleType, init_weights
from .torch_util import ensure_finite_
from .util import pass_through_fn

__all__ = [
"LayerNormBase",
Expand Down Expand Up @@ -430,7 +429,7 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
self.__cache = cache
assert config.d_model % config.n_heads == 0

self._activation_checkpoint_fn = pass_through_fn
self._activation_checkpoint_fn = None

# Dropout.
self.dropout = Dropout(config.residual_dropout)
Expand Down Expand Up @@ -492,7 +491,7 @@ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointin
if strategy == ActivationCheckpointingStrategy.fine_grained:
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
else:
self._activation_checkpoint_fn = pass_through_fn
self._activation_checkpoint_fn = None

@classmethod
def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
Expand Down Expand Up @@ -673,12 +672,20 @@ def forward(
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
# - for multi-query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_heads)
q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(self.fused_dims, dim=-1)
if self._activation_checkpoint_fn is not None:
q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(
self.fused_dims, dim=-1
)
else:
q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1)

# Get attention scores.
att, cache = self._activation_checkpoint_fn(
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
)
if self._activation_checkpoint_fn is not None:
att, cache = self._activation_checkpoint_fn( # type: ignore
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
)
else:
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)

# Add attention scores.
# shape: (B, T, C)
Expand All @@ -687,9 +694,15 @@ def forward(
# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
og_x = x
x = self._activation_checkpoint_fn(self.ff_norm, x)
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
else:
x = self.ff_norm(x)
x = self.ff_proj(x)
x = self._activation_checkpoint_fn(self.act, x)
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
else:
x = self.act(x)
x = self.ff_out(x)
x = self.dropout(x)
x = og_x + x
Expand Down Expand Up @@ -753,23 +766,35 @@ def forward(
# - for multi-query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_heads)
# shape of ff: (batch_size, seq_len, hidden_size)
q, k, v, ff = self.fused_attn_ff_proj(self._activation_checkpoint_fn(self.norm, x)).split(
self.fused_dims, dim=-1
)
if self._activation_checkpoint_fn is not None:
q, k, v, ff = self.fused_attn_ff_proj(self._activation_checkpoint_fn(self.norm, x)).split(
self.fused_dims, dim=-1
)
else:
q, k, v, ff = self.fused_attn_ff_proj(self.norm(x)).split(self.fused_dims, dim=-1)

# Get attention scores.
# shape: (B, T, C)
att, cache = self._activation_checkpoint_fn(
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
)
if self._activation_checkpoint_fn is not None:
att, cache = self._activation_checkpoint_fn( # type: ignore
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
)
else:
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)

# Apply output projections (and activation function) and sum the results.
# We keep these projections separate because we found that we got better throughput this
# way compared to fusing them.
return (
x + self.dropout(self.ff_out(self._activation_checkpoint_fn(self.act, ff))) + self.dropout(att),
cache,
)
if self._activation_checkpoint_fn is not None:
return (
x + self.dropout(self.ff_out(self._activation_checkpoint_fn(self.act, ff))) + self.dropout(att),
cache,
)
else:
return (
x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att),
cache,
)


class OlmoLlamaBlock(OlmoBlock):
Expand Down Expand Up @@ -874,9 +899,15 @@ def forward(
# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
og_x = x
x = self._activation_checkpoint_fn(self.ff_norm, x)
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
else:
x = self.ff_norm(x)
x = self.ff_proj(x)
x = self._activation_checkpoint_fn(self.act, x)
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
else:
x = self.act(x)
x = self.ff_out(x)
x = self.dropout(x)
x = og_x + x
Expand Down Expand Up @@ -945,7 +976,7 @@ def forward(
)
):
# shape: (batch_size, seq_len, d_model)
x, cache = self._activation_checkpoint_fn(
x, cache = self._activation_checkpoint_fn( # type: ignore
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
)
else:
Expand Down

0 comments on commit 466dba6

Please sign in to comment.