From 4bf1a71d7bdfb18cdc732c71b9c684a6caf29ecc Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 31 Aug 2024 22:56:49 -0400 Subject: [PATCH] Model: Fix model override application for draft args These have to be merged beforehand and the updated version needs to be re-fetched. It's possible to prevent the fetch of draft_args in the beginning of init. Signed-off-by: kingbri --- backends/exllamav2/model.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ab15836f..7fe08db9 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -347,6 +347,9 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs): # Set user-configured draft model values if enable_draft: + # Fetch from the updated kwargs + draft_args = unwrap(kwargs.get("draft"), {}) + self.draft_config.max_seq_len = self.config.max_seq_len self.draft_config.scale_pos_emb = unwrap( @@ -378,9 +381,15 @@ def set_model_overrides(self, **kwargs): return kwargs with open(override_config_path, "r", encoding="utf8") as override_config_file: - override_config = unwrap(yaml.safe_load(override_config_file), {}) - merged_kwargs = {**override_config, **kwargs} + override_args = unwrap(yaml.safe_load(override_config_file), {}) + + # Merge draft overrides beforehand + draft_override_args = unwrap(override_args.get("draft"), {}) + if self.draft_config and draft_override_args: + kwargs["draft"] = {**draft_override_args, **kwargs.get("draft")} + # Merge the override and model kwargs + merged_kwargs = {**override_args, **kwargs} return merged_kwargs def find_prompt_template(self, prompt_template_name, model_directory):