diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 8502f70491..425796b826 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -11,7 +11,13 @@ from ..utils import assert_dim_for_fp8_exec -__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"] +__all__ = [ + "gemm", + "fp8_gemm", + "grouped_gemm", + "fp8_grouped_gemm", + "fp8_grouped_gemm_single_output", +] @functools.lru_cache(maxsize=None) @@ -458,3 +464,86 @@ def fp8_grouped_gemm( ) return out, gelu_input + + +def fp8_grouped_gemm_single_output( + A: List[torch.Tensor], + A_scale_inv: List[torch.Tensor], + A_fp8_tensor_offset: int, + A_dtype: tex.DType, + B: List[torch.Tensor], + B_scale_inv: torch.Tensor, + B_fp8_tensor_offset: int, + B_dtype: tex.DType, + m_splits: List[int], + out: torch.Tensor, + out_dtype: torch.dtype, + workspaces: List[torch.Tensor], + out_offset: Optional[int] = None, + fp8_meta_tensor: tex.FP8TensorMeta = None, + gelu: bool = False, + accumulate: bool = False, + bias: Optional[List[torch.Tensor]] = None, + use_bias: bool = False, + use_split_accumulator: bool = False, + D_dtype: Optional[tex.DType] = None, +) -> Tuple[Union[List[torch.Tensor], None], ...]: + """ + TN layout Grouped GEMM with two lists of fp8 inputs, and a single contiguous output implicitly + splitted by m_splits. + This method assumes the scale_inv of A is a list of tensors. + Used for the calculation of output (fwd) and dgrad (bwd). + """ + + num_gemms = len(A) + empty_tensor = _empty_tensor() + empty_tensors = [empty_tensor] * num_gemms + if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: + assert fp8_meta_tensor is not None and out_offset is not None + for a, b in zip(A, B): + assert_dim_for_fp8_exec(a) + assert_dim_for_fp8_exec(b) + assert A[0].dtype == torch.uint8 + assert B[0].dtype == torch.uint8 + + # Use bfloat16 as default bias_dtype + bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype + if gelu: + gelu_input = [ + torch.empty((m, A[0].size(0)), dtype=bias_dtype) + for m in m_splits + ] + else: + gelu_input = empty_tensors + bias_dtype = TE_DType[bias_dtype] + + out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype + + torch.ops.tex_ts.te_grouped_gemm_single_output_ts( + A, + A_scale_inv, + A_fp8_tensor_offset, + A_dtype, + True, # transa + B, + B_scale_inv, + B_fp8_tensor_offset, + B_dtype, + False, # transb + m_splits, + out, + 0 if out_offset is None else out_offset, + empty_tensor if out_offset is None else fp8_meta_tensor.scale, + out_dtype, + empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, + bias if use_bias else empty_tensors, + bias_dtype, + gelu_input, # this is pre_gelu_out + False, # grad + workspaces, + workspaces[0].shape[0], + accumulate, + use_split_accumulator, + ) + + return out, gelu_input diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1a6f5f157e..e0188460db 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -159,6 +159,16 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); +void te_grouped_gemm_single_output( + std::vector A, std::vector A_scale_inverse, int A_offset, + transformer_engine::DType A_type, bool transa, std::vector B, + at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, + std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, + transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, + std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count); + /*************************************************************************************************** * Transpose **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index 01fb94cab4..fe6ac7f628 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -146,3 +146,63 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int te_workspace.data(), accumulate, use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream()); } + + +void te_grouped_gemm_single_output( + std::vector A, std::vector A_scale_inverse, int A_offset, + transformer_engine::DType A_type, bool transa, std::vector B, + at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, + std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, + transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, + std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count) { + using namespace transformer_engine; + std::vector te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace; + std::vector tensor_wrappers; + auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + transformer_engine::DType dtype, void* amax_dptr, + void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { + tensor_wrappers.emplace_back( + makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); + return tensor_wrappers.back().data(); + }; + void* d_i_ptr = reinterpret_cast(D.data_ptr()); + for (size_t i = 0; i < A.size(); i++) { + if (m_splits[i] == 0) + continue; + te_A.emplace_back(make_tensor( + A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))}, + A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset))); + te_B.emplace_back(make_tensor( + B[i].data_ptr(), {static_cast(B[i].size(0)), static_cast(B[i].size(1))}, + B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i))); + te_D.emplace_back(make_tensor( + d_i_ptr, {static_cast(m_splits[i]), static_cast(A[i].size(0))}, + D_type, getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr)); + te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast(bias[i].size(0))}, + bias_type, nullptr, nullptr, nullptr)); + + const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr + ? std::vector{static_cast(pre_gelu_out[i].size(0))} + : std::vector{static_cast(pre_gelu_out[i].size(0)), + static_cast(pre_gelu_out[i].size(1))}; + te_pre_gelu_out.emplace_back(make_tensor( + pre_gelu_out[i].data_ptr(), gelu_shape, + GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr)); + // Move the D pointer to the next split. + char* char_ptr = reinterpret_cast(d_i_ptr); + char_ptr += m_splits[i] * A[i].size(0) * D.element_size(); + d_i_ptr = reinterpret_cast(char_ptr); + } + for (size_t i = 0; i < workspace.size(); i++) { + te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte, + nullptr, nullptr, nullptr)); + } + + // For now, we only have multi-stream cublas backend. + nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), + te_pre_gelu_out.data(), te_A.size(), transa, transb, grad, + te_workspace.data(), accumulate, use_split_accumulator, + math_sm_count, at::cuda::getCurrentCUDAStream()); +} diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index 8515092ae0..b51399f808 100644 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -305,6 +305,39 @@ std::vector te_grouped_gemm_ts( return D; } +at::Tensor te_grouped_gemm_single_output_ts( + std::vector A, std::vector A_scale_inverse, int64_t A_offset, int64_t A_type, + int64_t transa, std::vector B, at::Tensor B_scale_inverse, int64_t B_offset, + int64_t B_type, int64_t transb, std::vector m_splits, at::Tensor D, int64_t D_offset, at::Tensor D_scale, + int64_t D_type, at::Tensor D_amax, std::vector bias, int64_t bias_type, + std::vector pre_gelu_out, int64_t grad, std::vector workspace, + int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator) { + // cast inputs to types accepted by te_gemm + transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); + bool transa_arg = static_cast(transa); + transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); + bool transb_arg = static_cast(transb); + transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); + transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type); + bool grad_arg = static_cast(grad); + size_t workspaceSize_arg = static_cast(workspaceSize); + bool accumulate_arg = static_cast(accumulate); + bool use_split_accumulator_arg = static_cast(use_split_accumulator); + + // Set an external SM Margin to all the GEMMs. + // This comes in handy when DP is overlapped with GEMMs + + const int device_id = at::cuda::current_device(); + const int sm_count = transformer_engine::cuda::sm_count(device_id); + int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + + te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse, + B_offset, B_type_arg, transb_arg, m_splits, D, D_offset, D_scale, D_type_arg, D_amax, bias, + bias_type_arg, pre_gelu_out, grad_arg, workspace, workspaceSize_arg, + accumulate_arg, use_split_accumulator_arg, num_math_sms); + return D; +} + at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, double eps, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, @@ -371,6 +404,7 @@ TORCH_LIBRARY(tex_ts, m) { m.def("srelu_ts", &srelu_ts); m.def("te_gemm_ts", &te_gemm_ts); m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts); + m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts); m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts); diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 7f56ced8ef..da61e94686 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -36,6 +36,7 @@ fp8_cast_transpose_bgrad_fused, fp8_multi_cast_transpose_fused, fp8_grouped_gemm, + fp8_grouped_gemm_single_output, grouped_gemm, ) from ..constants import GemmParallelModes, dist_group_type @@ -168,18 +169,17 @@ def forward( device=inputmats[0].device, ) - _ = fp8_grouped_gemm( + _ = fp8_grouped_gemm_single_output( [w._data for w in weights_fp8], - torch.cat( - [w._scale_inv for w in weights_fp8] - ), # avoiding torch.cat requires another interface, + [w._scale_inv for w in weights_fp8], 0, # weight offset is 0 for the newly created _scale_inv fp8_dtype_forward, inputmats, inputmat_scale_inv, 0, fp8_dtype_forward, - torch.split(out, m_splits), + m_splits, + out, activation_dtype, get_multi_stream_cublas_workspace(), bias=biases, @@ -359,18 +359,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=grad_output.device, ) - fp8_grouped_gemm( + fp8_grouped_gemm_single_output( [w.transpose_2d() for w in weights_fp8], - torch.cat( - [w._scale_inv for w in weights_fp8] - ), # avoiding torch.cat requires another interface + [w._scale_inv for w in weights_fp8], 0, # weight offset is 0 for the newly created _scale_inv weights_fp8[0]._fp8_dtype, grad_output_c, ctx.fp8_meta["scaling_bwd"].scale_inv, _GRAD_OUTPUT, fp8_dtype_backward, - torch.split(dgrad, ctx.m_splits), + ctx.m_splits, + dgrad, ctx.activation_dtype, get_multi_stream_cublas_workspace(), use_split_accumulator=_2X_ACC_DGRAD,