diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index e4f0d2011133a..1b41f99ce9dbb 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -9,6 +9,7 @@ import torch import uvloop from PIL import Image +from PIL import Image from tqdm import tqdm from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) @@ -60,6 +61,33 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str: raise ValueError(f"Unsupported model {model}") +def sample_requests(tokenizer: PreTrainedTokenizerBase, + args: argparse.Namespace) -> List[SampleRequest]: + dataset_path: str = args.dataset + num_requests: int = args.num_prompts + fixed_output_len: Optional[int] = args.output_len + model: str = args.model +def _get_prompt_for_image_model(question: str, *, model: str) -> str: + """Prepend and append special tokens around the question to form a prompt. + + Args: + question: The input question text to wrap with special tokens + model: The name of the model being used, to determine which special + tokens to add + + Returns: + The formatted prompt string with appropriate special tokens for the + model + + Raises: + ValueError: If an unsupported model name is provided + """ + model = model.lower() + if "pixtral" in model: + return f"[INST]{question}\n[IMG][/INST]" + raise ValueError(f"Unsupported model {model}") + + def sample_requests(tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace) -> List[SampleRequest]: dataset_path: str = args.dataset @@ -79,6 +107,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, # Filter out sequences that are too long or too short filtered_dataset: List[SampleRequest] = [] + for data in dataset: for data in dataset: if len(filtered_dataset) == num_requests: break @@ -87,6 +116,25 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, prompt = data["conversations"][0]["value"] completion = data["conversations"][1]["value"] + multi_modal_data: Optional[MultiModalDataDict] = None + if "image" in data: + multi_modal_data = multi_modal_data or {} + image_path = data["image"] + # TODO(vllm-project/vllm/issues/9778): Support multiple images. + assert isinstance(image_path, + str), "Only support single image input" + try: + multi_modal_data["image"] = Image.open(image_path).convert( + "RGB") + except FileNotFoundError: + # Ignore datapoint where asset is missing + continue + prompt = _get_prompt_for_image_model(question=prompt, model=model) + + # Only keep the first two turns of each conversation. + prompt = data["conversations"][0]["value"] + completion = data["conversations"][1]["value"] + multi_modal_data: Optional[MultiModalDataDict] = None if "image" in data: multi_modal_data = multi_modal_data or {} @@ -119,6 +167,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=multi_modal_data)) + expected_output_len=output_len, + multi_modal_data=multi_modal_data)) return filtered_dataset @@ -135,6 +185,9 @@ def run_vllm( prompts: List[TextPrompt] = [] sampling_params: List[SamplingParams] = [] for request in requests: + prompts.append( + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) prompts.append( TextPrompt(prompt=request.prompt, multi_modal_data=request.multi_modal_data)) @@ -186,6 +239,9 @@ async def run_vllm_async( prompts: List[TextPrompt] = [] sampling_params: List[SamplingParams] = [] for request in requests: + prompts.append( + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) prompts.append( TextPrompt(prompt=request.prompt, multi_modal_data=request.multi_modal_data)) @@ -313,7 +369,10 @@ def main(args: argparse.Namespace): ] else: requests = sample_requests(tokenizer, args) + requests = sample_requests(tokenizer, args) + is_multi_modal = any(request.multi_modal_data is not None + for request in requests) is_multi_modal = any(request.multi_modal_data is not None for request in requests) if args.backend == "vllm": @@ -342,11 +401,9 @@ def main(args: argparse.Namespace): total_output_tokens = sum(request.expected_output_len for request in requests) if is_multi_modal: - print( - "\033[91mWARNING\033[0m: Multi-modal request detected. The " - "following metrics is not accurate because image tokens are not " - "counted. See vllm-project/vllm/issues/9778 for details." - ) + print("\033[91mWARNING\033[0m: Multi-modal request detected. The " + "following metrics is not accurate because image tokens are not " + "counted. See vllm-project/vllm/issues/9778 for details.") # TODO(vllm-project/vllm/issues/9778): Count molti-modal token length. print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "