From d6e634f3d78649002460758f3964e4df0d39a546 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 13 Aug 2024 00:30:30 -0700 Subject: [PATCH] [TPU] Suppress import custom_ops warning (#7458) --- vllm/_custom_ops.py | 10 ++++++---- vllm/utils.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4331db8ee4e82..b6329859830ca 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -6,13 +6,15 @@ from vllm._core_ext import ScalarType from vllm.logger import init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) -try: - import vllm._C -except ImportError as e: - logger.warning("Failed to import from vllm._C with %r", e) +if not current_platform.is_tpu(): + try: + import vllm._C + except ImportError as e: + logger.warning("Failed to import from vllm._C with %r", e) with contextlib.suppress(ImportError): # ruff: noqa: F401 diff --git a/vllm/utils.py b/vllm/utils.py index a758c78dc9c25..753efca3e2a61 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -29,7 +29,6 @@ from typing_extensions import ParamSpec, TypeIs, assert_never import vllm.envs as envs -from vllm import _custom_ops as ops from vllm.logger import enable_trace_function_call, init_logger logger = init_logger(__name__) @@ -359,6 +358,7 @@ def is_xpu() -> bool: @lru_cache(maxsize=None) def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" + from vllm import _custom_ops as ops max_shared_mem = ( ops.get_max_shared_memory_per_block_device_attribute(gpu)) # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py