Skip to content

Commit

Permalink
fix bias for 0-dim tensor
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Oct 12, 2024
1 parent b36bd0a commit 73a1749
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
5 changes: 5 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int
std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);

/*
* This method is a performance-optimized version for the calculation of fwd and dgrad.
* It's not for general purpose use. Taking the GEMM (m, k) * (k, n) for example, it handles
* the case where m = 0, but not k = 0 or n = 0. For those cases, please use `te_grouped_gemm`.
*/
void te_grouped_gemm_single_output(
std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/pytorch/csrc/extensions/gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType
using namespace transformer_engine;
if (A.data_ptr() == nullptr || B.data_ptr() == nullptr) {
if (D.data_ptr() != nullptr && !accumulate) D.zero_();
if (bias.data_ptr() != nullptr) bias.zero_();
// torch.sum is able to handle 0-dim tensors
if (bias.data_ptr() != nullptr && grad) bias.copy_(B.sum(0));
if (pre_gelu_out.data_ptr() != nullptr) pre_gelu_out.zero_();
return;
}
Expand Down Expand Up @@ -111,7 +112,8 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int
for (size_t i = 0; i < A.size(); i++) {
if (A[i].data_ptr() == nullptr || B[i].data_ptr() == nullptr) {
if (D[i].data_ptr() != nullptr && !accumulate) D[i].zero_();
if (bias[i].data_ptr() != nullptr) bias[i].zero_();
// torch.sum is able to handle 0-dim tensors
if (bias[i].data_ptr() != nullptr && grad) bias[i].copy_(B[i].sum(0));
if (pre_gelu_out[i].data_ptr() != nullptr) pre_gelu_out[i].zero_();
continue;
}
Expand Down

0 comments on commit 73a1749

Please sign in to comment.