Skip to content

Commit

Permalink
[Bugfix] Generate exactly input_len tokens in benchmark_throughput (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 authored Oct 23, 2024
1 parent 208cb34 commit 65050a4
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,16 @@ def main(args: argparse.Namespace):
args.tokenizer, trust_remote_code=args.trust_remote_code)
if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
# As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length.
for i in range(-10, 10):
prompt = "hi " * (args.input_len + i)
tokenized_prompt = tokenizer(prompt).input_ids
if len(tokenized_prompt) == args.input_len:
break
else:
raise ValueError(
f"Failed to synthesize a prompt with {args.input_len} tokens.")
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
else:
Expand Down

0 comments on commit 65050a4

Please sign in to comment.