Skip to content

Commit

Permalink
Refactor Conversation Memory class and drivers (#1084)
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo authored Aug 27, 2024
1 parent 3c604af commit ef61c53
Show file tree
Hide file tree
Showing 30 changed files with 322 additions and 211 deletions.
10 changes: 9 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,18 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased

### Added
- `BaseConversationMemory.prompt_driver` for use with autopruning.
- Parameter `meta: dict` on `BaseEvent`.

### Changed
- **BREAKING**: Parameter `driver` on `BaseConversationMemory` renamed to `conversation_memory_driver`.
- **BREAKING**: `BaseConversationMemory.add_to_prompt_stack` now takes a `prompt_driver` parameter.
- **BREAKING**: `BaseConversationMemoryDriver.load` now returns `tuple[list[Run], Optional[dict]]`.
- **BREAKING**: `BaseConversationMemoryDriver.store` now takes `runs: list[Run]` and `metadata: Optional[dict]` as input.
- **BREAKING**: Parameter `file_path` on `LocalConversationMemoryDriver` renamed to `persist_file` and is now type `Optional[str]`.
- `Defaults.drivers_config.conversation_memory_driver` now defaults to `LocalConversationMemoryDriver` instead of `None`.

### Fixed
- Parsing streaming response with some OpenAi compatible services.

Expand Down
67 changes: 65 additions & 2 deletions MIGRATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,79 @@ This document provides instructions for migrating your codebase to accommodate b
Drivers, Loaders, and Engines will now raises exceptions rather than returning `ErrorArtifact`s.
Update any logic that expects `ErrorArtifact` to handle exceptions instead.

#### Before
```python
# Before
artifacts = WebLoader().load("https://www.griptape.ai")

if isinstance(artifacts, ErrorArtifact):
raise Exception(artifacts.value)
```

# After
#### After
```python
try:
artifacts = WebLoader().load("https://www.griptape.ai")
except Exception as e:
raise e
```

### LocalConversationMemoryDriver `file_path` renamed to `persist_file`

`LocalConversationMemoryDriver.file_path` has been renamed to `persist_file` and is now `Optional[str]`. If `persist_file` is not passed as a parameter, nothing will be persisted and no errors will be raised. `LocalConversationMemoryDriver` is now the default driver in the global `Defaults` object.

#### Before
```python
local_driver_with_file = LocalConversationMemoryDriver(
file_path="my_file.json"
)

local_driver = LocalConversationMemoryDriver()

assert local_driver_with_file.file_path == "my_file.json"
assert local_driver.file_path == "griptape_memory.json"
```

#### After
```python
local_driver_with_file = LocalConversationMemoryDriver(
persist_file="my_file.json"
)

local_driver = LocalConversationMemoryDriver()

assert local_driver_with_file.persist_file == "my_file.json"
assert local_driver.persist_file is None
```

### Changes to BaseConversationMemoryDriver

`BaseConversationMemoryDriver.driver` has been renamed to `conversation_memory_driver`. Method signatures for `.store` and `.load` have been changed.

#### Before
```python
memory_driver = LocalConversationMemoryDriver()

conversation_memory = ConversationMemory(
driver=memory_driver
)

load_result: BaseConversationMemory = memory_driver.load()

memory_driver.store(conversation_memory)
```

#### After
```python
memory_driver = LocalConversationMemoryDriver()

conversation_memory = ConversationMemory(
conversation_memory_driver=memory_driver
)

load_result: tuple[list[Run], dict[str, Any]] = memory_driver.load()

memory_driver.store(
conversation_memory.runs,
conversation_memory.meta
)
```
2 changes: 1 addition & 1 deletion docs/examples/src/amazon_dynamodb_sessions_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

structure = Agent(
conversation_memory=ConversationMemory(
driver=AmazonDynamoDbConversationMemoryDriver(
conversation_memory_driver=AmazonDynamoDbConversationMemoryDriver(
session=boto3.Session(
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from griptape.memory.structure import ConversationMemory
from griptape.structures import Agent

local_driver = LocalConversationMemoryDriver(file_path="memory.json")
agent = Agent(conversation_memory=ConversationMemory(driver=local_driver))
local_driver = LocalConversationMemoryDriver(persist_file="memory.json")
agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=local_driver))

agent.run("Surfing is my favorite sport.")
agent.run("What is my favorite sport?")
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
partition_key_value=conversation_id,
)

agent = Agent(conversation_memory=ConversationMemory(driver=dynamodb_driver))
agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=dynamodb_driver))

agent.run("My name is Jeff.")
agent.run("What is my name?")
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
conversation_id=conversation_id,
)

agent = Agent(conversation_memory=ConversationMemory(driver=redis_conversation_driver))
agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=redis_conversation_driver))

agent.run("My name is Jeff.")
agent.run("What is my name?")
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
cloud_conversation_driver = GriptapeCloudConversationMemoryDriver(
api_key=os.environ["GT_CLOUD_API_KEY"],
)
agent = Agent(conversation_memory=ConversationMemory(driver=cloud_conversation_driver))
agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=cloud_conversation_driver))

agent.run("My name is Jeff.")
agent.run("What is my name?")
6 changes: 3 additions & 3 deletions griptape/configs/drivers/base_drivers_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from attrs import define, field

Expand Down Expand Up @@ -38,7 +38,7 @@ class BaseDriversConfig(ABC, SerializableMixin):
_vector_store_driver: BaseVectorStoreDriver = field(
default=None, kw_only=True, metadata={"serializable": True}, alias="vector_store_driver"
)
_conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field(
_conversation_memory_driver: BaseConversationMemoryDriver = field(
default=None, kw_only=True, metadata={"serializable": True}, alias="conversation_memory_driver"
)
_text_to_speech_driver: BaseTextToSpeechDriver = field(
Expand Down Expand Up @@ -70,7 +70,7 @@ def vector_store_driver(self) -> BaseVectorStoreDriver: ...

@lazy_property()
@abstractmethod
def conversation_memory_driver(self) -> Optional[BaseConversationMemoryDriver]: ...
def conversation_memory_driver(self) -> BaseConversationMemoryDriver: ...

@lazy_property()
@abstractmethod
Expand Down
7 changes: 4 additions & 3 deletions griptape/configs/drivers/drivers_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from attrs import define

Expand All @@ -13,6 +13,7 @@
DummyPromptDriver,
DummyTextToSpeechDriver,
DummyVectorStoreDriver,
LocalConversationMemoryDriver,
)
from griptape.utils.decorators import lazy_property

Expand Down Expand Up @@ -52,8 +53,8 @@ def vector_store_driver(self) -> BaseVectorStoreDriver:
return DummyVectorStoreDriver(embedding_driver=self.embedding_driver)

@lazy_property()
def conversation_memory_driver(self) -> Optional[BaseConversationMemoryDriver]:
return None
def conversation_memory_driver(self) -> BaseConversationMemoryDriver:
return LocalConversationMemoryDriver()

@lazy_property()
def text_to_speech_driver(self) -> BaseTextToSpeechDriver:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
if TYPE_CHECKING:
import boto3

from griptape.memory.structure import BaseConversationMemory
from griptape.memory.structure import Run


@define
Expand All @@ -27,35 +27,26 @@ class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver):
table: Any = field(init=False)

def __attrs_post_init__(self) -> None:
dynamodb = self.session.resource("dynamodb")
self.table = self.session.resource("dynamodb").Table(self.table_name)

self.table = dynamodb.Table(self.table_name)

def store(self, memory: BaseConversationMemory) -> None:
def store(self, runs: list[Run], metadata: dict) -> None:
self.table.update_item(
Key=self._get_key(),
UpdateExpression="set #attr = :value",
ExpressionAttributeNames={"#attr": self.value_attribute_key},
ExpressionAttributeValues={":value": memory.to_json()},
ExpressionAttributeValues={
":value": json.dumps(self._to_params_dict(runs, metadata)),
},
)

def load(self) -> Optional[BaseConversationMemory]:
from griptape.memory.structure import BaseConversationMemory

def load(self) -> tuple[list[Run], dict[str, Any]]:
response = self.table.get_item(Key=self._get_key())

if "Item" in response and self.value_attribute_key in response["Item"]:
memory_dict = json.loads(response["Item"][self.value_attribute_key])
# needed to avoid recursive method calls
memory_dict["autoload"] = False

memory = BaseConversationMemory.from_dict(memory_dict)

memory.driver = self

return memory
return self._from_params_dict(memory_dict)
else:
return None
return [], {}

def _get_key(self) -> dict[str, str | int]:
key: dict[str, str | int] = {self.partition_key: self.partition_key_value}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any

from griptape.mixins import SerializableMixin

if TYPE_CHECKING:
from griptape.memory.structure import BaseConversationMemory
from griptape.memory.structure import Run


class BaseConversationMemoryDriver(SerializableMixin, ABC):
@abstractmethod
def store(self, memory: BaseConversationMemory) -> None: ...
def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: ...

@abstractmethod
def load(self) -> Optional[BaseConversationMemory]: ...
def load(self) -> tuple[list[Run], dict[str, Any]]: ...

def _to_params_dict(self, runs: list[Run], metadata: dict[str, Any]) -> dict:
return {"runs": [run.to_dict() for run in runs], "metadata": metadata}

def _from_params_dict(self, params_dict: dict[str, Any]) -> tuple[list[Run], dict[str, Any]]:
from griptape.memory.structure import Run

return [Run.from_dict(run) for run in params_dict.get("runs", [])], params_dict.get("metadata", {})
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@

import os
import uuid
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional
from urllib.parse import urljoin

import requests
from attrs import Attribute, Factory, define, field

from griptape.artifacts import BaseArtifact
from griptape.drivers import BaseConversationMemoryDriver
from griptape.utils import dict_merge

if TYPE_CHECKING:
from griptape.memory.structure import BaseConversationMemory
from griptape.memory.structure import Run


@define(kw_only=True)
Expand Down Expand Up @@ -55,26 +56,38 @@ def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
raise ValueError(f"{self.__class__.__name__} requires an API key")
return value

def store(self, memory: BaseConversationMemory) -> None:
# serliaze the run artifacts to json strings
messages = [{"input": run.input.to_json(), "output": run.output.to_json()} for run in memory.runs]
def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:
# serialize the run artifacts to json strings
messages = [
dict_merge(
{
"input": run.input.to_json(),
"output": run.output.to_json(),
"metadata": {"run_id": run.id},
},
run.meta,
)
for run in runs
]

# serialize the metadata to a json string
# remove runs because they are already stored as Messages
metadata = memory.to_dict()
del metadata["runs"]
body = dict_merge(
{
"messages": messages,
},
metadata,
)

# patch the Thread with the new messages and metadata
# all old Messages are replaced with the new ones
response = requests.patch(
self._get_url(f"/threads/{self.thread_id}"),
json={"messages": messages, "metadata": metadata},
json=body,
headers=self.headers,
)
response.raise_for_status()

def load(self) -> BaseConversationMemory:
from griptape.memory.structure import BaseConversationMemory, ConversationMemory, Run
def load(self) -> tuple[list[Run], dict[str, Any]]:
from griptape.memory.structure import Run

# get the Messages from the Thread
messages_response = requests.get(self._get_url(f"/threads/{self.thread_id}/messages"), headers=self.headers)
Expand All @@ -86,33 +99,16 @@ def load(self) -> BaseConversationMemory:
thread_response.raise_for_status()
thread_response = thread_response.json()

messages = messages_response.get("messages", [])

runs = [
Run(
id=m["message_id"],
id=m["metadata"].pop("run_id"),
meta=m["metadata"],
input=BaseArtifact.from_json(m["input"]),
output=BaseArtifact.from_json(m["output"]),
)
for m in messages
for m in messages_response.get("messages", [])
]
metadata = thread_response.get("metadata")

# the metadata will contain the serialized
# ConversationMemory object with the runs removed
# autoload=False to prevent recursively loading the memory
if metadata is not None and metadata != {}:
memory = BaseConversationMemory.from_dict(
{
**metadata,
"runs": [run.to_dict() for run in runs],
"autoload": False,
}
)
memory.driver = self
return memory
# no metadata found, return a new ConversationMemory object
return ConversationMemory(runs=runs, autoload=False, driver=self)
return runs, thread_response.get("metadata", {})

def _get_thread_id(self) -> str:
res = requests.post(self._get_url("/threads"), json={"name": uuid.uuid4().hex}, headers=self.headers)
Expand Down
Loading

0 comments on commit ef61c53

Please sign in to comment.