From 5dc2df68be05e5d9ad5f604b4856061280b5b8f0 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 28 Dec 2023 18:10:19 -0500 Subject: [PATCH] Model: Repetition penalty range -> penalty range All penalties can have a sustain (range) applied to them in exl2, so clarify the parameter. However, the default behaviors change based on if freq OR pres pen is enabled. For the sanity of OAI users, have freq and pres pen only apply on the output tokens when range is -1 (default). But, repetition penalty still functions the same way where -1 means the range is the max seq len. Doing this prevents gibberish output when using the more modern freq and presence penalties similar to llamacpp. NOTE: This logic is still subject to change in the future, but I believe it hits the happy medium for users who want defaults and users who want to tinker around with the sampling knobs. Signed-off-by: kingbri --- OAI/types/common.py | 10 +++++++--- model.py | 19 ++++++++++++++++--- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/OAI/types/common.py b/OAI/types/common.py index 5047e17b..a6b23810 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -77,9 +77,13 @@ class CommonCompletionRequest(BaseModel): logit_bias: Optional[Dict[int, float]] = None # Aliased variables - repetition_range: Optional[int] = Field( + penalty_range: Optional[int] = Field( default=-1, - validation_alias=AliasChoices("repetition_range", "repetition_penalty_range"), + validation_alias=AliasChoices( + "penalty_range", + "repetition_range", + "repetition_penalty_range", + ), ) def to_gen_params(self): @@ -106,7 +110,7 @@ def to_gen_params(self): "frequency_penalty": self.frequency_penalty, "presence_penalty": self.presence_penalty, "repetition_penalty": self.repetition_penalty, - "repetition_range": self.repetition_range, + "penalty_range": self.penalty_range, "repetition_decay": self.repetition_decay, "mirostat": self.mirostat_mode == 2, "mirostat_tau": self.mirostat_tau, diff --git a/model.py b/model.py index 0ddddb53..a1177aba 100644 --- a/model.py +++ b/model.py @@ -521,7 +521,7 @@ def generate_gen(self, prompt: str, **kwargs): 'presence_penalty' (float): Token presence penalty (default: 0.0) 'repetition_penalty' (float): Token repetition penalty (default: 1.15) - 'repetition_range' (int): Repetition penalty range + 'penalty_range' (int): Penalty range (default: whole context) 'repetition_decay' (int): Repetition penalty range (default: same as range) @@ -575,15 +575,24 @@ def generate_gen(self, prompt: str, **kwargs): gen_settings.token_repetition_penalty = unwrap( kwargs.get("repetition_penalty"), 1.0 ) + + # Applies for all penalties despite being called token_repetition_range gen_settings.token_repetition_range = unwrap( - kwargs.get("repetition_range"), self.config.max_seq_len + kwargs.get("penalty_range"), self.config.max_seq_len ) + # Dynamically scale penalty range to output tokens + # Only do this if freq/pres pen is enabled and the repetition range is -1 + auto_scale_penalty_range = ( + gen_settings.token_frequency_penalty != 0 + or gen_settings.token_presence_penalty != 0 + ) and gen_settings.token_repetition_range == -1 + # Always make sure the fallback is 0 if range < 0 # It's technically fine to use -1, but this just validates the passed # fallback # Always default to 0 if something goes wrong - if gen_settings.token_repetition_range <= 0: + if gen_settings.token_repetition_range < 0: fallback_decay = 0 else: fallback_decay = gen_settings.token_repetition_range @@ -609,6 +618,7 @@ def generate_gen(self, prompt: str, **kwargs): max_tokens=max_tokens, **vars(gen_settings), token_healing=token_healing, + auto_scale_penalty_range=auto_scale_penalty_range, add_bos_token=add_bos_token, ban_eos_token=ban_eos_token, stop_conditions=stop_conditions, @@ -684,6 +694,9 @@ def generate_gen(self, prompt: str, **kwargs): loras=self.active_loras, ) + if auto_scale_penalty_range: + gen_settings.token_repetition_range = generated_tokens + # Generate chunk, eos, tokens = self.generator.stream()