diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 95f8e8dd8dd..447fa4f76c8 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -32,10 +32,12 @@ from .. import EVENT_LOGGER_NAME from ..base import Handoff as HandoffBase from ..base import Response +from ..memory._base_memory import Memory from ..messages import ( AgentEvent, ChatMessage, HandoffMessage, + MemoryQueryEvent, MultiModalMessage, TextMessage, ToolCallExecutionEvent, @@ -133,6 +135,7 @@ class AssistantAgent(BaseChatAgent): will be returned as the response. Available variables: `{tool_name}`, `{arguments}`, `{result}`. For example, `"{tool_name}: {result}"` will create a summary like `"tool_name: result"`. + memory (Sequence[Memory] | None, optional): The memory store to use for the agent. Defaults to `None`. Raises: ValueError: If tool names are not unique. @@ -253,9 +256,22 @@ def __init__( ) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", reflect_on_tool_use: bool = False, tool_call_summary_format: str = "{result}", + memory: Sequence[Memory] | None = None, ): super().__init__(name=name, description=description) self._model_client = model_client + self._memory = None + if memory is not None: + if isinstance(memory, Memory): + self._memory = [memory] + elif isinstance(memory, list): + self._memory = memory + else: + raise TypeError(f"Expected Memory, List[Memory], or None, got {type(memory)}") + + self._system_messages: List[ + SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage + ] = [] if system_message is None: self._system_messages = [] else: @@ -338,6 +354,15 @@ async def on_messages_stream( # Inner messages. inner_messages: List[AgentEvent | ChatMessage] = [] + # Update the model context with memory content. + if self._memory: + for memory in self._memory: + memory_query_result = await memory.transform(self._model_context) + if memory_query_result and len(memory_query_result) > 0: + memory_query_event_msg = MemoryQueryEvent(content=memory_query_result, source=self.name) + inner_messages.append(memory_query_event_msg) + yield memory_query_event_msg + # Generate an inference result based on the current model context. llm_messages = self._system_messages + await self._model_context.get_messages() result = await self._model_client.create( diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/__init__.py new file mode 100644 index 00000000000..beba13fcbc7 --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/__init__.py @@ -0,0 +1,10 @@ +from ._base_memory import Memory, MemoryContent, MemoryMimeType +from ._list_memory import ListMemory, ListMemoryConfig + +__all__ = [ + "Memory", + "MemoryContent", + "MemoryMimeType", + "ListMemory", + "ListMemoryConfig", +] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py new file mode 100644 index 00000000000..d392ccb8cf3 --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py @@ -0,0 +1,107 @@ +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Protocol, Union, runtime_checkable + +from autogen_core import CancellationToken, Image +from autogen_core.model_context import ChatCompletionContext +from pydantic import BaseModel, ConfigDict, Field + + +class MemoryMimeType(Enum): + """Supported MIME types for memory content.""" + + TEXT = "text/plain" + JSON = "application/json" + MARKDOWN = "text/markdown" + IMAGE = "image/*" + BINARY = "application/octet-stream" + + +ContentType = Union[str, bytes, Dict[str, Any], Image] + + +class MemoryContent(BaseModel): + content: ContentType + mime_type: MemoryMimeType | str + metadata: Dict[str, Any] | None = None + timestamp: datetime | None = None + source: str | None = None + score: float | None = None + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class BaseMemoryConfig(BaseModel): + """Base configuration for memory implementations.""" + + k: int = Field(default=5, description="Number of results to return") + score_threshold: float | None = Field(default=None, description="Minimum relevance score") + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +@runtime_checkable +class Memory(Protocol): + """Protocol defining the interface for memory implementations.""" + + @property + def name(self) -> str | None: + """The name of this memory implementation.""" + ... + + @property + def config(self) -> BaseMemoryConfig: + """The configuration for this memory implementation.""" + ... + + async def transform( + self, + model_context: ChatCompletionContext, + ) -> List[MemoryContent]: + """ + Transform the provided model context using relevant memory content. + + Args: + model_context: The context to transform + + Returns: + List of memory entries with relevance scores + """ + ... + + async def query( + self, + query: MemoryContent, + cancellation_token: "CancellationToken | None" = None, + **kwargs: Any, + ) -> List[MemoryContent]: + """ + Query the memory store and return relevant entries. + + Args: + query: Query content item + cancellation_token: Optional token to cancel operation + **kwargs: Additional implementation-specific parameters + + Returns: + List of memory entries with relevance scores + """ + ... + + async def add(self, content: MemoryContent, cancellation_token: "CancellationToken | None" = None) -> None: + """ + Add a new content to memory. + + Args: + content: The memory content to add + cancellation_token: Optional token to cancel operation + """ + ... + + async def clear(self) -> None: + """Clear all entries from memory.""" + ... + + async def cleanup(self) -> None: + """Clean up any resources used by the memory implementation.""" + ... diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py new file mode 100644 index 00000000000..8ff3e9c0c08 --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py @@ -0,0 +1,251 @@ +import logging +from difflib import SequenceMatcher +from typing import Any, List + +from autogen_core import CancellationToken, Image +from autogen_core.model_context import ChatCompletionContext +from autogen_core.models import ( + SystemMessage, +) +from pydantic import Field + +from .. import EVENT_LOGGER_NAME +from ._base_memory import BaseMemoryConfig, Memory, MemoryContent, MemoryMimeType + +event_logger = logging.getLogger(EVENT_LOGGER_NAME) + + +class ListMemoryConfig(BaseMemoryConfig): + """Configuration for list-based memory implementation.""" + + similarity_threshold: float = Field( + default=0.35, description="Minimum similarity score for text matching", ge=0.0, le=1.0 + ) + + +class ListMemory(Memory): + """Simple list-based memory using text similarity matching. + + This memory implementation stores contents in a list and retrieves them based on + text similarity matching. It supports various content types and can transform + model contexts by injecting relevant memory content. + + Example: + ```python + # Initialize memory with custom config + memory = ListMemory(name="chat_history", config=ListMemoryConfig(similarity_threshold=0.7, k=3)) + + # Add memory content + content = MemoryContent(content="User prefers formal language", mime_type=MemoryMimeType.TEXT) + await memory.add(content) + + # Transform a model context with memory + context = await memory.transform(model_context) + ``` + + Attributes: + name (str): Identifier for this memory instance + config (ListMemoryConfig): Configuration controlling memory behavior + """ + + def __init__(self, name: str | None = None, config: ListMemoryConfig | None = None) -> None: + self._name = name or "default_list_memory" + self._config = config or ListMemoryConfig() + self._contents: List[MemoryContent] = [] + + @property + def name(self) -> str: + return self._name + + @property + def config(self) -> ListMemoryConfig: + return self._config + + async def transform( + self, + model_context: ChatCompletionContext, + ) -> List[MemoryContent]: + """Transform the model context by injecting relevant memory content. + + This method mutates the provided model_context by adding relevant memory content: + + 1. Extracts the last message from the context + 2. Uses it to query memory for relevant content + 3. Formats matching content into a system message + 4. Mutates the context by adding the system message + + Args: + model_context: The context to transform. Will be mutated if relevant + memories exist. + + Returns: + List[MemoryQueryResult]: A list of matching memory content with scores + + Example: + ```python + # Context will be mutated to include relevant memories + context = await memory.transform(model_context) + + # Any subsequent model calls will see the injected memories + messages = await context.get_messages() + ``` + """ + messages = await model_context.get_messages() + if not messages: + return [] + + # Extract query from last message + last_message = messages[-1] + query_text = last_message.content if isinstance(last_message.content, str) else str(last_message) + query = MemoryContent(content=query_text, mime_type=MemoryMimeType.TEXT) + + # Query memory and format results + results: List[str] = [] + query_results = await self.query(query) + for i, result in enumerate(query_results, 1): + if isinstance(result.content, str): + results.append(f"{i}. {result.content}") + event_logger.debug(f"Retrieved memory {i}. {result.content}, score: {result.score}") + + # Add memory results to context + if results: + memory_context = ( + "\n The following results were retrieved from memory for this task. You may choose to use them or not. :\n" + + "\n".join(results) + + "\n" + ) + await model_context.add_message(SystemMessage(content=memory_context)) + + return query_results + + async def query( + self, + query: MemoryContent, + cancellation_token: CancellationToken | None = None, + **kwargs: Any, + ) -> List[MemoryContent]: + """Query memory content based on text similarity. + + Searches memory content using text similarity matching against the query. + Only content exceeding the configured similarity threshold is returned, + sorted by relevance score in descending order. + + Args: + query: The content to match against memory content. Must contain + text that can be compared against stored content. + cancellation_token: Optional token to cancel long-running queries + **kwargs: Additional parameters passed to the similarity calculation + + Returns: + List[MemoryContent]: Matching content with similarity scores, + sorted by score in descending order. Limited to config.k entries. + + Raises: + ValueError: If query content cannot be converted to comparable text + + Example: + ```python + # Query memories similar to some text + query = MemoryContent(content="What's the weather?", mime_type=MemoryMimeType.TEXT) + results = await memory.query(query) + + # Check similarity scores + for result in results: + print(f"Score: {result.score}, Content: {result.content}") + ``` + """ + try: + query_text = self._extract_text(query) + except ValueError as e: + raise ValueError("Query must contain text content") from e + + results: List[MemoryContent] = [] + + for content in self._contents: + try: + content_text = self._extract_text(content) + except ValueError: + continue + + score = self._calculate_similarity(query_text, content_text) + + if score >= self._config.similarity_threshold and ( + self._config.score_threshold is None or score >= self._config.score_threshold + ): + result_content = content.model_copy() + result_content.score = score + results.append(result_content) + + results.sort(key=lambda x: x.score if x.score is not None else float("-inf"), reverse=True) + return results[: self._config.k] + + def _calculate_similarity(self, text1: str, text2: str) -> float: + """Calculate text similarity score using SequenceMatcher. + + Args: + text1: First text to compare + text2: Second text to compare + + Returns: + float: Similarity score between 0 and 1, where 1 means identical + + Note: + Uses difflib's SequenceMatcher for basic text similarity. + For production use cases, consider using more sophisticated + similarity metrics or embeddings. + """ + return SequenceMatcher(None, text1.lower(), text2.lower()).ratio() + + def _extract_text(self, content_item: MemoryContent) -> str: + """Extract searchable text from MemoryContent. + + Converts various content types into text that can be used for + similarity matching. + + Args: + content_item: Content to extract text from + + Returns: + str: Extracted text representation + + Raises: + ValueError: If content cannot be converted to text + + Note: + Currently supports TEXT, MARKDOWN, and JSON content types. + Images and binary content cannot be converted to text. + """ + content = content_item.content + + if content_item.mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]: + return str(content) + elif content_item.mime_type == MemoryMimeType.JSON: + if isinstance(content, dict): + return str(content) + raise ValueError("JSON content must be a dict") + elif isinstance(content, Image): + raise ValueError("Image content cannot be converted to text") + else: + raise ValueError(f"Unsupported content type: {content_item.mime_type}") + + async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None: + """Add new content to memory. + + Args: + content: Memory content to store + cancellation_token: Optional token to cancel operation + + Note: + Content is stored in chronological order. No deduplication is + performed. For production use cases, consider implementing + deduplication or content-based filtering. + """ + self._contents.append(content) + + async def clear(self) -> None: + """Clear all memory content.""" + self._contents = [] + + async def cleanup(self) -> None: + """Cleanup resources if needed.""" + pass diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 923b569602e..5c4ba0e03d6 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -12,6 +12,8 @@ class and includes specific fields relevant to the type of message being sent. from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated, deprecated +from autogen_agentchat.memory import MemoryContent + class BaseMessage(BaseModel, ABC): """Base class for all message types.""" @@ -123,13 +125,22 @@ class ToolCallSummaryMessage(BaseChatMessage): type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage" +class MemoryQueryEvent(BaseAgentEvent): + """An event signaling the results of memory queries.""" + + content: List[MemoryContent] + """The memory query results.""" + + type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent" + + ChatMessage = Annotated[ TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type") ] """Messages for agent-to-agent communication only.""" -AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent, Field(discriminator="type")] +AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent, Field(discriminator="type")] """Events emitted by agents and teams when they work, not used for agent-to-agent communication.""" @@ -140,7 +151,8 @@ class ToolCallSummaryMessage(BaseChatMessage): | HandoffMessage | ToolCallRequestEvent | ToolCallExecutionEvent - | ToolCallSummaryMessage, + | ToolCallSummaryMessage + | MemoryQueryEvent, Field(discriminator="type"), ] """(Deprecated, will be removed in 0.4.0) All message and event types.""" @@ -157,6 +169,7 @@ class ToolCallSummaryMessage(BaseChatMessage): "ToolCallMessage", "ToolCallResultMessage", "ToolCallSummaryMessage", + "MemoryQueryEvent", "ChatMessage", "AgentEvent", "AgentMessage", diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index ca079ce407b..75928847972 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -7,9 +7,11 @@ from autogen_agentchat import EVENT_LOGGER_NAME from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.base import Handoff, TaskResult +from autogen_agentchat.memory import Memory, ListMemory, MemoryContent, MemoryMimeType from autogen_agentchat.messages import ( ChatMessage, HandoffMessage, + MemoryQueryEvent, MultiModalMessage, TextMessage, ToolCallExecutionEvent, @@ -508,4 +510,81 @@ async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None: # Check if the mock client is called with only the last two messages. assert len(mock.calls) == 1 - assert len(mock.calls[0]) == 3 # 2 message from the context + 1 system message + # 2 message from the context + 1 system message + assert len(mock.calls[0]) == 3 + + +@pytest.mark.asyncio +async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None: + model = "gpt-4o-2024-05-13" + chat_completions = [ + ChatCompletion( + id="id1", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="Hello", role="assistant"), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), + ] + b64_image_str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" + mock = _MockChatCompletion(chat_completions) + monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) + + # Test basic memory properties and empty context + memory = ListMemory(name="test_memory") + assert memory.name == "test_memory" + assert memory.config is not None + + empty_context = BufferedChatCompletionContext(buffer_size=2) + empty_results = await memory.transform(empty_context) + assert len(empty_results) == 0 + + # Test various content types and memory transforms + memory = ListMemory() + await memory.add(MemoryContent(content="text content", mime_type=MemoryMimeType.TEXT)) + await memory.add(MemoryContent(content={"key": "value"}, mime_type=MemoryMimeType.JSON)) + await memory.add(MemoryContent(content=Image.from_base64(b64_image_str), mime_type=MemoryMimeType.IMAGE)) + + # Invalid query should raise error + with pytest.raises(ValueError, match="Query must contain text content"): + await memory.query(MemoryContent(content=Image.from_base64(b64_image_str), mime_type=MemoryMimeType.IMAGE)) + + # Test clear and cleanup + await memory.clear() + assert await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT)) == [] + await memory.cleanup() # Should not raise + + # Test invalid memory type + with pytest.raises(TypeError): + AssistantAgent( + "test_agent", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + memory="invalid", # type: ignore + ) + + # Test with agent + memory2 = ListMemory() + await memory2.add(MemoryContent(content="test instruction", mime_type=MemoryMimeType.TEXT)) + + agent = AssistantAgent( + "test_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), memory=[memory2] + ) + + result = await agent.run(task="test task") + assert len(result.messages) > 0 + memory_event = next((msg for msg in result.messages if isinstance(msg, MemoryQueryEvent)), None) + assert memory_event is not None + + # Test memory protocol + class BadMemory: + pass + + assert not isinstance(BadMemory(), Memory) + assert isinstance(ListMemory(), Memory) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/index.md b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/index.md index 0bab2460a3b..75653cde7c9 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/index.md +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/index.md @@ -81,6 +81,7 @@ tutorial/swarm tutorial/termination tutorial/custom-agents tutorial/state +tutorial/memory ``` ```{toctree} diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb new file mode 100644 index 00000000000..d25d0e94f32 --- /dev/null +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb @@ -0,0 +1,284 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Memory \n", + "\n", + "There are several use cases where it is valuable to maintain a _store_ of useful facts that can be intelligently added to the context of the agent just before a specific step. The typically use case here is a RAG pattern where a query is used to retrieve relevant information from a database that is then added to the agent's context.\n", + "\n", + "\n", + "AgentChat provides a {py:class}`~autogen_agentchat.memory.Memory` protocol that can be extended to provide this functionality. The key methods are `query`, `transform`, `add`, `clear`, and `cleanup`. \n", + "\n", + "- `query`: retrieve relevant information from the memory store \n", + "- `transform`: mutate an agent's internal `model_context` by adding the retrieved information (used in the {py:class}`~autogen_agentchat.agents.AssistantAgent` class) \n", + "- `add`: add new entries to the memory store\n", + "- `clear`: clear all entries from the memory store\n", + "- `cleanup`: clean up any resources used by the memory store \n", + "\n", + "\n", + "## ListMemory Example\n", + "\n", + "{py:class}`~autogen_agentchat.memory.ListMemory` is provided as an example implementation of the {py:class}`~autogen_agentchat.memory.Memory` protocol. It is a simple list-based memory implementation that uses text similarity matching to retrieve relevant information from the memory store. The similarity score is calculated using the `SequenceMatcher` class from the `difflib` module. The similarity score is calculated between the query text and the content text of each memory entry. \n", + "\n", + "In the following example, we will use ListMemory to similate a memory bank of user preferences and explore how it might be used in personalizing the agent's responses." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from autogen_agentchat.agents import AssistantAgent\n", + "from autogen_agentchat.memory import ListMemory, MemoryContent, MemoryMimeType\n", + "from autogen_agentchat.ui import Console\n", + "from autogen_ext.models.openai import OpenAIChatCompletionClient" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize user memory\n", + "user_memory = ListMemory()\n", + "\n", + "# Add user preferences to memory\n", + "await user_memory.add(MemoryContent(content=\"The weather should be in metric units\", mime_type=MemoryMimeType.TEXT))\n", + "\n", + "await user_memory.add(MemoryContent(content=\"Meal recipe must be vegan\", mime_type=MemoryMimeType.TEXT))\n", + "\n", + "\n", + "async def get_weather(city: str, units: str = \"imperial\") -> str:\n", + " if units == \"imperial\":\n", + " return f\"The weather in {city} is 73 degrees and Sunny.\"\n", + " elif units == \"metric\":\n", + " return f\"The weather in {city} is 23 degrees and Sunny.\"\n", + "\n", + "\n", + "assistant_agent = AssistantAgent(\n", + " name=\"assistant_agent\",\n", + " model_client=OpenAIChatCompletionClient(\n", + " model=\"gpt-4o-2024-08-06\",\n", + " ),\n", + " tools=[get_weather],\n", + " memory=[user_memory],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- user ----------\n", + "What is the weather in New York?\n", + "---------- assistant_agent ----------\n", + "[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=0.463768115942029)]\n", + "---------- assistant_agent ----------\n", + "[FunctionCall(id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", + "[Prompt tokens: 128, Completion tokens: 20]\n", + "---------- assistant_agent ----------\n", + "[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j')]\n", + "---------- assistant_agent ----------\n", + "The weather in New York is 23 degrees and Sunny.\n", + "---------- Summary ----------\n", + "Number of messages: 5\n", + "Finish reason: None\n", + "Total prompt tokens: 128\n", + "Total completion tokens: 20\n", + "Duration: 0.80 seconds\n" + ] + }, + { + "data": { + "text/plain": [ + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=0.463768115942029)], type='MemoryQueryEvent'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=128, completion_tokens=20), content=[FunctionCall(id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage')], stop_reason=None)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Run the agent with a task.\n", + "stream = assistant_agent.run_stream(task=\"What is the weather in New York?\")\n", + "await Console(stream)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can inspect that the `assistant_agent` model_context is actually updated with the retrieved memory entries. The `transform` method is used to format the retrieved memory entries into a string that can be used by the agent. In this case, we simply concatenate the content of each memory entry into a single string." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[UserMessage(content='What is the weather in New York?', source='user', type='UserMessage'),\n", + " SystemMessage(content='\\n The following results were retrieved from memory for this task. You may choose to use them or not. :\\n1. The weather should be in metric units\\n', type='SystemMessage'),\n", + " AssistantMessage(content=[FunctionCall(id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], source='assistant_agent', type='AssistantMessage'),\n", + " FunctionExecutionResultMessage(content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j')], type='FunctionExecutionResultMessage')]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "await assistant_agent._model_context.get_messages()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'),\n", + " MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=0.463768115942029)], type='MemoryQueryEvent'),\n", + " ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=219, completion_tokens=20), content=[FunctionCall(id='call_YPwxZOz0bTEW15beow3zXsaI', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'),\n", + " ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_YPwxZOz0bTEW15beow3zXsaI')], type='ToolCallExecutionEvent'),\n", + " ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage')]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result = await assistant_agent.run(task=\"What is the weather in New York?\")\n", + "result.messages" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see above that the weather is returned in Centigrade as stated in the user preferences. \n", + "\n", + "Similarly, assuming we ask a separate question about generating a meal plan, the agent is able to retrieve relevant information from the memory store and provide a personalized response." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- user ----------\n", + "Write brief meal recipe with broth\n", + "---------- assistant_agent ----------\n", + "[MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=0.5084745762711864)]\n", + "---------- assistant_agent ----------\n", + "Here's a brief vegan recipe using broth:\n", + "\n", + "**Vegan Vegetable Noodle Soup**\n", + "\n", + "**Ingredients:**\n", + "- 4 cups vegetable broth\n", + "- 1 cup water\n", + "- 1 cup carrots, sliced\n", + "- 1 cup celery, chopped\n", + "- 1 cup noodles (such as rice noodles or spaghetti broken into smaller pieces)\n", + "- 2 cups kale or spinach, chopped\n", + "- 2 cloves garlic, minced\n", + "- 1 tablespoon olive oil\n", + "- Salt and pepper to taste\n", + "- Lemon juice (optional)\n", + "\n", + "**Instructions:**\n", + "\n", + "1. **Sauté Vegetables:** In a large pot, heat olive oil over medium heat. Add minced garlic and sauté until fragrant. Add carrots and celery, and sauté for about 5 minutes, until they start to soften.\n", + "\n", + "2. **Add Broth and Noodles:** Pour in the vegetable broth and water, bringing it to a boil. Add the noodles and cook according to package instructions until they are al dente.\n", + "\n", + "3. **Cook Greens:** Stir in the kale or spinach and allow it to simmer for a couple of minutes until wilted.\n", + "\n", + "4. **Season and Serve:** Season with salt and pepper to taste. If desired, add a squeeze of lemon juice for extra flavor. \n", + "\n", + "5. **Enjoy:** Serve hot and enjoy your nutritious, comforting soup!\n", + "\n", + "This simple, flavorful soup is not only vegan but also packed with nutrients, making it a perfect meal any day. \n", + "\n", + "TERMINATE\n", + "[Prompt tokens: 306, Completion tokens: 294]\n", + "---------- Summary ----------\n", + "Number of messages: 3\n", + "Finish reason: None\n", + "Total prompt tokens: 306\n", + "Total completion tokens: 294\n", + "Duration: 4.39 seconds\n" + ] + }, + { + "data": { + "text/plain": [ + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write brief meal recipe with broth', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=0.5084745762711864)], type='MemoryQueryEvent'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=306, completion_tokens=294), content=\"Here's a brief vegan recipe using broth:\\n\\n**Vegan Vegetable Noodle Soup**\\n\\n**Ingredients:**\\n- 4 cups vegetable broth\\n- 1 cup water\\n- 1 cup carrots, sliced\\n- 1 cup celery, chopped\\n- 1 cup noodles (such as rice noodles or spaghetti broken into smaller pieces)\\n- 2 cups kale or spinach, chopped\\n- 2 cloves garlic, minced\\n- 1 tablespoon olive oil\\n- Salt and pepper to taste\\n- Lemon juice (optional)\\n\\n**Instructions:**\\n\\n1. **Sauté Vegetables:** In a large pot, heat olive oil over medium heat. Add minced garlic and sauté until fragrant. Add carrots and celery, and sauté for about 5 minutes, until they start to soften.\\n\\n2. **Add Broth and Noodles:** Pour in the vegetable broth and water, bringing it to a boil. Add the noodles and cook according to package instructions until they are al dente.\\n\\n3. **Cook Greens:** Stir in the kale or spinach and allow it to simmer for a couple of minutes until wilted.\\n\\n4. **Season and Serve:** Season with salt and pepper to taste. If desired, add a squeeze of lemon juice for extra flavor. \\n\\n5. **Enjoy:** Serve hot and enjoy your nutritious, comforting soup!\\n\\nThis simple, flavorful soup is not only vegan but also packed with nutrients, making it a perfect meal any day. \\n\\nTERMINATE\", type='TextMessage')], stop_reason=None)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stream = assistant_agent.run_stream(task=\"Write brief meal recipe with broth\")\n", + "await Console(stream)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom Memory Stores (Vector DBs, etc.)\n", + "\n", + "You can build on the `Memory` protocol to implement more complex memory stores. For example, you could implement a custom memory store that uses a vector database to store and retrieve information, or a memory store that uses a machine learning model to generate personalized responses based on the user's preferences etc.\n", + "\n", + "Specifically, you will need to overload the `query`, `transform`, and `add` methods to implement the desired functionality and pass the memory store to your agent.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/uv.lock b/python/uv.lock index 067520cb3e8..7994cf17b86 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -2,6 +2,7 @@ version = 1 requires-python = ">=3.10, <3.13" resolution-markers = [ "python_full_version < '3.11' and sys_platform == 'darwin'", + "python_version < '0'", "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", "python_full_version == '3.11.*' and sys_platform == 'darwin'", @@ -11,7 +12,6 @@ resolution-markers = [ "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'", "(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')", "python_full_version >= '3.12.4' and sys_platform == 'darwin'", - "python_version < '0'", "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'", "(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')", ]