Skip to content

Commit

Permalink
Model: Fix frequency penalty fallback
Browse files Browse the repository at this point in the history
The appropriate branches weren't firing when frequency penalty is
0.0. Also fix repetition penalty overriding.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Dec 31, 2023
1 parent 47744fe commit 72bc303
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,26 +574,29 @@ def generate_gen(self, prompt: str, **kwargs):
)
auto_scale_penalty_range = False

# Frequency penalty = repetition penalty if the user is on an older exl2 version
frequency_penalty = unwrap(kwargs.get("frequency_penalty"), 0.0)
if (frequency_penalty) != 0.0 and not hasattr(
gen_settings, "token_frequency_penalty"
):
logger.warning(
"Frequency penalty is not supported by the currently "
"installed ExLlamaV2 version. Setting this value to repetition penalty "
"instead."
)
gen_settings.token_repetition_penalty = frequency_penalty
else:
if hasattr(gen_settings, "token_frequency_penalty"):
gen_settings.token_frequency_penalty = frequency_penalty

# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled and the repetition range is -1
# Only do this if freq/pres pen is enabled
# and the repetition range is -1
auto_scale_penalty_range = (
gen_settings.token_frequency_penalty != 0
or gen_settings.token_presence_penalty != 0
) and gen_settings.token_repetition_range == -1
elif frequency_penalty != 0.0:
logger.warning(
"Frequency penalty is not supported by the currently "
"installed ExLlamaV2 version."
)

# Override the repetition penalty value if it isn't set already
# if the user is on an older exl2 version
if unwrap(gen_settings.token_repetition_penalty, 1.0) == 1.0:
gen_settings.token_repetition_penalty = frequency_penalty
logger.warning("Setting this value to repetition penalty instead.")

gen_settings.token_frequency_penalty = frequency_penalty

# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed
Expand Down

0 comments on commit 72bc303

Please sign in to comment.