-
Notifications
You must be signed in to change notification settings - Fork 19.4k
fix(core): remove orphaned ToolMessages in trim_messages #33268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1440,11 +1440,10 @@ def _first_max_tokens( | |
| # When all messages fit, only apply end_on filtering if needed | ||
| if end_on: | ||
| for _ in range(len(messages)): | ||
| if not _is_message_type(messages[-1], end_on): | ||
| messages.pop() | ||
| else: | ||
| if not messages or _is_message_type(messages[-1], end_on): | ||
| break | ||
| return messages | ||
| messages.pop() | ||
| return _remove_orphaned_tool_messages(messages) | ||
|
|
||
| # Use binary search to find the maximum number of messages within token limit | ||
| left, right = 0, len(messages) | ||
|
|
@@ -1535,7 +1534,7 @@ def _first_max_tokens( | |
| else: | ||
| break | ||
|
|
||
| return messages[:idx] | ||
| return _remove_orphaned_tool_messages(messages[:idx]) | ||
|
|
||
|
|
||
| def _last_max_tokens( | ||
|
|
@@ -1594,7 +1593,41 @@ def _last_max_tokens( | |
| if system_message: | ||
| result = [system_message, *result] | ||
|
|
||
| return result | ||
| return _remove_orphaned_tool_messages(result) | ||
|
|
||
|
|
||
| def _remove_orphaned_tool_messages( | ||
| messages: Sequence[BaseMessage], | ||
| ) -> list[BaseMessage]: | ||
| """Drop tool messages whose corresponding tool calls are absent.""" | ||
| if not messages: | ||
| return [] | ||
|
|
||
| valid_tool_call_ids: set[str] = set() | ||
| for message in messages: | ||
| if isinstance(message, AIMessage): | ||
| if message.tool_calls: | ||
| for tool_call in message.tool_calls: | ||
| tool_call_id = tool_call.get("id") | ||
| if tool_call_id: | ||
| valid_tool_call_ids.add(tool_call_id) | ||
| if isinstance(message.content, list): | ||
| for block in message.content: | ||
| if ( | ||
| isinstance(block, dict) | ||
| and block.get("type") == "tool_use" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this branch is needed, |
||
| and block.get("id") | ||
| ): | ||
| valid_tool_call_ids.add(block["id"]) | ||
|
|
||
| cleaned_messages: list[BaseMessage] = [] | ||
| for message in messages: | ||
| if isinstance(message, ToolMessage) and ( | ||
| not valid_tool_call_ids or message.tool_call_id not in valid_tool_call_ids | ||
| ): | ||
| continue | ||
| cleaned_messages.append(message) | ||
| return cleaned_messages | ||
|
|
||
|
|
||
| _MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -393,6 +393,84 @@ def test_trim_messages_last_30_include_system_allow_partial_end_on_human() -> No | |
| assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY | ||
|
|
||
|
|
||
| def test_trim_messages_last_removes_orphaned_tool_message() -> None: | ||
| messages = [ | ||
| HumanMessage("What's the weather in Florida?"), | ||
| AIMessage( | ||
| [ | ||
| {"type": "text", "text": "Let's check the weather in Florida"}, | ||
| { | ||
| "type": "tool_use", | ||
| "id": "abc123", | ||
| "name": "get_weather", | ||
| "input": {"location": "Florida"}, | ||
| }, | ||
| ], | ||
| tool_calls=[ | ||
| { | ||
| "name": "get_weather", | ||
| "args": {"location": "Florida"}, | ||
| "id": "abc123", | ||
| "type": "tool_call", | ||
| } | ||
| ], | ||
| ), | ||
| ToolMessage("It's sunny.", name="get_weather", tool_call_id="abc123"), | ||
| HumanMessage("I see"), | ||
| AIMessage("Do you want to know anything else?"), | ||
| HumanMessage("No, thanks"), | ||
| AIMessage("You're welcome! Have a great day!"), | ||
| ] | ||
|
|
||
| trimmed = trim_messages( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this test passes if we add to out of curiosity, would that update be sufficient for your use case? |
||
| messages, | ||
| strategy="last", | ||
| token_counter=len, | ||
| max_tokens=5, | ||
| ) | ||
|
|
||
| expected = [ | ||
| HumanMessage("I see"), | ||
| AIMessage("Do you want to know anything else?"), | ||
| HumanMessage("No, thanks"), | ||
| AIMessage("You're welcome! Have a great day!"), | ||
| ] | ||
|
|
||
| assert trimmed == expected | ||
| assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this condition was just checking that we don't mutate |
||
|
|
||
|
|
||
| def test_trim_messages_last_preserves_tool_message_when_call_present() -> None: | ||
| messages = [ | ||
| HumanMessage("Start"), | ||
| AIMessage( | ||
| "Sure, let me check", | ||
| tool_calls=[ | ||
| { | ||
| "name": "search", | ||
| "args": {"query": "status"}, | ||
| "id": "tool-1", | ||
| "type": "tool_call", | ||
| } | ||
| ], | ||
| ), | ||
| ToolMessage("All systems operational", tool_call_id="tool-1"), | ||
| HumanMessage("Thanks"), | ||
| ] | ||
|
|
||
| trimmed = trim_messages( | ||
| messages, | ||
| strategy="last", | ||
| token_counter=len, | ||
| max_tokens=3, | ||
| ) | ||
|
|
||
| expected = messages[1:] | ||
|
|
||
| assert trimmed == expected | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this test passes out of the box without the |
||
| assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY | ||
|
|
||
|
|
||
| def test_trim_messages_last_40_include_system_allow_partial_start_on_human() -> None: | ||
| expected = [ | ||
| SystemMessage("This is a 4 token text."), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to modify this logic?