-
-
Notifications
You must be signed in to change notification settings - Fork 210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: tool api call broken when user answer no to when asked for confirmation #371
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
from pathlib import Path | ||
|
||
from .commands import action_descriptions, execute_cmd | ||
from .constants import PROMPT_USER | ||
from .constants import INTERRUPT_CONTENT, PROMPT_USER | ||
from .init import init | ||
from .llm import reply | ||
from .llm.models import get_model | ||
|
@@ -136,11 +136,7 @@ def confirm_func(msg) -> bool: | |
) | ||
except KeyboardInterrupt: | ||
console.log("Interrupted. Stopping current execution.") | ||
manager.append( | ||
Message( | ||
"system", "User hit Ctrl-c to interrupt the process" | ||
) | ||
) | ||
manager.append(Message("system", INTERRUPT_CONTENT)) | ||
break | ||
finally: | ||
clear_interruptible() | ||
|
@@ -209,7 +205,7 @@ def step( | |
if ( | ||
not last_msg | ||
or (last_msg.role in ["assistant"]) | ||
or last_msg.content == "Interrupted" | ||
or last_msg.content == INTERRUPT_CONTENT | ||
or last_msg.pinned | ||
or not any(role == "user" for role in [m.role for m in log]) | ||
): # pragma: no cover | ||
|
@@ -239,9 +235,6 @@ def step( | |
if msg_response: | ||
yield msg_response.replace(quiet=True) | ||
yield from execute_msg(msg_response, confirm) | ||
except KeyboardInterrupt: | ||
clear_interruptible() | ||
yield Message("system", "User hit Ctrl-c to interrupt the process") | ||
finally: | ||
clear_interruptible() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if this is correct, I think the I think the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nevermind, it's there in the except clause for that very reason, but no longer needed. Not quite clear to me if the new behavior is correct though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm no it's probably not the same. Let me find something. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've updated the MR. To stay on the safe side, I've called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could also call |
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -200,9 +200,8 @@ def stream( | |
def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]: | ||
for message in message_dicts: | ||
# Format tool result as expected by the model | ||
if message["role"] == "system" and "call_id" in message: | ||
if message["role"] == "user" and "call_id" in message: | ||
modified_message = dict(message) | ||
modified_message["role"] = "user" | ||
modified_message["content"] = [ | ||
{ | ||
"type": "tool_result", | ||
|
@@ -358,22 +357,35 @@ def _transform_system_messages( | |
# unless a `call_id` is present, indicating the tool_format is 'tool'. | ||
# Tool responses are handled separately by _handle_tool. | ||
for i, message in enumerate(messages): | ||
if message.role == "system" and message.call_id is None: | ||
if message.role == "system": | ||
content = ( | ||
f"<system>{message.content}</system>" | ||
if message.call_id is None | ||
else message.content | ||
) | ||
|
||
messages[i] = Message( | ||
"user", | ||
content=f"<system>{message.content}</system>", | ||
content=content, | ||
files=message.files, # type: ignore | ||
call_id=message.call_id, | ||
) | ||
|
||
# find consecutive user role messages and merge them together | ||
messages_new: list[Message] = [] | ||
while messages: | ||
message = messages.pop(0) | ||
if messages_new and messages_new[-1].role == "user" and message.role == "user": | ||
if ( | ||
messages_new | ||
and messages_new[-1].role == "user" | ||
and message.role == "user" | ||
and message.call_id == messages_new[-1].call_id | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this condition needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When using the function call API, multiple calls may occur within a single answer (parallel calls). To handle this properly, ensure there is one tool response per tool call. By verifying that the call_id is distinct, we can ensure messages are not merged unless they originate from the same function call. |
||
): | ||
messages_new[-1] = Message( | ||
"user", | ||
content=f"{messages_new[-1].content}\n\n{message.content}", | ||
files=messages_new[-1].files + message.files, # type: ignore | ||
call_id=messages_new[-1].call_id, | ||
) | ||
else: | ||
messages_new.append(message) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now handled in the
execute_msg
as we need thecall_id
in the message response.