diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 1aac029992dbf..a5bf6127f47d7 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -54,13 +54,31 @@ def test_prefix(llm=None, sampling_params=None, prompts=None): print(f"cost time {end_time - start_time}") -def sample_requests( +@dataclasses.dataclass +class Request: + prompt: str + prompt_len: int + output_len: int + + +def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str: + vocab = tokenizer.get_vocab() + # Remove the special tokens. + vocab = { + k: v + for k, v in vocab.items() if k not in tokenizer.all_special_ids + } + return random.choices(list(vocab.values()), k=length) + + +def sample_requests_from_dataset( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, input_length_range: Tuple[int, int], fixed_output_len: Optional[int], -) -> List[Tuple[str, int, int]]: + prefix_len: int, +) -> List[Request]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") @@ -77,28 +95,54 @@ def sample_requests( random.shuffle(dataset) min_len, max_len = input_length_range + assert min_len >= 0 and max_len >= min_len, "input_length_range too small" # Filter out sequences that are too long or too short - filtered_dataset: List[Tuple[str, int, int]] = [] + filtered_requests: List[Request] = [] + prefix_token_ids: List[int] = sample_tokens(tokenizer, prefix_len) + for i in range(len(dataset)): - if len(filtered_dataset) == num_requests: + if len(filtered_requests) == num_requests: break # Tokenize the prompts and completions. - prompt = dataset[i][0] - prompt_token_ids = tokenizer(prompt).input_ids + prompt_token_ids = prefix_token_ids + tokenizer( + dataset[i][0]).input_ids + prompt = tokenizer.decode(prompt_token_ids) completion = dataset[i][1] completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) - output_len = len(completion_token_ids - ) if fixed_output_len is None else fixed_output_len - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - continue + output_len = (len(completion_token_ids) + if fixed_output_len is None else fixed_output_len) if min_len <= prompt_len <= max_len: - filtered_dataset.append((prompt, prompt_len, output_len)) + filtered_requests.append(Request(prompt, prompt_len, output_len)) + + return filtered_requests - return filtered_dataset + +def sample_requests_from_random( + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + input_length_range: Tuple[int, int], + fixed_output_len: Optional[int], + prefix_len: int, +) -> List[Request]: + + requests = [] + prefix_token_ids = sample_tokens(tokenizer, prefix_len) + min_len, max_len = input_length_range + + for i in range(num_requests): + unique_part_token_ids = sample_tokens( + tokenizer, + random.randint(min_len - prefix_len, max_len - prefix_len)) + prompt_token_ids = prefix_token_ids + unique_part_token_ids + prompt = tokenizer.decode(prompt_token_ids) + prompt_len = len(prompt_token_ids) + assert (min_len <= prompt_len <= max_len + ), f"prompt_len {prompt_len} out of range {min_len}:{max_len}" + requests.append(Request(prompt, prompt_len, fixed_output_len)) + return requests def repeat_and_sort_requests(requests: List[Tuple[str, int, int]], @@ -117,19 +161,28 @@ def main(args): input_length_range = tuple(map(int, args.input_length_range.split(':'))) random.seed(args.seed) if args.dataset_path is not None: - print(f"Start to sample {args.num_prompts} prompts" - "from {args.dataset_path}") - filtered_datasets = sample_requests( + print( + f"Start to sample {args.num_prompts} prompts from {args.dataset_path}" + ) + filtered_requests = sample_requests_from_dataset( dataset_path=args.dataset_path, num_requests=args.num_prompts, tokenizer=tokenizer, input_length_range=input_length_range, fixed_output_len=args.output_len, + prefix_len=args.prefix_len, ) else: - prompt_len = len(tokenizer(PROMPT).input_ids) - filtered_datasets = [(PROMPT, prompt_len, args.output_len) - ] * args.num_prompts + print(f"Start to sample {args.num_prompts} prompts from random") + filtered_requests = sample_requests_from_random( + num_requests=args.num_prompts, + tokenizer=tokenizer, + input_length_range=input_length_range, + fixed_output_len=args.output_len, + prefix_len=args.prefix_len, + ) + + print(f"Sampled {len(filtered_requests)} requests.") engine_args = EngineArgs.from_cli_args(args) @@ -137,8 +190,8 @@ def main(args): sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) - print("Testing filtered datasets") - prompts = repeat_and_sort_requests(filtered_datasets, + print("Testing filtered requests") + prompts = repeat_and_sort_requests(filtered_requests, repeat_count=args.repeat_count, sort=args.sort) @@ -182,6 +235,14 @@ def main(args): default='128:256', help='Range of input lengths for sampling prompts,' 'specified as "min:max" (e.g., "128:256").') + parser.add_argument( + "--prefix-len", + type=int, + default=0, + help="Specifies the length of a common prefix to be " + "added to the input prompt. The input-length-range will " + "subtract this length when filtering prompts. ", + ) parser = EngineArgs.add_cli_args(parser) args = parser.parse_args()