Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Remove dtype from args of permutation #1145

Merged
merged 5 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables

.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context

.. autoapifunction:: transformer_engine.pytorch.moe_permute

.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py
10 changes: 4 additions & 6 deletions tests/pytorch/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,7 @@ def _test_permutation(
te_permute_fwd_input.requires_grad_(True)
te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach()

te_permute_output, row_id_map = te_permute(
te_permute_fwd_input, te_dtype, indices, num_out_tokens
)
te_permute_output, row_id_map = te_permute(te_permute_fwd_input, indices, num_out_tokens)
te_permute_output.backward(te_permute_bwd_input, retain_graph=True)

te_probs = None
Expand All @@ -233,7 +231,7 @@ def _test_permutation(
te_unpermute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach()

te_unpermute_output = te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs)
te_unpermute_output = te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs)
te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True)

###################################################################################################################################
Expand Down Expand Up @@ -305,7 +303,7 @@ def backward_wrapper(
lambda: pytorch_permute(pytorch_permute_fwd_input, indices, num_out_tokens)
)
t2 = perf_test_cuda_kernel(
lambda: te_permute(te_permute_fwd_input, te_dtype, indices, num_out_tokens)
lambda: te_permute(te_permute_fwd_input, indices, num_out_tokens)
)
print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")

Expand Down Expand Up @@ -333,7 +331,7 @@ def backward_wrapper(
lambda: pytorch_unpermute(pytorch_unpermute_fwd_input, sorted_indices, probs=probs)
)
t2 = perf_test_cuda_kernel(
lambda: te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs)
lambda: te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs)
)
print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")

Expand Down
83 changes: 32 additions & 51 deletions transformer_engine/pytorch/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import torch

import transformer_engine_torch as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from .constants import TE_DType
from .float8_tensor import Float8Tensor


__all__ = [
Expand All @@ -27,14 +28,13 @@ class _moe_permute(torch.autograd.Function):
def forward(
ctx,
inp: torch.Tensor,
dtype: tex.DType,
indices: torch.Tensor,
num_out_tokens: int,
max_token_num: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Empty input check
if not inp.numel():
return inp, None
return inp, torch.tensor([], device=inp.device)

# Device check
assert inp.is_cuda, "TransformerEngine needs CUDA."
Expand All @@ -43,16 +43,13 @@ def forward(
assert inp.size(0) == indices.size(0), "Permute not possible"

# Data type check
fp8 = False
if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
fp8 = True
fp8 = isinstance(inp, Float8Tensor)
if fp8:
assert isinstance(
inp, Float8Tensor
), "Input must be in Float8Tensor type for FP8 moe_permute."
fp8_dtype = inp._fp8_dtype
dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
inp = inp._data
else:
dtype = TE_DType[inp.dtype]
if indices.dtype != torch.int32:
warnings.warn(
f"The data type of the input `indices` of Permute is {indices.dtype}! "
Expand All @@ -78,13 +75,12 @@ def forward(

if fp8:
permuted_act = Float8Tensor(
data=permuted_act, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
data=permuted_act, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv
)

ctx.row_id_map = row_id_map
ctx.num_tokens = indices.size(0)
ctx.topK = indices.size(1)
ctx.dtype = dtype
ctx.fp8 = fp8
return permuted_act, row_id_map

Expand All @@ -101,30 +97,27 @@ def backward(
if not permuted_act_grad.is_contiguous():
permuted_act_grad = permuted_act_grad.contiguous()

fp8 = ctx.fp8
if fp8:
if ctx.fp8:
assert isinstance(
permuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
fp8_dtype = permuted_act_grad._fp8_dtype
dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv
permuted_act_grad = permuted_act_grad._data

row_id_map = ctx.row_id_map
num_tokens = ctx.num_tokens
topK = ctx.topK
else:
dtype = TE_DType[permuted_act_grad.dtype]

act_grad = None
if ctx.needs_input_grad[0]:
act_grad = tex.moe_permute_bwd(
permuted_act_grad, ctx.dtype, row_id_map, torch.empty(0), num_tokens, topK
permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK
)
if fp8:
if ctx.fp8:
act_grad = Float8Tensor(
data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv * topK
data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv * ctx.topK
)

return act_grad, None, None, None, None
return act_grad, None, None, None


class _moe_unpermute(torch.autograd.Function):
Expand All @@ -134,7 +127,6 @@ class _moe_unpermute(torch.autograd.Function):
def forward(
ctx,
inp: torch.Tensor,
dtype: tex.DType,
row_id_map: torch.Tensor,
probs: torch.Tensor,
) -> torch.Tensor:
Expand Down Expand Up @@ -166,16 +158,13 @@ def forward(
assert row_id_map.is_cuda, "TransformerEngine needs CUDA."

# Data type check
fp8 = False
if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
fp8 = True
fp8 = isinstance(inp, Float8Tensor)
if fp8:
assert isinstance(
inp, Float8Tensor
), "Input must be in Float8Tensor type for FP8 moe_unpermute."
fp8_dtype = inp._fp8_dtype
dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
inp = inp._data
else:
dtype = TE_DType[inp.dtype]
if row_id_map.dtype != torch.int32:
warnings.warn(
f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! "
Expand All @@ -187,10 +176,9 @@ def forward(

if fp8:
unpermuted_output = Float8Tensor(
data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
data=unpermuted_output, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv
)

ctx.dtype = dtype
ctx.save_for_backward(inp, row_id_map, probs)
ctx.fp8 = fp8
return unpermuted_output
Expand All @@ -207,35 +195,33 @@ def backward(
if not unpermuted_act_grad.is_contiguous():
unpermuted_act_grad = unpermuted_act_grad.contiguous()

fp8 = ctx.fp8
if fp8:
if ctx.fp8:
assert isinstance(
unpermuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute."
fp8_dtype = unpermuted_act_grad._fp8_dtype
dtype = unpermuted_act_grad._fp8_dtype
fp8_scale_inv = unpermuted_act_grad._scale_inv
unpermuted_act_grad = unpermuted_act_grad._data
else:
dtype = TE_DType[unpermuted_act_grad.dtype]

inp, row_id_map, probs = ctx.saved_tensors

act_grad = None
if ctx.needs_input_grad[0]:
act_grad, prob_grad = tex.moe_unpermute_bwd(
unpermuted_act_grad, inp, ctx.dtype, row_id_map, probs
unpermuted_act_grad, inp, dtype, row_id_map, probs
)
if fp8:
act_grad = Float8Tensor(
data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
)
if not ctx.needs_input_grad[3]:
if ctx.fp8:
act_grad = Float8Tensor(data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv)
if not ctx.needs_input_grad[2]:
prob_grad = None

return act_grad, None, None, prob_grad
return act_grad, None, prob_grad


def moe_permute(
inp: torch.Tensor,
dtype: tex.DType,
indices: torch.Tensor,
num_out_tokens: int = -1,
max_token_num: int = -1,
Expand All @@ -247,8 +233,6 @@ def moe_permute(
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
dtype: tex.DType
Data type of the input tensor.
indices: torch.Tensor
The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'.
num_out_tokens: int, default = -1
Expand All @@ -259,12 +243,11 @@ def moe_permute(
By default, set to '-1', meaning the calculation of the size of workspace is
automatically taken over by the operator.
"""
return _moe_permute.apply(inp, dtype, indices, num_out_tokens, max_token_num)
return _moe_permute.apply(inp, indices, num_out_tokens, max_token_num)


def moe_unpermute(
inp: torch.Tensor,
dtype: tex.DType,
row_id_map: torch.Tensor,
probs: torch.Tensor = None,
) -> torch.Tensor:
Expand All @@ -276,8 +259,6 @@ def moe_unpermute(
----------
inp: torch.Tensor
Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
dtype: tex.DType
Data type of the input tensor.
row_id_map: torch.Tensor
The tensor of a mapping table for sorted indices used to unpermute the tokens,
which is the second output tensor of `Permute`.
Expand All @@ -286,4 +267,4 @@ def moe_unpermute(
the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
"""
return _moe_unpermute.apply(inp, dtype, row_id_map, probs)
return _moe_unpermute.apply(inp, row_id_map, probs)
Loading