Skip to content

Commit

Permalink
API + Model: Add new parameters and clean up documentation
Browse files Browse the repository at this point in the history
The example JSON fields were changed because of the new sampler
default strategy. Fix these by manually changing the values.

Also add support for fasttensors and expose generate_window to
the API. It's recommended to not adjust generate_window as it's
dynamically scaled based on max_seq_len by default.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Jan 25, 2024
1 parent 97d4f40 commit 21c6b1a
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 10 deletions.
18 changes: 16 additions & 2 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,25 @@ def progress(loaded_modules: int, total_modules: int,
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
)

# Enable CFG if present
use_cfg = unwrap(kwargs.get("use_cfg"), False)
if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"):
self.use_cfg = unwrap(kwargs.get("use_cfg"), False)
else:
self.use_cfg = use_cfg
elif use_cfg:
logger.warning(
"CFG is not supported by the currently installed ExLlamaV2 version."
)

# Enable fasttensors loading if present
use_fasttensors = unwrap(kwargs.get("fasttensors"), False)
if hasattr(ExLlamaV2Config, "fasttensors"):
self.config.fasttensors = use_fasttensors
elif use_fasttensors:
logger.warning(
"fasttensors is not supported by "
"the currently installed ExllamaV2 version."
)

# Turn off flash attention if CFG is on
# Workaround until batched FA2 is fixed in exllamav2 upstream
self.config.no_flash_attn = (
Expand Down Expand Up @@ -563,6 +575,7 @@ def generate_gen(self, prompt: str, **kwargs):
generate_window = max(
unwrap(kwargs.get("generate_window"), 512), max_tokens // 8
)
print(generate_window)

# Sampler settings
gen_settings = ExLlamaV2Sampler.Settings()
Expand Down Expand Up @@ -668,6 +681,7 @@ def generate_gen(self, prompt: str, **kwargs):
**vars(gen_settings),
token_healing=token_healing,
auto_scale_penalty_range=auto_scale_penalty_range,
generate_window=generate_window,
add_bos_token=add_bos_token,
ban_eos_token=ban_eos_token,
stop_conditions=stop_conditions,
Expand Down
31 changes: 23 additions & 8 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@ class SamplerParams(BaseModel):
"""Common class for sampler params that are used in APIs"""

max_tokens: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("max_tokens", 150)
default_factory=lambda: get_default_sampler_value("max_tokens", 150),
examples=[150]
)

generate_window: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("generate_window"),
examples=[512]
)

stop: Optional[Union[str, List[str]]] = Field(
Expand All @@ -29,7 +35,8 @@ class SamplerParams(BaseModel):
)

temperature: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temperature", 1.0)
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
examples=[1.0]
)

temperature_last: Optional[bool] = Field(
Expand All @@ -41,7 +48,8 @@ class SamplerParams(BaseModel):
)

top_p: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("top_p", 1.0)
default_factory=lambda: get_default_sampler_value("top_p", 1.0),
examples=[1.0]
)

top_a: Optional[float] = Field(
Expand All @@ -65,7 +73,8 @@ class SamplerParams(BaseModel):
)

repetition_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0)
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
examples=[1.0]
)

repetition_decay: Optional[int] = Field(
Expand All @@ -77,19 +86,22 @@ class SamplerParams(BaseModel):
)

mirostat_tau: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5)
default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5),
examples=[1.5]
)

mirostat_eta: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3)
default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3),
examples=[0.3]
)

add_bos_token: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("add_bos_token", True)
)

ban_eos_token: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("ban_eos_token", False)
default_factory=lambda: get_default_sampler_value("ban_eos_token", False),
examples=[False]
)

logit_bias: Optional[Dict[int, float]] = Field(
Expand All @@ -106,6 +118,7 @@ class SamplerParams(BaseModel):
default_factory=lambda: get_default_sampler_value("typical", 1.0),
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
examples=[1.0]
)

penalty_range: Optional[int] = Field(
Expand All @@ -122,6 +135,7 @@ class SamplerParams(BaseModel):
default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0),
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
description="Aliases: guidance_scale",
examples=[1.0]
)

def to_gen_params(self):
Expand All @@ -135,8 +149,9 @@ def to_gen_params(self):
self.stop = [self.stop]

return {
"stop": self.stop,
"max_tokens": self.max_tokens,
"generate_window": self.generate_window,
"stop": self.stop,
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"token_healing": self.token_healing,
Expand Down
3 changes: 3 additions & 0 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ model:
# WARNING: This flag disables Flash Attention! (a stopgap fix until it's fixed in upstream)
#use_cfg: False

# Enables fasttensors to possibly increase model loading speeds (default: False)
#fasttensors: true

# Options for draft models (speculative decoding). This will use more VRAM!
#draft:
# Overrides the directory to look for draft (default: models)
Expand Down
5 changes: 5 additions & 0 deletions sampler_overrides/sample_preset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ token_healing:
override: false
force: false

# Commented out because the default is dynamically scaled
#generate_window:
#override: 512
#force: false

# MARK: Temperature
temperature:
override: 1.0
Expand Down

0 comments on commit 21c6b1a

Please sign in to comment.