Description
System Info
transformers==4.48.1
python=3.11.11
Who can help?
Expected behavior
Mixed-precision training via torch.autocast
is broken for most models inspired by the HF Llama code (which is a lot of models) when using attn_implementation="flash_attention_2
and potentially not working as intended in general.
Snippet to reproduce on transformers==4.48.1
:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
torch.set_default_device("cuda:0")
model_name = "meta-llama/Llama-3.2-3B" # many others, e.g. "allenai/OLMo-2-1124-7B"
inputs = AutoTokenizer.from_pretrained(model_name)("I ❤️ 🤗", return_tensors="pt")
model = AutoModel.from_pretrained(model_name, attn_implementation="flash_attention_2", torch_dtype=torch.float32)
with torch.autocast("cuda", dtype=torch.bfloat16):
# errors with -> RuntimeError: FlashAttention only support fp16 and bf16 data type
outputs = model(**inputs)
Indeed after calling AutoModel.from_pretrained(model_name, attn_implementation="flash_attention_2", torch_dtype=torch.float32)
, we get
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. ...`
but the snippet fails even though we do use torch.autocast
as suggested. For mixed-precision training, we do actually want to load the weights in float32
.
Concretely, the source is two different issues:
- The common llama-inspired implementation of RMSNorm silently upcasts to
float32
even within anautocast
context, which is propagated from theq_norm
/v_norm
up until passing the projections to the attention function (FA2 fails here)- The RMSNorm issue has been discussed in these related issues: here and here and here
- We have this comment in the FlashAttention integration about
RMSNorm
usually handling silent upcasting correctly but it seems at some point this broke:transformers/src/transformers/integrations/flash_attention.py
Lines 32 to 36 in ec7afad
- The
cos
andsin
position embeddings for RoPE are infloat32
even within anautocast
context, which will again silently upcast thequery_states
/key_states
tofloat32
before passing to the attention function:
- they are in
float32
is becausecos
andsin
are created in e.g.LlamaModel
:
where thehidden_states
come from thenn.Embedding
which is never autocasted bytorch.autocast
. So:
does not work as intended (?) because the input at that point has not been autocasted yet.
- they are in
One fix is to remove the silent upcasting of the output to float32
in RMSNorm
if the input is bfloat16
and directly casting cos
and sin
to the torch.get_autocast_dtype
if in autocast.
In this discussion it seems that this might come with some issues so there might have to be some different solution (I am not quite sure of the exact reasons though for potential issues).
It's important to note that through all this silent upcasting, we're probably (I haven't benchmarked though) using a lot of extra memory when doing mixed-precision training (regardless of whether we use attn_implementation="flash_attention_2"
or not).