Skip to content

Commit

Permalink
Add platform pluggable
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan committed Dec 23, 2024
1 parent 048fc57 commit d6f685a
Show file tree
Hide file tree
Showing 12 changed files with 354 additions and 159 deletions.
2 changes: 1 addition & 1 deletion docs/source/design/plugin_system.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Every plugin has three parts:
What Can Plugins Do?
--------------------

Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling ``ModelRegistry.register_model`` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
Currently, the primary use case for plugins is to register custom, out-of-the-tree models or platform into vLLM. This is done by calling ``ModelRegistry.register_model`` or ``PlatformRegistry.register_platform`` to register the model or platform. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.

Guidelines for Writing Plugins
------------------------------
Expand Down
9 changes: 9 additions & 0 deletions tests/plugins/vllm_add_dummy_platform/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from setuptools import setup

setup(name='vllm_add_dummy_platform',
version='0.1',
packages=['vllm_add_dummy_platform'],
entry_points={
'vllm.general_plugins':
["register_dummy_model = vllm_add_dummy_platform:register"]
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from vllm import PlatformRegistry


def register():
# Register the dummy platform
PlatformRegistry.register_platform(
"my_platform", "vllm_add_dummy_platform.my_platform:DummyPlatform")
# Set the current platform to the dummy platform
PlatformRegistry.set_current_platform("my_platform")
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class DummyAttentionImpl:

def forward(self):
pass


class DummyAttentionBackend:

def __init__(self):
pass

def get_impl_cls(self):
return DummyAttentionImpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from my_attention import DummyAttentionBackend


class DummyModelRunner:

def __init__(self):
self.attn_backend = DummyAttentionBackend()
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from vllm.config import VllmConfig
from vllm.platforms import Platform


class DummyPlatform(Platform):
device_name = "dummy"

def __init__(self):
super().__init__()

@classmethod
def get_device_name(cls) -> str:
return "dummy"

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
parallel_config.worker_cls = \
"vllm_add_dummy_platform.my_worker.DummyWorker"
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import List

from my_model_runner import DummyModelRunner


class DummyCacheEngine:
pass


class DummyWorker:

def __init__(self):
self.cache_engine = List[DummyCacheEngine]
self.model_runner = DummyModelRunner()
7 changes: 7 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,23 @@
EmbeddingRequestOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.platforms.registry import PlatformRegistry
from vllm.plugins import load_general_plugins
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams

from .version import __version__, __version_tuple__

# Load general plugins first when the module is imported to make sure that all
# necessary global variables are set. Such as the `current_platform`.
load_general_plugins()

__all__ = [
"__version__",
"__version_tuple__",
"LLM",
"ModelRegistry",
"PlatformRegistry",
"PromptType",
"TextPrompt",
"TokensPrompt",
Expand Down
40 changes: 24 additions & 16 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op

CUSTOM_OPS_REGISTERED = False


class Attention(nn.Module):
"""Attention layer.
Expand Down Expand Up @@ -129,6 +131,7 @@ def forward(
attn_metadata: AttentionMetadata,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
_register_custom_ops()

if self.use_direct_call:
return self.impl.forward(query,
Expand Down Expand Up @@ -263,15 +266,6 @@ def unified_attention_fake(
return torch.empty_like(query).contiguous()


direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)


def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -307,10 +301,24 @@ def unified_attention_with_output_fake(
return


direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
def _register_custom_ops():
"""Register custom ops for attention."""
global CUSTOM_OPS_REGISTERED
if CUSTOM_OPS_REGISTERED:
return

direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
CUSTOM_OPS_REGISTERED = True
146 changes: 40 additions & 106 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,123 +1,57 @@
from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform
from .registry import PlatformRegistry, detect_current_platform

current_platform: Platform
_current_platform: Platform = UnspecifiedPlatform()

# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.

is_tpu = False
try:
# While it's technically possible to install libtpu on a non-TPU machine,
# this is a very uncommon scenario. Therefore, we assume that libtpu is
# installed if and only if the machine has TPUs.
import libtpu # noqa: F401
is_tpu = True
except Exception:
pass
def initialize_current_platform():
"""Initialize the current platform. This function is called when loading
the vllm plugin."""
global _current_platform
# Get the current platform from the registry first. If the current platform
# is not set, try to detect the current platform.
if PlatformRegistry.current_platform is not None:
_current_platform = PlatformRegistry.get_current_platform_cls()
else:
_current_platform = detect_current_platform()

is_cuda = False

try:
import pynvml
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
is_cuda = True
finally:
pynvml.nvmlShutdown()
except Exception:
# CUDA is supported on Jetson, but NVML may not be.
import os
def update_current_platform(device_name: str):
"""Update the current platform. This function is used by users to set the
current platform by hand."""
global _current_platform
PlatformRegistry.set_current_platform(device_name)
_current_platform = PlatformRegistry.get_current_platform_cls()

def cuda_is_jetson() -> bool:
return os.path.isfile("/etc/nv_tegra_release") \
or os.path.exists("/sys/class/tegra-firmware")

if cuda_is_jetson():
is_cuda = True
class CurrentPlatform:
"""A wrapper that provides an interface to the current platform.
`current_platform` is imported to many modules once vLLM is imported.
Updating `current_platform` value directly will not work in those modules.
So it needs the wrapper here to provide a dynamic platform loading
mechanism.
is_rocm = False
This class can make sure that the `current_platform` is always up-to-date.
"""

try:
import amdsmi
amdsmi.amdsmi_init()
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
is_rocm = True
finally:
amdsmi.amdsmi_shut_down()
except Exception:
pass
def __init__(self):
self.platform = _current_platform

is_hpu = False
try:
from importlib import util
is_hpu = util.find_spec('habana_frameworks') is not None
except Exception:
pass
def _refresh_current_platform(self):
"""Refresh the current platform dynamically."""
global _current_platform
if _current_platform is not self.platform:
self.platform = _current_platform

is_xpu = False
def __getattr__(self, name):
"""Go pass to the current platform."""
self._refresh_current_platform()
return getattr(self.platform, name)

try:
# installed IPEX if the machine has XPUs.
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
if hasattr(torch, 'xpu') and torch.xpu.is_available():
is_xpu = True
except Exception:
pass

is_cpu = False
try:
from importlib.metadata import version
is_cpu = "cpu" in version("vllm")
except Exception:
pass

is_neuron = False
try:
import transformers_neuronx # noqa: F401
is_neuron = True
except ImportError:
pass

is_openvino = False
try:
from importlib.metadata import version
is_openvino = "openvino" in version("vllm")
except Exception:
pass

if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from .tpu import TpuPlatform
current_platform = TpuPlatform()
elif is_cuda:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif is_rocm:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_hpu:
from .hpu import HpuPlatform
current_platform = HpuPlatform()
elif is_xpu:
from .xpu import XPUPlatform
current_platform = XPUPlatform()
elif is_cpu:
from .cpu import CpuPlatform
current_platform = CpuPlatform()
elif is_neuron:
from .neuron import NeuronPlatform
current_platform = NeuronPlatform()
elif is_openvino:
from .openvino import OpenVinoPlatform
current_platform = OpenVinoPlatform()
else:
current_platform = UnspecifiedPlatform()
# The global variable for other modules to use.
current_platform: CurrentPlatform = CurrentPlatform()

__all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum']
Loading

0 comments on commit d6f685a

Please sign in to comment.