Skip to content

Commit

Permalink
Model: Fix prompt template initialization
Browse files Browse the repository at this point in the history
The previous commit iterated through multiple try conditions which
made it so the user has to provide a dummy prompt template. Now,
template loading is fallback based.

Run through a loop of functions and return if one of them succeeds.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Jan 25, 2024
1 parent 740b021 commit 90fb41a
Showing 1 changed file with 32 additions and 26 deletions.
58 changes: 32 additions & 26 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,32 +158,10 @@ def progress(loaded_modules: int, total_modules: int,
self.config.set_low_mem()
"""

# Set prompt template override if provided
prompt_template_name = kwargs.get("prompt_template")
if prompt_template_name:
logger.info("Loading prompt template with name " f"{prompt_template_name}")
# Read the template
try:
self.prompt_template = get_template_from_file(prompt_template_name)
except FileNotFoundError:
self.prompt_template = None

# Then try finding the template from the tokenizer_config.json
try:
self.prompt_template = get_template_from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
"from_tokenizer_config",
)
except FileNotFoundError:
self.prompt_template = None

# If that fails, attempt fetching from model name
try:
template_match = find_template_from_model(model_directory)
self.prompt_template = get_template_from_file(template_match)
except (LookupError, FileNotFoundError):
self.prompt_template = None
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
)

# Catch all for template lookup errors
if self.prompt_template:
Expand Down Expand Up @@ -250,6 +228,34 @@ def progress(loaded_modules: int, total_modules: int,
self.draft_config.max_input_len = kwargs["chunk_size"]
self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2

def find_prompt_template(self, prompt_template_name, model_directory):
"""Tries to find a prompt template using various methods"""

logger.info("Loading prompt template with name " f"{prompt_template_name}")

find_template_functions = [
lambda: get_template_from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
"from_tokenizer_config",
),
lambda: get_template_from_file(find_template_from_model(model_directory)),
]

# Add lookup from prompt template name if provided
if prompt_template_name:
find_template_functions.insert(
0, lambda: get_template_from_file(prompt_template_name)
)

for func in find_template_functions:
try:
prompt_template = func()
if prompt_template is not None:
return prompt_template
except (FileNotFoundError, LookupError):
continue

def calculate_rope_alpha(self, base_seq_len):
"""Calculate the rope alpha value for a given sequence length."""
ratio = self.config.max_seq_len / base_seq_len
Expand Down

0 comments on commit 90fb41a

Please sign in to comment.