Skip to content

Commit

Permalink
Use the inner processor to enable fine-grained caching
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 20, 2024
1 parent 32e5197 commit 7264d4e
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 57 deletions.
41 changes: 37 additions & 4 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
from collections import UserDict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Mapping, NamedTuple,
from typing import (TYPE_CHECKING, Any, Callable, Literal, Mapping, NamedTuple,
Optional, Protocol, Union)

from torch import nn
Expand Down Expand Up @@ -111,6 +111,39 @@ def get_hf_processor(

return hf_processor

def get_modality_processor(
self,
hf_processor: ProcessorMixin,
modality_data_key: Literal["text", "images", "videos", "audios"],
) -> Callable[..., BatchFeature]:
"""
Get the HuggingFace modality-specific processor which is
a child of a :class:`transformers.ProcessorMixin`, identified by
the corresponding keyword argument in its `__call__` method.
"""
if modality_data_key == "text":
attributes = ["tokenizer"]
elif modality_data_key == "images":
attributes = ["image_processor"]
elif modality_data_key == "videos":
attributes = ["video_processor"]
elif modality_data_key == "audios":
attributes = ["audio_processor", "feature_extractor"]
else:
assert_never(modality_data_key)

modality_processor = next(
(getattr(hf_processor, attr)
for attr in attributes if hasattr(hf_processor, attr)),
None,
)
if modality_processor is None:
raise AttributeError(
f"Cannot found HuggingFace processor for "
f"{modality_data_key} inside {type(hf_processor)}")

return modality_processor


@dataclass(frozen=True)
class InputProcessingContext(InputContext):
Expand All @@ -131,15 +164,15 @@ def get_hf_processor(

def call_hf_processor(
self,
hf_processor: ProcessorMixin,
hf_processor: Union[ProcessorMixin, Callable[..., BatchFeature]],
data: Mapping[str, object],
kwargs: Optional[Mapping[str, object]] = None,
) -> BatchFeature:
assert callable(hf_processor)

if kwargs is None:
kwargs = {}

assert callable(hf_processor)

base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
Expand Down
106 changes: 53 additions & 53 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass, field
from functools import lru_cache, partial
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union, cast
from typing import (Any, Literal, NamedTuple, Optional, Protocol, TypeVar,
Union, cast)

import numpy as np
import torch
Expand Down Expand Up @@ -616,8 +617,8 @@ def maybe_log_cache_stats(self, cache: LRUCache, name: str) -> None:
def _iter_bytes_to_hash(self, key: str, obj: object) -> Iterable[bytes]:
# Recursive cases
if isinstance(obj, (list, tuple)):
for elem in obj:
yield from self._iter_bytes_to_hash(key, elem)
for i, elem in enumerate(obj):
yield from self._iter_bytes_to_hash(f"{key}.{i}", elem)
return
if isinstance(obj, dict):
for k, v in obj.items():
Expand Down Expand Up @@ -664,66 +665,64 @@ def _cached_call_fine(
self,
ctx: InputProcessingContext,
hf_processor: ProcessorMixin,
prompt: str,
mm_data: Mapping[str, list[object]],
text: str,
mm_data: Mapping[Literal["images", "videos", "audios"], list[Any]],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_mm_items = defaultdict[str, list[torch.Tensor]]()

num_items = len(next(iter(mm_data.values())))
for idx in range(num_items):
mm_item = {k: [v[idx]] for k, v in mm_data.items()}

self.maybe_log_cache_stats(self._fine_mm_cache, "fine_mm_cache")

processed_mm_item = self._fine_mm_cache.get_or_put(
self._hash_kwargs(**mm_item, **mm_kwargs),
default_factory=partial(
ctx.call_hf_processor,
hf_processor,
mm_item,
mm_kwargs,
),
)

for k, v in processed_mm_item.items():
# Remove the extra batch dimension
processed_mm_items[k].append(v[0])

# NOTE: Some processors (e.g. llava) do not accept mm-only input,
# in which case we have to fallback to processing `prompt` and `mm_data`
# together. Therefore, we place the text processing last to avoid
# redundant computation
self.maybe_log_cache_stats(self._fine_text_cache, "fine_text_cache")

processed_text = self._fine_text_cache.get_or_put(
prompt,
text,
default_factory=partial(
ctx.call_hf_processor,
hf_processor,
dict(text=prompt),
ctx.get_modality_processor(hf_processor, "text"),
dict(text=text),
),
)

processed_data = dict(**processed_text, **processed_mm_items)
processed_data = dict(**processed_text)
for data_key, items in mm_data.items():
processed_modal_items = defaultdict[str, list[torch.Tensor]](list)

for item in items:
self.maybe_log_cache_stats(self._fine_mm_cache,
"fine_mm_cache")

modal_item = cast(Mapping[str, object], {data_key: item})
processed_modal_item = self._fine_mm_cache.get_or_put(
self._hash_kwargs(**modal_item, **mm_kwargs),
default_factory=partial(
ctx.call_hf_processor,
ctx.get_modality_processor(hf_processor, data_key),
modal_item,
mm_kwargs,
),
)

for k, v in processed_modal_item.items():
# Remove the extra batch dimension
processed_modal_items[k].append(v[0])

processed_data.update(processed_modal_items)

return BatchFeature(processed_data)

def _cached_call_coarse(
self,
ctx: InputProcessingContext,
hf_processor: ProcessorMixin,
prompt: str,
text: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
self.maybe_log_cache_stats(self._coarse_cache, "coarse_cache")

processed_data = self._coarse_cache.get_or_put(
self._hash_kwargs(text=prompt, **mm_data, **mm_kwargs),
self._hash_kwargs(text=text, **mm_data, **mm_kwargs),
default_factory=partial(
ctx.call_hf_processor,
hf_processor,
dict(text=prompt, **mm_data),
dict(text=text, **mm_data),
mm_kwargs,
),
)
Expand All @@ -737,34 +736,35 @@ def call_hf_processor(
ctx: InputProcessingContext,
# Assumes that hf_processor has been initialized according to kwargs
hf_processor: ProcessorMixin,
prompt: str,
text: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
# Try to cache each item separately to improve hit rate
if mm_data and all(isinstance(v, list) for v in mm_data.values()):
extra_keys = mm_data.keys() - {"images", "videos", "audios"}
if (mm_data and not extra_keys
and all(isinstance(v, list) for v in mm_data.values())):
try:
return self._cached_call_fine(
ctx,
hf_processor,
prompt,
cast(Mapping[str, list[object]], mm_data),
mm_kwargs,
text=text,
mm_data=mm_data, # type: ignore[arg-type]
mm_kwargs=mm_kwargs,
)
except Exception:
# Failures are expected; see NOTE in `_cached_call_fine`
logger.debug(
"Failed to apply processor on each item separately",
logger.exception(
"Failed to apply processor on each item separately! "
"Falling back to coarse caching.",
stack_info=True,
)
pass

return self._cached_call_coarse(
ctx,
hf_processor,
prompt,
mm_data,
mm_kwargs,
text=text,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)


Expand Down Expand Up @@ -872,9 +872,9 @@ def _call_hf_processor(
return self.cache.call_hf_processor(
self.ctx,
self._get_hf_processor(**mm_kwargs),
prompt,
mm_data,
mm_kwargs,
text=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

def _apply_hf_processor(
Expand Down

0 comments on commit 7264d4e

Please sign in to comment.