-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use PromptRegistry for all chat templates
This removes the mapping of model families in SDG itself between granite, mixtral, mistral, merlinite, etc. Instead, it uses the PromptRegistry to lookup chat templates based on the model family given. And, if no model family is given, it still falls back to doing a best-guess based on the file path of the selected teacher model. A simple test was added to demonstrate how to register and use custom chat templates for generating prompts via the PromptRegistry. Signed-off-by: Ben Browning <[email protected]>
- Loading branch information
Showing
10 changed files
with
82 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,24 @@ | ||
# Local | ||
from .registry import PromptRegistry | ||
|
||
MODEL_FAMILY_MIXTRAL = "mixtral" | ||
MODEL_FAMILY_MERLINITE = "merlinite" | ||
# {{ prompt }} gives us the config's raw prompt string, not wrapped in | ||
# any messages format | ||
|
||
# {{ messages }} gives us the config's prompt in messages format, | ||
# where the config's prompt becomes the content value of a user role | ||
# message | ||
|
||
|
||
@PromptRegistry.register("blank") | ||
def blank_chat_template(): | ||
return """{{ messages }}""" | ||
return """{{ prompt }}""" | ||
|
||
|
||
@PromptRegistry.register(MODEL_FAMILY_MERLINITE) | ||
@PromptRegistry.register("merlinite", "granite") | ||
def merlinite_chat_template(): | ||
return """{% for message in messages %}{% if message['role'] == 'pretraining' %}{{ '<|pretrain|>' + message['content'] + '<|endoftext|>' + '<|/pretrain|>' }}{% elif message['role'] == 'system' %}{{ '<|system|>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<|user|>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n') }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|assistant|>' + '\n' }}{% endif %}{% endfor %}""" | ||
|
||
|
||
@PromptRegistry.register(MODEL_FAMILY_MIXTRAL) | ||
@PromptRegistry.register("mixtral", "mistral") | ||
def mixtral_chat_template(): | ||
return """{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n<s>\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + '</s>'}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# First Party | ||
from instructlab.sdg import PromptRegistry | ||
|
||
|
||
# Register our custom chat template under the "custom_model_family" | ||
# model family | ||
@PromptRegistry.register("custom_model_family") | ||
def custom_chat_template(): | ||
return """{% for message in messages %}{% if message['role'] == 'system' %}{{ '<<SYSTEM>>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<<USER>>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<<ASSISTANT>>' + '\n' + message['content'] + ('' if loop.last else '\n') }}{% endif %}{% endfor %}""" | ||
|
||
|
||
# Lookup the chat template for "custom_model_family" model family | ||
template = PromptRegistry.get_template("custom_model_family") | ||
assert template is not None | ||
|
||
# Ensure the template found is our custom one | ||
prompt = template.render( | ||
messages=[ | ||
{"role": "system", "content": "system prompt goes here"}, | ||
{"role": "user", "content": "user content goes here"}, | ||
] | ||
) | ||
expected_prompt = ( | ||
"<<SYSTEM>>\nsystem prompt goes here\n<<USER>>\nuser content goes here\n" | ||
) | ||
assert prompt == expected_prompt |