Skip to content

Commit

Permalink
[Frontend] Multi-Modality Support for Loading Local Image Files
Browse files Browse the repository at this point in the history
FIX #8730

Signed-off-by: chaunceyjiang <[email protected]>
  • Loading branch information
chaunceyjiang committed Nov 4, 2024
1 parent b67feb1 commit f7f2303
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 12 deletions.
8 changes: 8 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class ModelConfig:
"mistral" will always use the tokenizer from `mistral_common`.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or
videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted
environments.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
Expand Down Expand Up @@ -134,6 +138,7 @@ def __init__(
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
Expand Down Expand Up @@ -164,6 +169,7 @@ def __init__(
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.allowed_local_media_path = allowed_local_media_path
self.seed = seed
self.revision = revision
self.code_revision = code_revision
Expand Down Expand Up @@ -1319,6 +1325,8 @@ def maybe_create_spec_config(
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
allowed_local_media_path=target_model_config.
allowed_local_media_path,
dtype=target_model_config.dtype,
seed=target_model_config.seed,
revision=draft_revision,
Expand Down
9 changes: 9 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class EngineArgs:
tokenizer_mode: str = 'auto'
chat_template_text_format: str = 'string'
trust_remote_code: bool = False
allowed_local_media_path: str = ""
download_dir: Optional[str] = None
load_format: str = 'auto'
config_format: ConfigFormat = ConfigFormat.AUTO
Expand Down Expand Up @@ -269,6 +270,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
parser.add_argument(
'--allowed-local-media-path',
type=str,
help="Allowing API requests to read local images or videos"
"from directories specified by the server file system."
"This is a security risk."
"Should only be enabled in trusted environments")
parser.add_argument('--download-dir',
type=nullable_str,
default=EngineArgs.download_dir,
Expand Down Expand Up @@ -920,6 +928,7 @@ def create_model_config(self) -> ModelConfig:
tokenizer_mode=self.tokenizer_mode,
chat_template_text_format=self.chat_template_text_format,
trust_remote_code=self.trust_remote_code,
allowed_local_media_path=self.allowed_local_media_path,
dtype=self.dtype,
seed=self.seed,
revision=self.revision,
Expand Down
9 changes: 7 additions & 2 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,9 @@ def __init__(self, tracker: MultiModalItemTracker) -> None:
self._tracker = tracker

def parse_image(self, image_url: str) -> None:
image = get_and_parse_image(image_url)
image = get_and_parse_image(image_url,
allowed_local_media_path=self._tracker.
_model_config.allowed_local_media_path)

placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
Expand All @@ -327,7 +329,10 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
self._tracker = tracker

def parse_image(self, image_url: str) -> None:
image_coro = async_get_and_parse_image(image_url)
image_coro = async_get_and_parse_image(
image_url,
allowed_local_media_path=self._tracker._model_config.
allowed_local_media_path)

placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
Expand Down
6 changes: 6 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class LLM:
from the input.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images
or videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted
environments.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
Expand Down Expand Up @@ -139,6 +143,7 @@ def __init__(
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
Expand Down Expand Up @@ -179,6 +184,7 @@ def __init__(
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
Expand Down
72 changes: 62 additions & 10 deletions vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import os
from functools import lru_cache
from io import BytesIO
from typing import Any, List, Optional, Tuple, TypeVar, Union
Expand All @@ -18,19 +19,57 @@
cached_get_tokenizer = lru_cache(get_tokenizer)


def _load_image_from_bytes(b: bytes):
def _load_image_from_bytes(b: bytes) -> Image.Image:
image = Image.open(BytesIO(b))
image.load()
return image


def _load_image_from_data_url(image_url: str):
def _is_subpath(image_path: str, allowed_local_media_path: str) -> bool:
# Get the common path
common_path = os.path.commonpath([image_path, allowed_local_media_path])
# Check if the common path is the same as allowed_local_media_path
return common_path == os.path.abspath(allowed_local_media_path)


def _load_image_from_file(image_url: str,
allowed_local_media_path: str) -> Image.Image:
if not allowed_local_media_path:
raise ValueError("Invalid 'image_url': Cannot load local files without"
"'--allowed-local-media-path'.")
if allowed_local_media_path:
if not os.path.exists(allowed_local_media_path):
raise ValueError(
"Invalid '--allowed-local-media-path': "
f"The path {allowed_local_media_path} does not exist.")
if not os.path.isdir(allowed_local_media_path):
raise ValueError(
"Invalid '--allowed-local-media-path': "
f"The path {allowed_local_media_path} must be a directory.")

# Only split once and assume the second part is the image path
_, image_path = image_url.split("file://", 1)
if not _is_subpath(image_path, allowed_local_media_path):
raise ValueError(
f"Invalid 'image_url': The file path {image_path} must"
" be a subpath of '--allowed-local-media-path'"
f" '{allowed_local_media_path}'.")

image = Image.open(image_path)
image.load()
return image


def _load_image_from_data_url(image_url: str) -> Image.Image:
# Only split once and assume the second part is the base64 encoded image
_, image_base64 = image_url.split(",", 1)
return load_image_from_base64(image_base64)


def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
def fetch_image(image_url: str,
*,
image_mode: str = "RGB",
allowed_local_media_path: str = "") -> Image.Image:
"""
Load a PIL image from a HTTP or base64 data URL.
Expand All @@ -43,16 +82,19 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:

elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
elif image_url.startswith('file://'):
image = _load_image_from_file(image_url, allowed_local_media_path)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
"with either 'data:image', 'file://' or 'http'.")

return image.convert(image_mode)


async def async_fetch_image(image_url: str,
*,
image_mode: str = "RGB") -> Image.Image:
image_mode: str = "RGB",
allowed_local_media_path: str = "") -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
Expand All @@ -65,9 +107,11 @@ async def async_fetch_image(image_url: str,

elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
elif image_url.startswith('file://'):
image = _load_image_from_file(image_url, allowed_local_media_path)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
"with either 'data:image', 'file://' or 'http'.")

return image.convert(image_mode)

Expand Down Expand Up @@ -126,8 +170,12 @@ def get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
return {"audio": (audio, sr)}


def get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = fetch_image(image_url)
def get_and_parse_image(
image_url: str,
*,
allowed_local_media_path: str = "") -> MultiModalDataDict:
image = fetch_image(image_url,
allowed_local_media_path=allowed_local_media_path)
return {"image": image}


Expand All @@ -136,8 +184,12 @@ async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
return {"audio": (audio, sr)}


async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await async_fetch_image(image_url)
async def async_get_and_parse_image(
image_url: str,
*,
allowed_local_media_path: str = "") -> MultiModalDataDict:
image = await async_fetch_image(
image_url, allowed_local_media_path=allowed_local_media_path)
return {"image": image}


Expand Down

0 comments on commit f7f2303

Please sign in to comment.