Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Glaive conversation format support #1365

Merged
41 changes: 41 additions & 0 deletions src/axolotl/prompt_strategies/sharegpt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""

from typing import Any, Dict, Optional

from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template

from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
from axolotl.utils.tokenization import (
chatml_to_conversation,
merge_consecutive_messages,
)


def register_chatml_template(system_message=None):
Expand All @@ -19,6 +24,16 @@ def register_chatml_template(system_message=None):
sep="<|im_end|>",
)
)
register_conv_template(
Conversation(
name="chatml_glaive",
system_template="<|im_start|>system\n{system_message}",
system_message=system_message,
roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
)


def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
Expand Down Expand Up @@ -77,6 +92,20 @@ def load_guanaco(tokenizer, cfg):
)


def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else "chatml_glaive"
)
return GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation=conversation),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)


class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
Expand Down Expand Up @@ -158,3 +187,15 @@ def get_conversation_thread(self, prompt):
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
]
return turns


class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps glaive data to sharegpt format
"""

def get_conversation_thread(self, prompt):
conversation = chatml_to_conversation(prompt)
conversation = merge_consecutive_messages(conversation)

return conversation
16 changes: 13 additions & 3 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,19 @@ def tokenize_prompt(self, prompt):
LOG.warning(f"expected tuple, got {part}")
continue

user, assistant = conversation.roles
tool_role_label = None
if len(conversation.roles) == 3:
(
user_role_label,
assistant_role_label,
tool_role_label,
) = conversation.roles
else:
user_role_label, assistant_role_label = conversation.roles
role, content = part

# Uses "in" because role contains extra characters
if user in role:
if user_role_label in role:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
Expand All @@ -384,7 +392,7 @@ def tokenize_prompt(self, prompt):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif assistant in role:
elif assistant_role_label in role:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
Expand Down Expand Up @@ -426,6 +434,8 @@ def tokenize_prompt(self, prompt):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif tool_role_label and tool_role_label in role:
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue
Expand Down
7 changes: 7 additions & 0 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,16 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods

role_key_human = "human"
role_key_model = "gpt"
# Optional, only used for tool usage datasets.
role_key_tool = None

def __init__(
self,
prompt_style=None, # pylint: disable=unused-argument
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None,
):
if conversation:
if isinstance(conversation, Conversation):
Expand All @@ -286,6 +289,8 @@ def __init__(
self.role_key_human = role_key_human
if role_key_model:
self.role_key_model = role_key_model
if role_key_tool:
self.role_key_tool = role_key_tool

def _build_result(self, source):
if len(source) < 2:
Expand All @@ -303,6 +308,8 @@ def _build_result(self, source):
source.pop(0)

roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
if self.role_key_tool:
roles[self.role_key_tool] = conv.roles[2]

try:
# Apply prompt templates
Expand Down
64 changes: 64 additions & 0 deletions src/axolotl/utils/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@


import logging
import re
from typing import Dict, List

from termcolor import colored

Expand Down Expand Up @@ -36,3 +38,65 @@ def check_example_labels(example, tokenizer, text_only=False):
LOG.info("\n\n\n")

return " ".join(colored_tokens)


GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
GLAIVE_TO_SHAREGPT_ROLE = {
"SYSTEM": "system",
"USER": "human",
"ASSISTANT": "gpt",
"FUNCTION RESPONSE": "tool",
}

GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ")


def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]:
"""
Converts a ChatML formatted row to a list of messages in ShareGPT format.
Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb.
"""

system_prompt = row.get("system")
if system_prompt:
system_prompt = system_prompt.removeprefix("SYSTEM: ")

chat_str = row["chat"]
chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s]

chat_msg_dicts = [
{"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value}
for role, value in zip(chat_msgs[::2], chat_msgs[1::2])
]

if system_prompt:
chat_msg_dicts = [
{"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt}
] + chat_msg_dicts

return chat_msg_dicts


def merge_consecutive_messages(messages):
"""
Merge consecutive messages from the same sender into a single message.
This can be useful with datasets that contain multiple consecutive tool calls.
"""

merged_messages = []
current_from = None
current_message = ""

for msg in messages:
if current_from == msg["from"]:
current_message += msg["value"]
else:
if current_from is not None:
merged_messages.append({"from": current_from, "value": current_message})
current_from = msg["from"]
current_message = msg["value"]

if current_from is not None:
merged_messages.append({"from": current_from, "value": current_message})

return merged_messages
40 changes: 40 additions & 0 deletions tests/prompt_strategies/test_sharegpt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""
Test module for sharegpt integration w chatml
"""

import pytest
from datasets import Dataset
from tokenizers import AddedToken
from transformers import AutoTokenizer

from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.sharegpt import (
GlaiveShareGPTPromptTokenizingStrategy,
SimpleShareGPTPromptTokenizingStrategy,
register_chatml_template,
)
Expand Down Expand Up @@ -48,6 +50,18 @@ def fixture_sharegpt_dataset():
)


@pytest.fixture(name="glaive_dataset")
def fixture_sharegpt_glaive_dataset():
return Dataset.from_list(
[
{
"system": "SYSTEM: This is a system prompt",
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
}
]
)


@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
Expand Down Expand Up @@ -156,3 +170,29 @@ def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
]
# fmt: on

def test_chatml_glaive(self, glaive_dataset, tokenizer):
strategy = GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="chatml",
role_key_model=None,
role_key_human=None,
),
tokenizer,
True, # train_on_inputs
2048, # sequence_len
)

dataset_wrapper = TokenizedPromptDataset(
strategy, glaive_dataset, process_count=1
)

labels = dataset_wrapper[0]["labels"]
# fmt: off
assert labels == [
1, # bos
32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, # system
32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, # human
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
]
# fmt: on
19 changes: 19 additions & 0 deletions tests/test_prompt_tokenizers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module for testing prompt tokenizers."""

import json
import logging
import unittest
Expand All @@ -18,6 +19,7 @@
Llama2ChatPrompter,
LLama2ChatTokenizingStrategy,
)
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
Expand Down Expand Up @@ -266,6 +268,23 @@ def test_sharegpt_assistant_label_ignore(self):
idx = res["input_ids"].index(20255) # assistant token
assert res["labels"][idx] == -100

def test_glaive_tool_label_ignore(self):
conversation = {
"system": "SYSTEM: This is a system prompt",
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
}
prompter = ShareGPTPrompterV2()
strat = GlaiveShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
res = strat.tokenize_prompt(conversation)
idx = res["input_ids"].index(13566) # assistant token
assert res["labels"][idx] == -100

def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts
Expand Down
Loading