Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: tool api call broken when user answer no to when asked for confirmation #371

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions gptme/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pathlib import Path

from .commands import action_descriptions, execute_cmd
from .constants import PROMPT_USER
from .constants import INTERRUPT_CONTENT, PROMPT_USER
from .init import init
from .llm import reply
from .llm.models import get_model
Expand Down Expand Up @@ -136,11 +136,7 @@ def confirm_func(msg) -> bool:
)
except KeyboardInterrupt:
console.log("Interrupted. Stopping current execution.")
manager.append(
Message(
"system", "User hit Ctrl-c to interrupt the process"
)
)
manager.append(Message("system", INTERRUPT_CONTENT))
break
finally:
clear_interruptible()
Expand Down Expand Up @@ -209,7 +205,7 @@ def step(
if (
not last_msg
or (last_msg.role in ["assistant"])
or last_msg.content == "Interrupted"
or last_msg.content == INTERRUPT_CONTENT
or last_msg.pinned
or not any(role == "user" for role in [m.role for m in log])
): # pragma: no cover
Expand Down Expand Up @@ -239,9 +235,6 @@ def step(
if msg_response:
yield msg_response.replace(quiet=True)
yield from execute_msg(msg_response, confirm)
except KeyboardInterrupt:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now handled in the execute_msg as we need the call_id in the message response.

clear_interruptible()
yield Message("system", "User hit Ctrl-c to interrupt the process")
finally:
clear_interruptible()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is correct, I think the clear_interruptible is no longer called?

I think the finally clause might not run for the generator until the next() call.

Copy link
Owner

@ErikBjare ErikBjare Jan 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, it's there in the except clause for that very reason, but no longer needed.

Not quite clear to me if the new behavior is correct though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm no it's probably not the same. Let me find something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the MR. To stay on the safe side, I've called clear_interruptible in the execute_msg function. Notice the break in the except clause. Before it was still calling subsequent tools if any after a user interruption 🤦

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also call clear_interruptible inside the handle_keyboard_interrupt itself just before raising the KeyboardInterrupt exception, what do you think?


Expand Down
4 changes: 3 additions & 1 deletion gptme/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from time import sleep
from typing import Literal

from .constants import INTERRUPT_CONTENT

from . import llm
from .logmanager import LogManager, prepare_messages
from .message import (
Expand Down Expand Up @@ -183,7 +185,7 @@ def edit(manager: LogManager) -> Generator[Message, None, None]: # pragma: no c
try:
sleep(1)
except KeyboardInterrupt:
yield Message("system", "User hit Ctrl-c to interrupt the process")
yield Message("system", INTERRUPT_CONTENT)
return
manager.edit(list(reversed(res)))
print("Applied edited messages, write /log to see the result")
Expand Down
3 changes: 3 additions & 0 deletions gptme/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@
PROMPT_ASSISTANT = (
f"[bold {ROLE_COLOR['assistant']}]Assistant[/bold {ROLE_COLOR['assistant']}]"
)


INTERRUPT_CONTENT = "Interrupted by user"
22 changes: 17 additions & 5 deletions gptme/llm/llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,8 @@ def stream(
def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
for message in message_dicts:
# Format tool result as expected by the model
if message["role"] == "system" and "call_id" in message:
if message["role"] == "user" and "call_id" in message:
modified_message = dict(message)
modified_message["role"] = "user"
modified_message["content"] = [
{
"type": "tool_result",
Expand Down Expand Up @@ -358,22 +357,35 @@ def _transform_system_messages(
# unless a `call_id` is present, indicating the tool_format is 'tool'.
# Tool responses are handled separately by _handle_tool.
for i, message in enumerate(messages):
if message.role == "system" and message.call_id is None:
if message.role == "system":
content = (
f"<system>{message.content}</system>"
if message.call_id is None
else message.content
)

messages[i] = Message(
"user",
content=f"<system>{message.content}</system>",
content=content,
files=message.files, # type: ignore
call_id=message.call_id,
)

# find consecutive user role messages and merge them together
messages_new: list[Message] = []
while messages:
message = messages.pop(0)
if messages_new and messages_new[-1].role == "user" and message.role == "user":
if (
messages_new
and messages_new[-1].role == "user"
and message.role == "user"
and message.call_id == messages_new[-1].call_id
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this condition needed?

Copy link
Contributor Author

@jrmi jrmi Jan 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using the function call API, multiple calls may occur within a single answer (parallel calls). To handle this properly, ensure there is one tool response per tool call. By verifying that the call_id is distinct, we can ensure messages are not merged unless they originate from the same function call.

):
messages_new[-1] = Message(
"user",
content=f"{messages_new[-1].content}\n\n{message.content}",
files=messages_new[-1].files + message.files, # type: ignore
call_id=messages_new[-1].call_id,
)
else:
messages_new.append(message)
Expand Down
40 changes: 39 additions & 1 deletion gptme/llm/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
modified_message["content"] = content

if tool_calls:
# Clean content property if empty otherwise the call fails
if not content:
del modified_message["content"]
modified_message["tool_calls"] = tool_calls
Expand All @@ -283,6 +284,41 @@ def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
yield message


def _merge_tool_results_with_same_call_id(
messages_dicts: Iterable[dict],
) -> list[dict]: # Generator[dict, None, None]:
"""
When we call a tool, this tool can potentially yield multiple messages. However
the API expect to have only one tool result per tool call. This function tries
to merge subsequent tool results with the same call ID as expected by
the API.
"""

messages_dicts = iter(messages_dicts)

messages_new: list[dict] = []
while message := next(messages_dicts, None):
if messages_new and (
message["role"] == "tool"
and messages_new[-1]["role"] == "tool"
and message["tool_call_id"] == messages_new[-1]["tool_call_id"]
):
prev_msg = messages_new[-1]
content = message["content"]
if not isinstance(content, list):
content = {"type": "text", "text": content}

messages_new[-1] = {
"role": "tool",
"content": prev_msg["content"] + content,
"tool_call_id": prev_msg["tool_call_id"],
}
else:
messages_new.append(message)

return messages_new


def _process_file(msg: dict, model: ModelMeta) -> dict:
message_content = msg["content"]
if model.provider == "deepseek":
Expand Down Expand Up @@ -423,7 +459,9 @@ def _prepare_messages_for_api(
tools_dict = [_spec2tool(tool, model) for tool in tools] if tools else None

if tools_dict is not None:
messages_dicts = _handle_tools(messages_dicts)
messages_dicts = _merge_tool_results_with_same_call_id(
_handle_tools(messages_dicts)
)

messages_dicts = _transform_msgs_for_special_provider(messages_dicts, model)

Expand Down
16 changes: 15 additions & 1 deletion gptme/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from collections.abc import Generator
from functools import lru_cache

from gptme.constants import INTERRUPT_CONTENT

from ..util.interrupt import clear_interruptible

from ..message import Message
from .base import (
ConfirmFunc,
Expand Down Expand Up @@ -120,7 +124,17 @@ def execute_msg(msg: Message, confirm: ConfirmFunc) -> Generator[Message, None,

for tooluse in ToolUse.iter_from_content(msg.content):
if tooluse.is_runnable:
yield from tooluse.execute(confirm)
try:
for tool_response in tooluse.execute(confirm):
yield tool_response.replace(call_id=tooluse.call_id)
except KeyboardInterrupt:
clear_interruptible()
yield Message(
"system",
INTERRUPT_CONTENT,
call_id=tooluse.call_id,
)
break


# Called often when checking streaming output for executable blocks,
Expand Down
5 changes: 2 additions & 3 deletions gptme/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,9 @@ def execute(self, confirm: ConfirmFunc) -> Generator[Message, None, None]:
confirm,
)
if isinstance(ex, Generator):
for msg in ex:
yield msg.replace(call_id=self.call_id)
yield from ex
ErikBjare marked this conversation as resolved.
Show resolved Hide resolved
else:
yield ex.replace(call_id=self.call_id)
yield ex
except Exception as e:
# if we are testing, raise the exception
logger.exception(e)
Expand Down
8 changes: 7 additions & 1 deletion tests/test_llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def test_message_conversion_with_tools():
content='<thinking>\nSomething\n</thinking>\n@save(tool_call_id): {"path": "path.txt", "content": "file_content"}',
),
Message(role="system", content="Saved to toto.txt", call_id="tool_call_id"),
Message(role="system", content="(Modified by user)", call_id="tool_call_id"),
]

tool_save = get_tool("save")
Expand Down Expand Up @@ -152,7 +153,12 @@ def test_message_conversion_with_tools():
"content": [
{
"type": "tool_result",
"content": [{"type": "text", "text": "Saved to toto.txt"}],
"content": [
{
"type": "text",
"text": "Saved to toto.txt\n\n(Modified by user)",
}
],
"tool_use_id": "tool_call_id",
"cache_control": {"type": "ephemeral"},
}
Expand Down
6 changes: 5 additions & 1 deletion tests/test_llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def test_message_conversion_with_tools():
content='\n@save(tool_call_id): {"path": "path.txt", "content": "file_content"}',
),
Message(role="system", content="Saved to toto.txt", call_id="tool_call_id"),
Message(role="system", content="(Modified by user)", call_id="tool_call_id"),
]

set_default_model("openai/gpt-4o")
Expand Down Expand Up @@ -193,7 +194,10 @@ def test_message_conversion_with_tools():
},
{
"role": "tool",
"content": [{"type": "text", "text": "Saved to toto.txt"}],
"content": [
{"type": "text", "text": "Saved to toto.txt"},
{"type": "text", "text": "(Modified by user)"},
],
"tool_call_id": "tool_call_id",
},
]
Loading