From 97234be0ec67f48ed5e65bc0290f329dfb33798e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 23 Jul 2024 12:32:02 +0800 Subject: [PATCH] [Misc] Manage HTTP connections in one place (#6600) --- tests/conftest.py | 8 ++ tests/entrypoints/openai/test_vision.py | 10 +- tests/multimodal/test_utils.py | 10 +- vllm/assets/image.py | 13 +- vllm/connections.py | 167 ++++++++++++++++++++++++ vllm/multimodal/utils.py | 88 ++++--------- vllm/usage/usage_lib.py | 4 +- 7 files changed, 215 insertions(+), 85 deletions(-) create mode 100644 vllm/connections.py diff --git a/tests/conftest.py b/tests/conftest.py index 652d627377786..7f507310cd255 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,7 @@ from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.config import TokenizerPoolConfig +from vllm.connections import global_http_connection from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) from vllm.inputs import TextPrompt @@ -74,6 +75,13 @@ def prompts(self, prompts: _ImageAssetPrompts) -> List[str]: """Singleton instance of :class:`_ImageAssets`.""" +@pytest.fixture(autouse=True) +def init_test_http_connection(): + # pytest_asyncio may use a different event loop per test + # so we need to make sure the async client is created anew + global_http_connection.reuse_client = False + + def cleanup(): destroy_model_parallel() destroy_distributed_environment() diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index cc5c8d619183f..843ba91f7a076 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -2,9 +2,8 @@ import openai import pytest -import pytest_asyncio -from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64 +from vllm.multimodal.utils import encode_image_base64, fetch_image from ...utils import VLLM_PATH, RemoteOpenAIServer @@ -42,11 +41,10 @@ def client(server): return server.get_async_client() -@pytest_asyncio.fixture(scope="session") -async def base64_encoded_image() -> Dict[str, str]: +@pytest.fixture(scope="session") +def base64_encoded_image() -> Dict[str, str]: return { - image_url: - encode_image_base64(await ImageFetchAiohttp.fetch_image(image_url)) + image_url: encode_image_base64(fetch_image(image_url)) for image_url in TEST_IMAGE_URLS } diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 10cabdadb1dcd..cd1fc91c29374 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -7,7 +7,7 @@ import pytest from PIL import Image -from vllm.multimodal.utils import ImageFetchAiohttp, fetch_image +from vllm.multimodal.utils import async_fetch_image, fetch_image # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) TEST_IMAGE_URLS = [ @@ -37,15 +37,15 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool: return (np.asarray(a) == np.asarray(b.convert(a.mode))).all() -@pytest.mark.asyncio(scope="module") +@pytest.mark.asyncio @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) async def test_fetch_image_http(image_url: str): image_sync = fetch_image(image_url) - image_async = await ImageFetchAiohttp.fetch_image(image_url) + image_async = await async_fetch_image(image_url) assert _image_equals(image_sync, image_async) -@pytest.mark.asyncio(scope="module") +@pytest.mark.asyncio @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("suffix", get_supported_suffixes()) async def test_fetch_image_base64(url_images: Dict[str, Image.Image], @@ -78,5 +78,5 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image], else: pass # Lossy format; only check that image can be opened - data_image_async = await ImageFetchAiohttp.fetch_image(data_url) + data_image_async = await async_fetch_image(data_url) assert _image_equals(data_image_sync, data_image_async) diff --git a/vllm/assets/image.py b/vllm/assets/image.py index ca6c3ac9e3a38..b865b1b3a5497 100644 --- a/vllm/assets/image.py +++ b/vllm/assets/image.py @@ -1,11 +1,12 @@ -import shutil from dataclasses import dataclass from functools import lru_cache from typing import Literal -import requests from PIL import Image +from vllm.connections import global_http_connection +from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT + from .base import get_cache_dir @@ -22,11 +23,9 @@ def get_air_example_data_2_asset(filename: str) -> Image.Image: if not image_path.exists(): base_url = "https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava" - with requests.get(f"{base_url}/{filename}", stream=True) as response: - response.raise_for_status() - - with image_path.open("wb") as f: - shutil.copyfileobj(response.raw, f) + global_http_connection.download_file(f"{base_url}/{filename}", + image_path, + timeout=VLLM_IMAGE_FETCH_TIMEOUT) return Image.open(image_path) diff --git a/vllm/connections.py b/vllm/connections.py new file mode 100644 index 0000000000000..65d44176e2464 --- /dev/null +++ b/vllm/connections.py @@ -0,0 +1,167 @@ +from pathlib import Path +from typing import Mapping, Optional +from urllib.parse import urlparse + +import aiohttp +import requests + +from vllm.version import __version__ as VLLM_VERSION + + +class HTTPConnection: + """Helper class to send HTTP requests.""" + + def __init__(self, *, reuse_client: bool = True) -> None: + super().__init__() + + self.reuse_client = reuse_client + + self._sync_client: Optional[requests.Session] = None + self._async_client: Optional[aiohttp.ClientSession] = None + + def get_sync_client(self) -> requests.Session: + if self._sync_client is None or not self.reuse_client: + self._sync_client = requests.Session() + + return self._sync_client + + # NOTE: We intentionally use an async function even though it is not + # required, so that the client is only accessible inside async event loop + async def get_async_client(self) -> aiohttp.ClientSession: + if self._async_client is None or not self.reuse_client: + self._async_client = aiohttp.ClientSession() + + return self._async_client + + def _validate_http_url(self, url: str): + parsed_url = urlparse(url) + + if parsed_url.scheme not in ("http", "https"): + raise ValueError("Invalid HTTP URL: A valid HTTP URL " + "must have scheme 'http' or 'https'.") + + def _headers(self, **extras: str) -> Mapping[str, str]: + return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras} + + def get_response( + self, + url: str, + *, + stream: bool = False, + timeout: Optional[float] = None, + extra_headers: Optional[Mapping[str, str]] = None, + ): + self._validate_http_url(url) + + client = self.get_sync_client() + extra_headers = extra_headers or {} + + return client.get(url, + headers=self._headers(**extra_headers), + stream=stream, + timeout=timeout) + + async def get_async_response( + self, + url: str, + *, + timeout: Optional[float] = None, + extra_headers: Optional[Mapping[str, str]] = None, + ): + self._validate_http_url(url) + + client = await self.get_async_client() + extra_headers = extra_headers or {} + + return client.get(url, + headers=self._headers(**extra_headers), + timeout=timeout) + + def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.content + + async def async_get_bytes( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> bytes: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.read() + + def get_text(self, url: str, *, timeout: Optional[float] = None) -> str: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.text + + async def async_get_text( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> str: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.text() + + def get_json(self, url: str, *, timeout: Optional[float] = None) -> str: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.json() + + async def async_get_json( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> str: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.json() + + def download_file( + self, + url: str, + save_path: Path, + *, + timeout: Optional[float] = None, + chunk_size: int = 128, + ) -> Path: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + with save_path.open("wb") as f: + for chunk in r.iter_content(chunk_size): + f.write(chunk) + + return save_path + + async def async_download_file( + self, + url: str, + save_path: Path, + *, + timeout: Optional[float] = None, + chunk_size: int = 128, + ) -> Path: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + with save_path.open("wb") as f: + async for chunk in r.content.iter_chunked(chunk_size): + f.write(chunk) + + return save_path + + +global_http_connection = HTTPConnection() +"""The global :class:`HTTPConnection` instance used by vLLM.""" diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 8691a61343ab6..bafd208469788 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -1,26 +1,12 @@ import base64 from io import BytesIO -from typing import Optional, Union -from urllib.parse import urlparse +from typing import Union -import aiohttp -import requests from PIL import Image +from vllm.connections import global_http_connection from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT from vllm.multimodal.base import MultiModalDataDict -from vllm.version import __version__ as VLLM_VERSION - - -def _validate_remote_url(url: str, *, name: str): - parsed_url = urlparse(url) - if parsed_url.scheme not in ["http", "https"]: - raise ValueError(f"Invalid '{name}': A valid '{name}' " - "must have scheme 'http' or 'https'.") - - -def _get_request_headers(): - return {"User-Agent": f"vLLM/{VLLM_VERSION}"} def _load_image_from_bytes(b: bytes): @@ -42,13 +28,8 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image: By default, the image is converted into RGB format. """ if image_url.startswith('http'): - _validate_remote_url(image_url, name="image_url") - - headers = _get_request_headers() - - with requests.get(url=image_url, headers=headers) as response: - response.raise_for_status() - image_raw = response.content + image_raw = global_http_connection.get_bytes( + image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT) image = _load_image_from_bytes(image_raw) elif image_url.startswith('data:image'): @@ -60,55 +41,30 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image: return image.convert(image_mode) -class ImageFetchAiohttp: - aiohttp_client: Optional[aiohttp.ClientSession] = None - - @classmethod - def get_aiohttp_client(cls) -> aiohttp.ClientSession: - if cls.aiohttp_client is None: - timeout = aiohttp.ClientTimeout(total=VLLM_IMAGE_FETCH_TIMEOUT) - connector = aiohttp.TCPConnector() - cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout, - connector=connector) - - return cls.aiohttp_client - - @classmethod - async def fetch_image( - cls, - image_url: str, - *, - image_mode: str = "RGB", - ) -> Image.Image: - """ - Asynchronously load a PIL image from a HTTP or base64 data URL. - - By default, the image is converted into RGB format. - """ - - if image_url.startswith('http'): - _validate_remote_url(image_url, name="image_url") - - client = cls.get_aiohttp_client() - headers = _get_request_headers() +async def async_fetch_image(image_url: str, + *, + image_mode: str = "RGB") -> Image.Image: + """ + Asynchronously load a PIL image from a HTTP or base64 data URL. - async with client.get(url=image_url, headers=headers) as response: - response.raise_for_status() - image_raw = await response.read() - image = _load_image_from_bytes(image_raw) + By default, the image is converted into RGB format. + """ + if image_url.startswith('http'): + image_raw = await global_http_connection.async_get_bytes( + image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT) + image = _load_image_from_bytes(image_raw) - elif image_url.startswith('data:image'): - image = _load_image_from_data_url(image_url) - else: - raise ValueError( - "Invalid 'image_url': A valid 'image_url' must start " - "with either 'data:image' or 'http'.") + elif image_url.startswith('data:image'): + image = _load_image_from_data_url(image_url) + else: + raise ValueError("Invalid 'image_url': A valid 'image_url' must start " + "with either 'data:image' or 'http'.") - return image.convert(image_mode) + return image.convert(image_mode) async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict: - image = await ImageFetchAiohttp.fetch_image(image_url) + image = await async_fetch_image(image_url) return {"image": image} diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index fb6a6d854e21c..515e0a4d8abe7 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -16,6 +16,7 @@ import torch import vllm.envs as envs +from vllm.connections import global_http_connection from vllm.version import __version__ as VLLM_VERSION _config_home = envs.VLLM_CONFIG_ROOT @@ -204,7 +205,8 @@ def _report_continous_usage(self): def _send_to_server(self, data): try: - requests.post(_USAGE_STATS_SERVER, json=data) + global_http_client = global_http_connection.get_sync_client() + global_http_client.post(_USAGE_STATS_SERVER, json=data) except requests.exceptions.RequestException: # silently ignore unless we are using debug log logging.debug("Failed to send usage data to server")