Skip to content

Commit

Permalink
add implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Randall Smith <[email protected]>
  • Loading branch information
rasmith committed Oct 30, 2024
1 parent c2cd1a2 commit 4624680
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import functools
import importlib
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -486,6 +487,14 @@ def cutlass_scaled_mm(a: torch.Tensor,

m = a.shape[0]
n = b.shape[1]

if current_platform.is_rocm():
scaled_mm_triton_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"scaled_mm_triton")
scaled_mm_triton = scaled_mm_triton_module.scaled_mm_triton
return scaled_mm_triton(a, b, scale_a, scale_b, out_dtype, bias)

out = torch.empty((m, n), dtype=out_dtype, device=a.device)

torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
Expand Down

0 comments on commit 4624680

Please sign in to comment.