diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 2353601d91f2..6078f7d99a29 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -260,6 +260,11 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # Here we need to slice as we use a static cache by default, but FA2 does not support it + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": + seq_len = attention_mask.shape[-1] + key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -323,6 +328,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + last_cache_position: int = 0, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -343,21 +349,30 @@ def forward( (see `past_key_values`). cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence + last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing """ if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding - # Flash-attn is a 2D tensor + # In prefill, we may be larger than sliding window + effective_seq_len = max(cache_position.shape[0], self.sliding_window) + # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), + # thus we must slice from the right (at most `effective_seq_len` elements) if self.config._attn_implementation == "flash_attention_2": - if past_key_value is not None: # when decoding - attention_mask = attention_mask[:, -self.sliding_window :] + attention_mask = attention_mask[:, -effective_seq_len:] + # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice + # from the left, with an offset if we are beyond the sliding window else: min_dtype = torch.finfo(hidden_states.dtype).min sliding_window_mask = torch.tril( torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window ) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - if attention_mask.shape[-1] <= 1: # when decoding - attention_mask = attention_mask[:, :, :, -self.sliding_window :] + # In case we are beyond the sliding window, we need to correctly offset the mask slicing + # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo + offset = last_cache_position - effective_seq_len + # Should only be used when beyond the sliding window (i.e. offset > 0) + offset = max(0, offset) + attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] residual = hidden_states @@ -557,6 +572,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -595,9 +611,20 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + if last_cache_position is None: + last_cache_position = 0 + if attention_mask is not None: + # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position + # It will break dynamo tracing but there are no way around it (and it should never happen in practice) + last_cache_position = ( + attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() + ) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -621,6 +648,7 @@ def forward( output_attentions, use_cache, cache_position, + last_cache_position, ) else: layer_outputs = decoder_layer( @@ -631,6 +659,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + last_cache_position=last_cache_position, **flash_attn_kwargs, ) @@ -917,6 +946,10 @@ def prepare_inputs_for_generation( # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 + if ( isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2 diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 0e9ad69176cf..d8dc85c73059 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -305,6 +305,11 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # Here we need to slice as we use a static cache by default, but FA2 does not support it + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": + seq_len = attention_mask.shape[-1] + key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -349,6 +354,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + last_cache_position: int = 0, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -369,21 +375,30 @@ def forward( (see `past_key_values`). cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence + last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing """ if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding - # Flash-attn is a 2D tensor + # In prefill, we may be larger than sliding window + effective_seq_len = max(cache_position.shape[0], self.sliding_window) + # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), + # thus we must slice from the right (at most `effective_seq_len` elements) if self.config._attn_implementation == "flash_attention_2": - if past_key_value is not None: # when decoding - attention_mask = attention_mask[:, -self.sliding_window :] + attention_mask = attention_mask[:, -effective_seq_len:] + # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice + # from the left, with an offset if we are beyond the sliding window else: min_dtype = torch.finfo(hidden_states.dtype).min sliding_window_mask = torch.tril( torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window ) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - if attention_mask.shape[-1] <= 1: # when decoding - attention_mask = attention_mask[:, :, :, -self.sliding_window :] + # In case we are beyond the sliding window, we need to correctly offset the mask slicing + # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo + offset = last_cache_position - effective_seq_len + # Should only be used when beyond the sliding window (i.e. offset > 0) + offset = max(0, offset) + attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] residual = hidden_states @@ -443,6 +458,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -481,9 +497,20 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + if last_cache_position is None: + last_cache_position = 0 + if attention_mask is not None: + # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position + # It will break dynamo tracing but there are no way around it (and it should never happen in practice) + last_cache_position = ( + attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() + ) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -507,6 +534,7 @@ def forward( output_attentions, use_cache, cache_position, + last_cache_position, ) else: layer_outputs = decoder_layer( @@ -517,6 +545,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + last_cache_position=last_cache_position, **flash_attn_kwargs, ) @@ -586,6 +615,10 @@ def prepare_inputs_for_generation( # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 + if ( isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2 diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 9bc0d278166b..0f585e7bdb20 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -221,9 +221,19 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + "sliding_window": self.sliding_window, + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # Here we need to slice as we use a static cache by default, but FA2 does not support it + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": + seq_len = attention_mask.shape[-1] + key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -277,20 +287,30 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + last_cache_position: int = 0, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding - # Flash-attn is a 2D tensor + # In prefill, we may be larger than sliding window + effective_seq_len = max(cache_position.shape[0], self.sliding_window) + # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), + # thus we must slice from the right (at most `effective_seq_len` elements) if self.config._attn_implementation == "flash_attention_2": - if past_key_value is not None: # when decoding - attention_mask = attention_mask[:, -self.sliding_window :] + attention_mask = attention_mask[:, -effective_seq_len:] + # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice + # from the left, with an offset if we are beyond the sliding window else: min_dtype = torch.finfo(hidden_states.dtype).min sliding_window_mask = torch.tril( torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window ) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - if attention_mask.shape[-1] <= 1: # when decoding - attention_mask = attention_mask[:, :, :, -self.sliding_window :] + # In case we are beyond the sliding window, we need to correctly offset the mask slicing + # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo + offset = last_cache_position - effective_seq_len + # Should only be used when beyond the sliding window (i.e. offset > 0) + offset = max(0, offset) + attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] residual = hidden_states @@ -306,6 +326,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states @@ -554,6 +575,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -593,6 +615,16 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + if last_cache_position is None: + last_cache_position = 0 + if attention_mask is not None: + # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position + # It will break dynamo tracing but there are no way around it (and it should never happen in practice) + last_cache_position = ( + attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() + ) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) @@ -628,6 +660,7 @@ def forward( output_attentions, use_cache, cache_position, + last_cache_position, ) else: layer_outputs = decoder_layer( @@ -639,6 +672,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + last_cache_position=last_cache_position, **flash_attn_kwargs, ) @@ -857,6 +891,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **loss_kwargs, ) hidden_states = outputs[0] @@ -926,6 +961,10 @@ def prepare_inputs_for_generation( # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 + if ( isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2 diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index f3908b203a58..cce83d204c93 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -265,9 +265,19 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + "sliding_window": self.sliding_window, + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # Here we need to slice as we use a static cache by default, but FA2 does not support it + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": + seq_len = attention_mask.shape[-1] + key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -321,20 +331,30 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + last_cache_position: int = 0, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding - # Flash-attn is a 2D tensor + # In prefill, we may be larger than sliding window + effective_seq_len = max(cache_position.shape[0], self.sliding_window) + # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), + # thus we must slice from the right (at most `effective_seq_len` elements) if self.config._attn_implementation == "flash_attention_2": - if past_key_value is not None: # when decoding - attention_mask = attention_mask[:, -self.sliding_window :] + attention_mask = attention_mask[:, -effective_seq_len:] + # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice + # from the left, with an offset if we are beyond the sliding window else: min_dtype = torch.finfo(hidden_states.dtype).min sliding_window_mask = torch.tril( torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window ) attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - if attention_mask.shape[-1] <= 1: # when decoding - attention_mask = attention_mask[:, :, :, -self.sliding_window :] + # In case we are beyond the sliding window, we need to correctly offset the mask slicing + # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo + offset = last_cache_position - effective_seq_len + # Should only be used when beyond the sliding window (i.e. offset > 0) + offset = max(0, offset) + attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] residual = hidden_states @@ -350,6 +370,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states @@ -387,6 +408,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + last_cache_position: Optional[int] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -426,6 +448,16 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + if last_cache_position is None: + last_cache_position = 0 + if attention_mask is not None: + # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position + # It will break dynamo tracing but there are no way around it (and it should never happen in practice) + last_cache_position = ( + attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() + ) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) @@ -461,6 +493,7 @@ def forward( output_attentions, use_cache, cache_position, + last_cache_position, ) else: layer_outputs = decoder_layer( @@ -472,6 +505,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + last_cache_position=last_cache_position, **flash_attn_kwargs, ) @@ -589,6 +623,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **loss_kwargs, ) hidden_states = outputs[0] @@ -658,6 +693,10 @@ def prepare_inputs_for_generation( # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 + if ( isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2 diff --git a/tests/models/cohere2/test_modeling_cohere2.py b/tests/models/cohere2/test_modeling_cohere2.py index 1d201eb1e287..436f1f965e90 100644 --- a/tests/models/cohere2/test_modeling_cohere2.py +++ b/tests/models/cohere2/test_modeling_cohere2.py @@ -324,3 +324,36 @@ def test_export_static_cache(self): ) ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) + + @parameterized.expand([("flash_attention_2",), ("sdpa",), ("flex_attention",), ("eager",)]) + @require_read_token + def test_generation_beyond_sliding_window(self, attn_implementation: str): + """Test that we can correctly generate beyond the sliding window. This is non trivial as + we need to correctly slice the attention mask in all cases (because we use a HybridCache). + Outputs for every attention functions should be coherent and identical. + """ + model_id = "CohereForAI/c4ai-command-r7b-12-2024" + EXPECTED_COMPLETIONS = [ + " the mountains, the lakes, the rivers, the waterfalls, the waterfalls, the waterfalls, the waterfalls", + ", green, yellow, orange, purple, pink, brown, black, white, grey, silver", + ] + + input_text = [ + "This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens + "A list of colors: red, blue", # This will almost all be padding tokens + ] + tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left") + inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device) + + model = AutoModelForCausalLM.from_pretrained( + model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16 + ).to(torch_device) + + # Make sure prefill is larger than sliding window + input_size = inputs.input_ids.shape[-1] + self.assertTrue(input_size > model.config.sliding_window) + + out = model.generate(**inputs, max_new_tokens=20)[:, input_size:] + output_text = tokenizer.batch_decode(out) + + self.assertEqual(output_text, EXPECTED_COMPLETIONS) diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 2bbe0d8e5b74..1fb7bdfa8994 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -394,3 +394,36 @@ def test_model_9b_bf16_flex_attention(self): output_text = tokenizer.batch_decode(output, skip_special_tokens=False) self.assertEqual(output_text, EXPECTED_TEXTS) + + @parameterized.expand([("flash_attention_2",), ("sdpa",), ("flex_attention",), ("eager",)]) + @require_read_token + def test_generation_beyond_sliding_window(self, attn_implementation: str): + """Test that we can correctly generate beyond the sliding window. This is non trivial as + we need to correctly slice the attention mask in all cases (because we use a HybridCache). + Outputs for every attention functions should be coherent and identical. + """ + model_id = "google/gemma-2-2b" + EXPECTED_COMPLETIONS = [ + " the people, the food, the culture, the history, the music, the art, the architecture", + ", green, yellow, orange, purple, pink, brown, black, white, gray, silver", + ] + + input_text = [ + "This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens + "A list of colors: red, blue", # This will almost all be padding tokens + ] + tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left") + inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device) + + model = AutoModelForCausalLM.from_pretrained( + model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16 + ).to(torch_device) + + # Make sure prefill is larger than sliding window + input_size = inputs.input_ids.shape[-1] + self.assertTrue(input_size > model.config.sliding_window) + + out = model.generate(**inputs, max_new_tokens=20)[:, input_size:] + output_text = tokenizer.batch_decode(out) + + self.assertEqual(output_text, EXPECTED_COMPLETIONS)