Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Aug 20, 2024
1 parent b675717 commit 4721811
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 68 deletions.
4 changes: 2 additions & 2 deletions griptape/configs/drivers/drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
DummyPromptDriver,
DummyTextToSpeechDriver,
DummyVectorStoreDriver,
DefaultConversationMemoryDriver,
LocalConversationMemoryDriver,
)
from griptape.utils.decorators import lazy_property

Expand Down Expand Up @@ -54,7 +54,7 @@ def vector_store_driver(self) -> BaseVectorStoreDriver:

@lazy_property()
def conversation_memory_driver(self) -> BaseConversationMemoryDriver:
return DefaultConversationMemoryDriver()
return LocalConversationMemoryDriver()

@lazy_property()
def text_to_speech_driver(self) -> BaseTextToSpeechDriver:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@

class BaseConversationMemoryDriver(SerializableMixin, ABC):
@abstractmethod
def get_id(self) -> str: ...
def store(self, runs: list[Run], metadata: Optional[dict] = None, *, overwrite: bool = False) -> None: ...

@abstractmethod
def store_runs(self, runs: list[Run], *, overwrite: bool = False, **kwargs) -> None: ...

@abstractmethod
def load_runs(self, memory: BaseConversationMemory) -> Optional[list[Run]]: ...
def load(self) -> tuple[list[Run], Optional[dict]]: ...

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -45,27 +45,30 @@ class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver):
init=False,
)

def get_id(self) -> None:
def __attrs_post_init__(self) -> None:
if self.thread_id is None:
self.thread_id = os.getenv("GT_CLOUD_THREAD_ID", self._get_thread_id())
return self.thread_id

@api_key.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
if value is None:
raise ValueError(f"{self.__class__.__name__} requires an API key")
return value

def store_runs(self, runs: list[Run], overwrite: bool = False, **kwargs) -> None:
def store(self, runs: list[Run], metadata: Optional[dict] = None, overwrite: bool = False) -> None:
# serliaze the run artifacts to json strings
messages = [{"input": run.input.to_json(), "output": run.output.to_json()} for run in runs]

if overwrite:
# patch the Thread with the new messages and metadata
# all old Messages are replaced with the new ones
body = {"messages": messages}
if metadata is not None:
body["metadata"] = metadata

response = requests.patch(
self._get_url(f"/threads/{self.thread_id}"),
json={"messages": messages, "metadata": kwargs},
json=body,
headers=self.headers,
)
response.raise_for_status()
Expand All @@ -78,17 +81,24 @@ def store_runs(self, runs: list[Run], overwrite: bool = False, **kwargs) -> None
headers=self.headers,
)
response.raise_for_status()
if metadata is not None:
response = requests.patch(
self._get_url(f"/threads/{self.thread_id}"),
json={"metadata": metadata},
headers=self.headers,
)
response.raise_for_status()

def load_runs(self, memory: BaseConversationMemory) -> list[Run]:
def load(self) -> tuple[list[Run], Optional[dict]]:
from griptape.memory.structure import Run

# get the Messages from the Thread
messages_response = requests.get(self._get_url(f"/threads/{memory.id}/messages"), headers=self.headers)
messages_response = requests.get(self._get_url(f"/threads/{self.thread_id}/messages"), headers=self.headers)
messages_response.raise_for_status()
messages_response = messages_response.json()

# retrieve the Thread to get the metadata
thread_response = requests.get(self._get_url(f"/threads/{memory.id}"), headers=self.headers)
thread_response = requests.get(self._get_url(f"/threads/{self.thread_id}"), headers=self.headers)
thread_response.raise_for_status()
thread_response = thread_response.json()

Expand All @@ -102,12 +112,7 @@ def load_runs(self, memory: BaseConversationMemory) -> list[Run]:
)
for m in messages
]
metadata = thread_response.get("metadata")
if metadata is not None and metadata != {}:
# TODO: do something?
pass

return runs
return runs, thread_response.get("metadata")

def _get_thread_id(self) -> str:
res = requests.post(self._get_url("/threads"), json={"name": uuid.uuid4().hex}, headers=self.headers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,45 @@
from attrs import define, field

from griptape.drivers import BaseConversationMemoryDriver
from griptape.memory.structure import Run
from griptape.utils import dict_merge

if TYPE_CHECKING:
from griptape.memory.structure import BaseConversationMemory


@define
@define(kw_only=True)
class LocalConversationMemoryDriver(BaseConversationMemoryDriver):
file_path: str = field(default="griptape_memory.json", kw_only=True, metadata={"serializable": True})
persist_file: Optional[str] = field(default=None, metadata={"serializable": True})

def store(self, memory: BaseConversationMemory) -> None:
Path(self.file_path).write_text(memory.to_json())
def __attrs_post_init__(self) -> None:
if self.persist_file is not None:
directory = os.path.dirname(self.persist_file)

def load(self) -> Optional[BaseConversationMemory]:
from griptape.memory.structure import BaseConversationMemory
if directory is not None and not os.path.exists(directory):
os.makedirs(directory)

if not os.path.exists(self.file_path):
return None
if not os.path.isfile(self.persist_file):
Path(self.persist_file).touch()

memory_dict = json.loads(Path(self.file_path).read_text())
# needed to avoid recursive method calls
memory_dict["autoload"] = False
memory = BaseConversationMemory.from_dict(memory_dict)
def store(self, runs: list[Run], metadata: Optional[dict] = None, overwrite: bool = False) -> None:
if self.persist_file is not None:
metadata_dict = metadata if metadata is not None else {}

memory.driver = self
if overwrite:
Path(self.persist_file).write_text(
json.dumps({"runs": [run.to_dict() for run in runs], "metadata": metadata_dict})
)
else:
data = json.loads(Path(self.persist_file).read_text())
data["runs"] += [run.to_dict() for run in runs]
data["metadata"] = dict_merge(data["metadata"], metadata_dict)
Path(self.persist_file).write_text(json.dumps(data))

return memory
def load(self) -> tuple[list[Run], Optional[dict]]:
if self.persist_file is not None:
data = json.loads(Path(self.persist_file).read_text())

return [Run.from_dict(run) for run in data["runs"]], data.get("metadata", None)

return [], None
15 changes: 4 additions & 11 deletions griptape/memory/structure/base_conversation_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,19 @@

@define
class BaseConversationMemory(SerializableMixin, ABC):
_id: str = field(default=None, kw_only=True, metadata={"serializable": True}, alias="id")
driver: BaseConversationMemoryDriver = field(
default=Factory(lambda: Defaults.drivers_config.conversation_memory_driver), kw_only=True
)
runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True})
metadata: Optional[dict] = field(default=None, kw_only=True, metadata={"serializable": True})
structure: Structure = field(init=False)
autoload: bool = field(default=True, kw_only=True)
autoprune: bool = field(default=True, kw_only=True)
max_runs: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})

def __attrs_post_init__(self) -> None:
if self.autoload:
self.runs = self.driver.load_runs(self)
if self.runs is None:
raise ValueError(f"Failed to load runs from {self.driver.__class__.__name__}")

@property
def id(self) -> str:
if self._id is None:
self._id = self.driver.get_id()
return self._id
self.runs, self.metadata = self.driver.load(self)

def before_add_run(self, run: Run) -> None:
self.runs.append(run)
Expand All @@ -56,7 +48,7 @@ def add_run(self, run: Run) -> BaseConversationMemory:
return self

def after_add_run(self, run: Run) -> None:
self.driver.store_runs([run])
self.driver.store([run], self.metadata)

@abstractmethod
def try_add_run(self, run: Run) -> None: ...
Expand All @@ -73,6 +65,7 @@ def add_to_prompt_stack(
as possible without exceeding the token limit.
Args:
prompt_driver: The Prompt Driver to use for token counting.
prompt_stack: The Prompt Stack to add the Conversation Memory to.
index: Optional index to insert the Conversation Memory runs at.
Defaults to appending to the end of the Prompt Stack.
Expand Down

0 comments on commit 4721811

Please sign in to comment.