Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
rickyyx committed Nov 1, 2024
1 parent 5f8d807 commit ee497a5
Showing 1 changed file with 82 additions and 21 deletions.
103 changes: 82 additions & 21 deletions benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,31 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
print(f"cost time {end_time - start_time}")


def sample_requests(
@dataclasses.dataclass
class Request:
prompt: str
prompt_len: int
output_len: int


def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str:
vocab = tokenizer.get_vocab()
# Remove the special tokens.
vocab = {
k: v
for k, v in vocab.items() if k not in tokenizer.all_special_ids
}
return random.choices(list(vocab.values()), k=length)


def sample_requests_from_dataset(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
input_length_range: Tuple[int, int],
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
prefix_len: int,
) -> List[Request]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

Expand All @@ -77,28 +95,54 @@ def sample_requests(
random.shuffle(dataset)

min_len, max_len = input_length_range
assert min_len >= 0 and max_len >= min_len, "input_length_range too small"

# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
filtered_requests: List[Request] = []
prefix_token_ids: List[int] = sample_tokens(tokenizer, prefix_len)

for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
if len(filtered_requests) == num_requests:
break

# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
prompt_token_ids = prefix_token_ids + tokenizer(
dataset[i][0]).input_ids
prompt = tokenizer.decode(prompt_token_ids)
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
output_len = (len(completion_token_ids)
if fixed_output_len is None else fixed_output_len)
if min_len <= prompt_len <= max_len:
filtered_dataset.append((prompt, prompt_len, output_len))
filtered_requests.append(Request(prompt, prompt_len, output_len))

return filtered_requests

return filtered_dataset

def sample_requests_from_random(
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
input_length_range: Tuple[int, int],
fixed_output_len: Optional[int],
prefix_len: int,
) -> List[Request]:

requests = []
prefix_token_ids = sample_tokens(tokenizer, prefix_len)
min_len, max_len = input_length_range

for i in range(num_requests):
unique_part_token_ids = sample_tokens(
tokenizer,
random.randint(min_len - prefix_len, max_len - prefix_len))
prompt_token_ids = prefix_token_ids + unique_part_token_ids
prompt = tokenizer.decode(prompt_token_ids)
prompt_len = len(prompt_token_ids)
assert (min_len <= prompt_len <= max_len
), f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
requests.append(Request(prompt, prompt_len, fixed_output_len))
return requests


def repeat_and_sort_requests(requests: List[Tuple[str, int, int]],
Expand All @@ -117,28 +161,37 @@ def main(args):
input_length_range = tuple(map(int, args.input_length_range.split(':')))
random.seed(args.seed)
if args.dataset_path is not None:
print(f"Start to sample {args.num_prompts} prompts"
"from {args.dataset_path}")
filtered_datasets = sample_requests(
print(
f"Start to sample {args.num_prompts} prompts from {args.dataset_path}"

Check failure on line 165 in benchmarks/benchmark_prefix_caching.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (E501)

benchmarks/benchmark_prefix_caching.py:165:81: E501 Line too long (82 > 80)

Check failure on line 165 in benchmarks/benchmark_prefix_caching.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff (E501)

benchmarks/benchmark_prefix_caching.py:165:81: E501 Line too long (82 > 80)

Check failure on line 165 in benchmarks/benchmark_prefix_caching.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

benchmarks/benchmark_prefix_caching.py:165:81: E501 Line too long (82 > 80)

Check failure on line 165 in benchmarks/benchmark_prefix_caching.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

benchmarks/benchmark_prefix_caching.py:165:81: E501 Line too long (82 > 80)
)
filtered_requests = sample_requests_from_dataset(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
input_length_range=input_length_range,
fixed_output_len=args.output_len,
prefix_len=args.prefix_len,
)
else:
prompt_len = len(tokenizer(PROMPT).input_ids)
filtered_datasets = [(PROMPT, prompt_len, args.output_len)
] * args.num_prompts
print(f"Start to sample {args.num_prompts} prompts from random")
filtered_requests = sample_requests_from_random(
num_requests=args.num_prompts,
tokenizer=tokenizer,
input_length_range=input_length_range,
fixed_output_len=args.output_len,
prefix_len=args.prefix_len,
)

print(f"Sampled {len(filtered_requests)} requests.")

engine_args = EngineArgs.from_cli_args(args)

llm = LLM(**dataclasses.asdict(engine_args))

sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)

print("Testing filtered datasets")
prompts = repeat_and_sort_requests(filtered_datasets,
print("Testing filtered requests")
prompts = repeat_and_sort_requests(filtered_requests,
repeat_count=args.repeat_count,
sort=args.sort)

Expand Down Expand Up @@ -182,6 +235,14 @@ def main(args):
default='128:256',
help='Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").')
parser.add_argument(
"--prefix-len",
type=int,
default=0,
help="Specifies the length of a common prefix to be "
"added to the input prompt. The input-length-range will "
"subtract this length when filtering prompts. ",
)

parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
Expand Down

0 comments on commit ee497a5

Please sign in to comment.