Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixed-precision with torch.autocast is broken for many models when using attn_implementation="flash_attention_2" #35945

Open
konstantinjdobler opened this issue Jan 28, 2025 · 2 comments
Labels

Comments

@konstantinjdobler
Copy link
Contributor

konstantinjdobler commented Jan 28, 2025

System Info

transformers==4.481
python=3.11.11

Who can help?

@ArthurZucker

Expected behavior

Mixed-precision training via torch.autocast is broken for most models inspired by the HF Llama code (which is a lot of models) when using attn_implementation="flash_attention_2 and potentially not working as intended on general.

Snippet to reproduce on transformers==4.48.1:

torch.set_default_device("cuda:0")
model_name = "meta-llama/Llama-3.2-3B" # many others, e.g. "allenai/OLMo-2-1124-7B"
inputs = AutoTokenizer.from_pretrained(model_name)("I ❤️ 🤗", return_tensors="pt")
model = AutoModel.from_pretrained(model_name, attn_implementation="flash_attention_2", torch_dtype=torch.float32)

with torch.autocast("cuda", dtype=torch.bfloat16):
	# errors with -> RuntimeError: FlashAttention only support fp16 and bf16 data type
	outputs = model(**inputs)

Indeed after calling AutoModel.from_pretrained(model_name, attn_implementation="flash_attention_2", torch_dtype=torch.float32), we get

Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. ...`

but the snippet fails even though we do use torch.autocast as suggested. For mixed-precision training, we do actually want to load the weights in float32.

Concretely, the source is two different issues:

  • The common llama-inspired implementation of RMSNorm silently upcasts to float32 even within an autocast context, which is propagated from the q_norm / v_norm up until passing the projections to the attention function (FA2 fails here)
    • The RMSNorm issue has been discussed in these related issues: here and here and here
    • We have this comment in the FlashAttention integration about RMSNorm usually handling silent upcasting correctly but it seems at some point this broke:
      # In PEFT, usually we cast the layer norms in float32 for training stability reasons
      # therefore the input hidden states gets silently casted in float32. Hence, we need
      # cast them back in the correct dtype just to be sure everything works as expected.
      # This might slowdown training & inference so it is recommended to not cast the LayerNorms
      # in fp32. (usually our RMSNorm modules handle it correctly)
  • The cos and sin position embeddings for RoPE are in float32 even within an autocast context, which will again silently upcast the query_states/key_states to float32 before passing to the attention function:
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

One fix is to remove the silent upcasting of the output to float32 in RMSNorm if the input is bfloat16 and directly casting cos and sin to the torch.get_autocast_dtype if in autocast.
In this discussion it seems that this might come with some issues so there might have to be some different solution (I am not quite sure of the exact reasons though for potential issues).

It's important to note that through all this silent upcasting, we're probably (I haven't benchmarked though) using a lot of extra memory when doing mixed-precision training (regardless of whether we use attn_implementation="flash_attention_2" or not).

@konstantinjdobler
Copy link
Contributor Author

def patch_broken_autocast_llama():
    class FixedLlamaRMSNorm(LlamaRMSNorm):
        def forward(self, hidden_states):
            input_dtype = hidden_states.dtype
            hidden_states = hidden_states.to(torch.float32)
            variance = hidden_states.pow(2).mean(-1, keepdim=True)
            hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            # fix silent upcasting of output (note that we keep the inner .to() to *only* cast the output and no inner computation)
            return (self.weight * hidden_states.to(input_dtype)).to(input_dtype)
            # return self.weight * hidden_states.to(input_dtype)

    transformers.models.llama.modeling_llama.LlamaRMSNorm = FixedLlamaRMSNorm

    class FixedLlamaRotaryEmbedding(LlamaRotaryEmbedding):
        @torch.no_grad()
        def forward(self, x, position_ids):
            if "dynamic" in self.rope_type:
                self._dynamic_frequency_update(position_ids, device=x.device)

            # Core RoPE block
            inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
            position_ids_expanded = position_ids[:, None, :].float()
            # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
            device_type = x.device.type
            device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
            with torch.autocast(device_type=device_type, enabled=False):
                freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
                emb = torch.cat((freqs, freqs), dim=-1)
                cos = emb.cos()
                sin = emb.sin()

            # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
            cos = cos * self.attention_scaling
            sin = sin * self.attention_scaling

            # fix silent upcasting of output
            output_dtype = x.dtype
            if torch.is_autocast_enabled():
                output_dtype = torch.get_autocast_dtype(device_type=device_type)
            return cos.to(dtype=output_dtype), sin.to(dtype=output_dtype)
            # return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

    transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = FixedLlamaRotaryEmbedding

Snippet applying a patch solution to this problem for Llama (applies 1-to-1 to many other models), running this before the error repro snippet solves the RuntimeError`. Happy to submit a PR but definitely needs some discussion on what you think is the correct way to solve this as I know these precision issues have been a problem in the past as well.

@konstantinjdobler
Copy link
Contributor Author

As a sidenote, is it intended that we always keep the residual stream in float32 even in mixed-precision training? because of these lines:

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

while the computation in self.self_attn and self.mlp are autocasted and their outputs are bfloat16, adding back onto the float32 hidden states silently upcasts to float32 -> the residual stream is never autocasted since the initial hidden_states at layer 0 is still float32 as it (usually) comes from a nn.Embedding (which is not autocasted by torch.autocast).

Can also open a separate issue if you prefer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant