Skip to content

Commit

Permalink
[Feature] Update benchmark_throughput.py to support image input (vllm…
Browse files Browse the repository at this point in the history
…-project#9851)

Signed-off-by: Linkun Chen <[email protected]>
Co-authored-by: Linkun Chen <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
2 people authored and tlrmchlsmth committed Nov 23, 2024
1 parent c04ba4a commit 3cbe655
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 18 deletions.
11 changes: 11 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,14 @@ You can download the dataset by running:
```bash
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
```

## Downloading the ShareGPT4V dataset

The json file refers to several image datasets (coco, llava, etc.). The benchmark scripts
will ignore a datapoint if the referred image is missing.
```bash
wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/resolve/main/sharegpt4v_instruct_gpt4-vision_cap100k.json
mkdir coco -p
wget http://images.cocodataset.org/zips/train2017.zip -O coco/train2017.zip
unzip coco/train2017.zip -d coco/
```
82 changes: 64 additions & 18 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import uvloop
from PIL import Image
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
Expand Down Expand Up @@ -38,12 +39,33 @@ class SampleRequest:
multi_modal_data: Optional[MultiModalDataDict] = None


def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[SampleRequest]:
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
"""Prepend and append special tokens around the question to form a prompt.
Args:
question: The input question text to wrap with special tokens
model: The name of the model being used, to determine which special
tokens to add
Returns:
The formatted prompt string with appropriate special tokens for the
model
Raises:
ValueError: If an unsupported model name is provided
"""
model = model.lower()
if "pixtral" in model:
return f"<s>[INST]{question}\n[IMG][/INST]"
raise ValueError(f"Unsupported model {model}")


def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]:
dataset_path: str = args.dataset
num_requests: int = args.num_prompts
fixed_output_len: Optional[int] = args.output_len
model: str = args.model
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

Expand All @@ -52,23 +74,36 @@ def sample_requests(
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]

# Shuffle the dataset.
random.shuffle(dataset)

# Filter out sequences that are too long or too short
filtered_dataset: List[SampleRequest] = []
for i in range(len(dataset)):
for data in dataset:
if len(filtered_dataset) == num_requests:
break

# Only keep the first two turns of each conversation.
prompt = data["conversations"][0]["value"]
completion = data["conversations"][1]["value"]

multi_modal_data: Optional[MultiModalDataDict] = None
if "image" in data:
multi_modal_data = multi_modal_data or {}
image_path = data["image"]
# TODO(vllm-project/vllm/issues/9778): Support multiple images.
assert isinstance(image_path,
str), "Only support single image input"
try:
multi_modal_data["image"] = Image.open(image_path).convert(
"RGB")
except FileNotFoundError:
# Ignore datapoint where asset is missing
continue
prompt = _get_prompt_for_image_model(question=prompt, model=model)

# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
Expand All @@ -82,7 +117,8 @@ def sample_requests(
filtered_dataset.append(
SampleRequest(prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len))
expected_output_len=output_len,
multi_modal_data=multi_modal_data))

return filtered_dataset

Expand All @@ -99,7 +135,9 @@ def run_vllm(
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for request in requests:
prompts.append(TextPrompt(prompt=request.prompt))
prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append(
SamplingParams(
n=n,
Expand Down Expand Up @@ -148,7 +186,9 @@ async def run_vllm_async(
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for request in requests:
prompts.append(TextPrompt(prompt=request.prompt))
prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append(
SamplingParams(
n=n,
Expand Down Expand Up @@ -272,9 +312,10 @@ def main(args: argparse.Namespace):
for _ in range(args.num_prompts)
]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
requests = sample_requests(tokenizer, args)

is_multi_modal = any(request.multi_modal_data is not None
for request in requests)
if args.backend == "vllm":
if args.async_engine:
elapsed_time = uvloop.run(
Expand All @@ -300,6 +341,11 @@ def main(args: argparse.Namespace):
for request in requests)
total_output_tokens = sum(request.expected_output_len
for request in requests)
if is_multi_modal:
print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
"following metrics are not accurate because image tokens are not"
" counted. See vllm-project/vllm/issues/9778 for details.")
# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
Expand Down

0 comments on commit 3cbe655

Please sign in to comment.