From 2f72195012aabc98ccd78c77b7423d3748013fb7 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Tue, 8 Oct 2024 22:54:48 -0600 Subject: [PATCH] [Misc] Improve validation errors around best_of and n (#9167) Signed-off-by: Travis Johnson --- vllm/sampling_params.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index e074312280584..95345df43b57d 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -330,8 +330,8 @@ def _verify_args(self) -> None: if self.n < 1: raise ValueError(f"n must be at least 1, got {self.n}.") if not isinstance(self.best_of, int): - raise ValueError(f'best_of must be an int, but is of ' - f'type {type(self.best_of)}') + raise ValueError(f"best_of must be an int, but is of " + f"type {type(self.best_of)}") if self.best_of < self.n: raise ValueError(f"best_of must be greater than or equal to n, " f"got n={self.n} and best_of={self.best_of}.") @@ -390,10 +390,13 @@ def _verify_args(self) -> None: raise ValueError("best_of must equal n to use output_kind=DELTA") def _verify_greedy_sampling(self) -> None: + if self.n > 1: + raise ValueError("n must be 1 when using greedy sampling, " + f"got {self.n}.") assert isinstance(self.best_of, int) if self.best_of > 1: - raise ValueError("best_of must be 1 when using greedy sampling." - f"Got {self.best_of}.") + raise ValueError("best_of must be 1 when using greedy sampling, " + f"got {self.best_of}.") def update_from_generation_config( self,