diff --git a/docs/source/design/class_hierarchy.rst b/docs/source/design/class_hierarchy.rst index 15f0c8ccf77ee..58a888b17ba53 100644 --- a/docs/source/design/class_hierarchy.rst +++ b/docs/source/design/class_hierarchy.rst @@ -1,3 +1,5 @@ +.. _class_hierarchy: + vLLM's Class Hierarchy ======================= diff --git a/docs/source/design/plugin_system.rst b/docs/source/design/plugin_system.rst new file mode 100644 index 0000000000000..bfca702b9267a --- /dev/null +++ b/docs/source/design/plugin_system.rst @@ -0,0 +1,62 @@ +.. _plugin_system: + +vLLM's Plugin System +==================== + +The community frequently requests the ability to extend vLLM with custom features. To facilitate this, vLLM includes a plugin system that allows users to add custom features without modifying the vLLM codebase. This document explains how plugins work in vLLM and how to create a plugin for vLLM. + +How Plugins Work in vLLM +------------------------ + +Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see :ref:`class_hierarchy`), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the `load_general_plugins `__ function in the ``vllm.plugins`` module. This function is called for every process created by vLLM before it starts any work. + +How vLLM Discovers Plugins +-------------------------- + +vLLM's plugin system uses the standard Python ``entry_points`` mechanism. This mechanism allows developers to register functions in their Python packages for use by other packages. An example of a plugin: + +.. code-block:: python + + # inside `setup.py` file + from setuptools import setup + + setup(name='vllm_add_dummy_model', + version='0.1', + packages=['vllm_add_dummy_model'], + entry_points={ + 'vllm.general_plugins': + ["register_dummy_model = vllm_add_dummy_model:register"] + }) + + # inside `vllm_add_dummy_model.py` file + def register(): + from vllm import ModelRegistry + + if "MyLlava" not in ModelRegistry.get_supported_archs(): + ModelRegistry.register_model("MyLlava", + "vllm_add_dummy_model.my_llava:MyLlava") + +For more information on adding entry points to your package, please check the `official documentation `__. + +Every plugin has three parts: + +1. **Plugin group**: The name of the entry point group. vLLM uses the entry point group ``vllm.general_plugins`` to register general plugins. This is the key of ``entry_points`` in the ``setup.py`` file. Always use ``vllm.general_plugins`` for vLLM's general plugins. + +2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the ``entry_points`` dictionary. In the example above, the plugin name is ``register_dummy_model``. Plugins can be filtered by their names using the ``VLLM_PLUGINS`` environment variable. To load only a specific plugin, set ``VLLM_PLUGINS`` to the plugin name. + +3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is ``vllm_add_dummy_model:register``, which refers to a function named ``register`` in the ``vllm_add_dummy_model`` module. + +What Can Plugins Do? +-------------------- + +Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling ``ModelRegistry.register_model`` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM. + +Guidelines for Writing Plugins +------------------------------ + +- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes. + +Compatibility Guarantee +----------------------- + +vLLM guarantees the interface of documented plugins, such as ``ModelRegistry.register_model``, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, ``"vllm_add_dummy_model.my_llava:MyLlava"`` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development. \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index a2abd2995b1cc..3b2698a8845ed 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -158,6 +158,7 @@ Documentation design/class_hierarchy design/huggingface_integration + design/plugin_system design/input_processing/model_inputs_index design/kernel/paged_attention design/multimodal/multimodal_index diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index c6d88cc38e99b..a70ebf99c746f 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -102,11 +102,11 @@ This method should load the weights from the HuggingFace's checkpoint file and a Finally, register your :code:`*ForCausalLM` class to the :code:`_VLLM_MODELS` in `vllm/model_executor/models/registry.py `_. 6. Out-of-Tree Model Integration --------------------------------------------- +-------------------------------- -We also provide a way to integrate a model without modifying the vLLM codebase. Step 2, 3, 4 are still required, but you can skip step 1 and 5. +You can integrate a model without modifying the vLLM codebase. Steps 2, 3, and 4 are still required, but you can skip steps 1 and 5. Instead, write a plugin to register your model. For general introduction of the plugin system, see :ref:`plugin_system`. -Just add the following lines in your code: +To register the model, use the following code: .. code-block:: python @@ -114,7 +114,7 @@ Just add the following lines in your code: from your_code import YourModelForCausalLM ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) -If your model imports modules that initialize CUDA, consider instead lazy-importing it to avoid an error like :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`: +If your model imports modules that initialize CUDA, consider lazy-importing it to avoid errors like :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`: .. code-block:: python @@ -123,19 +123,8 @@ If your model imports modules that initialize CUDA, consider instead lazy-import ModelRegistry.register_model("YourModelForCausalLM", "your_code:YourModelForCausalLM") .. important:: - If your model is a multimodal model, make sure the model class implements the :class:`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface. + If your model is a multimodal model, ensure the model class implements the :class:`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface. Read more about that :ref:`here `. -If you are running api server with :code:`vllm serve `, you can wrap the entrypoint with the following code: - -.. code-block:: python - - from vllm import ModelRegistry - from your_code import YourModelForCausalLM - ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) - - if __name__ == '__main__': - import runpy - runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') - -Save the above code in a file and run it with :code:`python your_file.py `. +.. note:: + Although you can directly put these code snippets in your script using ``vllm.LLM``, the recommended way is to place these snippets in a vLLM plugin. This ensures compatibility with various vLLM features like distributed inference and the API server. diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 8373e11cfff9f..9fca724599012 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -27,16 +27,24 @@ def load_general_plugins(): allowed_plugins = envs.VLLM_PLUGINS discovered_plugins = entry_points(group='vllm.general_plugins') + logger.info("Available plugins:") + for plugin in discovered_plugins: + logger.info("name=%s, value=%s, group=%s", plugin.name, plugin.value, + plugin.group) + if allowed_plugins is None: + logger.info("all available plugins will be loaded.") + logger.info("set environment variable VLLM_PLUGINS to control" + " which plugins to load.") + else: + logger.info("plugins to load: %s", allowed_plugins) for plugin in discovered_plugins: - logger.info("Found general plugin: %s", plugin.name) if allowed_plugins is None or plugin.name in allowed_plugins: try: func = plugin.load() func() - logger.info("Loaded general plugin: %s", plugin.name) + logger.info("plugin %s loaded.", plugin.name) except Exception: - logger.exception("Failed to load general plugin: %s", - plugin.name) + logger.exception("Failed to load plugin %s", plugin.name) _torch_compile_backend: Optional[Union[Callable, str]] = None