From ef61c53a0ece043ee3aae2bc384b366de680ca92 Mon Sep 17 00:00:00 2001 From: Matt Vallillo Date: Tue, 27 Aug 2024 10:49:10 -0500 Subject: [PATCH] Refactor Conversation Memory class and drivers (#1084) --- CHANGELOG.md | 10 ++- MIGRATION.md | 67 ++++++++++++++++++- .../src/amazon_dynamodb_sessions_1.py | 2 +- .../src/conversation_memory_drivers_1.py | 4 +- .../src/conversation_memory_drivers_2.py | 2 +- .../src/conversation_memory_drivers_3.py | 2 +- ...versation_memory_drivers_griptape_cloud.py | 2 +- .../configs/drivers/base_drivers_config.py | 6 +- griptape/configs/drivers/drivers_config.py | 7 +- ...zon_dynamodb_conversation_memory_driver.py | 27 +++----- .../base_conversation_memory_driver.py | 16 +++-- ...iptape_cloud_conversation_memory_driver.py | 62 ++++++++--------- .../local_conversation_memory_driver.py | 43 ++++++------ .../redis_conversation_memory_driver.py | 26 +++---- .../structure/base_conversation_memory.py | 29 ++++---- griptape/memory/structure/run.py | 16 +++-- griptape/tasks/prompt_task.py | 2 +- griptape/tasks/toolkit_task.py | 2 +- .../test_amazon_bedrock_drivers_config.py | 10 ++- .../drivers/test_anthropic_drivers_config.py | 5 +- .../test_azure_openai_drivers_config.py | 5 +- .../drivers/test_cohere_drivers_config.py | 5 +- .../configs/drivers/test_drivers_config.py | 9 ++- .../drivers/test_google_drivers_config.py | 5 +- .../drivers/test_openai_driver_config.py | 5 +- ...est_dynamodb_conversation_memory_driver.py | 24 +++---- ...iptape_cloud_conversation_memory_driver.py | 60 ++++++++++------- .../test_local_conversation_memory_driver.py | 50 ++++++++------ .../test_redis_conversation_memory_driver.py | 22 +++--- .../structure/test_conversation_memory.py | 8 +-- 30 files changed, 322 insertions(+), 211 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b066e2e44..555306f90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,10 +5,18 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased + ### Added -- `BaseConversationMemory.prompt_driver` for use with autopruning. - Parameter `meta: dict` on `BaseEvent`. +### Changed +- **BREAKING**: Parameter `driver` on `BaseConversationMemory` renamed to `conversation_memory_driver`. +- **BREAKING**: `BaseConversationMemory.add_to_prompt_stack` now takes a `prompt_driver` parameter. +- **BREAKING**: `BaseConversationMemoryDriver.load` now returns `tuple[list[Run], Optional[dict]]`. +- **BREAKING**: `BaseConversationMemoryDriver.store` now takes `runs: list[Run]` and `metadata: Optional[dict]` as input. +- **BREAKING**: Parameter `file_path` on `LocalConversationMemoryDriver` renamed to `persist_file` and is now type `Optional[str]`. +- `Defaults.drivers_config.conversation_memory_driver` now defaults to `LocalConversationMemoryDriver` instead of `None`. + ### Fixed - Parsing streaming response with some OpenAi compatible services. diff --git a/MIGRATION.md b/MIGRATION.md index 75b7218fb..89ba95494 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -9,16 +9,79 @@ This document provides instructions for migrating your codebase to accommodate b Drivers, Loaders, and Engines will now raises exceptions rather than returning `ErrorArtifact`s. Update any logic that expects `ErrorArtifact` to handle exceptions instead. +#### Before ```python -# Before artifacts = WebLoader().load("https://www.griptape.ai") if isinstance(artifacts, ErrorArtifact): raise Exception(artifacts.value) +``` -# After +#### After +```python try: artifacts = WebLoader().load("https://www.griptape.ai") except Exception as e: raise e ``` + +### LocalConversationMemoryDriver `file_path` renamed to `persist_file` + +`LocalConversationMemoryDriver.file_path` has been renamed to `persist_file` and is now `Optional[str]`. If `persist_file` is not passed as a parameter, nothing will be persisted and no errors will be raised. `LocalConversationMemoryDriver` is now the default driver in the global `Defaults` object. + +#### Before +```python +local_driver_with_file = LocalConversationMemoryDriver( + file_path="my_file.json" +) + +local_driver = LocalConversationMemoryDriver() + +assert local_driver_with_file.file_path == "my_file.json" +assert local_driver.file_path == "griptape_memory.json" +``` + +#### After +```python +local_driver_with_file = LocalConversationMemoryDriver( + persist_file="my_file.json" +) + +local_driver = LocalConversationMemoryDriver() + +assert local_driver_with_file.persist_file == "my_file.json" +assert local_driver.persist_file is None +``` + +### Changes to BaseConversationMemoryDriver + +`BaseConversationMemoryDriver.driver` has been renamed to `conversation_memory_driver`. Method signatures for `.store` and `.load` have been changed. + +#### Before +```python +memory_driver = LocalConversationMemoryDriver() + +conversation_memory = ConversationMemory( + driver=memory_driver +) + +load_result: BaseConversationMemory = memory_driver.load() + +memory_driver.store(conversation_memory) +``` + +#### After +```python +memory_driver = LocalConversationMemoryDriver() + +conversation_memory = ConversationMemory( + conversation_memory_driver=memory_driver +) + +load_result: tuple[list[Run], dict[str, Any]] = memory_driver.load() + +memory_driver.store( + conversation_memory.runs, + conversation_memory.meta +) +``` diff --git a/docs/examples/src/amazon_dynamodb_sessions_1.py b/docs/examples/src/amazon_dynamodb_sessions_1.py index f7a6d0cd6..d44ec8f56 100644 --- a/docs/examples/src/amazon_dynamodb_sessions_1.py +++ b/docs/examples/src/amazon_dynamodb_sessions_1.py @@ -18,7 +18,7 @@ structure = Agent( conversation_memory=ConversationMemory( - driver=AmazonDynamoDbConversationMemoryDriver( + conversation_memory_driver=AmazonDynamoDbConversationMemoryDriver( session=boto3.Session( aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], diff --git a/docs/griptape-framework/drivers/src/conversation_memory_drivers_1.py b/docs/griptape-framework/drivers/src/conversation_memory_drivers_1.py index 27829d8d2..d87586d88 100644 --- a/docs/griptape-framework/drivers/src/conversation_memory_drivers_1.py +++ b/docs/griptape-framework/drivers/src/conversation_memory_drivers_1.py @@ -2,8 +2,8 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Agent -local_driver = LocalConversationMemoryDriver(file_path="memory.json") -agent = Agent(conversation_memory=ConversationMemory(driver=local_driver)) +local_driver = LocalConversationMemoryDriver(persist_file="memory.json") +agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=local_driver)) agent.run("Surfing is my favorite sport.") agent.run("What is my favorite sport?") diff --git a/docs/griptape-framework/drivers/src/conversation_memory_drivers_2.py b/docs/griptape-framework/drivers/src/conversation_memory_drivers_2.py index 9db525b42..0c32c1cc5 100644 --- a/docs/griptape-framework/drivers/src/conversation_memory_drivers_2.py +++ b/docs/griptape-framework/drivers/src/conversation_memory_drivers_2.py @@ -13,7 +13,7 @@ partition_key_value=conversation_id, ) -agent = Agent(conversation_memory=ConversationMemory(driver=dynamodb_driver)) +agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=dynamodb_driver)) agent.run("My name is Jeff.") agent.run("What is my name?") diff --git a/docs/griptape-framework/drivers/src/conversation_memory_drivers_3.py b/docs/griptape-framework/drivers/src/conversation_memory_drivers_3.py index 0f80d1393..5f0723940 100644 --- a/docs/griptape-framework/drivers/src/conversation_memory_drivers_3.py +++ b/docs/griptape-framework/drivers/src/conversation_memory_drivers_3.py @@ -14,7 +14,7 @@ conversation_id=conversation_id, ) -agent = Agent(conversation_memory=ConversationMemory(driver=redis_conversation_driver)) +agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=redis_conversation_driver)) agent.run("My name is Jeff.") agent.run("What is my name?") diff --git a/docs/griptape-framework/drivers/src/conversation_memory_drivers_griptape_cloud.py b/docs/griptape-framework/drivers/src/conversation_memory_drivers_griptape_cloud.py index 35492e06b..0723b5f75 100644 --- a/docs/griptape-framework/drivers/src/conversation_memory_drivers_griptape_cloud.py +++ b/docs/griptape-framework/drivers/src/conversation_memory_drivers_griptape_cloud.py @@ -9,7 +9,7 @@ cloud_conversation_driver = GriptapeCloudConversationMemoryDriver( api_key=os.environ["GT_CLOUD_API_KEY"], ) -agent = Agent(conversation_memory=ConversationMemory(driver=cloud_conversation_driver)) +agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=cloud_conversation_driver)) agent.run("My name is Jeff.") agent.run("What is my name?") diff --git a/griptape/configs/drivers/base_drivers_config.py b/griptape/configs/drivers/base_drivers_config.py index ec7503478..456249634 100644 --- a/griptape/configs/drivers/base_drivers_config.py +++ b/griptape/configs/drivers/base_drivers_config.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from attrs import define, field @@ -38,7 +38,7 @@ class BaseDriversConfig(ABC, SerializableMixin): _vector_store_driver: BaseVectorStoreDriver = field( default=None, kw_only=True, metadata={"serializable": True}, alias="vector_store_driver" ) - _conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( + _conversation_memory_driver: BaseConversationMemoryDriver = field( default=None, kw_only=True, metadata={"serializable": True}, alias="conversation_memory_driver" ) _text_to_speech_driver: BaseTextToSpeechDriver = field( @@ -70,7 +70,7 @@ def vector_store_driver(self) -> BaseVectorStoreDriver: ... @lazy_property() @abstractmethod - def conversation_memory_driver(self) -> Optional[BaseConversationMemoryDriver]: ... + def conversation_memory_driver(self) -> BaseConversationMemoryDriver: ... @lazy_property() @abstractmethod diff --git a/griptape/configs/drivers/drivers_config.py b/griptape/configs/drivers/drivers_config.py index ed68bcf8c..04edfd303 100644 --- a/griptape/configs/drivers/drivers_config.py +++ b/griptape/configs/drivers/drivers_config.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from attrs import define @@ -13,6 +13,7 @@ DummyPromptDriver, DummyTextToSpeechDriver, DummyVectorStoreDriver, + LocalConversationMemoryDriver, ) from griptape.utils.decorators import lazy_property @@ -52,8 +53,8 @@ def vector_store_driver(self) -> BaseVectorStoreDriver: return DummyVectorStoreDriver(embedding_driver=self.embedding_driver) @lazy_property() - def conversation_memory_driver(self) -> Optional[BaseConversationMemoryDriver]: - return None + def conversation_memory_driver(self) -> BaseConversationMemoryDriver: + return LocalConversationMemoryDriver() @lazy_property() def text_to_speech_driver(self) -> BaseTextToSpeechDriver: diff --git a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py index b0c2485d6..0842870eb 100644 --- a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: import boto3 - from griptape.memory.structure import BaseConversationMemory + from griptape.memory.structure import Run @define @@ -27,35 +27,26 @@ class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver): table: Any = field(init=False) def __attrs_post_init__(self) -> None: - dynamodb = self.session.resource("dynamodb") + self.table = self.session.resource("dynamodb").Table(self.table_name) - self.table = dynamodb.Table(self.table_name) - - def store(self, memory: BaseConversationMemory) -> None: + def store(self, runs: list[Run], metadata: dict) -> None: self.table.update_item( Key=self._get_key(), UpdateExpression="set #attr = :value", ExpressionAttributeNames={"#attr": self.value_attribute_key}, - ExpressionAttributeValues={":value": memory.to_json()}, + ExpressionAttributeValues={ + ":value": json.dumps(self._to_params_dict(runs, metadata)), + }, ) - def load(self) -> Optional[BaseConversationMemory]: - from griptape.memory.structure import BaseConversationMemory - + def load(self) -> tuple[list[Run], dict[str, Any]]: response = self.table.get_item(Key=self._get_key()) if "Item" in response and self.value_attribute_key in response["Item"]: memory_dict = json.loads(response["Item"][self.value_attribute_key]) - # needed to avoid recursive method calls - memory_dict["autoload"] = False - - memory = BaseConversationMemory.from_dict(memory_dict) - - memory.driver = self - - return memory + return self._from_params_dict(memory_dict) else: - return None + return [], {} def _get_key(self) -> dict[str, str | int]: key: dict[str, str | int] = {self.partition_key: self.partition_key_value} diff --git a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py index 1caeb902f..ea0a171f2 100644 --- a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py @@ -1,17 +1,25 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any from griptape.mixins import SerializableMixin if TYPE_CHECKING: - from griptape.memory.structure import BaseConversationMemory + from griptape.memory.structure import Run class BaseConversationMemoryDriver(SerializableMixin, ABC): @abstractmethod - def store(self, memory: BaseConversationMemory) -> None: ... + def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: ... @abstractmethod - def load(self) -> Optional[BaseConversationMemory]: ... + def load(self) -> tuple[list[Run], dict[str, Any]]: ... + + def _to_params_dict(self, runs: list[Run], metadata: dict[str, Any]) -> dict: + return {"runs": [run.to_dict() for run in runs], "metadata": metadata} + + def _from_params_dict(self, params_dict: dict[str, Any]) -> tuple[list[Run], dict[str, Any]]: + from griptape.memory.structure import Run + + return [Run.from_dict(run) for run in params_dict.get("runs", [])], params_dict.get("metadata", {}) diff --git a/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py b/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py index 2ea1d0d1a..3aac74090 100644 --- a/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py @@ -2,7 +2,7 @@ import os import uuid -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from urllib.parse import urljoin import requests @@ -10,9 +10,10 @@ from griptape.artifacts import BaseArtifact from griptape.drivers import BaseConversationMemoryDriver +from griptape.utils import dict_merge if TYPE_CHECKING: - from griptape.memory.structure import BaseConversationMemory + from griptape.memory.structure import Run @define(kw_only=True) @@ -55,26 +56,38 @@ def validate_api_key(self, _: Attribute, value: Optional[str]) -> str: raise ValueError(f"{self.__class__.__name__} requires an API key") return value - def store(self, memory: BaseConversationMemory) -> None: - # serliaze the run artifacts to json strings - messages = [{"input": run.input.to_json(), "output": run.output.to_json()} for run in memory.runs] + def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: + # serialize the run artifacts to json strings + messages = [ + dict_merge( + { + "input": run.input.to_json(), + "output": run.output.to_json(), + "metadata": {"run_id": run.id}, + }, + run.meta, + ) + for run in runs + ] - # serialize the metadata to a json string - # remove runs because they are already stored as Messages - metadata = memory.to_dict() - del metadata["runs"] + body = dict_merge( + { + "messages": messages, + }, + metadata, + ) # patch the Thread with the new messages and metadata # all old Messages are replaced with the new ones response = requests.patch( self._get_url(f"/threads/{self.thread_id}"), - json={"messages": messages, "metadata": metadata}, + json=body, headers=self.headers, ) response.raise_for_status() - def load(self) -> BaseConversationMemory: - from griptape.memory.structure import BaseConversationMemory, ConversationMemory, Run + def load(self) -> tuple[list[Run], dict[str, Any]]: + from griptape.memory.structure import Run # get the Messages from the Thread messages_response = requests.get(self._get_url(f"/threads/{self.thread_id}/messages"), headers=self.headers) @@ -86,33 +99,16 @@ def load(self) -> BaseConversationMemory: thread_response.raise_for_status() thread_response = thread_response.json() - messages = messages_response.get("messages", []) - runs = [ Run( - id=m["message_id"], + id=m["metadata"].pop("run_id"), + meta=m["metadata"], input=BaseArtifact.from_json(m["input"]), output=BaseArtifact.from_json(m["output"]), ) - for m in messages + for m in messages_response.get("messages", []) ] - metadata = thread_response.get("metadata") - - # the metadata will contain the serialized - # ConversationMemory object with the runs removed - # autoload=False to prevent recursively loading the memory - if metadata is not None and metadata != {}: - memory = BaseConversationMemory.from_dict( - { - **metadata, - "runs": [run.to_dict() for run in runs], - "autoload": False, - } - ) - memory.driver = self - return memory - # no metadata found, return a new ConversationMemory object - return ConversationMemory(runs=runs, autoload=False, driver=self) + return runs, thread_response.get("metadata", {}) def _get_thread_id(self) -> str: res = requests.post(self._get_url("/threads"), json={"name": uuid.uuid4().hex}, headers=self.headers) diff --git a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py index 9a79accc3..c8ea540be 100644 --- a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py @@ -3,34 +3,33 @@ import json import os from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from attrs import define, field from griptape.drivers import BaseConversationMemoryDriver if TYPE_CHECKING: - from griptape.memory.structure import BaseConversationMemory + from griptape.memory.structure import Run -@define +@define(kw_only=True) class LocalConversationMemoryDriver(BaseConversationMemoryDriver): - file_path: str = field(default="griptape_memory.json", kw_only=True, metadata={"serializable": True}) - - def store(self, memory: BaseConversationMemory) -> None: - Path(self.file_path).write_text(memory.to_json()) - - def load(self) -> Optional[BaseConversationMemory]: - from griptape.memory.structure import BaseConversationMemory - - if not os.path.exists(self.file_path): - return None - - memory_dict = json.loads(Path(self.file_path).read_text()) - # needed to avoid recursive method calls - memory_dict["autoload"] = False - memory = BaseConversationMemory.from_dict(memory_dict) - - memory.driver = self - - return memory + persist_file: Optional[str] = field(default=None, metadata={"serializable": True}) + + def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: + if self.persist_file is not None: + Path(self.persist_file).write_text(json.dumps(self._to_params_dict(runs, metadata))) + + def load(self) -> tuple[list[Run], dict[str, Any]]: + if ( + self.persist_file is not None + and os.path.exists(self.persist_file) + and (loaded_str := Path(self.persist_file).read_text()) is not None + ): + try: + return self._from_params_dict(json.loads(loaded_str)) + except Exception as e: + raise ValueError(f"Unable to load data from {self.persist_file}") from e + + return [], {} diff --git a/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py index 8741cda50..f30189e37 100644 --- a/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py @@ -2,17 +2,17 @@ import json import uuid -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from attrs import Factory, define, field from griptape.drivers import BaseConversationMemoryDriver -from griptape.utils.import_utils import import_optional_dependency +from griptape.utils import import_optional_dependency if TYPE_CHECKING: from redis import Redis - from griptape.memory.structure import BaseConversationMemory + from griptape.memory.structure import Run @define @@ -52,19 +52,11 @@ class RedisConversationMemoryDriver(BaseConversationMemoryDriver): ), ) - def store(self, memory: BaseConversationMemory) -> None: - self.client.hset(self.index, self.conversation_id, memory.to_json()) + def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: + self.client.hset(self.index, self.conversation_id, json.dumps(self._to_params_dict(runs, metadata))) - def load(self) -> Optional[BaseConversationMemory]: - from griptape.memory.structure import BaseConversationMemory - - key = self.index - memory_json = self.client.hget(key, self.conversation_id) + def load(self) -> tuple[list[Run], dict[str, Any]]: + memory_json = self.client.hget(self.index, self.conversation_id) if memory_json is not None: - memory_dict = json.loads(memory_json) - # needed to avoid recursive method calls - memory_dict["autoload"] = False - memory = BaseConversationMemory.from_dict(memory_dict) - memory.driver = self - return memory - return None + return self._from_params_dict(json.loads(memory_json)) + return [], {} diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index 15d0a9e99..92f5bd942 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -1,13 +1,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from attrs import Factory, define, field from griptape.common import PromptStack from griptape.configs import Defaults from griptape.mixins import SerializableMixin +from griptape.utils import dict_merge if TYPE_CHECKING: from griptape.drivers import BaseConversationMemoryDriver, BasePromptDriver @@ -16,22 +17,20 @@ @define class BaseConversationMemory(SerializableMixin, ABC): - driver: Optional[BaseConversationMemoryDriver] = field( + conversation_memory_driver: BaseConversationMemoryDriver = field( default=Factory(lambda: Defaults.drivers_config.conversation_memory_driver), kw_only=True ) - prompt_driver: BasePromptDriver = field( - default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True - ) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) + meta: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True}) autoload: bool = field(default=True, kw_only=True) autoprune: bool = field(default=True, kw_only=True) max_runs: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) def __attrs_post_init__(self) -> None: - if self.driver and self.autoload: - memory = self.driver.load() - if memory is not None: - [self.add_run(r) for r in memory.runs] + if self.autoload: + runs, meta = self.conversation_memory_driver.load() + self.runs.extend(runs) + self.meta = dict_merge(self.meta, meta) def before_add_run(self) -> None: pass @@ -44,8 +43,7 @@ def add_run(self, run: Run) -> BaseConversationMemory: return self def after_add_run(self) -> None: - if self.driver: - self.driver.store(self) + self.conversation_memory_driver.store(self.runs, self.meta) @abstractmethod def try_add_run(self, run: Run) -> None: ... @@ -53,13 +51,16 @@ def try_add_run(self, run: Run) -> None: ... @abstractmethod def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: ... - def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = None) -> PromptStack: + def add_to_prompt_stack( + self, prompt_driver: BasePromptDriver, prompt_stack: PromptStack, index: Optional[int] = None + ) -> PromptStack: """Add the Conversation Memory runs to the Prompt Stack by modifying the messages in place. If autoprune is enabled, this will fit as many Conversation Memory runs into the Prompt Stack as possible without exceeding the token limit. Args: + prompt_driver: The Prompt Driver to use for token counting. prompt_stack: The Prompt Stack to add the Conversation Memory to. index: Optional index to insert the Conversation Memory runs at. Defaults to appending to the end of the Prompt Stack. @@ -82,8 +83,8 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = temp_stack.messages.extend(memory_inputs) # Convert the Prompt Stack into tokens left. - tokens_left = self.prompt_driver.tokenizer.count_input_tokens_left( - self.prompt_driver.prompt_stack_to_string(temp_stack), + tokens_left = prompt_driver.tokenizer.count_input_tokens_left( + prompt_driver.prompt_stack_to_string(temp_stack), ) if tokens_left > 0: # There are still tokens left, no need to prune. diff --git a/griptape/memory/structure/run.py b/griptape/memory/structure/run.py index 3d8ca3869..5d2a182ad 100644 --- a/griptape/memory/structure/run.py +++ b/griptape/memory/structure/run.py @@ -1,13 +1,19 @@ +from __future__ import annotations + import uuid +from typing import TYPE_CHECKING, Optional from attrs import Factory, define, field -from griptape.artifacts.base_artifact import BaseArtifact from griptape.mixins import SerializableMixin +if TYPE_CHECKING: + from griptape.artifacts import BaseArtifact + -@define +@define(kw_only=True) class Run(SerializableMixin): - id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) - input: BaseArtifact = field(kw_only=True, metadata={"serializable": True}) - output: BaseArtifact = field(kw_only=True, metadata={"serializable": True}) + id: str = field(default=Factory(lambda: uuid.uuid4().hex), metadata={"serializable": True}) + meta: Optional[dict] = field(default=None, metadata={"serializable": True}) + input: BaseArtifact = field(metadata={"serializable": True}) + output: BaseArtifact = field(metadata={"serializable": True}) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 17a73e4cd..9c0060039 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -58,7 +58,7 @@ def prompt_stack(self) -> PromptStack: if memory is not None: # insert memory into the stack right before the user messages - memory.add_to_prompt_stack(stack, 1 if system_template else 0) + memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_template else 0) return stack diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index 24607a352..ff1194440 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -115,7 +115,7 @@ def prompt_stack(self) -> PromptStack: if memory: # inserting at index 1 to place memory right after system prompt - memory.add_to_prompt_stack(stack, 1) + memory.add_to_prompt_stack(self.prompt_driver, stack, 1) return stack diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index 129fe281f..d74c273f6 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -25,7 +25,10 @@ def config_with_values(self): def test_to_dict(self, config): assert config.to_dict() == { - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + }, "embedding_driver": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, "image_generation_driver": { "image_generation_model_driver": { @@ -77,7 +80,10 @@ def test_from_dict_with_values(self, config_with_values): def test_to_dict_with_values(self, config_with_values): assert config_with_values.to_dict() == { - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + }, "embedding_driver": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, "image_generation_driver": { "image_generation_model_driver": { diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index b2335d92a..2bdb9497d 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -45,7 +45,10 @@ def test_to_dict(self, config): "input_type": "document", }, }, - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + }, "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index 5c514c947..01886962e 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -36,7 +36,10 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, }, - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + }, "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 3c267d73d..0d16d1ab2 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -13,7 +13,10 @@ def test_to_dict(self, config): "type": "CohereDriversConfig", "image_generation_driver": {"type": "DummyImageGenerationDriver"}, "image_query_driver": {"type": "DummyImageQueryDriver"}, - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + }, "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, "prompt_driver": { diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index 20cc0926c..e2476c437 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -18,7 +18,10 @@ def test_to_dict(self, config): "stream": False, "use_native_tools": False, }, - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + }, "embedding_driver": {"type": "DummyEmbeddingDriver"}, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, "image_query_driver": {"type": "DummyImageQueryDriver"}, @@ -56,7 +59,7 @@ def test_lazy_init(self): assert Defaults.drivers_config.image_query_driver is not None assert Defaults.drivers_config.embedding_driver is not None assert Defaults.drivers_config.vector_store_driver is not None - assert Defaults.drivers_config.conversation_memory_driver is None + assert Defaults.drivers_config.conversation_memory_driver is not None assert Defaults.drivers_config.text_to_speech_driver is not None assert Defaults.drivers_config.audio_transcription_driver is not None @@ -65,6 +68,6 @@ def test_lazy_init(self): assert Defaults.drivers_config._image_query_driver is not None assert Defaults.drivers_config._embedding_driver is not None assert Defaults.drivers_config._vector_store_driver is not None - assert Defaults.drivers_config._conversation_memory_driver is None + assert Defaults.drivers_config._conversation_memory_driver is not None assert Defaults.drivers_config._text_to_speech_driver is not None assert Defaults.drivers_config._audio_transcription_driver is not None diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index f6df1afef..7a752c6de 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -43,7 +43,10 @@ def test_to_dict(self, config): "title": None, }, }, - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + }, "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index 2425b178f..40a755e50 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -28,7 +28,10 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, }, - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + }, "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", diff --git a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py index f1a5df1be..96e2ca969 100644 --- a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py @@ -46,7 +46,7 @@ def test_store(self): value_attribute_key=self.VALUE_ATTRIBUTE_KEY, partition_key_value=self.PARTITION_KEY_VALUE, ) - memory = ConversationMemory(driver=memory_driver) + memory = ConversationMemory(conversation_memory_driver=memory_driver) pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -72,7 +72,7 @@ def test_store_with_sort_key(self): sort_key="sortKey", sort_key_value="foo", ) - memory = ConversationMemory(driver=memory_driver) + memory = ConversationMemory(conversation_memory_driver=memory_driver) pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -93,7 +93,7 @@ def test_load(self): value_attribute_key=self.VALUE_ATTRIBUTE_KEY, partition_key_value=self.PARTITION_KEY_VALUE, ) - memory = ConversationMemory(driver=memory_driver) + memory = ConversationMemory(conversation_memory_driver=memory_driver, meta={"foo": "bar"}) pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -101,12 +101,10 @@ def test_load(self): pipeline.run() pipeline.run() - new_memory = memory_driver.load() + runs, metadata = memory_driver.load() - assert new_memory.type == "ConversationMemory" - assert len(new_memory.runs) == 2 - assert new_memory.runs[0].input.value == "test" - assert new_memory.runs[0].output.value == "mock output" + assert len(runs) == 2 + assert metadata == {"foo": "bar"} def test_load_with_sort_key(self): memory_driver = AmazonDynamoDbConversationMemoryDriver( @@ -118,7 +116,7 @@ def test_load_with_sort_key(self): sort_key="sortKey", sort_key_value="foo", ) - memory = ConversationMemory(driver=memory_driver) + memory = ConversationMemory(conversation_memory_driver=memory_driver, meta={"foo": "bar"}) pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -126,9 +124,7 @@ def test_load_with_sort_key(self): pipeline.run() pipeline.run() - new_memory = memory_driver.load() + runs, metadata = memory_driver.load() - assert new_memory.type == "ConversationMemory" - assert len(new_memory.runs) == 2 - assert new_memory.runs[0].input.value == "test" - assert new_memory.runs[0].output.value == "mock output" + assert len(runs) == 2 + assert metadata == {"foo": "bar"} diff --git a/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py index 707132ef5..dccdc9fd0 100644 --- a/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py @@ -1,10 +1,11 @@ import json +import os import pytest from griptape.artifacts import BaseArtifact from griptape.drivers import GriptapeCloudConversationMemoryDriver -from griptape.memory.structure import BaseConversationMemory, ConversationMemory, Run, SummaryConversationMemory +from griptape.memory.structure import Run TEST_CONVERSATION = '{"type": "SummaryConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": {"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}}], "max_runs": 2}' @@ -23,6 +24,7 @@ def get(*args, **kwargs): "input": '{"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}', "output": '{"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}', "index": 0, + "metadata": {"run_id": "1234"}, } ] }, @@ -32,7 +34,7 @@ def get(*args, **kwargs): return mocker.Mock( raise_for_status=lambda: None, json=lambda: { - "metadata": json.loads(TEST_CONVERSATION), + "metadata": {"foo": "bar"}, "name": "test", "thread_id": "test_metadata", } @@ -44,12 +46,22 @@ def get(*args, **kwargs): "requests.get", side_effect=get, ) + + def post(*args, **kwargs): + if str(args[0]).endswith("/threads"): + return mocker.Mock( + raise_for_status=lambda: None, + json=lambda: {"thread_id": "test", "name": "test"}, + ) + else: + return mocker.Mock( + raise_for_status=lambda: None, + json=lambda: {"message_id": "test"}, + ) + mocker.patch( "requests.post", - return_value=mocker.Mock( - raise_for_status=lambda: None, - json=lambda: {"thread_id": "test", "name": "test"}, - ), + side_effect=post, ) mocker.patch( "requests.patch", @@ -66,26 +78,28 @@ def test_no_api_key(self): with pytest.raises(ValueError): GriptapeCloudConversationMemoryDriver(api_key=None, thread_id="test") - def test_no_thread_id(self): + def test_thread_id(self): driver = GriptapeCloudConversationMemoryDriver(api_key="test") assert driver.thread_id == "test" + os.environ["GT_CLOUD_THREAD_ID"] = "test_env" + driver = GriptapeCloudConversationMemoryDriver(api_key="test") + assert driver.thread_id == "test_env" + driver = GriptapeCloudConversationMemoryDriver(api_key="test", thread_id="test_init") + assert driver.thread_id == "test_init" - def test_store(self, driver): - memory = ConversationMemory( - runs=[ - Run(input=BaseArtifact.from_dict(run["input"]), output=BaseArtifact.from_dict(run["output"])) - for run in json.loads(TEST_CONVERSATION)["runs"] - ], - ) - assert driver.store(memory) is None + def test_store(self, driver: GriptapeCloudConversationMemoryDriver): + runs = [ + Run(input=BaseArtifact.from_dict(run["input"]), output=BaseArtifact.from_dict(run["output"])) + for run in json.loads(TEST_CONVERSATION)["runs"] + ] + assert driver.store(runs, {}) is None def test_load(self, driver): - memory = driver.load() - assert isinstance(memory, BaseConversationMemory) - assert len(memory.runs) == 1 - - def test_load_metadata(self, driver): + runs, metadata = driver.load() + assert len(runs) == 1 + assert runs[0].id == "1234" + assert metadata == {} driver.thread_id = "test_metadata" - memory = driver.load() - assert isinstance(memory, SummaryConversationMemory) - assert len(memory.runs) == 1 + runs, metadata = driver.load() + assert len(runs) == 1 + assert metadata == {"foo": "bar"} diff --git a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py index dff66d0fc..52e8d31e2 100644 --- a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py @@ -1,5 +1,6 @@ import contextlib import os +from pathlib import Path import pytest @@ -21,26 +22,23 @@ def _run_before_and_after_tests(self): self.__delete_file(self.MEMORY_FILE_PATH) def test_store(self): - memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) - memory = ConversationMemory(driver=memory_driver, autoload=False) + memory_driver = LocalConversationMemoryDriver(persist_file=self.MEMORY_FILE_PATH) + memory = ConversationMemory(conversation_memory_driver=memory_driver, autoload=False) pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) - try: - with open(self.MEMORY_FILE_PATH): - raise AssertionError() - except FileNotFoundError: - assert True + assert not os.path.exists(self.MEMORY_FILE_PATH) pipeline.run() - with open(self.MEMORY_FILE_PATH): - assert True + assert os.path.exists(self.MEMORY_FILE_PATH) def test_load(self): - memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) - memory = ConversationMemory(driver=memory_driver, autoload=False, max_runs=5) + memory_driver = LocalConversationMemoryDriver(persist_file=self.MEMORY_FILE_PATH) + memory = ConversationMemory( + conversation_memory_driver=memory_driver, autoload=False, max_runs=5, meta={"foo": "bar"} + ) pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -48,17 +46,25 @@ def test_load(self): pipeline.run() pipeline.run() - new_memory = memory_driver.load() + runs, metadata = memory_driver.load() - assert new_memory.type == "ConversationMemory" - assert len(new_memory.runs) == 2 - assert new_memory.runs[0].input.value == "test" - assert new_memory.runs[0].output.value == "mock output" - assert new_memory.max_runs == 5 + assert len(runs) == 2 + assert runs[0].input.value == "test" + assert runs[0].output.value == "mock output" + assert metadata == {"foo": "bar"} + + runs[0].input.value = "new test" + + def test_load_bad_data(self): + Path(self.MEMORY_FILE_PATH).write_text("bad data") + memory_driver = LocalConversationMemoryDriver(persist_file=self.MEMORY_FILE_PATH) + + with pytest.raises(ValueError, match="Unable to load data from test_memory.json"): + ConversationMemory(conversation_memory_driver=memory_driver) def test_autoload(self): - memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) - memory = ConversationMemory(driver=memory_driver) + memory_driver = LocalConversationMemoryDriver(persist_file=self.MEMORY_FILE_PATH) + memory = ConversationMemory(conversation_memory_driver=memory_driver, autoload=False) pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -66,13 +72,13 @@ def test_autoload(self): pipeline.run() pipeline.run() - autoloaded_memory = ConversationMemory(driver=memory_driver) + autoloaded_memory = ConversationMemory(conversation_memory_driver=memory_driver) assert autoloaded_memory.type == "ConversationMemory" assert len(autoloaded_memory.runs) == 2 assert autoloaded_memory.runs[0].input.value == "test" assert autoloaded_memory.runs[0].output.value == "mock output" - def __delete_file(self, file_path) -> None: + def __delete_file(self, persist_file) -> None: with contextlib.suppress(FileNotFoundError): - os.remove(file_path) + os.remove(persist_file) diff --git a/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py index 4a92a28a8..e7ef45c42 100644 --- a/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py @@ -4,7 +4,8 @@ from griptape.drivers.memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver from griptape.memory.structure.base_conversation_memory import BaseConversationMemory -TEST_CONVERSATION = '{"type": "ConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": {"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}}], "max_runs": 2}' +TEST_DATA = '{"runs": [{"input": {"type": "TextArtifact", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "value": "Hello! How can I assist you today?"}}], "metadata": {"foo": "bar"}}' +TEST_MEMORY = '{"type": "ConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": {"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}}], "max_runs": 2}' CONVERSATION_ID = "117151897f344ff684b553d0655d8f39" INDEX = "griptape_conversation" HOST = "127.0.0.1" @@ -17,7 +18,7 @@ class TestRedisConversationMemoryDriver: def _mock_redis(self, mocker): mocker.patch.object(redis.StrictRedis, "hset", return_value=None) mocker.patch.object(redis.StrictRedis, "keys", return_value=[b"test"]) - mocker.patch.object(redis.StrictRedis, "hget", return_value=TEST_CONVERSATION) + mocker.patch.object(redis.StrictRedis, "hget", return_value=TEST_DATA) fake_redisearch = mocker.MagicMock() fake_redisearch.search = mocker.MagicMock(return_value=mocker.MagicMock(docs=[])) @@ -31,11 +32,16 @@ def driver(self): return RedisConversationMemoryDriver(host=HOST, port=PORT, db=0, index=INDEX, conversation_id=CONVERSATION_ID) def test_store(self, driver): - memory = BaseConversationMemory.from_json(TEST_CONVERSATION) - assert driver.store(memory) is None + memory = BaseConversationMemory.from_json(TEST_MEMORY) + assert driver.store(memory.runs, memory.meta) is None def test_load(self, driver): - memory = driver.load() - assert memory.type == "ConversationMemory" - assert memory.max_runs == 2 - assert memory.runs == BaseConversationMemory.from_json(TEST_CONVERSATION).runs + runs, metadata = driver.load() + assert len(runs) == 1 + assert metadata == {"foo": "bar"} + + def test_load_empty(self, mocker, driver): + mocker.patch.object(redis.StrictRedis, "hget", return_value=None) + runs, metadata = driver.load() + assert len(runs) == 0 + assert metadata == {} diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 3f9ac2344..84c8591d8 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -90,7 +90,7 @@ def test_add_to_prompt_stack_autopruing_disabled(self): prompt_stack = PromptStack() prompt_stack.add_user_message(TextArtifact("foo")) prompt_stack.add_assistant_message("bar") - memory.add_to_prompt_stack(prompt_stack) + memory.add_to_prompt_stack(agent.prompt_driver, prompt_stack) assert len(prompt_stack.messages) == 12 @@ -116,7 +116,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): prompt_stack.add_system_message("fizz") prompt_stack.add_user_message("foo") prompt_stack.add_assistant_message("bar") - memory.add_to_prompt_stack(prompt_stack) + memory.add_to_prompt_stack(agent.prompt_driver, prompt_stack) assert len(prompt_stack.messages) == 3 @@ -140,7 +140,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): prompt_stack.add_system_message("fizz") prompt_stack.add_user_message("foo") prompt_stack.add_assistant_message("bar") - memory.add_to_prompt_stack(prompt_stack) + memory.add_to_prompt_stack(agent.prompt_driver, prompt_stack) assert len(prompt_stack.messages) == 13 @@ -168,7 +168,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): prompt_stack.add_system_message("fizz") prompt_stack.add_user_message("foo") prompt_stack.add_assistant_message("bar") - memory.add_to_prompt_stack(prompt_stack, 1) + memory.add_to_prompt_stack(agent.prompt_driver, prompt_stack, 1) # We expect one run (2 Prompt Stack inputs) to be pruned. assert len(prompt_stack.messages) == 11