From 09a4c79847867cd1868a23171eaeac7aefadb76d Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 18 Mar 2024 22:54:59 -0400 Subject: [PATCH] Model: Auto-scale max_tokens by default If max_tokens is None, it automatically scales to fill up the context. This does not mean the generation will fill up that context since EOS stops also exist. Originally suggested by #86 Signed-off-by: kingbri --- backends/exllamav2/model.py | 46 ++++++++++++++++++++++--------------- common/sampling.py | 2 +- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 79bdbe2c..18eb626d 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -638,7 +638,6 @@ def generate_gen_sync(self, prompt: str, **kwargs): """ token_healing = unwrap(kwargs.get("token_healing"), False) - max_tokens = unwrap(kwargs.get("max_tokens"), 150) stream_interval = unwrap(kwargs.get("stream_interval"), 0) generate_window = max( unwrap(kwargs.get("generate_window"), 512), self.config.max_seq_len // 8 @@ -761,24 +760,8 @@ def generate_gen_sync(self, prompt: str, **kwargs): gen_settings.top_p = 0 gen_settings.typical = 0 - # Log generation options to console - # Some options are too large, so log the args instead - log_generation_params( - max_tokens=max_tokens, - **vars(gen_settings), - token_healing=token_healing, - auto_scale_penalty_range=auto_scale_penalty_range, - generate_window=generate_window, - add_bos_token=add_bos_token, - ban_eos_token=ban_eos_token, - speculative_ngram=self.generator.speculative_ngram, - logprobs=request_logprobs, - stop_conditions=stop_conditions, - logit_bias=logit_bias, - ) - - # Log prompt to console - log_prompt(prompt, negative_prompt) + # Store the gen settings for logging purposes + gen_settings_log_dict = vars(gen_settings) # Set logit bias if logit_bias: @@ -854,6 +837,31 @@ def generate_gen_sync(self, prompt: str, **kwargs): prompt_tokens = ids.shape[-1] + # Automatically set max_tokens to fill up the context + # This should be an OK default, but may be changed in the future + max_tokens = unwrap( + kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens + ) + + # Log generation options to console + # Some options are too large, so log the args instead + log_generation_params( + max_tokens=max_tokens, + **gen_settings_log_dict, + token_healing=token_healing, + auto_scale_penalty_range=auto_scale_penalty_range, + generate_window=generate_window, + add_bos_token=add_bos_token, + ban_eos_token=ban_eos_token, + speculative_ngram=self.generator.speculative_ngram, + logprobs=request_logprobs, + stop_conditions=stop_conditions, + logit_bias=logit_bias, + ) + + # Log prompt to console + log_prompt(prompt, negative_prompt) + # Begin generated_tokens = 0 full_response = "" diff --git a/common/sampling.py b/common/sampling.py index 5a4ea942..a72dd4f2 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -14,7 +14,7 @@ class BaseSamplerRequest(BaseModel): """Common class for sampler params that are used in APIs""" max_tokens: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("max_tokens", 150), + default_factory=lambda: get_default_sampler_value("max_tokens"), examples=[150], )