Skip to content

Commit

Permalink
add unit test for gmm
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangzefeng92 committed Jul 25, 2024
1 parent 9c98622 commit 0a9046b
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 14 deletions.
12 changes: 12 additions & 0 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,15 @@ void extGroupedGemm(at::Tensor& out, const at::Tensor& a, const at::Tensor& b,
callDiopi(diopiGroupedGemm, out, a, b, batchSizes, transA, transB);
}

void extGroupedGemmBackward(at::Tensor& gradA, at::Tensor& gradB,
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& batchSizes,
const at::Tensor& grad, bool transA,
bool transB) {
callDiopi(diopiGroupedGemmBackward, gradA, gradB, a, b, batchSizes, grad,
transA, transB);
}

// 判断是否有对应的 diopi 实现:
// 如果有, 则直接 pybind 上去;
// 否则不注册, 等到 python 层处理.
Expand Down Expand Up @@ -507,6 +516,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
if (&diopiGroupedGemm != nullptr) {
m.def("gmm", &extGroupedGemm, "deeplink extGroupedGemm");
}
if (&diopiGroupedGemmBackward != nullptr) {
m.def("gmm_backward", &extGroupedGemmBackward, "deeplink extGroupedGemm");
}
}

} // namespace dipu::dipu_ext
2 changes: 0 additions & 2 deletions deeplink_ext/internevo_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@
from .gmm import GroupedGemm
except:
print(_not_impl.format(op_name="grouped gemm"))
from .gmm_fallback import gmm_fallback as gmm_forward
from .gmm_fallback import gmm_fallback as GroupedGemm


__all__ = [
Expand Down
12 changes: 8 additions & 4 deletions deeplink_ext/internevo_ops/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def gmm_forward(a, b, batch_sizes, trans_a, trans_b):
if batch_sizes.is_cuda:
ext.gmm(out, a, b, batch_sizes, trans_a, trans_b)
else:
ext.gmm(out, a, b, batch_sizes.cuda, trans_a, trans_b)
ext.gmm(out, a, b, batch_sizes.cuda(), trans_a, trans_b)

return out

Expand All @@ -43,9 +43,13 @@ def backward(ctx, grad):
a, b, batch_sizes = ctx.saved_tensors
trans_b = ctx.trans_b

gradA = gmm_forward(grad, b, batch_sizes, False, trans_b)
# gradA = gmm_forward(grad, b, batch_sizes, False, not trans_b)
# lhs, rhs = (grad, a) if trans_b else (a, grad)
# gradB = gmm_forward(lhs, rhs, batch_sizes, True, False)

gradA = torch.empty_like(a)
gradB = torch.empty_like(b)
ext.gmm_backward(gradA, gradB, a, b, batch_sizes.cuda(), grad, False, trans_b)

lhs, rhs = (grad, a) if trans_b else (a, grad)
gradB = gmm_forward(lhs, rhs, batch_sizes, True, False)

return gradA, gradB, None, None
15 changes: 7 additions & 8 deletions tests/internevo/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def add_transpose_flags(x):


_TEST_PROBLEMS = add_transpose_flags((
(1, 128, 128, 128),
(8, 128, 128, 128),
(16, 128, 128, 128),
(1, 128, 256, 512),
(8, 128, 256, 512),
# (1, 128, 128, 128),
# (8, 128, 128, 128),
# (16, 128, 128, 128),
# (1, 128, 256, 512),
# (8, 128, 256, 512),
(16, 128, 256, 512),
))

Expand Down Expand Up @@ -67,7 +67,7 @@ def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b):
a_ref = a.detach().clone().requires_grad_(True)
b_ref = b.detach().clone().requires_grad_(True)

out = GroupedGemm(a, b, batch_sizes, False, trans_b)
out = GroupedGemm.apply(a, b, batch_sizes, trans_b)
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
self.assertTrue(allclose(out.cpu(), expected_out.cpu()))

Expand All @@ -94,7 +94,7 @@ def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b):
a_ref = a.detach().clone().requires_grad_(True)
b_ref = b.detach().clone().requires_grad_(True)

out = GroupedGemm(a, b, batch_sizes, False, trans_b)
out = GroupedGemm.apply(a, b, batch_sizes, trans_b)
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
self.assertTrue(allclose(out.cpu(), expected_out.cpu()))

Expand All @@ -105,6 +105,5 @@ def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b):
self.assertTrue(allclose(b.grad.cpu(), b_ref.grad.cpu()))



if __name__ == '__main__':
unittest.main()

0 comments on commit 0a9046b

Please sign in to comment.