Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add config for optional parameters in a chat message #2260

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ datasets:
message_field_role: role
# Key for content in each message (default: "content")
message_field_content: content
# Mapping of properties from the input dataset to the chat template. (default: None)
message_property_mappings:

# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/prompt_strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor

return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
Expand Down
3 changes: 1 addition & 2 deletions src/axolotl/prompt_strategies/bradley_terry/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_property_mappings": ds_cfg.get("message_property_mappings", {}),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", None
Expand Down
108 changes: 73 additions & 35 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
"""

import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set

from transformers import ProcessorMixin

from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
Expand All @@ -25,8 +26,7 @@ def __init__(
processor=None,
chat_template=None,
max_length=2048,
message_field_role: str = "role",
message_field_content: str = "content",
message_property_mappings: Optional[Dict[str, str]] = None,
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
roles: Optional[Dict[str, List[str]]] = None,
Expand All @@ -44,8 +44,10 @@ def __init__(
"tool": "tool",
}

self.message_field_role = message_field_role
self.message_field_content = message_field_content
self._chat_template_msg_variables = self.get_chat_template_msg_variables(
chat_template
)
Comment on lines +47 to +49
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, messages_array_name refers to the key of the List[dict]. Should we change the signature so that field_messages is passed to the Prompter as well? To allow passing messages_array_name=field_messages?

self.message_property_mappings = message_property_mappings
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.tokenizer = tokenizer
Expand All @@ -54,6 +56,10 @@ def __init__(
self.max_length = max_length
self.drop_system_message = drop_system_message

@property
def chat_template_msg_variables(self) -> Set[str]:
return self._chat_template_msg_variables

def build_prompt(self, conversation, add_generation_prompt=False, images=None):
if self.processor:
text = self.processor.apply_chat_template(
Expand Down Expand Up @@ -183,6 +189,12 @@ def adjust_train_details(

return adjusted_details

def get_chat_template_msg_variables(
self, chat_template: str, messages_array_name: str = "messages"
) -> Set[str]:
template_analyzer = JinjaTemplateAnalyzer(chat_template)
return template_analyzer.get_message_vars(messages_array_name)


class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
Expand Down Expand Up @@ -212,6 +224,10 @@ def __init__(
self.train_on_eos = train_on_eos
self.images = "images"

LOG.info(
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
)

@property
def messages(self):
return self._messages
Expand Down Expand Up @@ -424,61 +440,83 @@ def find_turn(self, turns: list[dict], turn_idx: int):

def get_conversation_thread(self, prompt):
turns = []
optional_keys = [
"tool_calls", # tool that 'assistant' calls
"name", # name of tool given by 'tool'
"tool_call_id", # mistral/mixtral requires this
]
for message in prompt[self.messages]:
transformed_message = self.transform_message(message)
LOG.warning(f"Message: {message}")
LOG.warning(f"Transformed message: {transformed_message}")

turn = {
"role": self.prompter.roles[message[self.prompter.message_field_role]],
**transformed_message,
"training": message.get(self.prompter.message_field_training),
"training_detail": message.get(
self.prompter.message_field_training_detail
),
}

# do not add content if None as it may conflict with some templates due to tools
content = message.get(self.prompter.message_field_content, None)
if content is not None:
turn["content"] = content

for key in optional_keys:
value = message.get(key, None)
if value is not None:
turn[key] = value

turns.append(turn)

if self.prompter.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]

return turns

def transform_message(self, message):
# Build the initial transformed message from the mappings
transformed_message = {
key: message[value]
for key, value in self.prompter.message_property_mappings.items()
if message.get(value) is not None
}

# Map the role if necessary
if "role" in transformed_message:
transformed_message["role"] = self.prompter.roles.get(
transformed_message["role"], transformed_message["role"]
)

# Determine which keys in the original message were not mapped
mapped_values = set(self.prompter.message_property_mappings.values())
remaining_keys = set(message) - mapped_values

# Keep only the properties defined in the chat template
# and not already mapped
for key in self.prompter.chat_template_msg_variables:
if key in remaining_keys:
val = message.get(key)
if val is not None:
transformed_message[key] = val

return transformed_message

def get_images(self, prompt):
return prompt.get(self.images, None)


def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
def load(
tokenizer,
cfg,
ds_cfg: Optional[Dict[str, Any]] = None,
processor=None,
):
dataset_config = ds_cfg if ds_cfg else {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")

prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_property_mappings": dataset_config.get(
"message_property_mappings", {}
),
"message_field_training": dataset_config.get("message_field_training", None),
"message_field_training_detail": dataset_config.get(
"message_field_training_detail",
None,
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
"roles": dataset_config.get("roles"),
"drop_system_message": dataset_config.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
Expand All @@ -487,15 +525,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
"roles_to_train": dataset_config.get("roles_to_train", ["assistant"]),
"train_on_eos": dataset_config.get("train_on_eos", "turn"),
}

strategy = ChatTemplateStrategy(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)

if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
if "field_messages" in dataset_config and hasattr(strategy, "messages"):
strategy.messages = dataset_config["field_messages"]

return strategy
Loading
Loading