Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc] benchmark_throughput : Add LoRA #11267

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 89 additions & 13 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import json
import random
import time
from typing import List, Optional
from functools import cache
from typing import Dict, List, Optional, Tuple

import torch
import uvloop
Expand All @@ -17,8 +18,11 @@
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
from vllm.utils import FlexibleArgumentParser, merge_async_iterators


Expand All @@ -28,15 +32,17 @@ class SampleRequest:

Attributes:
prompt: The input text prompt for the model.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
lora_request: Optional LoRARequest specifying the LoRA to use.
"""
prompt: str
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[MultiModalDataDict] = None
lora_request: Optional[LoRARequest] = None


def _get_prompt_for_image_model(question: str, *, model: str) -> str:
Expand All @@ -60,8 +66,30 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str:
raise ValueError(f"Unsupported model {model}")


@cache
def lora_path_on_disk(lora_path: str) -> str:
return get_adapter_absolute_path(lora_path)


lora_tokenizer_cache: Dict[int, AnyTokenizer] = {}


def get_random_lora_request(
args: argparse.Namespace
) -> Tuple[LoRARequest, Optional[AnyTokenizer]]:
global lora_tokenizer_cache
lora_id = random.randint(1, args.max_loras)
lora_request = LoRARequest(lora_name=str(lora_id),
lora_int_id=lora_id,
lora_path=lora_path_on_disk(args.lora_path))
if lora_id not in lora_tokenizer_cache:
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
return lora_request, lora_tokenizer_cache[lora_id]


def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]:

dataset_path: str = args.dataset
num_requests: int = args.num_prompts
fixed_output_len: Optional[int] = args.output_len
Expand All @@ -79,7 +107,9 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,

# Filter out sequences that are too long or too short
filtered_dataset: List[SampleRequest] = []
for data in dataset:
for data in tqdm(dataset,
total=len(filtered_dataset),
desc="sampling requests"):
if len(filtered_dataset) == num_requests:
break

Expand All @@ -102,9 +132,16 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
continue
prompt = _get_prompt_for_image_model(question=prompt, model=model)

request_tokenizer = tokenizer
lora_request: Optional[LoRARequest] = None
if args.enable_lora:
lora_request, lora_tokenizer = get_random_lora_request(args)
if lora_tokenizer:
request_tokenizer = lora_tokenizer

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are too many LoRA weights that don't include tokenizer files, which leads to request_tokenizer being None. We need to consider this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch ! Added a fallback to use the base tokenizer.

# Tokenize the prompts and completions.
prompt_token_ids = tokenizer(prompt).input_ids
completion_token_ids = tokenizer(completion).input_ids
prompt_token_ids = request_tokenizer(prompt).input_ids
completion_token_ids = request_tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
Expand All @@ -118,7 +155,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
SampleRequest(prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=multi_modal_data))
multi_modal_data=multi_modal_data,
lora_request=lora_request))

return filtered_dataset

Expand Down Expand Up @@ -146,14 +184,21 @@ def run_vllm(
ignore_eos=True,
max_tokens=request.expected_output_len,
))
lora_requests: Optional[List[LoRARequest]] = None
if engine_args.enable_lora:
lora_requests = [request.lora_request for request in requests]

use_beam_search = False

if not use_beam_search:
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
llm.generate(prompts,
sampling_params,
lora_request=lora_requests,
use_tqdm=True)
end = time.perf_counter()
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
Expand Down Expand Up @@ -185,6 +230,7 @@ async def run_vllm_async(
# Add the requests to the engine.
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
lora_requests: List[Optional[LoRARequest]] = []
for request in requests:
prompts.append(
TextPrompt(prompt=request.prompt,
Expand All @@ -197,11 +243,16 @@ async def run_vllm_async(
ignore_eos=True,
max_tokens=request.expected_output_len,
))
lora_requests.append(request.lora_request)

generators = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
for i, (prompt, sp,
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
generator = llm.generate(prompt,
sp,
lora_request=lr,
request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
Expand Down Expand Up @@ -297,6 +348,14 @@ def main(args: argparse.Namespace):
vocab_size = tokenizer.vocab_size
requests = []
for _ in range(args.num_prompts):

request_tokenizer = tokenizer
lora_request: Optional[LoRARequest] = None
if args.enable_lora:
lora_request, lora_tokenizer = get_random_lora_request(args)
if lora_tokenizer:
request_tokenizer = lora_tokenizer

# Synthesize a prompt with the given input length.
candidate_ids = [
random.randint(0, vocab_size - 1)
Expand All @@ -305,8 +364,8 @@ def main(args: argparse.Namespace):
# As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length.
for _ in range(5): # Max attempts to correct
candidate_prompt = tokenizer.decode(candidate_ids)
tokenized_len = len(tokenizer.encode(candidate_prompt))
candidate_prompt = request_tokenizer.decode(candidate_ids)
tokenized_len = len(request_tokenizer.encode(candidate_prompt))

if tokenized_len == args.input_len:
break
Expand All @@ -323,7 +382,8 @@ def main(args: argparse.Namespace):
requests.append(
SampleRequest(prompt=candidate_prompt,
prompt_len=args.input_len,
expected_output_len=args.output_len))
expected_output_len=args.output_len,
lora_request=lora_request))
else:
requests = sample_requests(tokenizer, args)

Expand Down Expand Up @@ -422,6 +482,14 @@ def main(args: argparse.Namespace):
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
# LoRA
parser.add_argument(
"--lora-path",
type=str,
default=None,
help="Path to the lora adapters to use. This can be an absolute path, "
"a relative path, or a Hugging Face model identifier.")

parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
Expand All @@ -431,6 +499,8 @@ def main(args: argparse.Namespace):
assert args.output_len is not None
else:
assert args.input_len is None
if args.enable_lora:
assert args.lora_path is not None

if args.backend == "vllm":
if args.hf_max_batch_size is not None:
Expand All @@ -440,6 +510,9 @@ def main(args: argparse.Namespace):
raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.enable_lora is not None:
raise ValueError("LoRA benchmarking is only supported for vLLM"
" backend")
elif args.backend == "mii":
if args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.")
Expand All @@ -452,4 +525,7 @@ def main(args: argparse.Namespace):
if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII "
"backend.")
if args.enable_lora is not None:
raise ValueError("LoRA benchmarking is only supported for vLLM"
" backend")
main(args)
Loading