From fc4570220cf0e5b42d5f5c999a7c0895228c750a Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 25 Jan 2024 00:11:30 -0500 Subject: [PATCH] API + Model: Add new parameters and clean up documentation The example JSON fields were changed because of the new sampler default strategy. Fix these by manually changing the values. Also add support for fasttensors and expose generate_window to the API. It's recommended to not adjust generate_window as it's dynamically scaled based on max_seq_len by default. Signed-off-by: kingbri --- backends/exllamav2/model.py | 17 ++++++++++++++-- common/sampling.py | 30 +++++++++++++++++++++-------- config_sample.yml | 3 +++ sampler_overrides/sample_preset.yml | 5 +++++ 4 files changed, 45 insertions(+), 10 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ac939d44..52764e22 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -138,13 +138,25 @@ def progress(loaded_modules: int, total_modules: int, kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len) ) + # Enable CFG if present + use_cfg = unwrap(kwargs.get("use_cfg"), False) if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"): - self.use_cfg = unwrap(kwargs.get("use_cfg"), False) - else: + self.use_cfg = use_cfg + elif use_cfg: logger.warning( "CFG is not supported by the currently installed ExLlamaV2 version." ) + # Enable fasttensors loading if present + use_fasttensors = unwrap(kwargs.get("fasttensors"), False) + if hasattr(ExLlamaV2Config, "fasttensors"): + self.config.fasttensors = use_fasttensors + elif use_fasttensors: + logger.warning( + "fasttensors is not supported by " + "the currently installed ExllamaV2 version." + ) + # Turn off flash attention if CFG is on # Workaround until batched FA2 is fixed in exllamav2 upstream self.config.no_flash_attn = ( @@ -668,6 +680,7 @@ def generate_gen(self, prompt: str, **kwargs): **vars(gen_settings), token_healing=token_healing, auto_scale_penalty_range=auto_scale_penalty_range, + generate_window=generate_window, add_bos_token=add_bos_token, ban_eos_token=ban_eos_token, stop_conditions=stop_conditions, diff --git a/common/sampling.py b/common/sampling.py index 53defcc1..8c28002d 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -17,7 +17,13 @@ class SamplerParams(BaseModel): """Common class for sampler params that are used in APIs""" max_tokens: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("max_tokens", 150) + default_factory=lambda: get_default_sampler_value("max_tokens", 150), + examples=[150], + ) + + generate_window: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("generate_window"), + examples=[512], ) stop: Optional[Union[str, List[str]]] = Field( @@ -29,7 +35,8 @@ class SamplerParams(BaseModel): ) temperature: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("temperature", 1.0) + default_factory=lambda: get_default_sampler_value("temperature", 1.0), + examples=[1.0], ) temperature_last: Optional[bool] = Field( @@ -41,7 +48,7 @@ class SamplerParams(BaseModel): ) top_p: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("top_p", 1.0) + default_factory=lambda: get_default_sampler_value("top_p", 1.0), examples=[1.0] ) top_a: Optional[float] = Field( @@ -65,7 +72,8 @@ class SamplerParams(BaseModel): ) repetition_penalty: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0) + default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0), + examples=[1.0], ) repetition_decay: Optional[int] = Field( @@ -77,11 +85,13 @@ class SamplerParams(BaseModel): ) mirostat_tau: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5) + default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5), + examples=[1.5], ) mirostat_eta: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3) + default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3), + examples=[0.3], ) add_bos_token: Optional[bool] = Field( @@ -89,7 +99,8 @@ class SamplerParams(BaseModel): ) ban_eos_token: Optional[bool] = Field( - default_factory=lambda: get_default_sampler_value("ban_eos_token", False) + default_factory=lambda: get_default_sampler_value("ban_eos_token", False), + examples=[False], ) logit_bias: Optional[Dict[int, float]] = Field( @@ -106,6 +117,7 @@ class SamplerParams(BaseModel): default_factory=lambda: get_default_sampler_value("typical", 1.0), validation_alias=AliasChoices("typical", "typical_p"), description="Aliases: typical_p", + examples=[1.0], ) penalty_range: Optional[int] = Field( @@ -122,6 +134,7 @@ class SamplerParams(BaseModel): default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0), validation_alias=AliasChoices("cfg_scale", "guidance_scale"), description="Aliases: guidance_scale", + examples=[1.0], ) def to_gen_params(self): @@ -135,8 +148,9 @@ def to_gen_params(self): self.stop = [self.stop] return { - "stop": self.stop, "max_tokens": self.max_tokens, + "generate_window": self.generate_window, + "stop": self.stop, "add_bos_token": self.add_bos_token, "ban_eos_token": self.ban_eos_token, "token_healing": self.token_healing, diff --git a/config_sample.yml b/config_sample.yml index 89368acf..cf1ddb53 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -97,6 +97,9 @@ model: # WARNING: This flag disables Flash Attention! (a stopgap fix until it's fixed in upstream) #use_cfg: False + # Enables fasttensors to possibly increase model loading speeds (default: False) + #fasttensors: true + # Options for draft models (speculative decoding). This will use more VRAM! #draft: # Overrides the directory to look for draft (default: models) diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index 9c661a14..eae17ab4 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -18,6 +18,11 @@ token_healing: override: false force: false +# Commented out because the default is dynamically scaled +#generate_window: + #override: 512 + #force: false + # MARK: Temperature temperature: override: 1.0