From 3cbe655d92c9556f911159d02fa12c986f081095 Mon Sep 17 00:00:00 2001 From: lkchen Date: Tue, 5 Nov 2024 11:30:02 -0800 Subject: [PATCH] [Feature] Update benchmark_throughput.py to support image input (#9851) Signed-off-by: Linkun Chen Co-authored-by: Linkun Chen Signed-off-by: Tyler Michael Smith --- benchmarks/README.md | 11 ++++ benchmarks/benchmark_throughput.py | 82 +++++++++++++++++++++++------- 2 files changed, 75 insertions(+), 18 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index 192d6c4022c83..2aa4a285021f1 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -6,3 +6,14 @@ You can download the dataset by running: ```bash wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json ``` + +## Downloading the ShareGPT4V dataset + +The json file refers to several image datasets (coco, llava, etc.). The benchmark scripts +will ignore a datapoint if the referred image is missing. +```bash +wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/resolve/main/sharegpt4v_instruct_gpt4-vision_cap100k.json +mkdir coco -p +wget http://images.cocodataset.org/zips/train2017.zip -O coco/train2017.zip +unzip coco/train2017.zip -d coco/ +``` diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 262b8652e49ff..159cf055737ce 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -8,6 +8,7 @@ import torch import uvloop +from PIL import Image from tqdm import tqdm from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) @@ -38,12 +39,33 @@ class SampleRequest: multi_modal_data: Optional[MultiModalDataDict] = None -def sample_requests( - dataset_path: str, - num_requests: int, - tokenizer: PreTrainedTokenizerBase, - fixed_output_len: Optional[int], -) -> List[SampleRequest]: +def _get_prompt_for_image_model(question: str, *, model: str) -> str: + """Prepend and append special tokens around the question to form a prompt. + + Args: + question: The input question text to wrap with special tokens + model: The name of the model being used, to determine which special + tokens to add + + Returns: + The formatted prompt string with appropriate special tokens for the + model + + Raises: + ValueError: If an unsupported model name is provided + """ + model = model.lower() + if "pixtral" in model: + return f"[INST]{question}\n[IMG][/INST]" + raise ValueError(f"Unsupported model {model}") + + +def sample_requests(tokenizer: PreTrainedTokenizerBase, + args: argparse.Namespace) -> List[SampleRequest]: + dataset_path: str = args.dataset + num_requests: int = args.num_prompts + fixed_output_len: Optional[int] = args.output_len + model: str = args.model if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") @@ -52,23 +74,36 @@ def sample_requests( dataset = json.load(f) # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Only keep the first two turns of each conversation. - dataset = [(data["conversations"][0]["value"], - data["conversations"][1]["value"]) for data in dataset] - # Shuffle the dataset. random.shuffle(dataset) # Filter out sequences that are too long or too short filtered_dataset: List[SampleRequest] = [] - for i in range(len(dataset)): + for data in dataset: if len(filtered_dataset) == num_requests: break + # Only keep the first two turns of each conversation. + prompt = data["conversations"][0]["value"] + completion = data["conversations"][1]["value"] + + multi_modal_data: Optional[MultiModalDataDict] = None + if "image" in data: + multi_modal_data = multi_modal_data or {} + image_path = data["image"] + # TODO(vllm-project/vllm/issues/9778): Support multiple images. + assert isinstance(image_path, + str), "Only support single image input" + try: + multi_modal_data["image"] = Image.open(image_path).convert( + "RGB") + except FileNotFoundError: + # Ignore datapoint where asset is missing + continue + prompt = _get_prompt_for_image_model(question=prompt, model=model) + # Tokenize the prompts and completions. - prompt = dataset[i][0] prompt_token_ids = tokenizer(prompt).input_ids - completion = dataset[i][1] completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) output_len = len(completion_token_ids @@ -82,7 +117,8 @@ def sample_requests( filtered_dataset.append( SampleRequest(prompt=prompt, prompt_len=prompt_len, - expected_output_len=output_len)) + expected_output_len=output_len, + multi_modal_data=multi_modal_data)) return filtered_dataset @@ -99,7 +135,9 @@ def run_vllm( prompts: List[TextPrompt] = [] sampling_params: List[SamplingParams] = [] for request in requests: - prompts.append(TextPrompt(prompt=request.prompt)) + prompts.append( + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) sampling_params.append( SamplingParams( n=n, @@ -148,7 +186,9 @@ async def run_vllm_async( prompts: List[TextPrompt] = [] sampling_params: List[SamplingParams] = [] for request in requests: - prompts.append(TextPrompt(prompt=request.prompt)) + prompts.append( + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) sampling_params.append( SamplingParams( n=n, @@ -272,9 +312,10 @@ def main(args: argparse.Namespace): for _ in range(args.num_prompts) ] else: - requests = sample_requests(args.dataset, args.num_prompts, tokenizer, - args.output_len) + requests = sample_requests(tokenizer, args) + is_multi_modal = any(request.multi_modal_data is not None + for request in requests) if args.backend == "vllm": if args.async_engine: elapsed_time = uvloop.run( @@ -300,6 +341,11 @@ def main(args: argparse.Namespace): for request in requests) total_output_tokens = sum(request.expected_output_len for request in requests) + if is_multi_modal: + print("\033[91mWARNING\033[0m: Multi-modal request detected. The " + "following metrics are not accurate because image tokens are not" + " counted. See vllm-project/vllm/issues/9778 for details.") + # TODO(vllm-project/vllm/issues/9778): Count molti-modal token length. print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " f"{total_output_tokens / elapsed_time:.2f} output tokens/s")