|
| 1 | +.. _adding_a_new_multimodal_model: |
| 2 | + |
| 3 | +Adding a New Multimodal Model |
| 4 | +============================= |
| 5 | + |
| 6 | +This document provides a high-level guide on integrating a :ref:`multi-modal model <multi_modality>` into vLLM. |
| 7 | + |
| 8 | +.. note:: |
| 9 | + The complexity of adding a new model depends heavily on the model's architecture. |
| 10 | + The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM. |
| 11 | + However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex. |
| 12 | + |
| 13 | +.. tip:: |
| 14 | + If you are encountering issues while integrating your model into vLLM, feel free to open an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ repository. |
| 15 | + We will be happy to help you out! |
| 16 | + |
| 17 | + |
| 18 | +1. Set up the base vLLM model |
| 19 | +----------------------------- |
| 20 | + |
| 21 | +As usual, follow :ref:`these steps <adding_a_new_model>` to implement the model in vLLM, but note the following: |
| 22 | + |
| 23 | +- You should additionally implement the :class:`~vllm.model_executor.models.interfaces.SupportsVision` interface. |
| 24 | + |
| 25 | + .. code-block:: diff |
| 26 | +
|
| 27 | + + from vllm.model_executor.models.interfaces import SupportsVision |
| 28 | +
|
| 29 | + - class YourModelForImage2Seq(nn.Module): |
| 30 | + + class YourModelForImage2Seq(nn.Module, SupportsVision): |
| 31 | +
|
| 32 | + .. note:: |
| 33 | + The model class does not have to be named :code:`*ForCausalLM`. |
| 34 | + Check out `the HuggingFace Transformers documentation <https://huggingface.co/docs/transformers/model_doc/auto#multimodal>`__ for some examples. |
| 35 | + |
| 36 | +- While implementing the :meth:`~torch.nn.Module.forward` method, reserve a keyword parameter |
| 37 | + for each input tensor that corresponds to a multi-modal input, as shown in the following example: |
| 38 | + |
| 39 | + .. code-block:: diff |
| 40 | +
|
| 41 | + def forward( |
| 42 | + self, |
| 43 | + input_ids: torch.Tensor, |
| 44 | + positions: torch.Tensor, |
| 45 | + kv_caches: List[torch.Tensor], |
| 46 | + attn_metadata: AttentionMetadata, |
| 47 | + + pixel_values: torch.Tensor, |
| 48 | + ) -> SamplerOutput: |
| 49 | +
|
| 50 | +
|
| 51 | +2. Register input mappers |
| 52 | +------------------------- |
| 53 | + |
| 54 | +For each modality type to support, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`. |
| 55 | +This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`. |
| 56 | + |
| 57 | +.. code-block:: diff |
| 58 | +
|
| 59 | + from vllm.model_executor.models.interfaces import SupportsVision |
| 60 | + + from vllm.multimodal import MULTIMODAL_REGISTRY |
| 61 | +
|
| 62 | + + @MULTIMODAL_REGISTRY.register_image_feature_input_mapper() |
| 63 | + + @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper() |
| 64 | + class YourModelForImage2Seq(nn.Module, SupportsVision): |
| 65 | +
|
| 66 | +A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function. |
| 67 | + |
| 68 | +.. seealso:: |
| 69 | + :ref:`input_processing_pipeline` |
| 70 | + |
| 71 | + |
| 72 | +3. (Optional) Register dummy data |
| 73 | +--------------------------------- |
| 74 | + |
| 75 | +During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models. |
| 76 | +In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`. |
| 77 | + |
| 78 | +.. code-block:: diff |
| 79 | +
|
| 80 | + from vllm.inputs import INPUT_REGISTRY |
| 81 | + from vllm.model_executor.models.interfaces import SupportsVision |
| 82 | + from vllm.multimodal import MULTIMODAL_REGISTRY |
| 83 | +
|
| 84 | + @MULTIMODAL_REGISTRY.register_image_feature_input_mapper() |
| 85 | + @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper() |
| 86 | + + @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>) |
| 87 | + class YourModelForImage2Seq(nn.Module, SupportsVision): |
| 88 | +
|
| 89 | +Here are some examples: |
| 90 | + |
| 91 | +- Image inputs (static feature size): `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__ |
| 92 | +- Image inputs (dynamic feature size): `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__ |
| 93 | + |
| 94 | +.. seealso:: |
| 95 | + :ref:`input_processing_pipeline` |
| 96 | + |
| 97 | + |
| 98 | +4. (Optional) Register input processor |
| 99 | +-------------------------------------- |
| 100 | + |
| 101 | +Sometimes, there is a need to process inputs at the :class:`~vllm.LLMEngine` level before they are passed to the model executor. |
| 102 | +This is often due to the fact that unlike implementations in HuggingFace Transformers, the reshaping and/or expansion of multi-modal embeddings needs to take place outside model's :meth:`~torch.nn.Module.forward` call. |
| 103 | +You can register input processors via :meth:`INPUT_REGISTRY.register_input_processor <vllm.inputs.registry.InputRegistry.register_input_processor>`. |
| 104 | + |
| 105 | +.. code-block:: diff |
| 106 | +
|
| 107 | + from vllm.inputs import INPUT_REGISTRY |
| 108 | + from vllm.model_executor.models.interfaces import SupportsVision |
| 109 | + from vllm.multimodal import MULTIMODAL_REGISTRY |
| 110 | +
|
| 111 | + @MULTIMODAL_REGISTRY.register_image_feature_input_mapper() |
| 112 | + @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper() |
| 113 | + @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>) |
| 114 | + + @INPUT_REGISTRY.register_input_processor(<your_input_processor>) |
| 115 | + class YourModelForImage2Seq(nn.Module, SupportsVision): |
| 116 | +
|
| 117 | +A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation. |
| 118 | +Here are some examples: |
| 119 | + |
| 120 | +- Insert static number of image tokens: `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__ |
| 121 | +- Insert dynamic number of image tokens: `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__ |
| 122 | + |
| 123 | +.. seealso:: |
| 124 | + :ref:`input_processing_pipeline` |
0 commit comments