diff --git a/CHANGELOG.md b/CHANGELOG.md index 38d5784cf5..ca6b750a23 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Chat.logger_level` for setting what the `Chat` utility sets the logger level to. - `FuturesExecutorMixin` to DRY up and optimize concurrent code across multiple classes. - `utils.execute_futures_list_dict` for executing a dict of lists of futures. +- `GriptapeCloudConversationMemoryDriver` to store conversation history in Griptape Cloud. ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. diff --git a/docs/griptape-framework/drivers/conversation-memory-drivers.md b/docs/griptape-framework/drivers/conversation-memory-drivers.md index 29ab328fd8..bb4c1b35a8 100644 --- a/docs/griptape-framework/drivers/conversation-memory-drivers.md +++ b/docs/griptape-framework/drivers/conversation-memory-drivers.md @@ -9,6 +9,14 @@ You can persist and load memory by using Conversation Memory Drivers. You can bu ## Conversation Memory Drivers +### Griptape Cloud + +The [GriptapeCloudConversationMemoryDriver](../../reference/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.md) allows you to persist Conversation Memory in Griptape Cloud. It provides seamless integration with Griptape's cloud-based `Threads` and `Messages` resources. + +```python +--8<-- "docs/griptape-framework/drivers/src/conversation_memory_drivers_griptape_cloud.py" +``` + ### Local The [LocalConversationMemoryDriver](../../reference/griptape/drivers/memory/conversation/local_conversation_memory_driver.md) allows you to persist Conversation Memory in a local JSON file. @@ -40,3 +48,4 @@ The [RedisConversationMemoryDriver](../../reference/griptape/drivers/memory/conv ```python --8<-- "docs/griptape-framework/drivers/src/conversation_memory_drivers_3.py" ``` + diff --git a/docs/griptape-framework/drivers/src/conversation_memory_drivers_griptape_cloud.py b/docs/griptape-framework/drivers/src/conversation_memory_drivers_griptape_cloud.py new file mode 100644 index 0000000000..35492e06bf --- /dev/null +++ b/docs/griptape-framework/drivers/src/conversation_memory_drivers_griptape_cloud.py @@ -0,0 +1,15 @@ +import os +import uuid + +from griptape.drivers import GriptapeCloudConversationMemoryDriver +from griptape.memory.structure import ConversationMemory +from griptape.structures import Agent + +conversation_id = uuid.uuid4().hex +cloud_conversation_driver = GriptapeCloudConversationMemoryDriver( + api_key=os.environ["GT_CLOUD_API_KEY"], +) +agent = Agent(conversation_memory=ConversationMemory(driver=cloud_conversation_driver)) + +agent.run("My name is Jeff.") +agent.run("What is my name?") diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 9e1790b011..f19ec7d109 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -15,6 +15,7 @@ from .memory.conversation.local_conversation_memory_driver import LocalConversationMemoryDriver from .memory.conversation.amazon_dynamodb_conversation_memory_driver import AmazonDynamoDbConversationMemoryDriver from .memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver +from .memory.conversation.griptape_cloud_conversation_memory_driver import GriptapeCloudConversationMemoryDriver from .embedding.base_embedding_driver import BaseEmbeddingDriver from .embedding.openai_embedding_driver import OpenAiEmbeddingDriver @@ -149,6 +150,7 @@ "LocalConversationMemoryDriver", "AmazonDynamoDbConversationMemoryDriver", "RedisConversationMemoryDriver", + "GriptapeCloudConversationMemoryDriver", "BaseEmbeddingDriver", "OpenAiEmbeddingDriver", "AzureOpenAiEmbeddingDriver", diff --git a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py index 44f214d7c0..b0c2485d6d 100644 --- a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import TYPE_CHECKING, Any, Optional from attrs import Factory, define, field @@ -44,9 +45,11 @@ def load(self) -> Optional[BaseConversationMemory]: response = self.table.get_item(Key=self._get_key()) if "Item" in response and self.value_attribute_key in response["Item"]: - memory_value = response["Item"][self.value_attribute_key] + memory_dict = json.loads(response["Item"][self.value_attribute_key]) + # needed to avoid recursive method calls + memory_dict["autoload"] = False - memory = BaseConversationMemory.from_json(memory_value) + memory = BaseConversationMemory.from_dict(memory_dict) memory.driver = self diff --git a/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py b/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py new file mode 100644 index 0000000000..fe5fb6a3e2 --- /dev/null +++ b/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Optional +from urllib.parse import urljoin + +import requests +from attrs import Attribute, Factory, define, field + +from griptape.artifacts import BaseArtifact +from griptape.drivers import BaseConversationMemoryDriver + +if TYPE_CHECKING: + from griptape.memory.structure import BaseConversationMemory + + +@define(kw_only=True) +class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver): + """A driver for storing conversation memory in the Griptape Cloud. + + Attributes: + thread_id: The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to + retrieve the ID from the environment variable `GT_CLOUD_THREAD_ID`. If that is not set, a new Thread will be + created. + base_url: The base URL of the Griptape Cloud API. Defaults to the value of the environment variable + `GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`. + api_key: The API key to use for authenticating with the Griptape Cloud API. If not provided, the driver will + attempt to retrieve the API key from the environment variable `GT_CLOUD_API_KEY`. + + Raises: + ValueError: If `api_key` is not provided. + """ + + thread_id: str = field( + default=None, + metadata={"serializable": True}, + ) + base_url: str = field( + default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), + ) + api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY"))) + headers: dict = field( + default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), + init=False, + ) + + def __attrs_post_init__(self) -> None: + if self.thread_id is None: + self.thread_id = os.getenv("GT_CLOUD_THREAD_ID", self._get_thread_id()) + + @api_key.validator # pyright: ignore[reportAttributeAccessIssue] + def validate_api_key(self, _: Attribute, value: Optional[str]) -> str: + if value is None: + raise ValueError(f"{self.__class__.__name__} requires an API key") + return value + + def store(self, memory: BaseConversationMemory) -> None: + # serliaze the run artifacts to json strings + messages = [{"input": run.input.to_json(), "output": run.output.to_json()} for run in memory.runs] + + # serialize the metadata to a json string + # remove runs because they are already stored as Messages + metadata = memory.to_dict() + del metadata["runs"] + + # patch the Thread with the new messages and metadata + # all old Messages are replaced with the new ones + response = requests.patch( + self._get_url(f"/threads/{self.thread_id}"), + json={"messages": messages, "metadata": metadata}, + headers=self.headers, + ) + response.raise_for_status() + + def load(self) -> BaseConversationMemory: + from griptape.memory.structure import BaseConversationMemory, ConversationMemory, Run + + # get the Messages from the Thread + messages_response = requests.get(self._get_url(f"/threads/{self.thread_id}/messages"), headers=self.headers) + messages_response.raise_for_status() + messages_response = messages_response.json() + + # retrieve the Thread to get the metadata + thread_response = requests.get(self._get_url(f"/threads/{self.thread_id}"), headers=self.headers) + thread_response.raise_for_status() + thread_response = thread_response.json() + + messages = messages_response.get("messages", []) + + runs = [ + Run( + id=m["message_id"], + input=BaseArtifact.from_json(m["input"]), + output=BaseArtifact.from_json(m["output"]), + ) + for m in messages + ] + metadata = thread_response.get("metadata") + + # the metadata will contain the serialized + # ConversationMemory object with the runs removed + # autoload=False to prevent recursively loading the memory + if metadata is not None and metadata != {}: + memory = BaseConversationMemory.from_dict( + { + **metadata, + "runs": [run.to_dict() for run in runs], + "autoload": False, + } + ) + memory.driver = self + return memory + # no metadata found, return a new ConversationMemory object + return ConversationMemory(runs=runs, autoload=False, driver=self) + + def _get_thread_id(self) -> str: + res = requests.post(self._get_url("/threads"), json={"name": "test"}, headers=self.headers) + res.raise_for_status() + return res.json().get("thread_id") + + def _get_url(self, path: str) -> str: + path = path.lstrip("/") + return urljoin(self.base_url, f"/api/{path}") diff --git a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py index f7b6e7d6e9..9a79accc3b 100644 --- a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import os from pathlib import Path from typing import TYPE_CHECKING, Optional @@ -24,7 +25,11 @@ def load(self) -> Optional[BaseConversationMemory]: if not os.path.exists(self.file_path): return None - memory = BaseConversationMemory.from_json(Path(self.file_path).read_text()) + + memory_dict = json.loads(Path(self.file_path).read_text()) + # needed to avoid recursive method calls + memory_dict["autoload"] = False + memory = BaseConversationMemory.from_dict(memory_dict) memory.driver = self diff --git a/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py index 9afc2f2045..8741cda509 100644 --- a/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import uuid from typing import TYPE_CHECKING, Optional @@ -59,8 +60,11 @@ def load(self) -> Optional[BaseConversationMemory]: key = self.index memory_json = self.client.hget(key, self.conversation_id) - if memory_json: - memory = BaseConversationMemory.from_json(memory_json) + if memory_json is not None: + memory_dict = json.loads(memory_json) + # needed to avoid recursive method calls + memory_dict["autoload"] = False + memory = BaseConversationMemory.from_dict(memory_dict) memory.driver = self return memory return None diff --git a/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py new file mode 100644 index 0000000000..707132ef5f --- /dev/null +++ b/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py @@ -0,0 +1,91 @@ +import json + +import pytest + +from griptape.artifacts import BaseArtifact +from griptape.drivers import GriptapeCloudConversationMemoryDriver +from griptape.memory.structure import BaseConversationMemory, ConversationMemory, Run, SummaryConversationMemory + +TEST_CONVERSATION = '{"type": "SummaryConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": {"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}}], "max_runs": 2}' + + +class TestGriptapeCloudConversationMemoryDriver: + @pytest.fixture(autouse=True) + def _mock_requests(self, mocker): + def get(*args, **kwargs): + if str(args[0]).endswith("/messages"): + return mocker.Mock( + raise_for_status=lambda: None, + json=lambda: { + "messages": [ + { + "message_id": "123", + "input": '{"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}', + "output": '{"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}', + "index": 0, + } + ] + }, + ) + else: + thread_id = args[0].split("/")[-1] + return mocker.Mock( + raise_for_status=lambda: None, + json=lambda: { + "metadata": json.loads(TEST_CONVERSATION), + "name": "test", + "thread_id": "test_metadata", + } + if thread_id == "test_metadata" + else {"name": "test", "thread_id": "test"}, + ) + + mocker.patch( + "requests.get", + side_effect=get, + ) + mocker.patch( + "requests.post", + return_value=mocker.Mock( + raise_for_status=lambda: None, + json=lambda: {"thread_id": "test", "name": "test"}, + ), + ) + mocker.patch( + "requests.patch", + return_value=mocker.Mock( + raise_for_status=lambda: None, + ), + ) + + @pytest.fixture() + def driver(self): + return GriptapeCloudConversationMemoryDriver(api_key="test", thread_id="test") + + def test_no_api_key(self): + with pytest.raises(ValueError): + GriptapeCloudConversationMemoryDriver(api_key=None, thread_id="test") + + def test_no_thread_id(self): + driver = GriptapeCloudConversationMemoryDriver(api_key="test") + assert driver.thread_id == "test" + + def test_store(self, driver): + memory = ConversationMemory( + runs=[ + Run(input=BaseArtifact.from_dict(run["input"]), output=BaseArtifact.from_dict(run["output"])) + for run in json.loads(TEST_CONVERSATION)["runs"] + ], + ) + assert driver.store(memory) is None + + def test_load(self, driver): + memory = driver.load() + assert isinstance(memory, BaseConversationMemory) + assert len(memory.runs) == 1 + + def test_load_metadata(self, driver): + driver.thread_id = "test_metadata" + memory = driver.load() + assert isinstance(memory, SummaryConversationMemory) + assert len(memory.runs) == 1