diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ab15836f..7fe08db9 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -347,6 +347,9 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs): # Set user-configured draft model values if enable_draft: + # Fetch from the updated kwargs + draft_args = unwrap(kwargs.get("draft"), {}) + self.draft_config.max_seq_len = self.config.max_seq_len self.draft_config.scale_pos_emb = unwrap( @@ -378,9 +381,15 @@ def set_model_overrides(self, **kwargs): return kwargs with open(override_config_path, "r", encoding="utf8") as override_config_file: - override_config = unwrap(yaml.safe_load(override_config_file), {}) - merged_kwargs = {**override_config, **kwargs} + override_args = unwrap(yaml.safe_load(override_config_file), {}) + + # Merge draft overrides beforehand + draft_override_args = unwrap(override_args.get("draft"), {}) + if self.draft_config and draft_override_args: + kwargs["draft"] = {**draft_override_args, **kwargs.get("draft")} + # Merge the override and model kwargs + merged_kwargs = {**override_args, **kwargs} return merged_kwargs def find_prompt_template(self, prompt_template_name, model_directory):