diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index c9504c20af..a210019dc1 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -44,3 +44,7 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.make_graphed_callables .. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context + +.. autoapifunction:: transformer_engine.pytorch.moe_permute + +.. autoapifunction:: transformer_engine.pytorch.moe_unpermute diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index e6ccf3b82f..b69aed6648 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -23,3 +23,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py +pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 99bd706b45..ed25b96955 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -220,9 +220,7 @@ def _test_permutation( te_permute_fwd_input.requires_grad_(True) te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() - te_permute_output, row_id_map = te_permute( - te_permute_fwd_input, te_dtype, indices, num_out_tokens - ) + te_permute_output, row_id_map = te_permute(te_permute_fwd_input, indices, num_out_tokens) te_permute_output.backward(te_permute_bwd_input, retain_graph=True) te_probs = None @@ -233,7 +231,7 @@ def _test_permutation( te_unpermute_fwd_input.requires_grad_(True) te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() - te_unpermute_output = te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs) + te_unpermute_output = te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) ################################################################################################################################### @@ -305,7 +303,7 @@ def backward_wrapper( lambda: pytorch_permute(pytorch_permute_fwd_input, indices, num_out_tokens) ) t2 = perf_test_cuda_kernel( - lambda: te_permute(te_permute_fwd_input, te_dtype, indices, num_out_tokens) + lambda: te_permute(te_permute_fwd_input, indices, num_out_tokens) ) print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") @@ -333,7 +331,7 @@ def backward_wrapper( lambda: pytorch_unpermute(pytorch_unpermute_fwd_input, sorted_indices, probs=probs) ) t2 = perf_test_cuda_kernel( - lambda: te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs) + lambda: te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs) ) print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 0c098830a9..9987db58e0 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -8,7 +8,8 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from .constants import TE_DType +from .float8_tensor import Float8Tensor __all__ = [ @@ -27,14 +28,13 @@ class _moe_permute(torch.autograd.Function): def forward( ctx, inp: torch.Tensor, - dtype: tex.DType, indices: torch.Tensor, num_out_tokens: int, max_token_num: int, ) -> Tuple[torch.Tensor, torch.Tensor]: # Empty input check if not inp.numel(): - return inp, None + return inp, torch.tensor([], device=inp.device) # Device check assert inp.is_cuda, "TransformerEngine needs CUDA." @@ -43,16 +43,13 @@ def forward( assert inp.size(0) == indices.size(0), "Permute not possible" # Data type check - fp8 = False - if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - fp8 = True + fp8 = isinstance(inp, Float8Tensor) if fp8: - assert isinstance( - inp, Float8Tensor - ), "Input must be in Float8Tensor type for FP8 moe_permute." - fp8_dtype = inp._fp8_dtype + dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv inp = inp._data + else: + dtype = TE_DType[inp.dtype] if indices.dtype != torch.int32: warnings.warn( f"The data type of the input `indices` of Permute is {indices.dtype}! " @@ -78,13 +75,12 @@ def forward( if fp8: permuted_act = Float8Tensor( - data=permuted_act, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + data=permuted_act, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv ) ctx.row_id_map = row_id_map ctx.num_tokens = indices.size(0) ctx.topK = indices.size(1) - ctx.dtype = dtype ctx.fp8 = fp8 return permuted_act, row_id_map @@ -101,30 +97,27 @@ def backward( if not permuted_act_grad.is_contiguous(): permuted_act_grad = permuted_act_grad.contiguous() - fp8 = ctx.fp8 - if fp8: + if ctx.fp8: assert isinstance( permuted_act_grad, Float8Tensor ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." - fp8_dtype = permuted_act_grad._fp8_dtype + dtype = permuted_act_grad._fp8_dtype fp8_scale_inv = permuted_act_grad._scale_inv permuted_act_grad = permuted_act_grad._data - - row_id_map = ctx.row_id_map - num_tokens = ctx.num_tokens - topK = ctx.topK + else: + dtype = TE_DType[permuted_act_grad.dtype] act_grad = None if ctx.needs_input_grad[0]: act_grad = tex.moe_permute_bwd( - permuted_act_grad, ctx.dtype, row_id_map, torch.empty(0), num_tokens, topK + permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK ) - if fp8: + if ctx.fp8: act_grad = Float8Tensor( - data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv * topK + data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv * ctx.topK ) - return act_grad, None, None, None, None + return act_grad, None, None, None class _moe_unpermute(torch.autograd.Function): @@ -134,7 +127,6 @@ class _moe_unpermute(torch.autograd.Function): def forward( ctx, inp: torch.Tensor, - dtype: tex.DType, row_id_map: torch.Tensor, probs: torch.Tensor, ) -> torch.Tensor: @@ -166,16 +158,13 @@ def forward( assert row_id_map.is_cuda, "TransformerEngine needs CUDA." # Data type check - fp8 = False - if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - fp8 = True + fp8 = isinstance(inp, Float8Tensor) if fp8: - assert isinstance( - inp, Float8Tensor - ), "Input must be in Float8Tensor type for FP8 moe_unpermute." - fp8_dtype = inp._fp8_dtype + dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv inp = inp._data + else: + dtype = TE_DType[inp.dtype] if row_id_map.dtype != torch.int32: warnings.warn( f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " @@ -187,10 +176,9 @@ def forward( if fp8: unpermuted_output = Float8Tensor( - data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + data=unpermuted_output, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv ) - ctx.dtype = dtype ctx.save_for_backward(inp, row_id_map, probs) ctx.fp8 = fp8 return unpermuted_output @@ -207,35 +195,33 @@ def backward( if not unpermuted_act_grad.is_contiguous(): unpermuted_act_grad = unpermuted_act_grad.contiguous() - fp8 = ctx.fp8 - if fp8: + if ctx.fp8: assert isinstance( unpermuted_act_grad, Float8Tensor ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute." - fp8_dtype = unpermuted_act_grad._fp8_dtype + dtype = unpermuted_act_grad._fp8_dtype fp8_scale_inv = unpermuted_act_grad._scale_inv unpermuted_act_grad = unpermuted_act_grad._data + else: + dtype = TE_DType[unpermuted_act_grad.dtype] inp, row_id_map, probs = ctx.saved_tensors act_grad = None if ctx.needs_input_grad[0]: act_grad, prob_grad = tex.moe_unpermute_bwd( - unpermuted_act_grad, inp, ctx.dtype, row_id_map, probs + unpermuted_act_grad, inp, dtype, row_id_map, probs ) - if fp8: - act_grad = Float8Tensor( - data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv - ) - if not ctx.needs_input_grad[3]: + if ctx.fp8: + act_grad = Float8Tensor(data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv) + if not ctx.needs_input_grad[2]: prob_grad = None - return act_grad, None, None, prob_grad + return act_grad, None, prob_grad def moe_permute( inp: torch.Tensor, - dtype: tex.DType, indices: torch.Tensor, num_out_tokens: int = -1, max_token_num: int = -1, @@ -247,8 +233,6 @@ def moe_permute( ---------- inp: torch.Tensor Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. - dtype: tex.DType - Data type of the input tensor. indices: torch.Tensor The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'. num_out_tokens: int, default = -1 @@ -259,12 +243,11 @@ def moe_permute( By default, set to '-1', meaning the calculation of the size of workspace is automatically taken over by the operator. """ - return _moe_permute.apply(inp, dtype, indices, num_out_tokens, max_token_num) + return _moe_permute.apply(inp, indices, num_out_tokens, max_token_num) def moe_unpermute( inp: torch.Tensor, - dtype: tex.DType, row_id_map: torch.Tensor, probs: torch.Tensor = None, ) -> torch.Tensor: @@ -276,8 +259,6 @@ def moe_unpermute( ---------- inp: torch.Tensor Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted. - dtype: tex.DType - Data type of the input tensor. row_id_map: torch.Tensor The tensor of a mapping table for sorted indices used to unpermute the tokens, which is the second output tensor of `Permute`. @@ -286,4 +267,4 @@ def moe_unpermute( the unpermuted tokens will be merged with their respective probabilities. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. """ - return _moe_unpermute.apply(inp, dtype, row_id_map, probs) + return _moe_unpermute.apply(inp, row_id_map, probs)