diff --git a/olmo/model.py b/olmo/model.py index e4ab0eb62..6eae70995 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -972,7 +972,7 @@ def forward( ): # shape: (batch_size, seq_len, d_model) x, cache = self.__activation_checkpoint_fn( - block, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache + block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache ) if attn_key_values is not None: