Skip to content

Commit

Permalink
Initial prototype for multi-modal processor
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Nov 5, 2024
1 parent 93dee88 commit 5108119
Show file tree
Hide file tree
Showing 46 changed files with 943 additions and 351 deletions.
2 changes: 1 addition & 1 deletion docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Base Classes

.. autodata:: vllm.multimodal.MultiModalDataDict

.. autoclass:: vllm.multimodal.MultiModalInputs
.. autoclass:: vllm.multimodal.MultiModalKwargs
:members:
:show-inheritance:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/models/enabling_multimodal_inputs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ A default mapper is available for each modality in the core vLLM library. This i
3. Register maximum number of multi-modal tokens
------------------------------------------------

For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data instance
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data item
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.

.. code-block:: diff
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from PIL.Image import Image

from vllm.inputs import InputContext, token_inputs
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer

from .....conftest import IMAGE_ASSETS
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_input_mapper_valid_mm_data(input_mapper_for_qwen,
mapped_img_data = input_mapper_for_qwen(qwen_vl_context, img_data)
# Ensure that we get the appropriately shaped pixel_values
# for images and image embeddings, respectively.
assert isinstance(mapped_img_data, MultiModalInputs)
assert isinstance(mapped_img_data, MultiModalKwargs)
assert "pixel_values" in mapped_img_data
assert mapped_img_data["pixel_values"].shape == expected_shape

Expand Down
22 changes: 11 additions & 11 deletions tests/multimodal/test_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from vllm.multimodal.base import MultiModalInputs, NestedTensors
from vllm.multimodal.base import MultiModalKwargs, NestedTensors


def assert_nested_tensors_equal(expected: NestedTensors,
Expand All @@ -13,40 +13,40 @@ def assert_nested_tensors_equal(expected: NestedTensors,
assert_nested_tensors_equal(expected_item, actual_item)


def assert_multimodal_inputs_equal(expected: MultiModalInputs,
actual: MultiModalInputs):
def assert_multimodal_inputs_equal(expected: MultiModalKwargs,
actual: MultiModalKwargs):
assert set(expected.keys()) == set(actual.keys())
for key in expected:
assert_nested_tensors_equal(expected[key], actual[key])


def test_multimodal_input_batch_single_tensor():
t = torch.rand([1, 2])
result = MultiModalInputs.batch([{"image": t}])
result = MultiModalKwargs.batch([{"image": t}])
assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)})


def test_multimodal_input_batch_multiple_tensors():
a = torch.rand([1, 1, 2])
b = torch.rand([1, 1, 2])
c = torch.rand([1, 1, 2])
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}])
assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])})


def test_multimodal_input_batch_multiple_heterogeneous_tensors():
a = torch.rand([1, 2, 2])
b = torch.rand([1, 3, 2])
c = torch.rand([1, 4, 2])
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}])
assert_multimodal_inputs_equal(result, {"image": [a, b, c]})


def test_multimodal_input_batch_nested_tensors():
a = torch.rand([2, 3])
b = torch.rand([2, 3])
c = torch.rand([2, 3])
result = MultiModalInputs.batch([{
result = MultiModalKwargs.batch([{
"image": [a]
}, {
"image": [b]
Expand All @@ -65,7 +65,7 @@ def test_multimodal_input_batch_heterogeneous_lists():
a = torch.rand([1, 2, 3])
b = torch.rand([1, 2, 3])
c = torch.rand([1, 2, 3])
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}])
assert_multimodal_inputs_equal(
result,
{"image": [torch.stack([a, b]), c.unsqueeze(0)]})
Expand All @@ -76,7 +76,7 @@ def test_multimodal_input_batch_multiple_batchable_lists():
b = torch.rand([1, 2, 3])
c = torch.rand([1, 2, 3])
d = torch.rand([1, 2, 3])
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}])
result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c, d]}])
assert_multimodal_inputs_equal(
result,
{"image": torch.stack([torch.stack([a, b]),
Expand All @@ -88,8 +88,8 @@ def test_multimodal_input_batch_mixed_stacking_depths():
b = torch.rand([1, 3, 3])
c = torch.rand([1, 4, 3])

result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}])
assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]})

result = MultiModalInputs.batch([{"image": [a]}, {"image": [b, c]}])
result = MultiModalKwargs.batch([{"image": [a]}, {"image": [b, c]}])
assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]})
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class ModelConfig:
matches the model name exposed via the APIs. If multiple model
names provided, the first name will be used. If not specified,
the model name will be the same as `model`.
limit_mm_per_prompt: Maximum number of data instances per modality
limit_mm_per_prompt: Maximum number of data items per modality
per prompt. Only applicable for multimodal models.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
Expand Down Expand Up @@ -721,6 +722,9 @@ def _error_callback(self, exc: Exception) -> None:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)

async def get_input_preprocessor(self) -> InputPreprocessor:
return self.engine.input_preprocessor

async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
Expand Down
5 changes: 4 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
Expand Down Expand Up @@ -226,6 +227,7 @@ def __init__(
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
) -> None:

Expand Down Expand Up @@ -338,7 +340,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
model_config)

self.input_preprocessor = InputPreprocessor(model_config,
self.tokenizer)
self.tokenizer,
mm_registry)

self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor(
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -94,6 +95,8 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig,
parallel_config=engine_config.parallel_config,
enable_lora=bool(engine_config.lora_config),
)
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer)

# Send RPCGenerateRequest to the MQLLMEngine.
self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
Expand Down Expand Up @@ -345,6 +348,9 @@ async def _check_success(error_message: str, socket: Socket):
or response != VLLM_RPC_SUCCESS_STR):
raise ValueError(error_message)

async def get_input_preprocessor(self) -> InputPreprocessor:
return self.input_preprocessor

async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)

Expand Down
16 changes: 11 additions & 5 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def generate(
async def beam_search(
self,
prompt: PromptType,
model_config: ModelConfig,
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
Expand All @@ -74,13 +73,14 @@ async def beam_search(
length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output

tokenizer = await self.get_tokenizer()
input_preprocessor = InputPreprocessor(model_config, tokenizer)
preprocessor = await self.get_input_preprocessor()
tokenizer_group = preprocessor.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async()

if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError
else:
processed_inputs = input_preprocessor._prompt_to_llm_inputs(
processed_inputs = preprocessor._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)
Expand Down Expand Up @@ -220,6 +220,7 @@ async def abort(self, request_id: str) -> None:
Args:
request_id: The unique id of the request.
"""
...

@abstractmethod
async def get_model_config(self) -> ModelConfig:
Expand All @@ -228,8 +229,13 @@ async def get_model_config(self) -> ModelConfig:

@abstractmethod
async def get_decoding_config(self) -> DecodingConfig:
...
"""Get the decoding configuration of the vLLM engine."""
...

@abstractmethod
async def get_input_preprocessor(self) -> InputPreprocessor:
"""Get the input processor of the vLLM engine."""
...

@abstractmethod
async def get_tokenizer(
Expand Down
1 change: 0 additions & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ async def create_chat_completion(
if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt=engine_prompt,
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
)
Expand Down
1 change: 0 additions & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ async def create_completion(
if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt=engine_prompt,
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
)
Expand Down
4 changes: 3 additions & 1 deletion vllm/inputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import DummyData, InputContext, InputRegistry
from .registry import (DummyData, InputContext, InputProcessingContext,
InputRegistry)

INPUT_REGISTRY = InputRegistry()
"""
Expand Down Expand Up @@ -32,6 +33,7 @@
"INPUT_REGISTRY",
"DummyData",
"InputContext",
"InputProcessingContext",
"InputRegistry",
]

Expand Down
13 changes: 7 additions & 6 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.multimodal.inputs import MultiModalInputsV2


class TextPrompt(TypedDict):
Expand Down Expand Up @@ -36,13 +37,13 @@ class TokensPrompt(TypedDict):

multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
DEPRECATED: Optional multi-modal data to pass to the model,
if the model supports it.
"""

mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
Expand Down Expand Up @@ -176,7 +177,7 @@ def token_inputs(
return inputs


DecoderOnlyInputs = TokenInputs
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputsV2"]
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
Expand All @@ -191,14 +192,14 @@ class EncoderDecoderInputs(TypedDict):
This specifies the required data for encoder-decoder models.
"""
encoder: TokenInputs
encoder: Union[TokenInputs, "MultiModalInputsV2"]
"""The inputs for the encoder portion."""

decoder: TokenInputs
decoder: Union[TokenInputs, "MultiModalInputsV2"]
"""The inputs for the decoder portion."""


SingletonInputs = TokenInputs
SingletonInputs = Union[TokenInputs, "MultiModalInputsV2"]
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
Expand Down
Loading

0 comments on commit 5108119

Please sign in to comment.