diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 09b7949f2e..059149a408 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -16,7 +16,6 @@ "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm", - "fp8_grouped_gemm_single_output", ] @@ -386,16 +385,17 @@ def grouped_gemm( def fp8_grouped_gemm( A: List[torch.Tensor], - A_scale_inv: torch.Tensor, + A_scale_inv: Union[torch.Tensor, List[torch.Tensor]], A_fp8_tensor_offset: int, A_dtype: tex.DType, B: List[torch.Tensor], B_scale_inv: torch.Tensor, B_fp8_tensor_offset: int, B_dtype: tex.DType, - out: List[torch.Tensor], + out: Union[torch.Tensor, List[torch.Tensor]], out_dtype: torch.dtype, workspaces: List[torch.Tensor], + m_splits: Optional[List[int]] = None, out_offset: Optional[int] = None, fp8_meta_tensor: tex.FP8TensorMeta = None, gelu: bool = False, @@ -407,93 +407,18 @@ def fp8_grouped_gemm( ) -> Tuple[Union[List[torch.Tensor], None], ...]: """ TN layout Grouped GEMM with fp8 inputs. - This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor. - scale: [ ...A_scale... | ...B_scale... | ...out_scale...] - scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...] - amax: [ ...A_amax... | ...B_amax... | ...out_amax...] + This function accepts two combinations of inputs: + 1. A_scale_inv is a list of tensors, out is a single tensor, and m_splits is not None. + This is used for the calculation of output (fwd) and dgrad (bwd). + 2. A_scale_inv is a single tensor, out is a list of tensors. This is used for the + calculation of wgrad. """ - - num_gemms = len(A) - empty_tensor = _empty_tensor() - empty_tensors = [empty_tensor] * num_gemms - if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - assert fp8_meta_tensor is not None and out_offset is not None - for a, b in zip(A, B): - assert_dim_for_fp8_exec(a) - assert_dim_for_fp8_exec(b) - assert A[0].dtype == torch.uint8 - assert B[0].dtype == torch.uint8 - - # Use bfloat16 as default bias_dtype - bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype - if gelu: - gelu_input = [ - torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) - for o in out - ] + if isinstance(A_scale_inv, list): + assert isinstance(out, torch.Tensor) and m_splits is not None + elif isinstance(A_scale_inv, torch.Tensor): + assert isinstance(out, list) else: - gelu_input = empty_tensors - bias_dtype = TE_DType[bias_dtype] - - out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype - - torch.ops.tex_ts.te_grouped_gemm_ts( - A, - A_scale_inv, - A_fp8_tensor_offset, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor_offset, - B_dtype, - False, # transb - out, - 0 if out_offset is None else out_offset, - empty_tensor if out_offset is None else fp8_meta_tensor.scale, - out_dtype, - empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, - bias if use_bias else empty_tensors, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspaces, - workspaces[0].shape[0], - accumulate, - use_split_accumulator, - ) - - return out, gelu_input - - -def fp8_grouped_gemm_single_output( - A: List[torch.Tensor], - A_scale_inv: List[torch.Tensor], - A_fp8_tensor_offset: int, - A_dtype: tex.DType, - B: List[torch.Tensor], - B_scale_inv: torch.Tensor, - B_fp8_tensor_offset: int, - B_dtype: tex.DType, - m_splits: List[int], - out: torch.Tensor, - out_dtype: torch.dtype, - workspaces: List[torch.Tensor], - out_offset: Optional[int] = None, - fp8_meta_tensor: tex.FP8TensorMeta = None, - gelu: bool = False, - accumulate: bool = False, - bias: Optional[List[torch.Tensor]] = None, - use_bias: bool = False, - use_split_accumulator: bool = False, - D_dtype: Optional[tex.DType] = None, -) -> Tuple[Union[List[torch.Tensor], None], ...]: - """ - TN layout Grouped GEMM with two lists of fp8 inputs, and a single contiguous output implicitly - splitted by m_splits. - This method assumes the scale_inv of A is a list of tensors. - Used for the calculation of output (fwd) and dgrad (bwd). - """ + raise ValueError("A_scale_inv should be a list of tensors or a single tensor.") num_gemms = len(A) empty_tensor = _empty_tensor() @@ -508,39 +433,72 @@ def fp8_grouped_gemm_single_output( # Use bfloat16 as default bias_dtype bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype - if gelu: - gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits] - else: - gelu_input = empty_tensors bias_dtype = TE_DType[bias_dtype] - - out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype - - torch.ops.tex_ts.te_grouped_gemm_single_output_ts( - A, - A_scale_inv, - A_fp8_tensor_offset, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor_offset, - B_dtype, - False, # transb - m_splits, - out, - 0 if out_offset is None else out_offset, - empty_tensor if out_offset is None else fp8_meta_tensor.scale, - out_dtype, - empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, - bias if use_bias else empty_tensors, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspaces, - workspaces[0].shape[0], - accumulate, - use_split_accumulator, - ) + gelu_input = empty_tensors + + if isinstance(out, list): + if gelu: + gelu_input = [ + torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) + for o in out + ] + out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype + + torch.ops.tex_ts.te_grouped_gemm_ts( + A, + A_scale_inv, + A_fp8_tensor_offset, + A_dtype, + True, # transa + B, + B_scale_inv, + B_fp8_tensor_offset, + B_dtype, + False, # transb + out, + 0 if out_offset is None else out_offset, + empty_tensor if out_offset is None else fp8_meta_tensor.scale, + out_dtype, + empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, + bias if use_bias else empty_tensors, + bias_dtype, + gelu_input, # this is pre_gelu_out + False, # grad + workspaces, + workspaces[0].shape[0], + accumulate, + use_split_accumulator, + ) + else: + if gelu: + gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits] + out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype + + torch.ops.tex_ts.te_grouped_gemm_single_output_ts( + A, + A_scale_inv, + A_fp8_tensor_offset, + A_dtype, + True, # transa + B, + B_scale_inv, + B_fp8_tensor_offset, + B_dtype, + False, # transb + m_splits, + out, + 0 if out_offset is None else out_offset, + empty_tensor if out_offset is None else fp8_meta_tensor.scale, + out_dtype, + empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, + bias if use_bias else empty_tensors, + bias_dtype, + gelu_input, # this is pre_gelu_out + False, # grad + workspaces, + workspaces[0].shape[0], + accumulate, + use_split_accumulator, + ) return out, gelu_input diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index da61e94686..b1c9f5670e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -36,7 +36,6 @@ fp8_cast_transpose_bgrad_fused, fp8_multi_cast_transpose_fused, fp8_grouped_gemm, - fp8_grouped_gemm_single_output, grouped_gemm, ) from ..constants import GemmParallelModes, dist_group_type @@ -169,7 +168,7 @@ def forward( device=inputmats[0].device, ) - _ = fp8_grouped_gemm_single_output( + _ = fp8_grouped_gemm( [w._data for w in weights_fp8], [w._scale_inv for w in weights_fp8], 0, # weight offset is 0 for the newly created _scale_inv @@ -178,10 +177,10 @@ def forward( inputmat_scale_inv, 0, fp8_dtype_forward, - m_splits, out, activation_dtype, get_multi_stream_cublas_workspace(), + m_splits=m_splits, bias=biases, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, @@ -359,7 +358,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=grad_output.device, ) - fp8_grouped_gemm_single_output( + fp8_grouped_gemm( [w.transpose_2d() for w in weights_fp8], [w._scale_inv for w in weights_fp8], 0, # weight offset is 0 for the newly created _scale_inv @@ -368,10 +367,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.fp8_meta["scaling_bwd"].scale_inv, _GRAD_OUTPUT, fp8_dtype_backward, - ctx.m_splits, dgrad, ctx.activation_dtype, get_multi_stream_cublas_workspace(), + m_splits=ctx.m_splits, use_split_accumulator=_2X_ACC_DGRAD, ) else: