Skip to content

Commit

Permalink
Model: Add fallback for freq and presence pen
Browse files Browse the repository at this point in the history
Previous behavior aliased freq pen for rep pen. Keep this behavior
when using the freq pen parameter with a legacy exllamav2 version
rather than ignoring both entirely.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Dec 30, 2023
1 parent 79a5758 commit 0dc12d8
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,14 +474,6 @@ def check_unsupported_settings(self, **kwargs):
"installed ExLlamaV2 version."
)

if (unwrap(kwargs.get("frequency_penalty"), 0.0)) != 0.0 and not hasattr(
ExLlamaV2Sampler.Settings, "token_frequency_penalty"
):
logger.warning(
"Frequency penalty is not supported by the currently "
"installed ExLlamaV2 version."
)

if (unwrap(kwargs.get("presence_penalty"), 0.0)) != 0.0 and not hasattr(
ExLlamaV2Sampler.Settings, "token_presence_penalty"
):
Expand Down Expand Up @@ -568,9 +560,7 @@ def generate_gen(self, prompt: str, **kwargs):
# Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
gen_settings.token_frequency_penalty = unwrap(
kwargs.get("frequency_penalty"), 0.0
)

gen_settings.token_presence_penalty = unwrap(
kwargs.get("presence_penalty"), 0.0
)
Expand All @@ -582,13 +572,28 @@ def generate_gen(self, prompt: str, **kwargs):
gen_settings.token_repetition_range = unwrap(
kwargs.get("penalty_range"), self.config.max_seq_len
)
auto_scale_penalty_range = False

# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled and the repetition range is -1
auto_scale_penalty_range = (
gen_settings.token_frequency_penalty != 0
or gen_settings.token_presence_penalty != 0
) and gen_settings.token_repetition_range == -1
# Frequency penalty = repetition penalty if the user is on an older exl2 version
frequency_penalty = unwrap(kwargs.get("frequency_penalty"), 0.0)
if (frequency_penalty) != 0.0 and not hasattr(
gen_settings, "token_frequency_penalty"
):
logger.warning(
"Frequency penalty is not supported by the currently "
"installed ExLlamaV2 version. Setting this value to repetition penalty "
"instead."
)
gen_settings.token_repetition_penalty = frequency_penalty
else:
# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled and the repetition range is -1
auto_scale_penalty_range = (
gen_settings.token_frequency_penalty != 0
or gen_settings.token_presence_penalty != 0
) and gen_settings.token_repetition_range == -1

gen_settings.token_frequency_penalty = frequency_penalty

# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed
Expand Down

0 comments on commit 0dc12d8

Please sign in to comment.