diff --git a/griptape/common/prompt_stack/elements/prompt_stack_element.py b/griptape/common/prompt_stack/elements/prompt_stack_element.py index 0ccd687d33..b94c8a5f44 100644 --- a/griptape/common/prompt_stack/elements/prompt_stack_element.py +++ b/griptape/common/prompt_stack/elements/prompt_stack_element.py @@ -46,13 +46,9 @@ def to_text(self) -> str: return self.to_text_artifact().to_text() def to_text_artifact(self) -> TextArtifact: - if all(isinstance(content, TextPromptStackContent) for content in self.content): - artifact = TextArtifact(value="") + artifact = TextArtifact(value="") - for content in self.content: - if isinstance(content, TextPromptStackContent): - artifact += content.artifact + for content in self.content: + artifact.value += content.artifact.to_text() - return artifact - else: - raise ValueError("Cannot convert to TextArtifact") + return artifact diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py index f7cc06a5cc..a82c47216a 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -3,7 +3,7 @@ from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact, ImageArtifact from griptape.mixins import SerializableMixin -from griptape.common import PromptStackElement, TextPromptStackContent, ImagePromptStackContent +from griptape.common import PromptStackElement, TextPromptStackContent, BasePromptStackContent, ImagePromptStackContent @define @@ -11,30 +11,34 @@ class PromptStack(SerializableMixin): inputs: list[PromptStackElement] = field(factory=list, kw_only=True, metadata={"serializable": True}) def add_input(self, content: str | BaseArtifact, role: str) -> PromptStackElement: - if isinstance(content, str): - self.inputs.append(PromptStackElement(content=[TextPromptStackContent(TextArtifact(content))], role=role)) - elif isinstance(content, TextArtifact): - self.inputs.append(PromptStackElement(content=[TextPromptStackContent(content)], role=role)) - elif isinstance(content, ListArtifact): - contents = [] - for artifact in content.value: - if isinstance(artifact, TextArtifact): - contents.append(TextPromptStackContent(artifact)) - elif isinstance(artifact, ImageArtifact): - contents.append(ImagePromptStackContent(artifact)) - else: - raise ValueError(f"Unsupported artifact type: {type(artifact)}") - self.inputs.append(PromptStackElement(content=contents, role=role)) - else: - raise ValueError(f"Unsupported content type: {type(content)}") + new_content = self.__process_content(content) + + self.inputs.append(PromptStackElement(content=new_content, role=role)) return self.inputs[-1] - def add_system_input(self, content: str) -> PromptStackElement: + def add_system_input(self, content: str | BaseArtifact) -> PromptStackElement: return self.add_input(content, PromptStackElement.SYSTEM_ROLE) def add_user_input(self, content: str | BaseArtifact) -> PromptStackElement: return self.add_input(content, PromptStackElement.USER_ROLE) - def add_assistant_input(self, content: str) -> PromptStackElement: + def add_assistant_input(self, content: str | BaseArtifact) -> PromptStackElement: return self.add_input(content, PromptStackElement.ASSISTANT_ROLE) + + def __process_content(self, content: str | BaseArtifact) -> list[BasePromptStackContent]: + if isinstance(content, str): + return [TextPromptStackContent(TextArtifact(content))] + elif isinstance(content, TextArtifact): + return [TextPromptStackContent(content)] + elif isinstance(content, ImageArtifact): + return [ImagePromptStackContent(content)] + elif isinstance(content, ListArtifact): + processed_contents = [self.__process_content(artifact) for artifact in content.value] + flattened_content = [ + sub_content for processed_content in processed_contents for sub_content in processed_content + ] + + return flattened_content + else: + raise ValueError(f"Unsupported content type: {type(content)}") diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index e5968c4df9..18b5f3a07b 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -64,8 +64,8 @@ def summarize_artifacts_rec( return self.prompt_driver.run( PromptStack( inputs=[ - PromptStack.Input(system_prompt, role=PromptStack.SYSTEM_ROLE), - PromptStack.Input(user_prompt, role=PromptStack.USER_ROLE), + PromptStackElement(system_prompt, role=PromptStackElement.SYSTEM_ROLE), + PromptStackElement(user_prompt, role=PromptStackElement.USER_ROLE), ] ) ) @@ -79,8 +79,8 @@ def summarize_artifacts_rec( self.prompt_driver.run( PromptStack( inputs=[ - PromptStack.Input(system_prompt, role=PromptStack.SYSTEM_ROLE), - PromptStack.Input(partial_text, role=PromptStack.USER_ROLE), + PromptStackElement(system_prompt, role=PromptStackElement.SYSTEM_ROLE), + PromptStackElement(partial_text, role=PromptStackElement.USER_ROLE), ] ) ).value, diff --git a/griptape/memory/structure/run.py b/griptape/memory/structure/run.py index c5a2b9b559..b91df2ae9c 100644 --- a/griptape/memory/structure/run.py +++ b/griptape/memory/structure/run.py @@ -1,10 +1,11 @@ import uuid from attrs import define, field, Factory +from griptape.artifacts.base_artifact import BaseArtifact from griptape.mixins import SerializableMixin @define class Run(SerializableMixin): id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) - input: str = field(kw_only=True, metadata={"serializable": True}) - output: str = field(kw_only=True, metadata={"serializable": True}) + input: BaseArtifact = field(kw_only=True, metadata={"serializable": True}) + output: BaseArtifact = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index 12512341ec..f767d74a26 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -51,12 +51,7 @@ def try_run(self, *args) -> Agent: self.task.execute() if self.conversation_memory and self.output is not None: - if isinstance(self.task.input, tuple): - input_text = self.task.input[0].to_text() - else: - input_text = self.task.input.to_text() - - run = Run(input=input_text, output=self.task.output.to_text()) + run = Run(input=self.input_task.input, output=self.output) self.conversation_memory.add_run(run) diff --git a/griptape/structures/pipeline.py b/griptape/structures/pipeline.py index d5724244e7..fe8fcbdf18 100644 --- a/griptape/structures/pipeline.py +++ b/griptape/structures/pipeline.py @@ -46,12 +46,7 @@ def try_run(self, *args) -> Pipeline: self.__run_from_task(self.input_task) if self.conversation_memory and self.output is not None: - if isinstance(self.input_task.input, tuple): - input_text = self.input_task.input[0].to_text() - else: - input_text = self.input_task.input.to_text() - - run = Run(input=input_text, output=self.output.to_text()) + run = Run(input=self.input_task.input, output=self.output) self.conversation_memory.add_run(run) diff --git a/griptape/structures/workflow.py b/griptape/structures/workflow.py index f9570edc00..4486ce77c8 100644 --- a/griptape/structures/workflow.py +++ b/griptape/structures/workflow.py @@ -110,12 +110,7 @@ def try_run(self, *args) -> Workflow: break if self.conversation_memory and self.output is not None: - if isinstance(self.input_task.input, tuple): - input_text = self.input_task.input[0].to_text() - else: - input_text = self.input_task.input.to_text() - - run = Run(input=input_text, output=self.output_task.output.to_text()) + run = Run(input=self.input_task.input, output=self.output) self.conversation_memory.add_run(run) diff --git a/griptape/tasks/base_text_input_task.py b/griptape/tasks/base_text_input_task.py index 9281f5a7e4..40b6e75c52 100644 --- a/griptape/tasks/base_text_input_task.py +++ b/griptape/tasks/base_text_input_task.py @@ -52,7 +52,7 @@ def _process_task_input( return task_input elif isinstance(task_input, Callable): - return task_input(self) + return self._process_task_input(task_input(self)) elif isinstance(task_input, str): return self._process_task_input(TextArtifact(task_input)) elif isinstance(task_input, BaseArtifact): diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 5727a499b7..4fd8a34db5 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -33,7 +33,7 @@ def prompt_stack(self) -> PromptStack: stack.add_user_input(self.input) if self.output: - stack.add_assistant_input(self.output.to_text()) + stack.add_assistant_input(self.output) if memory: # inserting at index 1 to place memory right after system prompt diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index de291d1a92..60eebc4058 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -67,7 +67,7 @@ def prompt_stack(self) -> PromptStack: stack.add_system_input(self.generate_system_template(self)) - stack.add_user_input(self.input.to_text()) + stack.add_user_input(self.input) if self.output: stack.add_assistant_input(self.output.to_text()) diff --git a/griptape/tools/prompt_image_generation_client/tool.py b/griptape/tools/prompt_image_generation_client/tool.py index 50020a1eab..6f4ce0e1fe 100644 --- a/griptape/tools/prompt_image_generation_client/tool.py +++ b/griptape/tools/prompt_image_generation_client/tool.py @@ -30,20 +30,15 @@ class PromptImageGenerationClient(BlobArtifactFileOutputMixin, BaseTool): Literal( "prompts", description="A detailed list of features and descriptions to include in the generated image.", - ): list[str], - Literal( - "negative_prompts", - description="A detailed list of features and descriptions to avoid in the generated image.", - ): list[str], + ): list[str] } ), } ) def generate_image(self, params: dict[str, dict[str, list[str]]]) -> ImageArtifact | ErrorArtifact: prompts = params["values"]["prompts"] - negative_prompts = params["values"]["negative_prompts"] - output_artifact = self.engine.run(prompts=prompts, negative_prompts=negative_prompts) + output_artifact = self.engine.run(prompts=prompts) if self.output_dir or self.output_file: self._write_to_file(output_artifact) diff --git a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py index ef3b0e1df2..80d77d24d1 100644 --- a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py @@ -81,5 +81,5 @@ def test_load(self): assert new_memory.type == "ConversationMemory" assert len(new_memory.runs) == 2 - assert new_memory.runs[0].input == "test" - assert new_memory.runs[0].output == "mock output" + assert new_memory.runs[0].input.value == "test" + assert new_memory.runs[0].output.value == "mock output" diff --git a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py index d12d5d3d2e..c794afd0e7 100644 --- a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py @@ -52,8 +52,8 @@ def test_load(self): assert new_memory.type == "ConversationMemory" assert len(new_memory.runs) == 2 - assert new_memory.runs[0].input == "test" - assert new_memory.runs[0].output == "mock output" + assert new_memory.runs[0].input.value == "test" + assert new_memory.runs[0].output.value == "mock output" assert new_memory.max_runs == 5 def test_autoload(self): @@ -71,8 +71,8 @@ def test_autoload(self): assert autoloaded_memory.type == "ConversationMemory" assert len(autoloaded_memory.runs) == 2 - assert autoloaded_memory.runs[0].input == "test" - assert autoloaded_memory.runs[0].output == "mock output" + assert autoloaded_memory.runs[0].input.value == "test" + assert autoloaded_memory.runs[0].output.value == "mock output" def __delete_file(self, file_path): try: diff --git a/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py index dee840508d..7a74b6921a 100644 --- a/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py @@ -3,7 +3,7 @@ from griptape.memory.structure.base_conversation_memory import BaseConversationMemory from griptape.drivers.memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver -TEST_CONVERSATION = '{"type": "ConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": "Hi There, Hello", "output": "Hello! How can I assist you today?"}], "max_runs": 2}' +TEST_CONVERSATION = '{"type": "ConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": {"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}}], "max_runs": 2}' CONVERSATION_ID = "117151897f344ff684b553d0655d8f39" INDEX = "griptape_converstaion" HOST = "127.0.0.1" diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 02426cbee7..7ffa578f85 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -6,12 +6,13 @@ from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tokenizer import MockTokenizer from griptape.tasks import PromptTask +from griptape.artifacts import TextArtifact class TestConversationMemory: def test_add_run(self): memory = ConversationMemory() - run = Run(input="test", output="test") + run = Run(input=TextArtifact("foo"), output=TextArtifact("bar")) memory.add_run(run) @@ -19,21 +20,21 @@ def test_add_run(self): def test_to_json(self): memory = ConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) assert json.loads(memory.to_json())["type"] == "ConversationMemory" - assert json.loads(memory.to_json())["runs"][0]["input"] == "foo" + assert json.loads(memory.to_json())["runs"][0]["input"]["value"] == "foo" def test_to_dict(self): memory = ConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) assert memory.to_dict()["type"] == "ConversationMemory" - assert memory.to_dict()["runs"][0]["input"] == "foo" + assert memory.to_dict()["runs"][0]["input"]["value"] == "foo" def test_to_prompt_stack(self): memory = ConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) prompt_stack = memory.to_prompt_stack() @@ -42,19 +43,19 @@ def test_to_prompt_stack(self): def test_from_dict(self): memory = ConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) memory_dict = memory.to_dict() assert isinstance(BaseConversationMemory.from_dict(memory_dict), ConversationMemory) - assert BaseConversationMemory.from_dict(memory_dict).runs[0].input == "foo" + assert BaseConversationMemory.from_dict(memory_dict).runs[0].input.value == "foo" def test_from_json(self): memory = ConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) memory_dict = memory.to_dict() assert isinstance(memory.from_dict(memory_dict), ConversationMemory) - assert memory.from_dict(memory_dict).runs[0].input == "foo" + assert memory.from_dict(memory_dict).runs[0].input.value == "foo" def test_buffering(self): memory = ConversationMemory(max_runs=2) @@ -70,24 +71,24 @@ def test_buffering(self): pipeline.run("run5") assert len(pipeline.conversation_memory.runs) == 2 - assert pipeline.conversation_memory.runs[0].input == "run4" - assert pipeline.conversation_memory.runs[1].input == "run5" + assert pipeline.conversation_memory.runs[0].input.value == "run4" + assert pipeline.conversation_memory.runs[1].input.value == "run5" def test_add_to_prompt_stack_autopruing_disabled(self): agent = Agent(prompt_driver=MockPromptDriver()) memory = ConversationMemory( autoprune=False, runs=[ - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), + Run(input=TextArtifact("foo1"), output=TextArtifact("bar1")), + Run(input=TextArtifact("foo2"), output=TextArtifact("bar2")), + Run(input=TextArtifact("foo3"), output=TextArtifact("bar3")), + Run(input=TextArtifact("foo4"), output=TextArtifact("bar4")), + Run(input=TextArtifact("foo5"), output=TextArtifact("bar5")), ], ) memory.structure = agent prompt_stack = PromptStack() - prompt_stack.add_user_input("foo") + prompt_stack.add_user_input(TextArtifact("foo")) prompt_stack.add_assistant_input("bar") memory.add_to_prompt_stack(prompt_stack) @@ -99,11 +100,11 @@ def test_add_to_prompt_stack_autopruing_enabled(self): memory = ConversationMemory( autoprune=True, runs=[ - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), + Run(input=TextArtifact("foo1"), output=TextArtifact("bar1")), + Run(input=TextArtifact("foo2"), output=TextArtifact("bar2")), + Run(input=TextArtifact("foo3"), output=TextArtifact("bar3")), + Run(input=TextArtifact("foo4"), output=TextArtifact("bar4")), + Run(input=TextArtifact("foo5"), output=TextArtifact("bar5")), ], ) memory.structure = agent @@ -120,11 +121,11 @@ def test_add_to_prompt_stack_autopruing_enabled(self): memory = ConversationMemory( autoprune=True, runs=[ - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), + Run(input=TextArtifact("foo1"), output=TextArtifact("bar1")), + Run(input=TextArtifact("foo2"), output=TextArtifact("bar2")), + Run(input=TextArtifact("foo3"), output=TextArtifact("bar3")), + Run(input=TextArtifact("foo4"), output=TextArtifact("bar4")), + Run(input=TextArtifact("foo5"), output=TextArtifact("bar5")), ], ) memory.structure = agent @@ -144,11 +145,11 @@ def test_add_to_prompt_stack_autopruing_enabled(self): autoprune=True, runs=[ # All of these sum to 155 tokens with the MockTokenizer. - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), + Run(input=TextArtifact("foo1"), output=TextArtifact("bar1")), + Run(input=TextArtifact("foo2"), output=TextArtifact("bar2")), + Run(input=TextArtifact("foo3"), output=TextArtifact("bar3")), + Run(input=TextArtifact("foo4"), output=TextArtifact("bar4")), + Run(input=TextArtifact("foo5"), output=TextArtifact("bar5")), ], ) memory.structure = agent diff --git a/tests/unit/memory/structure/test_summary_conversation_memory.py b/tests/unit/memory/structure/test_summary_conversation_memory.py index 2366892847..5ca99f07ab 100644 --- a/tests/unit/memory/structure/test_summary_conversation_memory.py +++ b/tests/unit/memory/structure/test_summary_conversation_memory.py @@ -3,6 +3,7 @@ from griptape.memory.structure import Run, SummaryConversationMemory from griptape.structures import Pipeline +from griptape.artifacts import TextArtifact from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_structure_config import MockStructureConfig @@ -40,21 +41,21 @@ def test_after_run(self): def test_to_json(self): memory = SummaryConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) assert json.loads(memory.to_json())["type"] == "SummaryConversationMemory" - assert json.loads(memory.to_json())["runs"][0]["input"] == "foo" + assert json.loads(memory.to_json())["runs"][0]["input"]["value"] == "foo" def test_to_dict(self): memory = SummaryConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) assert memory.to_dict()["type"] == "SummaryConversationMemory" - assert memory.to_dict()["runs"][0]["input"] == "foo" + assert memory.to_dict()["runs"][0]["input"]["value"] == "foo" def test_to_prompt_stack(self): memory = SummaryConversationMemory(summary="foobar") - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) prompt_stack = memory.to_prompt_stack() @@ -64,12 +65,12 @@ def test_to_prompt_stack(self): def test_from_dict(self): memory = SummaryConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) memory_dict = memory.to_dict() assert isinstance(memory.from_dict(memory_dict), SummaryConversationMemory) - assert memory.from_dict(memory_dict).runs[0].input == "foo" - assert memory.from_dict(memory_dict).runs[0].output == "bar" + assert memory.from_dict(memory_dict).runs[0].input.value == "foo" + assert memory.from_dict(memory_dict).runs[0].output.value == "bar" assert memory.from_dict(memory_dict).offset == memory.offset assert memory.from_dict(memory_dict).summary == memory.summary assert memory.from_dict(memory_dict).summary_index == memory.summary_index @@ -77,11 +78,11 @@ def test_from_dict(self): def test_from_json(self): memory = SummaryConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) memory_dict = memory.to_dict() assert isinstance(memory.from_dict(memory_dict), SummaryConversationMemory) - assert memory.from_dict(memory_dict).runs[0].input == "foo" + assert memory.from_dict(memory_dict).runs[0].input.value == "foo" def test_config_prompt_driver(self): memory = SummaryConversationMemory()