From ee2bd238c78528ddd880036c860d8d6c78f8808d Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Tue, 8 Oct 2024 17:28:12 -0400 Subject: [PATCH] [Bugfix] Try to handle older versions of pytorch (#9086) --- tests/kernels/test_awq.py | 5 +++ tests/kernels/test_awq_marlin.py | 4 +++ vllm/_custom_ops.py | 53 +++++++++++++++++++------------- 3 files changed, 41 insertions(+), 21 deletions(-) diff --git a/tests/kernels/test_awq.py b/tests/kernels/test_awq.py index e421aca48af2c..aa7a430850f9a 100644 --- a/tests/kernels/test_awq.py +++ b/tests/kernels/test_awq.py @@ -1,11 +1,14 @@ import os +import pytest import torch from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 +@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"), + reason="AWQ is not supported on this GPU type.") def test_awq_dequantize_opcheck(): os.environ["VLLM_USE_TRITON_AWQ"] = "0" qweight = torch.randint(-2000000000, @@ -21,6 +24,8 @@ def test_awq_dequantize_opcheck(): (qweight, scales, zeros, split_k_iters, thx, thy)) +@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"), + reason="AWQ is not supported on this GPU type.") def test_awq_gemm_opcheck(): os.environ["VLLM_USE_TRITON_AWQ"] = "0" input = torch.rand((2, 8192), device='cuda', dtype=torch.float16) diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index 0738ea9b97edb..0f0a2b24563fd 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -7,6 +7,7 @@ from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe, torch_moe_single) +from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk @@ -21,6 +22,9 @@ @pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.skipif(not (ops.supports_moe_ops + and hasattr(torch.ops._moe_C, "marlin_gemm_moe")), + reason="Marlin is not supported on this GPU type.") def test_fused_marlin_moe_awq( m: int, n: int, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 24e008dc38022..3a23692285efe 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,8 +1,9 @@ import contextlib import functools -from typing import List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch +import torch.library import vllm.envs as envs from vllm._core_ext import ScalarType @@ -25,6 +26,16 @@ import vllm._moe_C # noqa: F401 supports_moe_ops = True +if TYPE_CHECKING: + + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + def hint_on_error(fn): @@ -266,7 +277,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "gptq_gemm"): - @torch.library.register_fake("_C::gptq_gemm") + @register_fake("_C::gptq_gemm") def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, @@ -301,7 +312,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): - @torch.library.register_fake("_C::gptq_marlin_24_gemm") + @register_fake("_C::gptq_marlin_24_gemm") def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, @@ -309,7 +320,7 @@ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, size_n: int, size_k: int) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @torch.library.register_fake("_C::gptq_marlin_gemm") + @register_fake("_C::gptq_marlin_gemm") def _gptq_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, @@ -326,12 +337,12 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, use_fp32_reduce: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @torch.library.register_fake("_C::ggml_dequantize") + @register_fake("_C::ggml_dequantize") def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int, n: int) -> torch.Tensor: return torch.empty((m, n), dtype=torch.float16, device=W.device) - @torch.library.register_fake("_C::ggml_mul_mat_vec_a8") + @register_fake("_C::ggml_mul_mat_vec_a8") def _ggml_mul_mat_vec_a8_fake( W: torch.Tensor, X: torch.Tensor, @@ -340,7 +351,7 @@ def _ggml_mul_mat_vec_a8_fake( ) -> torch.Tensor: return torch.empty((1, row), dtype=torch.float16, device=W.device) - @torch.library.register_fake("_C::ggml_mul_mat_a8") + @register_fake("_C::ggml_mul_mat_a8") def _ggml_mul_mat_a8_fake( W: torch.Tensor, X: torch.Tensor, @@ -350,7 +361,7 @@ def _ggml_mul_mat_a8_fake( batch = X.size(0) return torch.empty((batch, row), dtype=torch.float16, device=W.device) - @torch.library.register_fake("_C::marlin_qqq_gemm") + @register_fake("_C::marlin_qqq_gemm") def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, s_tok: torch.Tensor, s_ch: torch.Tensor, s_group: torch.Tensor, workspace: torch.Tensor, @@ -360,7 +371,7 @@ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, dtype=torch.float16, device=a.device) - @torch.library.register_fake("_C::marlin_gemm") + @register_fake("_C::marlin_gemm") def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, size_n: int, @@ -369,7 +380,7 @@ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, dtype=torch.float16, device=a.device) - @torch.library.register_fake("_C::awq_dequantize") + @register_fake("_C::awq_dequantize") def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: @@ -380,7 +391,7 @@ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, dtype=scales.dtype, device=scales.device) - @torch.library.register_fake("_C::awq_gemm") + @register_fake("_C::awq_gemm") def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: @@ -389,7 +400,7 @@ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, dtype=input.dtype, device=input.device).sum(0) - @torch.library.register_fake("_C::aqlm_gemm") + @register_fake("_C::aqlm_gemm") def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, codebook_partition_sizes: List[int], @@ -405,7 +416,7 @@ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, output_sizes.append(-1) return flat_output.reshape(tuple(output_sizes)) - @torch.library.register_fake("_C::aqlm_dequant") + @register_fake("_C::aqlm_dequant") def _aqlm_dequant_fake( codes: torch.Tensor, codebooks: torch.Tensor, codebook_partition_sizes: List[int]) -> torch.Tensor: @@ -415,14 +426,14 @@ def _aqlm_dequant_fake( dtype=codebooks.dtype, device=codebooks.device) - @torch.library.register_fake("_C::fp8_marlin_gemm") + @register_fake("_C::fp8_marlin_gemm") def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @torch.library.register_fake("_C::machete_gemm") + @register_fake("_C::machete_gemm") def machete_gemm_fake( a: torch.Tensor, # Should be the tensor returned by machete_prepack_B @@ -440,13 +451,13 @@ def machete_gemm_fake( n = b_q.size(1) return torch.empty((m, n), device=a.device, dtype=a.dtype) - @torch.library.register_fake("_C::machete_prepack_B") + @register_fake("_C::machete_prepack_B") def machete_prepack_B_fake(b_q_weight: torch.Tensor, b_type: ScalarType) -> torch.Tensor: return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) - @torch.library.register_fake("_C::causal_conv1d_fwd") + @register_fake("_C::causal_conv1d_fwd") def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], conv_states: Optional[torch.Tensor], @@ -456,7 +467,7 @@ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, silu_activation: bool) -> torch.Tensor: return torch.empty_like(x) - @torch.library.register_fake("_C::causal_conv1d_update") + @register_fake("_C::causal_conv1d_update") def causal_conv1d_update_fake( x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], silu_activation: bool, @@ -464,7 +475,7 @@ def causal_conv1d_update_fake( conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: return torch.empty_like(x) - @torch.library.register_fake("_C::selective_scan_fwd") + @register_fake("_C::selective_scan_fwd") def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor], @@ -639,7 +650,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "permute_cols"): - @torch.library.register_fake("_C::permute_cols") + @register_fake("_C::permute_cols") def _permute_cols_fake(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.empty_like(a) @@ -837,7 +848,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): - @torch.library.register_fake("_moe_C::marlin_gemm_moe") + @register_fake("_moe_C::marlin_gemm_moe") def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, sorted_ids: torch.Tensor, topk_weights: torch.Tensor,