SDPA is_causal=False
has no effect due to LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
#36150
Labels
System Info
transformers
version: 4.48.3Who can help?
@ArthurZucker @Cyrilvallez
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Observe
is_causal=False
has no effect when usingattn_implementation="sdpa"
with anattention_mask
with at least oneFalse
element:Observe that mocking
LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
with an implementation that just replicates theattention_mask
also has no effect when usingis_causal=True
:Expected behavior
LlamaModel. _prepare_4d_causal_attention_mask_with_cache_position
should respectis_causal=False
. Right now, it always returns a causal mask when using sdpa with sequence_length > 1 and an attention_mask with at least one False element.is_causal
parameter. My 2nd example demonstrates that the current implementation ofLlamaModel. _prepare_4d_causal_attention_mask_with_cache_position
definitely isn't always necessary... so when is it necessary? Or what parts are necessary? Looking at the equivalent implementation that PyTorch describes forscaled_dot_product_attention
, it seems like we are replicating a bit of their handling ofattn_mask
. Also notably there are 4 separate CUDA allocations happening in the current implementation (torch.full
,torch.triu
,torch.arange
,Tensor.clone
) compared to my proposed 1.The text was updated successfully, but these errors were encountered: