diff --git a/gptme/chat.py b/gptme/chat.py index bb70e71b..a02d5d5f 100644 --- a/gptme/chat.py +++ b/gptme/chat.py @@ -11,7 +11,7 @@ from .commands import action_descriptions, execute_cmd from .config import get_config -from .constants import PROMPT_USER +from .constants import INTERRUPT_CONTENT, PROMPT_USER from .init import init from .llm import reply from .llm.models import get_model @@ -148,11 +148,7 @@ def confirm_func(msg) -> bool: ) except KeyboardInterrupt: console.log("Interrupted. Stopping current execution.") - manager.append( - Message( - "system", "User hit Ctrl-c to interrupt the process" - ) - ) + manager.append(Message("system", INTERRUPT_CONTENT)) break finally: clear_interruptible() @@ -225,7 +221,7 @@ def step( if ( not last_msg or (last_msg.role in ["assistant"]) - or last_msg.content == "Interrupted" + or last_msg.content == INTERRUPT_CONTENT or last_msg.pinned or not any(role == "user" for role in [m.role for m in log]) ): # pragma: no cover @@ -255,9 +251,6 @@ def step( if msg_response: yield msg_response.replace(quiet=True) yield from execute_msg(msg_response, confirm) - except KeyboardInterrupt: - clear_interruptible() - yield Message("system", "User hit Ctrl-c to interrupt the process") finally: clear_interruptible() diff --git a/gptme/commands.py b/gptme/commands.py index 81dc4935..258bb5f8 100644 --- a/gptme/commands.py +++ b/gptme/commands.py @@ -6,6 +6,8 @@ from time import sleep from typing import Literal +from .constants import INTERRUPT_CONTENT + from . import llm from .logmanager import LogManager, prepare_messages from .message import ( @@ -182,7 +184,7 @@ def edit(manager: LogManager) -> Generator[Message, None, None]: # pragma: no c try: sleep(1) except KeyboardInterrupt: - yield Message("system", "User hit Ctrl-c to interrupt the process") + yield Message("system", INTERRUPT_CONTENT) return manager.edit(list(reversed(res))) print("Applied edited messages, write /log to see the result") diff --git a/gptme/constants.py b/gptme/constants.py index 224ba140..de519545 100644 --- a/gptme/constants.py +++ b/gptme/constants.py @@ -26,3 +26,6 @@ PROMPT_ASSISTANT = ( f"[bold {ROLE_COLOR['assistant']}]Assistant[/bold {ROLE_COLOR['assistant']}]" ) + + +INTERRUPT_CONTENT = "Interrupted by user" diff --git a/gptme/llm/llm_anthropic.py b/gptme/llm/llm_anthropic.py index 393f6b02..fbc85f55 100644 --- a/gptme/llm/llm_anthropic.py +++ b/gptme/llm/llm_anthropic.py @@ -200,9 +200,8 @@ def stream( def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]: for message in message_dicts: # Format tool result as expected by the model - if message["role"] == "system" and "call_id" in message: + if message["role"] == "user" and "call_id" in message: modified_message = dict(message) - modified_message["role"] = "user" modified_message["content"] = [ { "type": "tool_result", @@ -358,22 +357,35 @@ def _transform_system_messages( # unless a `call_id` is present, indicating the tool_format is 'tool'. # Tool responses are handled separately by _handle_tool. for i, message in enumerate(messages): - if message.role == "system" and message.call_id is None: + if message.role == "system": + content = ( + f"{message.content}" + if message.call_id is None + else message.content + ) + messages[i] = Message( "user", - content=f"{message.content}", + content=content, files=message.files, # type: ignore + call_id=message.call_id, ) # find consecutive user role messages and merge them together messages_new: list[Message] = [] while messages: message = messages.pop(0) - if messages_new and messages_new[-1].role == "user" and message.role == "user": + if ( + messages_new + and messages_new[-1].role == "user" + and message.role == "user" + and message.call_id == messages_new[-1].call_id + ): messages_new[-1] = Message( "user", content=f"{messages_new[-1].content}\n\n{message.content}", files=messages_new[-1].files + message.files, # type: ignore + call_id=messages_new[-1].call_id, ) else: messages_new.append(message) diff --git a/gptme/llm/llm_openai.py b/gptme/llm/llm_openai.py index 38228a22..82ae6a77 100644 --- a/gptme/llm/llm_openai.py +++ b/gptme/llm/llm_openai.py @@ -274,6 +274,7 @@ def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]: modified_message["content"] = content if tool_calls: + # Clean content property if empty otherwise the call fails if not content: del modified_message["content"] modified_message["tool_calls"] = tool_calls @@ -283,6 +284,41 @@ def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]: yield message +def _merge_tool_results_with_same_call_id( + messages_dicts: Iterable[dict], +) -> list[dict]: # Generator[dict, None, None]: + """ + When we call a tool, this tool can potentially yield multiple messages. However + the API expect to have only one tool result per tool call. This function tries + to merge subsequent tool results with the same call ID as expected by + the API. + """ + + messages_dicts = iter(messages_dicts) + + messages_new: list[dict] = [] + while message := next(messages_dicts, None): + if messages_new and ( + message["role"] == "tool" + and messages_new[-1]["role"] == "tool" + and message["tool_call_id"] == messages_new[-1]["tool_call_id"] + ): + prev_msg = messages_new[-1] + content = message["content"] + if not isinstance(content, list): + content = {"type": "text", "text": content} + + messages_new[-1] = { + "role": "tool", + "content": prev_msg["content"] + content, + "tool_call_id": prev_msg["tool_call_id"], + } + else: + messages_new.append(message) + + return messages_new + + def _process_file(msg: dict, model: ModelMeta) -> dict: message_content = msg["content"] if model.provider == "deepseek": @@ -423,7 +459,9 @@ def _prepare_messages_for_api( tools_dict = [_spec2tool(tool, model) for tool in tools] if tools else None if tools_dict is not None: - messages_dicts = _handle_tools(messages_dicts) + messages_dicts = _merge_tool_results_with_same_call_id( + _handle_tools(messages_dicts) + ) messages_dicts = _transform_msgs_for_special_provider(messages_dicts, model) diff --git a/gptme/tools/__init__.py b/gptme/tools/__init__.py index 3288666b..b90fb9f0 100644 --- a/gptme/tools/__init__.py +++ b/gptme/tools/__init__.py @@ -4,6 +4,10 @@ from gptme.config import get_config +from gptme.constants import INTERRUPT_CONTENT + +from ..util.interrupt import clear_interruptible + from ..message import Message from .base import ( ToolFormat, @@ -113,7 +117,17 @@ def execute_msg(msg: Message, confirm: ConfirmFunc) -> Generator[Message, None, for tooluse in ToolUse.iter_from_content(msg.content): if tooluse.is_runnable: - yield from tooluse.execute(confirm) + try: + for tool_response in tooluse.execute(confirm): + yield tool_response.replace(call_id=tooluse.call_id) + except KeyboardInterrupt: + clear_interruptible() + yield Message( + "system", + INTERRUPT_CONTENT, + call_id=tooluse.call_id, + ) + break # Called often when checking streaming output for executable blocks, diff --git a/gptme/tools/base.py b/gptme/tools/base.py index 7b5f29c0..6aa5a2c2 100644 --- a/gptme/tools/base.py +++ b/gptme/tools/base.py @@ -287,10 +287,9 @@ def execute(self, confirm: ConfirmFunc) -> Generator[Message, None, None]: confirm, ) if isinstance(ex, Generator): - for msg in ex: - yield msg.replace(call_id=self.call_id) + yield from ex else: - yield ex.replace(call_id=self.call_id) + yield ex except Exception as e: # if we are testing, raise the exception logger.exception(e) diff --git a/tests/test_llm_anthropic.py b/tests/test_llm_anthropic.py index 8a5081e5..869f7cf3 100644 --- a/tests/test_llm_anthropic.py +++ b/tests/test_llm_anthropic.py @@ -97,6 +97,7 @@ def test_message_conversion_with_tools(): content='\nSomething\n\n@save(tool_call_id): {"path": "path.txt", "content": "file_content"}', ), Message(role="system", content="Saved to toto.txt", call_id="tool_call_id"), + Message(role="system", content="(Modified by user)", call_id="tool_call_id"), ] tool_save = get_tool("save") @@ -152,7 +153,12 @@ def test_message_conversion_with_tools(): "content": [ { "type": "tool_result", - "content": [{"type": "text", "text": "Saved to toto.txt"}], + "content": [ + { + "type": "text", + "text": "Saved to toto.txt\n\n(Modified by user)", + } + ], "tool_use_id": "tool_call_id", "cache_control": {"type": "ephemeral"}, } diff --git a/tests/test_llm_openai.py b/tests/test_llm_openai.py index 34d75817..401ae99f 100644 --- a/tests/test_llm_openai.py +++ b/tests/test_llm_openai.py @@ -116,6 +116,7 @@ def test_message_conversion_with_tools(): content='\n@save(tool_call_id): {"path": "path.txt", "content": "file_content"}', ), Message(role="system", content="Saved to toto.txt", call_id="tool_call_id"), + Message(role="system", content="(Modified by user)", call_id="tool_call_id"), ] set_default_model("openai/gpt-4o") @@ -193,7 +194,10 @@ def test_message_conversion_with_tools(): }, { "role": "tool", - "content": [{"type": "text", "text": "Saved to toto.txt"}], + "content": [ + {"type": "text", "text": "Saved to toto.txt"}, + {"type": "text", "text": "(Modified by user)"}, + ], "tool_call_id": "tool_call_id", }, ]