Skip to content

Commit

Permalink
squash
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Aug 19, 2024
1 parent 717753c commit 74a08a3
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Chat.logger_level` for setting what the `Chat` utility sets the logger level to.
- `FuturesExecutorMixin` to DRY up and optimize concurrent code across multiple classes.
- `utils.execute_futures_list_dict` for executing a dict of lists of futures.
- `GriptapeCloudConversationMemoryDriver` to store conversation history in Griptape Cloud.

### Changed
- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ You can persist and load memory by using Conversation Memory Drivers. You can bu

## Conversation Memory Drivers

### Griptape Cloud

The [GriptapeCloudConversationMemoryDriver](../../reference/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.md) allows you to persist Conversation Memory in Griptape Cloud. It provides seamless integration with Griptape's cloud-based `Threads` and `Messages` resources.

```python
--8<-- "docs/griptape-framework/drivers/src/conversation_memory_drivers_griptape_cloud.py"
```

### Local

The [LocalConversationMemoryDriver](../../reference/griptape/drivers/memory/conversation/local_conversation_memory_driver.md) allows you to persist Conversation Memory in a local JSON file.
Expand Down Expand Up @@ -40,3 +48,4 @@ The [RedisConversationMemoryDriver](../../reference/griptape/drivers/memory/conv
```python
--8<-- "docs/griptape-framework/drivers/src/conversation_memory_drivers_3.py"
```

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os
import uuid

from griptape.drivers import GriptapeCloudConversationMemoryDriver
from griptape.memory.structure import ConversationMemory
from griptape.structures import Agent

conversation_id = uuid.uuid4().hex
cloud_conversation_driver = GriptapeCloudConversationMemoryDriver(
api_key=os.environ["GT_CLOUD_API_KEY"],
)
agent = Agent(conversation_memory=ConversationMemory(driver=cloud_conversation_driver))

agent.run("My name is Jeff.")
agent.run("What is my name?")
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .memory.conversation.local_conversation_memory_driver import LocalConversationMemoryDriver
from .memory.conversation.amazon_dynamodb_conversation_memory_driver import AmazonDynamoDbConversationMemoryDriver
from .memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver
from .memory.conversation.griptape_cloud_conversation_memory_driver import GriptapeCloudConversationMemoryDriver

from .embedding.base_embedding_driver import BaseEmbeddingDriver
from .embedding.openai_embedding_driver import OpenAiEmbeddingDriver
Expand Down Expand Up @@ -149,6 +150,7 @@
"LocalConversationMemoryDriver",
"AmazonDynamoDbConversationMemoryDriver",
"RedisConversationMemoryDriver",
"GriptapeCloudConversationMemoryDriver",
"BaseEmbeddingDriver",
"OpenAiEmbeddingDriver",
"AzureOpenAiEmbeddingDriver",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, Optional

from attrs import Factory, define, field
Expand Down Expand Up @@ -44,9 +45,11 @@ def load(self) -> Optional[BaseConversationMemory]:
response = self.table.get_item(Key=self._get_key())

if "Item" in response and self.value_attribute_key in response["Item"]:
memory_value = response["Item"][self.value_attribute_key]
memory_dict = json.loads(response["Item"][self.value_attribute_key])
# needed to avoid recursive method calls
memory_dict["autoload"] = False

memory = BaseConversationMemory.from_json(memory_value)
memory = BaseConversationMemory.from_dict(memory_dict)

memory.driver = self

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Optional
from urllib.parse import urljoin

import requests
from attrs import Attribute, Factory, define, field

from griptape.artifacts import BaseArtifact
from griptape.drivers import BaseConversationMemoryDriver

if TYPE_CHECKING:
from griptape.memory.structure import BaseConversationMemory


@define(kw_only=True)
class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver):
"""A driver for storing conversation memory in the Griptape Cloud.
Attributes:
thread_id: The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to
retrieve the ID from the environment variable `GT_CLOUD_THREAD_ID`. If that is not set, a new Thread will be
created.
base_url: The base URL of the Griptape Cloud API. Defaults to the value of the environment variable
`GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`.
api_key: The API key to use for authenticating with the Griptape Cloud API. If not provided, the driver will
attempt to retrieve the API key from the environment variable `GT_CLOUD_API_KEY`.
Raises:
ValueError: If `api_key` is not provided.
"""

thread_id: str = field(
default=None,
metadata={"serializable": True},
)
base_url: str = field(
default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
)
api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY")))
headers: dict = field(
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
init=False,
)

def __attrs_post_init__(self) -> None:
if self.thread_id is None:
self.thread_id = os.getenv("GT_CLOUD_THREAD_ID", self._get_thread_id())

@api_key.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
if value is None:
raise ValueError(f"{self.__class__.__name__} requires an API key")
return value

def store(self, memory: BaseConversationMemory) -> None:
# serliaze the run artifacts to json strings
messages = [{"input": run.input.to_json(), "output": run.output.to_json()} for run in memory.runs]

# serialize the metadata to a json string
# remove runs because they are already stored as Messages
metadata = memory.to_dict()
del metadata["runs"]

# patch the Thread with the new messages and metadata
# all old Messages are replaced with the new ones
response = requests.patch(
self._get_url(f"/threads/{self.thread_id}"),
json={"messages": messages, "metadata": metadata},
headers=self.headers,
)
response.raise_for_status()

def load(self) -> BaseConversationMemory:
from griptape.memory.structure import BaseConversationMemory, ConversationMemory, Run

# get the Messages from the Thread
messages_response = requests.get(self._get_url(f"/threads/{self.thread_id}/messages"), headers=self.headers)
messages_response.raise_for_status()
messages_response = messages_response.json()

# retrieve the Thread to get the metadata
thread_response = requests.get(self._get_url(f"/threads/{self.thread_id}"), headers=self.headers)
thread_response.raise_for_status()
thread_response = thread_response.json()

messages = messages_response.get("messages", [])

runs = [
Run(
id=m["message_id"],
input=BaseArtifact.from_json(m["input"]),
output=BaseArtifact.from_json(m["output"]),
)
for m in messages
]
metadata = thread_response.get("metadata")

# the metadata will contain the serialized
# ConversationMemory object with the runs removed
# autoload=False to prevent recursively loading the memory
if metadata is not None and metadata != {}:
memory = BaseConversationMemory.from_dict(
{
**metadata,
"runs": [run.to_dict() for run in runs],
"autoload": False,
}
)
memory.driver = self
return memory
# no metadata found, return a new ConversationMemory object
return ConversationMemory(runs=runs, autoload=False, driver=self)

def _get_thread_id(self) -> str:
res = requests.post(self._get_url("/threads"), json={"name": "test"}, headers=self.headers)
res.raise_for_status()
return res.json().get("thread_id")

def _get_url(self, path: str) -> str:
path = path.lstrip("/")
return urljoin(self.base_url, f"/api/{path}")
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import os
from pathlib import Path
from typing import TYPE_CHECKING, Optional
Expand All @@ -24,7 +25,11 @@ def load(self) -> Optional[BaseConversationMemory]:

if not os.path.exists(self.file_path):
return None
memory = BaseConversationMemory.from_json(Path(self.file_path).read_text())

memory_dict = json.loads(Path(self.file_path).read_text())
# needed to avoid recursive method calls
memory_dict["autoload"] = False
memory = BaseConversationMemory.from_dict(memory_dict)

memory.driver = self

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import uuid
from typing import TYPE_CHECKING, Optional

Expand Down Expand Up @@ -59,8 +60,11 @@ def load(self) -> Optional[BaseConversationMemory]:

key = self.index
memory_json = self.client.hget(key, self.conversation_id)
if memory_json:
memory = BaseConversationMemory.from_json(memory_json)
if memory_json is not None:
memory_dict = json.loads(memory_json)
# needed to avoid recursive method calls
memory_dict["autoload"] = False
memory = BaseConversationMemory.from_dict(memory_dict)
memory.driver = self
return memory
return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import json

import pytest

from griptape.artifacts import BaseArtifact
from griptape.drivers import GriptapeCloudConversationMemoryDriver
from griptape.memory.structure import BaseConversationMemory, ConversationMemory, Run, SummaryConversationMemory

TEST_CONVERSATION = '{"type": "SummaryConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": {"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}}], "max_runs": 2}'


class TestGriptapeCloudConversationMemoryDriver:
@pytest.fixture(autouse=True)
def _mock_requests(self, mocker):
def get(*args, **kwargs):
if str(args[0]).endswith("/messages"):
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {
"messages": [
{
"message_id": "123",
"input": '{"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}',
"output": '{"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}',
"index": 0,
}
]
},
)
else:
thread_id = args[0].split("/")[-1]
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {
"metadata": json.loads(TEST_CONVERSATION),
"name": "test",
"thread_id": "test_metadata",
}
if thread_id == "test_metadata"
else {"name": "test", "thread_id": "test"},
)

mocker.patch(
"requests.get",
side_effect=get,
)
mocker.patch(
"requests.post",
return_value=mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {"thread_id": "test", "name": "test"},
),
)
mocker.patch(
"requests.patch",
return_value=mocker.Mock(
raise_for_status=lambda: None,
),
)

@pytest.fixture()
def driver(self):
return GriptapeCloudConversationMemoryDriver(api_key="test", thread_id="test")

def test_no_api_key(self):
with pytest.raises(ValueError):
GriptapeCloudConversationMemoryDriver(api_key=None, thread_id="test")

def test_no_thread_id(self):
driver = GriptapeCloudConversationMemoryDriver(api_key="test")
assert driver.thread_id == "test"

def test_store(self, driver):
memory = ConversationMemory(
runs=[
Run(input=BaseArtifact.from_dict(run["input"]), output=BaseArtifact.from_dict(run["output"]))
for run in json.loads(TEST_CONVERSATION)["runs"]
],
)
assert driver.store(memory) is None

def test_load(self, driver):
memory = driver.load()
assert isinstance(memory, BaseConversationMemory)
assert len(memory.runs) == 1

def test_load_metadata(self, driver):
driver.thread_id = "test_metadata"
memory = driver.load()
assert isinstance(memory, SummaryConversationMemory)
assert len(memory.runs) == 1

0 comments on commit 74a08a3

Please sign in to comment.