Skip to content

Mixed-precision with torch.autocast is broken for many models when using attn_implementation="flash_attention_2" #35945

Closed
@konstantinjdobler

Description

@konstantinjdobler

System Info

transformers==4.48.1
python=3.11.11

Who can help?

@ArthurZucker

Expected behavior

Mixed-precision training via torch.autocast is broken for most models inspired by the HF Llama code (which is a lot of models) when using attn_implementation="flash_attention_2 and potentially not working as intended in general.

Snippet to reproduce on transformers==4.48.1:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
torch.set_default_device("cuda:0")
model_name = "meta-llama/Llama-3.2-3B" # many others, e.g. "allenai/OLMo-2-1124-7B"
inputs = AutoTokenizer.from_pretrained(model_name)("I ❤️ 🤗", return_tensors="pt")
model = AutoModel.from_pretrained(model_name, attn_implementation="flash_attention_2", torch_dtype=torch.float32)

with torch.autocast("cuda", dtype=torch.bfloat16):
	# errors with -> RuntimeError: FlashAttention only support fp16 and bf16 data type
	outputs = model(**inputs)

Indeed after calling AutoModel.from_pretrained(model_name, attn_implementation="flash_attention_2", torch_dtype=torch.float32), we get

Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. ...`

but the snippet fails even though we do use torch.autocast as suggested. For mixed-precision training, we do actually want to load the weights in float32.

Concretely, the source is two different issues:

  • The common llama-inspired implementation of RMSNorm silently upcasts to float32 even within an autocast context, which is propagated from the q_norm / v_norm up until passing the projections to the attention function (FA2 fails here)
    • The RMSNorm issue has been discussed in these related issues: here and here and here
    • We have this comment in the FlashAttention integration about RMSNorm usually handling silent upcasting correctly but it seems at some point this broke:
      # In PEFT, usually we cast the layer norms in float32 for training stability reasons
      # therefore the input hidden states gets silently casted in float32. Hence, we need
      # cast them back in the correct dtype just to be sure everything works as expected.
      # This might slowdown training & inference so it is recommended to not cast the LayerNorms
      # in fp32. (usually our RMSNorm modules handle it correctly)
  • The cos and sin position embeddings for RoPE are in float32 even within an autocast context, which will again silently upcast the query_states/key_states to float32 before passing to the attention function:
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

One fix is to remove the silent upcasting of the output to float32 in RMSNorm if the input is bfloat16 and directly casting cos and sin to the torch.get_autocast_dtype if in autocast.
In this discussion it seems that this might come with some issues so there might have to be some different solution (I am not quite sure of the exact reasons though for potential issues).

It's important to note that through all this silent upcasting, we're probably (I haven't benchmarked though) using a lot of extra memory when doing mixed-precision training (regardless of whether we use attn_implementation="flash_attention_2" or not).

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions