From 0478dce08f5282f265ce62e6d14045e095e23b02 Mon Sep 17 00:00:00 2001
From: Ricky Xu <rickyx@anyscale.com>
Date: Mon, 18 Nov 2024 15:39:14 -0800
Subject: [PATCH] [misc] partial prefix & random input generation benchmark
 (#9929)

Signed-off-by: rickyx <rickyx@anyscale.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
---
 benchmarks/benchmark_prefix_caching.py | 116 +++++++++++++++++++------
 1 file changed, 91 insertions(+), 25 deletions(-)

diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py
index 6d33096ca1d11..5e9381f712e10 100644
--- a/benchmarks/benchmark_prefix_caching.py
+++ b/benchmarks/benchmark_prefix_caching.py
@@ -54,13 +54,30 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
     print(f"cost time {end_time - start_time}")
 
 
-def sample_requests(
+@dataclasses.dataclass
+class Request:
+    prompt: str
+    prompt_len: int
+    output_len: int
+
+
+def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str:
+    vocab = tokenizer.get_vocab()
+    # Remove the special tokens.
+    vocab = {
+        k: v
+        for k, v in vocab.items() if k not in tokenizer.all_special_ids
+    }
+    return random.choices(list(vocab.values()), k=length)
+
+
+def sample_requests_from_dataset(
     dataset_path: str,
     num_requests: int,
     tokenizer: PreTrainedTokenizerBase,
     input_length_range: Tuple[int, int],
     fixed_output_len: Optional[int],
-) -> List[Tuple[str, int, int]]:
+) -> List[Request]:
     if fixed_output_len is not None and fixed_output_len < 4:
         raise ValueError("output_len too small")
 
@@ -77,31 +94,55 @@ def sample_requests(
     random.shuffle(dataset)
 
     min_len, max_len = input_length_range
+    assert min_len >= 0 and max_len >= min_len, "input_length_range too small"
 
     # Filter out sequences that are too long or too short
-    filtered_dataset: List[Tuple[str, int, int]] = []
+    filtered_requests: List[Request] = []
+
     for i in range(len(dataset)):
-        if len(filtered_dataset) == num_requests:
+        if len(filtered_requests) == num_requests:
             break
 
         # Tokenize the prompts and completions.
-        prompt = dataset[i][0]
-        prompt_token_ids = tokenizer(prompt).input_ids
+        prompt_token_ids = tokenizer(dataset[i][0]).input_ids
+        prompt = tokenizer.decode(prompt_token_ids)
         completion = dataset[i][1]
         completion_token_ids = tokenizer(completion).input_ids
         prompt_len = len(prompt_token_ids)
-        output_len = len(completion_token_ids
-                         ) if fixed_output_len is None else fixed_output_len
-        if prompt_len < 4 or output_len < 4:
-            # Prune too short sequences.
-            continue
+        output_len = (len(completion_token_ids)
+                      if fixed_output_len is None else fixed_output_len)
         if min_len <= prompt_len <= max_len:
-            filtered_dataset.append((prompt, prompt_len, output_len))
+            filtered_requests.append(Request(prompt, prompt_len, output_len))
+
+    return filtered_requests
+
+
+def sample_requests_from_random(
+    num_requests: int,
+    tokenizer: PreTrainedTokenizerBase,
+    input_length_range: Tuple[int, int],
+    fixed_output_len: Optional[int],
+    prefix_len: int,
+) -> List[Request]:
 
-    return filtered_dataset
+    requests = []
+    prefix_token_ids = sample_tokens(tokenizer, prefix_len)
+    min_len, max_len = input_length_range
+
+    for i in range(num_requests):
+        unique_part_token_ids = sample_tokens(
+            tokenizer,
+            random.randint(min_len - prefix_len, max_len - prefix_len))
+        prompt_token_ids = prefix_token_ids + unique_part_token_ids
+        prompt = tokenizer.decode(prompt_token_ids)
+        prompt_len = len(prompt_token_ids)
+        assert (min_len <= prompt_len <= max_len
+                ), f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
+        requests.append(Request(prompt, prompt_len, fixed_output_len))
+    return requests
 
 
-def repeat_and_sort_requests(requests: List[Tuple[str, int, int]],
+def repeat_and_sort_requests(requests: List[Request],
                              repeat_count: int,
                              sort: bool = False) -> List[str]:
     repeated_requests = requests * repeat_count
@@ -109,7 +150,7 @@ def repeat_and_sort_requests(requests: List[Tuple[str, int, int]],
         repeated_requests.sort(key=lambda x: x[1])
     else:
         random.shuffle(repeated_requests)
-    return [req[0] for req in repeated_requests]
+    return [req.prompt for req in repeated_requests]
 
 
 def main(args):
@@ -117,9 +158,12 @@ def main(args):
     input_length_range = tuple(map(int, args.input_length_range.split(':')))
     random.seed(args.seed)
     if args.dataset_path is not None:
-        print(f"Start to sample {args.num_prompts} prompts"
+        if args.prefix_len > 0:
+            raise ValueError("prefix-len is not supported when "
+                             "dataset-path is provided.")
+        print(f"Start to sample {args.num_prompts} prompts "
               f"from {args.dataset_path}")
-        filtered_datasets = sample_requests(
+        filtered_requests = sample_requests_from_dataset(
             dataset_path=args.dataset_path,
             num_requests=args.num_prompts,
             tokenizer=tokenizer,
@@ -127,9 +171,22 @@ def main(args):
             fixed_output_len=args.output_len,
         )
     else:
-        prompt_len = len(tokenizer(PROMPT).input_ids)
-        filtered_datasets = [(PROMPT, prompt_len, args.output_len)
-                             ] * args.num_prompts
+        print(f"Start to sample {args.num_prompts} prompts from random")
+        filtered_requests = sample_requests_from_random(
+            num_requests=args.num_prompts,
+            tokenizer=tokenizer,
+            input_length_range=input_length_range,
+            fixed_output_len=args.output_len,
+            prefix_len=args.prefix_len,
+        )
+
+    # Print some helpful stats of the requests.
+    print(f"Sampled {len(filtered_requests)} requests.")
+    prompt_lens = [req.prompt_len for req in filtered_requests]
+    print(f"Average input length: {sum(prompt_lens) / len(prompt_lens)}")
+    print(f"P50 input length: {sorted(prompt_lens)[len(prompt_lens) // 2]}")
+    print(f"Min Prompt Length: {min(prompt_lens)}")
+    print(f"Max Prompt Length: {max(prompt_lens)}")
 
     engine_args = EngineArgs.from_cli_args(args)
 
@@ -137,8 +194,8 @@ def main(args):
 
     sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
 
-    print("Testing filtered datasets")
-    prompts = repeat_and_sort_requests(filtered_datasets,
+    print("Testing filtered requests")
+    prompts = repeat_and_sort_requests(filtered_requests,
                                        repeat_count=args.repeat_count,
                                        sort=args.sort)
 
@@ -161,20 +218,29 @@ def main(args):
     parser.add_argument('--output-len', type=int, default=10)
     parser.add_argument('--num-prompts',
                         type=int,
-                        default=1,
+                        required=True,
                         help="Number of the prompts sampled from dataset")
     parser.add_argument('--repeat-count',
                         type=int,
-                        default=100,
+                        default=1,
                         help='Number of times to repeat each prompt')
     parser.add_argument('--sort',
                         action='store_true',
                         help='Sort prompts by input length')
     parser.add_argument('--input-length-range',
                         type=str,
-                        default='128:256',
+                        required=True,
                         help='Range of input lengths for sampling prompts,'
                         'specified as "min:max" (e.g., "128:256").')
+    parser.add_argument(
+        "--prefix-len",
+        type=int,
+        default=0,
+        help="Specifies the length of a common prefix to be "
+        "added to the input prompt. The input-length-range will "
+        "subtract this length when filtering prompts. Only used "
+        "when dataset-path is not provided.",
+    )
 
     parser = EngineArgs.add_cli_args(parser)
     args = parser.parse_args()