diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 46a2fb8bc80a2..4cf79b7132a7f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,5 +1,6 @@ import contextlib import functools +import importlib from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch @@ -486,6 +487,14 @@ def cutlass_scaled_mm(a: torch.Tensor, m = a.shape[0] n = b.shape[1] + + if current_platform.is_rocm(): + scaled_mm_triton_module = importlib.import_module( + "vllm.model_executor.layers.quantization.compressed_tensors." + "scaled_mm_triton") + scaled_mm_triton = scaled_mm_triton_module.scaled_mm_triton + return scaled_mm_triton(a, b, scale_a, scale_b, out_dtype, bias) + out = torch.empty((m, n), dtype=out_dtype, device=a.device) torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)