Skip to content

Commit

Permalink
Fix Qwen2-VL and Qwen2-Audio
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 23, 2024
1 parent b5e5b8a commit cf24a1b
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 22 deletions.
172 changes: 172 additions & 0 deletions benchmarks/mmmu_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
r"""Benchmark offline inference throughput with MMMU-PRO Vision
e.g,
python3 benchmarks/mmmu_bench.py \
--model mistralai/Pixtral-12B-2409 \
--tokenizer-mode mistral \
--num-prompts 1000 \
--image-hit-rate 0.5
python3 benchmarks/mmmu_bench.py \
--model allenai/Molmo-72B-0924 \
--tensor-parallel-size 4 \
--trust-remote-code \
--num-prompts 1000
"""
import argparse
import asyncio
import base64
import dataclasses
import io
import math
import random
import time
from itertools import chain

from datasets import load_dataset
from PIL import Image

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.utils import FlexibleArgumentParser


def sample_mmmu_pro_vision_requests(
dataset,
num_requests: int,
image_hit_rate: float,
):
sampled_requests = []
num_unique_images = max(int(num_requests * (1 - image_hit_rate)), 1)
print(
f"Total {num_requests} requests with {num_unique_images} unique images"
)
dataset = dataset.take(num_unique_images)

# The dataset with streaming=True fetches (downloads) 64 rows at a time.
print("Fetching data. This may take a while...")
for data in dataset:
if len(sampled_requests) == num_requests:
break

# MMMU-Pro vision direct prompt
# Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5
prompt = (
"Answer with the option letter from the given choices directly. "
"The last line of your response should be of the following "
"format: 'Answer: $LETTER' (without quotes) where LETTER is one of "
"options.")

image: Image.Image = data["image"]
image = image.convert("RGB")
image_data = io.BytesIO()
image.save(image_data, format='JPEG')
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
mm_content = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
}

messages = [{
"role":
"user",
"content": [
{
"type": "text",
"text": prompt
},
mm_content,
],
}]
sampled_requests.append(messages)

n = math.ceil(num_requests / num_unique_images)
sampled_requests = list(
chain.from_iterable([x] * n for x in sampled_requests))[:num_requests]

return sampled_requests


def sample_hf_requests(
num_requests: int,
random_seed: int,
image_hit_rate: float,
):
dataset = load_dataset('MMMU/MMMU_Pro',
name='vision',
split="test",
streaming=True)
dataset = dataset.shuffle(seed=random_seed)
return sample_mmmu_pro_vision_requests(dataset, num_requests,
image_hit_rate)


def initialize_llm(engine_args):
print("Initializing LLM...")
return LLM(**dataclasses.asdict(engine_args))


async def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
engine_args = EngineArgs.from_cli_args(args)

sampling_params = SamplingParams(max_tokens=args.output_len, temperature=0)
chat_template = load_chat_template(args.chat_template)

# Concurrently initialize the LLM and sample data. Note that since
# both initialize_llm and sample_hf_requests are blocking, we need to
# use asyncio.to_thread to create async coroutines.
st = time.perf_counter()
sampling_task = asyncio.create_task(
asyncio.to_thread(sample_hf_requests, args.num_prompts, args.seed,
args.image_hit_rate))
llm_task = asyncio.create_task(
asyncio.to_thread(initialize_llm, engine_args))

sampled, llm = await asyncio.gather(sampling_task, llm_task)
print(f"Data sampling + LLM init time: {time.perf_counter() - st:.2f}s")

st = time.perf_counter()
outputs = llm.chat(sampled,
sampling_params=sampling_params,
chat_template=chat_template)
duration = time.perf_counter() - st

total_generated_tokens = sum(
len(output.outputs[0].token_ids) for output in outputs)

print(f"Request throughput: {args.num_prompts / duration:.2f} req/s")
print(f"Total generated tokens: {total_generated_tokens}")
print(
f"Token generation rate: {total_generated_tokens / duration:.2f} tok/s"
)


if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--output-len",
type=int,
default=128,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--image-hit-rate",
type=float,
default=0.0,
help="Image hit rate between 0 and 1.")
parser.add_argument("--chat-template",
type=str,
default=None,
help="Set the chat template to use.")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model

asyncio.run(main(args))
14 changes: 6 additions & 8 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,14 +509,12 @@ def _rand_video(
min_wh: int,
max_wh: int,
):
num_frames = rng.randint(min_frames, max_frames)
w, h = rng.randint(min_wh, max_wh, size=(2, ))

# Temporary fix. Qwen2-VL video processor fails on video of shape
# (b, 199, 178, 3) where b in (3, 5, 7)
w = (w // 32) * 32
h = (h // 32) * 32
num_frames = rng.randint(min_frames, max_frames)
num_frames = (num_frames // 2) * 2

w, h = rng.randint(min_wh, max_wh, size=(2, ))
return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)


Expand All @@ -527,7 +525,7 @@ def _rand_audio(
sr: int,
):
audio_len = rng.randint(min_len, max_len)
return rng.randint(0, 255, size=(audio_len, ), dtype=np.uint8), sr
return rng.rand(audio_len), sr


# yapf: disable
Expand All @@ -542,7 +540,7 @@ def _rand_audio(
("fixie-ai/ultravox-v0_3", {"audio"}),
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [10])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness(
Expand Down Expand Up @@ -588,7 +586,7 @@ def test_processing_cache_correctness(
"video":
partial(_rand_video,
rng,
min_frames=1,
min_frames=2,
max_frames=8,
min_wh=128,
max_wh=256),
Expand Down
32 changes: 18 additions & 14 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,23 +280,20 @@ def from_hf_inputs(
for key, config in config_by_key.items() if key in hf_inputs
}

if enable_sanity_checks:
batch_sizes = {k: len(v) for k, v in items_by_key.items()}
batch_size = next(iter(batch_sizes.values()), 0)
assert all(bs == batch_size for bs in batch_sizes.values()), dict(
batch_sizes=batch_sizes, items_by_key=items_by_key)

# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
# We assume that those fields are not used in vLLM
data = {k: hf_inputs[k] for k in items_by_key}

return MultiModalKwargs(data, items_by_key=items_by_key)
return MultiModalKwargs(data,
items_by_key=items_by_key,
enable_sanity_checks=enable_sanity_checks)

def __init__(
self,
data: Mapping[str, NestedTensors],
*,
items_by_key: Optional[Mapping[str, list[MultiModalFieldItem]]] = None,
enable_sanity_checks: bool = False,
) -> None:
if items_by_key is None:
items_by_key = {}
Expand All @@ -313,6 +310,17 @@ def __init__(

self._keys_by_modality = dict(keys_by_modality)

if enable_sanity_checks:
for modality, keys in keys_by_modality.items():
items_in_modality = {k: items_by_key[k] for k in keys}
batch_sizes = {k: len(v) for k, v in items_in_modality.items()}
batch_size = next(iter(batch_sizes.values()), 0)
assert all(bs == batch_size
for bs in batch_sizes.values()), dict(
modality=modality,
batch_sizes=batch_sizes,
items_by_key=items_by_key)

@staticmethod
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
"""
Expand Down Expand Up @@ -435,18 +443,14 @@ def from_items_by_modality(
for k, v in field.items():
items_by_key[k].append(v)

if enable_sanity_checks:
batch_sizes = {k: len(v) for k, v in items_by_key.items()}
batch_size = next(iter(batch_sizes.values()), 0)
assert all(bs == batch_size for bs in batch_sizes.values()), dict(
batch_sizes=batch_sizes, items_by_key=items_by_key)

data = {
k: items[0].field.reduce(items).data
for k, items in items_by_key.items()
}

return MultiModalKwargs(data, items_by_key=items_by_key)
return MultiModalKwargs(data,
items_by_key=items_by_key,
enable_sanity_checks=enable_sanity_checks)


MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
Expand Down

0 comments on commit cf24a1b

Please sign in to comment.