-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Implement Fp8 padding and unpadding module #1129
[PyTorch] Implement Fp8 padding and unpadding module #1129
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
I see quite a few code duplications between this file and the existing cast_transpose.cu
. Is adding a padding option into cast_transpose
and calling it from this multi_pad_cast_transpose_kernel
possible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @phu0ngng ,
I think cast_transpose
is unnecessary to implement padding.
Reason: In MoE model routing will cause dimension of seq is not multiple of 16 so we should padding it. But for cast_transpose
the most of dimension of seq are multiple of 16 (e.g 2048, 4096, 8192...).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean we can avoid code duplications between cast_transpose
and multi_pad_cast_transpose_kernel
by templating cast_transpose
so that it does padding when needed. Then from multi_pad_cast_transpose_kernel
, one can iterate through the loop and call cast_transpose
for each gemm with padding enabled.
cast_transpose
with a padding option could be reused in other future features.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack.
I'm refactoring code that multi_pad_cast_transpose_kernel
maybe removed. But your suggestion is valuable. Thanks a lot.
a6c845d
to
63ef882
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally LGTM.
I feed we could have better names for Fp8Padding
/Fp8Unpadding
, such as MultiPadding
/MultiUnpadding
. cc @phu0ngng
transformer_engine/pytorch/utils.py
Outdated
@@ -221,6 +221,14 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: | |||
return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype) | |||
|
|||
|
|||
def cast_if_needed_by_actual_dtype( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems never been used. Should we remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack, thx.
@@ -659,6 +659,7 @@ def prepare_forward( | |||
is_first_microbatch: Union[bool, None], # pylint: disable=unused-argument | |||
num_gemms: int = 1, | |||
allow_non_contiguous: bool = False, | |||
with_param: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems never been used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack, thx.
inputmats = torch.split(inp.view(-1, in_features), m_splits) | ||
|
||
# Allocate cast and transpose output tensor | ||
total_row = sum(padded_m_splits) | ||
out = torch.empty([total_row, in_features], dtype=inp.dtype, device="cuda") | ||
out_list = torch.split(out, padded_m_splits) | ||
|
||
multi_padding_fused(inputmats, padded_m_splits, out_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to make multi_padding_fused
accept whole tensors as inp/out,
inp = inp.view(-1, in_features)
multi_padding_fused(inp, m_splits, padded_m_splits, out)
and replace torch.split
with stepping the pointers in C++ to reduce CPU overheads, as I'm doing in https://github.com/NVIDIA/TransformerEngine/pull/1128/files#diff-342b0e9e5b472b443484b3a2c4a78647cd72431c272ed44c678ecbe636fb7a3aR191-R194
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack. Your work is helpful to me.
Thx.
Hi @yaox12, thanks for your suggestion. I wonder if |
Are there any numbers on E2E performance gain in your MoE model for other non-FP8 types (BF16 for example) with this padding/unpadding? |
I don't have non-FP8 types performance data now but it is a good idea. I will do some benchmark for non-FP8 types with padding/unpadding. Update: |
Agree. I think we can keep |
e652a43
to
a1ae467
Compare
1. Add multi-tensor padding kernel 2. Add FP8Padding and Fp8Unpadding module 3. Add padding grouped linear UT case Signed-off-by: beinggod <[email protected]>
a1ae467
to
c5fd1b4
Compare
Signed-off-by: beinggod <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
/te-ci pytorch |
Description
Currently the FP8 unpadding backward is implemented by torch.autograd. it involves num_gemms * aten::fill + DtoD call that hurts the MoE model performance in FP8 training. So we implemented Fp8 padding and unpadding module to eliminate the overhead of autograd. The workflow shows below.
Workflow:
I show 2% E2E performance gain in our MoE model.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
multi_padding_kernel
which fused multi-tensor padding.FP8Padding
andFP8Unpadding
module.Checklist: