Skip to content

gpt-oss: implement harmony parsing #15181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/chat-parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ class common_chat_msg_partial_exception : public std::runtime_error {
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
};

class common_chat_msg_parse_exception : public std::runtime_error {
public:
common_chat_msg_parse_exception(const std::string & message) : std::runtime_error(message) {}
};

class common_chat_msg_parser {
std::string input_;
bool is_partial_;
Expand Down
211 changes: 205 additions & 6 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msg
}
if (!msg.reasoning_content.empty()) {
jmsg["reasoning_content"] = msg.reasoning_content;
jmsg["thinking"] = msg.reasoning_content; // gpt-oss
}
if (!msg.tool_name.empty()) {
jmsg["name"] = msg.tool_name;
Expand Down Expand Up @@ -1314,17 +1315,215 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_GPT_OSS;

// TODO: support tool calls in GPT-OSS?
if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
// tool calls can appear in commentary or analysis channels
auto channel = builder.add_rule("channel", "\"<|channel|>\" ( \"commentary\" | \"analysis\" )");

std::vector<std::string> tool_rules_recipient_in_role;
std::vector<std::string> tool_rules_recipient_in_channel;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);

tool_rules_recipient_in_role.push_back(
builder.add_rule(name + "-call",
"\"" + name + "\"" + channel + " \" <|constrain|>json\"? \"<|message|>\" " +
builder.add_schema(name + "-args", parameters)
)
);

tool_rules_recipient_in_channel.push_back(
builder.add_rule(name + "-call",
"\"" + name + "\"" + " \" <|constrain|>json\"? \"<|message|>\" " +
builder.add_schema(name + "-args", parameters)
)
);
});

auto recipient_in_role = builder.add_rule("recipient_in_role",
"\"<|start|>assistant\"? \" to=functions.\" " +
string_join(tool_rules_recipient_in_role, " | ")
);

auto recipient_in_channel = builder.add_rule("recipient_in_channel",
channel + " \" to=functions.\" " +
string_join(tool_rules_recipient_in_channel, " | ")
);

builder.add_rule("root", recipient_in_role + " | " + recipient_in_channel);

// Trigger on tool calls that appear in the commentary channel
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
"<\\|channel\\|>(commentary|analysis) to"
});

// Trigger tool calls that appear in the role section, either at the
// start or in the middle.
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
"^ to"
});

data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
"<\\|start\\|>assistant to"
});

data.preserved_tokens = {
"<|channel|>",
"<|constrain|>",
"<|message|>",
"<|start|>",
"<|end|>",
};
});
}

return data;
}
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
// TODO @ngxson : this won't work with --special enabled, we should fix that
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
static const common_regex message_regex("<\\|message\\|>");
static const common_regex channel_regex("<\\|channel\\|>(final|analysis|commentary)");
static const common_regex tool_call_channel_regex("<\\|channel\\|>(commentary|analysis)");
static const common_regex start_regex("<\\|start\\|>assistant");
static const common_regex end_regex("<\\|end\\|>");
static const common_regex to_regex(" to=");
static const common_regex function_regex("functions\\.([a-zA-Z_][a-zA-Z0-9_]*)");
static const common_regex user_tool_call_regex("(?: <\\|constrain\\|>([a-zA-Z]+))?<\\|message\\|>");
static const common_regex builtin_tool_call_regex("(?:browser|python)[\\s\\S]*<\\|message\\|>");

// Save the start of the message so we can roll back when we encounter a tool call and parse_tool_calls == false.
size_t message_start_pos = 0;

// Similarly, save the channel start so we can roll back to defer reasoning parsing to builder.
size_t channel_start_pos = 0;

auto consume_until_next = [&](size_t from = std::string::npos) {
if (auto res = builder.try_find_regex(start_regex, from, false)) {
auto begin = res->groups[0].begin;
builder.move_to(begin);
return res->prelude;
}
return builder.consume_rest();
};

auto try_consume_message = [&]() {
if (builder.try_consume_regex(message_regex)) {
if (!builder.try_find_regex(end_regex)) {
builder.add_content(builder.consume_rest());
}
return true;
}
return false;
};

auto tool_call = [&](bool recipient_in_role) {
if (!builder.syntax().parse_tool_calls) {
// Move back to the start and consume up to the next message
builder.move_to(message_start_pos);
builder.add_content(consume_until_next(message_start_pos + 1));
return;
}

if (auto res = builder.try_consume_regex(function_regex)) {
auto name = builder.str(res->groups[1]);

if (recipient_in_role) {
if (!builder.try_consume_regex(tool_call_channel_regex)) {
throw common_chat_msg_parse_exception("expected <|channel|>(commentary|analysis), got: " + consume_until_next());
}
}

if (builder.try_consume_regex(user_tool_call_regex)) {
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
} else {
throw common_chat_msg_parse_exception("expected function args, got: " + consume_until_next());
}
} else if (builder.try_consume_regex(builtin_tool_call_regex)) {
builder.consume_rest();
LOG_ERR("builtin tool calls not implemented\n");
} else {
throw common_chat_msg_parse_exception("expected function name, got: " + consume_until_next());
}
};

auto commentary = [&]() {
if (builder.try_consume_regex(to_regex)) {
tool_call(false);
} else if (!try_consume_message()) {
throw common_chat_msg_parse_exception("expected: \" to=\" or <|message|>, got: " + consume_until_next());
}
};

auto analysis = [&]() {
if (builder.try_consume_regex(to_regex)) {
tool_call(false); // built-in tools can be called in the analysis channel
} else if (builder.try_consume_regex(message_regex)) {
// Defer reasoning parsing to builder
builder.move_to(channel_start_pos);

if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE) {
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>");
} else {
builder.add_content(consume_until_next());
}
} else {
throw common_chat_msg_parse_exception("expected: <|message|>, got: " + consume_until_next());
}
};

auto channel = [&](const common_chat_msg_parser::find_regex_result & match) {
auto type = builder.str(match.groups[1]);
if (type == "analysis") {
analysis();
} else if (type == "commentary") {
commentary();
} else if (type == "final") {
if (!try_consume_message()) {
throw common_chat_msg_parse_exception("expected: <|message|>, got: " + consume_until_next());
}
} else {
throw common_chat_msg_parse_exception("expected one of: [analysis, commentary, final], got: " + consume_until_next());
}
};

auto message = [&]() {
if (auto res = builder.try_consume_regex(channel_regex)) {
channel_start_pos = res->groups[0].begin;
channel(*res);
} else if (builder.try_consume_regex(to_regex)) {
tool_call(true);
} else {
throw common_chat_msg_parse_exception("expected: <|channel|> or \" to\", got: " + consume_until_next());
}
};

try {
message();
} catch (const common_chat_msg_parse_exception & e) {
LOG_DBG("Parse error: %s\n", e.what());
}

// Read in complete messages until done or partial exception raised
while (auto res = builder.try_consume_regex(start_regex)) {
message_start_pos = res->groups[0].begin;
try {
message();
} catch (const common_chat_msg_parse_exception & e) {
LOG_DBG("Parse error: %s\n", e.what());
}
}

builder.consume_rest();
}

static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
Expand Down
Loading
Loading