From 9ddd35a07190f320d8fc10b6f3b4af61e81f3cb1 Mon Sep 17 00:00:00 2001 From: Chauncey Date: Mon, 4 Nov 2024 23:34:57 +0800 Subject: [PATCH] [Frontend] Multi-Modality Support for Loading Local Image Files (#9915) Signed-off-by: chaunceyjiang Signed-off-by: Richard Liu --- tests/multimodal/test_utils.py | 39 +++++++++++++++++- vllm/config.py | 8 ++++ vllm/engine/arg_utils.py | 9 ++++ vllm/entrypoints/chat_utils.py | 9 +++- vllm/entrypoints/llm.py | 6 +++ vllm/multimodal/utils.py | 75 +++++++++++++++++++++++++++++----- 6 files changed, 132 insertions(+), 14 deletions(-) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 69f04f0a69c0b..9869c8123f001 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -1,11 +1,12 @@ import base64 import mimetypes -from tempfile import NamedTemporaryFile +import os +from tempfile import NamedTemporaryFile, TemporaryDirectory from typing import Dict, Tuple import numpy as np import pytest -from PIL import Image +from PIL import Image, ImageChops from transformers import AutoConfig, AutoTokenizer from vllm.multimodal.utils import (async_fetch_image, fetch_image, @@ -84,6 +85,40 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image], assert _image_equals(data_image_sync, data_image_async) +@pytest.mark.asyncio +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_fetch_image_local_files(image_url: str): + with TemporaryDirectory() as temp_dir: + origin_image = fetch_image(image_url) + origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)), + quality=100, + icc_profile=origin_image.info.get('icc_profile')) + + image_async = await async_fetch_image( + f"file://{temp_dir}/{os.path.basename(image_url)}", + allowed_local_media_path=temp_dir) + + image_sync = fetch_image( + f"file://{temp_dir}/{os.path.basename(image_url)}", + allowed_local_media_path=temp_dir) + # Check that the images are equal + assert not ImageChops.difference(image_sync, image_async).getbbox() + + with pytest.raises(ValueError): + await async_fetch_image( + f"file://{temp_dir}/../{os.path.basename(image_url)}", + allowed_local_media_path=temp_dir) + with pytest.raises(ValueError): + await async_fetch_image( + f"file://{temp_dir}/../{os.path.basename(image_url)}") + + with pytest.raises(ValueError): + fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}", + allowed_local_media_path=temp_dir) + with pytest.raises(ValueError): + fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}") + + @pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"]) def test_repeat_and_pad_placeholder_tokens(model): config = AutoConfig.from_pretrained(model) diff --git a/vllm/config.py b/vllm/config.py index 17e9b1c100498..0870eb9f70709 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -55,6 +55,10 @@ class ModelConfig: "mistral" will always use the tokenizer from `mistral_common`. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. + allowed_local_media_path: Allowing API requests to read local images or + videos from directories specified by the server file system. + This is a security risk. Should only be enabled in trusted + environments. dtype: Data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. @@ -134,6 +138,7 @@ def __init__( trust_remote_code: bool, dtype: Union[str, torch.dtype], seed: int, + allowed_local_media_path: str = "", revision: Optional[str] = None, code_revision: Optional[str] = None, rope_scaling: Optional[dict] = None, @@ -164,6 +169,7 @@ def __init__( self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode self.trust_remote_code = trust_remote_code + self.allowed_local_media_path = allowed_local_media_path self.seed = seed self.revision = revision self.code_revision = code_revision @@ -1319,6 +1325,8 @@ def maybe_create_spec_config( tokenizer=target_model_config.tokenizer, tokenizer_mode=target_model_config.tokenizer_mode, trust_remote_code=target_model_config.trust_remote_code, + allowed_local_media_path=target_model_config. + allowed_local_media_path, dtype=target_model_config.dtype, seed=target_model_config.seed, revision=draft_revision, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index da06ab186821e..bd39e72d58caa 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -92,6 +92,7 @@ class EngineArgs: tokenizer_mode: str = 'auto' chat_template_text_format: str = 'string' trust_remote_code: bool = False + allowed_local_media_path: str = "" download_dir: Optional[str] = None load_format: str = 'auto' config_format: ConfigFormat = ConfigFormat.AUTO @@ -269,6 +270,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument('--trust-remote-code', action='store_true', help='Trust remote code from huggingface.') + parser.add_argument( + '--allowed-local-media-path', + type=str, + help="Allowing API requests to read local images or videos" + "from directories specified by the server file system." + "This is a security risk." + "Should only be enabled in trusted environments") parser.add_argument('--download-dir', type=nullable_str, default=EngineArgs.download_dir, @@ -920,6 +928,7 @@ def create_model_config(self) -> ModelConfig: tokenizer_mode=self.tokenizer_mode, chat_template_text_format=self.chat_template_text_format, trust_remote_code=self.trust_remote_code, + allowed_local_media_path=self.allowed_local_media_path, dtype=self.dtype, seed=self.seed, revision=self.revision, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c9552977710d1..8da08d4b2c93c 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -307,7 +307,9 @@ def __init__(self, tracker: MultiModalItemTracker) -> None: self._tracker = tracker def parse_image(self, image_url: str) -> None: - image = get_and_parse_image(image_url) + image = get_and_parse_image(image_url, + allowed_local_media_path=self._tracker. + _model_config.allowed_local_media_path) placeholder = self._tracker.add("image", image) self._add_placeholder(placeholder) @@ -327,7 +329,10 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: self._tracker = tracker def parse_image(self, image_url: str) -> None: - image_coro = async_get_and_parse_image(image_url) + image_coro = async_get_and_parse_image( + image_url, + allowed_local_media_path=self._tracker._model_config. + allowed_local_media_path) placeholder = self._tracker.add("image", image_coro) self._add_placeholder(placeholder) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 3d62cb3598477..b18974c5a0c57 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -58,6 +58,10 @@ class LLM: from the input. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. + allowed_local_media_path: Allowing API requests to read local images + or videos from directories specified by the server file system. + This is a security risk. Should only be enabled in trusted + environments. tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. dtype: The data type for the model weights and activations. Currently, @@ -139,6 +143,7 @@ def __init__( tokenizer_mode: str = "auto", skip_tokenizer_init: bool = False, trust_remote_code: bool = False, + allowed_local_media_path: str = "", tensor_parallel_size: int = 1, dtype: str = "auto", quantization: Optional[str] = None, @@ -179,6 +184,7 @@ def __init__( tokenizer_mode=tokenizer_mode, skip_tokenizer_init=skip_tokenizer_init, trust_remote_code=trust_remote_code, + allowed_local_media_path=allowed_local_media_path, tensor_parallel_size=tensor_parallel_size, dtype=dtype, quantization=quantization, diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index c5ff552e06099..283c23c94d330 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -1,4 +1,5 @@ import base64 +import os from functools import lru_cache from io import BytesIO from typing import Any, List, Optional, Tuple, TypeVar, Union @@ -18,19 +19,60 @@ cached_get_tokenizer = lru_cache(get_tokenizer) -def _load_image_from_bytes(b: bytes): +def _load_image_from_bytes(b: bytes) -> Image.Image: image = Image.open(BytesIO(b)) image.load() return image -def _load_image_from_data_url(image_url: str): +def _is_subpath(image_path: str, allowed_local_media_path: str) -> bool: + # Get the common path + common_path = os.path.commonpath([ + os.path.abspath(image_path), + os.path.abspath(allowed_local_media_path) + ]) + # Check if the common path is the same as allowed_local_media_path + return common_path == os.path.abspath(allowed_local_media_path) + + +def _load_image_from_file(image_url: str, + allowed_local_media_path: str) -> Image.Image: + if not allowed_local_media_path: + raise ValueError("Invalid 'image_url': Cannot load local files without" + "'--allowed-local-media-path'.") + if allowed_local_media_path: + if not os.path.exists(allowed_local_media_path): + raise ValueError( + "Invalid '--allowed-local-media-path': " + f"The path {allowed_local_media_path} does not exist.") + if not os.path.isdir(allowed_local_media_path): + raise ValueError( + "Invalid '--allowed-local-media-path': " + f"The path {allowed_local_media_path} must be a directory.") + + # Only split once and assume the second part is the image path + _, image_path = image_url.split("file://", 1) + if not _is_subpath(image_path, allowed_local_media_path): + raise ValueError( + f"Invalid 'image_url': The file path {image_path} must" + " be a subpath of '--allowed-local-media-path'" + f" '{allowed_local_media_path}'.") + + image = Image.open(image_path) + image.load() + return image + + +def _load_image_from_data_url(image_url: str) -> Image.Image: # Only split once and assume the second part is the base64 encoded image _, image_base64 = image_url.split(",", 1) return load_image_from_base64(image_base64) -def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image: +def fetch_image(image_url: str, + *, + image_mode: str = "RGB", + allowed_local_media_path: str = "") -> Image.Image: """ Load a PIL image from a HTTP or base64 data URL. @@ -43,16 +85,19 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image: elif image_url.startswith('data:image'): image = _load_image_from_data_url(image_url) + elif image_url.startswith('file://'): + image = _load_image_from_file(image_url, allowed_local_media_path) else: raise ValueError("Invalid 'image_url': A valid 'image_url' must start " - "with either 'data:image' or 'http'.") + "with either 'data:image', 'file://' or 'http'.") return image.convert(image_mode) async def async_fetch_image(image_url: str, *, - image_mode: str = "RGB") -> Image.Image: + image_mode: str = "RGB", + allowed_local_media_path: str = "") -> Image.Image: """ Asynchronously load a PIL image from a HTTP or base64 data URL. @@ -65,9 +110,11 @@ async def async_fetch_image(image_url: str, elif image_url.startswith('data:image'): image = _load_image_from_data_url(image_url) + elif image_url.startswith('file://'): + image = _load_image_from_file(image_url, allowed_local_media_path) else: raise ValueError("Invalid 'image_url': A valid 'image_url' must start " - "with either 'data:image' or 'http'.") + "with either 'data:image', 'file://' or 'http'.") return image.convert(image_mode) @@ -126,8 +173,12 @@ def get_and_parse_audio(audio_url: str) -> MultiModalDataDict: return {"audio": (audio, sr)} -def get_and_parse_image(image_url: str) -> MultiModalDataDict: - image = fetch_image(image_url) +def get_and_parse_image( + image_url: str, + *, + allowed_local_media_path: str = "") -> MultiModalDataDict: + image = fetch_image(image_url, + allowed_local_media_path=allowed_local_media_path) return {"image": image} @@ -136,8 +187,12 @@ async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict: return {"audio": (audio, sr)} -async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict: - image = await async_fetch_image(image_url) +async def async_get_and_parse_image( + image_url: str, + *, + allowed_local_media_path: str = "") -> MultiModalDataDict: + image = await async_fetch_image( + image_url, allowed_local_media_path=allowed_local_media_path) return {"image": image}