Skip to content

Commit

Permalink
Add support for more modalities to conversation memory
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jun 14, 2024
1 parent e4dd283 commit f756096
Show file tree
Hide file tree
Showing 16 changed files with 94 additions and 111 deletions.
12 changes: 4 additions & 8 deletions griptape/common/prompt_stack/elements/prompt_stack_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,9 @@ def to_text(self) -> str:
return self.to_text_artifact().to_text()

Check warning on line 46 in griptape/common/prompt_stack/elements/prompt_stack_element.py

View check run for this annotation

Codecov / codecov/patch

griptape/common/prompt_stack/elements/prompt_stack_element.py#L46

Added line #L46 was not covered by tests

def to_text_artifact(self) -> TextArtifact:
if all(isinstance(content, TextPromptStackContent) for content in self.content):
artifact = TextArtifact(value="")
artifact = TextArtifact(value="")

for content in self.content:
if isinstance(content, TextPromptStackContent):
artifact += content.artifact
for content in self.content:
artifact.value += content.artifact.to_text()

return artifact
else:
raise ValueError("Cannot convert to TextArtifact")
return artifact
42 changes: 23 additions & 19 deletions griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,42 @@

from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact, ImageArtifact
from griptape.mixins import SerializableMixin
from griptape.common import PromptStackElement, TextPromptStackContent, ImagePromptStackContent
from griptape.common import PromptStackElement, TextPromptStackContent, BasePromptStackContent, ImagePromptStackContent


@define
class PromptStack(SerializableMixin):
inputs: list[PromptStackElement] = field(factory=list, kw_only=True, metadata={"serializable": True})

def add_input(self, content: str | BaseArtifact, role: str) -> PromptStackElement:
if isinstance(content, str):
self.inputs.append(PromptStackElement(content=[TextPromptStackContent(TextArtifact(content))], role=role))
elif isinstance(content, TextArtifact):
self.inputs.append(PromptStackElement(content=[TextPromptStackContent(content)], role=role))
elif isinstance(content, ListArtifact):
contents = []
for artifact in content.value:
if isinstance(artifact, TextArtifact):
contents.append(TextPromptStackContent(artifact))
elif isinstance(artifact, ImageArtifact):
contents.append(ImagePromptStackContent(artifact))
else:
raise ValueError(f"Unsupported artifact type: {type(artifact)}")
self.inputs.append(PromptStackElement(content=contents, role=role))
else:
raise ValueError(f"Unsupported content type: {type(content)}")
new_content = self.__process_content(content)

self.inputs.append(PromptStackElement(content=new_content, role=role))

return self.inputs[-1]

def add_system_input(self, content: str) -> PromptStackElement:
def add_system_input(self, content: str | BaseArtifact) -> PromptStackElement:
return self.add_input(content, PromptStackElement.SYSTEM_ROLE)

def add_user_input(self, content: str | BaseArtifact) -> PromptStackElement:
return self.add_input(content, PromptStackElement.USER_ROLE)

def add_assistant_input(self, content: str) -> PromptStackElement:
def add_assistant_input(self, content: str | BaseArtifact) -> PromptStackElement:
return self.add_input(content, PromptStackElement.ASSISTANT_ROLE)

def __process_content(self, content: str | BaseArtifact) -> list[BasePromptStackContent]:
if isinstance(content, str):
return [TextPromptStackContent(TextArtifact(content))]
elif isinstance(content, TextArtifact):
return [TextPromptStackContent(content)]
elif isinstance(content, ImageArtifact):
return [ImagePromptStackContent(content)]

Check warning on line 35 in griptape/common/prompt_stack/prompt_stack.py

View check run for this annotation

Codecov / codecov/patch

griptape/common/prompt_stack/prompt_stack.py#L35

Added line #L35 was not covered by tests
elif isinstance(content, ListArtifact):
processed_contents = [self.__process_content(artifact) for artifact in content.value]
flattened_content = [
sub_content for processed_content in processed_contents for sub_content in processed_content
]

return flattened_content

Check warning on line 42 in griptape/common/prompt_stack/prompt_stack.py

View check run for this annotation

Codecov / codecov/patch

griptape/common/prompt_stack/prompt_stack.py#L42

Added line #L42 was not covered by tests
else:
raise ValueError(f"Unsupported content type: {type(content)}")

Check warning on line 44 in griptape/common/prompt_stack/prompt_stack.py

View check run for this annotation

Codecov / codecov/patch

griptape/common/prompt_stack/prompt_stack.py#L44

Added line #L44 was not covered by tests
8 changes: 4 additions & 4 deletions griptape/engines/summary/prompt_summary_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def summarize_artifacts_rec(
return self.prompt_driver.run(
PromptStack(
inputs=[
PromptStack.Input(system_prompt, role=PromptStack.SYSTEM_ROLE),
PromptStack.Input(user_prompt, role=PromptStack.USER_ROLE),
PromptStackElement(system_prompt, role=PromptStackElement.SYSTEM_ROLE),
PromptStackElement(user_prompt, role=PromptStackElement.USER_ROLE),
]
)
)
Expand All @@ -79,8 +79,8 @@ def summarize_artifacts_rec(
self.prompt_driver.run(
PromptStack(
inputs=[
PromptStack.Input(system_prompt, role=PromptStack.SYSTEM_ROLE),
PromptStack.Input(partial_text, role=PromptStack.USER_ROLE),
PromptStackElement(system_prompt, role=PromptStackElement.SYSTEM_ROLE),
PromptStackElement(partial_text, role=PromptStackElement.USER_ROLE),
]
)
).value,
Expand Down
5 changes: 3 additions & 2 deletions griptape/memory/structure/run.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import uuid
from attrs import define, field, Factory
from griptape.artifacts.base_artifact import BaseArtifact
from griptape.mixins import SerializableMixin


@define
class Run(SerializableMixin):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
input: str = field(kw_only=True, metadata={"serializable": True})
output: str = field(kw_only=True, metadata={"serializable": True})
input: BaseArtifact = field(kw_only=True, metadata={"serializable": True})
output: BaseArtifact = field(kw_only=True, metadata={"serializable": True})
7 changes: 1 addition & 6 deletions griptape/structures/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,7 @@ def try_run(self, *args) -> Agent:
self.task.execute()

if self.conversation_memory and self.output is not None:
if isinstance(self.task.input, tuple):
input_text = self.task.input[0].to_text()
else:
input_text = self.task.input.to_text()

run = Run(input=input_text, output=self.task.output.to_text())
run = Run(input=self.input_task.input, output=self.output)

self.conversation_memory.add_run(run)

Expand Down
7 changes: 1 addition & 6 deletions griptape/structures/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,7 @@ def try_run(self, *args) -> Pipeline:
self.__run_from_task(self.input_task)

if self.conversation_memory and self.output is not None:
if isinstance(self.input_task.input, tuple):
input_text = self.input_task.input[0].to_text()
else:
input_text = self.input_task.input.to_text()

run = Run(input=input_text, output=self.output.to_text())
run = Run(input=self.input_task.input, output=self.output)

self.conversation_memory.add_run(run)

Expand Down
7 changes: 1 addition & 6 deletions griptape/structures/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,7 @@ def try_run(self, *args) -> Workflow:
break

if self.conversation_memory and self.output is not None:
if isinstance(self.input_task.input, tuple):
input_text = self.input_task.input[0].to_text()
else:
input_text = self.input_task.input.to_text()

run = Run(input=input_text, output=self.output_task.output.to_text())
run = Run(input=self.input_task.input, output=self.output)

self.conversation_memory.add_run(run)

Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/base_text_input_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _process_task_input(

return task_input
elif isinstance(task_input, Callable):
return task_input(self)
return self._process_task_input(task_input(self))
elif isinstance(task_input, str):
return self._process_task_input(TextArtifact(task_input))
elif isinstance(task_input, BaseArtifact):
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def prompt_stack(self) -> PromptStack:
stack.add_user_input(self.input)

if self.output:
stack.add_assistant_input(self.output.to_text())
stack.add_assistant_input(self.output)

if memory:
# inserting at index 1 to place memory right after system prompt
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/toolkit_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def prompt_stack(self) -> PromptStack:

stack.add_system_input(self.generate_system_template(self))

stack.add_user_input(self.input.to_text())
stack.add_user_input(self.input)

if self.output:
stack.add_assistant_input(self.output.to_text())
Expand Down
9 changes: 2 additions & 7 deletions griptape/tools/prompt_image_generation_client/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,15 @@ class PromptImageGenerationClient(BlobArtifactFileOutputMixin, BaseTool):
Literal(
"prompts",
description="A detailed list of features and descriptions to include in the generated image.",
): list[str],
Literal(
"negative_prompts",
description="A detailed list of features and descriptions to avoid in the generated image.",
): list[str],
): list[str]
}
),
}
)
def generate_image(self, params: dict[str, dict[str, list[str]]]) -> ImageArtifact | ErrorArtifact:
prompts = params["values"]["prompts"]
negative_prompts = params["values"]["negative_prompts"]

output_artifact = self.engine.run(prompts=prompts, negative_prompts=negative_prompts)
output_artifact = self.engine.run(prompts=prompts)

if self.output_dir or self.output_file:
self._write_to_file(output_artifact)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,5 @@ def test_load(self):

assert new_memory.type == "ConversationMemory"
assert len(new_memory.runs) == 2
assert new_memory.runs[0].input == "test"
assert new_memory.runs[0].output == "mock output"
assert new_memory.runs[0].input.value == "test"
assert new_memory.runs[0].output.value == "mock output"
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def test_load(self):

assert new_memory.type == "ConversationMemory"
assert len(new_memory.runs) == 2
assert new_memory.runs[0].input == "test"
assert new_memory.runs[0].output == "mock output"
assert new_memory.runs[0].input.value == "test"
assert new_memory.runs[0].output.value == "mock output"
assert new_memory.max_runs == 5

def test_autoload(self):
Expand All @@ -71,8 +71,8 @@ def test_autoload(self):

assert autoloaded_memory.type == "ConversationMemory"
assert len(autoloaded_memory.runs) == 2
assert autoloaded_memory.runs[0].input == "test"
assert autoloaded_memory.runs[0].output == "mock output"
assert autoloaded_memory.runs[0].input.value == "test"
assert autoloaded_memory.runs[0].output.value == "mock output"

def __delete_file(self, file_path):
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from griptape.memory.structure.base_conversation_memory import BaseConversationMemory
from griptape.drivers.memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver

TEST_CONVERSATION = '{"type": "ConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": "Hi There, Hello", "output": "Hello! How can I assist you today?"}], "max_runs": 2}'
TEST_CONVERSATION = '{"type": "ConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": {"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}}], "max_runs": 2}'
CONVERSATION_ID = "117151897f344ff684b553d0655d8f39"
INDEX = "griptape_converstaion"
HOST = "127.0.0.1"
Expand Down
67 changes: 34 additions & 33 deletions tests/unit/memory/structure/test_conversation_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,35 @@
from tests.mocks.mock_prompt_driver import MockPromptDriver
from tests.mocks.mock_tokenizer import MockTokenizer
from griptape.tasks import PromptTask
from griptape.artifacts import TextArtifact


class TestConversationMemory:
def test_add_run(self):
memory = ConversationMemory()
run = Run(input="test", output="test")
run = Run(input=TextArtifact("foo"), output=TextArtifact("bar"))

memory.add_run(run)

assert memory.runs[0] == run

def test_to_json(self):
memory = ConversationMemory()
memory.add_run(Run(input="foo", output="bar"))
memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar")))

assert json.loads(memory.to_json())["type"] == "ConversationMemory"
assert json.loads(memory.to_json())["runs"][0]["input"] == "foo"
assert json.loads(memory.to_json())["runs"][0]["input"]["value"] == "foo"

def test_to_dict(self):
memory = ConversationMemory()
memory.add_run(Run(input="foo", output="bar"))
memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar")))

assert memory.to_dict()["type"] == "ConversationMemory"
assert memory.to_dict()["runs"][0]["input"] == "foo"
assert memory.to_dict()["runs"][0]["input"]["value"] == "foo"

def test_to_prompt_stack(self):
memory = ConversationMemory()
memory.add_run(Run(input="foo", output="bar"))
memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar")))

prompt_stack = memory.to_prompt_stack()

Expand All @@ -42,19 +43,19 @@ def test_to_prompt_stack(self):

def test_from_dict(self):
memory = ConversationMemory()
memory.add_run(Run(input="foo", output="bar"))
memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar")))
memory_dict = memory.to_dict()

assert isinstance(BaseConversationMemory.from_dict(memory_dict), ConversationMemory)
assert BaseConversationMemory.from_dict(memory_dict).runs[0].input == "foo"
assert BaseConversationMemory.from_dict(memory_dict).runs[0].input.value == "foo"

def test_from_json(self):
memory = ConversationMemory()
memory.add_run(Run(input="foo", output="bar"))
memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar")))
memory_dict = memory.to_dict()

assert isinstance(memory.from_dict(memory_dict), ConversationMemory)
assert memory.from_dict(memory_dict).runs[0].input == "foo"
assert memory.from_dict(memory_dict).runs[0].input.value == "foo"

def test_buffering(self):
memory = ConversationMemory(max_runs=2)
Expand All @@ -70,24 +71,24 @@ def test_buffering(self):
pipeline.run("run5")

assert len(pipeline.conversation_memory.runs) == 2
assert pipeline.conversation_memory.runs[0].input == "run4"
assert pipeline.conversation_memory.runs[1].input == "run5"
assert pipeline.conversation_memory.runs[0].input.value == "run4"
assert pipeline.conversation_memory.runs[1].input.value == "run5"

def test_add_to_prompt_stack_autopruing_disabled(self):
agent = Agent(prompt_driver=MockPromptDriver())
memory = ConversationMemory(
autoprune=False,
runs=[
Run(input="foo1", output="bar1"),
Run(input="foo2", output="bar2"),
Run(input="foo3", output="bar3"),
Run(input="foo4", output="bar4"),
Run(input="foo5", output="bar5"),
Run(input=TextArtifact("foo1"), output=TextArtifact("bar1")),
Run(input=TextArtifact("foo2"), output=TextArtifact("bar2")),
Run(input=TextArtifact("foo3"), output=TextArtifact("bar3")),
Run(input=TextArtifact("foo4"), output=TextArtifact("bar4")),
Run(input=TextArtifact("foo5"), output=TextArtifact("bar5")),
],
)
memory.structure = agent
prompt_stack = PromptStack()
prompt_stack.add_user_input("foo")
prompt_stack.add_user_input(TextArtifact("foo"))
prompt_stack.add_assistant_input("bar")
memory.add_to_prompt_stack(prompt_stack)

Expand All @@ -99,11 +100,11 @@ def test_add_to_prompt_stack_autopruing_enabled(self):
memory = ConversationMemory(
autoprune=True,
runs=[
Run(input="foo1", output="bar1"),
Run(input="foo2", output="bar2"),
Run(input="foo3", output="bar3"),
Run(input="foo4", output="bar4"),
Run(input="foo5", output="bar5"),
Run(input=TextArtifact("foo1"), output=TextArtifact("bar1")),
Run(input=TextArtifact("foo2"), output=TextArtifact("bar2")),
Run(input=TextArtifact("foo3"), output=TextArtifact("bar3")),
Run(input=TextArtifact("foo4"), output=TextArtifact("bar4")),
Run(input=TextArtifact("foo5"), output=TextArtifact("bar5")),
],
)
memory.structure = agent
Expand All @@ -120,11 +121,11 @@ def test_add_to_prompt_stack_autopruing_enabled(self):
memory = ConversationMemory(
autoprune=True,
runs=[
Run(input="foo1", output="bar1"),
Run(input="foo2", output="bar2"),
Run(input="foo3", output="bar3"),
Run(input="foo4", output="bar4"),
Run(input="foo5", output="bar5"),
Run(input=TextArtifact("foo1"), output=TextArtifact("bar1")),
Run(input=TextArtifact("foo2"), output=TextArtifact("bar2")),
Run(input=TextArtifact("foo3"), output=TextArtifact("bar3")),
Run(input=TextArtifact("foo4"), output=TextArtifact("bar4")),
Run(input=TextArtifact("foo5"), output=TextArtifact("bar5")),
],
)
memory.structure = agent
Expand All @@ -144,11 +145,11 @@ def test_add_to_prompt_stack_autopruing_enabled(self):
autoprune=True,
runs=[
# All of these sum to 155 tokens with the MockTokenizer.
Run(input="foo1", output="bar1"),
Run(input="foo2", output="bar2"),
Run(input="foo3", output="bar3"),
Run(input="foo4", output="bar4"),
Run(input="foo5", output="bar5"),
Run(input=TextArtifact("foo1"), output=TextArtifact("bar1")),
Run(input=TextArtifact("foo2"), output=TextArtifact("bar2")),
Run(input=TextArtifact("foo3"), output=TextArtifact("bar3")),
Run(input=TextArtifact("foo4"), output=TextArtifact("bar4")),
Run(input=TextArtifact("foo5"), output=TextArtifact("bar5")),
],
)
memory.structure = agent
Expand Down
Loading

0 comments on commit f756096

Please sign in to comment.