Skip to content

Commit

Permalink
[platform] Do not use current_platform in global namespace
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan committed Dec 20, 2024
1 parent 1ecc645 commit e32fe71
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 55 deletions.
63 changes: 39 additions & 24 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,42 @@

logger = init_logger(__name__)

if not current_platform.is_tpu() and not current_platform.is_hpu():
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)

supports_moe_ops = False
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
supports_moe_ops = True

register_fake_func = None
# neuron has torch version that doesn't even have impl_abstract
if TYPE_CHECKING or current_platform.is_neuron():
if TYPE_CHECKING:

def register_fake(fn):
return lambda name: fn

register_fake_func = register_fake
else:
try:
from torch.library import register_fake

register_fake_func = register_fake
except ImportError:
from torch.library import impl_abstract as register_fake
try:
from torch.library import impl_abstract as register_fake

register_fake_func = register_fake
except ImportError:
# For the platform which torch version that doesn't even have
# impl_abstract, set a dummy register_fake_func. For example, the
# neuron platform.
def register_fake(fn):
return lambda name: fn

register_fake_func = register_fake


def hint_on_error(fn):
Expand Down Expand Up @@ -302,7 +317,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,

if hasattr(torch.ops._C, "gptq_gemm"):

@register_fake("_C::gptq_gemm")
@register_fake_func("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
Expand Down Expand Up @@ -337,7 +352,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,

if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):

@register_fake("_C::gptq_marlin_24_gemm")
@register_fake_func("_C::gptq_marlin_24_gemm")
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor,
Expand All @@ -346,7 +361,7 @@ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

@register_fake("_C::gptq_marlin_gemm")
@register_fake_func("_C::gptq_marlin_gemm")
def _gptq_marlin_gemm_fake(a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
Expand All @@ -364,7 +379,7 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor,
is_zp_float: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

@register_fake("_C::marlin_qqq_gemm")
@register_fake_func("_C::marlin_qqq_gemm")
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
Expand All @@ -374,7 +389,7 @@ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
dtype=torch.float16,
device=a.device)

@register_fake("_C::marlin_gemm")
@register_fake_func("_C::marlin_gemm")
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
size_m: torch.SymInt, size_n: torch.SymInt,
Expand All @@ -383,7 +398,7 @@ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
dtype=torch.float16,
device=a.device)

@register_fake("_C::awq_dequantize")
@register_fake_func("_C::awq_dequantize")
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: torch.SymInt,
thx: int, thy: int) -> torch.Tensor:
Expand All @@ -394,7 +409,7 @@ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
dtype=scales.dtype,
device=scales.device)

@register_fake("_C::awq_gemm")
@register_fake_func("_C::awq_gemm")
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor, scales: torch.Tensor,
split_k_iters: torch.SymInt) -> torch.Tensor:
Expand All @@ -403,7 +418,7 @@ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
dtype=input.dtype,
device=input.device).sum(0)

@register_fake("_C::aqlm_gemm")
@register_fake_func("_C::aqlm_gemm")
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: List[int],
Expand All @@ -419,7 +434,7 @@ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
output_sizes.append(-1)
return flat_output.reshape(tuple(output_sizes))

@register_fake("_C::aqlm_dequant")
@register_fake_func("_C::aqlm_dequant")
def _aqlm_dequant_fake(
codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: List[int]) -> torch.Tensor:
Expand All @@ -429,15 +444,15 @@ def _aqlm_dequant_fake(
dtype=codebooks.dtype,
device=codebooks.device)

@register_fake("_C::fp8_marlin_gemm")
@register_fake_func("_C::fp8_marlin_gemm")
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: torch.SymInt,
size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)

@register_fake("_C::machete_mm")
@register_fake_func("_C::machete_mm")
def machete_mm_fake(
a: torch.Tensor,
# b_q Should be the tensor returned by machete_prepack_B
Expand All @@ -455,7 +470,7 @@ def machete_mm_fake(
n = b_q.size(1)
return torch.empty((m, n), device=a.device, dtype=a.dtype)

@register_fake("_C::machete_prepack_B")
@register_fake_func("_C::machete_prepack_B")
def machete_prepack_B_fake(
b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
Expand All @@ -465,13 +480,13 @@ def machete_prepack_B_fake(

if hasattr(torch.ops._C, "ggml_dequantize"):

@register_fake("_C::ggml_dequantize")
@register_fake_func("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
m: torch.SymInt,
n: torch.SymInt) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)

@register_fake("_C::ggml_mul_mat_vec_a8")
@register_fake_func("_C::ggml_mul_mat_vec_a8")
def _ggml_mul_mat_vec_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
Expand All @@ -480,7 +495,7 @@ def _ggml_mul_mat_vec_a8_fake(
) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device)

@register_fake("_C::ggml_mul_mat_a8")
@register_fake_func("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
Expand Down Expand Up @@ -787,7 +802,7 @@ def machete_prepack_B(

if hasattr(torch.ops._C, "permute_cols"):

@register_fake("_C::permute_cols")
@register_fake_func("_C::permute_cols")
def _permute_cols_fake(a: torch.Tensor,
perm: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a)
Expand Down Expand Up @@ -993,7 +1008,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,

if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):

@register_fake("_moe_C::marlin_gemm_moe")
@register_fake_func("_moe_C::marlin_gemm_moe")
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
sorted_ids: torch.Tensor,
topk_weights: torch.Tensor,
Expand Down
14 changes: 6 additions & 8 deletions vllm/attention/ops/blocksparse_attention/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@

from vllm.platforms import current_platform

from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask)

IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80)

if IS_COMPUTE_8_OR_ABOVE:
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd


class LocalStridedBlockSparseAttn(torch.nn.Module):

IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80)

def __init__(
self,
n_heads,
Expand All @@ -33,12 +31,12 @@ def __init__(
if use_spda is None:
use_spda = current_platform.is_rocm() or \
current_platform.is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE
self.IS_COMPUTE_8_OR_ABOVE
device = device or (torch.cuda.current_device()
if current_platform.is_cuda_alike() else "cpu")
device = torch.device(device)
# NOTE: vllm CPU backend support BF16 instead of FP16.
dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
dtype = dtype or (torch.bfloat16 if self.IS_COMPUTE_8_OR_ABOVE
or device.type == "cpu" else torch.half)

self.n_heads = n_heads
Expand Down Expand Up @@ -122,7 +120,7 @@ def varlen_attn(self,
return: tensor of shape as q.
"""
assert (
IS_COMPUTE_8_OR_ABOVE
self.IS_COMPUTE_8_OR_ABOVE
), "Requires compute capability of 8 or above (Ampere or newer) to use \
Triton kernel."

Expand Down
11 changes: 6 additions & 5 deletions vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,8 @@

from vllm.platforms import current_platform

# Static kernels parameters
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
NUM_WARPS = 8

# To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5)

if triton.__version__ >= "2.1.0":

@triton.jit
Expand Down Expand Up @@ -719,11 +714,17 @@ def context_attention_fwd(q,
sliding_window=None):

q_dtype_is_f32 = q.dtype is torch.float32
# Static kernels parameters
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64

# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
# if q.dtype is torch.float32:
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK

# To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5)

# Turing does have tensor core for float32 multiplication
# use ieee as fallback for triton kernels work. There is also
# warning on vllm/config.py to inform users this fallback
Expand Down
4 changes: 3 additions & 1 deletion vllm/distributed/device_communicators/hpu_communicator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import contextlib

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.platforms import current_platform

if current_platform.is_hpu():
with contextlib.suppress(ImportError):
import habana_frameworks.torch as htorch # noqa: F401


Expand Down
4 changes: 3 additions & 1 deletion vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import os

import torch
Expand All @@ -6,7 +7,8 @@

from vllm.platforms import current_platform

if current_platform.is_tpu():
with contextlib.suppress(ImportError):
import habana_frameworks.torch as htorch # noqa: F401
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
Expand Down
19 changes: 11 additions & 8 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform

if current_platform.is_cuda_alike():
from .fused_moe import fused_experts
else:
fused_experts = None # type: ignore
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)


Expand Down Expand Up @@ -115,6 +107,11 @@ def forward_cuda(
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)

if current_platform.is_cuda_alike():
from .fused_moe import fused_experts
else:
fused_experts = None # type: ignore

return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
Expand Down Expand Up @@ -142,6 +139,12 @@ def forward_tpu(
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None

if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore

return fused_moe_pallas(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
Expand Down
5 changes: 1 addition & 4 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
Expand Down
5 changes: 1 addition & 4 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
HiddenStates, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.medusa_worker import MedusaWorker
Expand Down

0 comments on commit e32fe71

Please sign in to comment.