Skip to content

Commit

Permalink
[Frontend] Multi-Modality Support for Loading Local Image Files
Browse files Browse the repository at this point in the history
FIX #8730

Signed-off-by: chaunceyjiang <[email protected]>
  • Loading branch information
chaunceyjiang committed Nov 4, 2024
1 parent b67feb1 commit 5aa7b79
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 30 deletions.
16 changes: 12 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class ModelConfig:
"mistral" will always use the tokenizer from `mistral_common`.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or
videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted
environments.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
Expand Down Expand Up @@ -134,6 +138,7 @@ def __init__(
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
allowed_local_media_path: str = None,

Check failure on line 141 in vllm/config.py

View workflow job for this annotation

GitHub Actions / mypy (3.8)

Incompatible default for argument "allowed_local_media_path" (default has type "None", argument has type "str") [assignment]

Check failure on line 141 in vllm/config.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible default for argument "allowed_local_media_path" (default has type "None", argument has type "str") [assignment]

Check failure on line 141 in vllm/config.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible default for argument "allowed_local_media_path" (default has type "None", argument has type "str") [assignment]

Check failure on line 141 in vllm/config.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible default for argument "allowed_local_media_path" (default has type "None", argument has type "str") [assignment]

Check failure on line 141 in vllm/config.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible default for argument "allowed_local_media_path" (default has type "None", argument has type "str") [assignment]
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
Expand Down Expand Up @@ -164,6 +169,7 @@ def __init__(
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.allowed_local_media_path = allowed_local_media_path
self.seed = seed
self.revision = revision
self.code_revision = code_revision
Expand Down Expand Up @@ -1319,6 +1325,8 @@ def maybe_create_spec_config(
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
allowed_local_media_path=target_model_config.
allowed_local_media_path,
dtype=target_model_config.dtype,
seed=target_model_config.seed,
revision=draft_revision,
Expand Down Expand Up @@ -1386,10 +1394,10 @@ def maybe_create_spec_config(
ngram_prompt_lookup_max,
ngram_prompt_lookup_min,
draft_token_acceptance_method=draft_token_acceptance_method,
typical_acceptance_sampler_posterior_threshold=\
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=\
typical_acceptance_sampler_posterior_alpha,
typical_acceptance_sampler_posterior_threshold=
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
)
Expand Down
27 changes: 17 additions & 10 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class EngineArgs:
tokenizer_mode: str = 'auto'
chat_template_text_format: str = 'string'
trust_remote_code: bool = False
allowed_local_media_path: str = ""
download_dir: Optional[str] = None
load_format: str = 'auto'
config_format: ConfigFormat = ConfigFormat.AUTO
Expand Down Expand Up @@ -269,6 +270,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
parser.add_argument(
'--allowed-local-media-path',
type=str,
help="Allowing API requests to read local images or videos"
"from directories specified by the server file system."
"This is a security risk."
"Should only be enabled in trusted environments")
parser.add_argument('--download-dir',
type=nullable_str,
default=EngineArgs.download_dir,
Expand Down Expand Up @@ -920,6 +928,7 @@ def create_model_config(self) -> ModelConfig:
tokenizer_mode=self.tokenizer_mode,
chat_template_text_format=self.chat_template_text_format,
trust_remote_code=self.trust_remote_code,
allowed_local_media_path=self.allowed_local_media_path,
dtype=self.dtype,
seed=self.seed,
revision=self.revision,
Expand Down Expand Up @@ -971,8 +980,8 @@ def create_engine_config(self) -> VllmConfig:
f"'bitsandbytes' load format, but got {self.load_format}")

if (self.load_format == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.quantization != "bitsandbytes":
self.qlora_adapter_name_or_path is not None) and \
self.quantization != "bitsandbytes":
raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")
Expand Down Expand Up @@ -1059,10 +1068,9 @@ def create_engine_config(self) -> VllmConfig:
target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
speculative_model_quantization = \
self.speculative_model_quantization,
speculative_draft_tensor_parallel_size = \
self.speculative_draft_tensor_parallel_size,
speculative_model_quantization=self.speculative_model_quantization,
speculative_draft_tensor_parallel_size=self.
speculative_draft_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
speculative_disable_by_batch_size=self.
Expand All @@ -1072,8 +1080,7 @@ def create_engine_config(self) -> VllmConfig:
disable_log_stats=self.disable_log_stats,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
draft_token_acceptance_method=\
self.spec_decoding_acceptance_method,
draft_token_acceptance_method=self.spec_decoding_acceptance_method,
typical_acceptance_sampler_posterior_threshold=self.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=self.
Expand Down Expand Up @@ -1136,7 +1143,7 @@ def create_engine_config(self) -> VllmConfig:
and self.max_cpu_loras > 0 else None) if self.enable_lora else None

if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "":
self.qlora_adapter_name_or_path != "":
if self.model_loader_extra_config is None:
self.model_loader_extra_config = {}
self.model_loader_extra_config[
Expand All @@ -1147,7 +1154,7 @@ def create_engine_config(self) -> VllmConfig:
prompt_adapter_config = PromptAdapterConfig(
max_prompt_adapters=self.max_prompt_adapters,
max_prompt_adapter_token=self.max_prompt_adapter_token) \
if self.enable_prompt_adapter else None
if self.enable_prompt_adapter else None

decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)
Expand Down
13 changes: 9 additions & 4 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain image_url.
This is supported by OpenAI API, although it is not documented.
Example:
{
"image_url": "https://example.com/image.jpg"
Expand All @@ -73,7 +73,7 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):

class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain audio_url.
Example:
{
"audio_url": "https://example.com/audio.mp3"
Expand Down Expand Up @@ -307,7 +307,9 @@ def __init__(self, tracker: MultiModalItemTracker) -> None:
self._tracker = tracker

def parse_image(self, image_url: str) -> None:
image = get_and_parse_image(image_url)
image = get_and_parse_image(image_url,
allowed_local_media_path=self._tracker.
_model_config.allowed_local_media_path)

placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
Expand All @@ -327,7 +329,10 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
self._tracker = tracker

def parse_image(self, image_url: str) -> None:
image_coro = async_get_and_parse_image(image_url)
image_coro = async_get_and_parse_image(
image_url,
allowed_local_media_path=self._tracker._model_config.
allowed_local_media_path)

placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
Expand Down
11 changes: 9 additions & 2 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class LLM:
from the input.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images
or videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted
environments.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
Expand Down Expand Up @@ -139,6 +143,7 @@ def __init__(
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
Expand Down Expand Up @@ -179,6 +184,7 @@ def __init__(
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
Expand Down Expand Up @@ -475,7 +481,7 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
not ignore_eos:
instance.completed.append(new_beam)
else:
instance_new_beams.append(new_beam)
Expand Down Expand Up @@ -931,7 +937,8 @@ def _run_engine(
# Calculate tokens only for RequestOutput
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"]
in_spd = total_in_toks / \
pbar.format_dict["elapsed"]
total_out_toks += sum(
len(stp.token_ids) for stp in output.outputs)
out_spd = (total_out_toks /
Expand Down
69 changes: 59 additions & 10 deletions vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
from functools import lru_cache
from io import BytesIO
from os import path
from typing import Any, List, Optional, Tuple, TypeVar, Union

import numpy as np
Expand All @@ -18,19 +19,54 @@
cached_get_tokenizer = lru_cache(get_tokenizer)


def _load_image_from_bytes(b: bytes):
def _load_image_from_bytes(b: bytes) -> Image.Image:
image = Image.open(BytesIO(b))
image.load()
return image


def _load_image_from_data_url(image_url: str):
def _is_subpath(image_path, allowed_local_media_path):
# Get the common path
common_path = path.commonpath([image_path, allowed_local_media_path])
# Check if the common path is the same as allowed_local_media_path
return common_path == path.abspath(allowed_local_media_path)


def _load_image_from_file(image_url: str,
allowed_local_media_path: str) -> Image.Image:
if not allowed_local_media_path:
raise ValueError("Invalid 'image_url': Cannot load local files without"
"'--allowed-local-media-path'.")
if allowed_local_media_path:
if not path.exists(allowed_local_media_path):
raise ValueError(
"Invalid '--allowed-local-media-path': The path does not exist."
)
if not path.isdir(allowed_local_media_path):
raise ValueError("Invalid '--allowed-local-media-path': "
"The path must be a directory.")

# Only split once and assume the second part is the image path
_, image_path = image_url.split("file://", 1)
if not _is_subpath(image_path, allowed_local_media_path):
raise ValueError("Invalid 'image_url': The file path must be a"
" subpath of '--allowed-local-media-path'.")

image = Image.open(image_path)
image.load()
return image


def _load_image_from_data_url(image_url: str) -> Image.Image:
# Only split once and assume the second part is the base64 encoded image
_, image_base64 = image_url.split(",", 1)
return load_image_from_base64(image_base64)


def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
def fetch_image(image_url: str,
*,
image_mode: str = "RGB",
allowed_local_media_path: str = "") -> Image.Image:
"""
Load a PIL image from a HTTP or base64 data URL.
Expand All @@ -43,16 +79,19 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:

elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
elif image_url.startswith('file://'):
image = _load_image_from_file(image_url, allowed_local_media_path)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
"with either 'data:image', 'file://' or 'http'.")

return image.convert(image_mode)


async def async_fetch_image(image_url: str,
*,
image_mode: str = "RGB") -> Image.Image:
image_mode: str = "RGB",
allowed_local_media_path: str = "") -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
Expand All @@ -65,9 +104,11 @@ async def async_fetch_image(image_url: str,

elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
elif image_url.startswith('file://'):
image = _load_image_from_file(image_url, allowed_local_media_path)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
"with either 'data:image', 'file://' or 'http'.")

return image.convert(image_mode)

Expand Down Expand Up @@ -126,8 +167,12 @@ def get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
return {"audio": (audio, sr)}


def get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = fetch_image(image_url)
def get_and_parse_image(
image_url: str,
*,
allowed_local_media_path: str = "") -> MultiModalDataDict:
image = fetch_image(image_url,
allowed_local_media_path=allowed_local_media_path)
return {"image": image}


Expand All @@ -136,8 +181,12 @@ async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
return {"audio": (audio, sr)}


async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await async_fetch_image(image_url)
async def async_get_and_parse_image(
image_url: str,
*,
allowed_local_media_path: str = "") -> MultiModalDataDict:
image = await async_fetch_image(
image_url, allowed_local_media_path=allowed_local_media_path)
return {"image": image}


Expand Down

0 comments on commit 5aa7b79

Please sign in to comment.