Skip to content

Commit

Permalink
Fix mask slicing for models with HybridCache (#35681)
Browse files Browse the repository at this point in the history
* correctly slice

* check mask

* Update modular_gemma2.py

* fix

* add tests

* fix typo

* finally fix mask slicing

* Finally correctly slice in all cases!!

* add test for all attention functions

* small fix in tests

* trick around dynamo tracing issue

* last update

* more robust

* kwargs propagation

* make it explicit for checkpointing

* apply modular
  • Loading branch information
Cyrilvallez authored Jan 28, 2025
1 parent b764c20 commit 3f860db
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 22 deletions.
43 changes: 38 additions & 5 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ def forward(
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# Here we need to slice as we use a static cache by default, but FA2 does not support it
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
Expand Down Expand Up @@ -323,6 +328,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand All @@ -343,21 +349,30 @@ def forward(
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
"""

if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
# In prefill, we may be larger than sliding window
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
# thus we must slice from the right (at most `effective_seq_len` elements)
if self.config._attn_implementation == "flash_attention_2":
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
attention_mask = attention_mask[:, -effective_seq_len:]
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]

residual = hidden_states

Expand Down Expand Up @@ -557,6 +572,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -595,9 +611,20 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)

hidden_states = inputs_embeds

# create position embeddings to be shared across the decoder layers
Expand All @@ -621,6 +648,7 @@ def forward(
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -631,6 +659,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)

Expand Down Expand Up @@ -917,6 +946,10 @@ def prepare_inputs_for_generation(
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0

if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
Expand Down
43 changes: 38 additions & 5 deletions src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ def forward(
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# Here we need to slice as we use a static cache by default, but FA2 does not support it
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
Expand Down Expand Up @@ -349,6 +354,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand All @@ -369,21 +375,30 @@ def forward(
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
"""

if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
# In prefill, we may be larger than sliding window
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
# thus we must slice from the right (at most `effective_seq_len` elements)
if self.config._attn_implementation == "flash_attention_2":
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
attention_mask = attention_mask[:, -effective_seq_len:]
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]

residual = hidden_states

Expand Down Expand Up @@ -443,6 +458,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -481,9 +497,20 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)

hidden_states = inputs_embeds

# create position embeddings to be shared across the decoder layers
Expand All @@ -507,6 +534,7 @@ def forward(
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -517,6 +545,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)

Expand Down Expand Up @@ -586,6 +615,10 @@ def prepare_inputs_for_generation(
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0

if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
Expand Down
51 changes: 45 additions & 6 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,19 @@ def forward(

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
"sliding_window": self.sliding_window,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# Here we need to slice as we use a static cache by default, but FA2 does not support it
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
Expand Down Expand Up @@ -277,20 +287,30 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
# In prefill, we may be larger than sliding window
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
# thus we must slice from the right (at most `effective_seq_len` elements)
if self.config._attn_implementation == "flash_attention_2":
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
attention_mask = attention_mask[:, -effective_seq_len:]
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]

residual = hidden_states

Expand All @@ -306,6 +326,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -554,6 +575,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -593,6 +615,16 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
Expand Down Expand Up @@ -628,6 +660,7 @@ def forward(
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -639,6 +672,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)

Expand Down Expand Up @@ -857,6 +891,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**loss_kwargs,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -926,6 +961,10 @@ def prepare_inputs_for_generation(
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0

if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
Expand Down
Loading

0 comments on commit 3f860db

Please sign in to comment.