Skip to content

Commit

Permalink
修复 MistralTool 格式
Browse files Browse the repository at this point in the history
  • Loading branch information
AlongWY committed Sep 18, 2024
1 parent 558b983 commit 746a8f6
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 11 deletions.
10 changes: 5 additions & 5 deletions src/llamafactory/data/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def apply(self, **kwargs) -> SLOTS:

elements = []
for name, arguments in functions:
elements.append(f""""{{"name":"{name}","arguments":{arguments}}}""")
elements.append(f"""{{"name": "{name}", "arguments": {arguments}}}""")
elements = ["[TOOL_CALLS] [" + ", ".join(elements) + "]"]

return elements
Expand All @@ -163,14 +163,14 @@ def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
tool_results: List[Tuple[str, str]]
try:
tool_results = [json.dumps(result) for result in json.loads(content)]
tool_results = json.loads(content)
except json.JSONDecodeError:
tool_results = []

elements = []
for content in tool_results:
elements.append(f"[TOOL_RESULTS] {{\"content\":{content}}}[/TOOL_RESULTS]")
return ["".join(elements)]
for result in tool_results:
elements.append(f"[TOOL_RESULTS] {{\"content\": {result}}}[/TOOL_RESULTS]")
return elements


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:

_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}"]), # mistral add space here
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
format_function=MistralFunctionFormatter(slots=[], tool_format="mistral"),
Expand Down
4 changes: 2 additions & 2 deletions src/llamafactory/data/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
)

MISTRAL_TOOL_PROMPT = "[AVAILABLE_TOOLS] {tools} [/AVAILABLE_TOOLS]"
MISTRAL_TOOL_PROMPT = "[AVAILABLE_TOOLS] {tools}[/AVAILABLE_TOOLS]"

FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])

Expand Down Expand Up @@ -176,7 +176,7 @@ def get_function_slots() -> SLOTS:
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tools = [{"type": "function", "function": tool} for tool in tools]
tools = json.dumps([{"type": "function", "function": tool} for tool in tools],ensure_ascii=False)
return MISTRAL_TOOL_PROMPT.format(tools=tools)

@override
Expand Down
194 changes: 191 additions & 3 deletions tests/data/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from typing import TYPE_CHECKING, List, Sequence

Expand All @@ -21,11 +21,9 @@
from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.hparams import DataArguments


if TYPE_CHECKING:
from transformers import PreTrainedTokenizer


HF_TOKEN = os.environ.get("HF_TOKEN", None)

TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
Expand All @@ -37,6 +35,81 @@
{"role": "assistant", "content": "很高兴认识你!"},
]

TOOL_MESSAGES = {
"tools": [
{
"type": "function",
"function": {
"name": "get_news",
"description": "获取最新新闻文章",
"parameters": {
"type": "object",
"properties": {
"category": {"type": "string", "description": "要检索的新闻文章类别"},
"country": {"type": "string", "description": "获取新闻文章的国家"}
},
"required": ["category"]
}
}
},
{
"type": "function",
"function": {
"name": "search_books",
"description": "根据提供的标准搜索书籍",
"parameters": {
"type": "object",
"properties": {
"title": {"type": "string", "description": "这本书的标题"},
"author": {"type": "string", "description": "这本书的作者"},
"genre": {"type": "string", "description": "这本书的类型"}
}
}
}
}
],
"messages": [
{
"role": "user",
"content": "你能帮我找到最新的美国体育新闻吗?"
},
{
"role": "tool_calls",
"content": [
{
"type": "function",
"function": {"name": "get_news", "arguments": {"category": "运动", "country": "美国"}}
}
]
},
{
"role": "tool",
"content": json.dumps(
{"title": "NBA总决赛:湖人队对阵热火队", "link": "NBA官方网站"},
ensure_ascii=False
),
},
{
"role": "tool",
"content": json.dumps(
{"title": "NFL:爱国者队击败酋长队", "link": "https://www.nfl.com/新闻"},
ensure_ascii=False
),
},
{
"role": "tool",
"content": json.dumps(
{"title": "MLB:道奇队赢得世界系列赛", "link": "https://www.mlb.com/新闻"},
ensure_ascii=False
)
},
{
"role": "assistant",
"content": "1. NBA总决赛:湖人队对阵热火队\n2. NFL:爱国者队击败酋长队\n3. MLB:道奇队赢得世界系列赛"
}
],
}


def _check_tokenization(
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
Expand Down Expand Up @@ -168,3 +241,118 @@ def test_yi_template():
)
answer_str = "很高兴认识你!<|im_end|>"
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str)


@pytest.mark.xfail(reason="The fast tokenizer of mistral model is corrupted.")
def test_mistral_template():
TEMPLATE = r"""
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{%- set user_messages = messages | selectattr("role", "equalto", "user") | list %}
{%- for message in lmessages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %}
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
{{- raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif %}
{%- endfor %}
{{- bos_token }}
{%- for message in messages %}
{%- if message["role"] == "user" %}
{%- if tools is not none and (message == user_messages[-1]) %}
{{- "[AVAILABLE_TOOLS] [" }}
{%- for tool in tools %}
{%- set tool = tool.function %}
{{- '{"type": "function", "function": {' }}
{%- for key, val in tool.items() if key != "return" %}
{%- if val is string %}
{{- '"' + key + '": "' + val + '"' }}
{%- else %}
{{- '"' + key + '": ' + val|tojson }}
{%- endif %}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- endfor %}
{{- "}}" }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" }}
{%- endif %}
{%- endfor %}
{{- "[/AVAILABLE_TOOLS]" }}
{%- endif %}
{{- "[INST] " + message["content"] + "[/INST]" }}
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
{%- if message.tool_calls is defined %}
{%- set tool_calls = message.tool_calls %}
{%- else %}
{%- set tool_calls = message.content %}
{%- endif %}
{{- "[TOOL_CALLS] [" }}
{%- for tool_call in tool_calls %}
{%- set out = tool_call.function|tojson %}
{{- out }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" }}
{%- endif %}
{%- endfor %}
{%- elif message["role"] == "assistant" %}
{{- " " + message["content"] }}
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
{%- if message.content is defined and message.content.content is defined %}
{%- set content = message.content.content %}
{%- else %}
{%- set content = message.content %}
{%- endif %}
{{- '[TOOL_RESULTS] {"content": ' + content|string + "}[/TOOL_RESULTS]" }}
{%- else %}
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}
"""
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("/home/share/models/Mistral-7B-v0.3")
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="mistral"))

content_str = tokenizer.apply_chat_template(
conversation=TOOL_MESSAGES['messages'],
tools=TOOL_MESSAGES['tools'],
chat_template=TEMPLATE,
tokenize=False
)
content_ids = tokenizer.apply_chat_template(
conversation=TOOL_MESSAGES['messages'],
tools=TOOL_MESSAGES['tools'],
chat_template=TEMPLATE,
tokenize=True
)
encoded_pairs = template.encode_multiturn(
tokenizer,
[
TOOL_MESSAGES['messages'][0],
{
"role": "function",
"content": json.dumps([function['function'] for function in TOOL_MESSAGES['messages'][1]['content']])
},
{
"role": "observation",
"content": json.dumps([item['content'] for item in TOOL_MESSAGES['messages'][2:-1]])
},
TOOL_MESSAGES['messages'][-1],
],
tools=json.dumps([tool['function'] for tool in TOOL_MESSAGES['tools']])
)

final_ids = []
for prompt, response in encoded_pairs:
final_ids.extend(prompt)
final_ids.extend(response)

final_str = tokenizer.decode(final_ids)

assert content_str == final_str
assert content_ids == final_ids

0 comments on commit 746a8f6

Please sign in to comment.