From 3fe85438975503b0bfe62979aa56e06a7eabbafa Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Fri, 11 Oct 2024 22:22:04 -0700 Subject: [PATCH 1/3] fix bias for 0-dim tensor Signed-off-by: Xin Yao --- transformer_engine/pytorch/csrc/extensions.h | 5 +++++ transformer_engine/pytorch/csrc/extensions/gemm.cu | 6 ++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index c30e583178..7cdf35e794 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -165,6 +165,11 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); +/* + * This method is a performance-optimized version for the calculation of fwd and dgrad. + * It's not for general purpose use. Taking the GEMM (m, k) * (k, n) for example, it handles + * the case where m = 0, but not k = 0 or n = 0. For those cases, please use `te_grouped_gemm`. + */ void te_grouped_gemm_single_output( std::vector A, std::vector A_scale_inverse, int A_offset, transformer_engine::DType A_type, bool transa, std::vector B, diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index ba9851e7e8..76d3ceca14 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -17,7 +17,8 @@ void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType using namespace transformer_engine; if (A.data_ptr() == nullptr || B.data_ptr() == nullptr) { if (D.data_ptr() != nullptr && !accumulate) D.zero_(); - if (bias.data_ptr() != nullptr) bias.zero_(); + // torch.sum is able to handle 0-dim tensors + if (bias.data_ptr() != nullptr && grad) bias.copy_(B.sum(0)); if (pre_gelu_out.data_ptr() != nullptr) pre_gelu_out.zero_(); return; } @@ -111,7 +112,8 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int for (size_t i = 0; i < A.size(); i++) { if (A[i].data_ptr() == nullptr || B[i].data_ptr() == nullptr) { if (D[i].data_ptr() != nullptr && !accumulate) D[i].zero_(); - if (bias[i].data_ptr() != nullptr) bias[i].zero_(); + // torch.sum is able to handle 0-dim tensors + if (bias[i].data_ptr() != nullptr && grad) bias[i].copy_(B[i].sum(0)); if (pre_gelu_out[i].data_ptr() != nullptr) pre_gelu_out[i].zero_(); continue; } From 8035717893300457625b9a095771bfecb17653b9 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Fri, 11 Oct 2024 22:36:42 -0700 Subject: [PATCH 2/3] add check Signed-off-by: Xin Yao --- transformer_engine/pytorch/csrc/extensions.h | 5 ----- transformer_engine/pytorch/csrc/extensions/gemm.cu | 2 ++ 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 7cdf35e794..c30e583178 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -165,11 +165,6 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); -/* - * This method is a performance-optimized version for the calculation of fwd and dgrad. - * It's not for general purpose use. Taking the GEMM (m, k) * (k, n) for example, it handles - * the case where m = 0, but not k = 0 or n = 0. For those cases, please use `te_grouped_gemm`. - */ void te_grouped_gemm_single_output( std::vector A, std::vector A_scale_inverse, int A_offset, transformer_engine::DType A_type, bool transa, std::vector B, diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index 76d3ceca14..a1d170cd48 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -177,6 +177,8 @@ void te_grouped_gemm_single_output( void* d_i_ptr = reinterpret_cast(D.data_ptr()); for (size_t i = 0; i < A.size(); i++) { if (m_splits[i] == 0) continue; + NVTE_CHECK(A[i].data_ptr() != nullptr, "A[", i, "] must not be nullptr."); + NVTE_CHECK(B[i].data_ptr() != nullptr, "B[", i, "] must not be nullptr."); NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); te_A.emplace_back(make_tensor( From b5fa1c3e254693951c63ede927036958481bc190 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Tue, 15 Oct 2024 02:58:39 -0700 Subject: [PATCH 3/3] use numel() instead of nullptr Signed-off-by: Xin Yao --- .../pytorch/csrc/extensions/gemm.cu | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index a1d170cd48..40b96a057f 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -15,11 +15,16 @@ void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { using namespace transformer_engine; - if (A.data_ptr() == nullptr || B.data_ptr() == nullptr) { - if (D.data_ptr() != nullptr && !accumulate) D.zero_(); - // torch.sum is able to handle 0-dim tensors - if (bias.data_ptr() != nullptr && grad) bias.copy_(B.sum(0)); - if (pre_gelu_out.data_ptr() != nullptr) pre_gelu_out.zero_(); + if (A.numel() == 0 || B.numel() == 0) { + if (D.numel() != 0 && !accumulate) D.zero_(); + if (bias.numel() != 0 && grad) { + if (B.numel() == 0) { + bias.zero_(); + } else { + bias.copy_(B.sum(0)); + } + } + if (pre_gelu_out.numel() != 0) pre_gelu_out.zero_(); return; } @@ -110,11 +115,16 @@ void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int return tensor_wrappers.back().data(); }; for (size_t i = 0; i < A.size(); i++) { - if (A[i].data_ptr() == nullptr || B[i].data_ptr() == nullptr) { - if (D[i].data_ptr() != nullptr && !accumulate) D[i].zero_(); - // torch.sum is able to handle 0-dim tensors - if (bias[i].data_ptr() != nullptr && grad) bias[i].copy_(B[i].sum(0)); - if (pre_gelu_out[i].data_ptr() != nullptr) pre_gelu_out[i].zero_(); + if (A[i].numel() == 0 || B[i].numel() == 0) { + if (D[i].numel() != 0 && !accumulate) D[i].zero_(); + if (bias[i].numel() != 0 && grad) { + if (B[i].numel() == 0) { + bias[i].zero_(); + } else { + bias[i].copy_(B[i].sum(0)); + } + } + if (pre_gelu_out[i].numel() != 0) pre_gelu_out[i].zero_(); continue; }