Skip to content

Commit

Permalink
Use PromptRegistry for all chat templates
Browse files Browse the repository at this point in the history
This removes the mapping of model families in SDG itself between
granite, mixtral, mistral, merlinite, etc. Instead, it uses the
PromptRegistry to lookup chat templates based on the model family
given. And, if no model family is given, it still falls back to doing
a best-guess based on the file path of the selected teacher model.

A simple test was added to demonstrate how to register and use custom
chat templates for generating prompts via the PromptRegistry.

Signed-off-by: Ben Browning <[email protected]>
  • Loading branch information
bbrowning committed Nov 27, 2024
1 parent b77586c commit 1007e8e
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 36 deletions.
2 changes: 2 additions & 0 deletions docs/upgrading_from_v0.6.x_to_v0.7.x.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Advanced users are now able to supply custom Pipeline `Block` implementations by

See the `tests/testdata/custom_block.py` and `tests/testdata/custom_block_pipeline.yaml` files in this repository for an example of how to create custom blocks and use them from your own pipeline config yamls.

See the `tests/testdata/custom_prompt.py` file in this repository for an example how to register custom chat templates used when formatting prompts.

## Breaking Changes

### Pipeline configs and Prompt templates switched to Jinja
Expand Down
3 changes: 0 additions & 3 deletions src/instructlab/sdg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
"SamplePopulatorBlock",
"SelectorBlock",
"SetToMajorityValueBlock",
"MODEL_FAMILY_MERLINITE",
"MODEL_FAMILY_MIXTRAL",
"FULL_PIPELINES_PACKAGE",
"SIMPLE_PIPELINES_PACKAGE",
"generate_data",
Expand Down Expand Up @@ -59,7 +57,6 @@
PipelineConfigParserError,
PipelineContext,
)
from .prompts import MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL
from .registry import BlockRegistry, PromptRegistry
from .utils import GenerateException
from .utils.taxonomy import TaxonomyReadingException
9 changes: 3 additions & 6 deletions src/instructlab/sdg/blocks/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import openai

# Local
# Import prompts to register default chat templates
from .. import prompts as default_prompts # pylint: disable=unused-import
from ..registry import BlockRegistry, PromptRegistry
from .block import Block

Expand Down Expand Up @@ -156,7 +158,6 @@ def _parse(self, generated_string) -> dict:
# 3. Empty string - the pipeline has specified that no model prompt is needed
def _format_prompt(self, sample: Dict) -> str:
prompt_templated_str = self.prompt_template.render(sample).strip()
wrap_in_messages_format = True

model_prompt = None
if self.model_prompt is None:
Expand All @@ -167,12 +168,8 @@ def _format_prompt(self, sample: Dict) -> str:
# Our model prompt is an empty string, which we'll render
# verbatim without wrapping in the messages format
model_prompt = PromptRegistry.get_template("blank")
wrap_in_messages_format = False

if wrap_in_messages_format:
messages = [{"role": "user", "content": prompt_templated_str}]
else:
messages = prompt_templated_str
messages = [{"role": "user", "content": prompt_templated_str}]

return model_prompt.render(
messages=messages,
Expand Down
6 changes: 1 addition & 5 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
Pipeline,
PipelineContext,
)
from instructlab.sdg.prompts import MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL
from instructlab.sdg.utils import GenerateException, models
from instructlab.sdg.utils.json import jldump
from instructlab.sdg.utils.taxonomy import (
Expand Down Expand Up @@ -355,10 +354,7 @@ def generate_data(

logger.debug(f"Generating to: {os.path.join(output_dir, output_file_test)}")

if models.get_model_family(model_family, model_name) == "mixtral":
model_family = MODEL_FAMILY_MIXTRAL
else:
model_family = MODEL_FAMILY_MERLINITE
model_family = models.get_model_family(model_family, model_name)

ctx = _context_init(
client,
Expand Down
14 changes: 9 additions & 5 deletions src/instructlab/sdg/prompts.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
# Local
from .registry import PromptRegistry

MODEL_FAMILY_MIXTRAL = "mixtral"
MODEL_FAMILY_MERLINITE = "merlinite"
# {{ prompt }} gives us the config's raw prompt string, not wrapped in
# any messages format

# {{ messages }} gives us the config's prompt in messages format,
# where the config's prompt becomes the content value of a user role
# message


@PromptRegistry.register("blank")
def blank_chat_template():
return """{{ messages }}"""
return """{{ prompt }}"""


@PromptRegistry.register(MODEL_FAMILY_MERLINITE)
@PromptRegistry.register("merlinite", "granite")
def merlinite_chat_template():
return """{% for message in messages %}{% if message['role'] == 'pretraining' %}{{ '<|pretrain|>' + message['content'] + '<|endoftext|>' + '<|/pretrain|>' }}{% elif message['role'] == 'system' %}{{ '<|system|>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<|user|>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n') }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|assistant|>' + '\n' }}{% endif %}{% endfor %}"""


@PromptRegistry.register(MODEL_FAMILY_MIXTRAL)
@PromptRegistry.register("mixtral", "mistral")
def mixtral_chat_template():
return """{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n<s>\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + '</s>'}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n"""
8 changes: 5 additions & 3 deletions src/instructlab/sdg/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class PromptRegistry:
_registry: Dict[str, Template] = {}

@classmethod
def register(cls, name: str):
def register(cls, *names: str):
"""Decorator to register a Jinja2 template function by name.
:param name: Name of the template to register.
Expand All @@ -55,8 +55,10 @@ def register(cls, name: str):

def decorator(func):
template_str = func()
cls._registry[name] = Template(template_str, undefined=StrictUndefined)
logger.debug(f"Registered prompt template '{name}'")
template = Template(template_str, undefined=StrictUndefined)
for name in names:
cls._registry[name] = template
logger.debug(f"Registered prompt template '{name}'")
return func

return decorator
Expand Down
21 changes: 9 additions & 12 deletions src/instructlab/sdg/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,22 @@
import re

# First Party
from instructlab.sdg.registry import PromptRegistry
from instructlab.sdg.utils import GenerateException

# When otherwise unknown, ilab uses this as the default family
DEFAULT_MODEL_FAMILY = "merlinite"

# Model families understood by ilab
MODEL_FAMILIES = set(("merlinite", "mixtral"))

# Map model names to their family
MODEL_FAMILY_MAPPINGS = {"granite": "merlinite", "mistral": "mixtral"}


def get_model_family(model_family, model_path):
model_family_retrieved = MODEL_FAMILY_MAPPINGS.get(model_family, model_family)
if model_family_retrieved and model_family_retrieved.lower() not in MODEL_FAMILIES:
raise GenerateException("Unknown model family: %s" % model_family_retrieved)
registry = PromptRegistry.get_registry()

# A model_family was given, so use it explicitly
if model_family:
if model_family not in registry:
raise GenerateException("Unknown model family: %s" % model_family)
return model_family

# Try to guess the model family based on the model's filename
guess = re.match(r"^\w*", os.path.basename(model_path)).group(0).lower()
guess = MODEL_FAMILY_MAPPINGS.get(guess, guess)

return guess if guess in MODEL_FAMILIES else DEFAULT_MODEL_FAMILY
return guess if guess in registry else DEFAULT_MODEL_FAMILY
5 changes: 5 additions & 0 deletions tests/functional/test_custom_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@
def test_custom_block(testdata_path: pathlib.Path):
script = testdata_path.joinpath("custom_block.py")
subprocess.check_call([sys.executable, str(script)], text=True)


def test_custom_prompt(testdata_path: pathlib.Path):
script = testdata_path.joinpath("custom_prompt.py")
subprocess.check_call([sys.executable, str(script)], text=True)
22 changes: 20 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TestModels:
def test_granite_model_family(self):
assert (
models.get_model_family("granite", "./models/granite-7b-lab-Q4_K_M.gguf")
== "merlinite"
== "granite"
)

def test_merlinite_model_family(self):
Expand All @@ -32,12 +32,30 @@ def test_mixtral_model_family(self):
== "mixtral"
)

def test_mistral_model_family(self):
assert (
models.get_model_family(
"mistral", "./models/mistral-7b-instruct-v0.2.Q4_K_M.gguf"
)
== "mistral"
)

def test_default_model_family(self):
assert (
models.get_model_family(None, "./models/foo-8x7b-instruct-v0.1.Q4_K_M.gguf")
== "merlinite"
)
assert (
models.get_model_family("", "./models/foo-8x7b-instruct-v0.1.Q4_K_M.gguf")
== "merlinite"
)

def test_model_family_overrides(self):
assert (
models.get_model_family(
"mixtral", "./models/foo-8x7b-instruct-v0.1.Q4_K_M.gguf"
)
== "merlinite"
== "mixtral"
)

def test_unknown_model_family(self):
Expand Down
28 changes: 28 additions & 0 deletions tests/testdata/custom_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0

# First Party
from instructlab.sdg import PromptRegistry


# Register our custom chat template under the "custom_model_family"
# model family
@PromptRegistry.register("custom_model_family")
def custom_chat_template():
return """{% for message in messages %}{% if message['role'] == 'system' %}{{ '<<SYSTEM>>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<<USER>>' + '\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<<ASSISTANT>>' + '\n' + message['content'] + ('' if loop.last else '\n') }}{% endif %}{% endfor %}"""


# Lookup the chat template for "custom_model_family" model family
template = PromptRegistry.get_template("custom_model_family")
assert template is not None

# Ensure the template found is our custom one
prompt = template.render(
messages=[
{"role": "system", "content": "system prompt goes here"},
{"role": "user", "content": "user content goes here"},
]
)
expected_prompt = (
"<<SYSTEM>>\nsystem prompt goes here\n<<USER>>\nuser content goes here\n"
)
assert prompt == expected_prompt

0 comments on commit 1007e8e

Please sign in to comment.