diff --git a/outlines/models/vllm.py b/outlines/models/vllm.py index 378a35d91..7c1f5ad7c 100644 --- a/outlines/models/vllm.py +++ b/outlines/models/vllm.py @@ -30,6 +30,7 @@ def generate( sampling_parameters: SamplingParameters, *, sampling_params: Optional["SamplingParams"] = None, + use_tqdm: bool = True, ): """Generate text using vLLM. @@ -47,11 +48,13 @@ def generate( An instance of `SamplingParameters`, a dataclass that contains the name of the sampler to use and related parameters as available in Outlines. - samplng_params + sampling_params An instance of `vllm.sampling_params.SamplingParams`. The values passed via this dataclass supersede the values of the parameters in `generation_parameters` and `sampling_parameters`. See the vLLM documentation for more details: https://docs.vllm.ai/en/latest/dev/sampling_params.html. + use_tqdm + A boolean in order to display progress bar while inferencing Returns ------- @@ -103,7 +106,10 @@ def generate( sampling_params.use_beam_search = True results = self.model.generate( - prompts, sampling_params=sampling_params, lora_request=self.lora_request + prompts, + sampling_params=sampling_params, + lora_request=self.lora_request, + use_tqdm=use_tqdm, ) results = [[sample.text for sample in batch.outputs] for batch in results]