Skip to content

Commit

Permalink
added quadratic sampling (#56)
Browse files Browse the repository at this point in the history
* added quadratic sampling

* Update sample_preset.yml

* oops missed a spot

* Sampling: Fix smoothing factor semantics
  • Loading branch information
AAbushady authored Feb 3, 2024
1 parent 4a7b8b1 commit d7c1885
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 0 deletions.
9 changes: 9 additions & 0 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,14 @@ def check_unsupported_settings(self, **kwargs):
"installed ExLlamaV2 version."
)

if (unwrap(kwargs.get("smoothing_factor"), 0.0)) > 0.0 and not hasattr(
ExLlamaV2Sampler.Settings, "smoothing_factor"
):
logger.warning(
"Smoothing factor is not supported by the currently "
"installed ExLlamaV2 version."
)

def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
generation = list(self.generate_gen(prompt, **kwargs))
Expand Down Expand Up @@ -593,6 +601,7 @@ def generate_gen(self, prompt: str, **kwargs):
# Apply settings
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False)
gen_settings.smoothing_factor = unwrap(kwargs.get("smoothing_factor"), 0.0)
gen_settings.top_k = unwrap(kwargs.get("top_k"), 0)
gen_settings.top_p = unwrap(kwargs.get("top_p"), 1.0)
gen_settings.top_a = unwrap(kwargs.get("top_a"), 0.0)
Expand Down
5 changes: 5 additions & 0 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class SamplerParams(BaseModel):
examples=[1.0],
)

smoothing_factor: Optional[float] = Field(
default_factor=lambda: get_default_sampler_value("smoothing_factor", 0.0),
)

top_k: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("top_k", 0)
)
Expand Down Expand Up @@ -173,6 +177,7 @@ def to_gen_params(self):
"min_temp": self.min_temp,
"max_temp": self.max_temp,
"temp_exponent": self.temp_exponent,
"smoothing_factor": self.smoothing_factor,
"top_k": self.top_k,
"top_p": self.top_p,
"top_a": self.top_a,
Expand Down
3 changes: 3 additions & 0 deletions sampler_overrides/sample_preset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ max_temp:
temp_exponent:
override: 0.0
force: false
smoothing_factor:
override: 0.0
force: false

# MARK: Alphabet soup
top_k:
Expand Down

0 comments on commit d7c1885

Please sign in to comment.