Skip to content

Commit

Permalink
rebase on top of upstream main
Browse files Browse the repository at this point in the history
Signed-off-by: Linkun Chen <[email protected]>
  • Loading branch information
Linkun Chen committed Nov 4, 2024
1 parent 84237f5 commit 13c726c
Showing 1 changed file with 62 additions and 5 deletions.
67 changes: 62 additions & 5 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import uvloop
from PIL import Image
from PIL import Image
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
Expand Down Expand Up @@ -60,6 +61,33 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str:
raise ValueError(f"Unsupported model {model}")


def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]:
dataset_path: str = args.dataset
num_requests: int = args.num_prompts
fixed_output_len: Optional[int] = args.output_len
model: str = args.model
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
"""Prepend and append special tokens around the question to form a prompt.
Args:
question: The input question text to wrap with special tokens
model: The name of the model being used, to determine which special
tokens to add
Returns:
The formatted prompt string with appropriate special tokens for the
model
Raises:
ValueError: If an unsupported model name is provided
"""
model = model.lower()
if "pixtral" in model:
return f"<s>[INST]{question}\n[IMG][/INST]"
raise ValueError(f"Unsupported model {model}")


def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]:
dataset_path: str = args.dataset
Expand All @@ -79,6 +107,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,

# Filter out sequences that are too long or too short
filtered_dataset: List[SampleRequest] = []
for data in dataset:
for data in dataset:

Check failure on line 111 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff

benchmarks/benchmark_throughput.py:111:5: SyntaxError: Expected an indented block after `for` statement

Check failure on line 111 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff

benchmarks/benchmark_throughput.py:111:5: SyntaxError: Expected an indented block after `for` statement

Check failure on line 111 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff

benchmarks/benchmark_throughput.py:111:5: SyntaxError: Expected an indented block after `for` statement
if len(filtered_dataset) == num_requests:
break
Expand All @@ -87,6 +116,25 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
prompt = data["conversations"][0]["value"]
completion = data["conversations"][1]["value"]

multi_modal_data: Optional[MultiModalDataDict] = None
if "image" in data:
multi_modal_data = multi_modal_data or {}
image_path = data["image"]
# TODO(vllm-project/vllm/issues/9778): Support multiple images.
assert isinstance(image_path,
str), "Only support single image input"
try:
multi_modal_data["image"] = Image.open(image_path).convert(
"RGB")
except FileNotFoundError:
# Ignore datapoint where asset is missing
continue
prompt = _get_prompt_for_image_model(question=prompt, model=model)

# Only keep the first two turns of each conversation.
prompt = data["conversations"][0]["value"]
completion = data["conversations"][1]["value"]

multi_modal_data: Optional[MultiModalDataDict] = None
if "image" in data:
multi_modal_data = multi_modal_data or {}
Expand Down Expand Up @@ -119,6 +167,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=multi_modal_data))
expected_output_len=output_len,

Check failure on line 170 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff

benchmarks/benchmark_throughput.py:170:1: SyntaxError: Unexpected indentation

Check failure on line 170 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff

benchmarks/benchmark_throughput.py:170:1: SyntaxError: Unexpected indentation

Check failure on line 170 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff

benchmarks/benchmark_throughput.py:170:1: SyntaxError: Unexpected indentation
multi_modal_data=multi_modal_data))

Check failure on line 171 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff

benchmarks/benchmark_throughput.py:171:60: SyntaxError: Expected a statement

Check failure on line 171 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff

benchmarks/benchmark_throughput.py:171:61: SyntaxError: Expected a statement

Check failure on line 171 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff

benchmarks/benchmark_throughput.py:171:60: SyntaxError: Expected a statement

Check failure on line 171 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff

benchmarks/benchmark_throughput.py:171:61: SyntaxError: Expected a statement

Check failure on line 171 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff

benchmarks/benchmark_throughput.py:171:60: SyntaxError: Expected a statement

Check failure on line 171 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff

benchmarks/benchmark_throughput.py:171:61: SyntaxError: Expected a statement

Check failure on line 172 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff

benchmarks/benchmark_throughput.py:171:62: SyntaxError: Expected a statement

Check failure on line 172 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff

benchmarks/benchmark_throughput.py:171:62: SyntaxError: Expected a statement

Check failure on line 172 in benchmarks/benchmark_throughput.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff

benchmarks/benchmark_throughput.py:171:62: SyntaxError: Expected a statement
return filtered_dataset

Expand All @@ -135,6 +185,9 @@ def run_vllm(
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for request in requests:
prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
Expand Down Expand Up @@ -186,6 +239,9 @@ async def run_vllm_async(
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for request in requests:
prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
Expand Down Expand Up @@ -313,7 +369,10 @@ def main(args: argparse.Namespace):
]
else:
requests = sample_requests(tokenizer, args)
requests = sample_requests(tokenizer, args)

is_multi_modal = any(request.multi_modal_data is not None
for request in requests)
is_multi_modal = any(request.multi_modal_data is not None
for request in requests)
if args.backend == "vllm":
Expand Down Expand Up @@ -342,11 +401,9 @@ def main(args: argparse.Namespace):
total_output_tokens = sum(request.expected_output_len
for request in requests)
if is_multi_modal:
print(
"\033[91mWARNING\033[0m: Multi-modal request detected. The "
"following metrics is not accurate because image tokens are not "
"counted. See vllm-project/vllm/issues/9778 for details."
)
print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
"following metrics is not accurate because image tokens are not "
"counted. See vllm-project/vllm/issues/9778 for details.")
# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
Expand Down

0 comments on commit 13c726c

Please sign in to comment.