-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mixed-precision with torch.autocast
is broken for many models when using attn_implementation="flash_attention_2"
#35945
Comments
def patch_broken_autocast_llama():
class FixedLlamaRMSNorm(LlamaRMSNorm):
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# fix silent upcasting of output (note that we keep the inner .to() to *only* cast the output and no inner computation)
return (self.weight * hidden_states.to(input_dtype)).to(input_dtype)
# return self.weight * hidden_states.to(input_dtype)
transformers.models.llama.modeling_llama.LlamaRMSNorm = FixedLlamaRMSNorm
class FixedLlamaRotaryEmbedding(LlamaRotaryEmbedding):
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
# fix silent upcasting of output
output_dtype = x.dtype
if torch.is_autocast_enabled():
output_dtype = torch.get_autocast_dtype(device_type=device_type)
return cos.to(dtype=output_dtype), sin.to(dtype=output_dtype)
# return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = FixedLlamaRotaryEmbedding Snippet applying a patch solution to this problem for |
As a sidenote, is it intended that we always keep the residual stream in transformers/src/transformers/models/llama/modeling_llama.py Lines 331 to 353 in ec7afad
while the computation in Can also open a separate issue if you prefer. |
System Info
transformers==4.481
python=3.11.11
Who can help?
@ArthurZucker
Expected behavior
Mixed-precision training via
torch.autocast
is broken for most models inspired by the HF Llama code (which is a lot of models) when usingattn_implementation="flash_attention_2
and potentially not working as intended on general.Snippet to reproduce on
transformers==4.48.1
:Indeed after calling
AutoModel.from_pretrained(model_name, attn_implementation="flash_attention_2", torch_dtype=torch.float32)
, we getbut the snippet fails even though we do use
torch.autocast
as suggested. For mixed-precision training, we do actually want to load the weights infloat32
.Concretely, the source is two different issues:
float32
even within anautocast
context, which is propagated from theq_norm
/v_norm
up until passing the projections to the attention function (FA2 fails here)RMSNorm
usually handling silent upcasting correctly but it seems at some point this broke:transformers/src/transformers/integrations/flash_attention.py
Lines 32 to 36 in ec7afad
cos
andsin
position embeddings for RoPE are infloat32
even within anautocast
context, which will again silently upcast thequery_states
/key_states
tofloat32
before passing to the attention function:transformers/src/transformers/models/llama/modeling_llama.py
Line 275 in ec7afad
float32
is becausecos
andsin
are created in e.g.LlamaModel
:transformers/src/transformers/models/llama/modeling_llama.py
Line 571 in ec7afad
where the
hidden_states
come from thenn.Embedding
which is never autocasted bytorch.autocast
. So:transformers/src/transformers/models/llama/modeling_llama.py
Line 141 in ec7afad
does not work as intended (?) because the input at that point has not been autocasted yet.
One fix is to remove the silent upcasting of the output to
float32
inRMSNorm
if the input isbfloat16
and directly castingcos
andsin
to thetorch.get_autocast_dtype
if in autocast.In this discussion it seems that this might come with some issues so there might have to be some different solution (I am not quite sure of the exact reasons though for potential issues).
It's important to note that through all this silent upcasting, we're probably (I haven't benchmarked though) using a lot of extra memory when doing mixed-precision training (regardless of whether we use
attn_implementation="flash_attention_2"
or not).The text was updated successfully, but these errors were encountered: