Skip to content

Commit

Permalink
let grouped gemm support different input combinations
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Aug 27, 2024
1 parent 8eba144 commit 7b9ac33
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 126 deletions.
200 changes: 79 additions & 121 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"fp8_gemm",
"grouped_gemm",
"fp8_grouped_gemm",
"fp8_grouped_gemm_single_output",
]


Expand Down Expand Up @@ -386,16 +385,17 @@ def grouped_gemm(

def fp8_grouped_gemm(
A: List[torch.Tensor],
A_scale_inv: torch.Tensor,
A_scale_inv: Union[torch.Tensor, List[torch.Tensor]],
A_fp8_tensor_offset: int,
A_dtype: tex.DType,
B: List[torch.Tensor],
B_scale_inv: torch.Tensor,
B_fp8_tensor_offset: int,
B_dtype: tex.DType,
out: List[torch.Tensor],
out: Union[torch.Tensor, List[torch.Tensor]],
out_dtype: torch.dtype,
workspaces: List[torch.Tensor],
m_splits: Optional[List[int]] = None,
out_offset: Optional[int] = None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
gelu: bool = False,
Expand All @@ -407,93 +407,18 @@ def fp8_grouped_gemm(
) -> Tuple[Union[List[torch.Tensor], None], ...]:
"""
TN layout Grouped GEMM with fp8 inputs.
This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor.
scale: [ ...A_scale... | ...B_scale... | ...out_scale...]
scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...]
amax: [ ...A_amax... | ...B_amax... | ...out_amax...]
This function accepts two combinations of inputs:
1. A_scale_inv is a list of tensors, out is a single tensor, and m_splits is not None.
This is used for the calculation of output (fwd) and dgrad (bwd).
2. A_scale_inv is a single tensor, out is a list of tensors. This is used for the
calculation of wgrad.
"""

num_gemms = len(A)
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_offset is not None
for a, b in zip(A, B):
assert_dim_for_fp8_exec(a)
assert_dim_for_fp8_exec(b)
assert A[0].dtype == torch.uint8
assert B[0].dtype == torch.uint8

# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
]
if isinstance(A_scale_inv, list):
assert isinstance(out, torch.Tensor) and m_splits is not None
elif isinstance(A_scale_inv, torch.Tensor):
assert isinstance(out, list)
else:
gelu_input = empty_tensors
bias_dtype = TE_DType[bias_dtype]

out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype

torch.ops.tex_ts.te_grouped_gemm_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
out,
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)

return out, gelu_input


def fp8_grouped_gemm_single_output(
A: List[torch.Tensor],
A_scale_inv: List[torch.Tensor],
A_fp8_tensor_offset: int,
A_dtype: tex.DType,
B: List[torch.Tensor],
B_scale_inv: torch.Tensor,
B_fp8_tensor_offset: int,
B_dtype: tex.DType,
m_splits: List[int],
out: torch.Tensor,
out_dtype: torch.dtype,
workspaces: List[torch.Tensor],
out_offset: Optional[int] = None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
gelu: bool = False,
accumulate: bool = False,
bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
"""
TN layout Grouped GEMM with two lists of fp8 inputs, and a single contiguous output implicitly
splitted by m_splits.
This method assumes the scale_inv of A is a list of tensors.
Used for the calculation of output (fwd) and dgrad (bwd).
"""
raise ValueError("A_scale_inv should be a list of tensors or a single tensor.")

num_gemms = len(A)
empty_tensor = _empty_tensor()
Expand All @@ -508,39 +433,72 @@ def fp8_grouped_gemm_single_output(

# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype
if gelu:
gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits]
else:
gelu_input = empty_tensors
bias_dtype = TE_DType[bias_dtype]

out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype

torch.ops.tex_ts.te_grouped_gemm_single_output_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
m_splits,
out,
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)
gelu_input = empty_tensors

if isinstance(out, list):
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
]
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype

torch.ops.tex_ts.te_grouped_gemm_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
out,
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)
else:
if gelu:
gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits]
out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype

torch.ops.tex_ts.te_grouped_gemm_single_output_ts(
A,
A_scale_inv,
A_fp8_tensor_offset,
A_dtype,
True, # transa
B,
B_scale_inv,
B_fp8_tensor_offset,
B_dtype,
False, # transb
m_splits,
out,
0 if out_offset is None else out_offset,
empty_tensor if out_offset is None else fp8_meta_tensor.scale,
out_dtype,
empty_tensor if out_offset is None else fp8_meta_tensor.amax_history,
bias if use_bias else empty_tensors,
bias_dtype,
gelu_input, # this is pre_gelu_out
False, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)

return out, gelu_input
9 changes: 4 additions & 5 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
fp8_cast_transpose_bgrad_fused,
fp8_multi_cast_transpose_fused,
fp8_grouped_gemm,
fp8_grouped_gemm_single_output,
grouped_gemm,
)
from ..constants import GemmParallelModes, dist_group_type
Expand Down Expand Up @@ -169,7 +168,7 @@ def forward(
device=inputmats[0].device,
)

_ = fp8_grouped_gemm_single_output(
_ = fp8_grouped_gemm(
[w._data for w in weights_fp8],
[w._scale_inv for w in weights_fp8],
0, # weight offset is 0 for the newly created _scale_inv
Expand All @@ -178,10 +177,10 @@ def forward(
inputmat_scale_inv,
0,
fp8_dtype_forward,
m_splits,
out,
activation_dtype,
get_multi_stream_cublas_workspace(),
m_splits=m_splits,
bias=biases,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
Expand Down Expand Up @@ -359,7 +358,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
dtype=ctx.activation_dtype,
device=grad_output.device,
)
fp8_grouped_gemm_single_output(
fp8_grouped_gemm(
[w.transpose_2d() for w in weights_fp8],
[w._scale_inv for w in weights_fp8],
0, # weight offset is 0 for the newly created _scale_inv
Expand All @@ -368,10 +367,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
ctx.fp8_meta["scaling_bwd"].scale_inv,
_GRAD_OUTPUT,
fp8_dtype_backward,
ctx.m_splits,
dgrad,
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
m_splits=ctx.m_splits,
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
Expand Down

0 comments on commit 7b9ac33

Please sign in to comment.