From dc1cf2fbc06b40cc8f2a955ccd010ab35bc8b734 Mon Sep 17 00:00:00 2001 From: Kai Wu Date: Tue, 1 Oct 2024 17:01:21 -0700 Subject: [PATCH 1/2] add random_seed to sample_hf_requests in benchmark_serving script --- benchmarks/benchmark_serving.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 56c37b241a359..b804e10b613eb 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -202,6 +202,7 @@ def sample_hf_requests( dataset_split: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, + random_seed: int, fixed_output_len: Optional[int] = None, ) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: dataset = load_dataset(dataset_path, @@ -210,7 +211,7 @@ def sample_hf_requests( streaming=True) assert "conversations" in dataset.features, ( "HF Dataset must have 'conversations' column.") - filtered_dataset = dataset.shuffle().filter( + filtered_dataset = dataset.shuffle(seed=random_seed).filter( lambda x: len(x["conversations"]) >= 2) sampled_requests: List[Tuple[str, int, int, Dict[str, Collection[str]]]] = [] @@ -651,6 +652,7 @@ def main(args: argparse.Namespace): dataset_split=args.hf_split, num_requests=args.num_prompts, tokenizer=tokenizer, + random_seed=args.seed, fixed_output_len=args.hf_output_len, ) From a0c0958f207129cc10918de2991011b1acb983ef Mon Sep 17 00:00:00 2001 From: Kai Wu Date: Thu, 17 Oct 2024 09:23:35 -0700 Subject: [PATCH 2/2] Update benchmarks/benchmark_serving.py formated Co-authored-by: Isotr0py <2037008807@qq.com> --- benchmarks/benchmark_serving.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 27c9062314d5b..1381004c9f02b 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -211,8 +211,8 @@ def sample_hf_requests( streaming=True) assert "conversations" in dataset.features, ( "HF Dataset must have 'conversations' column.") - filtered_dataset = dataset.shuffle(seed=random_seed).filter( - lambda x: len(x["conversations"]) >= 2) + filter_func = lambda x: len(x["conversations"]) >= 2 + filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func) sampled_requests: List[Tuple[str, int, int, Dict[str, Collection[str]]]] = [] for data in filtered_dataset: