Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDPA is_causal=False has no effect due to LlamaModel._prepare_4d_causal_attention_mask_with_cache_position #36150

Open
4 tasks
ringohoffman opened this issue Feb 12, 2025 · 4 comments
Labels

Comments

@ringohoffman
Copy link
Contributor

ringohoffman commented Feb 12, 2025

System Info

  • transformers version: 4.48.3
  • Platform: Linux-5.15.0-130-generic-x86_64-with-glibc2.35
  • Python version: 3.9.21
  • Huggingface_hub version: 0.28.1
  • Safetensors version: 0.5.2
  • Accelerate version: 1.3.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA H100 80GB HBM3

Who can help?

@ArthurZucker @Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Observe is_causal=False has no effect when using attn_implementation="sdpa" with an attention_mask with at least one False element:

import torch
import transformers

device = torch.device("cuda:0")
input_ids = torch.tensor(
    [
        [
            128000, 128006,   9125, 128007,    271,     34,   7747,    553,    279,
            2768,   1495,    439,   1694,   5552,    311,   5557,     11,  17452,
            11,  10034,     11,    477,  11759,     13, 128009, 128006,    882,
            128007,    271,    791,    502,  77355,   3280,    690,  10536,   1022,
            449,    264,  72097,   2489,   1990,  35812,    323,  64921,     13,
            128009, 128006,  78191, 128007,    271,  42079, 128009, 128004, 128004,
            128004, 128004
        ]
    ],
    device=device,
)
attention_mask = torch.tensor(
    [
        [
            True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True, False, False, False, False
        ]
    ],
    device=device,
)

with device:
    model = transformers.AutoModelForCausalLM.from_pretrained(
        "/models/meta-llama/Llama-3.2-1B-Instruct",  # https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct
        attn_implementation="sdpa",
        torch_dtype=torch.bfloat16,
    )

causal_logits = model(input_ids, attention_mask=attention_mask, is_causal=True).logits
noncausal_logits = model(input_ids, attention_mask=attention_mask, is_causal=False).logits

torch.testing.assert_close(causal_logits, noncausal_logits)  # shouldn't be true, otherwise what is_causal controlling?

Observe that mocking LlamaModel._prepare_4d_causal_attention_mask_with_cache_position with an implementation that just replicates the attention_mask also has no effect when using is_causal=True:

from unittest import mock

def _prepare_4d_causal_attention_mask_with_cache_position(
    attention_mask: torch.Tensor,
    sequence_length: int,
    target_length: int,
    dtype: torch.dtype,
    device: torch.device,
    cache_position: torch.Tensor,
    batch_size: int,
    **kwargs,
):
    min_dtype = torch.tensor(torch.finfo(dtype).min, dtype=dtype, device=attention_mask.device)
    return ~attention_mask.view(batch_size, 1, 1, sequence_length).expand(batch_size, 1, sequence_length, sequence_length) * min_dtype

with mock.patch.object(model.model, "_prepare_4d_causal_attention_mask_with_cache_position", _prepare_4d_causal_attention_mask_with_cache_position):
    sdpa_causal_logits = model(input_ids, attention_mask=attention_mask, is_causal=True).logits

hf_causal_logits = model(input_ids, attention_mask=attention_mask, is_causal=True).logits

torch.testing.assert_close(sdpa_causal_logits, hf_causal_logits)  # shouldn't be true, otherwise what is _prepare_4d_causal_attention_mask_with_cache_position doing?

Expected behavior

  1. At the very least, LlamaModel. _prepare_4d_causal_attention_mask_with_cache_position should respect is_causal=False. Right now, it always returns a causal mask when using sdpa with sequence_length > 1 and an attention_mask with at least one False element.
  2. It is not really clear to me why we aren't purely relying on SDPA's own is_causal parameter. My 2nd example demonstrates that the current implementation of LlamaModel. _prepare_4d_causal_attention_mask_with_cache_position definitely isn't always necessary... so when is it necessary? Or what parts are necessary? Looking at the equivalent implementation that PyTorch describes for scaled_dot_product_attention, it seems like we are replicating a bit of their handling of attn_mask. Also notably there are 4 separate CUDA allocations happening in the current implementation (torch.full, torch.triu, torch.arange, Tensor.clone) compared to my proposed 1.
@ringohoffman
Copy link
Contributor Author

@gante

@ringohoffman ringohoffman changed the title SDPA is_causal=True has no affect due to LlamaModel._prepare_4d_causal_attention_mask_with_cache_position SDPA is_causal=True has no effect due to LlamaModel._prepare_4d_causal_attention_mask_with_cache_position Feb 12, 2025
@ringohoffman ringohoffman changed the title SDPA is_causal=True has no effect due to LlamaModel._prepare_4d_causal_attention_mask_with_cache_position SDPA is_causal=False has no effect due to LlamaModel._prepare_4d_causal_attention_mask_with_cache_position Feb 12, 2025
@zucchini-nlp
Copy link
Member

Kinda related to #36049. Not sure if the PR will solve the problem

@ringohoffman
Copy link
Contributor Author

Kinda related to #36049. Not sure if the PR will solve the problem

#36049 deals purely with typing. It will definitely not solve this issue.

To fix this issue at a minimum, LlamaModel._prepare_4d_causal_attention_mask_with_cache_position needs to be passed is_causal and return a different result depending on that.

But I am asking what parts of the current implementation of LlamaModel._prepare_4d_causal_attention_mask_with_cache_position are necessary since it seems like scaled_dot_product_attention handles a lot of what it is doing already, as shown by my second code example. Maybe some of the cache_position stuff we might keep?

@zucchini-nlp
Copy link
Member

@ringohoffman yep, sorry! I meant related in a sense that is_causal was never an acceptable kwarg for the model's forward until the PR is merged. Therefore we don't expect the arg to have any effect.

The _prepare_4d_causal_attention_mask_with_cache_position does indeed the same thing was SDPA's causal mask in most simple cases. But the SDPA mask doesn't handle padding ids which we add to the mask, as you mentioned. Also we have a variety of Key-Value caching methods, like Static Cache. When these methods are used, we need to prepare special mask by masking out the extra positions at the right end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants