From 0dc12d82d545fb46562ab58df393aded5c413611 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 30 Dec 2023 00:24:15 -0500 Subject: [PATCH] Model: Add fallback for freq and presence pen Previous behavior aliased freq pen for rep pen. Keep this behavior when using the freq pen parameter with a legacy exllamav2 version rather than ignoring both entirely. Signed-off-by: kingbri --- model.py | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/model.py b/model.py index 66b7eb55..d10cfe80 100644 --- a/model.py +++ b/model.py @@ -474,14 +474,6 @@ def check_unsupported_settings(self, **kwargs): "installed ExLlamaV2 version." ) - if (unwrap(kwargs.get("frequency_penalty"), 0.0)) != 0.0 and not hasattr( - ExLlamaV2Sampler.Settings, "token_frequency_penalty" - ): - logger.warning( - "Frequency penalty is not supported by the currently " - "installed ExLlamaV2 version." - ) - if (unwrap(kwargs.get("presence_penalty"), 0.0)) != 0.0 and not hasattr( ExLlamaV2Sampler.Settings, "token_presence_penalty" ): @@ -568,9 +560,7 @@ def generate_gen(self, prompt: str, **kwargs): # Default tau and eta fallbacks don't matter if mirostat is off gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5) gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1) - gen_settings.token_frequency_penalty = unwrap( - kwargs.get("frequency_penalty"), 0.0 - ) + gen_settings.token_presence_penalty = unwrap( kwargs.get("presence_penalty"), 0.0 ) @@ -582,13 +572,28 @@ def generate_gen(self, prompt: str, **kwargs): gen_settings.token_repetition_range = unwrap( kwargs.get("penalty_range"), self.config.max_seq_len ) + auto_scale_penalty_range = False - # Dynamically scale penalty range to output tokens - # Only do this if freq/pres pen is enabled and the repetition range is -1 - auto_scale_penalty_range = ( - gen_settings.token_frequency_penalty != 0 - or gen_settings.token_presence_penalty != 0 - ) and gen_settings.token_repetition_range == -1 + # Frequency penalty = repetition penalty if the user is on an older exl2 version + frequency_penalty = unwrap(kwargs.get("frequency_penalty"), 0.0) + if (frequency_penalty) != 0.0 and not hasattr( + gen_settings, "token_frequency_penalty" + ): + logger.warning( + "Frequency penalty is not supported by the currently " + "installed ExLlamaV2 version. Setting this value to repetition penalty " + "instead." + ) + gen_settings.token_repetition_penalty = frequency_penalty + else: + # Dynamically scale penalty range to output tokens + # Only do this if freq/pres pen is enabled and the repetition range is -1 + auto_scale_penalty_range = ( + gen_settings.token_frequency_penalty != 0 + or gen_settings.token_presence_penalty != 0 + ) and gen_settings.token_repetition_range == -1 + + gen_settings.token_frequency_penalty = frequency_penalty # Always make sure the fallback is 0 if range < 0 # It's technically fine to use -1, but this just validates the passed