Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Implement Fp8 padding and unpadding module #1129

Merged

Conversation

BeingGod
Copy link
Contributor

@BeingGod BeingGod commented Aug 22, 2024

Description

Currently the FP8 unpadding backward is implemented by torch.autograd. it involves num_gemms * aten::fill + DtoD call that hurts the MoE model performance in FP8 training. So we implemented Fp8 padding and unpadding module to eliminate the overhead of autograd. The workflow shows below.

Workflow:
image

I show 2% E2E performance gain in our MoE model.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Implement a kernel multi_padding_kernel which fused multi-tensor padding.
  • Implement FP8Padding and FP8Unpadding module.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
I see quite a few code duplications between this file and the existing cast_transpose.cu. Is adding a padding option into cast_transpose and calling it from this multi_pad_cast_transpose_kernel possible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @phu0ngng ,
I think cast_transpose is unnecessary to implement padding.
Reason: In MoE model routing will cause dimension of seq is not multiple of 16 so we should padding it. But for cast_transpose the most of dimension of seq are multiple of 16 (e.g 2048, 4096, 8192...).

Copy link
Collaborator

@phu0ngng phu0ngng Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we can avoid code duplications between cast_transpose and multi_pad_cast_transpose_kernel by templating cast_transpose so that it does padding when needed. Then from multi_pad_cast_transpose_kernel, one can iterate through the loop and call cast_transpose for each gemm with padding enabled.

cast_transpose with a padding option could be reused in other future features.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack.

I'm refactoring code that multi_pad_cast_transpose_kernel maybe removed. But your suggestion is valuable. Thanks a lot.

@BeingGod BeingGod changed the title [PyTorch] Implement of fused padding, cast and transpose kernel [WIP][PyTorch] Implement of fused padding, cast and transpose kernel Aug 27, 2024
@BeingGod BeingGod force-pushed the dev/zhangrb/fused_multi_pad_cast_transpose branch from a6c845d to 63ef882 Compare August 28, 2024 06:55
@BeingGod BeingGod changed the title [WIP][PyTorch] Implement of fused padding, cast and transpose kernel [PyTorch] Implement of Fp8 padding and unpadding module Aug 28, 2024
@BeingGod BeingGod changed the title [PyTorch] Implement of Fp8 padding and unpadding module [PyTorch] Implement Fp8 padding and unpadding module Aug 28, 2024
Copy link
Collaborator

@yaox12 yaox12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM.
I feed we could have better names for Fp8Padding /Fp8Unpadding, such as MultiPadding/MultiUnpadding. cc @phu0ngng

@@ -221,6 +221,14 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype)


def cast_if_needed_by_actual_dtype(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems never been used. Should we remove it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack, thx.

@@ -659,6 +659,7 @@ def prepare_forward(
is_first_microbatch: Union[bool, None], # pylint: disable=unused-argument
num_gemms: int = 1,
allow_non_contiguous: bool = False,
with_param: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems never been used.

Copy link
Contributor Author

@BeingGod BeingGod Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack, thx.

Comment on lines 35 to 42
inputmats = torch.split(inp.view(-1, in_features), m_splits)

# Allocate cast and transpose output tensor
total_row = sum(padded_m_splits)
out = torch.empty([total_row, in_features], dtype=inp.dtype, device="cuda")
out_list = torch.split(out, padded_m_splits)

multi_padding_fused(inputmats, padded_m_splits, out_list)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to make multi_padding_fused accept whole tensors as inp/out,

inp = inp.view(-1, in_features)
multi_padding_fused(inp, m_splits, padded_m_splits, out)

and replace torch.split with stepping the pointers in C++ to reduce CPU overheads, as I'm doing in https://github.com/NVIDIA/TransformerEngine/pull/1128/files#diff-342b0e9e5b472b443484b3a2c4a78647cd72431c272ed44c678ecbe636fb7a3aR191-R194

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack. Your work is helpful to me.

Thx.

@BeingGod
Copy link
Contributor Author

BeingGod commented Aug 28, 2024

Generally LGTM. I feed we could have better names for Fp8Padding /Fp8Unpadding, such as MultiPadding/MultiUnpadding. cc @phu0ngng

Hi @yaox12, thanks for your suggestion. I wonder if MultiPadding/MultiUnpadding should be support customized padding number (e.g 8, 32...) ?
Current padding and unpadding module only support padding to 16 (for FP8). So I added Fp8 prefix for padding and unpadding module.

@phu0ngng
Copy link
Collaborator

Hi @yaox12, thanks for your suggestion. I wonder if MultiPadding/MultiUnpadding should support customized padding numbers (e.g. 8, 32...)? The current padding and unpadding module only supports padding to 16 (for FP8). So I added the Fp8 prefix for padding and unpadding modules.

Are there any numbers on E2E performance gain in your MoE model for other non-FP8 types (BF16 for example) with this padding/unpadding?

@BeingGod
Copy link
Contributor Author

BeingGod commented Aug 28, 2024

Hi @yaox12, thanks for your suggestion. I wonder if MultiPadding/MultiUnpadding should support customized padding numbers (e.g. 8, 32...)? The current padding and unpadding module only supports padding to 16 (for FP8). So I added the Fp8 prefix for padding and unpadding modules.

Are there any numbers on E2E performance gain in your MoE model for other non-FP8 types (BF16 for example) with this padding/unpadding?

I don't have non-FP8 types performance data now but it is a good idea. I will do some benchmark for non-FP8 types with padding/unpadding.

Update:
Hi @phu0ngng. I have do some benchmark for BF16 with padding/unpadding. I set multiple of padding to 4,8,16. It seems hurt E2E performance that using padding/unpadding for BF16. Perhaps it is meaningless in current that supports other multiple of padding.

image

@phu0ngng
Copy link
Collaborator

Hi @phu0ngng. I have done some benchmarks for BF16 with padding/unpadding. I set multiple of padding to 4,8,16. It seems to hurt E2E performance that uses padding/unpadding for BF16. Perhaps it is meaningless in the current that supports other multiple padding.

Agree. I think we can keep FP8 in the name for now. Thanks.

@BeingGod BeingGod force-pushed the dev/zhangrb/fused_multi_pad_cast_transpose branch from e652a43 to a1ae467 Compare August 30, 2024 03:15
 1. Add multi-tensor padding kernel
 2. Add FP8Padding and Fp8Unpadding module
 3. Add padding grouped linear UT case

Signed-off-by: beinggod <[email protected]>
@BeingGod BeingGod force-pushed the dev/zhangrb/fused_multi_pad_cast_transpose branch from a1ae467 to c5fd1b4 Compare August 30, 2024 03:33
Copy link
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@phu0ngng
Copy link
Collaborator

phu0ngng commented Sep 4, 2024

/te-ci pytorch

@phu0ngng phu0ngng merged commit 215db88 into NVIDIA:main Sep 5, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants