Skip to content

Commit

Permalink
fix custom kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
charlifu committed Mar 21, 2024
1 parent 42324b6 commit 9b1388c
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,25 @@ def apply_weights(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = weights["weight"]
if is_hip() and x.shape[0] == 1:
m, n, k = weight.shape[0], x.shape[0], x.shape[1]
out = torch.empty(x.shape[0], weight.shape[0], dtype=x.dtype)
if is_hip() and x.view(-1, x.size(-1)).shape[0] == 1:
batched = False
if x.dim() == 3:
inp = x.view(-1, x.size(-1))
batched = True
else:
inp = x
m, n, k = weight.shape[0], inp.shape[0], inp.shape[1]
out = torch.empty(inp.shape[0], weight.shape[0], dtype=inp.dtype, device='cuda')
if k == 8192 and (m == 1280 or m == 7168):
custom_ops.LLMM1(weight, x, out, 8)
custom_ops.LLMM1(weight, inp, out, 8)
elif k == 3584 and m == 8192:
custom_ops.LLMM1(weight, x, out, 8)
custom_ops.LLMM1(weight, inp, out, 8)
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
custom_ops.LLMM1(weight, x, out, 4)
custom_ops.LLMM1(weight, inp, out, 4)
else:
out = F.linear(x, weight)
out = F.linear(inp, weight)
if batched:
out = out.view(x.shape[0], x.shape[1], weight.shape[0])
if bias != None:
out = out + bias
return out
Expand Down

0 comments on commit 9b1388c

Please sign in to comment.