Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TransformerEngine FP8 is slower & more memory intensive than FlashAttention FP16? #1119

Closed
darius-lam opened this issue Aug 18, 2024 · 4 comments

Comments

@darius-lam
Copy link

darius-lam commented Aug 18, 2024

I'm running some benchmarks on TransformerEngine MHA FP8 versus FlashAttention MHA FP16. However, I'm consistently getting that TE FP8 is not only slower by 50-60% than FlashAttention; it also uses much more memory (11GB vs 27GB).

I'm scratching my head because FP8 should use less memory for the same sequence length + MHA parameters. I'm using the latest TE build from source on 1xH100 and cudnn installed. Here's the benchmarking code:

TransformerEngine:

b = 32
n = 1024 #w * h
nhead = 32
nhead_k = 8
head_dim = 128

hidden_dim = 4096
n_iters = 1000

print("BASELINE TEST, %d TOKENS" % n)

#seq = torch.randn((b, n, hidden_dim)).cuda()
seq = torch.randn((n, b, hidden_dim)).cuda()
mha = te.TransformerLayer(hidden_dim, ffn_hidden_size = hidden_dim*4, num_attention_heads = nhead, num_gqa_groups = nhead_k, hidden_dropout=0, attention_dropout=0, kv_channels=head_dim, self_attn_mask_type='no_mask', bias=True).cuda()
params = list(mha.parameters())
opt = torch.optim.AdamW(params, lr=1e-4)

fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)#, amax_history_len=32, amax_compute_algo="max")

num_params = sum(p.numel() for p in mha.parameters())
print(f"Number of parameters: %f M" % (num_params/1e6))

start = perf_counter()
for _ in tqdm(range(n_iters)):
    opt.zero_grad()
    
    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        out = mha(seq)
        loss = out.sum()

    loss.backward()
    opt.step()

torch.cuda.synchronize()
print('perf: %f' % (1000*(perf_counter() - start)/n_iters))

FlashAttention Benchmark:

from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
from flash_attn.ops.fused_dense import FusedDense
from flash_attn.modules.mha import MHA

# Baseline flash_attn test

b = 32
n = 1024
nhead = 32
nhead_k = 8
head_dim = 64

hidden_dim = 4096
n_iters = 1000

print("BASELINE TEST, %d TOKENS" % n)

seq = torch.randn((b, n, hidden_dim)).cuda().to(torch.bfloat16)
class TransformerLayer(nn.Module):
    def __init__(self, hidden_dim, nhead, nhead_k, ffn_hidden_size):
        super().__init__()
        self.mha = MHA(hidden_dim, nhead, nhead_k, use_flash_attn=True, fused_bias_fc=True, causal=False)

        self.ffn = nn.Sequential(
            FusedDense(hidden_dim, ffn_hidden_size, bias=True),
            nn.GELU(),
            FusedDense(ffn_hidden_size, hidden_dim, bias=True),
        )

    def forward(self, x):
        x = x + self.mha(x)
        x = x + self.ffn(x)
        return x 
    
#mha = MHA(hidden_dim, nhead, nhead_k, use_flash_attn=True, fused_bias_fc=True, causal=False).cuda()
mha = TransformerLayer(hidden_dim, nhead, nhead_k, ffn_hidden_size = hidden_dim*4).cuda()

params = list(mha.parameters())
opt = torch.optim.AdamW(params, lr=1e-4)
num_params = sum(p.numel() for p in mha.parameters())
print(f"Number of parameters: %f M" % (num_params/1e6))

a = perf_counter()
for _ in tqdm(range(n_iters)):
    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        opt.zero_grad()
        out = mha(seq)
        loss = out.sum()
        
        loss.backward()
        opt.step()

print('flash_attn_func: %f' % (1000*(perf_counter() - a)/n_iters))

Any ideas?

@ptrendx
Copy link
Member

ptrendx commented Aug 19, 2024

I tested the scrips you posted. The reason why TE is slower in your example is due to the fact that fp8_autocast is only affecting the FP8 execution, but we do preserve the precision of the other parts of the model. In this case it is FP32, since you did not use AMP or cast the model to FP16/BF16 there. In the default recipe in TE attention is not being computed in FP8, and so without casting the model it uses the original precision (FP32). Neither flash attention nor cuDNN attention backends in TE support FP32 execution, and so what you got was the slowest, unfused attention (which additionally uses the most memory), which resulted in the poor performance you observed.
Changing your TE script to this:

import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from time import perf_counter
from tqdm import tqdm

b = 32
n = 1024 #w * h
nhead = 32
nhead_k = 8
head_dim = 128

hidden_dim = 4096
n_iters = 1000

print("BASELINE TEST, %d TOKENS" % n)

#seq = torch.randn((b, n, hidden_dim)).cuda()
seq = torch.randn((n, b, hidden_dim)).cuda().bfloat16()
mha = te.TransformerLayer(hidden_dim, ffn_hidden_size = hidden_dim*4, num_attention_heads = nhead, num_gqa_groups = nhead_k, hidden_dropout=0, attention_dropout=0, kv_channels=head_dim, self_attn_mask_type='no_mask', bias=True).cuda()
mha = mha.bfloat16()
params = list(mha.parameters())
opt = torch.optim.AdamW(params, lr=1e-4)

fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)#, amax_history_len=32, amax_compute_algo="max")

num_params = sum(p.numel() for p in mha.parameters())
print(f"Number of parameters: %f M" % (num_params/1e6))

start = perf_counter()
for _ in tqdm(range(n_iters)):
    opt.zero_grad()
    
    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        out = mha(seq)
        loss = out.sum()

    loss.backward()
    opt.step()

torch.cuda.synchronize()
print('perf: %f' % (1000*(perf_counter() - start)/n_iters))

(I added the missing imports and cast the model and input to bf16) I get following performance results on H100 PCIe (just fwd/bwd - omitting the optimizer, which gives the constant overhead for all cases):

  • TE (original script): 104 ms
  • Flash: 95 ms
  • TE (new script): 44.6 ms
  • TE (new script, fp8 disabled, so pure BF16 execution for comparison): 65.6 ms

One additional note is that in your FA script you did not include LayerNorm, but that is a small difference in this case.

@darius-lam
Copy link
Author

Very helpful, thank you

@darius-lam
Copy link
Author

darius-lam commented Aug 23, 2024

hi @ptrendx I have a follow-up question: what is actually cast to fp8 with TE using the code above? First, we cast model to bfloat16, so model weights are 16-bit. Then, we use the fp8_autocast to convert the TE Linear layer activations to FP8. Is that correct? So the activations & gradients for the TE linear layers are in FP8, but the weights, the optimizer states are all bfloat 16?

On the other hand, if we don't directly cast to bfloat16, the model weights are in fp32 and stay that way throughout training? What happens if we use torch.cuda.amp.autocast(Bfloat16)?

GEMM Weight Gradient Optimizer State Activation Comm Gradient Comm
Model BFloat16 / AMP + Cast TF8 fp8 bf16 fp32 (?) fp32 (?) bf16 (?) bf16 (?)
Model No BFloat16 + Cast TF8 fp8 fp32 fp8 (?) fp32 (?) fp32 (?) fp32 (?)
Model No BFloat16 + No Cast TF8 fp32 fp32 fp32 fp32 fp32 fp32

I am trying to wrap my head around how the FP8 is implemented with TE

@ptrendx
Copy link
Member

ptrendx commented Aug 23, 2024

Functionality-wise, you can think of fp8_autocast as changing just the internal execution of the operators, so a functionally equivalent execution of the forward pass would be:

x = x.to(fp8).to(fp32)
weight = weight.to(fp8).to(fp32)
y = linear(x, weight).to(original input type)
y = y + bias

What is different in actual execution is that the linear layer actually takes the x and weight in fp8 and only the internal accumulator is in FP32, but the result is the same.
So in your table:

  • with AMP:
    • weights and weight gradients are FP32;
    • activations and data gradients are BF16;
    • optimizer state is FP32;
    • GEMM casts things to FP8 if in fp8_autocast just in time but outputs BF16/FP32
  • with bf16 cast:
    • weights and weight gradients are BF16;
    • activations and data gradients are BF16;
    • optimizer state depends on your optimizer implementation - regular optimizers would make it BF16 and you would need master weights support to keep the state in FP32;
    • GEMM casts things to FP8 if in fp8_autocast just in time but outputs BF16

We also have an additional context manager fp8_model_init which makes the layers actually hold FP8 parameters rather than higher precision ones. It is not the default behavior though since the user needs to make sure that they do have a high precision copy of the parameters somewhere else (or that those parameters are not actually trainable, like in inference or Lora), otherwise the convergence would suffer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants