From cf24a1b2077d8def5ba7ca027caf8ce7006c070d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 23 Dec 2024 11:35:39 +0000 Subject: [PATCH] Fix Qwen2-VL and Qwen2-Audio Signed-off-by: DarkLight1337 --- benchmarks/mmmu_bench.py | 172 ++++++++++++++++++++++++++++ tests/multimodal/test_processing.py | 14 +-- vllm/multimodal/inputs.py | 32 +++--- 3 files changed, 196 insertions(+), 22 deletions(-) create mode 100644 benchmarks/mmmu_bench.py diff --git a/benchmarks/mmmu_bench.py b/benchmarks/mmmu_bench.py new file mode 100644 index 0000000000000..6315ddf853bd6 --- /dev/null +++ b/benchmarks/mmmu_bench.py @@ -0,0 +1,172 @@ +r"""Benchmark offline inference throughput with MMMU-PRO Vision +e.g, +python3 benchmarks/mmmu_bench.py \ + --model mistralai/Pixtral-12B-2409 \ + --tokenizer-mode mistral \ + --num-prompts 1000 \ + --image-hit-rate 0.5 + +python3 benchmarks/mmmu_bench.py \ + --model allenai/Molmo-72B-0924 \ + --tensor-parallel-size 4 \ + --trust-remote-code \ + --num-prompts 1000 +""" +import argparse +import asyncio +import base64 +import dataclasses +import io +import math +import random +import time +from itertools import chain + +from datasets import load_dataset +from PIL import Image + +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.chat_utils import load_chat_template +from vllm.utils import FlexibleArgumentParser + + +def sample_mmmu_pro_vision_requests( + dataset, + num_requests: int, + image_hit_rate: float, +): + sampled_requests = [] + num_unique_images = max(int(num_requests * (1 - image_hit_rate)), 1) + print( + f"Total {num_requests} requests with {num_unique_images} unique images" + ) + dataset = dataset.take(num_unique_images) + + # The dataset with streaming=True fetches (downloads) 64 rows at a time. + print("Fetching data. This may take a while...") + for data in dataset: + if len(sampled_requests) == num_requests: + break + + # MMMU-Pro vision direct prompt + # Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5 + prompt = ( + "Answer with the option letter from the given choices directly. " + "The last line of your response should be of the following " + "format: 'Answer: $LETTER' (without quotes) where LETTER is one of " + "options.") + + image: Image.Image = data["image"] + image = image.convert("RGB") + image_data = io.BytesIO() + image.save(image_data, format='JPEG') + image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") + mm_content = { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + + messages = [{ + "role": + "user", + "content": [ + { + "type": "text", + "text": prompt + }, + mm_content, + ], + }] + sampled_requests.append(messages) + + n = math.ceil(num_requests / num_unique_images) + sampled_requests = list( + chain.from_iterable([x] * n for x in sampled_requests))[:num_requests] + + return sampled_requests + + +def sample_hf_requests( + num_requests: int, + random_seed: int, + image_hit_rate: float, +): + dataset = load_dataset('MMMU/MMMU_Pro', + name='vision', + split="test", + streaming=True) + dataset = dataset.shuffle(seed=random_seed) + return sample_mmmu_pro_vision_requests(dataset, num_requests, + image_hit_rate) + + +def initialize_llm(engine_args): + print("Initializing LLM...") + return LLM(**dataclasses.asdict(engine_args)) + + +async def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + engine_args = EngineArgs.from_cli_args(args) + + sampling_params = SamplingParams(max_tokens=args.output_len, temperature=0) + chat_template = load_chat_template(args.chat_template) + + # Concurrently initialize the LLM and sample data. Note that since + # both initialize_llm and sample_hf_requests are blocking, we need to + # use asyncio.to_thread to create async coroutines. + st = time.perf_counter() + sampling_task = asyncio.create_task( + asyncio.to_thread(sample_hf_requests, args.num_prompts, args.seed, + args.image_hit_rate)) + llm_task = asyncio.create_task( + asyncio.to_thread(initialize_llm, engine_args)) + + sampled, llm = await asyncio.gather(sampling_task, llm_task) + print(f"Data sampling + LLM init time: {time.perf_counter() - st:.2f}s") + + st = time.perf_counter() + outputs = llm.chat(sampled, + sampling_params=sampling_params, + chat_template=chat_template) + duration = time.perf_counter() - st + + total_generated_tokens = sum( + len(output.outputs[0].token_ids) for output in outputs) + + print(f"Request throughput: {args.num_prompts / duration:.2f} req/s") + print(f"Total generated tokens: {total_generated_tokens}") + print( + f"Token generation rate: {total_generated_tokens / duration:.2f} tok/s" + ) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the throughput.") + parser.add_argument("--output-len", + type=int, + default=128, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.") + parser.add_argument("--image-hit-rate", + type=float, + default=0.0, + help="Image hit rate between 0 and 1.") + parser.add_argument("--chat-template", + type=str, + default=None, + help="Set the chat template to use.") + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + + asyncio.run(main(args)) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 0c0db419a5957..e3bf5c2565605 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -509,14 +509,12 @@ def _rand_video( min_wh: int, max_wh: int, ): - num_frames = rng.randint(min_frames, max_frames) - w, h = rng.randint(min_wh, max_wh, size=(2, )) - # Temporary fix. Qwen2-VL video processor fails on video of shape # (b, 199, 178, 3) where b in (3, 5, 7) - w = (w // 32) * 32 - h = (h // 32) * 32 + num_frames = rng.randint(min_frames, max_frames) + num_frames = (num_frames // 2) * 2 + w, h = rng.randint(min_wh, max_wh, size=(2, )) return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8) @@ -527,7 +525,7 @@ def _rand_audio( sr: int, ): audio_len = rng.randint(min_len, max_len) - return rng.randint(0, 255, size=(audio_len, ), dtype=np.uint8), sr + return rng.rand(audio_len), sr # yapf: disable @@ -542,7 +540,7 @@ def _rand_audio( ("fixie-ai/ultravox-v0_3", {"audio"}), ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) -@pytest.mark.parametrize("num_batches", [10]) +@pytest.mark.parametrize("num_batches", [32]) @pytest.mark.parametrize("simplify_rate", [1.0]) # yapf: enable def test_processing_cache_correctness( @@ -588,7 +586,7 @@ def test_processing_cache_correctness( "video": partial(_rand_video, rng, - min_frames=1, + min_frames=2, max_frames=8, min_wh=128, max_wh=256), diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index eab829a2983d4..d46afc7e23b6c 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -280,23 +280,20 @@ def from_hf_inputs( for key, config in config_by_key.items() if key in hf_inputs } - if enable_sanity_checks: - batch_sizes = {k: len(v) for k, v in items_by_key.items()} - batch_size = next(iter(batch_sizes.values()), 0) - assert all(bs == batch_size for bs in batch_sizes.values()), dict( - batch_sizes=batch_sizes, items_by_key=items_by_key) - # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key` # We assume that those fields are not used in vLLM data = {k: hf_inputs[k] for k in items_by_key} - return MultiModalKwargs(data, items_by_key=items_by_key) + return MultiModalKwargs(data, + items_by_key=items_by_key, + enable_sanity_checks=enable_sanity_checks) def __init__( self, data: Mapping[str, NestedTensors], *, items_by_key: Optional[Mapping[str, list[MultiModalFieldItem]]] = None, + enable_sanity_checks: bool = False, ) -> None: if items_by_key is None: items_by_key = {} @@ -313,6 +310,17 @@ def __init__( self._keys_by_modality = dict(keys_by_modality) + if enable_sanity_checks: + for modality, keys in keys_by_modality.items(): + items_in_modality = {k: items_by_key[k] for k in keys} + batch_sizes = {k: len(v) for k, v in items_in_modality.items()} + batch_size = next(iter(batch_sizes.values()), 0) + assert all(bs == batch_size + for bs in batch_sizes.values()), dict( + modality=modality, + batch_sizes=batch_sizes, + items_by_key=items_by_key) + @staticmethod def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: """ @@ -435,18 +443,14 @@ def from_items_by_modality( for k, v in field.items(): items_by_key[k].append(v) - if enable_sanity_checks: - batch_sizes = {k: len(v) for k, v in items_by_key.items()} - batch_size = next(iter(batch_sizes.values()), 0) - assert all(bs == batch_size for bs in batch_sizes.values()), dict( - batch_sizes=batch_sizes, items_by_key=items_by_key) - data = { k: items[0].field.reduce(items).data for k, items in items_by_key.items() } - return MultiModalKwargs(data, items_by_key=items_by_key) + return MultiModalKwargs(data, + items_by_key=items_by_key, + enable_sanity_checks=enable_sanity_checks) MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]