From 52ebc982c11cbb96d7d1a8da1c5755c7b8a58d7d Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 8 Feb 2024 00:17:48 -0500 Subject: [PATCH] Sampling: Fix dynatemp defaults Default max temp and min temp is 1.0 Signed-off-by: kingbri --- backends/exllamav2/model.py | 6 +++--- common/sampling.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 84c6724b..ea45fdcd 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -537,8 +537,8 @@ def generate_gen(self, prompt: str, **kwargs): gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False) # DynaTemp settings - max_temp = unwrap(kwargs.get("max_temp"), 0.0) - min_temp = unwrap(kwargs.get("min_temp"), 0.0) + max_temp = unwrap(kwargs.get("max_temp"), 1.0) + min_temp = unwrap(kwargs.get("min_temp"), 1.0) if max_temp > min_temp: gen_settings.max_temp = max_temp @@ -553,7 +553,7 @@ def generate_gen(self, prompt: str, **kwargs): # Warn if max/min temp values are > 0 # and if they're less than or equal to each other if max_temp < min_temp or ( - 0 not in {min_temp, max_temp} and max_temp == min_temp + 1 not in {min_temp, max_temp} and max_temp == min_temp ): logger.warning( "Max temp is less than or equal to min temp, skipping DynaTemp." diff --git a/common/sampling.py b/common/sampling.py index a9ea9d32..e148acdd 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -142,13 +142,13 @@ class BaseSamplerRequest(BaseModel): ) max_temp: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("max_temp", 0.0), + default_factory=lambda: get_default_sampler_value("max_temp", 1.0), validation_alias=AliasChoices("max_temp", "dynatemp_high"), description="Aliases: dynatemp_high", ) min_temp: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("min_temp", 0.0), + default_factory=lambda: get_default_sampler_value("min_temp", 1.0), validation_alias=AliasChoices("min_temp", "dynatemp_low"), description="Aliases: dynatemp_low", )