diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b2952bbfa917c..a25f7abca5498 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -26,7 +26,8 @@ import vllm._moe_C # noqa: F401 supports_moe_ops = True -if TYPE_CHECKING: +# neuron has torch version that doesn't even have impl_abstract +if TYPE_CHECKING or current_platform.is_neuron(): def register_fake(fn): return lambda name: fn diff --git a/vllm/config.py b/vllm/config.py index 00dd047e6d058..12935e77c2aa7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -17,8 +17,7 @@ get_hf_image_processor_config, get_hf_text_config) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - is_hip, is_neuron, is_openvino, is_xpu, - print_warning_once) + is_hip, is_openvino, is_xpu, print_warning_once) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -215,8 +214,10 @@ def __init__(self, self.is_attention_free = self._init_attention_free() self.has_inner_state = self._init_has_inner_state() - self.override_neuron_config = override_neuron_config if is_neuron( - ) else None + if current_platform.is_neuron(): + self.override_neuron_config = override_neuron_config + else: + self.override_neuron_config = None supported_tasks, task = self._resolve_task(task, self.hf_config) self.supported_tasks = supported_tasks @@ -368,7 +369,7 @@ def _verify_quantization(self) -> None: "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" " is not set, enabling VLLM_USE_TRITON_AWQ.") envs.VLLM_USE_TRITON_AWQ = True - if is_neuron( + if current_platform.is_neuron( ) and self.quantization not in neuron_supported_quantization: raise ValueError( f"{self.quantization} quantization is currently not " @@ -1112,7 +1113,7 @@ def __init__(self, device: str = "auto") -> None: # Automated device type detection if current_platform.is_cuda_alike(): self.device_type = "cuda" - elif is_neuron(): + elif current_platform.is_neuron(): self.device_type = "neuron" elif is_openvino(): self.device_type = "openvino" diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index c648862b2d757..58912158139bd 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -58,6 +58,13 @@ except Exception: pass +is_neuron = False +try: + import transformers_neuronx # noqa: F401 + is_neuron = True +except ImportError: + pass + if is_tpu: # people might install pytorch built with cuda but run on tpu # so we need to check tpu first @@ -75,6 +82,9 @@ elif is_cpu: from .cpu import CpuPlatform current_platform = CpuPlatform() +elif is_neuron: + from .neuron import NeuronPlatform + current_platform = NeuronPlatform() else: current_platform = UnspecifiedPlatform() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 00742a290e42a..d36367f2bc9c1 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -10,6 +10,7 @@ class PlatformEnum(enum.Enum): TPU = enum.auto() XPU = enum.auto() CPU = enum.auto() + NEURON = enum.auto() UNSPECIFIED = enum.auto() @@ -48,6 +49,9 @@ def is_xpu(self) -> bool: def is_cpu(self) -> bool: return self._enum == PlatformEnum.CPU + def is_neuron(self) -> bool: + return self._enum == PlatformEnum.NEURON + def is_cuda_alike(self) -> bool: """Stateless version of :func:`torch.cuda.is_available`.""" return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py new file mode 100644 index 0000000000000..07d8398eda525 --- /dev/null +++ b/vllm/platforms/neuron.py @@ -0,0 +1,9 @@ +from .interface import Platform, PlatformEnum + + +class NeuronPlatform(Platform): + _enum = PlatformEnum.NEURON + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return "neuron" diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index ce46082247639..ef7ca149266b6 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -1,10 +1,13 @@ from importlib.util import find_spec from vllm.logger import init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) -HAS_TRITON = find_spec("triton") is not None +# neuron has too old torch +HAS_TRITON = find_spec( + "triton") is not None and not current_platform.is_neuron() if not HAS_TRITON: logger.info("Triton not installed; certain GPU-related functions" diff --git a/vllm/utils.py b/vllm/utils.py index 428c2095dcd5d..797c1bcfd5342 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -327,15 +327,6 @@ def is_openvino() -> bool: return False -@lru_cache(maxsize=None) -def is_neuron() -> bool: - try: - import transformers_neuronx - except ImportError: - transformers_neuronx = None - return transformers_neuronx is not None - - @lru_cache(maxsize=None) def is_xpu() -> bool: from importlib.metadata import PackageNotFoundError, version @@ -786,7 +777,7 @@ def is_pin_memory_available() -> bool: elif is_xpu(): print_warning_once("Pin memory is not supported on XPU.") return False - elif is_neuron(): + elif current_platform.is_neuron(): print_warning_once("Pin memory is not supported on Neuron.") return False elif current_platform.is_cpu() or is_openvino():