From 047a50722780e7b647f9107783e210021190edc3 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Mon, 9 Sep 2024 22:30:48 +0800 Subject: [PATCH] [PyTorch] Propagate fp8 scale-inverse modification to `GroupedLinear` (#1128) * propagate scale_inv modification to GroupedLinear Signed-off-by: Xin Yao * optimization for separate scale_inv of weights and single output Signed-off-by: Xin Yao * let grouped gemm support different input combinations Signed-off-by: Xin Yao * fix type Signed-off-by: Xin Yao * add contiguous check Signed-off-by: Xin Yao * use len() instead of isinstance Signed-off-by: Xin Yao * fix ut Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_numerics.py | 31 +++-- .../pytorch/cpp_extensions/gemm.py | 129 ++++++++++++------ .../pytorch/cpp_extensions/transpose.py | 3 +- transformer_engine/pytorch/csrc/extensions.h | 10 ++ .../pytorch/csrc/extensions/gemm.cu | 61 +++++++++ transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 36 +++++ .../pytorch/module/grouped_linear.py | 43 ++++-- 7 files changed, 249 insertions(+), 64 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 723f68369b..ad34b4996f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1266,12 +1266,15 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False ) inp_hidden_states.retain_grad() - m = config.seq_len // 16 - dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() - dist.append(dist[-1]) # Manually add a zero - m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) - m_splits = m_splits * 16 - assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms + if num_gemms > 1: + m = config.seq_len // 16 + dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() + dist.append(dist[-1]) # Manually add a zero + m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) + m_splits = m_splits * 16 + assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms + else: + m_splits = torch.tensor([config.seq_len]) with fp8_autocast(enabled=fp8): if isinstance(block, GroupedLinear): @@ -1353,7 +1356,7 @@ def test_grouped_linear_accuracy( @pytest.mark.parametrize("parallel_mode", ["column", "row"]) def test_grouped_linear_accuracy_parallel_mode(parallel_mode): - """Split the tests to reduce CI time""" + """Split the tests to save CI time""" test_grouped_linear_accuracy( dtype=torch.float32, num_gemms=6, @@ -1365,6 +1368,18 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode): ) +def test_grouped_linear_accuracy_single_gemm(): + """Split the tests to save CI time""" + test_grouped_linear_accuracy( + dtype=torch.float32, + num_gemms=1, + bs=2, + model=list(model_configs.keys())[0], + fp8=True, + fp8_model_params=True, + ) + + def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): @@ -2034,7 +2049,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): fp8_grouped_gemm( A_fp8, - scale_inv, + [scale_inv], 0, # A_offset tex.DType.kFloat8E4M3, B_fp8, diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 8502f70491..fd1eb4a810 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -11,7 +11,12 @@ from ..utils import assert_dim_for_fp8_exec -__all__ = ["gemm", "fp8_gemm", "grouped_gemm", "fp8_grouped_gemm"] +__all__ = [ + "gemm", + "fp8_gemm", + "grouped_gemm", + "fp8_grouped_gemm", +] @functools.lru_cache(maxsize=None) @@ -313,7 +318,7 @@ def grouped_gemm( layout: str = "TN", bias: Optional[List[torch.Tensor]] = None, use_bias: bool = False, -) -> Tuple[Union[List[torch.Tensor], None], ...]: +) -> Tuple[List[torch.Tensor], ...]: """Non FP8 Grouped GEMM.""" assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." @@ -380,7 +385,7 @@ def grouped_gemm( def fp8_grouped_gemm( A: List[torch.Tensor], - A_scale_inv: torch.Tensor, + A_scale_inv: List[torch.Tensor], A_fp8_tensor_offset: int, A_dtype: tex.DType, B: List[torch.Tensor], @@ -390,6 +395,7 @@ def fp8_grouped_gemm( out: List[torch.Tensor], out_dtype: torch.dtype, workspaces: List[torch.Tensor], + m_splits: Optional[List[int]] = None, out_offset: Optional[int] = None, fp8_meta_tensor: tex.FP8TensorMeta = None, gelu: bool = False, @@ -398,16 +404,25 @@ def fp8_grouped_gemm( use_bias: bool = False, use_split_accumulator: bool = False, D_dtype: Optional[tex.DType] = None, -) -> Tuple[Union[List[torch.Tensor], None], ...]: +) -> Tuple[List[torch.Tensor], ...]: """ TN layout Grouped GEMM with fp8 inputs. - This method assumes the scale/scale_inv/amax of A/B/out is contiguous in the meta tensor. - scale: [ ...A_scale... | ...B_scale... | ...out_scale...] - scale_inv: [ ...A_scale_inv... | ...B_scale_inv... | ...out_scale_inv...] - amax: [ ...A_amax... | ...B_amax... | ...out_amax...] + Input requirements: + 1. If len(A_scale_inv) == num_gemms, len(out) must be 1, and m_splits is not None. + This is used for the calculation of output (fwd) and dgrad (bwd). + 2. if len(A_scale_inv) == 1, len(out) must be num_gemms. This is used for the + calculation of wgrad. """ - num_gemms = len(A) + if num_gemms > 1 and len(A_scale_inv) == num_gemms: + assert len(out) == 1 and m_splits is not None + elif num_gemms > 1 and len(A_scale_inv) == 1: + assert len(out) == num_gemms + elif num_gemms == 1: + assert len(A_scale_inv) == 1 and len(out) == 1 + else: + raise ValueError("Invalid input combinations of A_scale_inv and out.") + empty_tensor = _empty_tensor() empty_tensors = [empty_tensor] * num_gemms if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: @@ -420,41 +435,71 @@ def fp8_grouped_gemm( # Use bfloat16 as default bias_dtype bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype - if gelu: - gelu_input = [ - torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) - for o in out - ] - else: - gelu_input = empty_tensors bias_dtype = TE_DType[bias_dtype] - + gelu_input = empty_tensors out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype - torch.ops.tex_ts.te_grouped_gemm_ts( - A, - A_scale_inv, - A_fp8_tensor_offset, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor_offset, - B_dtype, - False, # transb - out, - 0 if out_offset is None else out_offset, - empty_tensor if out_offset is None else fp8_meta_tensor.scale, - out_dtype, - empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, - bias if use_bias else empty_tensors, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspaces, - workspaces[0].shape[0], - accumulate, - use_split_accumulator, - ) + if len(A_scale_inv) == 1: + if gelu: + gelu_input = [ + torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) + for o in out + ] + + torch.ops.tex_ts.te_grouped_gemm_ts( + A, + A_scale_inv[0], + A_fp8_tensor_offset, + A_dtype, + True, # transa + B, + B_scale_inv, + B_fp8_tensor_offset, + B_dtype, + False, # transb + out, + 0 if out_offset is None else out_offset, + empty_tensor if out_offset is None else fp8_meta_tensor.scale, + out_dtype, + empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, + bias if use_bias else empty_tensors, + bias_dtype, + gelu_input, # this is pre_gelu_out + False, # grad + workspaces, + workspaces[0].shape[0], + accumulate, + use_split_accumulator, + ) + else: + if gelu: + gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits] + + torch.ops.tex_ts.te_grouped_gemm_single_output_ts( + A, + A_scale_inv, + A_fp8_tensor_offset, + A_dtype, + True, # transa + B, + B_scale_inv, + B_fp8_tensor_offset, + B_dtype, + False, # transb + m_splits, + out[0], + 0 if out_offset is None else out_offset, + empty_tensor if out_offset is None else fp8_meta_tensor.scale, + out_dtype, + empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, + bias if use_bias else empty_tensors, + bias_dtype, + gelu_input, # this is pre_gelu_out + False, # grad + workspaces, + workspaces[0].shape[0], + accumulate, + use_split_accumulator, + ) return out, gelu_input diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index 37a1b59da2..ddc3b67e9e 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -175,6 +175,7 @@ def fp8_multi_cast_transpose_fused( amax_indices: List[int], scale_inv_indices: List[int], otype: tex.DType, + scale_inv: Optional[torch.Tensor] = None, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """Cast + Transpose with FP8 output""" @@ -182,7 +183,7 @@ def fp8_multi_cast_transpose_fused( input_list, fp8_meta_tensor.scale, fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, + scale_inv if scale_inv is not None else fp8_meta_tensor.scale_inv, scale_indices, amax_indices, scale_inv_indices, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 31103cbe8e..c797208e06 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -165,6 +165,16 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); +void te_grouped_gemm_single_output( + std::vector A, std::vector A_scale_inverse, int A_offset, + transformer_engine::DType A_type, bool transa, std::vector B, + at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, + std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, + transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, + std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count); + /*************************************************************************************************** * Transpose **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index 7405914a0e..ba9851e7e8 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -151,3 +151,64 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int te_workspace.data(), accumulate, use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream()); } + +void te_grouped_gemm_single_output( + std::vector A, std::vector A_scale_inverse, int A_offset, + transformer_engine::DType A_type, bool transa, std::vector B, + at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, + std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, + transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, + std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count) { + using namespace transformer_engine; + std::vector te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace; + std::vector tensor_wrappers; + auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + transformer_engine::DType dtype, void* amax_dptr, + void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { + tensor_wrappers.emplace_back( + makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); + return tensor_wrappers.back().data(); + }; + NVTE_CHECK(D.is_contiguous(), "D must be contiguous."); + void* d_i_ptr = reinterpret_cast(D.data_ptr()); + for (size_t i = 0; i < A.size(); i++) { + if (m_splits[i] == 0) continue; + NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); + NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); + te_A.emplace_back(make_tensor( + A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))}, + A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset))); + te_B.emplace_back(make_tensor( + B[i].data_ptr(), {static_cast(B[i].size(0)), static_cast(B[i].size(1))}, + B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i))); + te_D.emplace_back(make_tensor( + d_i_ptr, {static_cast(m_splits[i]), static_cast(A[i].size(0))}, D_type, + getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr)); + te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast(bias[i].size(0))}, + bias_type, nullptr, nullptr, nullptr)); + + const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr + ? std::vector{static_cast(pre_gelu_out[i].size(0))} + : std::vector{static_cast(pre_gelu_out[i].size(0)), + static_cast(pre_gelu_out[i].size(1))}; + te_pre_gelu_out.emplace_back(make_tensor( + pre_gelu_out[i].data_ptr(), gelu_shape, + GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr)); + // Move the D pointer to the next split. + char* char_ptr = reinterpret_cast(d_i_ptr); + char_ptr += m_splits[i] * A[i].size(0) * D.element_size(); + d_i_ptr = reinterpret_cast(char_ptr); + } + for (size_t i = 0; i < workspace.size(); i++) { + te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte, + nullptr, nullptr, nullptr)); + } + + // For now, we only have multi-stream cublas backend. + nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), + te_pre_gelu_out.data(), te_A.size(), transa, transb, grad, + te_workspace.data(), accumulate, use_split_accumulator, + math_sm_count, at::cuda::getCurrentCUDAStream()); +} diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index 8c480e8343..9f31dba669 100644 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -305,6 +305,41 @@ std::vector te_grouped_gemm_ts( return D; } +at::Tensor te_grouped_gemm_single_output_ts( + std::vector A, std::vector A_scale_inverse, int64_t A_offset, + int64_t A_type, int64_t transa, std::vector B, at::Tensor B_scale_inverse, + int64_t B_offset, int64_t B_type, int64_t transb, std::vector m_splits, at::Tensor D, + int64_t D_offset, at::Tensor D_scale, int64_t D_type, at::Tensor D_amax, + std::vector bias, int64_t bias_type, std::vector pre_gelu_out, + int64_t grad, std::vector workspace, int64_t workspaceSize, int64_t accumulate, + int64_t use_split_accumulator) { + // cast inputs to types accepted by te_gemm + transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); + bool transa_arg = static_cast(transa); + transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); + bool transb_arg = static_cast(transb); + transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); + transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type); + bool grad_arg = static_cast(grad); + size_t workspaceSize_arg = static_cast(workspaceSize); + bool accumulate_arg = static_cast(accumulate); + bool use_split_accumulator_arg = static_cast(use_split_accumulator); + + // Set an external SM Margin to all the GEMMs. + // This comes in handy when DP is overlapped with GEMMs + + const int device_id = at::cuda::current_device(); + const int sm_count = transformer_engine::cuda::sm_count(device_id); + int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + + te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, + B_scale_inverse, B_offset, B_type_arg, transb_arg, m_splits, D, + D_offset, D_scale, D_type_arg, D_amax, bias, bias_type_arg, + pre_gelu_out, grad_arg, workspace, workspaceSize_arg, + accumulate_arg, use_split_accumulator_arg, num_math_sms); + return D; +} + at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, double eps, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, @@ -371,6 +406,7 @@ TORCH_LIBRARY(tex_ts, m) { m.def("srelu_ts", &srelu_ts); m.def("te_gemm_ts", &te_gemm_ts); m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts); + m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts); m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts); diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a91ff5c361..ca100392c7 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -42,6 +42,7 @@ from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor +from ..export import is_in_onnx_export_mode __all__ = ["GroupedLinear"] @@ -102,10 +103,12 @@ def forward( inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats] inputmats = [] inputmats_t = [] + inputmat_scale_inv = None global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device) if ( not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled @@ -121,6 +124,7 @@ def forward( indices, # amax_indices indices, # scale_inv_indices fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) else: # FP8 input for forward @@ -130,9 +134,22 @@ def forward( fp8_meta["scaling_fwd"], _GEMM_INPUT + i, fp8_dtype_forward, + scale_inv=inputmat_scale_inv, ) for i in range(num_gemms) ] + + # Hack for ONNX export + # Note: ONNX models are represented as a graph of tensor + # operations, so the in-place scale-inv update doesn't fit + # very well. We work around this by making it look like + # the scale-inv tensor is initialized with a copy. + # Note: ONNX export expects FP8 scales can be represented + # with constant ops. However, copying into a buffer + # involves an expand op for array broadcasting. We work + # around this by filling the buffer instead. + if is_in_onnx_export_mode(): + inputmat_scale_inv.fill_(inputmat_scale_inv.item()) else: inputmats = inputmats_no_fp8 @@ -153,16 +170,17 @@ def forward( _ = fp8_grouped_gemm( [w._data for w in weights_fp8], - fp8_meta["scaling_fwd"].scale_inv, - _GEMM_WEIGHT, + [w._scale_inv for w in weights_fp8], + 0, # weight offset is 0 for the newly created _scale_inv fp8_dtype_forward, inputmats, - fp8_meta["scaling_fwd"].scale_inv, - _GEMM_INPUT, + inputmat_scale_inv, + 0, fp8_dtype_forward, - torch.split(out, m_splits), + [out], activation_dtype, get_multi_stream_cublas_workspace(), + m_splits=m_splits, bias=biases, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, @@ -230,7 +248,7 @@ def forward( t.activation_offloading = True ctx.save_for_backward( - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + inputmat_scale_inv, *saved_inputmats, *saved_inputmats_t, *weights, @@ -270,7 +288,7 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: with torch.cuda.nvtx.range("_GroupedLinear_backward"): ( - fwd_scale_inverses, + inputmat_scale_inv, *saved_tensors, ) = ctx.saved_tensors inputmats = saved_tensors[: ctx.num_gemms] @@ -342,18 +360,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) fp8_grouped_gemm( [w.transpose_2d() for w in weights_fp8], - torch.cat( - [w._scale_inv for w in weights_fp8] - ), # avoiding torch.cat requires another interface + [w._scale_inv for w in weights_fp8], 0, # weight offset is 0 for the newly created _scale_inv weights_fp8[0]._fp8_dtype, grad_output_c, ctx.fp8_meta["scaling_bwd"].scale_inv, _GRAD_OUTPUT, fp8_dtype_backward, - torch.split(dgrad, ctx.m_splits), + [dgrad], ctx.activation_dtype, get_multi_stream_cublas_workspace(), + m_splits=ctx.m_splits, use_split_accumulator=_2X_ACC_DGRAD, ) else: @@ -396,8 +413,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inp._data if isinstance(inp, Float8Tensor) else inp for inp in inputmats_t ], - fwd_scale_inverses, - _GEMM_INPUT, + [inputmat_scale_inv], + 0, fp8_dtype_forward, grad_output_t, ctx.fp8_meta["scaling_bwd"].scale_inv,