-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Fix bias for 0-dim tensors in gemm #1246
Conversation
// torch.sum is able to handle 0-dim tensors | ||
if (bias.data_ptr() != nullptr && grad) bias.copy_(B.sum(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should handle the case where B
has no data:
// torch.sum is able to handle 0-dim tensors | |
if (bias.data_ptr() != nullptr && grad) bias.copy_(B.sum(0)); | |
if (bias.data_ptr() != nullptr && grad) { | |
if (B.data_ptr() == nullptr) { | |
bias.zero_(); | |
} else { | |
bias.copy_(B.sum(0)) | |
} | |
} |
It seems we are checking data_ptr() == nullptr
in order to check for empty tensors. As an alternative, it may be better to make our intention explicit:
if (A.numel() == 0 || B.numel() == 0) {
if (D.numel() != 0 && !accumulate) D.zero_();
// torch.sum is able to handle 0-dim tensors
if (bias.numel() != 0 && grad) bias.copy_(B.sum(0));
if (pre_gelu_out.numel() != 0) pre_gelu_out.zero_();
return;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
/te-ci pytorch |
/te-ci pytorch |
I'll merge this PR since the CI passes except for a doc test failing with irrelevant error. |
* fix bias for 0-dim tensor Signed-off-by: Xin Yao <[email protected]> * add check Signed-off-by: Xin Yao <[email protected]> * use numel() instead of nullptr Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]>
* fix bias for 0-dim tensor Signed-off-by: Xin Yao <[email protected]> * add check Signed-off-by: Xin Yao <[email protected]> * use numel() instead of nullptr Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]>
Description
For
te_gemm
andte_grouped_gemm
,grad == false
, we should do nothing tobias
.grad == true
, we should calculategrad_bias
here.Type of change
Checklist: