From a40b24e9a32ee9d7abe6a88c902737d904c7d4b6 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Thu, 18 Jul 2024 16:25:31 +0800 Subject: [PATCH 1/4] add extGroupedGemm --- csrc/extensions.cpp | 8 ++++ deeplink_ext/internevo_ops/__init__.py | 12 +++++ deeplink_ext/internevo_ops/gmm.py | 51 ++++++++++++++++++++++ deeplink_ext/internevo_ops/gmm_fallback.py | 13 ++++++ 4 files changed, 84 insertions(+) create mode 100644 deeplink_ext/internevo_ops/gmm.py create mode 100644 deeplink_ext/internevo_ops/gmm_fallback.py diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index bb46f71..998127d 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -397,6 +397,11 @@ void extRotaryEmbeddingV2(at::Tensor& query, at::Tensor& key, callDiopi(diopiRotaryEmbeddingV2, query, key, cos, sin, dim); } +void extGroupedGemm(at::Tensor& out, const at::Tensor& a, const at::Tensor& b, + const at::Tensor& batchSizes, bool transA, bool transB) { + callDiopi(diopiGroupedGemm, out, a, b, batchSizes, transA, transB); +} + // 判断是否有对应的 diopi 实现: // 如果有, 则直接 pybind 上去; // 否则不注册, 等到 python 层处理. @@ -499,6 +504,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rotary_embedding_v2", &extRotaryEmbeddingV2, "deeplink extRotaryEmbeddingV2"); } + if (&diopiGroupedGemm != nullptr) { + m.def("gmm", &extGroupedGemm, "deeplink extGroupedGemm"); + } } } // namespace dipu::dipu_ext diff --git a/deeplink_ext/internevo_ops/__init__.py b/deeplink_ext/internevo_ops/__init__.py index ac76456..c11f053 100644 --- a/deeplink_ext/internevo_ops/__init__.py +++ b/deeplink_ext/internevo_ops/__init__.py @@ -8,6 +8,7 @@ print(_not_impl.format(op_name="adamw")) from torch.optim import AdamW as AdamW + try: from .flash_attention import FlashSelfAttention, FlashCrossAttention except Exception as e: @@ -33,6 +34,15 @@ from .rotary_embedding_fallback import ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_ +try: + from .gmm import gmm_forward + from .gmm import GroupedGemm +except: + print(_not_impl.format(op_name="grouped gemm")) + from .gmm_fallback import gmm_fallback as gmm_forward + from .gmm_fallback import gmm_fallback as GroupedGemm + + __all__ = [ "AdamW", "FlashSelfAttention", @@ -40,4 +50,6 @@ "MixedFusedRMSNorm", "ApplyRotaryEmb", "ApplyRotaryEmbQKV_", + "gmm_forward", + "GroupedGemm", ] diff --git a/deeplink_ext/internevo_ops/gmm.py b/deeplink_ext/internevo_ops/gmm.py new file mode 100644 index 0000000..0d7985b --- /dev/null +++ b/deeplink_ext/internevo_ops/gmm.py @@ -0,0 +1,51 @@ +import numbers +import torch +from torch.nn import init + +import deeplink_ext.cpp_extensions as ext + +assert hasattr(ext, "gmm") + + +__all__ = ["gmm", "GroupedGemmFunc"] + + +def gmm_forward(a, b, batchSizes, transA, transB): + assert not (transA and transB), "'transA' and 'transB' can't both be true" + assert batchSizes.ndim == 1, "Expected 1d tensor for batchSizes" + assert a.ndim == 2, "Expected 2d tensor for 'a'" + assert b.ndim == (2 if transA else 3) + + shape = ( + (batchSizes.shape[0], a.shape[1], b.shape[1]) + if transA + else (a.shape[0], (b.shape[1] if transB else b.shape[2])) + ) + out = torch.empty(*shape, device=a.device, dtype=a.dtype) + + if batchSizes.is_cuda: + ext.gmm(out, a, b, batchSizes, transA, transB) + else: + ext.gmm(out, a, b, batchSizes.cuda, transA, transB) + + return out + + +class GroupedGemm(torch.autograd.Fuction): + @staticmethod + def forward(ctx, a, b, batchSizes, transB): + ctx.save_for_backward(a, b, batchSizes) + ctx.transB = transB + return gmm_forward(a, b, batchSizes, False, transB) + + @staticmethod + def backward(ctx, grad): + a, b, batchSizes = ctx.saved_tensors + transB = ctx.transB + + gradA = gmm_forward(grad, b, batchSizes, False, transB) + + lhs, rhs = (grad, a) if transB else (a, grad) + gradB = gmm_forward(lhs, rhs, batchSizes, True, False) + + return gradA, gradB, None, None \ No newline at end of file diff --git a/deeplink_ext/internevo_ops/gmm_fallback.py b/deeplink_ext/internevo_ops/gmm_fallback.py new file mode 100644 index 0000000..be4bea1 --- /dev/null +++ b/deeplink_ext/internevo_ops/gmm_fallback.py @@ -0,0 +1,13 @@ +import torch + + +def gmm_fallback(a, b, batch_sizes, trans_a=False, trans_b=False): + batch_sizes = batch_sizes.numpy() + + out = [] + start = 0 + for i, size in enumerate(batch_sizes): + rhs = b[i, :, :].t() if trans_b else b[i, :, :] + out.append(a[start:start + size, :] @ rhs) + start += size + return torch.cat(out) \ No newline at end of file From a38923ba19c41c71aeac3442a93f252c1a5e6405 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Mon, 22 Jul 2024 10:35:48 +0800 Subject: [PATCH 2/4] add unit test for gmm --- tests/internevo/test_gmm.py | 110 ++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 tests/internevo/test_gmm.py diff --git a/tests/internevo/test_gmm.py b/tests/internevo/test_gmm.py new file mode 100644 index 0000000..76b61eb --- /dev/null +++ b/tests/internevo/test_gmm.py @@ -0,0 +1,110 @@ +import unittest +import itertools + +from absl.testing import parameterized +# from grouped_gemm import ops +from deeplink_ext.internevo_ops import GroupedGemm +import numpy as np +import torch + + +def allclose(x, y, pct=2.0): + mask = torch.isclose(x, y, rtol=1e-5) + pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 + if pct_diff > pct: + print(x[torch.logical_not(mask)], y[torch.logical_not(mask)]) + print("{:.2f}% of values not close.".format(pct_diff)) + return False + return True + + +def add_transpose_flags(x): + out = [] + for y in x: + for f in [(False,), (True,)]: + out.append(y + f) + return out + + +_TEST_PROBLEMS = add_transpose_flags(( + (1, 128, 128, 128), + (8, 128, 128, 128), + (16, 128, 128, 128), + (1, 128, 256, 512), + (8, 128, 256, 512), + (16, 128, 256, 512), +)) + + +def randn(bs, x, y): + out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x) + return out.cuda().to(torch.bfloat16) + + +def gmm(a, b, batch_sizes, trans_b=False): + batch_sizes = batch_sizes.numpy() + + out = [] + start = 0 + for i, size in enumerate(batch_sizes): + rhs = b[i, :, :].t() if trans_b else b[i, :, :] + out.append(a[start:start + size, :] @ rhs) + start += size + return torch.cat(out) + + +@parameterized.parameters(*_TEST_PROBLEMS) +class OpsTest(parameterized.TestCase): + + def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b): + torch.manual_seed(0) + a = randn(z, m, k).view(-1, k) + b = randn(z, n, k) if trans_b else randn(z, k, n) + batch_sizes = torch.tensor([m] * z) + + a.requires_grad_(True) + b.requires_grad_(True) + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + + out = GroupedGemm(a, b, batch_sizes, False, trans_b) + expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) + self.assertTrue(allclose(out.cpu(), expected_out.cpu())) + + # Check gradients. + out.sum().backward() + expected_out.sum().backward() + self.assertTrue(allclose(a.grad.cpu(), a_ref.grad.cpu())) + self.assertTrue(allclose(b.grad.cpu(), b_ref.grad.cpu())) + + def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b): + torch.manual_seed(0) + a = randn(z, m, k).view(-1, k) + b = randn(z, n, k) if trans_b else randn(z, k, n) + + dist = torch.rand(z, ) + dist /= dist.sum() + batch_sizes = (dist * m).to(torch.long) + error = m * z - batch_sizes.sum() + batch_sizes[-1] += error + assert batch_sizes.sum() == (m * z) + + a.requires_grad_(True) + b.requires_grad_(True) + a_ref = a.detach().clone().requires_grad_(True) + b_ref = b.detach().clone().requires_grad_(True) + + out = GroupedGemm(a, b, batch_sizes, False, trans_b) + expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) + self.assertTrue(allclose(out.cpu(), expected_out.cpu())) + + # Check gradients. + out.sum().backward() + expected_out.sum().backward() + self.assertTrue(allclose(a.grad.cpu(), a_ref.grad.cpu())) + self.assertTrue(allclose(b.grad.cpu(), b_ref.grad.cpu())) + + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 9c9862226568a5a06348c0dbc8635688aa35edb6 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Mon, 22 Jul 2024 14:36:53 +0800 Subject: [PATCH 3/4] fix bugs of gmm --- deeplink_ext/internevo_ops/__init__.py | 2 +- deeplink_ext/internevo_ops/gmm.py | 52 +++++++++++++------------- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/deeplink_ext/internevo_ops/__init__.py b/deeplink_ext/internevo_ops/__init__.py index c11f053..7c52f7b 100644 --- a/deeplink_ext/internevo_ops/__init__.py +++ b/deeplink_ext/internevo_ops/__init__.py @@ -40,7 +40,7 @@ except: print(_not_impl.format(op_name="grouped gemm")) from .gmm_fallback import gmm_fallback as gmm_forward - from .gmm_fallback import gmm_fallback as GroupedGemm + from .gmm_fallback import gmm_fallback as GroupedGemm __all__ = [ diff --git a/deeplink_ext/internevo_ops/gmm.py b/deeplink_ext/internevo_ops/gmm.py index 0d7985b..e4e54c9 100644 --- a/deeplink_ext/internevo_ops/gmm.py +++ b/deeplink_ext/internevo_ops/gmm.py @@ -7,45 +7,45 @@ assert hasattr(ext, "gmm") -__all__ = ["gmm", "GroupedGemmFunc"] +__all__ = ["gmm_forward", "GroupedGemm"] -def gmm_forward(a, b, batchSizes, transA, transB): - assert not (transA and transB), "'transA' and 'transB' can't both be true" - assert batchSizes.ndim == 1, "Expected 1d tensor for batchSizes" +def gmm_forward(a, b, batch_sizes, trans_a, trans_b): + assert not (trans_a and trans_b), "'trans_a' and 'trans_b' can't both be true" + assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" assert a.ndim == 2, "Expected 2d tensor for 'a'" - assert b.ndim == (2 if transA else 3) + assert b.ndim == (2 if trans_a else 3) shape = ( - (batchSizes.shape[0], a.shape[1], b.shape[1]) - if transA - else (a.shape[0], (b.shape[1] if transB else b.shape[2])) + (batch_sizes.shape[0], a.shape[1], b.shape[1]) + if trans_a + else (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) ) out = torch.empty(*shape, device=a.device, dtype=a.dtype) - if batchSizes.is_cuda: - ext.gmm(out, a, b, batchSizes, transA, transB) + if batch_sizes.is_cuda: + ext.gmm(out, a, b, batch_sizes, trans_a, trans_b) else: - ext.gmm(out, a, b, batchSizes.cuda, transA, transB) + ext.gmm(out, a, b, batch_sizes.cuda, trans_a, trans_b) return out -class GroupedGemm(torch.autograd.Fuction): +class GroupedGemm(torch.autograd.Function): @staticmethod - def forward(ctx, a, b, batchSizes, transB): - ctx.save_for_backward(a, b, batchSizes) - ctx.transB = transB - return gmm_forward(a, b, batchSizes, False, transB) - + def forward(ctx, a, b, batch_sizes, trans_b): + ctx.save_for_backward(a, b, batch_sizes) + ctx.trans_b = trans_b + return gmm_forward(a, b, batch_sizes, False, trans_b) + @staticmethod def backward(ctx, grad): - a, b, batchSizes = ctx.saved_tensors - transB = ctx.transB - - gradA = gmm_forward(grad, b, batchSizes, False, transB) - - lhs, rhs = (grad, a) if transB else (a, grad) - gradB = gmm_forward(lhs, rhs, batchSizes, True, False) - - return gradA, gradB, None, None \ No newline at end of file + a, b, batch_sizes = ctx.saved_tensors + trans_b = ctx.trans_b + + gradA = gmm_forward(grad, b, batch_sizes, False, trans_b) + + lhs, rhs = (grad, a) if trans_b else (a, grad) + gradB = gmm_forward(lhs, rhs, batch_sizes, True, False) + + return gradA, gradB, None, None From 0a9046b8339c71652a7dd1564a6a6b8b8c545aa1 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Thu, 25 Jul 2024 14:14:38 +0800 Subject: [PATCH 4/4] add unit test for gmm --- csrc/extensions.cpp | 12 ++++++++++++ deeplink_ext/internevo_ops/__init__.py | 2 -- deeplink_ext/internevo_ops/gmm.py | 12 ++++++++---- tests/internevo/test_gmm.py | 15 +++++++-------- 4 files changed, 27 insertions(+), 14 deletions(-) diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 998127d..ba3ca9d 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -402,6 +402,15 @@ void extGroupedGemm(at::Tensor& out, const at::Tensor& a, const at::Tensor& b, callDiopi(diopiGroupedGemm, out, a, b, batchSizes, transA, transB); } +void extGroupedGemmBackward(at::Tensor& gradA, at::Tensor& gradB, + const at::Tensor& a, const at::Tensor& b, + const at::Tensor& batchSizes, + const at::Tensor& grad, bool transA, + bool transB) { + callDiopi(diopiGroupedGemmBackward, gradA, gradB, a, b, batchSizes, grad, + transA, transB); +} + // 判断是否有对应的 diopi 实现: // 如果有, 则直接 pybind 上去; // 否则不注册, 等到 python 层处理. @@ -507,6 +516,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { if (&diopiGroupedGemm != nullptr) { m.def("gmm", &extGroupedGemm, "deeplink extGroupedGemm"); } + if (&diopiGroupedGemmBackward != nullptr) { + m.def("gmm_backward", &extGroupedGemmBackward, "deeplink extGroupedGemm"); + } } } // namespace dipu::dipu_ext diff --git a/deeplink_ext/internevo_ops/__init__.py b/deeplink_ext/internevo_ops/__init__.py index 7c52f7b..6feb992 100644 --- a/deeplink_ext/internevo_ops/__init__.py +++ b/deeplink_ext/internevo_ops/__init__.py @@ -39,8 +39,6 @@ from .gmm import GroupedGemm except: print(_not_impl.format(op_name="grouped gemm")) - from .gmm_fallback import gmm_fallback as gmm_forward - from .gmm_fallback import gmm_fallback as GroupedGemm __all__ = [ diff --git a/deeplink_ext/internevo_ops/gmm.py b/deeplink_ext/internevo_ops/gmm.py index e4e54c9..ed65b90 100644 --- a/deeplink_ext/internevo_ops/gmm.py +++ b/deeplink_ext/internevo_ops/gmm.py @@ -26,7 +26,7 @@ def gmm_forward(a, b, batch_sizes, trans_a, trans_b): if batch_sizes.is_cuda: ext.gmm(out, a, b, batch_sizes, trans_a, trans_b) else: - ext.gmm(out, a, b, batch_sizes.cuda, trans_a, trans_b) + ext.gmm(out, a, b, batch_sizes.cuda(), trans_a, trans_b) return out @@ -43,9 +43,13 @@ def backward(ctx, grad): a, b, batch_sizes = ctx.saved_tensors trans_b = ctx.trans_b - gradA = gmm_forward(grad, b, batch_sizes, False, trans_b) + # gradA = gmm_forward(grad, b, batch_sizes, False, not trans_b) + # lhs, rhs = (grad, a) if trans_b else (a, grad) + # gradB = gmm_forward(lhs, rhs, batch_sizes, True, False) + + gradA = torch.empty_like(a) + gradB = torch.empty_like(b) + ext.gmm_backward(gradA, gradB, a, b, batch_sizes.cuda(), grad, False, trans_b) - lhs, rhs = (grad, a) if trans_b else (a, grad) - gradB = gmm_forward(lhs, rhs, batch_sizes, True, False) return gradA, gradB, None, None diff --git a/tests/internevo/test_gmm.py b/tests/internevo/test_gmm.py index 76b61eb..a7c817c 100644 --- a/tests/internevo/test_gmm.py +++ b/tests/internevo/test_gmm.py @@ -27,11 +27,11 @@ def add_transpose_flags(x): _TEST_PROBLEMS = add_transpose_flags(( - (1, 128, 128, 128), - (8, 128, 128, 128), - (16, 128, 128, 128), - (1, 128, 256, 512), - (8, 128, 256, 512), + # (1, 128, 128, 128), + # (8, 128, 128, 128), + # (16, 128, 128, 128), + # (1, 128, 256, 512), + # (8, 128, 256, 512), (16, 128, 256, 512), )) @@ -67,7 +67,7 @@ def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b): a_ref = a.detach().clone().requires_grad_(True) b_ref = b.detach().clone().requires_grad_(True) - out = GroupedGemm(a, b, batch_sizes, False, trans_b) + out = GroupedGemm.apply(a, b, batch_sizes, trans_b) expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) self.assertTrue(allclose(out.cpu(), expected_out.cpu())) @@ -94,7 +94,7 @@ def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b): a_ref = a.detach().clone().requires_grad_(True) b_ref = b.detach().clone().requires_grad_(True) - out = GroupedGemm(a, b, batch_sizes, False, trans_b) + out = GroupedGemm.apply(a, b, batch_sizes, trans_b) expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) self.assertTrue(allclose(out.cpu(), expected_out.cpu())) @@ -105,6 +105,5 @@ def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b): self.assertTrue(allclose(b.grad.cpu(), b_ref.grad.cpu())) - if __name__ == '__main__': unittest.main() \ No newline at end of file