Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[platform] Do not use current_platform in global namespace #11362

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 39 additions & 24 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,42 @@

logger = init_logger(__name__)

if not current_platform.is_tpu() and not current_platform.is_hpu():
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)

supports_moe_ops = False
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
supports_moe_ops = True

register_fake_func = None
# neuron has torch version that doesn't even have impl_abstract
if TYPE_CHECKING or current_platform.is_neuron():
if TYPE_CHECKING:

def register_fake(fn):
return lambda name: fn

register_fake_func = register_fake
else:
try:
from torch.library import register_fake

register_fake_func = register_fake
except ImportError:
from torch.library import impl_abstract as register_fake
try:
from torch.library import impl_abstract as register_fake

register_fake_func = register_fake
except ImportError:
# For the platform which torch version that doesn't even have
# impl_abstract, set a dummy register_fake_func. For example, the
# neuron platform.
def register_fake(fn):
return lambda name: fn

register_fake_func = register_fake


# activation ops
Expand Down Expand Up @@ -273,7 +288,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,

if hasattr(torch.ops._C, "gptq_gemm"):

@register_fake("_C::gptq_gemm")
@register_fake_func("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
Expand Down Expand Up @@ -308,7 +323,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,

if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):

@register_fake("_C::gptq_marlin_24_gemm")
@register_fake_func("_C::gptq_marlin_24_gemm")
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor,
Expand All @@ -317,7 +332,7 @@ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

@register_fake("_C::gptq_marlin_gemm")
@register_fake_func("_C::gptq_marlin_gemm")
def _gptq_marlin_gemm_fake(a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
Expand All @@ -335,7 +350,7 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor,
is_zp_float: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

@register_fake("_C::marlin_qqq_gemm")
@register_fake_func("_C::marlin_qqq_gemm")
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
Expand All @@ -345,7 +360,7 @@ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
dtype=torch.float16,
device=a.device)

@register_fake("_C::marlin_gemm")
@register_fake_func("_C::marlin_gemm")
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
size_m: torch.SymInt, size_n: torch.SymInt,
Expand All @@ -354,7 +369,7 @@ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
dtype=torch.float16,
device=a.device)

@register_fake("_C::awq_dequantize")
@register_fake_func("_C::awq_dequantize")
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: torch.SymInt,
thx: int, thy: int) -> torch.Tensor:
Expand All @@ -365,7 +380,7 @@ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
dtype=scales.dtype,
device=scales.device)

@register_fake("_C::awq_gemm")
@register_fake_func("_C::awq_gemm")
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor, scales: torch.Tensor,
split_k_iters: torch.SymInt) -> torch.Tensor:
Expand All @@ -374,7 +389,7 @@ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
dtype=input.dtype,
device=input.device).sum(0)

@register_fake("_C::aqlm_gemm")
@register_fake_func("_C::aqlm_gemm")
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: List[int],
Expand All @@ -390,7 +405,7 @@ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
output_sizes.append(-1)
return flat_output.reshape(tuple(output_sizes))

@register_fake("_C::aqlm_dequant")
@register_fake_func("_C::aqlm_dequant")
def _aqlm_dequant_fake(
codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: List[int]) -> torch.Tensor:
Expand All @@ -400,15 +415,15 @@ def _aqlm_dequant_fake(
dtype=codebooks.dtype,
device=codebooks.device)

@register_fake("_C::fp8_marlin_gemm")
@register_fake_func("_C::fp8_marlin_gemm")
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: torch.SymInt,
size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)

@register_fake("_C::machete_mm")
@register_fake_func("_C::machete_mm")
def machete_mm_fake(
a: torch.Tensor,
# b_q Should be the tensor returned by machete_prepack_B
Expand All @@ -426,7 +441,7 @@ def machete_mm_fake(
n = b_q.size(1)
return torch.empty((m, n), device=a.device, dtype=a.dtype)

@register_fake("_C::machete_prepack_B")
@register_fake_func("_C::machete_prepack_B")
def machete_prepack_B_fake(
b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
Expand All @@ -436,13 +451,13 @@ def machete_prepack_B_fake(

if hasattr(torch.ops._C, "ggml_dequantize"):

@register_fake("_C::ggml_dequantize")
@register_fake_func("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
m: torch.SymInt,
n: torch.SymInt) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)

@register_fake("_C::ggml_mul_mat_vec_a8")
@register_fake_func("_C::ggml_mul_mat_vec_a8")
def _ggml_mul_mat_vec_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
Expand All @@ -451,7 +466,7 @@ def _ggml_mul_mat_vec_a8_fake(
) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device)

@register_fake("_C::ggml_mul_mat_a8")
@register_fake_func("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
Expand Down Expand Up @@ -758,7 +773,7 @@ def machete_prepack_B(

if hasattr(torch.ops._C, "permute_cols"):

@register_fake("_C::permute_cols")
@register_fake_func("_C::permute_cols")
def _permute_cols_fake(a: torch.Tensor,
perm: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a)
Expand Down Expand Up @@ -964,7 +979,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,

if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):

@register_fake("_moe_C::marlin_gemm_moe")
@register_fake_func("_moe_C::marlin_gemm_moe")
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
sorted_ids: torch.Tensor,
topk_weights: torch.Tensor,
Expand Down
14 changes: 6 additions & 8 deletions vllm/attention/ops/blocksparse_attention/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@

from vllm.platforms import current_platform

from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask)

IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80)

if IS_COMPUTE_8_OR_ABOVE:
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd


class LocalStridedBlockSparseAttn(torch.nn.Module):

IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80)

def __init__(
self,
n_heads,
Expand All @@ -33,12 +31,12 @@ def __init__(
if use_spda is None:
use_spda = current_platform.is_rocm() or \
current_platform.is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE
self.IS_COMPUTE_8_OR_ABOVE
device = device or (torch.cuda.current_device()
if current_platform.is_cuda_alike() else "cpu")
device = torch.device(device)
# NOTE: vllm CPU backend support BF16 instead of FP16.
dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
dtype = dtype or (torch.bfloat16 if self.IS_COMPUTE_8_OR_ABOVE
or device.type == "cpu" else torch.half)

self.n_heads = n_heads
Expand Down Expand Up @@ -122,7 +120,7 @@ def varlen_attn(self,
return: tensor of shape as q.
"""
assert (
IS_COMPUTE_8_OR_ABOVE
self.IS_COMPUTE_8_OR_ABOVE
), "Requires compute capability of 8 or above (Ampere or newer) to use \
Triton kernel."

Expand Down
11 changes: 6 additions & 5 deletions vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,8 @@

from vllm.platforms import current_platform

# Static kernels parameters
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
NUM_WARPS = 8

# To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5)

if triton.__version__ >= "2.1.0":

@triton.jit
Expand Down Expand Up @@ -719,11 +714,17 @@ def context_attention_fwd(q,
sliding_window=None):

q_dtype_is_f32 = q.dtype is torch.float32
# Static kernels parameters
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64

# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
# if q.dtype is torch.float32:
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK

# To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5)

# Turing does have tensor core for float32 multiplication
# use ieee as fallback for triton kernels work. There is also
# warning on vllm/config.py to inform users this fallback
Expand Down
4 changes: 3 additions & 1 deletion vllm/distributed/device_communicators/hpu_communicator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import contextlib

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.platforms import current_platform

if current_platform.is_hpu():
with contextlib.suppress(ImportError):
import habana_frameworks.torch as htorch # noqa: F401


Expand Down
3 changes: 2 additions & 1 deletion vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import os

import torch
Expand All @@ -6,7 +7,7 @@

from vllm.platforms import current_platform

if current_platform.is_tpu():
with contextlib.suppress(ImportError):
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
Expand Down
19 changes: 11 additions & 8 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform

if current_platform.is_cuda_alike():
from .fused_moe import fused_experts
else:
fused_experts = None # type: ignore
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)


Expand Down Expand Up @@ -135,6 +127,11 @@ def forward_cuda(
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)

if current_platform.is_cuda_alike():
from .fused_moe import fused_experts
else:
fused_experts = None # type: ignore

return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
Expand Down Expand Up @@ -170,6 +167,12 @@ def forward_tpu(
if e_score_correction_bias is not None:
raise NotImplementedError(
"Expert score correction bias is not supported for TPU.")

if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore

return fused_moe_pallas(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
Expand Down
5 changes: 1 addition & 4 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
Expand Down
5 changes: 1 addition & 4 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
HiddenStates, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.medusa_worker import MedusaWorker
Expand Down
Loading