From fa2acb2828d2ed4888f4e2c8e6e9da325e6660ed Mon Sep 17 00:00:00 2001 From: erinmaybe <90116178+erinmaybe@users.noreply.github.com> Date: Sat, 3 Feb 2024 21:51:29 -0500 Subject: [PATCH] Adds aliases for min_temp and max_temp (#58) * Adds aliases for min_temp and max_temp * Sampling: Add dynatemp_exponent alias --- common/sampling.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/common/sampling.py b/common/sampling.py index 53c7b2ee..a9ea9d32 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -43,19 +43,6 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("temperature_last", False) ) - max_temp: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("max_temp", 0.0), - ) - - min_temp: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("min_temp", 0.0), - ) - - temp_exponent: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0), - examples=[1.0], - ) - smoothing_factor: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0), ) @@ -154,6 +141,24 @@ class BaseSamplerRequest(BaseModel): examples=[1.0], ) + max_temp: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("max_temp", 0.0), + validation_alias=AliasChoices("max_temp", "dynatemp_high"), + description="Aliases: dynatemp_high", + ) + + min_temp: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("min_temp", 0.0), + validation_alias=AliasChoices("min_temp", "dynatemp_low"), + description="Aliases: dynatemp_low", + ) + + temp_exponent: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0), + validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"), + examples=[1.0], + ) + def to_gen_params(self): """Converts samplers to internal generation params"""