From 9b1388c093a70577eaabb0e77c7fab1da9f03990 Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 21 Mar 2024 21:00:56 +0000 Subject: [PATCH] fix custom kernel --- vllm/model_executor/layers/linear.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 6398c53bfda06..edcb448741f7e 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -74,17 +74,25 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: weight = weights["weight"] - if is_hip() and x.shape[0] == 1: - m, n, k = weight.shape[0], x.shape[0], x.shape[1] - out = torch.empty(x.shape[0], weight.shape[0], dtype=x.dtype) + if is_hip() and x.view(-1, x.size(-1)).shape[0] == 1: + batched = False + if x.dim() == 3: + inp = x.view(-1, x.size(-1)) + batched = True + else: + inp = x + m, n, k = weight.shape[0], inp.shape[0], inp.shape[1] + out = torch.empty(inp.shape[0], weight.shape[0], dtype=inp.dtype, device='cuda') if k == 8192 and (m == 1280 or m == 7168): - custom_ops.LLMM1(weight, x, out, 8) + custom_ops.LLMM1(weight, inp, out, 8) elif k == 3584 and m == 8192: - custom_ops.LLMM1(weight, x, out, 8) + custom_ops.LLMM1(weight, inp, out, 8) elif k <= 8192 and k % 8 == 0 and m % 4 == 0: - custom_ops.LLMM1(weight, x, out, 4) + custom_ops.LLMM1(weight, inp, out, 4) else: - out = F.linear(x, weight) + out = F.linear(inp, weight) + if batched: + out = out.view(x.shape[0], x.shape[1], weight.shape[0]) if bias != None: out = out + bias return out