diff --git a/CHANGELOG.md b/CHANGELOG.md index 59947dc5fb..6e66c3e8f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,8 +5,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased + ### Added -- `BaseConversationMemory.prompt_driver` for use with autopruning. +- Parameter `lazy_load` on `LocalConversationMemoryDriver`. + +### Changed +- **BREAKING**: Parameter `driver` on `BaseConversationMemory` renamed to `conversation_memory_driver`. +- **BREAKING**: `BaseConversationMemory.add_to_prompt_stack` now takes a `prompt_driver` parameter. +- **BREAKING**: `BaseConversationMemoryDriver.load` now returns `tuple[list[Run], dict]`. +- **BREAKING**: Parameter `file_path` on `LocalConversationMemoryDriver` renamed to `persist_file`. ### Fixed - Parsing streaming response with some OpenAi compatible services. diff --git a/docs/examples/src/amazon_dynamodb_sessions_1.py b/docs/examples/src/amazon_dynamodb_sessions_1.py index f7a6d0cd68..d44ec8f561 100644 --- a/docs/examples/src/amazon_dynamodb_sessions_1.py +++ b/docs/examples/src/amazon_dynamodb_sessions_1.py @@ -18,7 +18,7 @@ structure = Agent( conversation_memory=ConversationMemory( - driver=AmazonDynamoDbConversationMemoryDriver( + conversation_memory_driver=AmazonDynamoDbConversationMemoryDriver( session=boto3.Session( aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], diff --git a/docs/griptape-framework/drivers/src/conversation_memory_drivers_1.py b/docs/griptape-framework/drivers/src/conversation_memory_drivers_1.py index 27829d8d2d..d87586d88b 100644 --- a/docs/griptape-framework/drivers/src/conversation_memory_drivers_1.py +++ b/docs/griptape-framework/drivers/src/conversation_memory_drivers_1.py @@ -2,8 +2,8 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Agent -local_driver = LocalConversationMemoryDriver(file_path="memory.json") -agent = Agent(conversation_memory=ConversationMemory(driver=local_driver)) +local_driver = LocalConversationMemoryDriver(persist_file="memory.json") +agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=local_driver)) agent.run("Surfing is my favorite sport.") agent.run("What is my favorite sport?") diff --git a/docs/griptape-framework/drivers/src/conversation_memory_drivers_2.py b/docs/griptape-framework/drivers/src/conversation_memory_drivers_2.py index 9db525b426..0c32c1cc55 100644 --- a/docs/griptape-framework/drivers/src/conversation_memory_drivers_2.py +++ b/docs/griptape-framework/drivers/src/conversation_memory_drivers_2.py @@ -13,7 +13,7 @@ partition_key_value=conversation_id, ) -agent = Agent(conversation_memory=ConversationMemory(driver=dynamodb_driver)) +agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=dynamodb_driver)) agent.run("My name is Jeff.") agent.run("What is my name?") diff --git a/docs/griptape-framework/drivers/src/conversation_memory_drivers_3.py b/docs/griptape-framework/drivers/src/conversation_memory_drivers_3.py index 0f80d13934..5f07239405 100644 --- a/docs/griptape-framework/drivers/src/conversation_memory_drivers_3.py +++ b/docs/griptape-framework/drivers/src/conversation_memory_drivers_3.py @@ -14,7 +14,7 @@ conversation_id=conversation_id, ) -agent = Agent(conversation_memory=ConversationMemory(driver=redis_conversation_driver)) +agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=redis_conversation_driver)) agent.run("My name is Jeff.") agent.run("What is my name?") diff --git a/docs/griptape-framework/drivers/src/conversation_memory_drivers_griptape_cloud.py b/docs/griptape-framework/drivers/src/conversation_memory_drivers_griptape_cloud.py index 35492e06bf..0723b5f751 100644 --- a/docs/griptape-framework/drivers/src/conversation_memory_drivers_griptape_cloud.py +++ b/docs/griptape-framework/drivers/src/conversation_memory_drivers_griptape_cloud.py @@ -9,7 +9,7 @@ cloud_conversation_driver = GriptapeCloudConversationMemoryDriver( api_key=os.environ["GT_CLOUD_API_KEY"], ) -agent = Agent(conversation_memory=ConversationMemory(driver=cloud_conversation_driver)) +agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=cloud_conversation_driver)) agent.run("My name is Jeff.") agent.run("What is my name?") diff --git a/griptape/configs/drivers/base_drivers_config.py b/griptape/configs/drivers/base_drivers_config.py index ec75034786..456249634d 100644 --- a/griptape/configs/drivers/base_drivers_config.py +++ b/griptape/configs/drivers/base_drivers_config.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from attrs import define, field @@ -38,7 +38,7 @@ class BaseDriversConfig(ABC, SerializableMixin): _vector_store_driver: BaseVectorStoreDriver = field( default=None, kw_only=True, metadata={"serializable": True}, alias="vector_store_driver" ) - _conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( + _conversation_memory_driver: BaseConversationMemoryDriver = field( default=None, kw_only=True, metadata={"serializable": True}, alias="conversation_memory_driver" ) _text_to_speech_driver: BaseTextToSpeechDriver = field( @@ -70,7 +70,7 @@ def vector_store_driver(self) -> BaseVectorStoreDriver: ... @lazy_property() @abstractmethod - def conversation_memory_driver(self) -> Optional[BaseConversationMemoryDriver]: ... + def conversation_memory_driver(self) -> BaseConversationMemoryDriver: ... @lazy_property() @abstractmethod diff --git a/griptape/configs/drivers/drivers_config.py b/griptape/configs/drivers/drivers_config.py index ed68bcf8c0..04edfd3033 100644 --- a/griptape/configs/drivers/drivers_config.py +++ b/griptape/configs/drivers/drivers_config.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from attrs import define @@ -13,6 +13,7 @@ DummyPromptDriver, DummyTextToSpeechDriver, DummyVectorStoreDriver, + LocalConversationMemoryDriver, ) from griptape.utils.decorators import lazy_property @@ -52,8 +53,8 @@ def vector_store_driver(self) -> BaseVectorStoreDriver: return DummyVectorStoreDriver(embedding_driver=self.embedding_driver) @lazy_property() - def conversation_memory_driver(self) -> Optional[BaseConversationMemoryDriver]: - return None + def conversation_memory_driver(self) -> BaseConversationMemoryDriver: + return LocalConversationMemoryDriver() @lazy_property() def text_to_speech_driver(self) -> BaseTextToSpeechDriver: diff --git a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py index b0c2485d6d..c505fcaf32 100644 --- a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py @@ -6,12 +6,12 @@ from attrs import Factory, define, field from griptape.drivers import BaseConversationMemoryDriver -from griptape.utils import import_optional_dependency +from griptape.utils import dict_merge, import_optional_dependency if TYPE_CHECKING: import boto3 - from griptape.memory.structure import BaseConversationMemory + from griptape.memory.structure import Run @define @@ -27,35 +27,50 @@ class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver): table: Any = field(init=False) def __attrs_post_init__(self) -> None: - dynamodb = self.session.resource("dynamodb") - - self.table = dynamodb.Table(self.table_name) - - def store(self, memory: BaseConversationMemory) -> None: - self.table.update_item( - Key=self._get_key(), - UpdateExpression="set #attr = :value", - ExpressionAttributeNames={"#attr": self.value_attribute_key}, - ExpressionAttributeValues={":value": memory.to_json()}, - ) - - def load(self) -> Optional[BaseConversationMemory]: - from griptape.memory.structure import BaseConversationMemory + self.table = self.session.resource("dynamodb").Table(self.table_name) + + def store(self, runs: list[Run], metadata: Optional[dict] = None, *, overwrite: bool = False) -> None: + if overwrite: + self.table.update_item( + Key=self._get_key(), + UpdateExpression="set #attr = :value", + ExpressionAttributeNames={"#attr": self.value_attribute_key}, + ExpressionAttributeValues={ + ":value": json.dumps({"runs": [run.to_dict() for run in runs], "metadata": metadata}) + }, + ) + else: + response = self.table.get_item(Key=self._get_key()) + if "Item" in response and self.value_attribute_key in response["Item"]: + data = json.loads(response["Item"][self.value_attribute_key]) + data["runs"] += [run.to_dict() for run in runs] + data["metadata"] = dict_merge(data["metadata"], metadata) + self.table.update_item( + Key=self._get_key(), + UpdateExpression="set #attr = :value", + ExpressionAttributeNames={"#attr": self.value_attribute_key}, + ExpressionAttributeValues={":value": json.dumps(data)}, + ) + else: + self.table.put_item( + Item={ + **self._get_key(), + self.value_attribute_key: json.dumps( + {"runs": [run.to_dict() for run in runs], "metadata": metadata} + ), + } + ) + + def load(self) -> tuple[list[Run], dict]: + from griptape.memory.structure import Run response = self.table.get_item(Key=self._get_key()) if "Item" in response and self.value_attribute_key in response["Item"]: memory_dict = json.loads(response["Item"][self.value_attribute_key]) - # needed to avoid recursive method calls - memory_dict["autoload"] = False - - memory = BaseConversationMemory.from_dict(memory_dict) - - memory.driver = self - - return memory + return [Run.from_dict(run) for run in memory_dict["runs"]], memory_dict.get("metadata", {}) else: - return None + return [], {} def _get_key(self) -> dict[str, str | int]: key: dict[str, str | int] = {self.partition_key: self.partition_key_value} diff --git a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py index 1caeb902f2..577a47c8f8 100644 --- a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py @@ -6,12 +6,12 @@ from griptape.mixins import SerializableMixin if TYPE_CHECKING: - from griptape.memory.structure import BaseConversationMemory + from griptape.memory.structure import Run class BaseConversationMemoryDriver(SerializableMixin, ABC): @abstractmethod - def store(self, memory: BaseConversationMemory) -> None: ... + def store(self, runs: list[Run], metadata: Optional[dict] = None, *, overwrite: bool = False) -> None: ... @abstractmethod - def load(self) -> Optional[BaseConversationMemory]: ... + def load(self) -> tuple[list[Run], dict]: ... diff --git a/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py b/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py index 2ea1d0d1ac..9624371545 100644 --- a/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py @@ -10,9 +10,10 @@ from griptape.artifacts import BaseArtifact from griptape.drivers import BaseConversationMemoryDriver +from griptape.utils import dict_merge if TYPE_CHECKING: - from griptape.memory.structure import BaseConversationMemory + from griptape.memory.structure import Run @define(kw_only=True) @@ -55,26 +56,59 @@ def validate_api_key(self, _: Attribute, value: Optional[str]) -> str: raise ValueError(f"{self.__class__.__name__} requires an API key") return value - def store(self, memory: BaseConversationMemory) -> None: - # serliaze the run artifacts to json strings - messages = [{"input": run.input.to_json(), "output": run.output.to_json()} for run in memory.runs] - - # serialize the metadata to a json string - # remove runs because they are already stored as Messages - metadata = memory.to_dict() - del metadata["runs"] - - # patch the Thread with the new messages and metadata - # all old Messages are replaced with the new ones - response = requests.patch( - self._get_url(f"/threads/{self.thread_id}"), - json={"messages": messages, "metadata": metadata}, - headers=self.headers, - ) - response.raise_for_status() + def store(self, runs: list[Run], metadata: Optional[dict] = None, *, overwrite: bool = False) -> None: + # serialize the run artifacts to json strings + messages = [] + for run in runs: + run_dict = { + "input": run.input.to_json(), + "output": run.output.to_json(), + "metadata": {"run_id": run.id}, + } + if run.meta is not None: + run_dict["metadata"].update(run.meta) + + messages.append(run_dict) + + body = { + "messages": messages, + "metadata": metadata, + } + + if overwrite: + # patch the Thread with the new messages and metadata + # all old Messages are replaced with the new ones + response = requests.patch( + self._get_url(f"/threads/{self.thread_id}"), + json=body, + headers=self.headers, + ) + response.raise_for_status() + else: + # add the new messages to the Thread + for message in body["messages"]: + response = requests.post( + self._get_url(f"/threads/{self.thread_id}/messages"), + json=message, + headers=self.headers, + ) + response.raise_for_status() + + # get the thread to merge the metadata + response = requests.get(self._get_url(f"/threads/{self.thread_id}"), headers=self.headers) + response.raise_for_status() + + thread_metadata = response.json().get("metadata", {}) + thread_metadata = dict_merge(thread_metadata, metadata) + response = requests.patch( + self._get_url(f"/threads/{self.thread_id}"), + json={"metadata": thread_metadata}, + headers=self.headers, + ) + response.raise_for_status() - def load(self) -> BaseConversationMemory: - from griptape.memory.structure import BaseConversationMemory, ConversationMemory, Run + def load(self) -> tuple[list[Run], dict]: + from griptape.memory.structure import Run # get the Messages from the Thread messages_response = requests.get(self._get_url(f"/threads/{self.thread_id}/messages"), headers=self.headers) @@ -90,29 +124,14 @@ def load(self) -> BaseConversationMemory: runs = [ Run( - id=m["message_id"], + id=m["metadata"].pop("run_id"), + meta=m["metadata"], input=BaseArtifact.from_json(m["input"]), output=BaseArtifact.from_json(m["output"]), ) for m in messages ] - metadata = thread_response.get("metadata") - - # the metadata will contain the serialized - # ConversationMemory object with the runs removed - # autoload=False to prevent recursively loading the memory - if metadata is not None and metadata != {}: - memory = BaseConversationMemory.from_dict( - { - **metadata, - "runs": [run.to_dict() for run in runs], - "autoload": False, - } - ) - memory.driver = self - return memory - # no metadata found, return a new ConversationMemory object - return ConversationMemory(runs=runs, autoload=False, driver=self) + return runs, thread_response.get("metadata", {}) def _get_thread_id(self) -> str: res = requests.post(self._get_url("/threads"), json={"name": uuid.uuid4().hex}, headers=self.headers) diff --git a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py index 9a79accc3b..bc081d1266 100644 --- a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py @@ -8,29 +8,55 @@ from attrs import define, field from griptape.drivers import BaseConversationMemoryDriver +from griptape.utils import dict_merge if TYPE_CHECKING: - from griptape.memory.structure import BaseConversationMemory + from griptape.memory.structure import Run -@define +@define(kw_only=True) class LocalConversationMemoryDriver(BaseConversationMemoryDriver): - file_path: str = field(default="griptape_memory.json", kw_only=True, metadata={"serializable": True}) + persist_file: Optional[str] = field(default=None, metadata={"serializable": True}) + lazy_load: bool = field(default=True, metadata={"serializable": True}) - def store(self, memory: BaseConversationMemory) -> None: - Path(self.file_path).write_text(memory.to_json()) + def __attrs_post_init__(self) -> None: + if self.persist_file is not None and not self.lazy_load: + self._load_file() - def load(self) -> Optional[BaseConversationMemory]: - from griptape.memory.structure import BaseConversationMemory + def store(self, runs: list[Run], metadata: Optional[dict] = None, *, overwrite: bool = False) -> None: + if self.persist_file is not None: + self._load_file() - if not os.path.exists(self.file_path): - return None + data = {"runs": [run.to_dict() for run in runs], "metadata": metadata} - memory_dict = json.loads(Path(self.file_path).read_text()) - # needed to avoid recursive method calls - memory_dict["autoload"] = False - memory = BaseConversationMemory.from_dict(memory_dict) + if not overwrite: + loaded_str = Path(self.persist_file).read_text() + if loaded_str: + loaded_data = json.loads(loaded_str) + loaded_data["runs"] += data["runs"] + loaded_data["metadata"] = dict_merge(loaded_data["metadata"], data["metadata"]) + data = loaded_data - memory.driver = self + Path(self.persist_file).write_text(json.dumps(data)) - return memory + def load(self) -> tuple[list[Run], dict]: + from griptape.memory.structure import Run + + if self.persist_file is not None: + self._load_file() + loaded_str = Path(self.persist_file).read_text() + if loaded_str: + data = json.loads(loaded_str) + return [Run.from_dict(run) for run in data["runs"]], data.get("metadata", {}) + + return [], {} + + def _load_file(self) -> None: + if self.persist_file is not None and not os.path.exists(self.persist_file): + directory = os.path.dirname(self.persist_file) + + if directory and not os.path.exists(directory): + os.makedirs(directory) + + if not os.path.isfile(self.persist_file): + Path(self.persist_file).touch() diff --git a/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py index 8741cda509..83837125c1 100644 --- a/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py @@ -7,12 +7,12 @@ from attrs import Factory, define, field from griptape.drivers import BaseConversationMemoryDriver -from griptape.utils.import_utils import import_optional_dependency +from griptape.utils import dict_merge, import_optional_dependency if TYPE_CHECKING: from redis import Redis - from griptape.memory.structure import BaseConversationMemory + from griptape.memory.structure import Run @define @@ -52,19 +52,22 @@ class RedisConversationMemoryDriver(BaseConversationMemoryDriver): ), ) - def store(self, memory: BaseConversationMemory) -> None: - self.client.hset(self.index, self.conversation_id, memory.to_json()) + def store(self, runs: list[Run], metadata: Optional[dict] = None, *, overwrite: bool = False) -> None: + data = {"runs": [run.to_dict() for run in runs], "metadata": metadata} + if not overwrite: + loaded_str = self.client.hget(self.index, self.conversation_id) + if loaded_str is not None: + loaded_data = json.loads(loaded_str) + loaded_data["runs"] += data["runs"] + loaded_data["metadata"] = dict_merge(loaded_data["metadata"], data["metadata"]) + data = loaded_data + self.client.hset(self.index, self.conversation_id, json.dumps(data)) - def load(self) -> Optional[BaseConversationMemory]: - from griptape.memory.structure import BaseConversationMemory + def load(self) -> tuple[list[Run], dict]: + from griptape.memory.structure import Run - key = self.index - memory_json = self.client.hget(key, self.conversation_id) + memory_json = self.client.hget(self.index, self.conversation_id) if memory_json is not None: memory_dict = json.loads(memory_json) - # needed to avoid recursive method calls - memory_dict["autoload"] = False - memory = BaseConversationMemory.from_dict(memory_dict) - memory.driver = self - return memory - return None + return [Run.from_dict(run) for run in memory_dict["runs"]], memory_dict.get("metadata", {}) + return [], {} diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index 15d0a9e999..14050bab07 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -8,6 +8,7 @@ from griptape.common import PromptStack from griptape.configs import Defaults from griptape.mixins import SerializableMixin +from griptape.utils import dict_merge_opt if TYPE_CHECKING: from griptape.drivers import BaseConversationMemoryDriver, BasePromptDriver @@ -16,36 +17,33 @@ @define class BaseConversationMemory(SerializableMixin, ABC): - driver: Optional[BaseConversationMemoryDriver] = field( + conversation_memory_driver: BaseConversationMemoryDriver = field( default=Factory(lambda: Defaults.drivers_config.conversation_memory_driver), kw_only=True ) - prompt_driver: BasePromptDriver = field( - default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True - ) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) + meta: Optional[dict] = field(default=None, kw_only=True, metadata={"serializable": True}) autoload: bool = field(default=True, kw_only=True) autoprune: bool = field(default=True, kw_only=True) max_runs: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) def __attrs_post_init__(self) -> None: - if self.driver and self.autoload: - memory = self.driver.load() - if memory is not None: - [self.add_run(r) for r in memory.runs] + if self.autoload: + runs, meta = self.conversation_memory_driver.load() + self.runs.extend(runs) + self.meta = dict_merge_opt(self.meta, meta) - def before_add_run(self) -> None: + def before_add_run(self, run: Run) -> None: pass def add_run(self, run: Run) -> BaseConversationMemory: - self.before_add_run() + self.before_add_run(run) self.try_add_run(run) - self.after_add_run() + self.after_add_run(run) return self - def after_add_run(self) -> None: - if self.driver: - self.driver.store(self) + def after_add_run(self, run: Run) -> None: + self.conversation_memory_driver.store([run], self.meta) @abstractmethod def try_add_run(self, run: Run) -> None: ... @@ -53,13 +51,16 @@ def try_add_run(self, run: Run) -> None: ... @abstractmethod def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: ... - def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = None) -> PromptStack: + def add_to_prompt_stack( + self, prompt_driver: BasePromptDriver, prompt_stack: PromptStack, index: Optional[int] = None + ) -> PromptStack: """Add the Conversation Memory runs to the Prompt Stack by modifying the messages in place. If autoprune is enabled, this will fit as many Conversation Memory runs into the Prompt Stack as possible without exceeding the token limit. Args: + prompt_driver: The Prompt Driver to use for token counting. prompt_stack: The Prompt Stack to add the Conversation Memory to. index: Optional index to insert the Conversation Memory runs at. Defaults to appending to the end of the Prompt Stack. @@ -82,8 +83,8 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = temp_stack.messages.extend(memory_inputs) # Convert the Prompt Stack into tokens left. - tokens_left = self.prompt_driver.tokenizer.count_input_tokens_left( - self.prompt_driver.prompt_stack_to_string(temp_stack), + tokens_left = prompt_driver.tokenizer.count_input_tokens_left( + prompt_driver.prompt_stack_to_string(temp_stack), ) if tokens_left > 0: # There are still tokens left, no need to prune. diff --git a/griptape/memory/structure/run.py b/griptape/memory/structure/run.py index 3d8ca38690..5d2a182ad6 100644 --- a/griptape/memory/structure/run.py +++ b/griptape/memory/structure/run.py @@ -1,13 +1,19 @@ +from __future__ import annotations + import uuid +from typing import TYPE_CHECKING, Optional from attrs import Factory, define, field -from griptape.artifacts.base_artifact import BaseArtifact from griptape.mixins import SerializableMixin +if TYPE_CHECKING: + from griptape.artifacts import BaseArtifact + -@define +@define(kw_only=True) class Run(SerializableMixin): - id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) - input: BaseArtifact = field(kw_only=True, metadata={"serializable": True}) - output: BaseArtifact = field(kw_only=True, metadata={"serializable": True}) + id: str = field(default=Factory(lambda: uuid.uuid4().hex), metadata={"serializable": True}) + meta: Optional[dict] = field(default=None, metadata={"serializable": True}) + input: BaseArtifact = field(metadata={"serializable": True}) + output: BaseArtifact = field(metadata={"serializable": True}) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 17a73e4cdb..9c00600393 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -58,7 +58,7 @@ def prompt_stack(self) -> PromptStack: if memory is not None: # insert memory into the stack right before the user messages - memory.add_to_prompt_stack(stack, 1 if system_template else 0) + memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_template else 0) return stack diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index 24607a352e..ff11944406 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -115,7 +115,7 @@ def prompt_stack(self) -> PromptStack: if memory: # inserting at index 1 to place memory right after system prompt - memory.add_to_prompt_stack(stack, 1) + memory.add_to_prompt_stack(self.prompt_driver, stack, 1) return stack diff --git a/griptape/utils/__init__.py b/griptape/utils/__init__.py index 03725f59d4..8b4bcfe6ab 100644 --- a/griptape/utils/__init__.py +++ b/griptape/utils/__init__.py @@ -7,7 +7,12 @@ from .chat import Chat from .futures import execute_futures_dict, execute_futures_list, execute_futures_list_dict from .token_counter import TokenCounter -from .dict_utils import remove_null_values_in_dict_recursively, dict_merge, remove_key_in_dict_recursively +from .dict_utils import ( + remove_null_values_in_dict_recursively, + dict_merge, + dict_merge_opt, + remove_key_in_dict_recursively, +) from .file_utils import load_file, load_files from .hash import str_to_hash from .import_utils import import_optional_dependency @@ -40,6 +45,7 @@ def minify_json(value: str) -> str: "TokenCounter", "remove_null_values_in_dict_recursively", "dict_merge", + "dict_merge_opt", "remove_key_in_dict_recursively", "Stream", "load_artifact_from_memory", diff --git a/griptape/utils/dict_utils.py b/griptape/utils/dict_utils.py index 0bf5f59db1..f8eb626371 100644 --- a/griptape/utils/dict_utils.py +++ b/griptape/utils/dict_utils.py @@ -55,3 +55,12 @@ def dict_merge(dct: Optional[dict], merge_dct: Optional[dict], *, add_keys: bool dct[key] = merge_dct[key] return dct + + +def dict_merge_opt(dct: Optional[dict], merge_dct: Optional[dict], *, add_keys: bool = True) -> Optional[dict]: + if dct is None: + return merge_dct + if merge_dct is None: + return dct + + return dict_merge(dct, merge_dct, add_keys=add_keys) diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index 129fe281f4..72be3c03f6 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -25,7 +25,11 @@ def config_with_values(self): def test_to_dict(self, config): assert config.to_dict() == { - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + "lazy_load": True, + }, "embedding_driver": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, "image_generation_driver": { "image_generation_model_driver": { @@ -77,7 +81,11 @@ def test_from_dict_with_values(self, config_with_values): def test_to_dict_with_values(self, config_with_values): assert config_with_values.to_dict() == { - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + "lazy_load": True, + }, "embedding_driver": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, "image_generation_driver": { "image_generation_model_driver": { diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index b2335d92a6..cc90a357b8 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -45,7 +45,11 @@ def test_to_dict(self, config): "input_type": "document", }, }, - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + "lazy_load": True, + }, "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index 5c514c9479..313dce9c88 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -36,7 +36,11 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, }, - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + "lazy_load": True, + }, "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 3c267d73da..b84cacb308 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -13,7 +13,11 @@ def test_to_dict(self, config): "type": "CohereDriversConfig", "image_generation_driver": {"type": "DummyImageGenerationDriver"}, "image_query_driver": {"type": "DummyImageQueryDriver"}, - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + "lazy_load": True, + }, "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, "prompt_driver": { diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index 20cc0926c2..7b9ad3a9bb 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -18,7 +18,11 @@ def test_to_dict(self, config): "stream": False, "use_native_tools": False, }, - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + "lazy_load": True, + }, "embedding_driver": {"type": "DummyEmbeddingDriver"}, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, "image_query_driver": {"type": "DummyImageQueryDriver"}, @@ -56,7 +60,7 @@ def test_lazy_init(self): assert Defaults.drivers_config.image_query_driver is not None assert Defaults.drivers_config.embedding_driver is not None assert Defaults.drivers_config.vector_store_driver is not None - assert Defaults.drivers_config.conversation_memory_driver is None + assert Defaults.drivers_config.conversation_memory_driver is not None assert Defaults.drivers_config.text_to_speech_driver is not None assert Defaults.drivers_config.audio_transcription_driver is not None @@ -65,6 +69,6 @@ def test_lazy_init(self): assert Defaults.drivers_config._image_query_driver is not None assert Defaults.drivers_config._embedding_driver is not None assert Defaults.drivers_config._vector_store_driver is not None - assert Defaults.drivers_config._conversation_memory_driver is None + assert Defaults.drivers_config._conversation_memory_driver is not None assert Defaults.drivers_config._text_to_speech_driver is not None assert Defaults.drivers_config._audio_transcription_driver is not None diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index f6df1afef5..ac9da6fbde 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -43,7 +43,11 @@ def test_to_dict(self, config): "title": None, }, }, - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + "lazy_load": True, + }, "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index 2425b178f8..182a6dff6d 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -28,7 +28,11 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, }, - "conversation_memory_driver": None, + "conversation_memory_driver": { + "type": "LocalConversationMemoryDriver", + "persist_file": None, + "lazy_load": True, + }, "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", diff --git a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py index f1a5df1be9..71fbd86de1 100644 --- a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py @@ -46,7 +46,7 @@ def test_store(self): value_attribute_key=self.VALUE_ATTRIBUTE_KEY, partition_key_value=self.PARTITION_KEY_VALUE, ) - memory = ConversationMemory(driver=memory_driver) + memory = ConversationMemory(conversation_memory_driver=memory_driver) pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -72,7 +72,7 @@ def test_store_with_sort_key(self): sort_key="sortKey", sort_key_value="foo", ) - memory = ConversationMemory(driver=memory_driver) + memory = ConversationMemory(conversation_memory_driver=memory_driver) pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -85,6 +85,12 @@ def test_store_with_sort_key(self): response = table.get_item(TableName=self.DYNAMODB_TABLE_NAME, Key={"entryId": "bar", "sortKey": "foo"}) assert "Item" in response + memory_driver.store([], {"foo": "bar"}, overwrite=True) + runs, metadata = memory_driver.load() + + assert metadata == {"foo": "bar"} + assert len(runs) == 0 + def test_load(self): memory_driver = AmazonDynamoDbConversationMemoryDriver( session=boto3.Session(region_name=self.AWS_REGION), @@ -93,7 +99,7 @@ def test_load(self): value_attribute_key=self.VALUE_ATTRIBUTE_KEY, partition_key_value=self.PARTITION_KEY_VALUE, ) - memory = ConversationMemory(driver=memory_driver) + memory = ConversationMemory(conversation_memory_driver=memory_driver) pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -101,12 +107,10 @@ def test_load(self): pipeline.run() pipeline.run() - new_memory = memory_driver.load() + runs, metadata = memory_driver.load() - assert new_memory.type == "ConversationMemory" - assert len(new_memory.runs) == 2 - assert new_memory.runs[0].input.value == "test" - assert new_memory.runs[0].output.value == "mock output" + assert len(runs) == 2 + assert metadata == {} def test_load_with_sort_key(self): memory_driver = AmazonDynamoDbConversationMemoryDriver( @@ -118,7 +122,7 @@ def test_load_with_sort_key(self): sort_key="sortKey", sort_key_value="foo", ) - memory = ConversationMemory(driver=memory_driver) + memory = ConversationMemory(conversation_memory_driver=memory_driver, meta={"foo": "bar"}) pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -126,9 +130,7 @@ def test_load_with_sort_key(self): pipeline.run() pipeline.run() - new_memory = memory_driver.load() + runs, metadata = memory_driver.load() - assert new_memory.type == "ConversationMemory" - assert len(new_memory.runs) == 2 - assert new_memory.runs[0].input.value == "test" - assert new_memory.runs[0].output.value == "mock output" + assert len(runs) == 2 + assert metadata == {"foo": "bar"} diff --git a/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py index 707132ef5f..fcd0460f48 100644 --- a/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py @@ -1,10 +1,11 @@ import json +import os import pytest from griptape.artifacts import BaseArtifact from griptape.drivers import GriptapeCloudConversationMemoryDriver -from griptape.memory.structure import BaseConversationMemory, ConversationMemory, Run, SummaryConversationMemory +from griptape.memory.structure import Run TEST_CONVERSATION = '{"type": "SummaryConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": {"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}}], "max_runs": 2}' @@ -23,6 +24,7 @@ def get(*args, **kwargs): "input": '{"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}', "output": '{"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}', "index": 0, + "metadata": {"run_id": "1234"}, } ] }, @@ -32,7 +34,7 @@ def get(*args, **kwargs): return mocker.Mock( raise_for_status=lambda: None, json=lambda: { - "metadata": json.loads(TEST_CONVERSATION), + "metadata": {"foo": "bar"}, "name": "test", "thread_id": "test_metadata", } @@ -44,12 +46,22 @@ def get(*args, **kwargs): "requests.get", side_effect=get, ) + + def post(*args, **kwargs): + if str(args[0]).endswith("/threads"): + return mocker.Mock( + raise_for_status=lambda: None, + json=lambda: {"thread_id": "test", "name": "test"}, + ) + else: + return mocker.Mock( + raise_for_status=lambda: None, + json=lambda: {"message_id": "test"}, + ) + mocker.patch( "requests.post", - return_value=mocker.Mock( - raise_for_status=lambda: None, - json=lambda: {"thread_id": "test", "name": "test"}, - ), + side_effect=post, ) mocker.patch( "requests.patch", @@ -66,26 +78,35 @@ def test_no_api_key(self): with pytest.raises(ValueError): GriptapeCloudConversationMemoryDriver(api_key=None, thread_id="test") - def test_no_thread_id(self): + def test_thread_id(self): driver = GriptapeCloudConversationMemoryDriver(api_key="test") assert driver.thread_id == "test" + os.environ["GT_CLOUD_THREAD_ID"] = "test_env" + driver = GriptapeCloudConversationMemoryDriver(api_key="test") + assert driver.thread_id == "test_env" + driver = GriptapeCloudConversationMemoryDriver(api_key="test", thread_id="test_init") + assert driver.thread_id == "test_init" - def test_store(self, driver): - memory = ConversationMemory( - runs=[ - Run(input=BaseArtifact.from_dict(run["input"]), output=BaseArtifact.from_dict(run["output"])) - for run in json.loads(TEST_CONVERSATION)["runs"] - ], - ) - assert driver.store(memory) is None + def test_store(self, driver: GriptapeCloudConversationMemoryDriver): + runs = [ + Run(input=BaseArtifact.from_dict(run["input"]), output=BaseArtifact.from_dict(run["output"])) + for run in json.loads(TEST_CONVERSATION)["runs"] + ] + assert driver.store(runs, {}) is None - def test_load(self, driver): - memory = driver.load() - assert isinstance(memory, BaseConversationMemory) - assert len(memory.runs) == 1 + def test_store_overwrite(self, driver): + runs = [ + Run(input=BaseArtifact.from_dict(run["input"]), output=BaseArtifact.from_dict(run["output"])) + for run in json.loads(TEST_CONVERSATION)["runs"] + ] + assert driver.store(runs, {}, overwrite=True) is None - def test_load_metadata(self, driver): + def test_load(self, driver): + runs, metadata = driver.load() + assert len(runs) == 1 + assert runs[0].id == "1234" + assert metadata == {} driver.thread_id = "test_metadata" - memory = driver.load() - assert isinstance(memory, SummaryConversationMemory) - assert len(memory.runs) == 1 + runs, metadata = driver.load() + assert len(runs) == 1 + assert metadata == {"foo": "bar"} diff --git a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py index dff66d0fc1..bf5d0bf1b0 100644 --- a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py @@ -21,26 +21,21 @@ def _run_before_and_after_tests(self): self.__delete_file(self.MEMORY_FILE_PATH) def test_store(self): - memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) - memory = ConversationMemory(driver=memory_driver, autoload=False) + memory_driver = LocalConversationMemoryDriver(persist_file=self.MEMORY_FILE_PATH) + memory = ConversationMemory(conversation_memory_driver=memory_driver, autoload=False) pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) - try: - with open(self.MEMORY_FILE_PATH): - raise AssertionError() - except FileNotFoundError: - assert True + assert not os.path.exists(self.MEMORY_FILE_PATH) pipeline.run() - with open(self.MEMORY_FILE_PATH): - assert True + assert os.path.exists(self.MEMORY_FILE_PATH) def test_load(self): - memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) - memory = ConversationMemory(driver=memory_driver, autoload=False, max_runs=5) + memory_driver = LocalConversationMemoryDriver(persist_file=self.MEMORY_FILE_PATH) + memory = ConversationMemory(conversation_memory_driver=memory_driver, autoload=False, max_runs=5) pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -48,17 +43,26 @@ def test_load(self): pipeline.run() pipeline.run() - new_memory = memory_driver.load() + runs, metadata = memory_driver.load() - assert new_memory.type == "ConversationMemory" - assert len(new_memory.runs) == 2 - assert new_memory.runs[0].input.value == "test" - assert new_memory.runs[0].output.value == "mock output" - assert new_memory.max_runs == 5 + assert len(runs) == 2 + assert runs[0].input.value == "test" + assert runs[0].output.value == "mock output" + assert metadata == {} + + runs[0].input.value = "new test" + + memory_driver.store(runs, {"foo": "bar"}, overwrite=True) + runs, metadata = memory_driver.load() + + assert len(runs) == 2 + assert runs[0].input.value == "new test" + assert runs[0].output.value == "mock output" + assert metadata == {"foo": "bar"} def test_autoload(self): - memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) - memory = ConversationMemory(driver=memory_driver) + memory_driver = LocalConversationMemoryDriver(persist_file=self.MEMORY_FILE_PATH, lazy_load=False) + memory = ConversationMemory(conversation_memory_driver=memory_driver) pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -66,13 +70,13 @@ def test_autoload(self): pipeline.run() pipeline.run() - autoloaded_memory = ConversationMemory(driver=memory_driver) + autoloaded_memory = ConversationMemory(conversation_memory_driver=memory_driver) assert autoloaded_memory.type == "ConversationMemory" assert len(autoloaded_memory.runs) == 2 assert autoloaded_memory.runs[0].input.value == "test" assert autoloaded_memory.runs[0].output.value == "mock output" - def __delete_file(self, file_path) -> None: + def __delete_file(self, persist_file) -> None: with contextlib.suppress(FileNotFoundError): - os.remove(file_path) + os.remove(persist_file) diff --git a/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py index 4a92a28a80..aaba13b9bb 100644 --- a/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py @@ -4,7 +4,8 @@ from griptape.drivers.memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver from griptape.memory.structure.base_conversation_memory import BaseConversationMemory -TEST_CONVERSATION = '{"type": "ConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": {"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}}], "max_runs": 2}' +TEST_DATA = '{"runs": [{"input": {"type": "TextArtifact", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "value": "Hello! How can I assist you today?"}}], "metadata": {"foo": "bar"}}' +TEST_MEMORY = '{"type": "ConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": {"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}}], "max_runs": 2}' CONVERSATION_ID = "117151897f344ff684b553d0655d8f39" INDEX = "griptape_conversation" HOST = "127.0.0.1" @@ -17,7 +18,7 @@ class TestRedisConversationMemoryDriver: def _mock_redis(self, mocker): mocker.patch.object(redis.StrictRedis, "hset", return_value=None) mocker.patch.object(redis.StrictRedis, "keys", return_value=[b"test"]) - mocker.patch.object(redis.StrictRedis, "hget", return_value=TEST_CONVERSATION) + mocker.patch.object(redis.StrictRedis, "hget", return_value=TEST_DATA) fake_redisearch = mocker.MagicMock() fake_redisearch.search = mocker.MagicMock(return_value=mocker.MagicMock(docs=[])) @@ -31,11 +32,11 @@ def driver(self): return RedisConversationMemoryDriver(host=HOST, port=PORT, db=0, index=INDEX, conversation_id=CONVERSATION_ID) def test_store(self, driver): - memory = BaseConversationMemory.from_json(TEST_CONVERSATION) - assert driver.store(memory) is None + memory = BaseConversationMemory.from_json(TEST_MEMORY) + assert driver.store(memory.runs, memory.meta) is None + assert driver.store(memory.runs, memory.meta, overwrite=True) is None def test_load(self, driver): - memory = driver.load() - assert memory.type == "ConversationMemory" - assert memory.max_runs == 2 - assert memory.runs == BaseConversationMemory.from_json(TEST_CONVERSATION).runs + runs, metadata = driver.load() + assert len(runs) == 1 + assert metadata == {"foo": "bar"} diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 3f9ac23443..84c8591d86 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -90,7 +90,7 @@ def test_add_to_prompt_stack_autopruing_disabled(self): prompt_stack = PromptStack() prompt_stack.add_user_message(TextArtifact("foo")) prompt_stack.add_assistant_message("bar") - memory.add_to_prompt_stack(prompt_stack) + memory.add_to_prompt_stack(agent.prompt_driver, prompt_stack) assert len(prompt_stack.messages) == 12 @@ -116,7 +116,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): prompt_stack.add_system_message("fizz") prompt_stack.add_user_message("foo") prompt_stack.add_assistant_message("bar") - memory.add_to_prompt_stack(prompt_stack) + memory.add_to_prompt_stack(agent.prompt_driver, prompt_stack) assert len(prompt_stack.messages) == 3 @@ -140,7 +140,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): prompt_stack.add_system_message("fizz") prompt_stack.add_user_message("foo") prompt_stack.add_assistant_message("bar") - memory.add_to_prompt_stack(prompt_stack) + memory.add_to_prompt_stack(agent.prompt_driver, prompt_stack) assert len(prompt_stack.messages) == 13 @@ -168,7 +168,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): prompt_stack.add_system_message("fizz") prompt_stack.add_user_message("foo") prompt_stack.add_assistant_message("bar") - memory.add_to_prompt_stack(prompt_stack, 1) + memory.add_to_prompt_stack(agent.prompt_driver, prompt_stack, 1) # We expect one run (2 Prompt Stack inputs) to be pruned. assert len(prompt_stack.messages) == 11 diff --git a/tests/unit/utils/test_dict_utils.py b/tests/unit/utils/test_dict_utils.py index 94e870e1a8..499cc8e38e 100644 --- a/tests/unit/utils/test_dict_utils.py +++ b/tests/unit/utils/test_dict_utils.py @@ -1,6 +1,11 @@ import pytest -from griptape.utils import dict_merge, remove_key_in_dict_recursively, remove_null_values_in_dict_recursively +from griptape.utils import ( + dict_merge, + dict_merge_opt, + remove_key_in_dict_recursively, + remove_null_values_in_dict_recursively, +) class TestDictUtils: @@ -64,3 +69,24 @@ def test_dict_merge_does_not_insert_new_keys(self): with pytest.raises(KeyError): assert dict_merge(a, b, add_keys=False)["b"]["b3"] == 6 + + def test_dict_merge_opt(self): + a = {"a": 1, "b": {"b1": 2, "b2": 3}} + b = {"a": 1, "b": {"b1": 4}} + + assert dict_merge_opt(a, b) == dict_merge(a, b) + + a = None + b = {"a": 1, "b": {"b1": 2, "b2": 3}} + + assert dict_merge_opt(a, b) == b + + a = {"a": 1, "b": {"b1": 2, "b2": 3}} + b = None + + assert dict_merge_opt(a, b) == a + + a = None + b = None + + assert dict_merge_opt(a, b) is None