Skip to content

Commit

Permalink
Sampling: Update DRY
Browse files Browse the repository at this point in the history
Switch to new parameters and remove dry_max_ngram as that's not supposed
to be changed.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
kingbri1 committed Sep 7, 2024
1 parent 05c3f11 commit ae37f3f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
17 changes: 12 additions & 5 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,14 +1081,21 @@ async def generate_gen(
)

# DRY options
dry_allowed_length = unwrap(kwargs.get("dry_allowed_length"), 0)
dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 0.0)

# 0 = disabled
if dry_allowed_length:
gen_settings.dry_allowed_length = dry_allowed_length
# < 0 = disabled
if dry_multiplier > 0:
gen_settings.dry_allowed_length = unwrap(
kwargs.get("dry_allowed_length"), 0
)
gen_settings.dry_base = unwrap(kwargs.get("dry_base"), 2.0)
gen_settings.dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 2.0)
gen_settings.dry_max_ngram = unwrap(kwargs.get("dry_max_ngram"), 20)

# Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range
# Use max_seq_len as the fallback to stay consistent
gen_settings.dry_range = unwrap(
kwargs.get("dry_range"), self.config.max_seq_len
)

# Tokenize sequence breakers
dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers")
Expand Down
9 changes: 4 additions & 5 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("dry_multiplier", 2.0)
)

# TODO: Remove these aliases
dry_max_ngram: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("dry_max_ngram", 20),
alias=AliasChoices("dry_max_ngram", "dry_penalty_last_n"),
dry_range: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("dry_range", 0),
alias=AliasChoices("dry_range", "dry_penalty_last_n"),
description=("Aliases: dry_penalty_last_n"),
)

dry_sequence_breakers: Optional[str] = Field(
Expand Down Expand Up @@ -371,7 +371,6 @@ def to_gen_params(self, **kwargs):
"penalty_range": self.penalty_range,
"dry_allowed_length": self.dry_allowed_length,
"dry_base": self.dry_base,
"dry_max_ngram": self.dry_max_ngram,
"dry_multiplier": self.dry_multiplier,
"dry_sequence_breakers": self.dry_sequence_breakers,
"repetition_decay": self.repetition_decay,
Expand Down

0 comments on commit ae37f3f

Please sign in to comment.