diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 1e5967bd9bf8b..c1b10b3cf8f58 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -4,7 +4,8 @@ import json import random import time -from typing import List, Optional +from functools import cache +from typing import Dict, List, Optional, Tuple import torch import uvloop @@ -17,8 +18,11 @@ from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) from vllm.inputs import TextPrompt +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import MultiModalDataDict from vllm.sampling_params import BeamSearchParams +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer from vllm.utils import FlexibleArgumentParser, merge_async_iterators @@ -28,15 +32,17 @@ class SampleRequest: Attributes: prompt: The input text prompt for the model. - multi_modal_data: Optional dictionary containing multi-modal data (e.g. - images). prompt_len: The length of the prompt in tokens. expected_output_len: The expected length of the output in tokens. + multi_modal_data: Optional dictionary containing multi-modal data (e.g. + images). + lora_request: Optional LoRARequest specifying the LoRA to use. """ prompt: str prompt_len: int expected_output_len: int multi_modal_data: Optional[MultiModalDataDict] = None + lora_request: Optional[LoRARequest] = None def _get_prompt_for_image_model(question: str, *, model: str) -> str: @@ -60,8 +66,30 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str: raise ValueError(f"Unsupported model {model}") +@cache +def lora_path_on_disk(lora_path: str) -> str: + return get_adapter_absolute_path(lora_path) + + +lora_tokenizer_cache: Dict[int, AnyTokenizer] = {} + + +def get_random_lora_request( + args: argparse.Namespace +) -> Tuple[LoRARequest, Optional[AnyTokenizer]]: + global lora_tokenizer_cache + lora_id = random.randint(1, args.max_loras) + lora_request = LoRARequest(lora_name=str(lora_id), + lora_int_id=lora_id, + lora_path=lora_path_on_disk(args.lora_path)) + if lora_id not in lora_tokenizer_cache: + lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) + return lora_request, lora_tokenizer_cache[lora_id] + + def sample_requests(tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace) -> List[SampleRequest]: + dataset_path: str = args.dataset num_requests: int = args.num_prompts fixed_output_len: Optional[int] = args.output_len @@ -79,7 +107,9 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, # Filter out sequences that are too long or too short filtered_dataset: List[SampleRequest] = [] - for data in dataset: + for data in tqdm(dataset, + total=len(filtered_dataset), + desc="sampling requests"): if len(filtered_dataset) == num_requests: break @@ -102,9 +132,16 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, continue prompt = _get_prompt_for_image_model(question=prompt, model=model) + request_tokenizer = tokenizer + lora_request: Optional[LoRARequest] = None + if args.enable_lora: + lora_request, lora_tokenizer = get_random_lora_request(args) + if lora_tokenizer: + request_tokenizer = lora_tokenizer + # Tokenize the prompts and completions. - prompt_token_ids = tokenizer(prompt).input_ids - completion_token_ids = tokenizer(completion).input_ids + prompt_token_ids = request_tokenizer(prompt).input_ids + completion_token_ids = request_tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) output_len = len(completion_token_ids ) if fixed_output_len is None else fixed_output_len @@ -118,7 +155,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, SampleRequest(prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, - multi_modal_data=multi_modal_data)) + multi_modal_data=multi_modal_data, + lora_request=lora_request)) return filtered_dataset @@ -146,14 +184,21 @@ def run_vllm( ignore_eos=True, max_tokens=request.expected_output_len, )) + lora_requests: Optional[List[LoRARequest]] = None + if engine_args.enable_lora: + lora_requests = [request.lora_request for request in requests] use_beam_search = False if not use_beam_search: start = time.perf_counter() - llm.generate(prompts, sampling_params, use_tqdm=True) + llm.generate(prompts, + sampling_params, + lora_request=lora_requests, + use_tqdm=True) end = time.perf_counter() else: + assert lora_requests is None, "BeamSearch API does not support LoRA" prompts = [request.prompt for request in requests] # output_len should be the same for all requests. output_len = requests[0][2] @@ -185,6 +230,7 @@ async def run_vllm_async( # Add the requests to the engine. prompts: List[TextPrompt] = [] sampling_params: List[SamplingParams] = [] + lora_requests: List[Optional[LoRARequest]] = [] for request in requests: prompts.append( TextPrompt(prompt=request.prompt, @@ -197,11 +243,16 @@ async def run_vllm_async( ignore_eos=True, max_tokens=request.expected_output_len, )) + lora_requests.append(request.lora_request) generators = [] start = time.perf_counter() - for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): - generator = llm.generate(prompt, sp, request_id=f"test{i}") + for i, (prompt, sp, + lr) in enumerate(zip(prompts, sampling_params, lora_requests)): + generator = llm.generate(prompt, + sp, + lora_request=lr, + request_id=f"test{i}") generators.append(generator) all_gens = merge_async_iterators(*generators) async for i, res in all_gens: @@ -297,6 +348,14 @@ def main(args: argparse.Namespace): vocab_size = tokenizer.vocab_size requests = [] for _ in range(args.num_prompts): + + request_tokenizer = tokenizer + lora_request: Optional[LoRARequest] = None + if args.enable_lora: + lora_request, lora_tokenizer = get_random_lora_request(args) + if lora_tokenizer: + request_tokenizer = lora_tokenizer + # Synthesize a prompt with the given input length. candidate_ids = [ random.randint(0, vocab_size - 1) @@ -305,8 +364,8 @@ def main(args: argparse.Namespace): # As tokenizer may add additional tokens like BOS, we need to try # different lengths to get the desired input length. for _ in range(5): # Max attempts to correct - candidate_prompt = tokenizer.decode(candidate_ids) - tokenized_len = len(tokenizer.encode(candidate_prompt)) + candidate_prompt = request_tokenizer.decode(candidate_ids) + tokenized_len = len(request_tokenizer.encode(candidate_prompt)) if tokenized_len == args.input_len: break @@ -323,7 +382,8 @@ def main(args: argparse.Namespace): requests.append( SampleRequest(prompt=candidate_prompt, prompt_len=args.input_len, - expected_output_len=args.output_len)) + expected_output_len=args.output_len, + lora_request=lora_request)) else: requests = sample_requests(tokenizer, args) @@ -422,6 +482,14 @@ def main(args: argparse.Namespace): action='store_true', default=False, help="Disable decoupled async engine frontend.") + # LoRA + parser.add_argument( + "--lora-path", + type=str, + default=None, + help="Path to the lora adapters to use. This can be an absolute path, " + "a relative path, or a Hugging Face model identifier.") + parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() if args.tokenizer is None: @@ -431,6 +499,8 @@ def main(args: argparse.Namespace): assert args.output_len is not None else: assert args.input_len is None + if args.enable_lora: + assert args.lora_path is not None if args.backend == "vllm": if args.hf_max_batch_size is not None: @@ -440,6 +510,9 @@ def main(args: argparse.Namespace): raise ValueError("HF max batch size is required for HF backend.") if args.quantization is not None: raise ValueError("Quantization is only for vLLM backend.") + if args.enable_lora is not None: + raise ValueError("LoRA benchmarking is only supported for vLLM" + " backend") elif args.backend == "mii": if args.dtype != "auto": raise ValueError("dtype must be auto for MII backend.") @@ -452,4 +525,7 @@ def main(args: argparse.Namespace): if args.tokenizer != args.model: raise ValueError("Tokenizer must be the same as the model for MII " "backend.") + if args.enable_lora is not None: + raise ValueError("LoRA benchmarking is only supported for vLLM" + " backend") main(args)