From 8824ea0205cb93e7e8226474abb053f2613aea4f Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 19 Apr 2024 22:52:32 -0400 Subject: [PATCH] Model: Add EOS token support from generation_config.json GenerationConfig is meant to override various parts of the model on generation within the transformers lib. Rather than implementing the entire GenerationConfig framework (since it's pretty redundant), add in multi eos_token support like VLLM. The GenerationConfig is used only for generation, but can be used for other uses if needed. If there's more necessary parameters in the future, add those in as well. Signed-off-by: kingbri --- backends/exllamav2/model.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index b610d8ec..29342925 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -36,6 +36,7 @@ get_template_from_model_json, get_template_from_file, ) +from common.transformers_utils import GenerationConfig from common.utils import coalesce, unwrap @@ -57,6 +58,7 @@ class ExllamaV2Container: # Internal config vars cache_mode: str = "FP16" use_cfg: bool = False + generation_config: Optional[GenerationConfig] = None # GPU split vars gpu_split: Optional[list] = None @@ -193,6 +195,21 @@ def progress(loaded_modules: int, total_modules: int, kwargs.get("prompt_template"), model_directory ) + # Load generation config overrides + generation_config_path = ( + pathlib.Path(self.config.model_dir) / "generation_config.json" + ) + if generation_config_path.exists(): + try: + self.generation_config = GenerationConfig.from_file( + generation_config_path.parent + ) + except Exception: + logger.error(traceback.format_exc()) + logger.warning( + "Skipping generation config load because of an unexpected error." + ) + # Catch all for template lookup errors if self.prompt_template: logger.info( @@ -566,6 +583,7 @@ def decode_tokens(self, ids: List[int], **kwargs): decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True), )[0] + # TODO: Maybe support generation_config for eos_token def get_special_tokens( self, add_bos_token: bool = True, ban_eos_token: bool = False ): @@ -840,13 +858,20 @@ def generate_gen_sync( grammar_string, gen_settings, self.model, self.tokenizer ) + # Fetch EOS tokens from generation_config if they exist + eos_tokens = ( + self.generation_config.eos_tokens() + if self.generation_config + else [self.tokenizer.eos_token_id] + ) + # Ban the EOS token if specified. If not, append to stop conditions # as well. # Set this below logging to avoid polluting the stop strings array if ban_eos_token: - gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) + gen_settings.disallow_tokens(self.tokenizer, eos_tokens) else: - stop_conditions.append(self.tokenizer.eos_token_id) + stop_conditions += eos_tokens # Stop conditions self.generator.set_stop_conditions(stop_conditions) @@ -891,6 +916,8 @@ def generate_gen_sync( token_healing=token_healing, auto_scale_penalty_range=auto_scale_penalty_range, generate_window=generate_window, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=eos_tokens, add_bos_token=add_bos_token, ban_eos_token=ban_eos_token, speculative_ngram=self.generator.speculative_ngram,