diff --git a/CHANGELOG.md b/CHANGELOG.md index efe70904d..3e3daaaad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased +### Fixed +- Conversation Memory not using `StructureConfig.conversation_memory_driver` when autoloading. ### Added - `AstraDbVectorStoreDriver` to support DataStax Astra DB as a vector store. diff --git a/docs/griptape-framework/drivers/conversation-memory-drivers.md b/docs/griptape-framework/drivers/conversation-memory-drivers.md index acdb7c202..d21e78824 100644 --- a/docs/griptape-framework/drivers/conversation-memory-drivers.md +++ b/docs/griptape-framework/drivers/conversation-memory-drivers.md @@ -19,7 +19,7 @@ from griptape.drivers import LocalConversationMemoryDriver from griptape.memory.structure import ConversationMemory local_driver = LocalConversationMemoryDriver(file_path="memory.json") -agent = Agent(conversation_memory=ConversationMemory(driver=local_driver)) +agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=local_driver)) agent.run("Surfing is my favorite sport.") agent.run("What is my favorite sport?") @@ -47,7 +47,7 @@ dynamodb_driver = AmazonDynamoDbConversationMemoryDriver( partition_key_value=conversation_id, ) -agent = Agent(conversation_memory=ConversationMemory(driver=dynamodb_driver)) +agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=dynamodb_driver)) agent.run("My name is Jeff.") agent.run("What is my name?") @@ -78,7 +78,7 @@ redis_conversation_driver = RedisConversationMemoryDriver( conversation_id = conversation_id ) -agent = Agent(conversation_memory=ConversationMemory(driver=redis_conversation_driver)) +agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=redis_conversation_driver)) agent.run("My name is Jeff.") agent.run("What is my name?") diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index c3d3c501e..a509df238 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -16,18 +16,27 @@ @define class BaseConversationMemory(SerializableMixin, ABC): - driver: Optional[BaseConversationMemoryDriver] = field(default=None, kw_only=True) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) structure: Structure = field(init=False) autoload: bool = field(default=True, kw_only=True) autoprune: bool = field(default=True, kw_only=True) max_runs: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) + _conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( + default=None, kw_only=True, alias="conversation_memory_driver" + ) def __attrs_post_init__(self) -> None: - if self.driver and self.autoload: - memory = self.driver.load() - if memory is not None: - [self.add_run(r) for r in memory.runs] + if self.autoload and not hasattr(self, "structure"): + self.load() + + @property + def conversation_memory_driver(self) -> Optional[BaseConversationMemoryDriver]: + if self._conversation_memory_driver is None: + if hasattr(self, "structure"): + return self.structure.config.conversation_memory_driver + else: + return None + return self._conversation_memory_driver def before_add_run(self) -> None: pass @@ -40,8 +49,16 @@ def add_run(self, run: Run) -> BaseConversationMemory: return self def after_add_run(self) -> None: - if self.driver: - self.driver.store(self) + if self.conversation_memory_driver is not None: + self.conversation_memory_driver.store(self) + + def load(self) -> BaseConversationMemory: + if self.conversation_memory_driver is not None: + memory = self.conversation_memory_driver.load() + if memory is not None: + [self.add_run(r) for r in memory.runs] + + return self @abstractmethod def try_add_run(self, run: Run) -> None: ... diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 765910f5c..d9fc6298b 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -62,7 +62,7 @@ class Structure(ABC, EventPublisherMixin): logger_level: int = field(default=logging.INFO, kw_only=True) conversation_memory: Optional[BaseConversationMemory] = field( default=Factory( - lambda self: ConversationMemory(driver=self.config.conversation_memory_driver), + lambda self: ConversationMemory(conversation_memory_driver=self.config.conversation_memory_driver), takes_self=True, ), kw_only=True, @@ -96,6 +96,8 @@ def validate_rules(self, _: Attribute, rules: list[Rule]) -> None: def __attrs_post_init__(self) -> None: if self.conversation_memory is not None: self.conversation_memory.structure = self + if self.conversation_memory.autoload: + self.conversation_memory.load() self.config.structure = self diff --git a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py index 8e700d0a5..d992f724c 100644 --- a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py @@ -48,7 +48,7 @@ def test_store(self): value_attribute_key=self.VALUE_ATTRIBUTE_KEY, partition_key_value=self.PARTITION_KEY_VALUE, ) - memory = ConversationMemory(driver=memory_driver) + memory = ConversationMemory(conversation_memory_driver=memory_driver) pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -75,7 +75,7 @@ def test_store_with_sort_key(self): sort_key="sortKey", sort_key_value="foo", ) - memory = ConversationMemory(driver=memory_driver) + memory = ConversationMemory(conversation_memory_driver=memory_driver) pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -97,7 +97,7 @@ def test_load(self): value_attribute_key=self.VALUE_ATTRIBUTE_KEY, partition_key_value=self.PARTITION_KEY_VALUE, ) - memory = ConversationMemory(driver=memory_driver) + memory = ConversationMemory(conversation_memory_driver=memory_driver) pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -123,7 +123,7 @@ def test_load_with_sort_key(self): sort_key="sortKey", sort_key_value="foo", ) - memory = ConversationMemory(driver=memory_driver) + memory = ConversationMemory(conversation_memory_driver=memory_driver) pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) pipeline.add_task(PromptTask("test")) diff --git a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py index e1a383ab9..a75365265 100644 --- a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py @@ -24,7 +24,7 @@ def _run_before_and_after_tests(self): def test_store(self): prompt_driver = MockPromptDriver() memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) - memory = ConversationMemory(driver=memory_driver, autoload=False) + memory = ConversationMemory(conversation_memory_driver=memory_driver, autoload=False) pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -43,7 +43,7 @@ def test_store(self): def test_load(self): prompt_driver = MockPromptDriver() memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) - memory = ConversationMemory(driver=memory_driver, autoload=False, max_runs=5) + memory = ConversationMemory(conversation_memory_driver=memory_driver, autoload=False, max_runs=5) pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -62,7 +62,7 @@ def test_load(self): def test_autoload(self): prompt_driver = MockPromptDriver() memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) - memory = ConversationMemory(driver=memory_driver) + memory = ConversationMemory(conversation_memory_driver=memory_driver) pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -70,7 +70,7 @@ def test_autoload(self): pipeline.run() pipeline.run() - autoloaded_memory = ConversationMemory(driver=memory_driver) + autoloaded_memory = ConversationMemory(conversation_memory_driver=memory_driver) assert autoloaded_memory.type == "ConversationMemory" assert len(autoloaded_memory.runs) == 2