-
Notifications
You must be signed in to change notification settings - Fork 305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TransformerEngine FP8 is slower & more memory intensive than FlashAttention FP16? #1119
Comments
I tested the scrips you posted. The reason why TE is slower in your example is due to the fact that import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from time import perf_counter
from tqdm import tqdm
b = 32
n = 1024 #w * h
nhead = 32
nhead_k = 8
head_dim = 128
hidden_dim = 4096
n_iters = 1000
print("BASELINE TEST, %d TOKENS" % n)
#seq = torch.randn((b, n, hidden_dim)).cuda()
seq = torch.randn((n, b, hidden_dim)).cuda().bfloat16()
mha = te.TransformerLayer(hidden_dim, ffn_hidden_size = hidden_dim*4, num_attention_heads = nhead, num_gqa_groups = nhead_k, hidden_dropout=0, attention_dropout=0, kv_channels=head_dim, self_attn_mask_type='no_mask', bias=True).cuda()
mha = mha.bfloat16()
params = list(mha.parameters())
opt = torch.optim.AdamW(params, lr=1e-4)
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)#, amax_history_len=32, amax_compute_algo="max")
num_params = sum(p.numel() for p in mha.parameters())
print(f"Number of parameters: %f M" % (num_params/1e6))
start = perf_counter()
for _ in tqdm(range(n_iters)):
opt.zero_grad()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = mha(seq)
loss = out.sum()
loss.backward()
opt.step()
torch.cuda.synchronize()
print('perf: %f' % (1000*(perf_counter() - start)/n_iters)) (I added the missing imports and cast the model and input to bf16) I get following performance results on H100 PCIe (just fwd/bwd - omitting the optimizer, which gives the constant overhead for all cases):
One additional note is that in your FA script you did not include LayerNorm, but that is a small difference in this case. |
Very helpful, thank you |
hi @ptrendx I have a follow-up question: what is actually cast to fp8 with TE using the code above? First, we cast model to bfloat16, so model weights are 16-bit. Then, we use the fp8_autocast to convert the TE Linear layer activations to FP8. Is that correct? So the activations & gradients for the TE linear layers are in FP8, but the weights, the optimizer states are all bfloat 16? On the other hand, if we don't directly cast to bfloat16, the model weights are in fp32 and stay that way throughout training? What happens if we use torch.cuda.amp.autocast(Bfloat16)?
I am trying to wrap my head around how the FP8 is implemented with TE |
Functionality-wise, you can think of x = x.to(fp8).to(fp32)
weight = weight.to(fp8).to(fp32)
y = linear(x, weight).to(original input type)
y = y + bias What is different in actual execution is that the linear layer actually takes the
We also have an additional context manager |
I'm running some benchmarks on TransformerEngine MHA FP8 versus FlashAttention MHA FP16. However, I'm consistently getting that TE FP8 is not only slower by 50-60% than FlashAttention; it also uses much more memory (11GB vs 27GB).
I'm scratching my head because FP8 should use less memory for the same sequence length + MHA parameters. I'm using the latest TE build from source on 1xH100 and cudnn installed. Here's the benchmarking code:
TransformerEngine:
FlashAttention Benchmark:
Any ideas?
The text was updated successfully, but these errors were encountered: