Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Online video support for VLMs #10020

Merged
merged 10 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/assets/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def download_video_asset(filename: str) -> str:


def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
cv2 = try_import_video_packages()
cv2, _ = try_import_video_packages()

cap = cv2.VideoCapture(path)
if not cap.isOpened():
Expand All @@ -59,7 +59,7 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:

def video_to_pil_images_list(path: str,
num_frames: int = -1) -> List[Image.Image]:
cv2 = try_import_video_packages()
cv2, _ = try_import_video_packages()
frames = video_to_ndarrays(path, num_frames)
return [
Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
Expand Down
72 changes: 66 additions & 6 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image,
get_and_parse_audio, get_and_parse_image)
async_get_and_parse_video,
get_and_parse_audio, get_and_parse_image,
get_and_parse_video)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import print_warning_once

Expand All @@ -52,6 +54,23 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
"""The type of the content part."""


class VideoURL(TypedDict, total=False):
url: Required[str]
"""
Either a URL of the video or a data URL with base64 encoded video data.

Learn more in the
[Vision guide](https://cookbook.openai.com/examples/gpt_with_vision_for_video_understanding).
litianjian marked this conversation as resolved.
Show resolved Hide resolved
"""


class ChatCompletionContentPartVideoParam(TypedDict, total=False):
video_url: Required[VideoURL]

type: Required[Literal["video_url"]]
"""The type of the content part."""


class CustomChatCompletionContentPartParam(TypedDict, total=False):
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore

Expand Down Expand Up @@ -82,12 +101,24 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
audio_url: Required[str]


class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain audio_url.

Example:
{
"video_url": "https://example.com/video.mp4"
}
"""
video_url: Required[str]


ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartRefusalParam,
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPartParam,
CustomChatCompletionContentSimpleImageParam,
CustomChatCompletionContentSimpleAudioParam, str]
CustomChatCompletionContentSimpleAudioParam,
CustomChatCompletionContentSimpleVideoParam, str]


class CustomChatCompletionMessageParam(TypedDict, total=False):
Expand Down Expand Up @@ -208,6 +239,9 @@ def _placeholder_str(self, modality: ModalityStr,
elif modality == "video":
if model_type == "qwen2_vl":
return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.video_token_index)
raise TypeError(f"Unknown {modality} model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")
Expand Down Expand Up @@ -298,6 +332,10 @@ def parse_image(self, image_url: str) -> None:
def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError

@abstractmethod
def parse_video(self, video_url: str) -> None:
raise NotImplementedError


class MultiModalContentParser(BaseMultiModalContentParser):

Expand All @@ -318,6 +356,12 @@ def parse_audio(self, audio_url: str) -> None:
placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)

def parse_video(self, video_url: str) -> None:
video = get_and_parse_video(video_url)

placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)


class AsyncMultiModalContentParser(BaseMultiModalContentParser):

Expand All @@ -338,6 +382,12 @@ def parse_audio(self, audio_url: str) -> None:
placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)

def parse_video(self, video_url: str) -> None:
video = async_get_and_parse_video(video_url)

placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)


def validate_chat_template(chat_template: Optional[Union[Path, str]]):
"""Raises if the provided chat template appears invalid."""
Expand Down Expand Up @@ -418,6 +468,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}

# Define a mapping from part types to their corresponding parsing functions.
Expand All @@ -430,6 +481,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
"refusal":
lambda part: _RefusalParser(part).get("refusal", ""),
"video_url":
lambda part: _VideoParser(part).get("video_url", {}).get("url", ""),
}


Expand Down Expand Up @@ -474,7 +527,10 @@ def _parse_chat_message_content_mm_part(
audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
part)
return "audio_url", audio_params.get("audio_url", "")

if part.get("video_url") is not None:
video_params = cast(CustomChatCompletionContentSimpleVideoParam,
part)
return "video_url", video_params.get("video_url", "")
# Raise an error if no 'type' or direct URL is found.
raise ValueError("Missing 'type' field in multimodal part.")

Expand All @@ -484,7 +540,7 @@ def _parse_chat_message_content_mm_part(


VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"audio_url")
"audio_url", "video_url")


def _parse_chat_message_content_parts(
Expand Down Expand Up @@ -544,7 +600,7 @@ def _parse_chat_message_content_part(
# Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)

# if part_type is text/refusal/image_url/audio_url but
# if part_type is text/refusal/image_url/audio_url/video_url but
# content is empty, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning(
Expand All @@ -563,6 +619,10 @@ def _parse_chat_message_content_part(
mm_parser.parse_audio(content)
return {'type': 'audio'} if wrap_dicts else None

if part_type == "video_url":
mm_parser.parse_video(content)
return {'type': 'video'} if wrap_dicts else None

raise NotImplementedError(f"Unknown part type: {part_type}")


Expand Down
3 changes: 3 additions & 0 deletions vllm/multimodal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ class MultiModalDataBuiltins(TypedDict, total=False):
audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
"""The input audio item(s) and corresponding sampling rate(s)."""

video: MultiModalData[Tuple[np.ndarray]]
"""The input video(s)."""


MultiModalDataDict = Union[MultiModalDataBuiltins,
Mapping[str, MultiModalData[object]]]
Expand Down
90 changes: 88 additions & 2 deletions vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,81 @@ async def async_fetch_image(image_url: str,
return image.convert(image_mode)


def _load_video_frames_from_bytes(b: bytes):
frame = Image.open(BytesIO(b))
return np.array(frame)


def load_video_frames_from_base64(frame: Union[bytes, str]):
"""Load frame from base64 format."""
return _load_video_frames_from_bytes(base64.b64decode(frame))


def _load_video_from_bytes(b: bytes, num_frames: int = 32):
_, decord = try_import_video_packages()

video_path = BytesIO(b)
vr = decord.VideoReader(video_path, num_threads=1)
total_frame_num = len(vr)

if total_frame_num > num_frames:
uniform_sampled_frames = np.linspace(0,
total_frame_num - 1,
num_frames,
dtype=int)
frame_idx = uniform_sampled_frames.tolist()
litianjian marked this conversation as resolved.
Show resolved Hide resolved
else:
frame_idx = [i for i in range(0, total_frame_num)]
frames = vr.get_batch(frame_idx).asnumpy()

return frames


def _load_video_from_data_url(video_url: str):
# Only split once and assume the second part is the base64 encoded image
frames_base64 = video_url.split(",")[1:]
return np.stack([
load_video_frames_from_base64(frame_base64)
for frame_base64 in frames_base64
])


def fetch_video(video_url: str, *, num_frames: int = 32) -> npt.NDArray:
"""
Load video from a HTTP or base64 data URL.
"""
if video_url.startswith('http') or video_url.startswith('https'):
video_raw = global_http_connection.get_bytes(
video_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
video = _load_video_from_bytes(video_raw, num_frames)
elif video_url.startswith('data:video'):
video = _load_video_from_data_url(video_url)
else:
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
"with either 'data:video' or 'http'.")
return video


async def async_fetch_video(video_url: str,
*,
num_frames: int = 32) -> npt.NDArray:
"""
Asynchronously load video from a HTTP or base64 data URL.

By default, the image is converted into RGB format.
"""
if video_url.startswith('http') or video_url.startswith('https'):
video_raw = await global_http_connection.async_get_bytes(
video_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
video = _load_video_from_bytes(video_raw, num_frames)
elif video_url.startswith('data:video'):
video = _load_video_from_data_url(video_url)
else:
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
"with either 'data:video' or 'http'.")
return video


def try_import_audio_packages() -> Tuple[Any, Any]:
try:
import librosa
Expand Down Expand Up @@ -131,6 +206,11 @@ def get_and_parse_image(image_url: str) -> MultiModalDataDict:
return {"image": image}


def get_and_parse_video(video_url: str) -> MultiModalDataDict:
video = fetch_video(video_url)
return {"video": video}


async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
audio, sr = await async_fetch_audio(audio_url)
return {"audio": (audio, sr)}
Expand All @@ -141,6 +221,11 @@ async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
return {"image": image}


async def async_get_and_parse_video(video_url: str) -> MultiModalDataDict:
video = await async_fetch_video(video_url)
return {"video": video}


def encode_audio_base64(
audio: np.ndarray,
sampling_rate: int,
Expand Down Expand Up @@ -191,14 +276,15 @@ def rescale_image_size(image: Image.Image,
def try_import_video_packages() -> Any:
try:
import cv2
import decord
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
except ImportError:
raise ImportError(
"Please install vllm[video] for video support.") from None
return cv2
return cv2, decord


def resize_video(frames: npt.NDArray, size: Tuple[int, int]) -> npt.NDArray:
cv2 = try_import_video_packages()
cv2, _ = try_import_video_packages()

num_frames, _, _, channels = frames.shape
new_height, new_width = size
Expand Down