diff --git a/common/sampling.py b/common/sampling.py index c2b4c3c7..7e5ded46 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -25,7 +25,9 @@ class BaseSamplerRequest(BaseModel): max_tokens: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("max_tokens"), - validation_alias=AliasChoices("max_tokens", "max_length"), + validation_alias=AliasChoices( + "max_tokens", "max_completion_tokens", "max_length" + ), description="Aliases: max_length", examples=[150], ge=0,