Skip to content

Commit

Permalink
Model: Repetition penalty range -> penalty range
Browse files Browse the repository at this point in the history
All penalties can have a sustain (range) applied to them in exl2,
so clarify the parameter.

However, the default behaviors change based on if freq OR pres pen
is enabled. For the sanity of OAI users, have freq and pres pen only
apply on the output tokens when range is -1 (default).

But, repetition penalty still functions the same way where -1 means
the range is the max seq len.

Doing this prevents gibberish output when using the more modern freq
and presence penalties similar to llamacpp.

NOTE: This logic is still subject to change in the future, but I believe
it hits the happy medium for users who want defaults and users who want
to tinker around with the sampling knobs.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Dec 28, 2023
1 parent c72d309 commit 5dc2df6
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
10 changes: 7 additions & 3 deletions OAI/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,13 @@ class CommonCompletionRequest(BaseModel):
logit_bias: Optional[Dict[int, float]] = None

# Aliased variables
repetition_range: Optional[int] = Field(
penalty_range: Optional[int] = Field(
default=-1,
validation_alias=AliasChoices("repetition_range", "repetition_penalty_range"),
validation_alias=AliasChoices(
"penalty_range",
"repetition_range",
"repetition_penalty_range",
),
)

def to_gen_params(self):
Expand All @@ -106,7 +110,7 @@ def to_gen_params(self):
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"repetition_penalty": self.repetition_penalty,
"repetition_range": self.repetition_range,
"penalty_range": self.penalty_range,
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,
Expand Down
19 changes: 16 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def generate_gen(self, prompt: str, **kwargs):
'presence_penalty' (float): Token presence penalty (default: 0.0)
'repetition_penalty' (float): Token repetition penalty
(default: 1.15)
'repetition_range' (int): Repetition penalty range
'penalty_range' (int): Penalty range
(default: whole context)
'repetition_decay' (int): Repetition penalty range
(default: same as range)
Expand Down Expand Up @@ -575,15 +575,24 @@ def generate_gen(self, prompt: str, **kwargs):
gen_settings.token_repetition_penalty = unwrap(
kwargs.get("repetition_penalty"), 1.0
)

# Applies for all penalties despite being called token_repetition_range
gen_settings.token_repetition_range = unwrap(
kwargs.get("repetition_range"), self.config.max_seq_len
kwargs.get("penalty_range"), self.config.max_seq_len
)

# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled and the repetition range is -1
auto_scale_penalty_range = (
gen_settings.token_frequency_penalty != 0
or gen_settings.token_presence_penalty != 0
) and gen_settings.token_repetition_range == -1

# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed
# fallback
# Always default to 0 if something goes wrong
if gen_settings.token_repetition_range <= 0:
if gen_settings.token_repetition_range < 0:
fallback_decay = 0
else:
fallback_decay = gen_settings.token_repetition_range
Expand All @@ -609,6 +618,7 @@ def generate_gen(self, prompt: str, **kwargs):
max_tokens=max_tokens,
**vars(gen_settings),
token_healing=token_healing,
auto_scale_penalty_range=auto_scale_penalty_range,
add_bos_token=add_bos_token,
ban_eos_token=ban_eos_token,
stop_conditions=stop_conditions,
Expand Down Expand Up @@ -684,6 +694,9 @@ def generate_gen(self, prompt: str, **kwargs):
loras=self.active_loras,
)

if auto_scale_penalty_range:
gen_settings.token_repetition_range = generated_tokens

# Generate
chunk, eos, tokens = self.generator.stream()

Expand Down

0 comments on commit 5dc2df6

Please sign in to comment.