Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GriptapeCloudConversationMemoryDriver #1063

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Chat.logger_level` for setting what the `Chat` utility sets the logger level to.
- `FuturesExecutorMixin` to DRY up and optimize concurrent code across multiple classes.
- `utils.execute_futures_list_dict` for executing a dict of lists of futures.
- `GriptapeCloudConversationMemoryDriver` to store conversation history in Griptape Cloud.

### Changed
- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ You can persist and load memory by using Conversation Memory Drivers. You can bu

## Conversation Memory Drivers

### Griptape Cloud

The [GriptapeCloudConversationMemoryDriver](../../reference/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.md) allows you to persist Conversation Memory in Griptape Cloud. It provides seamless integration with Griptape's cloud-based `Threads` and `Messages` resources.
collindutter marked this conversation as resolved.
Show resolved Hide resolved

```python
--8<-- "docs/griptape-framework/drivers/src/conversation_memory_drivers_griptape_cloud.py"
```

### Local

The [LocalConversationMemoryDriver](../../reference/griptape/drivers/memory/conversation/local_conversation_memory_driver.md) allows you to persist Conversation Memory in a local JSON file.
Expand Down Expand Up @@ -40,3 +48,4 @@ The [RedisConversationMemoryDriver](../../reference/griptape/drivers/memory/conv
```python
--8<-- "docs/griptape-framework/drivers/src/conversation_memory_drivers_3.py"
```

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os
import uuid

from griptape.drivers import GriptapeCloudConversationMemoryDriver
from griptape.memory.structure import ConversationMemory
from griptape.structures import Agent

conversation_id = uuid.uuid4().hex
cloud_conversation_driver = GriptapeCloudConversationMemoryDriver(
api_key=os.environ["GT_CLOUD_API_KEY"],
)
agent = Agent(conversation_memory=ConversationMemory(driver=cloud_conversation_driver))

agent.run("My name is Jeff.")
agent.run("What is my name?")
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .memory.conversation.local_conversation_memory_driver import LocalConversationMemoryDriver
from .memory.conversation.amazon_dynamodb_conversation_memory_driver import AmazonDynamoDbConversationMemoryDriver
from .memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver
from .memory.conversation.griptape_cloud_conversation_memory_driver import GriptapeCloudConversationMemoryDriver

from .embedding.base_embedding_driver import BaseEmbeddingDriver
from .embedding.openai_embedding_driver import OpenAiEmbeddingDriver
Expand Down Expand Up @@ -149,6 +150,7 @@
"LocalConversationMemoryDriver",
"AmazonDynamoDbConversationMemoryDriver",
"RedisConversationMemoryDriver",
"GriptapeCloudConversationMemoryDriver",
"BaseEmbeddingDriver",
"OpenAiEmbeddingDriver",
"AzureOpenAiEmbeddingDriver",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, Optional

from attrs import Factory, define, field
Expand Down Expand Up @@ -44,9 +45,11 @@ def load(self) -> Optional[BaseConversationMemory]:
response = self.table.get_item(Key=self._get_key())

if "Item" in response and self.value_attribute_key in response["Item"]:
memory_value = response["Item"][self.value_attribute_key]
memory_dict = json.loads(response["Item"][self.value_attribute_key])
# needed to avoid recursive method calls
memory_dict["autoload"] = False

memory = BaseConversationMemory.from_json(memory_value)
memory = BaseConversationMemory.from_dict(memory_dict)

memory.driver = self

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Optional
from urllib.parse import urljoin

import requests
from attrs import Attribute, Factory, define, field

from griptape.artifacts import BaseArtifact
from griptape.drivers import BaseConversationMemoryDriver

if TYPE_CHECKING:
from griptape.memory.structure import BaseConversationMemory


@define(kw_only=True)
class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver):
"""A driver for storing conversation memory in the Griptape Cloud.

Attributes:
thread_id: The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to
retrieve the ID from the environment variable `GT_CLOUD_THREAD_ID`. If that is not set, a new Thread will be
created.
base_url: The base URL of the Griptape Cloud API. Defaults to the value of the environment variable
`GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`.
api_key: The API key to use for authenticating with the Griptape Cloud API. If not provided, the driver will
attempt to retrieve the API key from the environment variable `GT_CLOUD_API_KEY`.

Raises:
ValueError: If `api_key` is not provided.
"""

thread_id: str = field(
default=None,
metadata={"serializable": True},
)
base_url: str = field(
default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
)
api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY")))
headers: dict = field(
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
init=False,
)

def __attrs_post_init__(self) -> None:
if self.thread_id is None:
self.thread_id = os.getenv("GT_CLOUD_THREAD_ID", self._get_thread_id())
collindutter marked this conversation as resolved.
Show resolved Hide resolved

@api_key.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
if value is None:
raise ValueError(f"{self.__class__.__name__} requires an API key")
return value

def store(self, memory: BaseConversationMemory) -> None:
# serliaze the run artifacts to json strings
messages = [{"input": run.input.to_json(), "output": run.output.to_json()} for run in memory.runs]

# serialize the metadata to a json string
# remove runs because they are already stored as Messages
metadata = memory.to_dict()
del metadata["runs"]

# patch the Thread with the new messages and metadata
# all old Messages are replaced with the new ones
response = requests.patch(
self._get_url(f"/threads/{self.thread_id}"),
json={"messages": messages, "metadata": metadata},
headers=self.headers,
)
response.raise_for_status()

def load(self) -> BaseConversationMemory:
from griptape.memory.structure import BaseConversationMemory, ConversationMemory, Run

# get the Messages from the Thread
messages_response = requests.get(self._get_url(f"/threads/{self.thread_id}/messages"), headers=self.headers)
messages_response.raise_for_status()
messages_response = messages_response.json()

# retrieve the Thread to get the metadata
thread_response = requests.get(self._get_url(f"/threads/{self.thread_id}"), headers=self.headers)
thread_response.raise_for_status()
thread_response = thread_response.json()

messages = messages_response.get("messages", [])

runs = [
Run(
id=m["message_id"],
input=BaseArtifact.from_json(m["input"]),
output=BaseArtifact.from_json(m["output"]),
)
for m in messages
]
metadata = thread_response.get("metadata")

# the metadata will contain the serialized
# ConversationMemory object with the runs removed
# autoload=False to prevent recursively loading the memory
if metadata is not None and metadata != {}:
memory = BaseConversationMemory.from_dict(
{
**metadata,
"runs": [run.to_dict() for run in runs],
collindutter marked this conversation as resolved.
Show resolved Hide resolved
"autoload": False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For visibility of offline conversation, why is this necessary in this Driver but not others?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no idea

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you confirm that this is still necessary? Or can you test other Drivers to see if they should have this too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated other drivers with this fix

}
)
memory.driver = self
return memory
# no metadata found, return a new ConversationMemory object
return ConversationMemory(runs=runs, autoload=False, driver=self)

def _get_thread_id(self) -> str:
res = requests.post(self._get_url("/threads"), json={"name": "test"}, headers=self.headers)
res.raise_for_status()
return res.json().get("thread_id")

def _get_url(self, path: str) -> str:
path = path.lstrip("/")
return urljoin(self.base_url, f"/api/{path}")
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import os
from pathlib import Path
from typing import TYPE_CHECKING, Optional
Expand All @@ -24,7 +25,11 @@ def load(self) -> Optional[BaseConversationMemory]:

if not os.path.exists(self.file_path):
return None
memory = BaseConversationMemory.from_json(Path(self.file_path).read_text())

memory_dict = json.loads(Path(self.file_path).read_text())
# needed to avoid recursive method calls
memory_dict["autoload"] = False
memory = BaseConversationMemory.from_dict(memory_dict)

memory.driver = self

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import uuid
from typing import TYPE_CHECKING, Optional

Expand Down Expand Up @@ -59,8 +60,11 @@ def load(self) -> Optional[BaseConversationMemory]:

key = self.index
memory_json = self.client.hget(key, self.conversation_id)
if memory_json:
memory = BaseConversationMemory.from_json(memory_json)
if memory_json is not None:
memory_dict = json.loads(memory_json)
# needed to avoid recursive method calls
memory_dict["autoload"] = False
memory = BaseConversationMemory.from_dict(memory_dict)
memory.driver = self
return memory
return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import json

import pytest

from griptape.artifacts import BaseArtifact
from griptape.drivers import GriptapeCloudConversationMemoryDriver
from griptape.memory.structure import BaseConversationMemory, ConversationMemory, Run, SummaryConversationMemory

TEST_CONVERSATION = '{"type": "SummaryConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": {"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}}], "max_runs": 2}'


class TestGriptapeCloudConversationMemoryDriver:
@pytest.fixture(autouse=True)
def _mock_requests(self, mocker):
def get(*args, **kwargs):
if str(args[0]).endswith("/messages"):
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {
"messages": [
{
"message_id": "123",
"input": '{"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}',
"output": '{"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}',
"index": 0,
}
]
},
)
else:
thread_id = args[0].split("/")[-1]
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {
"metadata": json.loads(TEST_CONVERSATION),
"name": "test",
"thread_id": "test_metadata",
}
if thread_id == "test_metadata"
else {"name": "test", "thread_id": "test"},
)

mocker.patch(
"requests.get",
side_effect=get,
)
mocker.patch(
"requests.post",
return_value=mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {"thread_id": "test", "name": "test"},
),
)
mocker.patch(
"requests.patch",
return_value=mocker.Mock(
raise_for_status=lambda: None,
),
)

@pytest.fixture()
def driver(self):
return GriptapeCloudConversationMemoryDriver(api_key="test", thread_id="test")

def test_no_api_key(self):
with pytest.raises(ValueError):
GriptapeCloudConversationMemoryDriver(api_key=None, thread_id="test")

def test_no_thread_id(self):
driver = GriptapeCloudConversationMemoryDriver(api_key="test")
assert driver.thread_id == "test"

def test_store(self, driver):
memory = ConversationMemory(
runs=[
Run(input=BaseArtifact.from_dict(run["input"]), output=BaseArtifact.from_dict(run["output"]))
for run in json.loads(TEST_CONVERSATION)["runs"]
],
)
assert driver.store(memory) is None

def test_load(self, driver):
memory = driver.load()
assert isinstance(memory, BaseConversationMemory)
assert len(memory.runs) == 1

def test_load_metadata(self, driver):
driver.thread_id = "test_metadata"
memory = driver.load()
assert isinstance(memory, SummaryConversationMemory)
assert len(memory.runs) == 1
Loading