From e32fe71766afb7948a49ca58662bb66a5228c4d8 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Fri, 20 Dec 2024 16:03:19 +0800 Subject: [PATCH] [platform] Do not use current_platform in global namespace Signed-off-by: wangxiyuan --- vllm/_custom_ops.py | 63 ++++++++++++------- .../ops/blocksparse_attention/interface.py | 14 ++--- vllm/attention/ops/prefix_prefill.py | 11 ++-- .../device_communicators/hpu_communicator.py | 4 +- .../device_communicators/tpu_communicator.py | 4 +- vllm/model_executor/layers/fused_moe/layer.py | 19 +++--- vllm/spec_decode/multi_step_worker.py | 5 +- vllm/spec_decode/spec_decode_worker.py | 5 +- 8 files changed, 70 insertions(+), 55 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 19f31b8ec419d..1d6aca24b0c3c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -13,27 +13,42 @@ logger = init_logger(__name__) -if not current_platform.is_tpu() and not current_platform.is_hpu(): - try: - import vllm._C - except ImportError as e: - logger.warning("Failed to import from vllm._C with %r", e) +try: + import vllm._C +except ImportError as e: + logger.warning("Failed to import from vllm._C with %r", e) supports_moe_ops = False with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 supports_moe_ops = True +register_fake_func = None # neuron has torch version that doesn't even have impl_abstract -if TYPE_CHECKING or current_platform.is_neuron(): +if TYPE_CHECKING: def register_fake(fn): return lambda name: fn + + register_fake_func = register_fake else: try: from torch.library import register_fake + + register_fake_func = register_fake except ImportError: - from torch.library import impl_abstract as register_fake + try: + from torch.library import impl_abstract as register_fake + + register_fake_func = register_fake + except ImportError: + # For the platform which torch version that doesn't even have + # impl_abstract, set a dummy register_fake_func. For example, the + # neuron platform. + def register_fake(fn): + return lambda name: fn + + register_fake_func = register_fake def hint_on_error(fn): @@ -302,7 +317,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "gptq_gemm"): - @register_fake("_C::gptq_gemm") + @register_fake_func("_C::gptq_gemm") def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, @@ -337,7 +352,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): - @register_fake("_C::gptq_marlin_24_gemm") + @register_fake_func("_C::gptq_marlin_24_gemm") def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, @@ -346,7 +361,7 @@ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, size_k: torch.SymInt) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @register_fake("_C::gptq_marlin_gemm") + @register_fake_func("_C::gptq_marlin_gemm") def _gptq_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, @@ -364,7 +379,7 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @register_fake("_C::marlin_qqq_gemm") + @register_fake_func("_C::marlin_qqq_gemm") def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, s_tok: torch.Tensor, s_ch: torch.Tensor, s_group: torch.Tensor, workspace: torch.Tensor, @@ -374,7 +389,7 @@ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, dtype=torch.float16, device=a.device) - @register_fake("_C::marlin_gemm") + @register_fake_func("_C::marlin_gemm") def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: torch.SymInt, size_n: torch.SymInt, @@ -383,7 +398,7 @@ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, dtype=torch.float16, device=a.device) - @register_fake("_C::awq_dequantize") + @register_fake_func("_C::awq_dequantize") def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: torch.SymInt, thx: int, thy: int) -> torch.Tensor: @@ -394,7 +409,7 @@ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, dtype=scales.dtype, device=scales.device) - @register_fake("_C::awq_gemm") + @register_fake_func("_C::awq_gemm") def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: torch.SymInt) -> torch.Tensor: @@ -403,7 +418,7 @@ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, dtype=input.dtype, device=input.device).sum(0) - @register_fake("_C::aqlm_gemm") + @register_fake_func("_C::aqlm_gemm") def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, codebook_partition_sizes: List[int], @@ -419,7 +434,7 @@ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, output_sizes.append(-1) return flat_output.reshape(tuple(output_sizes)) - @register_fake("_C::aqlm_dequant") + @register_fake_func("_C::aqlm_dequant") def _aqlm_dequant_fake( codes: torch.Tensor, codebooks: torch.Tensor, codebook_partition_sizes: List[int]) -> torch.Tensor: @@ -429,7 +444,7 @@ def _aqlm_dequant_fake( dtype=codebooks.dtype, device=codebooks.device) - @register_fake("_C::fp8_marlin_gemm") + @register_fake_func("_C::fp8_marlin_gemm") def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: torch.SymInt, @@ -437,7 +452,7 @@ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, size_k: torch.SymInt) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake("_C::machete_mm") + @register_fake_func("_C::machete_mm") def machete_mm_fake( a: torch.Tensor, # b_q Should be the tensor returned by machete_prepack_B @@ -455,7 +470,7 @@ def machete_mm_fake( n = b_q.size(1) return torch.empty((m, n), device=a.device, dtype=a.dtype) - @register_fake("_C::machete_prepack_B") + @register_fake_func("_C::machete_prepack_B") def machete_prepack_B_fake( b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, group_scales_type: Optional[torch.dtype]) -> torch.Tensor: @@ -465,13 +480,13 @@ def machete_prepack_B_fake( if hasattr(torch.ops._C, "ggml_dequantize"): - @register_fake("_C::ggml_dequantize") + @register_fake_func("_C::ggml_dequantize") def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: torch.SymInt, n: torch.SymInt) -> torch.Tensor: return torch.empty((m, n), dtype=torch.float16, device=W.device) - @register_fake("_C::ggml_mul_mat_vec_a8") + @register_fake_func("_C::ggml_mul_mat_vec_a8") def _ggml_mul_mat_vec_a8_fake( W: torch.Tensor, X: torch.Tensor, @@ -480,7 +495,7 @@ def _ggml_mul_mat_vec_a8_fake( ) -> torch.Tensor: return torch.empty((1, row), dtype=torch.float16, device=W.device) - @register_fake("_C::ggml_mul_mat_a8") + @register_fake_func("_C::ggml_mul_mat_a8") def _ggml_mul_mat_a8_fake( W: torch.Tensor, X: torch.Tensor, @@ -787,7 +802,7 @@ def machete_prepack_B( if hasattr(torch.ops._C, "permute_cols"): - @register_fake("_C::permute_cols") + @register_fake_func("_C::permute_cols") def _permute_cols_fake(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.empty_like(a) @@ -993,7 +1008,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): - @register_fake("_moe_C::marlin_gemm_moe") + @register_fake_func("_moe_C::marlin_gemm_moe") def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, sorted_ids: torch.Tensor, topk_weights: torch.Tensor, diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py index 350f88c8f9740..e19b15dc44b73 100644 --- a/vllm/attention/ops/blocksparse_attention/interface.py +++ b/vllm/attention/ops/blocksparse_attention/interface.py @@ -4,17 +4,15 @@ from vllm.platforms import current_platform +from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd from .utils import (dense_to_crow_col, get_head_sliding_step, get_sparse_attn_mask) -IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80) - -if IS_COMPUTE_8_OR_ABOVE: - from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd - class LocalStridedBlockSparseAttn(torch.nn.Module): + IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80) + def __init__( self, n_heads, @@ -33,12 +31,12 @@ def __init__( if use_spda is None: use_spda = current_platform.is_rocm() or \ current_platform.is_cpu() or not \ - IS_COMPUTE_8_OR_ABOVE + self.IS_COMPUTE_8_OR_ABOVE device = device or (torch.cuda.current_device() if current_platform.is_cuda_alike() else "cpu") device = torch.device(device) # NOTE: vllm CPU backend support BF16 instead of FP16. - dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE + dtype = dtype or (torch.bfloat16 if self.IS_COMPUTE_8_OR_ABOVE or device.type == "cpu" else torch.half) self.n_heads = n_heads @@ -122,7 +120,7 @@ def varlen_attn(self, return: tensor of shape as q. """ assert ( - IS_COMPUTE_8_OR_ABOVE + self.IS_COMPUTE_8_OR_ABOVE ), "Requires compute capability of 8 or above (Ampere or newer) to use \ Triton kernel." diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 9c11a8df55278..140e798e75166 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -7,13 +7,8 @@ from vllm.platforms import current_platform -# Static kernels parameters -BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64 NUM_WARPS = 8 -# To check compatibility -IS_TURING = current_platform.get_device_capability() == (7, 5) - if triton.__version__ >= "2.1.0": @triton.jit @@ -719,11 +714,17 @@ def context_attention_fwd(q, sliding_window=None): q_dtype_is_f32 = q.dtype is torch.float32 + # Static kernels parameters + BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64 + # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory # if q.dtype is torch.float32: BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK + # To check compatibility + IS_TURING = current_platform.get_device_capability() == (7, 5) + # Turing does have tensor core for float32 multiplication # use ieee as fallback for triton kernels work. There is also # warning on vllm/config.py to inform users this fallback diff --git a/vllm/distributed/device_communicators/hpu_communicator.py b/vllm/distributed/device_communicators/hpu_communicator.py index cc9b19ce022b5..183c617709279 100644 --- a/vllm/distributed/device_communicators/hpu_communicator.py +++ b/vllm/distributed/device_communicators/hpu_communicator.py @@ -1,10 +1,12 @@ +import contextlib + import torch import torch.distributed as dist from torch.distributed import ProcessGroup from vllm.platforms import current_platform -if current_platform.is_hpu(): +with contextlib.suppress(ImportError): import habana_frameworks.torch as htorch # noqa: F401 diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 765a0f9cb1c87..14e86f94ee9cc 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -1,3 +1,4 @@ +import contextlib import os import torch @@ -6,7 +7,8 @@ from vllm.platforms import current_platform -if current_platform.is_tpu(): +with contextlib.suppress(ImportError): + import habana_frameworks.torch as htorch # noqa: F401 import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from torch_xla._internal import pjrt diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8c6f7c6e06515..d5003b59aaa15 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -14,14 +14,6 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -if current_platform.is_cuda_alike(): - from .fused_moe import fused_experts -else: - fused_experts = None # type: ignore -if current_platform.is_tpu(): - from .moe_pallas import fused_moe as fused_moe_pallas -else: - fused_moe_pallas = None # type: ignore logger = init_logger(__name__) @@ -115,6 +107,11 @@ def forward_cuda( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function) + if current_platform.is_cuda_alike(): + from .fused_moe import fused_experts + else: + fused_experts = None # type: ignore + return fused_experts(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -142,6 +139,12 @@ def forward_tpu( assert num_expert_group is None assert topk_group is None assert custom_routing_function is None + + if current_platform.is_tpu(): + from .moe_pallas import fused_moe as fused_moe_pallas + else: + fused_moe_pallas = None # type: ignore + return fused_moe_pallas(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 676ac5eb3609d..4e36854060b66 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -8,10 +8,7 @@ from vllm.platforms import current_platform from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, SequenceGroupMetadata) - -if current_platform.is_cuda_alike(): - from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner - +from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 2689802161987..e269bb02f43f1 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -20,10 +20,7 @@ HiddenStates, SequenceGroupMetadata, get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer - -if current_platform.is_cuda_alike(): - from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner - +from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.medusa_worker import MedusaWorker