From 6e5c165e1e4ea25d105d4aa32ce178c74b4b083c Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 20 Nov 2024 17:00:18 +0000 Subject: [PATCH] [V1] VLM prefix caching: Add hashing of images --- vllm/v1/core/scheduler.py | 2 ++ vllm/v1/engine/__init__.py | 1 + vllm/v1/engine/core.py | 33 ++++++++++++++++++++++++++++++ vllm/v1/engine/processor.py | 1 + vllm/v1/request.py | 3 +++ vllm/v1/worker/gpu_model_runner.py | 2 ++ 6 files changed, 42 insertions(+) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index ba50a9786d805..bc30bd7b25256 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -508,6 +508,7 @@ class NewRequestData: req_id: str prompt_token_ids: List[int] prompt: Optional[str] + mm_hash: List[str] mm_inputs: List["MultiModalKwargs"] mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams @@ -525,6 +526,7 @@ def from_request( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, prompt=request.prompt, + mm_hash=request.mm_hash, mm_inputs=request.mm_inputs, mm_positions=request.mm_positions, sampling_params=request.sampling_params, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index edfb8bd7c2fc1..c57679b4fffb1 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -36,6 +36,7 @@ class EngineCoreRequest: prompt: Optional[str] prompt_token_ids: List[int] mm_data: Optional[MultiModalDataDict] + mm_hash: List[str] mm_placeholders: Optional[MultiModalPlaceholderDict] mm_processor_kwargs: Optional[Dict[str, Any]] sampling_params: SamplingParams diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 35ed131d50de9..bb4fccdcf9256 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -7,8 +7,10 @@ from multiprocessing.sharedctypes import Synchronized from typing import Any, Iterator, List, Tuple, Type, Union +import PIL import zmq import zmq.asyncio +from blake3 import blake3 from msgspec import msgpack from vllm.config import CacheConfig, VllmConfig @@ -93,6 +95,34 @@ def _initialize_kv_caches(self, self.model_executor.initialize_cache(num_gpu_blocks) return num_gpu_blocks, num_cpu_blocks + def hash_mm_data(self, req: EngineCoreRequest): + assert req.mm_data # Data exists + assert not req.mm_hash # No hash + + print("hash_mm_data: req_id = {}".format(req.request_id)) + + # FIXME(alexm): + # 1. Support other modalities + # 2. Support multiple images + image = req.mm_data.get("image") + assert isinstance(image, PIL.Image.Image) + + print(" type(data) = {}, data = {}".format(type(image), image)) + + # Convert image to bytes + start_time = time.time() + bytes = image.tobytes() + elapsed_time = time.time() - start_time + print(" tobytes time = {}".format(elapsed_time)) + + # Hash image bytes + start_time = time.time() + hasher = blake3() + hasher.update(bytes) + req.mm_hash.append(hasher.hexdigest()) + elapsed_time = time.time() - start_time + print(" hash time = {}".format(elapsed_time)) + def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" @@ -101,6 +131,9 @@ def add_request(self, request: EngineCoreRequest): # take 10-50 ms, which can cause a spike in the latency. We should # consider moving this to a separate thread. if req.mm_data: + + self.hash_mm_data(req) + req.mm_inputs = self.mm_input_mapper.process_inputs( req.mm_data, req.mm_processor_kwargs) self.scheduler.add_request(req) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 5c1577190c75a..6b7367bc17d09 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -114,6 +114,7 @@ def process_inputs( decoder_inputs.prompt, decoder_inputs.prompt_token_ids, decoder_inputs.multi_modal_data, + [], # Initially, mm hash is empty decoder_inputs.multi_modal_placeholders, decoder_inputs.mm_processor_kwargs, sampling_params, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 51fb4003e5fe0..cfc1d56576889 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -57,6 +57,9 @@ def __init__( # Output of the mm input mapper (e.g., image tensors). self.mm_inputs: List[MultiModalKwargs] = [] + # FIXME(alexm): Support other modalities (not just image) + self.mm_hash: List[int] = [] + @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1f9b544637bf7..f91c96c2adb13 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -169,6 +169,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_id=req_id, prompt_token_ids=req_data.prompt_token_ids, prompt=req_data.prompt, + mm_hash=req_data.mm_hash, mm_inputs=req_data.mm_inputs, mm_positions=req_data.mm_positions, sampling_params=sampling_params, @@ -599,6 +600,7 @@ class CachedRequestState: req_id: str prompt_token_ids: List[int] prompt: Optional[str] + mm_hash: List[str] mm_inputs: List[MultiModalKwargs] mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams