Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] VLM prefix caching: Add hashing of images #10497

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ class NewRequestData:
req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
mm_hash: List[str]
mm_inputs: List["MultiModalKwargs"]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams
Expand All @@ -525,6 +526,7 @@ def from_request(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
mm_hash=request.mm_hash,
mm_inputs=request.mm_inputs,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class EngineCoreRequest:
prompt: Optional[str]
prompt_token_ids: List[int]
mm_data: Optional[MultiModalDataDict]
mm_hash: List[str]
mm_placeholders: Optional[MultiModalPlaceholderDict]
mm_processor_kwargs: Optional[Dict[str, Any]]
sampling_params: SamplingParams
Expand Down
33 changes: 33 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from multiprocessing.sharedctypes import Synchronized
from typing import Any, Iterator, List, Tuple, Type, Union

import PIL
import zmq
import zmq.asyncio
from blake3 import blake3
from msgspec import msgpack

from vllm.config import CacheConfig, VllmConfig
Expand Down Expand Up @@ -93,6 +95,34 @@ def _initialize_kv_caches(self,
self.model_executor.initialize_cache(num_gpu_blocks)
return num_gpu_blocks, num_cpu_blocks

def hash_mm_data(self, req: EngineCoreRequest):
assert req.mm_data # Data exists
assert not req.mm_hash # No hash

print("hash_mm_data: req_id = {}".format(req.request_id))

# FIXME(alexm):
# 1. Support other modalities
# 2. Support multiple images
image = req.mm_data.get("image")
assert isinstance(image, PIL.Image.Image)

print(" type(data) = {}, data = {}".format(type(image), image))

# Convert image to bytes
start_time = time.time()
bytes = image.tobytes()
elapsed_time = time.time() - start_time
print(" tobytes time = {}".format(elapsed_time))

# Hash image bytes
start_time = time.time()
hasher = blake3()
hasher.update(bytes)
req.mm_hash.append(hasher.hexdigest())
elapsed_time = time.time() - start_time
print(" hash time = {}".format(elapsed_time))

def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""

Expand All @@ -101,6 +131,9 @@ def add_request(self, request: EngineCoreRequest):
# take 10-50 ms, which can cause a spike in the latency. We should
# consider moving this to a separate thread.
if req.mm_data:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on doing this on the frontend engine process (i.e. v1/engine/processor.py::Processor) before sending to the EngineCore?

IIUC: this add_request is called on the EngineCore process, meaning it's sync blocking the model executor too?

Copy link
Member

@ywang96 ywang96 Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea this is already planned. Eventually the multimodal data processor will live on the frontend, together with input token sequence processor. #10044 is working towards this direction.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rickyyx I think it is a good idea, I can try it.

self.hash_mm_data(req)

req.mm_inputs = self.mm_input_mapper.process_inputs(
req.mm_data, req.mm_processor_kwargs)
self.scheduler.add_request(req)
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def process_inputs(
decoder_inputs.prompt,
decoder_inputs.prompt_token_ids,
decoder_inputs.multi_modal_data,
[], # Initially, mm hash is empty
decoder_inputs.multi_modal_placeholders,
decoder_inputs.mm_processor_kwargs,
sampling_params,
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def __init__(
# Output of the mm input mapper (e.g., image tensors).
self.mm_inputs: List[MultiModalKwargs] = []

# FIXME(alexm): Support other modalities (not just image)
self.mm_hash: List[int] = []

@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls(
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
req_id=req_id,
prompt_token_ids=req_data.prompt_token_ids,
prompt=req_data.prompt,
mm_hash=req_data.mm_hash,
mm_inputs=req_data.mm_inputs,
mm_positions=req_data.mm_positions,
sampling_params=sampling_params,
Expand Down Expand Up @@ -599,6 +600,7 @@ class CachedRequestState:
req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
mm_hash: List[str]
mm_inputs: List[MultiModalKwargs]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams
Expand Down