-
Notifications
You must be signed in to change notification settings - Fork 186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Conversation Memory class and drivers #1084
Conversation
50d8e36
to
5cd4f87
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅ 📢 Thoughts on this report? Let us know! |
default=Factory(lambda: Defaults.drivers_config.conversation_memory_driver), kw_only=True | ||
) | ||
prompt_driver: BasePromptDriver = field( | ||
default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True | ||
) | ||
runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) | ||
metadata: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[self.add_run(r) for r in memory.runs] | ||
if self.autoload: | ||
runs, metadata = self.conversation_memory_driver.load() | ||
runs.extend(self.runs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're going to merge Driver Run
s in (I like this idea), I think we should add them after user defined runs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
went back and forth. my thought was that, if you passed new Runs here, they would probably be the most recent and the data that gets loaded would be "historical"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revisiting this, i think we have the same intention but my implementation is wrong
if self.persist_file is not None and not self.lazy_load: | ||
self._load_file() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How would one use lazy_load
? Should this be present in all Drivers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly for "backwards compatibility" with the previous implementation. it wouldnt try to create the file unless accessed. we could just make one default and stick with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I vote for picking one way.
I kind of lean towards the lazy way, otherwise, it'll just create like an empty file, right? Though I could see the predictable behavior nice I suppose... though if the only goal is to load it back into the same driver, then the driver can check if it exists... yeah I vote for the lazy way. That said I don't feel too strongly about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dont really care either way either. ill remove the param and default to the existing lazy behavior and see how that looks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm misunderstanding the functionality, but can we just check if persist_file
is set and only save then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its timing on when the file is actually created. LocalVectorStoreDriver creates the file on post_init, whereas the original implementation of LocalConversationMemoryDriver waited until .store()
to create the file if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, so when would we ever want lazy_load=False
? Seems easiest to just check if persist_file
is set during store
and that's the only time we create a file.
griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
Show resolved
Hide resolved
if not overwrite: | ||
loaded_str = Path(self.persist_file).read_text() | ||
if loaded_str: | ||
loaded_data = json.loads(loaded_str) | ||
loaded_data["runs"] += data["runs"] | ||
loaded_data["metadata"] = dict_merge(loaded_data["metadata"], data["metadata"]) | ||
data = loaded_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may be a slightly cleaner method of appending to the json file.
thread_metadata = response.json().get("metadata", {}) | ||
thread_metadata = dict_merge(thread_metadata, metadata) | ||
response = requests.patch( | ||
self._get_url(f"/threads/{self.thread_id}"), | ||
json={"metadata": thread_metadata}, | ||
headers=self.headers, | ||
) | ||
response.raise_for_status() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be DRY'd up with the overwrite logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on the cloud side? nvm ill see how i can update it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah just meant that we should try to only have a single requests.patch
for both conditional flows.
loaded_str = self.client.hget(self.index, self.conversation_id) | ||
if loaded_str is not None: | ||
loaded_data = json.loads(loaded_str) | ||
loaded_data["runs"] += data["runs"] | ||
loaded_data["metadata"] = dict_merge(loaded_data["metadata"], data["metadata"]) | ||
data = loaded_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use a LPUSH to append to the list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does that work for pushing to nested lists?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm honestly not sure, but it does seem like something we should at least try.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!
Comments are mainly topics for discussion. Some may be out of scope for this PR
griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
Show resolved
Hide resolved
griptape/drivers/memory/conversation/base_conversation_memory_driver.py
Outdated
Show resolved
Hide resolved
if self.persist_file is not None and not self.lazy_load: | ||
self._load_file() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I vote for picking one way.
I kind of lean towards the lazy way, otherwise, it'll just create like an empty file, right? Though I could see the predictable behavior nice I suppose... though if the only goal is to load it back into the same driver, then the driver can check if it exists... yeah I vote for the lazy way. That said I don't feel too strongly about it.
|
||
|
||
class BaseConversationMemoryDriver(SerializableMixin, ABC): | ||
@abstractmethod | ||
def store(self, memory: BaseConversationMemory) -> None: ... | ||
def store(self, runs: list[Run], metadata: dict, *, overwrite: bool = False) -> None: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternative idea: how about adding a clear()
(or reset()
) method instead of overwrite
?
(Or am I missing understanding the behavior?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, just a couple minor questions/feedbacks
def _get_dict(self, runs: list[Run], metadata: Optional[dict] = None) -> dict: | ||
data: dict = {"runs": [run.to_dict() for run in runs]} | ||
if metadata is not None: | ||
data["metadata"] = metadata | ||
return data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_get_dict
doesn't tell me much about why this method exists. Maybe rename to something like _to_params
?
griptape/utils/dict_utils.py
Outdated
def dict_merge_opt(dct: Optional[dict], merge_dct: Optional[dict], *, add_keys: bool = True) -> Optional[dict]: | ||
if dct is None: | ||
return merge_dct | ||
if merge_dct is None: | ||
return dct | ||
|
||
return dict_merge(dct, merge_dct, add_keys=add_keys) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this method need to exist? Can its functionality be merged into dict_merge
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
every instance of its usage would need an updated type hint. tried to be unobtrusive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted
for run in runs: | ||
run_dict = dict_merge_opt( | ||
{ | ||
"input": run.input.to_json(), | ||
"output": run.output.to_json(), | ||
"metadata": {"run_id": run.id}, | ||
}, | ||
run.meta, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use list comprehension instead.
input: BaseArtifact = field(kw_only=True, metadata={"serializable": True}) | ||
output: BaseArtifact = field(kw_only=True, metadata={"serializable": True}) | ||
id: str = field(default=Factory(lambda: uuid.uuid4().hex), metadata={"serializable": True}) | ||
meta: Optional[dict] = field(default=None, metadata={"serializable": True}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should default to meta: dict[str, Any]
instead of Optional
. More consistent with other meta implementations, and then we don't need merge_dict_opt
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
must have been looking at BaseVectorStoreDriver.Entry
.
griptape/utils/__init__.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to be annoying, can you remove this change from the PR?
tests/unit/utils/test_dict_utils.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved outside of the last two requested changes.
MIGRATION.md
Outdated
### LocalConversationMemoryDriver `file_path` renamed to `persist_file` | ||
|
||
`LocalConversationMemoryDriver.file_path` has been renamed to `persist_file` and is now `Optional[str]`. If `persist_file` is not passed as a parameter, nothing will be persisted and no errors will be raised. `LocalConversationMemoryDriver` is now the default driver in the global `Defaults` object. | ||
|
||
#### 0.30.X | ||
```python | ||
local_driver_with_file = LocalConversationMemoryDriver( | ||
file_path="my_file.json" | ||
) | ||
|
||
local_driver = LocalConversationMemoryDriver() | ||
|
||
assert local_driver_with_file.file_path == "my_file.json" | ||
assert local_driver.file_path == "griptape_memory.json" | ||
``` | ||
|
||
#### 0.31.X | ||
```python | ||
local_driver_with_file = LocalConversationMemoryDriver( | ||
persist_file="my_file.json" | ||
) | ||
|
||
local_driver = LocalConversationMemoryDriver() | ||
|
||
assert local_driver_with_file.persist_file == "my_file.json" | ||
assert local_driver.persist_file is None | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this higher up since the other examples use it.
MIGRATION.md
Outdated
|
||
### Changes to BaseConversationMemoryDriver | ||
|
||
`BaseConversationMemoryDriver` has updated parameter names and different method signatures for `.store` and `.load`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should also call out driver
renamed to conversation_memory_driver
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!
|
||
key = self.index | ||
memory_json = self.client.hget(key, self.conversation_id) | ||
def load(self) -> tuple[list[Run], dict[str, Any]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so much cleaner!
Describe your changes
LocalConversationMemory
with an optionalpersist_file
, consistent withLocalVectorStoreDriver
load
to a tuple instead of instanceBaseConversationMemory.to_prompt_stack
to take a prompt driver.Issue ticket number and link