Skip to content

Commit

Permalink
Comments
Browse files Browse the repository at this point in the history
Signed-off-by: rickyx <[email protected]>
  • Loading branch information
rickyyx committed Nov 7, 2024
1 parent b1bf34f commit 987613b
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def sample_requests_from_dataset(
tokenizer: PreTrainedTokenizerBase,
input_length_range: Tuple[int, int],
fixed_output_len: Optional[int],
prefix_len: int,
) -> List[Request]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
Expand All @@ -99,15 +98,13 @@ def sample_requests_from_dataset(

# Filter out sequences that are too long or too short
filtered_requests: List[Request] = []
prefix_token_ids: List[int] = sample_tokens(tokenizer, prefix_len)

for i in range(len(dataset)):
if len(filtered_requests) == num_requests:
break

# Tokenize the prompts and completions.
prompt_token_ids = prefix_token_ids + tokenizer(
dataset[i][0]).input_ids
prompt_token_ids = tokenizer(dataset[i][0]).input_ids
prompt = tokenizer.decode(prompt_token_ids)
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
Expand Down Expand Up @@ -161,6 +158,9 @@ def main(args):
input_length_range = tuple(map(int, args.input_length_range.split(':')))
random.seed(args.seed)
if args.dataset_path is not None:
if args.prefix_len > 0:
raise ValueError("prefix-len is not supported when "
"dataset-path is provided.")
print(f"Start to sample {args.num_prompts} prompts "
f"from {args.dataset_path}")
filtered_requests = sample_requests_from_dataset(
Expand All @@ -169,7 +169,6 @@ def main(args):
tokenizer=tokenizer,
input_length_range=input_length_range,
fixed_output_len=args.output_len,
prefix_len=args.prefix_len,
)
else:
print(f"Start to sample {args.num_prompts} prompts from random")
Expand All @@ -181,7 +180,13 @@ def main(args):
prefix_len=args.prefix_len,
)

# Print some helpful stats of the requests.
print(f"Sampled {len(filtered_requests)} requests.")
prompt_lens = [req.prompt_len for req in filtered_requests]
print(f"Average input length: {sum(prompt_lens) / len(prompt_lens)}")
print(f"P50 input length: {sorted(prompt_lens)[len(prompt_lens) // 2]}")
print(f"Min Prompt Length: {min(prompt_lens)}")
print(f"Max Prompt Length: {max(prompt_lens)}")

engine_args = EngineArgs.from_cli_args(args)

Expand Down Expand Up @@ -240,7 +245,8 @@ def main(args):
default=0,
help="Specifies the length of a common prefix to be "
"added to the input prompt. The input-length-range will "
"subtract this length when filtering prompts. ",
"subtract this length when filtering prompts. Only used "
"when dataset-path is not provided.",
)

parser = EngineArgs.add_cli_args(parser)
Expand Down

0 comments on commit 987613b

Please sign in to comment.