Skip to content

Commit

Permalink
Refactor ConversationMemory
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Aug 22, 2024
1 parent 489453e commit a337cd4
Show file tree
Hide file tree
Showing 32 changed files with 394 additions and 215 deletions.
9 changes: 8 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased

### Added
- `BaseConversationMemory.prompt_driver` for use with autopruning.
- Parameter `lazy_load` on `LocalConversationMemoryDriver`.

### Changed
- **BREAKING**: Parameter `driver` on `BaseConversationMemory` renamed to `conversation_memory_driver`.
- **BREAKING**: `BaseConversationMemory.add_to_prompt_stack` now takes a `prompt_driver` parameter.
- **BREAKING**: `BaseConversationMemoryDriver.load` now returns `tuple[list[Run], dict]`.
- **BREAKING**: Parameter `file_path` on `LocalConversationMemoryDriver` renamed to `persist_file`.

### Fixed
- Parsing streaming response with some OpenAi compatible services.
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/src/amazon_dynamodb_sessions_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

structure = Agent(
conversation_memory=ConversationMemory(
driver=AmazonDynamoDbConversationMemoryDriver(
conversation_memory_driver=AmazonDynamoDbConversationMemoryDriver(
session=boto3.Session(
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from griptape.memory.structure import ConversationMemory
from griptape.structures import Agent

local_driver = LocalConversationMemoryDriver(file_path="memory.json")
agent = Agent(conversation_memory=ConversationMemory(driver=local_driver))
local_driver = LocalConversationMemoryDriver(persist_file="memory.json")
agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=local_driver))

agent.run("Surfing is my favorite sport.")
agent.run("What is my favorite sport?")
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
partition_key_value=conversation_id,
)

agent = Agent(conversation_memory=ConversationMemory(driver=dynamodb_driver))
agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=dynamodb_driver))

agent.run("My name is Jeff.")
agent.run("What is my name?")
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
conversation_id=conversation_id,
)

agent = Agent(conversation_memory=ConversationMemory(driver=redis_conversation_driver))
agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=redis_conversation_driver))

agent.run("My name is Jeff.")
agent.run("What is my name?")
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
cloud_conversation_driver = GriptapeCloudConversationMemoryDriver(
api_key=os.environ["GT_CLOUD_API_KEY"],
)
agent = Agent(conversation_memory=ConversationMemory(driver=cloud_conversation_driver))
agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=cloud_conversation_driver))

agent.run("My name is Jeff.")
agent.run("What is my name?")
6 changes: 3 additions & 3 deletions griptape/configs/drivers/base_drivers_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from attrs import define, field

Expand Down Expand Up @@ -38,7 +38,7 @@ class BaseDriversConfig(ABC, SerializableMixin):
_vector_store_driver: BaseVectorStoreDriver = field(
default=None, kw_only=True, metadata={"serializable": True}, alias="vector_store_driver"
)
_conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field(
_conversation_memory_driver: BaseConversationMemoryDriver = field(
default=None, kw_only=True, metadata={"serializable": True}, alias="conversation_memory_driver"
)
_text_to_speech_driver: BaseTextToSpeechDriver = field(
Expand Down Expand Up @@ -70,7 +70,7 @@ def vector_store_driver(self) -> BaseVectorStoreDriver: ...

@lazy_property()
@abstractmethod
def conversation_memory_driver(self) -> Optional[BaseConversationMemoryDriver]: ...
def conversation_memory_driver(self) -> BaseConversationMemoryDriver: ...

@lazy_property()
@abstractmethod
Expand Down
7 changes: 4 additions & 3 deletions griptape/configs/drivers/drivers_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from attrs import define

Expand All @@ -13,6 +13,7 @@
DummyPromptDriver,
DummyTextToSpeechDriver,
DummyVectorStoreDriver,
LocalConversationMemoryDriver,
)
from griptape.utils.decorators import lazy_property

Expand Down Expand Up @@ -52,8 +53,8 @@ def vector_store_driver(self) -> BaseVectorStoreDriver:
return DummyVectorStoreDriver(embedding_driver=self.embedding_driver)

@lazy_property()
def conversation_memory_driver(self) -> Optional[BaseConversationMemoryDriver]:
return None
def conversation_memory_driver(self) -> BaseConversationMemoryDriver:
return LocalConversationMemoryDriver()

@lazy_property()
def text_to_speech_driver(self) -> BaseTextToSpeechDriver:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from attrs import Factory, define, field

from griptape.drivers import BaseConversationMemoryDriver
from griptape.utils import import_optional_dependency
from griptape.utils import dict_merge, import_optional_dependency

if TYPE_CHECKING:
import boto3

from griptape.memory.structure import BaseConversationMemory
from griptape.memory.structure import Run


@define
Expand All @@ -27,35 +27,50 @@ class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver):
table: Any = field(init=False)

def __attrs_post_init__(self) -> None:
dynamodb = self.session.resource("dynamodb")

self.table = dynamodb.Table(self.table_name)

def store(self, memory: BaseConversationMemory) -> None:
self.table.update_item(
Key=self._get_key(),
UpdateExpression="set #attr = :value",
ExpressionAttributeNames={"#attr": self.value_attribute_key},
ExpressionAttributeValues={":value": memory.to_json()},
)

def load(self) -> Optional[BaseConversationMemory]:
from griptape.memory.structure import BaseConversationMemory
self.table = self.session.resource("dynamodb").Table(self.table_name)

def store(self, runs: list[Run], metadata: Optional[dict] = None, *, overwrite: bool = False) -> None:
if overwrite:
self.table.update_item(
Key=self._get_key(),
UpdateExpression="set #attr = :value",
ExpressionAttributeNames={"#attr": self.value_attribute_key},
ExpressionAttributeValues={
":value": json.dumps({"runs": [run.to_dict() for run in runs], "metadata": metadata})
},
)
else:
response = self.table.get_item(Key=self._get_key())
if "Item" in response and self.value_attribute_key in response["Item"]:
data = json.loads(response["Item"][self.value_attribute_key])
data["runs"] += [run.to_dict() for run in runs]
data["metadata"] = dict_merge(data["metadata"], metadata)
self.table.update_item(
Key=self._get_key(),
UpdateExpression="set #attr = :value",
ExpressionAttributeNames={"#attr": self.value_attribute_key},
ExpressionAttributeValues={":value": json.dumps(data)},
)
else:
self.table.put_item(
Item={
**self._get_key(),
self.value_attribute_key: json.dumps(
{"runs": [run.to_dict() for run in runs], "metadata": metadata}
),
}
)

def load(self) -> tuple[list[Run], dict]:
from griptape.memory.structure import Run

response = self.table.get_item(Key=self._get_key())

if "Item" in response and self.value_attribute_key in response["Item"]:
memory_dict = json.loads(response["Item"][self.value_attribute_key])
# needed to avoid recursive method calls
memory_dict["autoload"] = False

memory = BaseConversationMemory.from_dict(memory_dict)

memory.driver = self

return memory
return [Run.from_dict(run) for run in memory_dict["runs"]], memory_dict.get("metadata", {})
else:
return None
return [], {}

def _get_key(self) -> dict[str, str | int]:
key: dict[str, str | int] = {self.partition_key: self.partition_key_value}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from griptape.mixins import SerializableMixin

if TYPE_CHECKING:
from griptape.memory.structure import BaseConversationMemory
from griptape.memory.structure import Run


class BaseConversationMemoryDriver(SerializableMixin, ABC):
@abstractmethod
def store(self, memory: BaseConversationMemory) -> None: ...
def store(self, runs: list[Run], metadata: Optional[dict] = None, *, overwrite: bool = False) -> None: ...

@abstractmethod
def load(self) -> Optional[BaseConversationMemory]: ...
def load(self) -> tuple[list[Run], dict]: ...
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

from griptape.artifacts import BaseArtifact
from griptape.drivers import BaseConversationMemoryDriver
from griptape.utils import dict_merge

if TYPE_CHECKING:
from griptape.memory.structure import BaseConversationMemory
from griptape.memory.structure import Run


@define(kw_only=True)
Expand Down Expand Up @@ -55,26 +56,59 @@ def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
raise ValueError(f"{self.__class__.__name__} requires an API key")
return value

def store(self, memory: BaseConversationMemory) -> None:
# serliaze the run artifacts to json strings
messages = [{"input": run.input.to_json(), "output": run.output.to_json()} for run in memory.runs]

# serialize the metadata to a json string
# remove runs because they are already stored as Messages
metadata = memory.to_dict()
del metadata["runs"]

# patch the Thread with the new messages and metadata
# all old Messages are replaced with the new ones
response = requests.patch(
self._get_url(f"/threads/{self.thread_id}"),
json={"messages": messages, "metadata": metadata},
headers=self.headers,
)
response.raise_for_status()
def store(self, runs: list[Run], metadata: Optional[dict] = None, *, overwrite: bool = False) -> None:
# serialize the run artifacts to json strings
messages = []
for run in runs:
run_dict = {
"input": run.input.to_json(),
"output": run.output.to_json(),
"metadata": {"run_id": run.id},
}
if run.meta is not None:
run_dict["metadata"].update(run.meta)

Check warning on line 69 in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py#L69

Added line #L69 was not covered by tests

messages.append(run_dict)

body = {
"messages": messages,
"metadata": metadata,
}

if overwrite:
# patch the Thread with the new messages and metadata
# all old Messages are replaced with the new ones
response = requests.patch(
self._get_url(f"/threads/{self.thread_id}"),
json=body,
headers=self.headers,
)
response.raise_for_status()
else:
# add the new messages to the Thread
for message in body["messages"]:
response = requests.post(
self._get_url(f"/threads/{self.thread_id}/messages"),
json=message,
headers=self.headers,
)
response.raise_for_status()

# get the thread to merge the metadata
response = requests.get(self._get_url(f"/threads/{self.thread_id}"), headers=self.headers)
response.raise_for_status()

thread_metadata = response.json().get("metadata", {})
thread_metadata = dict_merge(thread_metadata, metadata)
response = requests.patch(
self._get_url(f"/threads/{self.thread_id}"),
json={"metadata": thread_metadata},
headers=self.headers,
)
response.raise_for_status()

def load(self) -> BaseConversationMemory:
from griptape.memory.structure import BaseConversationMemory, ConversationMemory, Run
def load(self) -> tuple[list[Run], dict]:
from griptape.memory.structure import Run

# get the Messages from the Thread
messages_response = requests.get(self._get_url(f"/threads/{self.thread_id}/messages"), headers=self.headers)
Expand All @@ -90,29 +124,14 @@ def load(self) -> BaseConversationMemory:

runs = [
Run(
id=m["message_id"],
id=m["metadata"].pop("run_id"),
meta=m["metadata"],
input=BaseArtifact.from_json(m["input"]),
output=BaseArtifact.from_json(m["output"]),
)
for m in messages
]
metadata = thread_response.get("metadata")

# the metadata will contain the serialized
# ConversationMemory object with the runs removed
# autoload=False to prevent recursively loading the memory
if metadata is not None and metadata != {}:
memory = BaseConversationMemory.from_dict(
{
**metadata,
"runs": [run.to_dict() for run in runs],
"autoload": False,
}
)
memory.driver = self
return memory
# no metadata found, return a new ConversationMemory object
return ConversationMemory(runs=runs, autoload=False, driver=self)
return runs, thread_response.get("metadata", {})

def _get_thread_id(self) -> str:
res = requests.post(self._get_url("/threads"), json={"name": uuid.uuid4().hex}, headers=self.headers)
Expand Down
Loading

0 comments on commit a337cd4

Please sign in to comment.