diff --git a/OAI/types/common.py b/OAI/types/common.py index 5047e17b..a6b23810 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -77,9 +77,13 @@ class CommonCompletionRequest(BaseModel): logit_bias: Optional[Dict[int, float]] = None # Aliased variables - repetition_range: Optional[int] = Field( + penalty_range: Optional[int] = Field( default=-1, - validation_alias=AliasChoices("repetition_range", "repetition_penalty_range"), + validation_alias=AliasChoices( + "penalty_range", + "repetition_range", + "repetition_penalty_range", + ), ) def to_gen_params(self): @@ -106,7 +110,7 @@ def to_gen_params(self): "frequency_penalty": self.frequency_penalty, "presence_penalty": self.presence_penalty, "repetition_penalty": self.repetition_penalty, - "repetition_range": self.repetition_range, + "penalty_range": self.penalty_range, "repetition_decay": self.repetition_decay, "mirostat": self.mirostat_mode == 2, "mirostat_tau": self.mirostat_tau, diff --git a/model.py b/model.py index 0ddddb53..a1177aba 100644 --- a/model.py +++ b/model.py @@ -521,7 +521,7 @@ def generate_gen(self, prompt: str, **kwargs): 'presence_penalty' (float): Token presence penalty (default: 0.0) 'repetition_penalty' (float): Token repetition penalty (default: 1.15) - 'repetition_range' (int): Repetition penalty range + 'penalty_range' (int): Penalty range (default: whole context) 'repetition_decay' (int): Repetition penalty range (default: same as range) @@ -575,15 +575,24 @@ def generate_gen(self, prompt: str, **kwargs): gen_settings.token_repetition_penalty = unwrap( kwargs.get("repetition_penalty"), 1.0 ) + + # Applies for all penalties despite being called token_repetition_range gen_settings.token_repetition_range = unwrap( - kwargs.get("repetition_range"), self.config.max_seq_len + kwargs.get("penalty_range"), self.config.max_seq_len ) + # Dynamically scale penalty range to output tokens + # Only do this if freq/pres pen is enabled and the repetition range is -1 + auto_scale_penalty_range = ( + gen_settings.token_frequency_penalty != 0 + or gen_settings.token_presence_penalty != 0 + ) and gen_settings.token_repetition_range == -1 + # Always make sure the fallback is 0 if range < 0 # It's technically fine to use -1, but this just validates the passed # fallback # Always default to 0 if something goes wrong - if gen_settings.token_repetition_range <= 0: + if gen_settings.token_repetition_range < 0: fallback_decay = 0 else: fallback_decay = gen_settings.token_repetition_range @@ -609,6 +618,7 @@ def generate_gen(self, prompt: str, **kwargs): max_tokens=max_tokens, **vars(gen_settings), token_healing=token_healing, + auto_scale_penalty_range=auto_scale_penalty_range, add_bos_token=add_bos_token, ban_eos_token=ban_eos_token, stop_conditions=stop_conditions, @@ -684,6 +694,9 @@ def generate_gen(self, prompt: str, **kwargs): loras=self.active_loras, ) + if auto_scale_penalty_range: + gen_settings.token_repetition_range = generated_tokens + # Generate chunk, eos, tokens = self.generator.stream()