Skip to content

Commit

Permalink
PR updates
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Aug 16, 2024
1 parent a207a06 commit 38a540b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 51 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import json
import os
from typing import TYPE_CHECKING, Optional
from urllib.parse import urljoin
Expand Down Expand Up @@ -58,16 +57,12 @@ def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:

def store(self, memory: BaseConversationMemory) -> None:
# serliaze the run artifacts to json strings
messages = [
{"input": json.dumps(run.input.to_dict()), "output": json.dumps(run.output.to_dict())}
for run in memory.runs
]
messages = [{"input": run.input.to_json(), "output": run.output.to_json()} for run in memory.runs]

# serialize the metadata to a json string
# remove runs because they are already stored as Messages
metadata = memory.to_dict()
del metadata["runs"]
metadata = json.dumps(metadata)

# patch the Thread with the new messages and metadata
# all old Messages are replaced with the new ones
Expand All @@ -78,28 +73,26 @@ def store(self, memory: BaseConversationMemory) -> None:
)
response.raise_for_status()

def load(self) -> Optional[BaseConversationMemory]:
def load(self) -> BaseConversationMemory:
from griptape.memory.structure import BaseConversationMemory, ConversationMemory, Run

# get the Messages from the Thread
messages_response = requests.get(self._get_url(f"/threads/{self.thread_id}/messages"), headers=self.headers)
messages_response.raise_for_status()
messages_response = messages_response.json()
if messages_response is None:
raise RuntimeError(f"Error getting messages for thread {self.thread_id}")

messages = sorted(messages_response.get("messages", []), key=lambda m: m["index"])

# retrieve the Thread to get the metadata
thread_response = requests.get(self._get_url(f"/threads/{self.thread_id}"), headers=self.headers)
thread_response.raise_for_status()
thread_response = thread_response.json()

from griptape.memory.structure import Run
messages = messages_response.get("messages", [])

runs = [
Run(
id=m["message_id"],
input=BaseArtifact.from_dict(json.loads(m["input"])),
output=BaseArtifact.from_dict(json.loads(m["output"])),
input=BaseArtifact.from_json(m["input"]),
output=BaseArtifact.from_json(m["output"]),
)
for m in messages
]
Expand All @@ -109,18 +102,14 @@ def load(self) -> Optional[BaseConversationMemory]:
# ConversationMemory object with the runs removed
# autoload=False to prevent recursively loading the memory
if metadata is not None and metadata != {}:
from griptape.memory.structure import BaseConversationMemory

return BaseConversationMemory.from_dict(
{
**json.loads(thread_response.get("metadata")),
**metadata,
"runs": [run.to_dict() for run in runs],
"autoload": False,
}
)
# no metadata found, return a new ConversationMemory object
from griptape.memory.structure import ConversationMemory

return ConversationMemory(runs=runs, autoload=False)

def _get_thread_id(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,32 @@ class TestGriptapeCloudConversationMemoryDriver:
@pytest.fixture(autouse=True)
def _mock_requests(self, mocker):
def get(*args, **kwargs):
if str(args[0]).startswith("https://cloud.griptape.ai/api/threads/"):
if str(args[0]).endswith("/messages"):
thread_id = args[0].split("/")[-2]
if thread_id == "test_error":
return mocker.Mock(raise_for_status=lambda: None, json=lambda: None)
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {
"messages": [
{
"message_id": "123",
"input": '{"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}',
"output": '{"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}',
"index": 0,
}
]
},
)
else:
thread_id = args[0].split("/")[-1]
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {"metadata": TEST_CONVERSATION, "name": "test", "thread_id": "test_metadata"}
if thread_id == "test_metadata"
else {"name": "test", "thread_id": "test"},
)
return None
if str(args[0]).endswith("/messages"):
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {
"messages": [
{
"message_id": "123",
"input": '{"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}',
"output": '{"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}',
"index": 0,
}
]
},
)
else:
thread_id = args[0].split("/")[-1]
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {
"metadata": json.loads(TEST_CONVERSATION),
"name": "test",
"thread_id": "test_metadata",
}
if thread_id == "test_metadata"
else {"name": "test", "thread_id": "test"},
)

mocker.patch(
"requests.get",
Expand Down Expand Up @@ -90,8 +89,3 @@ def test_load_metadata(self, driver):
memory = driver.load()
assert isinstance(memory, SummaryConversationMemory)
assert len(memory.runs) == 1

def test_load_error(self, driver):
driver.thread_id = "test_error"
with pytest.raises(RuntimeError, match="Error getting messages for thread test_error"):
driver.load()

0 comments on commit 38a540b

Please sign in to comment.