diff --git a/src/guidellm/main.py b/src/guidellm/main.py index 4016ece..4089030 100644 --- a/src/guidellm/main.py +++ b/src/guidellm/main.py @@ -151,6 +151,15 @@ "until the user exits. " ), ) +@click.option( + "--subset-size", + type=int, + default=None, + help=( + "The number of subsets to use from the dataset. " + "If not provided, all subsets will be used." + ), +) def generate_benchmark_report_cli( target: str, backend: BackendEnginePublic, @@ -164,6 +173,7 @@ def generate_benchmark_report_cli( max_requests: Union[Literal["dataset"], int, None], output_path: str, enable_continuous_refresh: bool, + subset_size: Optional[int], ): """ Generate a benchmark report for a specified backend and dataset. @@ -181,6 +191,7 @@ def generate_benchmark_report_cli( max_requests=max_requests, output_path=output_path, cont_refresh_table=enable_continuous_refresh, + subset_size=subset_size, ) @@ -197,6 +208,7 @@ def generate_benchmark_report( max_requests: Union[Literal["dataset"], int, None], output_path: str, cont_refresh_table: bool, + subset_size: Optional[int], ) -> GuidanceReport: """ Generate a benchmark report for a specified backend and dataset. @@ -251,7 +263,7 @@ def generate_benchmark_report( request_generator = FileRequestGenerator(path=data, tokenizer=tokenizer_inst) elif data_type == "transformers": request_generator = TransformersDatasetRequestGenerator( - dataset=data, tokenizer=tokenizer_inst + dataset=data, tokenizer=tokenizer_inst, subset_size=subset_size ) else: raise ValueError(f"Unknown data type: {data_type}") diff --git a/src/guidellm/request/transformers.py b/src/guidellm/request/transformers.py index 3fd2404..d9ac2e1 100644 --- a/src/guidellm/request/transformers.py +++ b/src/guidellm/request/transformers.py @@ -33,6 +33,8 @@ class TransformersDatasetRequestGenerator(RequestGenerator): :type mode: str :param async_queue_size: The size of the request queue. :type async_queue_size: int + :param subset_size: The number of the subsets to use from the database. + :type subset_size: Optional[int] """ def __init__( @@ -45,6 +47,7 @@ def __init__( tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, mode: GenerationMode = "async", async_queue_size: int = 50, + subset_size: Optional[int] = None, **kwargs, ): self._dataset = dataset @@ -58,6 +61,9 @@ def __init__( self._hf_column = resolve_transformers_dataset_column( self._hf_dataset, column=column ) + if subset_size is not None and isinstance(self._hf_dataset, Dataset): + self._hf_dataset = self._hf_dataset.select(range(subset_size)) + self._hf_dataset_iterator = iter(self._hf_dataset) # NOTE: Must be after all the parameters since the queue population