diff --git a/libs/langchain_v1/langchain/agents/middleware/summarization.py b/libs/langchain_v1/langchain/agents/middleware/summarization.py index 6ba6221206ca6..5ce2538876d45 100644 --- a/libs/langchain_v1/langchain/agents/middleware/summarization.py +++ b/libs/langchain_v1/langchain/agents/middleware/summarization.py @@ -104,11 +104,12 @@ def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | messages = state["messages"] self._ensure_message_ids(messages) + # If max_tokens_before_summary is None, summarization is disabled + if self.max_tokens_before_summary is None: + return None + total_tokens = self.token_counter(messages) - if ( - self.max_tokens_before_summary is not None - and total_tokens < self.max_tokens_before_summary - ): + if total_tokens < self.max_tokens_before_summary: return None cutoff_index = self._find_safe_cutoff(messages) diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py index 02fa96e6b65af..5434b50ba19d4 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py @@ -1056,12 +1056,18 @@ def test_summarization_middleware_no_summarization_cases() -> None: model = FakeToolCallingModel() middleware = SummarizationMiddleware(model=model, max_tokens_before_summary=1000) - # Test when summarization is disabled + # Test when summarization is disabled with few messages middleware_disabled = SummarizationMiddleware(model=model, max_tokens_before_summary=None) state = {"messages": [HumanMessage(content="Hello"), AIMessage(content="Hi")]} result = middleware_disabled.before_model(state, None) assert result is None + # Test when summarization is disabled with many messages (exceeding messages_to_keep) + many_messages = [HumanMessage(content=f"Message {i}") for i in range(25)] + state_many = {"messages": many_messages} + result = middleware_disabled.before_model(state_many, None) + assert result is None + # Test when token count is below threshold def mock_token_counter(messages): return 500 # Below threshold