Skip to content

Commit

Permalink
[PyTorch] Remove dtype from args of permutation (#1145)
Browse files Browse the repository at this point in the history
* remove dtype from args
* update docs with permutation ops

---------

Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Aug 29, 2024
1 parent 4ddb0a7 commit 8ddac3d
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 57 deletions.
4 changes: 4 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables

.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context

.. autoapifunction:: transformer_engine.pytorch.moe_permute

.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py
10 changes: 4 additions & 6 deletions tests/pytorch/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,7 @@ def _test_permutation(
te_permute_fwd_input.requires_grad_(True)
te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach()

te_permute_output, row_id_map = te_permute(
te_permute_fwd_input, te_dtype, indices, num_out_tokens
)
te_permute_output, row_id_map = te_permute(te_permute_fwd_input, indices, num_out_tokens)
te_permute_output.backward(te_permute_bwd_input, retain_graph=True)

te_probs = None
Expand All @@ -233,7 +231,7 @@ def _test_permutation(
te_unpermute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach()

te_unpermute_output = te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs)
te_unpermute_output = te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs)
te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True)

###################################################################################################################################
Expand Down Expand Up @@ -305,7 +303,7 @@ def backward_wrapper(
lambda: pytorch_permute(pytorch_permute_fwd_input, indices, num_out_tokens)
)
t2 = perf_test_cuda_kernel(
lambda: te_permute(te_permute_fwd_input, te_dtype, indices, num_out_tokens)
lambda: te_permute(te_permute_fwd_input, indices, num_out_tokens)
)
print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")

Expand Down Expand Up @@ -333,7 +331,7 @@ def backward_wrapper(
lambda: pytorch_unpermute(pytorch_unpermute_fwd_input, sorted_indices, probs=probs)
)
t2 = perf_test_cuda_kernel(
lambda: te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs)
lambda: te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs)
)
print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")

Expand Down
83 changes: 32 additions & 51 deletions transformer_engine/pytorch/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import torch

import transformer_engine_torch as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from .constants import TE_DType
from .float8_tensor import Float8Tensor


__all__ = [
Expand All @@ -27,14 +28,13 @@ class _moe_permute(torch.autograd.Function):
def forward(
ctx,
inp: torch.Tensor,
dtype: tex.DType,
indices: torch.Tensor,
num_out_tokens: int,
max_token_num: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Empty input check
if not inp.numel():
return inp, None
return inp, torch.tensor([], device=inp.device)

# Device check
assert inp.is_cuda, "TransformerEngine needs CUDA."
Expand All @@ -43,16 +43,13 @@ def forward(
assert inp.size(0) == indices.size(0), "Permute not possible"

# Data type check
fp8 = False
if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
fp8 = True
fp8 = isinstance(inp, Float8Tensor)
if fp8:
assert isinstance(
inp, Float8Tensor
), "Input must be in Float8Tensor type for FP8 moe_permute."
fp8_dtype = inp._fp8_dtype
dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
inp = inp._data
else:
dtype = TE_DType[inp.dtype]
if indices.dtype != torch.int32:
warnings.warn(
f"The data type of the input `indices` of Permute is {indices.dtype}! "
Expand All @@ -78,13 +75,12 @@ def forward(

if fp8:
permuted_act = Float8Tensor(
data=permuted_act, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
data=permuted_act, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv
)

ctx.row_id_map = row_id_map
ctx.num_tokens = indices.size(0)
ctx.topK = indices.size(1)
ctx.dtype = dtype
ctx.fp8 = fp8
return permuted_act, row_id_map

Expand All @@ -101,30 +97,27 @@ def backward(
if not permuted_act_grad.is_contiguous():
permuted_act_grad = permuted_act_grad.contiguous()

fp8 = ctx.fp8
if fp8:
if ctx.fp8:
assert isinstance(
permuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
fp8_dtype = permuted_act_grad._fp8_dtype
dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv
permuted_act_grad = permuted_act_grad._data

row_id_map = ctx.row_id_map
num_tokens = ctx.num_tokens
topK = ctx.topK
else:
dtype = TE_DType[permuted_act_grad.dtype]

act_grad = None
if ctx.needs_input_grad[0]:
act_grad = tex.moe_permute_bwd(
permuted_act_grad, ctx.dtype, row_id_map, torch.empty(0), num_tokens, topK
permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK
)
if fp8:
if ctx.fp8:
act_grad = Float8Tensor(
data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv * topK
data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv * ctx.topK
)

return act_grad, None, None, None, None
return act_grad, None, None, None


class _moe_unpermute(torch.autograd.Function):
Expand All @@ -134,7 +127,6 @@ class _moe_unpermute(torch.autograd.Function):
def forward(
ctx,
inp: torch.Tensor,
dtype: tex.DType,
row_id_map: torch.Tensor,
probs: torch.Tensor,
) -> torch.Tensor:
Expand Down Expand Up @@ -166,16 +158,13 @@ def forward(
assert row_id_map.is_cuda, "TransformerEngine needs CUDA."

# Data type check
fp8 = False
if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
fp8 = True
fp8 = isinstance(inp, Float8Tensor)
if fp8:
assert isinstance(
inp, Float8Tensor
), "Input must be in Float8Tensor type for FP8 moe_unpermute."
fp8_dtype = inp._fp8_dtype
dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
inp = inp._data
else:
dtype = TE_DType[inp.dtype]
if row_id_map.dtype != torch.int32:
warnings.warn(
f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! "
Expand All @@ -187,10 +176,9 @@ def forward(

if fp8:
unpermuted_output = Float8Tensor(
data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
data=unpermuted_output, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv
)

ctx.dtype = dtype
ctx.save_for_backward(inp, row_id_map, probs)
ctx.fp8 = fp8
return unpermuted_output
Expand All @@ -207,35 +195,33 @@ def backward(
if not unpermuted_act_grad.is_contiguous():
unpermuted_act_grad = unpermuted_act_grad.contiguous()

fp8 = ctx.fp8
if fp8:
if ctx.fp8:
assert isinstance(
unpermuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute."
fp8_dtype = unpermuted_act_grad._fp8_dtype
dtype = unpermuted_act_grad._fp8_dtype
fp8_scale_inv = unpermuted_act_grad._scale_inv
unpermuted_act_grad = unpermuted_act_grad._data
else:
dtype = TE_DType[unpermuted_act_grad.dtype]

inp, row_id_map, probs = ctx.saved_tensors

act_grad = None
if ctx.needs_input_grad[0]:
act_grad, prob_grad = tex.moe_unpermute_bwd(
unpermuted_act_grad, inp, ctx.dtype, row_id_map, probs
unpermuted_act_grad, inp, dtype, row_id_map, probs
)
if fp8:
act_grad = Float8Tensor(
data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
)
if not ctx.needs_input_grad[3]:
if ctx.fp8:
act_grad = Float8Tensor(data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv)
if not ctx.needs_input_grad[2]:
prob_grad = None

return act_grad, None, None, prob_grad
return act_grad, None, prob_grad


def moe_permute(
inp: torch.Tensor,
dtype: tex.DType,
indices: torch.Tensor,
num_out_tokens: int = -1,
max_token_num: int = -1,
Expand All @@ -247,8 +233,6 @@ def moe_permute(
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
dtype: tex.DType
Data type of the input tensor.
indices: torch.Tensor
The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'.
num_out_tokens: int, default = -1
Expand All @@ -259,12 +243,11 @@ def moe_permute(
By default, set to '-1', meaning the calculation of the size of workspace is
automatically taken over by the operator.
"""
return _moe_permute.apply(inp, dtype, indices, num_out_tokens, max_token_num)
return _moe_permute.apply(inp, indices, num_out_tokens, max_token_num)


def moe_unpermute(
inp: torch.Tensor,
dtype: tex.DType,
row_id_map: torch.Tensor,
probs: torch.Tensor = None,
) -> torch.Tensor:
Expand All @@ -276,8 +259,6 @@ def moe_unpermute(
----------
inp: torch.Tensor
Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
dtype: tex.DType
Data type of the input tensor.
row_id_map: torch.Tensor
The tensor of a mapping table for sorted indices used to unpermute the tokens,
which is the second output tensor of `Permute`.
Expand All @@ -286,4 +267,4 @@ def moe_unpermute(
the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
"""
return _moe_unpermute.apply(inp, dtype, row_id_map, probs)
return _moe_unpermute.apply(inp, row_id_map, probs)

0 comments on commit 8ddac3d

Please sign in to comment.