Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] FP8 MHA with RoPE and Miscellaneous Improvements #1100

Merged
merged 24 commits into from
Sep 5, 2024

Conversation

yaox12
Copy link
Collaborator

@yaox12 yaox12 commented Aug 13, 2024

Description

  1. FP8 MHA with RoPE. Handle Float8Tensor inputs depending on the dtype instead of the fp8_mha flag in DotProductAttention. fp8_mha still ensures the output of DPA is in FP8. With this PR:
    • FP8 DPA workflow (unchanged): LayerNormLinear -> DPA (cast BF16 input to FP8, FP8 DPA, cast output to BF16) -> Linear
    • FP8 MHA workflow:
      • Without RoPE (unchange): LayerNormLinear (output in FP8) -> DPA (FP8 DPA) -> (FP8 input) Linear.
      • With RoPE (new): LayerNormLinear (output in BF16) -> Apply RoPE (output in BF16) -> DPA (cast BF16 input to FP8, FP8 DPA) -> (FP8 input) Linear
  2. Rename is_first_module_in_mha to fp8_output and add this flag to LayerNormLinear, otherwise even the LayerNormLinear in MLP (after MHA) would produce FP8 outputs when fp8_mha=True.
  3. Avoid index_select ops in cast_to_fp8.
  4. Avoid index_select ops in FP8 DPA. I only modified the fwd functions because in backward the CPU overheads are not exposed.
  5. Changed the way we check strides of k and v to avoid creating PyTorch tensors.
  6. Move transpose to backward for Float8Tensor inputs in Linear.

Timeline

As we can see, this PR greatly reduces the CPU overheads in red boxes.

  • Before
    image
  • After
    image

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I like this approach. The current FP8 MHA impl is brittle since it expects the modules to pass specific combinations of Float8Tensor/torch.Tensor. Adding logic so the modules can do casts internally makes this more flexible.

This is similar to how I envision the operation-based API to work. See how we cast inputs in the linear operation:

if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"],
fprop_tensor=True,
)
x_fp8 = Float8Tensor(
data=torch.empty_like(x_local, dtype=torch.uint8),
fp8_meta=input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device),
dtype=dtype,
)
with_cast_transpose = weight.requires_grad
if tensor_parallel_mode == "column" and sequence_parallel:
with_cast_transpose = False
if with_cast_transpose:
x_fp8.cast_transpose_(x_local)
else:
x_fp8.copy_(x_local)
x_local = x_fp8
elif not with_fp8_compute and is_float8_tensor(x_local):
x_local = x_local.from_float8()

We're not there yet, but the goal is to be able to implement FP8 MHA with something like:

model = te.Sequential(
    te.ops.LayerNorm(...),  # fp8 output
    te.ops.Linear(...),
    te.ops.RoPE(...),  # fp8 output
    te.ops.SelfAttention(...),  # fp8 output
    te.ops.Linear(...),
)
with te.fp8_autocast():
    y = model(x)

transformer_engine/pytorch/module/layernorm_linear.py Outdated Show resolved Hide resolved
@timmoon10
Copy link
Collaborator

Regarding further optimizations: removing the select operations would be helpful if it's not too difficult. I've observed that they add non-trivial CPU overhead in other cases, so I recommend looking at #865. You should also be aware that I've made significant changes in the cpp_extensions functions in #1083.

The logic for torch.ops.tex_ts is for ONNX exports, which is based on TorchScript. We only register the ops needed for inference, so that's why FP8 cast is registered with TorchScript while FP8 cast-transpose is a plain Pybind11 function

@yaox12
Copy link
Collaborator Author

yaox12 commented Aug 14, 2024

Regarding further optimizations: removing the select operations would be helpful if it's not too difficult. I've observed that they add non-trivial CPU overhead in other cases, so I recommend looking at #865. You should also be aware that I've made significant changes in the cpp_extensions functions in #1083.

The logic for torch.ops.tex_ts is for ONNX exports, which is based on TorchScript. We only register the ops needed for inference, so that's why FP8 cast is registered with TorchScript while FP8 cast-transpose is a plain Pybind11 function

Thanks for your explanation. cast_to_fp8 is doing index selection in Torchscript and thus calling PyTorch ops. I'll try to move it to plain C++.

@yaox12 yaox12 force-pushed the xiny/fp8_mha_with_rope branch 3 times, most recently from 6e1334d to a1ba977 Compare August 14, 2024 05:34
@yaox12 yaox12 changed the title [PyTorch] FP8 MHA with RoPE [PyTorch] FP8 MHA with RoPE and Miscellaneous Improvements Aug 14, 2024
@yaox12 yaox12 marked this pull request as ready for review August 14, 2024 05:58
transformer_engine/pytorch/attention.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/module/layernorm_linear.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/module/linear.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/csrc/extensions/attention.cu Outdated Show resolved Hide resolved
@timmoon10
Copy link
Collaborator

/te-ci pytorch

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Signed-off-by: Xin Yao <[email protected]>
Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

transformer_engine/pytorch/attention.py Outdated Show resolved Hide resolved
@cyanguwa
Copy link
Collaborator

cyanguwa commented Aug 20, 2024

@yaox12 do we have a test that particularly tests the functionality of FP8 MHA + RoPE? The test should be able to answer your question above as well, regarding the FP8GlobalStateManager.get_fp8_recipe().fp8_mha. Thanks.

Signed-off-by: Xin Yao <[email protected]>
@yaox12
Copy link
Collaborator Author

yaox12 commented Aug 21, 2024

@yaox12 do we have a test that particularly tests the functionality of FP8 MHA + RoPE? The test should be able to answer your question above as well, regarding the FP8GlobalStateManager.get_fp8_recipe().fp8_mha. Thanks.

Thanks. Added RoPE tests.

@timmoon10 timmoon10 self-requested a review August 21, 2024 18:00
@cyanguwa
Copy link
Collaborator

/te-ci pytorch

@yaox12
Copy link
Collaborator Author

yaox12 commented Aug 27, 2024

@cyanguwa I find Flash Attention 3 is not installed in our CI container, so I just skip the FP8 DPA/MHA tests when FA3 is not available, otherwise they will throw the error "no attention backends available".

Another CI failure is tests/pytorch/fused_attn/test_fused_attn.py::test_dpa_mask[mask_9_0-model_configs0-dtype0], and it failes in other PRs too. Seems to be a CI issue.

@cyanguwa
Copy link
Collaborator

cyanguwa commented Aug 27, 2024

@cyanguwa I find Flash Attention 3 is not installed in our CI container, so I just skip the FP8 DPA/MHA tests when FA3 is not available, otherwise they will throw the error "no attention backends available".

Another CI failure is tests/pytorch/fused_attn/test_fused_attn.py::test_dpa_mask[mask_9_0-model_configs0-dtype0], and it failes in other PRs too. Seems to be a CI issue.

Yes, mask_9_0 is a cuDNN 9.4.0.47 issue, and the FP8 tests are getting fixed in #1141.

@yaox12
Copy link
Collaborator Author

yaox12 commented Aug 30, 2024

@timmoon10 Can you review above unresolved comments?

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yaox12
Copy link
Collaborator Author

yaox12 commented Sep 2, 2024

@timmoon10 @cyanguwa Can you trigger the CI?

@yaox12
Copy link
Collaborator Author

yaox12 commented Sep 4, 2024

/te-ci pytorch

Signed-off-by: Xin Yao <[email protected]>
@yaox12
Copy link
Collaborator Author

yaox12 commented Sep 4, 2024

/te-ci pytorch

@yaox12
Copy link
Collaborator Author

yaox12 commented Sep 5, 2024

As Tim and Charlene have approved, and all comments have been resolved, and the CI has passed, I'll merge this PR.

@yaox12 yaox12 merged commit 5fafeb0 into NVIDIA:main Sep 5, 2024
26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants