From efc01d947bedc61487d9ecb3561aa218e23023f9 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 13 Mar 2024 23:32:11 -0400 Subject: [PATCH] API + Model: Add speculative ngram decoding Speculative ngram decoding is like speculative decoding without the draft model. It's not as useful because it only decodes on predictable sequences, but it depends on the usecase. Signed-off-by: kingbri --- backends/exllamav2/model.py | 27 ++++++++++++++++++++++++--- common/sampling.py | 5 +++++ sampler_overrides/sample_preset.yml | 3 +++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index c803c15c..77908d8b 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -20,7 +20,12 @@ from typing import List, Optional, Union from backends.exllamav2.grammar import ExLlamaV2Grammar -from common.gen_logging import log_generation_params, log_metrics, log_prompt, log_response +from common.gen_logging import ( + log_generation_params, + log_metrics, + log_prompt, + log_response, +) from common.templating import ( PromptTemplate, find_template_from_model, @@ -598,7 +603,17 @@ def generate(self, prompt: str, **kwargs): def check_unsupported_settings(self, **kwargs): """Check and warn the user if a sampler is unsupported. Meant for dev wheels!""" - pass + if unwrap(kwargs.get("speculative_ngram"), False) and not hasattr( + ExLlamaV2StreamingGenerator, "speculative_ngram" + ): + logger.warning( + "Speculative ngram is not supported by the currently " + "installed ExLlamaV2 version." + ) + + kwargs.pop("speculative_ngram") + + return kwargs # pylint: disable=too-many-locals,too-many-branches,too-many-statements def generate_gen(self, prompt: str, **kwargs): @@ -656,7 +671,7 @@ def generate_gen(self, prompt: str, **kwargs): gen_settings = ExLlamaV2Sampler.Settings() # Check unsupported settings for dev wheels - self.check_unsupported_settings(**kwargs) + kwargs = self.check_unsupported_settings(**kwargs) # Apply settings gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0) @@ -758,6 +773,11 @@ def generate_gen(self, prompt: str, **kwargs): request_logprobs = unwrap(kwargs.get("logprobs"), 0) self.generator.return_top_tokens = request_logprobs + # Speculative Ngram + self.generator.speculative_ngram = unwrap( + kwargs.get("speculative_ngram"), False + ) + # Override sampler settings for temp = 0 if gen_settings.temperature == 0: gen_settings.temperature = 1.0 @@ -775,6 +795,7 @@ def generate_gen(self, prompt: str, **kwargs): generate_window=generate_window, add_bos_token=add_bos_token, ban_eos_token=ban_eos_token, + speculative_ngram=self.generator.speculative_ngram, logprobs=request_logprobs, stop_conditions=stop_conditions, logit_bias=logit_bias, diff --git a/common/sampling.py b/common/sampling.py index aa7e7d0d..5a4ea942 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -123,6 +123,10 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("grammar_string"), ) + speculative_ngram: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("speculative_ngram"), + ) + # Aliased variables typical: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("typical", 1.0), @@ -268,6 +272,7 @@ def to_gen_params(self, **kwargs): "negative_prompt": self.negative_prompt, "json_schema": self.json_schema, "grammar_string": self.grammar_string, + "speculative_ngram": self.speculative_ngram, } return {**gen_params, **kwargs} diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index da28dbd0..88dee472 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -17,6 +17,9 @@ stop: token_healing: override: false force: false +speculative_ngram: + override: false + force: false # Commented out because the default is dynamically scaled #generate_window: