Skip to content

Commit

Permalink
fix bugs of gmm
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangzefeng92 committed Jul 22, 2024
1 parent a38923b commit 9c98622
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 27 deletions.
2 changes: 1 addition & 1 deletion deeplink_ext/internevo_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
except:
print(_not_impl.format(op_name="grouped gemm"))
from .gmm_fallback import gmm_fallback as gmm_forward
from .gmm_fallback import gmm_fallback as GroupedGemm
from .gmm_fallback import gmm_fallback as GroupedGemm


__all__ = [
Expand Down
52 changes: 26 additions & 26 deletions deeplink_ext/internevo_ops/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,45 @@
assert hasattr(ext, "gmm")


__all__ = ["gmm", "GroupedGemmFunc"]
__all__ = ["gmm_forward", "GroupedGemm"]


def gmm_forward(a, b, batchSizes, transA, transB):
assert not (transA and transB), "'transA' and 'transB' can't both be true"
assert batchSizes.ndim == 1, "Expected 1d tensor for batchSizes"
def gmm_forward(a, b, batch_sizes, trans_a, trans_b):
assert not (trans_a and trans_b), "'trans_a' and 'trans_b' can't both be true"
assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
assert a.ndim == 2, "Expected 2d tensor for 'a'"
assert b.ndim == (2 if transA else 3)
assert b.ndim == (2 if trans_a else 3)

shape = (
(batchSizes.shape[0], a.shape[1], b.shape[1])
if transA
else (a.shape[0], (b.shape[1] if transB else b.shape[2]))
(batch_sizes.shape[0], a.shape[1], b.shape[1])
if trans_a
else (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
)
out = torch.empty(*shape, device=a.device, dtype=a.dtype)

if batchSizes.is_cuda:
ext.gmm(out, a, b, batchSizes, transA, transB)
if batch_sizes.is_cuda:
ext.gmm(out, a, b, batch_sizes, trans_a, trans_b)
else:
ext.gmm(out, a, b, batchSizes.cuda, transA, transB)
ext.gmm(out, a, b, batch_sizes.cuda, trans_a, trans_b)

return out


class GroupedGemm(torch.autograd.Fuction):
class GroupedGemm(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b, batchSizes, transB):
ctx.save_for_backward(a, b, batchSizes)
ctx.transB = transB
return gmm_forward(a, b, batchSizes, False, transB)
def forward(ctx, a, b, batch_sizes, trans_b):
ctx.save_for_backward(a, b, batch_sizes)
ctx.trans_b = trans_b
return gmm_forward(a, b, batch_sizes, False, trans_b)

@staticmethod
def backward(ctx, grad):
a, b, batchSizes = ctx.saved_tensors
transB = ctx.transB
gradA = gmm_forward(grad, b, batchSizes, False, transB)
lhs, rhs = (grad, a) if transB else (a, grad)
gradB = gmm_forward(lhs, rhs, batchSizes, True, False)
return gradA, gradB, None, None
a, b, batch_sizes = ctx.saved_tensors
trans_b = ctx.trans_b

gradA = gmm_forward(grad, b, batch_sizes, False, trans_b)

lhs, rhs = (grad, a) if trans_b else (a, grad)
gradB = gmm_forward(lhs, rhs, batch_sizes, True, False)

return gradA, gradB, None, None

0 comments on commit 9c98622

Please sign in to comment.