-
Notifications
You must be signed in to change notification settings - Fork 305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] FP8 MHA with RoPE and Miscellaneous Improvements #1100
Conversation
4bdddfd
to
b10d27f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall I like this approach. The current FP8 MHA impl is brittle since it expects the modules to pass specific combinations of Float8Tensor
/torch.Tensor
. Adding logic so the modules can do casts internally makes this more flexible.
This is similar to how I envision the operation-based API to work. See how we cast inputs in the linear operation:
TransformerEngine/transformer_engine/pytorch/ops/basic/basic_linear.py
Lines 465 to 488 in ec49a52
if with_fp8_compute and not is_float8_tensor(x_local): | |
fp8_dtype = get_fp8_te_dtype( | |
input_fp8_meta["recipe"], | |
fprop_tensor=True, | |
) | |
x_fp8 = Float8Tensor( | |
data=torch.empty_like(x_local, dtype=torch.uint8), | |
fp8_meta=input_fp8_meta, | |
fp8_meta_forward=True, | |
fp8_meta_index=0, | |
fp8_dtype=fp8_dtype, | |
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), | |
dtype=dtype, | |
) | |
with_cast_transpose = weight.requires_grad | |
if tensor_parallel_mode == "column" and sequence_parallel: | |
with_cast_transpose = False | |
if with_cast_transpose: | |
x_fp8.cast_transpose_(x_local) | |
else: | |
x_fp8.copy_(x_local) | |
x_local = x_fp8 | |
elif not with_fp8_compute and is_float8_tensor(x_local): | |
x_local = x_local.from_float8() |
We're not there yet, but the goal is to be able to implement FP8 MHA with something like:
model = te.Sequential(
te.ops.LayerNorm(...), # fp8 output
te.ops.Linear(...),
te.ops.RoPE(...), # fp8 output
te.ops.SelfAttention(...), # fp8 output
te.ops.Linear(...),
)
with te.fp8_autocast():
y = model(x)
Regarding further optimizations: removing the select operations would be helpful if it's not too difficult. I've observed that they add non-trivial CPU overhead in other cases, so I recommend looking at #865. You should also be aware that I've made significant changes in the The logic for |
Thanks for your explanation. |
6e1334d
to
a1ba977
Compare
a1ba977
to
2af460b
Compare
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
2af460b
to
33c3ed6
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Xin Yao <[email protected]>
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: Xin Yao <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@yaox12 do we have a test that particularly tests the functionality of FP8 MHA + RoPE? The test should be able to answer your question above as well, regarding the |
Signed-off-by: Xin Yao <[email protected]>
1ca1860
to
521c77a
Compare
Signed-off-by: Xin Yao <[email protected]>
Thanks. Added RoPE tests. |
for more information, see https://pre-commit.ci
Signed-off-by: Xin Yao <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Xin Yao <[email protected]>
/te-ci pytorch |
Signed-off-by: Xin Yao <[email protected]>
@cyanguwa I find Flash Attention 3 is not installed in our CI container, so I just skip the FP8 DPA/MHA tests when FA3 is not available, otherwise they will throw the error "no attention backends available". Another CI failure is |
Yes, |
@timmoon10 Can you review above unresolved comments? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@timmoon10 @cyanguwa Can you trigger the CI? |
/te-ci pytorch |
Signed-off-by: Xin Yao <[email protected]>
/te-ci pytorch |
As Tim and Charlene have approved, and all comments have been resolved, and the CI has passed, I'll merge this PR. |
Description
Float8Tensor
inputs depending on the dtype instead of thefp8_mha
flag inDotProductAttention
.fp8_mha
still ensures the output of DPA is in FP8. With this PR:is_first_module_in_mha
tofp8_output
and add this flag toLayerNormLinear
, otherwise even theLayerNormLinear
in MLP (after MHA) would produce FP8 outputs whenfp8_mha=True
.index_select
ops incast_to_fp8
.index_select
ops in FP8 DPA. I only modified the fwd functions because in backward the CPU overheads are not exposed.Float8Tensor
inputs inLinear
.Timeline
As we can see, this PR greatly reduces the CPU overheads in red boxes.
Type of change
Checklist: