diff --git a/olmo/model.py b/olmo/model.py index a52411c43..10f539afd 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -948,7 +948,7 @@ def forward( ): # shape: (batch_size, seq_len, d_model) x, cache = self.__activation_checkpoint_fn( - block, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache + block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache ) if attn_key_values is not None: