Skip to content

Commit

Permalink
[Misc] Refactor benchmark_throughput.py (vllm-project#9779)
Browse files Browse the repository at this point in the history
Signed-off-by: Linkun Chen <[email protected]>
Co-authored-by: Linkun Chen <[email protected]>
Co-authored-by: Linkun Chen <[email protected]>
  • Loading branch information
3 people authored Nov 4, 2024
1 parent 67ed2df commit b56f422
Showing 1 changed file with 55 additions and 26 deletions.
81 changes: 55 additions & 26 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import random
import time
from typing import List, Optional, Tuple
from typing import List, Optional

import torch
import uvloop
Expand All @@ -15,16 +15,35 @@
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators


@dataclasses.dataclass
class SampleRequest:
"""A class representing a single inference request for benchmarking.
Attributes:
prompt: The input text prompt for the model.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
"""
prompt: str
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[MultiModalDataDict] = None


def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
) -> List[SampleRequest]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

Expand All @@ -41,7 +60,7 @@ def sample_requests(
random.shuffle(dataset)

# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
filtered_dataset: List[SampleRequest] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
Expand All @@ -60,31 +79,34 @@ def sample_requests(
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
filtered_dataset.append(
SampleRequest(prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len))

return filtered_dataset


def run_vllm(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
n: int,
engine_args: EngineArgs,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))

# Add the requests to the engine.
prompts: List[str] = []
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
for request in requests:
prompts.append(TextPrompt(prompt=request.prompt))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
max_tokens=request.expected_output_len,
))

use_beam_search = False
Expand All @@ -94,11 +116,11 @@ def run_vllm(
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
else:
prompts = [prompt for prompt, _, _ in requests]
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
for prompt, input_len, _output_len in requests:
assert _output_len == output_len
for request in requests:
assert request.expected_output_len == output_len
start = time.perf_counter()
llm.beam_search(
prompts,
Expand All @@ -112,7 +134,7 @@ def run_vllm(


async def run_vllm_async(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
n: int,
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
Expand All @@ -123,17 +145,17 @@ async def run_vllm_async(
engine_args, disable_frontend_multiprocessing) as llm:

# Add the requests to the engine.
prompts: List[str] = []
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
for request in requests:
prompts.append(TextPrompt(prompt=request.prompt))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
max_tokens=request.expected_output_len,
))

generators = []
Expand All @@ -149,7 +171,7 @@ async def run_vllm_async(


def run_hf(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
Expand Down Expand Up @@ -207,14 +229,14 @@ def run_hf(


def run_mii(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
model: str,
tensor_parallel_size: int,
output_len: int,
) -> float:
from mii import client, serve
llm = serve(model, tensor_parallel=tensor_parallel_size)
prompts = [prompt for prompt, _, _ in requests]
prompts = [request.prompt for request in requests]

start = time.perf_counter()
llm.generate(prompts, max_new_tokens=output_len)
Expand Down Expand Up @@ -243,8 +265,12 @@ def main(args: argparse.Namespace):
else:
raise ValueError(
f"Failed to synthesize a prompt with {args.input_len} tokens.")
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
requests = [
SampleRequest(prompt=prompt,
prompt_len=args.input_len,
expected_output_len=args.output_len)
for _ in range(args.num_prompts)
]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
Expand All @@ -270,9 +296,10 @@ def main(args: argparse.Namespace):
args.output_len)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
total_output_tokens = sum(output_len for _, _, output_len in requests)
total_num_tokens = sum(request.prompt_len + request.expected_output_len
for request in requests)
total_output_tokens = sum(request.expected_output_len
for request in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
Expand All @@ -299,7 +326,9 @@ def main(args: argparse.Namespace):
parser.add_argument("--dataset",
type=str,
default=None,
help="Path to the dataset.")
help="Path to the dataset. The dataset is expected to "
"be a json in form of List[Dict[..., conversations: "
"List[Dict[..., value: <prompt_or_response>]]]]")
parser.add_argument("--input-len",
type=int,
default=None,
Expand Down

0 comments on commit b56f422

Please sign in to comment.