Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/N] Initial prototype for multi-modal processor #10044

Merged
merged 22 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/models/enabling_multimodal_inputs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ A default mapper is available for each modality in the core vLLM library. This i
3. Register maximum number of multi-modal tokens
------------------------------------------------

For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data instance
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data item
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.

.. code-block:: diff
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from PIL.Image import Image

from vllm.inputs import InputContext, token_inputs
from vllm.multimodal.base import MultiModalKwargs
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer

from .....conftest import IMAGE_ASSETS
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from vllm.multimodal.base import MultiModalKwargs, NestedTensors
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors


def assert_nested_tensors_equal(expected: NestedTensors,
Expand Down
37 changes: 23 additions & 14 deletions tests/multimodal/test_processor_kwargs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from array import array
from typing import Mapping
from typing import Callable, Dict, Mapping, Optional
from unittest.mock import patch

import pytest
import torch

from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext,
InputRegistry, token_inputs)
InputRegistry, ProcessorInputs, token_inputs)
from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData

Expand Down Expand Up @@ -34,10 +34,9 @@ def custom_processor(ctx: InputContext,
inputs: DecoderOnlyInputs,
*,
num_crops=DEFAULT_NUM_CROPS):
# For testing purposes, we don't worry about the llm inputs / return
# type validation, and just return the value of the kwarg that we
# clobber.
return num_crops
# For testing purposes, we don't worry about the prompt
return token_inputs(prompt_token_ids=[],
mm_processor_kwargs={"num_crops": num_crops})

with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
return_value=custom_processor):
Expand Down Expand Up @@ -109,6 +108,21 @@ def _get_num_crops_info(init_num_crops: int, inference_num_crops: int):
return init_kwargs, inference_kwargs, expected_seq_count


def _get_processed_num_crops(
processor: Callable[[ProcessorInputs], ProcessorInputs],
inference_kwargs: Optional[Dict[str, int]],
) -> int:
processed_inputs = processor(
token_inputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=inference_kwargs))

assert "type" in processed_inputs
assert processed_inputs["type"] == "token"
assert "mm_processor_kwargs" in processed_inputs
return processed_inputs["mm_processor_kwargs"]["num_crops"]


@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
(None, None),
(NUM_CROPS_OVERRIDE, None),
Expand All @@ -124,10 +138,8 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops,

ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(
token_inputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=inference_kwargs))
num_crops_val = _get_processed_num_crops(processor, inference_kwargs)

assert num_crops_val == expected_seq_count


Expand All @@ -153,10 +165,7 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,

processor = dummy_registry.create_input_processor(ctx.model_config)
# Should filter out the inference time kwargs
num_crops_val = processor(
token_inputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=mm_processor_kwargs))
num_crops_val = _get_processed_num_crops(processor, mm_processor_kwargs)
assert num_crops_val == DEFAULT_NUM_CROPS


Expand Down
4 changes: 2 additions & 2 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Compare the with and without prefix caching."""
from vllm.inputs import DecoderOnlyInputs
from vllm.inputs import token_inputs
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import hash_block_tokens
Expand All @@ -8,7 +8,7 @@
def make_request(request_id, prompt_token_ids):
return Request(
request_id=request_id,
inputs=DecoderOnlyInputs(prompt_token_ids=prompt_token_ids),
inputs=token_inputs(prompt_token_ids=prompt_token_ids),
sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100,
arrival_time=0,
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class ModelConfig:
matches the model name exposed via the APIs. If multiple model
names provided, the first name will be used. If not specified,
the model name will be the same as `model`.
limit_mm_per_prompt: Maximum number of data instances per modality
limit_mm_per_prompt: Maximum number of data items per modality
per prompt. Only applicable for multimodal models.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
Expand Down Expand Up @@ -729,6 +730,9 @@ def _error_callback(self, exc: Exception) -> None:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)

async def get_input_preprocessor(self) -> InputPreprocessor:
return self.engine.input_preprocessor

async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
Expand Down
16 changes: 6 additions & 10 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType)
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
Expand All @@ -39,6 +39,7 @@
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
Expand Down Expand Up @@ -226,6 +227,7 @@ def __init__(
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
) -> None:

Expand Down Expand Up @@ -335,7 +337,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
model_config)

self.input_preprocessor = InputPreprocessor(model_config,
self.tokenizer)
self.tokenizer,
mm_registry)

self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor(
Expand Down Expand Up @@ -851,13 +854,6 @@ def add_request(
)
processed_inputs = self.input_processor(preprocessed_inputs)

# This is a bit of a hack - copy the mm_processor_kwargs that were
# used in the input processor to the processed output, since these
# kwargs are presumed to be immutable and the values should be aligned
# between the input processor (here) and the input mapper.
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
"mm_processor_kwargs")

self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
Expand Down Expand Up @@ -2019,7 +2015,7 @@ def _validate_model_inputs(self, inputs: ProcessorInputs,
else:
prompt_inputs = inputs

prompt_ids = prompt_inputs.get("prompt_token_ids")
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids

if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -94,6 +95,8 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig,
parallel_config=engine_config.parallel_config,
enable_lora=bool(engine_config.lora_config),
)
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer)

# Send RPCGenerateRequest to the MQLLMEngine.
self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
Expand Down Expand Up @@ -345,6 +348,9 @@ async def _check_success(error_message: str, socket: Socket):
or response != VLLM_RPC_SUCCESS_STR):
raise ValueError(error_message)

async def get_input_preprocessor(self) -> InputPreprocessor:
return self.input_preprocessor

async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)

Expand Down
16 changes: 11 additions & 5 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def generate(
async def beam_search(
self,
prompt: PromptType,
model_config: ModelConfig,
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
Expand All @@ -74,13 +73,14 @@ async def beam_search(
length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output

tokenizer = await self.get_tokenizer()
input_preprocessor = InputPreprocessor(model_config, tokenizer)
preprocessor = await self.get_input_preprocessor()
tokenizer_group = preprocessor.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async()

if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError
else:
processed_inputs = input_preprocessor._prompt_to_llm_inputs(
processed_inputs = preprocessor._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)
Expand Down Expand Up @@ -220,6 +220,7 @@ async def abort(self, request_id: str) -> None:
Args:
request_id: The unique id of the request.
"""
...

@abstractmethod
async def get_model_config(self) -> ModelConfig:
Expand All @@ -228,8 +229,13 @@ async def get_model_config(self) -> ModelConfig:

@abstractmethod
async def get_decoding_config(self) -> DecodingConfig:
...
"""Get the decoding configuration of the vLLM engine."""
...

@abstractmethod
async def get_input_preprocessor(self) -> InputPreprocessor:
"""Get the input processor of the vLLM engine."""
...

@abstractmethod
async def get_tokenizer(
Expand Down
1 change: 0 additions & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ async def create_chat_completion(
if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt=engine_prompt,
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
)
Expand Down
1 change: 0 additions & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ async def create_completion(
if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt=engine_prompt,
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
)
Expand Down
12 changes: 8 additions & 4 deletions vllm/inputs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import DummyData, InputContext, InputRegistry
SingletonInputs, SingletonInputsAdapter, SingletonPrompt,
TextPrompt, TokenInputs, TokensPrompt,
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
token_inputs, zip_enc_dec_prompts)
from .registry import (DummyData, InputContext, InputProcessingContext,
InputRegistry)

INPUT_REGISTRY = InputRegistry()
"""
Expand All @@ -26,12 +28,14 @@
"EncoderDecoderInputs",
"ProcessorInputs",
"SingletonInputs",
"SingletonInputsAdapter",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
"INPUT_REGISTRY",
"DummyData",
"InputContext",
"InputProcessingContext",
"InputRegistry",
]

Expand Down
Loading