Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/grouped gmm #123

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2023, DeepLink.

Check notice on line 1 in csrc/extensions.cpp

View workflow job for this annotation

GitHub Actions / clang-format

Run clang-format on csrc/extensions.cpp

File csrc/extensions.cpp does not conform to Custom style guidelines. (lines 408)

#include <cstdint>
#include <string>
Expand Down Expand Up @@ -397,6 +397,20 @@
callDiopi(diopiRotaryEmbeddingV2, query, key, cos, sin, dim);
}

void extGroupedGemm(at::Tensor& out, const at::Tensor& a, const at::Tensor& b,
const at::Tensor& batchSizes, bool transA, bool transB) {
callDiopi(diopiGroupedGemm, out, a, b, batchSizes, transA, transB);
}

void extGroupedGemmBackward(at::Tensor& gradA, at::Tensor& gradB,
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& batchSizes,
const at::Tensor& grad, bool transA,
bool transB) {
callDiopi(diopiGroupedGemmBackward, gradA, gradB, a, b, batchSizes, grad,
transA, transB);
}

// 判断是否有对应的 diopi 实现:
// 如果有, 则直接 pybind 上去;
// 否则不注册, 等到 python 层处理.
Expand Down Expand Up @@ -499,6 +513,12 @@
m.def("rotary_embedding_v2", &extRotaryEmbeddingV2,
"deeplink extRotaryEmbeddingV2");
}
if (&diopiGroupedGemm != nullptr) {
m.def("gmm", &extGroupedGemm, "deeplink extGroupedGemm");
}
if (&diopiGroupedGemmBackward != nullptr) {
m.def("gmm_backward", &extGroupedGemmBackward, "deeplink extGroupedGemm");
}
}

} // namespace dipu::dipu_ext
10 changes: 10 additions & 0 deletions deeplink_ext/internevo_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
print(_not_impl.format(op_name="adamw"))
from torch.optim import AdamW as AdamW


try:
from .flash_attention import FlashSelfAttention, FlashCrossAttention
except Exception as e:
Expand All @@ -33,11 +34,20 @@
from .rotary_embedding_fallback import ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_


try:
from .gmm import gmm_forward
from .gmm import GroupedGemm
except:
print(_not_impl.format(op_name="grouped gemm"))


__all__ = [
"AdamW",
"FlashSelfAttention",
"FlashCrossAttention",
"MixedFusedRMSNorm",
"ApplyRotaryEmb",
"ApplyRotaryEmbQKV_",
"gmm_forward",
"GroupedGemm",
]
55 changes: 55 additions & 0 deletions deeplink_ext/internevo_ops/gmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numbers
import torch
from torch.nn import init

import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "gmm")


__all__ = ["gmm_forward", "GroupedGemm"]


def gmm_forward(a, b, batch_sizes, trans_a, trans_b):
assert not (trans_a and trans_b), "'trans_a' and 'trans_b' can't both be true"
assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
assert a.ndim == 2, "Expected 2d tensor for 'a'"
assert b.ndim == (2 if trans_a else 3)

shape = (
(batch_sizes.shape[0], a.shape[1], b.shape[1])
if trans_a
else (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
)
out = torch.empty(*shape, device=a.device, dtype=a.dtype)

if batch_sizes.is_cuda:
ext.gmm(out, a, b, batch_sizes, trans_a, trans_b)
else:
ext.gmm(out, a, b, batch_sizes.cuda(), trans_a, trans_b)

return out


class GroupedGemm(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b, batch_sizes, trans_b):
ctx.save_for_backward(a, b, batch_sizes)
ctx.trans_b = trans_b
return gmm_forward(a, b, batch_sizes, False, trans_b)

@staticmethod
def backward(ctx, grad):
a, b, batch_sizes = ctx.saved_tensors
trans_b = ctx.trans_b

# gradA = gmm_forward(grad, b, batch_sizes, False, not trans_b)
# lhs, rhs = (grad, a) if trans_b else (a, grad)
# gradB = gmm_forward(lhs, rhs, batch_sizes, True, False)

gradA = torch.empty_like(a)
gradB = torch.empty_like(b)
ext.gmm_backward(gradA, gradB, a, b, batch_sizes.cuda(), grad, False, trans_b)


return gradA, gradB, None, None
13 changes: 13 additions & 0 deletions deeplink_ext/internevo_ops/gmm_fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch


def gmm_fallback(a, b, batch_sizes, trans_a=False, trans_b=False):
batch_sizes = batch_sizes.numpy()

out = []
start = 0
for i, size in enumerate(batch_sizes):
rhs = b[i, :, :].t() if trans_b else b[i, :, :]
out.append(a[start:start + size, :] @ rhs)
start += size
return torch.cat(out)
109 changes: 109 additions & 0 deletions tests/internevo/test_gmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import unittest
import itertools

from absl.testing import parameterized
# from grouped_gemm import ops
from deeplink_ext.internevo_ops import GroupedGemm
import numpy as np
import torch


def allclose(x, y, pct=2.0):
mask = torch.isclose(x, y, rtol=1e-5)
pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
if pct_diff > pct:
print(x[torch.logical_not(mask)], y[torch.logical_not(mask)])
print("{:.2f}% of values not close.".format(pct_diff))
return False
return True


def add_transpose_flags(x):
out = []
for y in x:
for f in [(False,), (True,)]:
out.append(y + f)
return out


_TEST_PROBLEMS = add_transpose_flags((
# (1, 128, 128, 128),
# (8, 128, 128, 128),
# (16, 128, 128, 128),
# (1, 128, 256, 512),
# (8, 128, 256, 512),
(16, 128, 256, 512),
))


def randn(bs, x, y):
out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
return out.cuda().to(torch.bfloat16)


def gmm(a, b, batch_sizes, trans_b=False):
batch_sizes = batch_sizes.numpy()

out = []
start = 0
for i, size in enumerate(batch_sizes):
rhs = b[i, :, :].t() if trans_b else b[i, :, :]
out.append(a[start:start + size, :] @ rhs)
start += size
return torch.cat(out)


@parameterized.parameters(*_TEST_PROBLEMS)
class OpsTest(parameterized.TestCase):

def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b):
torch.manual_seed(0)
a = randn(z, m, k).view(-1, k)
b = randn(z, n, k) if trans_b else randn(z, k, n)
batch_sizes = torch.tensor([m] * z)

a.requires_grad_(True)
b.requires_grad_(True)
a_ref = a.detach().clone().requires_grad_(True)
b_ref = b.detach().clone().requires_grad_(True)

out = GroupedGemm.apply(a, b, batch_sizes, trans_b)
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
self.assertTrue(allclose(out.cpu(), expected_out.cpu()))

# Check gradients.
out.sum().backward()
expected_out.sum().backward()
self.assertTrue(allclose(a.grad.cpu(), a_ref.grad.cpu()))
self.assertTrue(allclose(b.grad.cpu(), b_ref.grad.cpu()))

def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b):
torch.manual_seed(0)
a = randn(z, m, k).view(-1, k)
b = randn(z, n, k) if trans_b else randn(z, k, n)

dist = torch.rand(z, )
dist /= dist.sum()
batch_sizes = (dist * m).to(torch.long)
error = m * z - batch_sizes.sum()
batch_sizes[-1] += error
assert batch_sizes.sum() == (m * z)

a.requires_grad_(True)
b.requires_grad_(True)
a_ref = a.detach().clone().requires_grad_(True)
b_ref = b.detach().clone().requires_grad_(True)

out = GroupedGemm.apply(a, b, batch_sizes, trans_b)
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
self.assertTrue(allclose(out.cpu(), expected_out.cpu()))

# Check gradients.
out.sum().backward()
expected_out.sum().backward()
self.assertTrue(allclose(a.grad.cpu(), a_ref.grad.cpu()))
self.assertTrue(allclose(b.grad.cpu(), b_ref.grad.cpu()))


if __name__ == '__main__':
unittest.main()
Loading