Skip to content

Commit

Permalink
[Misc] Manage HTTP connections in one place (vllm-project#6600)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Jul 23, 2024
1 parent c051bfe commit 97234be
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 85 deletions.
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import TokenizerPoolConfig
from vllm.connections import global_http_connection
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt
Expand Down Expand Up @@ -74,6 +75,13 @@ def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
"""Singleton instance of :class:`_ImageAssets`."""


@pytest.fixture(autouse=True)
def init_test_http_connection():
# pytest_asyncio may use a different event loop per test
# so we need to make sure the async client is created anew
global_http_connection.reuse_client = False


def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
Expand Down
10 changes: 4 additions & 6 deletions tests/entrypoints/openai/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import openai
import pytest
import pytest_asyncio

from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
from vllm.multimodal.utils import encode_image_base64, fetch_image

from ...utils import VLLM_PATH, RemoteOpenAIServer

Expand Down Expand Up @@ -42,11 +41,10 @@ def client(server):
return server.get_async_client()


@pytest_asyncio.fixture(scope="session")
async def base64_encoded_image() -> Dict[str, str]:
@pytest.fixture(scope="session")
def base64_encoded_image() -> Dict[str, str]:
return {
image_url:
encode_image_base64(await ImageFetchAiohttp.fetch_image(image_url))
image_url: encode_image_base64(fetch_image(image_url))
for image_url in TEST_IMAGE_URLS
}

Expand Down
10 changes: 5 additions & 5 deletions tests/multimodal/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from PIL import Image

from vllm.multimodal.utils import ImageFetchAiohttp, fetch_image
from vllm.multimodal.utils import async_fetch_image, fetch_image

# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
Expand Down Expand Up @@ -37,15 +37,15 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
return (np.asarray(a) == np.asarray(b.convert(a.mode))).all()


@pytest.mark.asyncio(scope="module")
@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_http(image_url: str):
image_sync = fetch_image(image_url)
image_async = await ImageFetchAiohttp.fetch_image(image_url)
image_async = await async_fetch_image(image_url)
assert _image_equals(image_sync, image_async)


@pytest.mark.asyncio(scope="module")
@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
Expand Down Expand Up @@ -78,5 +78,5 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
else:
pass # Lossy format; only check that image can be opened

data_image_async = await ImageFetchAiohttp.fetch_image(data_url)
data_image_async = await async_fetch_image(data_url)
assert _image_equals(data_image_sync, data_image_async)
13 changes: 6 additions & 7 deletions vllm/assets/image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import shutil
from dataclasses import dataclass
from functools import lru_cache
from typing import Literal

import requests
from PIL import Image

from vllm.connections import global_http_connection
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT

from .base import get_cache_dir


Expand All @@ -22,11 +23,9 @@ def get_air_example_data_2_asset(filename: str) -> Image.Image:
if not image_path.exists():
base_url = "https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava"

with requests.get(f"{base_url}/{filename}", stream=True) as response:
response.raise_for_status()

with image_path.open("wb") as f:
shutil.copyfileobj(response.raw, f)
global_http_connection.download_file(f"{base_url}/{filename}",
image_path,
timeout=VLLM_IMAGE_FETCH_TIMEOUT)

return Image.open(image_path)

Expand Down
167 changes: 167 additions & 0 deletions vllm/connections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from pathlib import Path
from typing import Mapping, Optional
from urllib.parse import urlparse

import aiohttp
import requests

from vllm.version import __version__ as VLLM_VERSION


class HTTPConnection:
"""Helper class to send HTTP requests."""

def __init__(self, *, reuse_client: bool = True) -> None:
super().__init__()

self.reuse_client = reuse_client

self._sync_client: Optional[requests.Session] = None
self._async_client: Optional[aiohttp.ClientSession] = None

def get_sync_client(self) -> requests.Session:
if self._sync_client is None or not self.reuse_client:
self._sync_client = requests.Session()

return self._sync_client

# NOTE: We intentionally use an async function even though it is not
# required, so that the client is only accessible inside async event loop
async def get_async_client(self) -> aiohttp.ClientSession:
if self._async_client is None or not self.reuse_client:
self._async_client = aiohttp.ClientSession()

return self._async_client

def _validate_http_url(self, url: str):
parsed_url = urlparse(url)

if parsed_url.scheme not in ("http", "https"):
raise ValueError("Invalid HTTP URL: A valid HTTP URL "
"must have scheme 'http' or 'https'.")

def _headers(self, **extras: str) -> Mapping[str, str]:
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}

def get_response(
self,
url: str,
*,
stream: bool = False,
timeout: Optional[float] = None,
extra_headers: Optional[Mapping[str, str]] = None,
):
self._validate_http_url(url)

client = self.get_sync_client()
extra_headers = extra_headers or {}

return client.get(url,
headers=self._headers(**extra_headers),
stream=stream,
timeout=timeout)

async def get_async_response(
self,
url: str,
*,
timeout: Optional[float] = None,
extra_headers: Optional[Mapping[str, str]] = None,
):
self._validate_http_url(url)

client = await self.get_async_client()
extra_headers = extra_headers or {}

return client.get(url,
headers=self._headers(**extra_headers),
timeout=timeout)

def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes:
with self.get_response(url, timeout=timeout) as r:
r.raise_for_status()

return r.content

async def async_get_bytes(
self,
url: str,
*,
timeout: Optional[float] = None,
) -> bytes:
async with await self.get_async_response(url, timeout=timeout) as r:
r.raise_for_status()

return await r.read()

def get_text(self, url: str, *, timeout: Optional[float] = None) -> str:
with self.get_response(url, timeout=timeout) as r:
r.raise_for_status()

return r.text

async def async_get_text(
self,
url: str,
*,
timeout: Optional[float] = None,
) -> str:
async with await self.get_async_response(url, timeout=timeout) as r:
r.raise_for_status()

return await r.text()

def get_json(self, url: str, *, timeout: Optional[float] = None) -> str:
with self.get_response(url, timeout=timeout) as r:
r.raise_for_status()

return r.json()

async def async_get_json(
self,
url: str,
*,
timeout: Optional[float] = None,
) -> str:
async with await self.get_async_response(url, timeout=timeout) as r:
r.raise_for_status()

return await r.json()

def download_file(
self,
url: str,
save_path: Path,
*,
timeout: Optional[float] = None,
chunk_size: int = 128,
) -> Path:
with self.get_response(url, timeout=timeout) as r:
r.raise_for_status()

with save_path.open("wb") as f:
for chunk in r.iter_content(chunk_size):
f.write(chunk)

return save_path

async def async_download_file(
self,
url: str,
save_path: Path,
*,
timeout: Optional[float] = None,
chunk_size: int = 128,
) -> Path:
async with await self.get_async_response(url, timeout=timeout) as r:
r.raise_for_status()

with save_path.open("wb") as f:
async for chunk in r.content.iter_chunked(chunk_size):
f.write(chunk)

return save_path


global_http_connection = HTTPConnection()
"""The global :class:`HTTPConnection` instance used by vLLM."""
88 changes: 22 additions & 66 deletions vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,12 @@
import base64
from io import BytesIO
from typing import Optional, Union
from urllib.parse import urlparse
from typing import Union

import aiohttp
import requests
from PIL import Image

from vllm.connections import global_http_connection
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
from vllm.multimodal.base import MultiModalDataDict
from vllm.version import __version__ as VLLM_VERSION


def _validate_remote_url(url: str, *, name: str):
parsed_url = urlparse(url)
if parsed_url.scheme not in ["http", "https"]:
raise ValueError(f"Invalid '{name}': A valid '{name}' "
"must have scheme 'http' or 'https'.")


def _get_request_headers():
return {"User-Agent": f"vLLM/{VLLM_VERSION}"}


def _load_image_from_bytes(b: bytes):
Expand All @@ -42,13 +28,8 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'):
_validate_remote_url(image_url, name="image_url")

headers = _get_request_headers()

with requests.get(url=image_url, headers=headers) as response:
response.raise_for_status()
image_raw = response.content
image_raw = global_http_connection.get_bytes(
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
image = _load_image_from_bytes(image_raw)

elif image_url.startswith('data:image'):
Expand All @@ -60,55 +41,30 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
return image.convert(image_mode)


class ImageFetchAiohttp:
aiohttp_client: Optional[aiohttp.ClientSession] = None

@classmethod
def get_aiohttp_client(cls) -> aiohttp.ClientSession:
if cls.aiohttp_client is None:
timeout = aiohttp.ClientTimeout(total=VLLM_IMAGE_FETCH_TIMEOUT)
connector = aiohttp.TCPConnector()
cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout,
connector=connector)

return cls.aiohttp_client

@classmethod
async def fetch_image(
cls,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""

if image_url.startswith('http'):
_validate_remote_url(image_url, name="image_url")

client = cls.get_aiohttp_client()
headers = _get_request_headers()
async def async_fetch_image(image_url: str,
*,
image_mode: str = "RGB") -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
async with client.get(url=image_url, headers=headers) as response:
response.raise_for_status()
image_raw = await response.read()
image = _load_image_from_bytes(image_raw)
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'):
image_raw = await global_http_connection.async_get_bytes(
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
image = _load_image_from_bytes(image_raw)

elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
else:
raise ValueError(
"Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")

return image.convert(image_mode)
return image.convert(image_mode)


async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await ImageFetchAiohttp.fetch_image(image_url)
image = await async_fetch_image(image_url)
return {"image": image}


Expand Down
Loading

0 comments on commit 97234be

Please sign in to comment.