Skip to content

Commit

Permalink
Model: Fix model override application for draft args
Browse files Browse the repository at this point in the history
These have to be merged beforehand and the updated version needs to be
re-fetched. It's possible to prevent the fetch of draft_args in the
beginning of init.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Sep 1, 2024
1 parent 4aebe8a commit 4bf1a71
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,9 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):

# Set user-configured draft model values
if enable_draft:
# Fetch from the updated kwargs
draft_args = unwrap(kwargs.get("draft"), {})

self.draft_config.max_seq_len = self.config.max_seq_len

self.draft_config.scale_pos_emb = unwrap(
Expand Down Expand Up @@ -378,9 +381,15 @@ def set_model_overrides(self, **kwargs):
return kwargs

with open(override_config_path, "r", encoding="utf8") as override_config_file:
override_config = unwrap(yaml.safe_load(override_config_file), {})
merged_kwargs = {**override_config, **kwargs}
override_args = unwrap(yaml.safe_load(override_config_file), {})

# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft"), {})
if self.draft_config and draft_override_args:
kwargs["draft"] = {**draft_override_args, **kwargs.get("draft")}

# Merge the override and model kwargs
merged_kwargs = {**override_args, **kwargs}
return merged_kwargs

def find_prompt_template(self, prompt_template_name, model_directory):
Expand Down

0 comments on commit 4bf1a71

Please sign in to comment.