diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index c803c15c..77908d8b 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -20,7 +20,12 @@ from typing import List, Optional, Union from backends.exllamav2.grammar import ExLlamaV2Grammar -from common.gen_logging import log_generation_params, log_metrics, log_prompt, log_response +from common.gen_logging import ( + log_generation_params, + log_metrics, + log_prompt, + log_response, +) from common.templating import ( PromptTemplate, find_template_from_model, @@ -598,7 +603,17 @@ def generate(self, prompt: str, **kwargs): def check_unsupported_settings(self, **kwargs): """Check and warn the user if a sampler is unsupported. Meant for dev wheels!""" - pass + if unwrap(kwargs.get("speculative_ngram"), False) and not hasattr( + ExLlamaV2StreamingGenerator, "speculative_ngram" + ): + logger.warning( + "Speculative ngram is not supported by the currently " + "installed ExLlamaV2 version." + ) + + kwargs.pop("speculative_ngram") + + return kwargs # pylint: disable=too-many-locals,too-many-branches,too-many-statements def generate_gen(self, prompt: str, **kwargs): @@ -656,7 +671,7 @@ def generate_gen(self, prompt: str, **kwargs): gen_settings = ExLlamaV2Sampler.Settings() # Check unsupported settings for dev wheels - self.check_unsupported_settings(**kwargs) + kwargs = self.check_unsupported_settings(**kwargs) # Apply settings gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0) @@ -758,6 +773,11 @@ def generate_gen(self, prompt: str, **kwargs): request_logprobs = unwrap(kwargs.get("logprobs"), 0) self.generator.return_top_tokens = request_logprobs + # Speculative Ngram + self.generator.speculative_ngram = unwrap( + kwargs.get("speculative_ngram"), False + ) + # Override sampler settings for temp = 0 if gen_settings.temperature == 0: gen_settings.temperature = 1.0 @@ -775,6 +795,7 @@ def generate_gen(self, prompt: str, **kwargs): generate_window=generate_window, add_bos_token=add_bos_token, ban_eos_token=ban_eos_token, + speculative_ngram=self.generator.speculative_ngram, logprobs=request_logprobs, stop_conditions=stop_conditions, logit_bias=logit_bias, diff --git a/common/sampling.py b/common/sampling.py index aa7e7d0d..5a4ea942 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -123,6 +123,10 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("grammar_string"), ) + speculative_ngram: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("speculative_ngram"), + ) + # Aliased variables typical: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("typical", 1.0), @@ -268,6 +272,7 @@ def to_gen_params(self, **kwargs): "negative_prompt": self.negative_prompt, "json_schema": self.json_schema, "grammar_string": self.grammar_string, + "speculative_ngram": self.speculative_ngram, } return {**gen_params, **kwargs} diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index da28dbd0..88dee472 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -17,6 +17,9 @@ stop: token_healing: override: false force: false +speculative_ngram: + override: false + force: false # Commented out because the default is dynamically scaled #generate_window: