Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Conversation Memory Autoload #1033

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
### Fixed
- Conversation Memory not using `StructureConfig.conversation_memory_driver` when autoloading.

### Added
- `AstraDbVectorStoreDriver` to support DataStax Astra DB as a vector store.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ from griptape.drivers import LocalConversationMemoryDriver
from griptape.memory.structure import ConversationMemory

local_driver = LocalConversationMemoryDriver(file_path="memory.json")
agent = Agent(conversation_memory=ConversationMemory(driver=local_driver))
agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=local_driver))

agent.run("Surfing is my favorite sport.")
agent.run("What is my favorite sport?")
Expand Down Expand Up @@ -47,7 +47,7 @@ dynamodb_driver = AmazonDynamoDbConversationMemoryDriver(
partition_key_value=conversation_id,
)

agent = Agent(conversation_memory=ConversationMemory(driver=dynamodb_driver))
agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=dynamodb_driver))

agent.run("My name is Jeff.")
agent.run("What is my name?")
Expand Down Expand Up @@ -78,7 +78,7 @@ redis_conversation_driver = RedisConversationMemoryDriver(
conversation_id = conversation_id
)

agent = Agent(conversation_memory=ConversationMemory(driver=redis_conversation_driver))
agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=redis_conversation_driver))

agent.run("My name is Jeff.")
agent.run("What is my name?")
Expand Down
31 changes: 24 additions & 7 deletions griptape/memory/structure/base_conversation_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,27 @@

@define
class BaseConversationMemory(SerializableMixin, ABC):
driver: Optional[BaseConversationMemoryDriver] = field(default=None, kw_only=True)
runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True})
structure: Structure = field(init=False)
autoload: bool = field(default=True, kw_only=True)
autoprune: bool = field(default=True, kw_only=True)
max_runs: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
_conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field(
default=None, kw_only=True, alias="conversation_memory_driver"
)

def __attrs_post_init__(self) -> None:
if self.driver and self.autoload:
memory = self.driver.load()
if memory is not None:
[self.add_run(r) for r in memory.runs]
if self.autoload and not hasattr(self, "structure"):
self.load()

@property
def conversation_memory_driver(self) -> Optional[BaseConversationMemoryDriver]:
if self._conversation_memory_driver is None:
if hasattr(self, "structure"):
return self.structure.config.conversation_memory_driver
else:
return None
return self._conversation_memory_driver

Comment on lines +33 to 40
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the strucuture just set the driver rather than setting the structure itself? also why the hasattr? should it just default to None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern is easier to apply to other areas of the framework (tasks) if the Structure sets itself.

init=False because we don't want users to set the Structure, it should only be done by the Structure itself.

def before_add_run(self) -> None:
pass
Expand All @@ -40,8 +49,16 @@ def add_run(self, run: Run) -> BaseConversationMemory:
return self

def after_add_run(self) -> None:
if self.driver:
self.driver.store(self)
if self.conversation_memory_driver is not None:
self.conversation_memory_driver.store(self)

def load(self) -> BaseConversationMemory:
if self.conversation_memory_driver is not None:
memory = self.conversation_memory_driver.load()
if memory is not None:
[self.add_run(r) for r in memory.runs]

return self

@abstractmethod
def try_add_run(self, run: Run) -> None: ...
Expand Down
4 changes: 3 additions & 1 deletion griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Structure(ABC, EventPublisherMixin):
logger_level: int = field(default=logging.INFO, kw_only=True)
conversation_memory: Optional[BaseConversationMemory] = field(
default=Factory(
lambda self: ConversationMemory(driver=self.config.conversation_memory_driver),
lambda self: ConversationMemory(conversation_memory_driver=self.config.conversation_memory_driver),
takes_self=True,
),
kw_only=True,
Expand Down Expand Up @@ -96,6 +96,8 @@ def validate_rules(self, _: Attribute, rules: list[Rule]) -> None:
def __attrs_post_init__(self) -> None:
if self.conversation_memory is not None:
self.conversation_memory.structure = self
if self.conversation_memory.autoload:
self.conversation_memory.load()
Comment on lines +99 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels to me like this autoloading decision should be made from within the conversation memory class. Like either in the structure setter, or in another method called maybe preprocess (if we wanna be consistent with the way tasks are added to structures)

Also, won't load() get called twice? (Once in the __attrs_post_init__, when there is no structure, then once here after the structure is added) Is that ok?


self.config.structure = self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_store(self):
value_attribute_key=self.VALUE_ATTRIBUTE_KEY,
partition_key_value=self.PARTITION_KEY_VALUE,
)
memory = ConversationMemory(driver=memory_driver)
memory = ConversationMemory(conversation_memory_driver=memory_driver)
pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory)

pipeline.add_task(PromptTask("test"))
Expand All @@ -75,7 +75,7 @@ def test_store_with_sort_key(self):
sort_key="sortKey",
sort_key_value="foo",
)
memory = ConversationMemory(driver=memory_driver)
memory = ConversationMemory(conversation_memory_driver=memory_driver)
pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory)

pipeline.add_task(PromptTask("test"))
Expand All @@ -97,7 +97,7 @@ def test_load(self):
value_attribute_key=self.VALUE_ATTRIBUTE_KEY,
partition_key_value=self.PARTITION_KEY_VALUE,
)
memory = ConversationMemory(driver=memory_driver)
memory = ConversationMemory(conversation_memory_driver=memory_driver)
pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory)

pipeline.add_task(PromptTask("test"))
Expand All @@ -123,7 +123,7 @@ def test_load_with_sort_key(self):
sort_key="sortKey",
sort_key_value="foo",
)
memory = ConversationMemory(driver=memory_driver)
memory = ConversationMemory(conversation_memory_driver=memory_driver)
pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory)

pipeline.add_task(PromptTask("test"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _run_before_and_after_tests(self):
def test_store(self):
prompt_driver = MockPromptDriver()
memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH)
memory = ConversationMemory(driver=memory_driver, autoload=False)
memory = ConversationMemory(conversation_memory_driver=memory_driver, autoload=False)
pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory)

pipeline.add_task(PromptTask("test"))
Expand All @@ -43,7 +43,7 @@ def test_store(self):
def test_load(self):
prompt_driver = MockPromptDriver()
memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH)
memory = ConversationMemory(driver=memory_driver, autoload=False, max_runs=5)
memory = ConversationMemory(conversation_memory_driver=memory_driver, autoload=False, max_runs=5)
pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory)

pipeline.add_task(PromptTask("test"))
Expand All @@ -62,15 +62,15 @@ def test_load(self):
def test_autoload(self):
prompt_driver = MockPromptDriver()
memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH)
memory = ConversationMemory(driver=memory_driver)
memory = ConversationMemory(conversation_memory_driver=memory_driver)
pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory)

pipeline.add_task(PromptTask("test"))

pipeline.run()
pipeline.run()

autoloaded_memory = ConversationMemory(driver=memory_driver)
autoloaded_memory = ConversationMemory(conversation_memory_driver=memory_driver)

assert autoloaded_memory.type == "ConversationMemory"
assert len(autoloaded_memory.runs) == 2
Expand Down
Loading