From 78f74d6fcef2db1df72b9242c027080cf50ccc55 Mon Sep 17 00:00:00 2001
From: Jeremie Pardou <571533+jrmi@users.noreply.github.com>
Date: Mon, 30 Dec 2024 22:47:55 +0100
Subject: [PATCH 1/3] fix: tool response broken after user interruption
---
gptme/chat.py | 3 ---
gptme/tools/__init__.py | 10 +++++++++-
gptme/tools/base.py | 5 ++---
3 files changed, 11 insertions(+), 7 deletions(-)
diff --git a/gptme/chat.py b/gptme/chat.py
index bb70e71b..7d1648f1 100644
--- a/gptme/chat.py
+++ b/gptme/chat.py
@@ -255,9 +255,6 @@ def step(
if msg_response:
yield msg_response.replace(quiet=True)
yield from execute_msg(msg_response, confirm)
- except KeyboardInterrupt:
- clear_interruptible()
- yield Message("system", "User hit Ctrl-c to interrupt the process")
finally:
clear_interruptible()
diff --git a/gptme/tools/__init__.py b/gptme/tools/__init__.py
index 3288666b..de7381a8 100644
--- a/gptme/tools/__init__.py
+++ b/gptme/tools/__init__.py
@@ -113,7 +113,15 @@ def execute_msg(msg: Message, confirm: ConfirmFunc) -> Generator[Message, None,
for tooluse in ToolUse.iter_from_content(msg.content):
if tooluse.is_runnable:
- yield from tooluse.execute(confirm)
+ try:
+ for tool_response in tooluse.execute(confirm):
+ yield tool_response.replace(call_id=tooluse.call_id)
+ except KeyboardInterrupt:
+ yield Message(
+ "system",
+ "User hit Ctrl-c to interrupt the process",
+ call_id=tooluse.call_id,
+ )
# Called often when checking streaming output for executable blocks,
diff --git a/gptme/tools/base.py b/gptme/tools/base.py
index 7b5f29c0..6aa5a2c2 100644
--- a/gptme/tools/base.py
+++ b/gptme/tools/base.py
@@ -287,10 +287,9 @@ def execute(self, confirm: ConfirmFunc) -> Generator[Message, None, None]:
confirm,
)
if isinstance(ex, Generator):
- for msg in ex:
- yield msg.replace(call_id=self.call_id)
+ yield from ex
else:
- yield ex.replace(call_id=self.call_id)
+ yield ex
except Exception as e:
# if we are testing, raise the exception
logger.exception(e)
From 363865f10abfc104a8952ce2c4ca6f089e2985a0 Mon Sep 17 00:00:00 2001
From: Jeremie Pardou <571533+jrmi@users.noreply.github.com>
Date: Tue, 31 Dec 2024 00:10:27 +0100
Subject: [PATCH 2/3] fix: broken tool call after editing the file before
saving
---
gptme/llm/llm_anthropic.py | 22 +++++++++++++++-----
gptme/llm/llm_openai.py | 40 ++++++++++++++++++++++++++++++++++++-
gptme/tools/__init__.py | 4 ++++
tests/test_llm_anthropic.py | 8 +++++++-
tests/test_llm_openai.py | 6 +++++-
5 files changed, 72 insertions(+), 8 deletions(-)
diff --git a/gptme/llm/llm_anthropic.py b/gptme/llm/llm_anthropic.py
index 393f6b02..fbc85f55 100644
--- a/gptme/llm/llm_anthropic.py
+++ b/gptme/llm/llm_anthropic.py
@@ -200,9 +200,8 @@ def stream(
def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
for message in message_dicts:
# Format tool result as expected by the model
- if message["role"] == "system" and "call_id" in message:
+ if message["role"] == "user" and "call_id" in message:
modified_message = dict(message)
- modified_message["role"] = "user"
modified_message["content"] = [
{
"type": "tool_result",
@@ -358,22 +357,35 @@ def _transform_system_messages(
# unless a `call_id` is present, indicating the tool_format is 'tool'.
# Tool responses are handled separately by _handle_tool.
for i, message in enumerate(messages):
- if message.role == "system" and message.call_id is None:
+ if message.role == "system":
+ content = (
+ f"{message.content}"
+ if message.call_id is None
+ else message.content
+ )
+
messages[i] = Message(
"user",
- content=f"{message.content}",
+ content=content,
files=message.files, # type: ignore
+ call_id=message.call_id,
)
# find consecutive user role messages and merge them together
messages_new: list[Message] = []
while messages:
message = messages.pop(0)
- if messages_new and messages_new[-1].role == "user" and message.role == "user":
+ if (
+ messages_new
+ and messages_new[-1].role == "user"
+ and message.role == "user"
+ and message.call_id == messages_new[-1].call_id
+ ):
messages_new[-1] = Message(
"user",
content=f"{messages_new[-1].content}\n\n{message.content}",
files=messages_new[-1].files + message.files, # type: ignore
+ call_id=messages_new[-1].call_id,
)
else:
messages_new.append(message)
diff --git a/gptme/llm/llm_openai.py b/gptme/llm/llm_openai.py
index 38228a22..82ae6a77 100644
--- a/gptme/llm/llm_openai.py
+++ b/gptme/llm/llm_openai.py
@@ -274,6 +274,7 @@ def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
modified_message["content"] = content
if tool_calls:
+ # Clean content property if empty otherwise the call fails
if not content:
del modified_message["content"]
modified_message["tool_calls"] = tool_calls
@@ -283,6 +284,41 @@ def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
yield message
+def _merge_tool_results_with_same_call_id(
+ messages_dicts: Iterable[dict],
+) -> list[dict]: # Generator[dict, None, None]:
+ """
+ When we call a tool, this tool can potentially yield multiple messages. However
+ the API expect to have only one tool result per tool call. This function tries
+ to merge subsequent tool results with the same call ID as expected by
+ the API.
+ """
+
+ messages_dicts = iter(messages_dicts)
+
+ messages_new: list[dict] = []
+ while message := next(messages_dicts, None):
+ if messages_new and (
+ message["role"] == "tool"
+ and messages_new[-1]["role"] == "tool"
+ and message["tool_call_id"] == messages_new[-1]["tool_call_id"]
+ ):
+ prev_msg = messages_new[-1]
+ content = message["content"]
+ if not isinstance(content, list):
+ content = {"type": "text", "text": content}
+
+ messages_new[-1] = {
+ "role": "tool",
+ "content": prev_msg["content"] + content,
+ "tool_call_id": prev_msg["tool_call_id"],
+ }
+ else:
+ messages_new.append(message)
+
+ return messages_new
+
+
def _process_file(msg: dict, model: ModelMeta) -> dict:
message_content = msg["content"]
if model.provider == "deepseek":
@@ -423,7 +459,9 @@ def _prepare_messages_for_api(
tools_dict = [_spec2tool(tool, model) for tool in tools] if tools else None
if tools_dict is not None:
- messages_dicts = _handle_tools(messages_dicts)
+ messages_dicts = _merge_tool_results_with_same_call_id(
+ _handle_tools(messages_dicts)
+ )
messages_dicts = _transform_msgs_for_special_provider(messages_dicts, model)
diff --git a/gptme/tools/__init__.py b/gptme/tools/__init__.py
index de7381a8..e110779f 100644
--- a/gptme/tools/__init__.py
+++ b/gptme/tools/__init__.py
@@ -4,6 +4,8 @@
from gptme.config import get_config
+from ..util.interrupt import clear_interruptible
+
from ..message import Message
from .base import (
ToolFormat,
@@ -117,11 +119,13 @@ def execute_msg(msg: Message, confirm: ConfirmFunc) -> Generator[Message, None,
for tool_response in tooluse.execute(confirm):
yield tool_response.replace(call_id=tooluse.call_id)
except KeyboardInterrupt:
+ clear_interruptible()
yield Message(
"system",
"User hit Ctrl-c to interrupt the process",
call_id=tooluse.call_id,
)
+ break
# Called often when checking streaming output for executable blocks,
diff --git a/tests/test_llm_anthropic.py b/tests/test_llm_anthropic.py
index 8a5081e5..869f7cf3 100644
--- a/tests/test_llm_anthropic.py
+++ b/tests/test_llm_anthropic.py
@@ -97,6 +97,7 @@ def test_message_conversion_with_tools():
content='\nSomething\n\n@save(tool_call_id): {"path": "path.txt", "content": "file_content"}',
),
Message(role="system", content="Saved to toto.txt", call_id="tool_call_id"),
+ Message(role="system", content="(Modified by user)", call_id="tool_call_id"),
]
tool_save = get_tool("save")
@@ -152,7 +153,12 @@ def test_message_conversion_with_tools():
"content": [
{
"type": "tool_result",
- "content": [{"type": "text", "text": "Saved to toto.txt"}],
+ "content": [
+ {
+ "type": "text",
+ "text": "Saved to toto.txt\n\n(Modified by user)",
+ }
+ ],
"tool_use_id": "tool_call_id",
"cache_control": {"type": "ephemeral"},
}
diff --git a/tests/test_llm_openai.py b/tests/test_llm_openai.py
index 34d75817..401ae99f 100644
--- a/tests/test_llm_openai.py
+++ b/tests/test_llm_openai.py
@@ -116,6 +116,7 @@ def test_message_conversion_with_tools():
content='\n@save(tool_call_id): {"path": "path.txt", "content": "file_content"}',
),
Message(role="system", content="Saved to toto.txt", call_id="tool_call_id"),
+ Message(role="system", content="(Modified by user)", call_id="tool_call_id"),
]
set_default_model("openai/gpt-4o")
@@ -193,7 +194,10 @@ def test_message_conversion_with_tools():
},
{
"role": "tool",
- "content": [{"type": "text", "text": "Saved to toto.txt"}],
+ "content": [
+ {"type": "text", "text": "Saved to toto.txt"},
+ {"type": "text", "text": "(Modified by user)"},
+ ],
"tool_call_id": "tool_call_id",
},
]
From 681bff4abd539645619cd0a65eef88359892e948 Mon Sep 17 00:00:00 2001
From: Jeremie Pardou <571533+jrmi@users.noreply.github.com>
Date: Sun, 5 Jan 2025 19:06:28 +0100
Subject: [PATCH 3/3] fix: shorten interrupt message content
---
gptme/chat.py | 10 +++-------
gptme/commands.py | 4 +++-
gptme/constants.py | 3 +++
gptme/tools/__init__.py | 4 +++-
4 files changed, 12 insertions(+), 9 deletions(-)
diff --git a/gptme/chat.py b/gptme/chat.py
index 7d1648f1..a02d5d5f 100644
--- a/gptme/chat.py
+++ b/gptme/chat.py
@@ -11,7 +11,7 @@
from .commands import action_descriptions, execute_cmd
from .config import get_config
-from .constants import PROMPT_USER
+from .constants import INTERRUPT_CONTENT, PROMPT_USER
from .init import init
from .llm import reply
from .llm.models import get_model
@@ -148,11 +148,7 @@ def confirm_func(msg) -> bool:
)
except KeyboardInterrupt:
console.log("Interrupted. Stopping current execution.")
- manager.append(
- Message(
- "system", "User hit Ctrl-c to interrupt the process"
- )
- )
+ manager.append(Message("system", INTERRUPT_CONTENT))
break
finally:
clear_interruptible()
@@ -225,7 +221,7 @@ def step(
if (
not last_msg
or (last_msg.role in ["assistant"])
- or last_msg.content == "Interrupted"
+ or last_msg.content == INTERRUPT_CONTENT
or last_msg.pinned
or not any(role == "user" for role in [m.role for m in log])
): # pragma: no cover
diff --git a/gptme/commands.py b/gptme/commands.py
index 81dc4935..258bb5f8 100644
--- a/gptme/commands.py
+++ b/gptme/commands.py
@@ -6,6 +6,8 @@
from time import sleep
from typing import Literal
+from .constants import INTERRUPT_CONTENT
+
from . import llm
from .logmanager import LogManager, prepare_messages
from .message import (
@@ -182,7 +184,7 @@ def edit(manager: LogManager) -> Generator[Message, None, None]: # pragma: no c
try:
sleep(1)
except KeyboardInterrupt:
- yield Message("system", "User hit Ctrl-c to interrupt the process")
+ yield Message("system", INTERRUPT_CONTENT)
return
manager.edit(list(reversed(res)))
print("Applied edited messages, write /log to see the result")
diff --git a/gptme/constants.py b/gptme/constants.py
index 224ba140..de519545 100644
--- a/gptme/constants.py
+++ b/gptme/constants.py
@@ -26,3 +26,6 @@
PROMPT_ASSISTANT = (
f"[bold {ROLE_COLOR['assistant']}]Assistant[/bold {ROLE_COLOR['assistant']}]"
)
+
+
+INTERRUPT_CONTENT = "Interrupted by user"
diff --git a/gptme/tools/__init__.py b/gptme/tools/__init__.py
index e110779f..b90fb9f0 100644
--- a/gptme/tools/__init__.py
+++ b/gptme/tools/__init__.py
@@ -4,6 +4,8 @@
from gptme.config import get_config
+from gptme.constants import INTERRUPT_CONTENT
+
from ..util.interrupt import clear_interruptible
from ..message import Message
@@ -122,7 +124,7 @@ def execute_msg(msg: Message, confirm: ConfirmFunc) -> Generator[Message, None,
clear_interruptible()
yield Message(
"system",
- "User hit Ctrl-c to interrupt the process",
+ INTERRUPT_CONTENT,
call_id=tooluse.call_id,
)
break