From c6aff8bbde6f4580014561375f192ce309e275d2 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 22 Aug 2024 02:15:16 -0700 Subject: [PATCH 1/7] propagate scale_inv modification to GroupedLinear Signed-off-by: Xin Yao --- .../pytorch/cpp_extensions/transpose.py | 3 +- .../pytorch/module/grouped_linear.py | 35 ++++++++++++++----- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index 37a1b59da2..ddc3b67e9e 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -175,6 +175,7 @@ def fp8_multi_cast_transpose_fused( amax_indices: List[int], scale_inv_indices: List[int], otype: tex.DType, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """Cast + Transpose with FP8 output""" @@ -182,7 +183,7 @@ def fp8_multi_cast_transpose_fused( input_list, fp8_meta_tensor.scale, fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + scale_inv if scale_inv is not None else fp8_meta_tensor.scale_inv, scale_indices, amax_indices, scale_inv_indices, diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a91ff5c361..7f56ced8ef 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -42,6 +42,7 @@ from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor +from ..export import is_in_onnx_export_mode __all__ = ["GroupedLinear"] @@ -102,10 +103,12 @@ def forward( inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats] inputmats = [] inputmats_t = [] + inputmat_scale_inv = None global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device) if ( not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled @@ -121,6 +124,7 @@ def forward( indices, # amax_indices indices, # scale_inv_indices fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) else: # FP8 input for forward @@ -130,9 +134,22 @@ def forward( fp8_meta["scaling_fwd"], _GEMM_INPUT + i, fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) for i in range(num_gemms) ] + + # Hack for ONNX export + # Note: ONNX models are represented as a graph of tensor + # operations, so the in-place scale-inv update doesn't fit + # very well. We work around this by making it look like + # the scale-inv tensor is initialized with a copy. + # Note: ONNX export expects FP8 scales can be represented + # with constant ops. However, copying into a buffer + # involves an expand op for array broadcasting. We work + # around this by filling the buffer instead. + if is_in_onnx_export_mode(): + inputmat_scale_inv.fill_(inputmat_scale_inv.item()) else: inputmats = inputmats_no_fp8 @@ -153,12 +170,14 @@ def forward( _ = fp8_grouped_gemm( [w._data for w in weights_fp8], - fp8_meta["scaling_fwd"].scale_inv, - _GEMM_WEIGHT, + torch.cat( + [w._scale_inv for w in weights_fp8] + ), # avoiding torch.cat requires another interface, + 0, # weight offset is 0 for the newly created _scale_inv fp8_dtype_forward, inputmats, - fp8_meta["scaling_fwd"].scale_inv, - _GEMM_INPUT, + inputmat_scale_inv, + 0, fp8_dtype_forward, torch.split(out, m_splits), activation_dtype, @@ -230,7 +249,7 @@ def forward( t.activation_offloading = True ctx.save_for_backward( - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + inputmat_scale_inv, *saved_inputmats, *saved_inputmats_t, *weights, @@ -270,7 +289,7 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: with torch.cuda.nvtx.range("_GroupedLinear_backward"): ( - fwd_scale_inverses, + inputmat_scale_inv, *saved_tensors, ) = ctx.saved_tensors inputmats = saved_tensors[: ctx.num_gemms] @@ -396,8 +415,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inp._data if isinstance(inp, Float8Tensor) else inp for inp in inputmats_t ], - fwd_scale_inverses, - _GEMM_INPUT, + inputmat_scale_inv, + 0, fp8_dtype_forward, grad_output_t, ctx.fp8_meta["scaling_bwd"].scale_inv, From 8eba144d4906ae33cdc49a4398ef737fca5a90a6 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Sun, 25 Aug 2024 23:50:12 -0700 Subject: [PATCH 2/7] optimization for separate scale_inv of weights and single output Signed-off-by: Xin Yao --- .../pytorch/cpp_extensions/gemm.py | 88 ++++++++++++++++++- transformer_engine/pytorch/csrc/extensions.h | 10 +++ .../pytorch/csrc/extensions/gemm.cu | 58 ++++++++++++ transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 36 ++++++++ .../pytorch/module/grouped_linear.py | 19 ++-- 5 files changed, 200 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 8502f70491..09b7949f2e 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -11,7 +11,13 @@ from ..utils import assert_dim_for_fp8_exec -__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"] +__all__ = [ + "gemm", + "fp8_gemm", + "grouped_gemm", + "fp8_grouped_gemm", + "fp8_grouped_gemm_single_output", +] @functools.lru_cache(maxsize=None) @@ -458,3 +464,83 @@ def fp8_grouped_gemm( ) return out, gelu_input + + +def fp8_grouped_gemm_single_output( + A: List[torch.Tensor], + A_scale_inv: List[torch.Tensor], + A_fp8_tensor_offset: int, + A_dtype: tex.DType, + B: List[torch.Tensor], + B_scale_inv: torch.Tensor, + B_fp8_tensor_offset: int, + B_dtype: tex.DType, + m_splits: List[int], + out: torch.Tensor, + out_dtype: torch.dtype, + workspaces: List[torch.Tensor], + out_offset: Optional[int] = None, + fp8_meta_tensor: tex.FP8TensorMeta = None, + gelu: bool = False, + accumulate: bool = False, + bias: Optional[List[torch.Tensor]] = None, + use_bias: bool = False, + use_split_accumulator: bool = False, + D_dtype: Optional[tex.DType] = None, +) -> Tuple[Union[List[torch.Tensor], None], ...]: + """ + TN layout Grouped GEMM with two lists of fp8 inputs, and a single contiguous output implicitly + splitted by m_splits. + This method assumes the scale_inv of A is a list of tensors. + Used for the calculation of output (fwd) and dgrad (bwd). + """ + + num_gemms = len(A) + empty_tensor = _empty_tensor() + empty_tensors = [empty_tensor] * num_gemms + if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: + assert fp8_meta_tensor is not None and out_offset is not None + for a, b in zip(A, B): + assert_dim_for_fp8_exec(a) + assert_dim_for_fp8_exec(b) + assert A[0].dtype == torch.uint8 + assert B[0].dtype == torch.uint8 + + # Use bfloat16 as default bias_dtype + bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype + if gelu: + gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits] + else: + gelu_input = empty_tensors + bias_dtype = TE_DType[bias_dtype] + + out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype + + torch.ops.tex_ts.te_grouped_gemm_single_output_ts( + A, + A_scale_inv, + A_fp8_tensor_offset, + A_dtype, + True, # transa + B, + B_scale_inv, + B_fp8_tensor_offset, + B_dtype, + False, # transb + m_splits, + out, + 0 if out_offset is None else out_offset, + empty_tensor if out_offset is None else fp8_meta_tensor.scale, + out_dtype, + empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, + bias if use_bias else empty_tensors, + bias_dtype, + gelu_input, # this is pre_gelu_out + False, # grad + workspaces, + workspaces[0].shape[0], + accumulate, + use_split_accumulator, + ) + + return out, gelu_input diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1a6f5f157e..e0188460db 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -159,6 +159,16 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); +void te_grouped_gemm_single_output( + std::vector A, std::vector A_scale_inverse, int A_offset, + transformer_engine::DType A_type, bool transa, std::vector B, + at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, + std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, + transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, + std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count); + /*************************************************************************************************** * Transpose **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index 01fb94cab4..53fafe40cf 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -146,3 +146,61 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int te_workspace.data(), accumulate, use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream()); } + +void te_grouped_gemm_single_output( + std::vector A, std::vector A_scale_inverse, int A_offset, + transformer_engine::DType A_type, bool transa, std::vector B, + at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, + std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, + transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, + std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count) { + using namespace transformer_engine; + std::vector te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace; + std::vector tensor_wrappers; + auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + transformer_engine::DType dtype, void* amax_dptr, + void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { + tensor_wrappers.emplace_back( + makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); + return tensor_wrappers.back().data(); + }; + void* d_i_ptr = reinterpret_cast(D.data_ptr()); + for (size_t i = 0; i < A.size(); i++) { + if (m_splits[i] == 0) continue; + te_A.emplace_back(make_tensor( + A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))}, + A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset))); + te_B.emplace_back(make_tensor( + B[i].data_ptr(), {static_cast(B[i].size(0)), static_cast(B[i].size(1))}, + B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i))); + te_D.emplace_back(make_tensor( + d_i_ptr, {static_cast(m_splits[i]), static_cast(A[i].size(0))}, D_type, + getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr)); + te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast(bias[i].size(0))}, + bias_type, nullptr, nullptr, nullptr)); + + const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr + ? std::vector{static_cast(pre_gelu_out[i].size(0))} + : std::vector{static_cast(pre_gelu_out[i].size(0)), + static_cast(pre_gelu_out[i].size(1))}; + te_pre_gelu_out.emplace_back(make_tensor( + pre_gelu_out[i].data_ptr(), gelu_shape, + GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr)); + // Move the D pointer to the next split. + char* char_ptr = reinterpret_cast(d_i_ptr); + char_ptr += m_splits[i] * A[i].size(0) * D.element_size(); + d_i_ptr = reinterpret_cast(char_ptr); + } + for (size_t i = 0; i < workspace.size(); i++) { + te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte, + nullptr, nullptr, nullptr)); + } + + // For now, we only have multi-stream cublas backend. + nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), + te_pre_gelu_out.data(), te_A.size(), transa, transb, grad, + te_workspace.data(), accumulate, use_split_accumulator, + math_sm_count, at::cuda::getCurrentCUDAStream()); +} diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index 8515092ae0..7d7ac8b14d 100644 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -305,6 +305,41 @@ std::vector te_grouped_gemm_ts( return D; } +at::Tensor te_grouped_gemm_single_output_ts( + std::vector A, std::vector A_scale_inverse, int64_t A_offset, + int64_t A_type, int64_t transa, std::vector B, at::Tensor B_scale_inverse, + int64_t B_offset, int64_t B_type, int64_t transb, std::vector m_splits, at::Tensor D, + int64_t D_offset, at::Tensor D_scale, int64_t D_type, at::Tensor D_amax, + std::vector bias, int64_t bias_type, std::vector pre_gelu_out, + int64_t grad, std::vector workspace, int64_t workspaceSize, int64_t accumulate, + int64_t use_split_accumulator) { + // cast inputs to types accepted by te_gemm + transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); + bool transa_arg = static_cast(transa); + transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); + bool transb_arg = static_cast(transb); + transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); + transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type); + bool grad_arg = static_cast(grad); + size_t workspaceSize_arg = static_cast(workspaceSize); + bool accumulate_arg = static_cast(accumulate); + bool use_split_accumulator_arg = static_cast(use_split_accumulator); + + // Set an external SM Margin to all the GEMMs. + // This comes in handy when DP is overlapped with GEMMs + + const int device_id = at::cuda::current_device(); + const int sm_count = transformer_engine::cuda::sm_count(device_id); + int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + + te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, + B_scale_inverse, B_offset, B_type_arg, transb_arg, m_splits, D, + D_offset, D_scale, D_type_arg, D_amax, bias, bias_type_arg, + pre_gelu_out, grad_arg, workspace, workspaceSize_arg, + accumulate_arg, use_split_accumulator_arg, num_math_sms); + return D; +} + at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, double eps, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, @@ -371,6 +406,7 @@ TORCH_LIBRARY(tex_ts, m) { m.def("srelu_ts", &srelu_ts); m.def("te_gemm_ts", &te_gemm_ts); m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts); + m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts); m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts); diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 7f56ced8ef..da61e94686 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -36,6 +36,7 @@ fp8_cast_transpose_bgrad_fused, fp8_multi_cast_transpose_fused, fp8_grouped_gemm, + fp8_grouped_gemm_single_output, grouped_gemm, ) from ..constants import GemmParallelModes, dist_group_type @@ -168,18 +169,17 @@ def forward( device=inputmats[0].device, ) - _ = fp8_grouped_gemm( + _ = fp8_grouped_gemm_single_output( [w._data for w in weights_fp8], - torch.cat( - [w._scale_inv for w in weights_fp8] - ), # avoiding torch.cat requires another interface, + [w._scale_inv for w in weights_fp8], 0, # weight offset is 0 for the newly created _scale_inv fp8_dtype_forward, inputmats, inputmat_scale_inv, 0, fp8_dtype_forward, - torch.split(out, m_splits), + m_splits, + out, activation_dtype, get_multi_stream_cublas_workspace(), bias=biases, @@ -359,18 +359,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=grad_output.device, ) - fp8_grouped_gemm( + fp8_grouped_gemm_single_output( [w.transpose_2d() for w in weights_fp8], - torch.cat( - [w._scale_inv for w in weights_fp8] - ), # avoiding torch.cat requires another interface + [w._scale_inv for w in weights_fp8], 0, # weight offset is 0 for the newly created _scale_inv weights_fp8[0]._fp8_dtype, grad_output_c, ctx.fp8_meta["scaling_bwd"].scale_inv, _GRAD_OUTPUT, fp8_dtype_backward, - torch.split(dgrad, ctx.m_splits), + ctx.m_splits, + dgrad, ctx.activation_dtype, get_multi_stream_cublas_workspace(), use_split_accumulator=_2X_ACC_DGRAD, From 7b9ac333102a2eb6d7c18ad4cdb5200d02550dad Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Mon, 26 Aug 2024 23:27:13 -0700 Subject: [PATCH 3/7] let grouped gemm support different input combinations Signed-off-by: Xin Yao --- .../pytorch/cpp_extensions/gemm.py | 200 +++++++----------- .../pytorch/module/grouped_linear.py | 9 +- 2 files changed, 83 insertions(+), 126 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 09b7949f2e..059149a408 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -16,7 +16,6 @@ "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm", - "fp8_grouped_gemm_single_output", ] @@ -386,16 +385,17 @@ def grouped_gemm( def fp8_grouped_gemm( A: List[torch.Tensor], - A_scale_inv: torch.Tensor, + A_scale_inv: Union[torch.Tensor, List[torch.Tensor]], A_fp8_tensor_offset: int, A_dtype: tex.DType, B: List[torch.Tensor], B_scale_inv: torch.Tensor, B_fp8_tensor_offset: int, B_dtype: tex.DType, - out: List[torch.Tensor], + out: Union[torch.Tensor, List[torch.Tensor]], out_dtype: torch.dtype, workspaces: List[torch.Tensor], + m_splits: Optional[List[int]] = None, out_offset: Optional[int] = None, fp8_meta_tensor: tex.FP8TensorMeta = None, gelu: bool = False, @@ -407,93 +407,18 @@ def fp8_grouped_gemm( ) -> Tuple[Union[List[torch.Tensor], None], ...]: """ TN layout Grouped GEMM with fp8 inputs. - This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor. - scale: [ ...A_scale... | ...B_scale... | ...out_scale...] - scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...] - amax: [ ...A_amax... | ...B_amax... | ...out_amax...] + This function accepts two combinations of inputs: + 1. A_scale_inv is a list of tensors, out is a single tensor, and m_splits is not None. + This is used for the calculation of output (fwd) and dgrad (bwd). + 2. A_scale_inv is a single tensor, out is a list of tensors. This is used for the + calculation of wgrad. """ - - num_gemms = len(A) - empty_tensor = _empty_tensor() - empty_tensors = [empty_tensor] * num_gemms - if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - assert fp8_meta_tensor is not None and out_offset is not None - for a, b in zip(A, B): - assert_dim_for_fp8_exec(a) - assert_dim_for_fp8_exec(b) - assert A[0].dtype == torch.uint8 - assert B[0].dtype == torch.uint8 - - # Use bfloat16 as default bias_dtype - bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype - if gelu: - gelu_input = [ - torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) - for o in out - ] + if isinstance(A_scale_inv, list): + assert isinstance(out, torch.Tensor) and m_splits is not None + elif isinstance(A_scale_inv, torch.Tensor): + assert isinstance(out, list) else: - gelu_input = empty_tensors - bias_dtype = TE_DType[bias_dtype] - - out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype - - torch.ops.tex_ts.te_grouped_gemm_ts( - A, - A_scale_inv, - A_fp8_tensor_offset, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor_offset, - B_dtype, - False, # transb - out, - 0 if out_offset is None else out_offset, - empty_tensor if out_offset is None else fp8_meta_tensor.scale, - out_dtype, - empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, - bias if use_bias else empty_tensors, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspaces, - workspaces[0].shape[0], - accumulate, - use_split_accumulator, - ) - - return out, gelu_input - - -def fp8_grouped_gemm_single_output( - A: List[torch.Tensor], - A_scale_inv: List[torch.Tensor], - A_fp8_tensor_offset: int, - A_dtype: tex.DType, - B: List[torch.Tensor], - B_scale_inv: torch.Tensor, - B_fp8_tensor_offset: int, - B_dtype: tex.DType, - m_splits: List[int], - out: torch.Tensor, - out_dtype: torch.dtype, - workspaces: List[torch.Tensor], - out_offset: Optional[int] = None, - fp8_meta_tensor: tex.FP8TensorMeta = None, - gelu: bool = False, - accumulate: bool = False, - bias: Optional[List[torch.Tensor]] = None, - use_bias: bool = False, - use_split_accumulator: bool = False, - D_dtype: Optional[tex.DType] = None, -) -> Tuple[Union[List[torch.Tensor], None], ...]: - """ - TN layout Grouped GEMM with two lists of fp8 inputs, and a single contiguous output implicitly - splitted by m_splits. - This method assumes the scale_inv of A is a list of tensors. - Used for the calculation of output (fwd) and dgrad (bwd). - """ + raise ValueError("A_scale_inv should be a list of tensors or a single tensor.") num_gemms = len(A) empty_tensor = _empty_tensor() @@ -508,39 +433,72 @@ def fp8_grouped_gemm_single_output( # Use bfloat16 as default bias_dtype bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype - if gelu: - gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits] - else: - gelu_input = empty_tensors bias_dtype = TE_DType[bias_dtype] - - out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype - - torch.ops.tex_ts.te_grouped_gemm_single_output_ts( - A, - A_scale_inv, - A_fp8_tensor_offset, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor_offset, - B_dtype, - False, # transb - m_splits, - out, - 0 if out_offset is None else out_offset, - empty_tensor if out_offset is None else fp8_meta_tensor.scale, - out_dtype, - empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, - bias if use_bias else empty_tensors, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspaces, - workspaces[0].shape[0], - accumulate, - use_split_accumulator, - ) + gelu_input = empty_tensors + + if isinstance(out, list): + if gelu: + gelu_input = [ + torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) + for o in out + ] + out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype + + torch.ops.tex_ts.te_grouped_gemm_ts( + A, + A_scale_inv, + A_fp8_tensor_offset, + A_dtype, + True, # transa + B, + B_scale_inv, + B_fp8_tensor_offset, + B_dtype, + False, # transb + out, + 0 if out_offset is None else out_offset, + empty_tensor if out_offset is None else fp8_meta_tensor.scale, + out_dtype, + empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, + bias if use_bias else empty_tensors, + bias_dtype, + gelu_input, # this is pre_gelu_out + False, # grad + workspaces, + workspaces[0].shape[0], + accumulate, + use_split_accumulator, + ) + else: + if gelu: + gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits] + out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype + + torch.ops.tex_ts.te_grouped_gemm_single_output_ts( + A, + A_scale_inv, + A_fp8_tensor_offset, + A_dtype, + True, # transa + B, + B_scale_inv, + B_fp8_tensor_offset, + B_dtype, + False, # transb + m_splits, + out, + 0 if out_offset is None else out_offset, + empty_tensor if out_offset is None else fp8_meta_tensor.scale, + out_dtype, + empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, + bias if use_bias else empty_tensors, + bias_dtype, + gelu_input, # this is pre_gelu_out + False, # grad + workspaces, + workspaces[0].shape[0], + accumulate, + use_split_accumulator, + ) return out, gelu_input diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index da61e94686..b1c9f5670e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -36,7 +36,6 @@ fp8_cast_transpose_bgrad_fused, fp8_multi_cast_transpose_fused, fp8_grouped_gemm, - fp8_grouped_gemm_single_output, grouped_gemm, ) from ..constants import GemmParallelModes, dist_group_type @@ -169,7 +168,7 @@ def forward( device=inputmats[0].device, ) - _ = fp8_grouped_gemm_single_output( + _ = fp8_grouped_gemm( [w._data for w in weights_fp8], [w._scale_inv for w in weights_fp8], 0, # weight offset is 0 for the newly created _scale_inv @@ -178,10 +177,10 @@ def forward( inputmat_scale_inv, 0, fp8_dtype_forward, - m_splits, out, activation_dtype, get_multi_stream_cublas_workspace(), + m_splits=m_splits, bias=biases, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, @@ -359,7 +358,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=grad_output.device, ) - fp8_grouped_gemm_single_output( + fp8_grouped_gemm( [w.transpose_2d() for w in weights_fp8], [w._scale_inv for w in weights_fp8], 0, # weight offset is 0 for the newly created _scale_inv @@ -368,10 +367,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.fp8_meta["scaling_bwd"].scale_inv, _GRAD_OUTPUT, fp8_dtype_backward, - ctx.m_splits, dgrad, ctx.activation_dtype, get_multi_stream_cublas_workspace(), + m_splits=ctx.m_splits, use_split_accumulator=_2X_ACC_DGRAD, ) else: From e42a6d4386a7362bfa2ef03e82b88c615764bd74 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Tue, 27 Aug 2024 00:35:50 -0700 Subject: [PATCH 4/7] fix type Signed-off-by: Xin Yao --- transformer_engine/pytorch/cpp_extensions/gemm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 059149a408..58764b6bcc 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -404,7 +404,7 @@ def fp8_grouped_gemm( use_bias: bool = False, use_split_accumulator: bool = False, D_dtype: Optional[tex.DType] = None, -) -> Tuple[Union[List[torch.Tensor], None], ...]: +) -> Tuple[Union[List[torch.Tensor], torch.Tensor, None], ...]: """ TN layout Grouped GEMM with fp8 inputs. This function accepts two combinations of inputs: @@ -416,7 +416,7 @@ def fp8_grouped_gemm( if isinstance(A_scale_inv, list): assert isinstance(out, torch.Tensor) and m_splits is not None elif isinstance(A_scale_inv, torch.Tensor): - assert isinstance(out, list) + assert isinstance(out, (list, tuple)) else: raise ValueError("A_scale_inv should be a list of tensors or a single tensor.") @@ -436,7 +436,7 @@ def fp8_grouped_gemm( bias_dtype = TE_DType[bias_dtype] gelu_input = empty_tensors - if isinstance(out, list): + if not isinstance(out, torch.Tensor): if gelu: gelu_input = [ torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) From d7441f0ea1ac6e1bb12f964fb422b4af3d8ebc79 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 29 Aug 2024 19:49:29 -0700 Subject: [PATCH 5/7] add contiguous check Signed-off-by: Xin Yao --- transformer_engine/pytorch/csrc/extensions/gemm.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index 53fafe40cf..29cabdb53f 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -166,9 +166,12 @@ void te_grouped_gemm_single_output( makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); return tensor_wrappers.back().data(); }; + NVTE_CHECK(D.is_contiguous(), "D must be contiguous."); void* d_i_ptr = reinterpret_cast(D.data_ptr()); for (size_t i = 0; i < A.size(); i++) { if (m_splits[i] == 0) continue; + NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); + NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); te_A.emplace_back(make_tensor( A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))}, A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset))); From e8b4c0163a34c5ff33af0664e186f1ceb8548f00 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Wed, 4 Sep 2024 19:39:41 -0700 Subject: [PATCH 6/7] use len() instead of isinstance Signed-off-by: Xin Yao --- tests/pytorch/test_numerics.py | 29 +++++++++++---- .../pytorch/cpp_extensions/gemm.py | 37 ++++++++++--------- .../pytorch/module/grouped_linear.py | 6 +-- 3 files changed, 44 insertions(+), 28 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 85cd4fc256..998f30abe3 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1229,12 +1229,15 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False ) inp_hidden_states.retain_grad() - m = config.seq_len // 16 - dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() - dist.append(dist[-1]) # Manually add a zero - m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) - m_splits = m_splits * 16 - assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms + if num_gemms > 1: + m = config.seq_len // 16 + dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() + dist.append(dist[-1]) # Manually add a zero + m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) + m_splits = m_splits * 16 + assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms + else: + m_splits = torch.tensor([config.seq_len]) with fp8_autocast(enabled=fp8): if isinstance(block, GroupedLinear): @@ -1316,7 +1319,7 @@ def test_grouped_linear_accuracy( @pytest.mark.parametrize("parallel_mode", ["column", "row"]) def test_grouped_linear_accuracy_parallel_mode(parallel_mode): - """Split the tests to reduce CI time""" + """Split the tests to save CI time""" test_grouped_linear_accuracy( dtype=torch.float32, num_gemms=6, @@ -1328,6 +1331,18 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode): ) +def test_grouped_linear_accuracy_single_gemm(): + """Split the tests to save CI time""" + test_grouped_linear_accuracy( + dtype=torch.float32, + num_gemms=1, + bs=2, + model=list(model_configs.keys())[0], + fp8=True, + fp8_model_params=True, + ) + + def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): reset_rng_states() diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 58764b6bcc..fd1eb4a810 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -318,7 +318,7 @@ def grouped_gemm( layout: str = "TN", bias: Optional[List[torch.Tensor]] = None, use_bias: bool = False, -) -> Tuple[Union[List[torch.Tensor], None], ...]: +) -> Tuple[List[torch.Tensor], ...]: """Non FP8 Grouped GEMM.""" assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." @@ -385,14 +385,14 @@ def grouped_gemm( def fp8_grouped_gemm( A: List[torch.Tensor], - A_scale_inv: Union[torch.Tensor, List[torch.Tensor]], + A_scale_inv: List[torch.Tensor], A_fp8_tensor_offset: int, A_dtype: tex.DType, B: List[torch.Tensor], B_scale_inv: torch.Tensor, B_fp8_tensor_offset: int, B_dtype: tex.DType, - out: Union[torch.Tensor, List[torch.Tensor]], + out: List[torch.Tensor], out_dtype: torch.dtype, workspaces: List[torch.Tensor], m_splits: Optional[List[int]] = None, @@ -404,23 +404,25 @@ def fp8_grouped_gemm( use_bias: bool = False, use_split_accumulator: bool = False, D_dtype: Optional[tex.DType] = None, -) -> Tuple[Union[List[torch.Tensor], torch.Tensor, None], ...]: +) -> Tuple[List[torch.Tensor], ...]: """ TN layout Grouped GEMM with fp8 inputs. - This function accepts two combinations of inputs: - 1. A_scale_inv is a list of tensors, out is a single tensor, and m_splits is not None. + Input requirements: + 1. If len(A_scale_inv) == num_gemms, len(out) must be 1, and m_splits is not None. This is used for the calculation of output (fwd) and dgrad (bwd). - 2. A_scale_inv is a single tensor, out is a list of tensors. This is used for the + 2. if len(A_scale_inv) == 1, len(out) must be num_gemms. This is used for the calculation of wgrad. """ - if isinstance(A_scale_inv, list): - assert isinstance(out, torch.Tensor) and m_splits is not None - elif isinstance(A_scale_inv, torch.Tensor): - assert isinstance(out, (list, tuple)) + num_gemms = len(A) + if num_gemms > 1 and len(A_scale_inv) == num_gemms: + assert len(out) == 1 and m_splits is not None + elif num_gemms > 1 and len(A_scale_inv) == 1: + assert len(out) == num_gemms + elif num_gemms == 1: + assert len(A_scale_inv) == 1 and len(out) == 1 else: - raise ValueError("A_scale_inv should be a list of tensors or a single tensor.") + raise ValueError("Invalid input combinations of A_scale_inv and out.") - num_gemms = len(A) empty_tensor = _empty_tensor() empty_tensors = [empty_tensor] * num_gemms if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: @@ -435,18 +437,18 @@ def fp8_grouped_gemm( bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype bias_dtype = TE_DType[bias_dtype] gelu_input = empty_tensors + out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype - if not isinstance(out, torch.Tensor): + if len(A_scale_inv) == 1: if gelu: gelu_input = [ torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) for o in out ] - out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype torch.ops.tex_ts.te_grouped_gemm_ts( A, - A_scale_inv, + A_scale_inv[0], A_fp8_tensor_offset, A_dtype, True, # transa @@ -472,7 +474,6 @@ def fp8_grouped_gemm( else: if gelu: gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits] - out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype torch.ops.tex_ts.te_grouped_gemm_single_output_ts( A, @@ -486,7 +487,7 @@ def fp8_grouped_gemm( B_dtype, False, # transb m_splits, - out, + out[0], 0 if out_offset is None else out_offset, empty_tensor if out_offset is None else fp8_meta_tensor.scale, out_dtype, diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b1c9f5670e..ca100392c7 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -177,7 +177,7 @@ def forward( inputmat_scale_inv, 0, fp8_dtype_forward, - out, + [out], activation_dtype, get_multi_stream_cublas_workspace(), m_splits=m_splits, @@ -367,7 +367,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.fp8_meta["scaling_bwd"].scale_inv, _GRAD_OUTPUT, fp8_dtype_backward, - dgrad, + [dgrad], ctx.activation_dtype, get_multi_stream_cublas_workspace(), m_splits=ctx.m_splits, @@ -413,7 +413,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inp._data if isinstance(inp, Float8Tensor) else inp for inp in inputmats_t ], - inputmat_scale_inv, + [inputmat_scale_inv], 0, fp8_dtype_forward, grad_output_t, From 09befbe84b721bb4fd39cc67521046062289a5e6 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Wed, 4 Sep 2024 22:38:47 -0700 Subject: [PATCH 7/7] fix ut Signed-off-by: Xin Yao --- tests/pytorch/test_numerics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 998f30abe3..f36e1e33e9 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1860,7 +1860,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): fp8_grouped_gemm( A_fp8, - scale_inv, + [scale_inv], 0, # A_offset tex.DType.kFloat8E4M3, B_fp8,