From 2ccff7fc37f5d819fc92b6a53120ce476c1e396a Mon Sep 17 00:00:00 2001 From: Ricky Xu Date: Mon, 18 Nov 2024 15:39:14 -0800 Subject: [PATCH] [misc] partial prefix & random input generation benchmark (#9929) Signed-off-by: rickyx --- benchmarks/benchmark_prefix_caching.py | 116 +++++++++++++++++++------ 1 file changed, 91 insertions(+), 25 deletions(-) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 6d33096ca1d11..5e9381f712e10 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -54,13 +54,30 @@ def test_prefix(llm=None, sampling_params=None, prompts=None): print(f"cost time {end_time - start_time}") -def sample_requests( +@dataclasses.dataclass +class Request: + prompt: str + prompt_len: int + output_len: int + + +def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str: + vocab = tokenizer.get_vocab() + # Remove the special tokens. + vocab = { + k: v + for k, v in vocab.items() if k not in tokenizer.all_special_ids + } + return random.choices(list(vocab.values()), k=length) + + +def sample_requests_from_dataset( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, input_length_range: Tuple[int, int], fixed_output_len: Optional[int], -) -> List[Tuple[str, int, int]]: +) -> List[Request]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") @@ -77,31 +94,55 @@ def sample_requests( random.shuffle(dataset) min_len, max_len = input_length_range + assert min_len >= 0 and max_len >= min_len, "input_length_range too small" # Filter out sequences that are too long or too short - filtered_dataset: List[Tuple[str, int, int]] = [] + filtered_requests: List[Request] = [] + for i in range(len(dataset)): - if len(filtered_dataset) == num_requests: + if len(filtered_requests) == num_requests: break # Tokenize the prompts and completions. - prompt = dataset[i][0] - prompt_token_ids = tokenizer(prompt).input_ids + prompt_token_ids = tokenizer(dataset[i][0]).input_ids + prompt = tokenizer.decode(prompt_token_ids) completion = dataset[i][1] completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) - output_len = len(completion_token_ids - ) if fixed_output_len is None else fixed_output_len - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - continue + output_len = (len(completion_token_ids) + if fixed_output_len is None else fixed_output_len) if min_len <= prompt_len <= max_len: - filtered_dataset.append((prompt, prompt_len, output_len)) + filtered_requests.append(Request(prompt, prompt_len, output_len)) + + return filtered_requests + + +def sample_requests_from_random( + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + input_length_range: Tuple[int, int], + fixed_output_len: Optional[int], + prefix_len: int, +) -> List[Request]: - return filtered_dataset + requests = [] + prefix_token_ids = sample_tokens(tokenizer, prefix_len) + min_len, max_len = input_length_range + + for i in range(num_requests): + unique_part_token_ids = sample_tokens( + tokenizer, + random.randint(min_len - prefix_len, max_len - prefix_len)) + prompt_token_ids = prefix_token_ids + unique_part_token_ids + prompt = tokenizer.decode(prompt_token_ids) + prompt_len = len(prompt_token_ids) + assert (min_len <= prompt_len <= max_len + ), f"prompt_len {prompt_len} out of range {min_len}:{max_len}" + requests.append(Request(prompt, prompt_len, fixed_output_len)) + return requests -def repeat_and_sort_requests(requests: List[Tuple[str, int, int]], +def repeat_and_sort_requests(requests: List[Request], repeat_count: int, sort: bool = False) -> List[str]: repeated_requests = requests * repeat_count @@ -109,7 +150,7 @@ def repeat_and_sort_requests(requests: List[Tuple[str, int, int]], repeated_requests.sort(key=lambda x: x[1]) else: random.shuffle(repeated_requests) - return [req[0] for req in repeated_requests] + return [req.prompt for req in repeated_requests] def main(args): @@ -117,9 +158,12 @@ def main(args): input_length_range = tuple(map(int, args.input_length_range.split(':'))) random.seed(args.seed) if args.dataset_path is not None: - print(f"Start to sample {args.num_prompts} prompts" + if args.prefix_len > 0: + raise ValueError("prefix-len is not supported when " + "dataset-path is provided.") + print(f"Start to sample {args.num_prompts} prompts " f"from {args.dataset_path}") - filtered_datasets = sample_requests( + filtered_requests = sample_requests_from_dataset( dataset_path=args.dataset_path, num_requests=args.num_prompts, tokenizer=tokenizer, @@ -127,9 +171,22 @@ def main(args): fixed_output_len=args.output_len, ) else: - prompt_len = len(tokenizer(PROMPT).input_ids) - filtered_datasets = [(PROMPT, prompt_len, args.output_len) - ] * args.num_prompts + print(f"Start to sample {args.num_prompts} prompts from random") + filtered_requests = sample_requests_from_random( + num_requests=args.num_prompts, + tokenizer=tokenizer, + input_length_range=input_length_range, + fixed_output_len=args.output_len, + prefix_len=args.prefix_len, + ) + + # Print some helpful stats of the requests. + print(f"Sampled {len(filtered_requests)} requests.") + prompt_lens = [req.prompt_len for req in filtered_requests] + print(f"Average input length: {sum(prompt_lens) / len(prompt_lens)}") + print(f"P50 input length: {sorted(prompt_lens)[len(prompt_lens) // 2]}") + print(f"Min Prompt Length: {min(prompt_lens)}") + print(f"Max Prompt Length: {max(prompt_lens)}") engine_args = EngineArgs.from_cli_args(args) @@ -137,8 +194,8 @@ def main(args): sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) - print("Testing filtered datasets") - prompts = repeat_and_sort_requests(filtered_datasets, + print("Testing filtered requests") + prompts = repeat_and_sort_requests(filtered_requests, repeat_count=args.repeat_count, sort=args.sort) @@ -161,20 +218,29 @@ def main(args): parser.add_argument('--output-len', type=int, default=10) parser.add_argument('--num-prompts', type=int, - default=1, + required=True, help="Number of the prompts sampled from dataset") parser.add_argument('--repeat-count', type=int, - default=100, + default=1, help='Number of times to repeat each prompt') parser.add_argument('--sort', action='store_true', help='Sort prompts by input length') parser.add_argument('--input-length-range', type=str, - default='128:256', + required=True, help='Range of input lengths for sampling prompts,' 'specified as "min:max" (e.g., "128:256").') + parser.add_argument( + "--prefix-len", + type=int, + default=0, + help="Specifies the length of a common prefix to be " + "added to the input prompt. The input-length-range will " + "subtract this length when filtering prompts. Only used " + "when dataset-path is not provided.", + ) parser = EngineArgs.add_cli_args(parser) args = parser.parse_args()