From a7ed0729c2a47ee2b32faa6fcdc7906268dfecb7 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 13 Dec 2024 14:50:42 -0800 Subject: [PATCH] Add StructuredOutputTool --- griptape/common/prompt_stack/prompt_stack.py | 4 +- griptape/drivers/prompt/base_prompt_driver.py | 15 +++++ .../drivers/prompt/cohere_prompt_driver.py | 12 ++++ .../drivers/prompt/ollama_prompt_driver.py | 8 ++- .../prompt/openai_chat_prompt_driver.py | 15 +++++ griptape/structures/agent.py | 23 ++++++- griptape/tasks/actions_subtask.py | 11 +++- griptape/tasks/prompt_task.py | 4 +- griptape/tasks/toolkit_task.py | 2 +- griptape/tools/__init__.py | 2 + griptape/tools/structured_output/__init__.py | 0 griptape/tools/structured_output/tool.py | 20 ++++++ tests/unit/structures/test_agent.py | 8 +++ tests/unit/tasks/test_actions_subtask.py | 66 +++++++++++++++++++ .../unit/tools/test_structured_output_tool.py | 13 ++++ 15 files changed, 195 insertions(+), 8 deletions(-) create mode 100644 griptape/tools/structured_output/__init__.py create mode 100644 griptape/tools/structured_output/tool.py create mode 100644 tests/unit/tools/test_structured_output_tool.py diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py index 3b1b8ef74..4cfc99008 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from attrs import define, field +from schema import Schema from griptape.artifacts import ( ActionArtifact, @@ -31,6 +32,7 @@ class PromptStack(SerializableMixin): messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True}) tools: list[BaseTool] = field(factory=list, kw_only=True) + output_schema: Optional[Schema] = field(default=None, kw_only=True) @property def system_messages(self) -> list[Message]: diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 707f67644..e964f0fe7 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -56,9 +56,17 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True}) extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: + from griptape.tools.structured_output.tool import StructuredOutputTool + + if not self.use_native_structured_output and prompt_stack.output_schema is not None: + structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) + if structured_ouptut_tool not in prompt_stack.tools: + prompt_stack.tools.append(structured_ouptut_tool) + EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: @@ -73,11 +81,18 @@ def after_run(self, result: Message) -> None: @observable(tags=["PromptDriver.run()"]) def run(self, prompt_input: PromptStack | BaseArtifact) -> Message: + from griptape.tools.structured_output.tool import StructuredOutputTool + if isinstance(prompt_input, BaseArtifact): prompt_stack = PromptStack.from_artifact(prompt_input) else: prompt_stack = prompt_input + if not self.use_native_structured_output and prompt_stack.output_schema is not None: + structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) + if structured_ouptut_tool not in prompt_stack.tools: + prompt_stack.tools.append(structured_ouptut_tool) + for attempt in self.retrying(): with attempt: self.before_run(prompt_stack) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 3811db5cd..5fe0400ac 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -113,6 +113,18 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if prompt_stack.tools and self.use_native_tools else {} ), + **( + { + "response_format": { + "type": "json_object", + "schema": prompt_stack.output_schema.json_schema("Output"), + } + } + if not prompt_stack.tools # Respond format is not supported with tools https://docs.cohere.com/reference/chat#request.body.response_format + and prompt_stack.output_schema is not None + and self.use_native_structured_output + else {} + ), **self.extra_params, } diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 5cbba1fdf..f372f3b54 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -68,6 +68,7 @@ class OllamaPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() @@ -79,7 +80,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) response = self.client.chat(**params) - logger.debug(response) + logger.debug(response.model_dump()) return Message( content=self.__to_prompt_stack_message_content(response), @@ -113,6 +114,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: and not self.stream # Tool calling is only supported when not streaming else {} ), + **( + {"format": prompt_stack.output_schema.json_schema("Output")} + if prompt_stack.output_schema and self.use_native_structured_output + else {} + ), **self.extra_params, } diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index eed0e35f0..f0468cd02 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -76,6 +76,7 @@ class OpenAiChatPromptDriver(BasePromptDriver): seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True}) ignored_exception_types: tuple[type[Exception], ...] = field( default=Factory( @@ -160,6 +161,20 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}), **({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}), **({"stream_options": {"include_usage": True}} if self.stream else {}), + **( + { + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": prompt_stack.output_schema.json_schema("Output"), + "strict": True, + }, + } + } + if prompt_stack.output_schema is not None and self.use_native_structured_output + else {} + ), **self.extra_params, } diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index 128c02faa..42f82bcf8 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union from attrs import Attribute, Factory, define, evolve, field +from schema import Schema from griptape.artifacts.text_artifact import TextArtifact from griptape.common import observable @@ -32,6 +33,7 @@ class Agent(Structure): tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) fail_fast: bool = field(default=False, kw_only=True) + output_type: Optional[Union[type, Schema]] = field(default=None, kw_only=True) @fail_fast.validator # pyright: ignore[reportAttributeAccessIssue] def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FBT001 @@ -41,18 +43,27 @@ def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FB def __attrs_post_init__(self) -> None: super().__attrs_post_init__() - self.prompt_driver.stream = self.stream + prompt_driver = self.prompt_driver + prompt_driver.stream = self.stream if len(self.tasks) == 0: if self.tools: task = ToolkitTask( self.input, - prompt_driver=self.prompt_driver, + prompt_driver=prompt_driver, tools=self.tools, max_meta_memory_entries=self.max_meta_memory_entries, + output_schema=self._build_schema_from_type(self.output_type) + if self.output_type is not None + else None, ) else: task = PromptTask( - self.input, prompt_driver=self.prompt_driver, max_meta_memory_entries=self.max_meta_memory_entries + self.input, + prompt_driver=prompt_driver, + max_meta_memory_entries=self.max_meta_memory_entries, + output_schema=self._build_schema_from_type(self.output_type) + if self.output_type is not None + else None, ) self.add_task(task) @@ -80,3 +91,9 @@ def try_run(self, *args) -> Agent: self.task.run() return self + + def _build_schema_from_type(self, output_type: type | Schema) -> Schema: + if isinstance(output_type, Schema): + return output_type + else: + return Schema(output_type) diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 6f9d70053..c889554fd 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -9,12 +9,13 @@ from attrs import define, field from griptape import utils -from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, JsonArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction from griptape.configs import Defaults from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask +from griptape.tools.structured_output.tool import StructuredOutputTool from griptape.utils import remove_null_values_in_dict_recursively, with_contextvars if TYPE_CHECKING: @@ -87,6 +88,14 @@ def attach_to(self, parent_task: BaseTask) -> None: self.__init_from_prompt(self.input.to_text()) else: self.__init_from_artifacts(self.input) + + structured_outputs = [a for a in self.actions if isinstance(a.tool, StructuredOutputTool)] + if structured_outputs: + output_values = [JsonArtifact(a.input["values"]) for a in structured_outputs] + if len(structured_outputs) > 1: + self.output = ListArtifact(output_values) + else: + self.output = output_values[0] except Exception as e: logger.error("Subtask %s\nError parsing tool action: %s", self.origin_task.id, e) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index e1d582f1f..64a3eb314 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union from attrs import NOTHING, Factory, NothingType, define, field +from schema import Schema from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact from griptape.common import PromptStack @@ -38,6 +39,7 @@ class PromptTask(RuleMixin, BaseTask): default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), alias="input", ) + output_schema: Optional[Schema] = field(default=None, kw_only=True) @property def rulesets(self) -> list: @@ -67,7 +69,7 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], @property def prompt_stack(self) -> PromptStack: - stack = PromptStack() + stack = PromptStack(output_schema=self.output_schema) memory = self.conversation_memory system_template = self.generate_system_template(self) diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index d9104ed68..6b0cb8c9b 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -67,7 +67,7 @@ def tool_output_memory(self) -> list[TaskMemory]: @property def prompt_stack(self) -> PromptStack: - stack = PromptStack(tools=self.tools) + stack = PromptStack(tools=self.tools, output_schema=self.output_schema) memory = self.structure.conversation_memory if self.structure is not None else None stack.add_system_message(self.generate_system_template(self)) diff --git a/griptape/tools/__init__.py b/griptape/tools/__init__.py index 67a1712a1..ec9cbd5b7 100644 --- a/griptape/tools/__init__.py +++ b/griptape/tools/__init__.py @@ -23,6 +23,7 @@ from .extraction.tool import ExtractionTool from .prompt_summary.tool import PromptSummaryTool from .query.tool import QueryTool +from .structured_output.tool import StructuredOutputTool __all__ = [ "BaseTool", @@ -50,4 +51,5 @@ "ExtractionTool", "PromptSummaryTool", "QueryTool", + "StructuredOutputTool", ] diff --git a/griptape/tools/structured_output/__init__.py b/griptape/tools/structured_output/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/tools/structured_output/tool.py b/griptape/tools/structured_output/tool.py new file mode 100644 index 000000000..89e638f59 --- /dev/null +++ b/griptape/tools/structured_output/tool.py @@ -0,0 +1,20 @@ +from attrs import define, field +from schema import Schema + +from griptape.artifacts import BaseArtifact, JsonArtifact +from griptape.tools import BaseTool +from griptape.utils.decorators import activity + + +@define +class StructuredOutputTool(BaseTool): + output_schema: Schema = field(kw_only=True) + + @activity( + config={ + "description": "Used to provide the final response which ends this conversation.", + "schema": lambda self: self.output_schema, + } + ) + def provide_output(self, params: dict) -> BaseArtifact: + return JsonArtifact(params["values"]) diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 387910f40..eae6930f8 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -284,3 +284,11 @@ def test_stream_mutation(self): assert isinstance(agent.tasks[0], PromptTask) assert agent.tasks[0].prompt_driver.stream is True assert agent.tasks[0].prompt_driver is not prompt_driver + + def test_output_type_primitive(self): + from griptape.tools import StructuredOutputTool + + agent = Agent(output_type=str) + + assert isinstance(agent.tools[0], StructuredOutputTool) + assert agent.tools[0].output_schema == agent._build_schema_from_type(str) diff --git a/tests/unit/tasks/test_actions_subtask.py b/tests/unit/tasks/test_actions_subtask.py index b9c692315..33df86133 100644 --- a/tests/unit/tasks/test_actions_subtask.py +++ b/tests/unit/tasks/test_actions_subtask.py @@ -4,6 +4,7 @@ from griptape.artifacts import ActionArtifact, ListArtifact, TextArtifact from griptape.artifacts.error_artifact import ErrorArtifact +from griptape.artifacts.json_artifact import JsonArtifact from griptape.common import ToolAction from griptape.structures import Agent from griptape.tasks import ActionsSubtask, ToolkitTask @@ -257,3 +258,68 @@ def test_origin_task(self): with pytest.raises(Exception, match="ActionSubtask has no origin task."): assert ActionsSubtask("test").origin_task + + def test_structured_output_tool(self): + import schema + + from griptape.tools.structured_output.tool import StructuredOutputTool + + actions = ListArtifact( + [ + ActionArtifact( + ToolAction( + tag="foo", + name="StructuredOutputTool", + path="provide_output", + input={"values": {"test": "value"}}, + ) + ), + ] + ) + + task = ToolkitTask(tools=[StructuredOutputTool(output_schema=schema.Schema({"test": str}))]) + Agent().add_task(task) + subtask = task.add_subtask(ActionsSubtask(actions)) + + assert isinstance(subtask.output, JsonArtifact) + assert subtask.output.value == {"test": "value"} + + def test_structured_output_tool_multiple(self): + import schema + + from griptape.tools.structured_output.tool import StructuredOutputTool + + actions = ListArtifact( + [ + ActionArtifact( + ToolAction( + tag="foo", + name="StructuredOutputTool1", + path="provide_output", + input={"values": {"test1": "value"}}, + ) + ), + ActionArtifact( + ToolAction( + tag="foo", + name="StructuredOutputTool2", + path="provide_output", + input={"values": {"test2": "value"}}, + ) + ), + ] + ) + + task = ToolkitTask( + tools=[ + StructuredOutputTool(name="StructuredOutputTool1", output_schema=schema.Schema({"test": str})), + StructuredOutputTool(name="StructuredOutputTool2", output_schema=schema.Schema({"test": str})), + ] + ) + Agent().add_task(task) + subtask = task.add_subtask(ActionsSubtask(actions)) + + assert isinstance(subtask.output, ListArtifact) + assert len(subtask.output.value) == 2 + assert subtask.output.value[0].value == {"test1": "value"} + assert subtask.output.value[1].value == {"test2": "value"} diff --git a/tests/unit/tools/test_structured_output_tool.py b/tests/unit/tools/test_structured_output_tool.py new file mode 100644 index 000000000..d310b2f9b --- /dev/null +++ b/tests/unit/tools/test_structured_output_tool.py @@ -0,0 +1,13 @@ +import pytest +import schema + +from griptape.tools import StructuredOutputTool + + +class TestStructuredOutputTool: + @pytest.fixture() + def tool(self): + return StructuredOutputTool(output_schema=schema.Schema({"foo": "bar"})) + + def test_provide_output(self, tool): + assert tool.provide_output({"values": {"foo": "bar"}}).value == {"foo": "bar"}