From dfc5cbaebb35415c4c1b97d064f7e43148a15acc Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 13 Dec 2024 14:50:42 -0800 Subject: [PATCH] Add Structured Output functionality --- griptape/common/prompt_stack/prompt_stack.py | 5 +- .../prompt/amazon_bedrock_prompt_driver.py | 41 +++++-- .../drivers/prompt/anthropic_prompt_driver.py | 37 ++++-- griptape/drivers/prompt/base_prompt_driver.py | 16 ++- .../drivers/prompt/cohere_prompt_driver.py | 18 ++- .../drivers/prompt/google_prompt_driver.py | 40 +++++-- .../prompt/huggingface_hub_prompt_driver.py | 45 +++++++- .../drivers/prompt/ollama_prompt_driver.py | 25 ++-- .../prompt/openai_chat_prompt_driver.py | 40 +++++-- griptape/rules/json_schema_rule.py | 9 +- griptape/schemas/base_schema.py | 3 + griptape/tasks/actions_subtask.py | 11 +- griptape/tasks/prompt_task.py | 108 ++++++++++------- griptape/tools/__init__.py | 2 + griptape/tools/structured_output/__init__.py | 0 griptape/tools/structured_output/tool.py | 20 ++++ .../test_amazon_bedrock_drivers_config.py | 4 + .../drivers/test_anthropic_drivers_config.py | 2 + .../test_azure_openai_drivers_config.py | 2 + .../drivers/test_cohere_drivers_config.py | 2 + .../configs/drivers/test_drivers_config.py | 2 + .../drivers/test_google_drivers_config.py | 2 + .../drivers/test_openai_driver_config.py | 2 + .../test_amazon_bedrock_prompt_driver.py | 82 +++++++++++-- .../prompt/test_anthropic_prompt_driver.py | 73 +++++++++++- .../prompt/test_cohere_prompt_driver.py | 107 ++++++++++++++++- .../prompt/test_google_prompt_driver.py | 52 +++++++-- .../test_hugging_face_hub_prompt_driver.py | 42 ++++++- .../prompt/test_ollama_prompt_driver.py | 109 ++++++++++++++---- tests/unit/structures/test_structure.py | 2 + tests/unit/tasks/test_actions_subtask.py | 68 ++++++++++- tests/unit/tasks/test_tool_task.py | 2 + tests/unit/tasks/test_toolkit_task.py | 2 + .../unit/tools/test_structured_output_tool.py | 13 +++ 34 files changed, 835 insertions(+), 153 deletions(-) create mode 100644 griptape/tools/structured_output/__init__.py create mode 100644 griptape/tools/structured_output/tool.py create mode 100644 tests/unit/tools/test_structured_output_tool.py diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py index 3b1b8ef74..752ce8a8d 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from attrs import define, field @@ -24,6 +24,8 @@ from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: + from schema import Schema + from griptape.tools import BaseTool @@ -31,6 +33,7 @@ class PromptStack(SerializableMixin): messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True}) tools: list[BaseTool] = field(factory=list, kw_only=True) + output_schema: Optional[Schema] = field(default=None, kw_only=True) @property def system_messages(self) -> list[Message]: diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index b108180d2..62f6834ec 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -1,9 +1,9 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from schema import Schema from griptape.artifacts import ( @@ -55,9 +55,20 @@ class AmazonBedrockPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + native_structured_output_mode: Literal["native", "tool"] = field( + default="tool", kw_only=True, metadata={"serializable": True} + ) tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True}) _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + @native_structured_output_mode.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_native_structured_output_mode(self, attribute: Attribute, value: str) -> str: + if value == "native": + raise ValueError("AmazonBedrockPromptDriver does not support `native` structured output mode.") + + return value + @lazy_property() def client(self) -> Any: return self.session.client("bedrock-runtime") @@ -103,10 +114,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages] - messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()]) - return { + params = { "modelId": self.model, "messages": messages, "system": system_messages, @@ -115,14 +125,27 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}), }, "additionalModelRequestFields": self.additional_model_request_fields, - **( - {"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}} - if prompt_stack.tools and self.use_native_tools - else {} - ), **self.extra_params, } + if prompt_stack.tools and self.use_native_tools: + params["toolConfig"] = { + "tools": [], + "toolChoice": self.tool_choice, + } + + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.native_structured_output_mode == "tool" + ): + self._add_structured_output_tool(prompt_stack) + params["toolConfig"]["toolChoice"] = {"any": {}} + + params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools) + + return params + def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]: return [ { diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 3341006a1..7c1bd422b 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -1,9 +1,9 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from schema import Schema from griptape.artifacts import ( @@ -68,6 +68,10 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + native_structured_output_mode: Literal["native", "tool"] = field( + default="tool", kw_only=True, metadata={"serializable": True} + ) max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @@ -75,6 +79,13 @@ class AnthropicPromptDriver(BasePromptDriver): def client(self) -> Client: return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key) + @native_structured_output_mode.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_native_structured_output_mode(self, attribute: Attribute, value: str) -> str: + if value == "native": + raise ValueError("AnthropicPromptDriver does not support `native` structured output mode.") + + return value + @observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) @@ -110,7 +121,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = prompt_stack.system_messages system_message = system_messages[0].to_text() if system_messages else None - return { + params = { "model": self.model, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, @@ -118,15 +129,25 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "top_k": self.top_k, "max_tokens": self.max_tokens, "messages": messages, - **( - {"tools": self.__to_anthropic_tools(prompt_stack.tools), "tool_choice": self.tool_choice} - if prompt_stack.tools and self.use_native_tools - else {} - ), **({"system": system_message} if system_message else {}), **self.extra_params, } + if prompt_stack.tools and self.use_native_tools: + params["tool_choice"] = self.tool_choice + + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.native_structured_output_mode == "tool" + ): + self._add_structured_output_tool(prompt_stack) + params["tool_choice"] = {"type": "any"} + + params["tools"] = self.__to_anthropic_tools(prompt_stack.tools) + + return params + def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]: return [ {"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)} diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 707f67644..271eaec60 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional from attrs import Factory, define, field @@ -56,6 +56,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + native_structured_output_mode: Literal["native", "tool"] = field( + default="native", kw_only=True, metadata={"serializable": True} + ) extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: @@ -122,6 +126,16 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ... @abstractmethod def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ... + def _add_structured_output_tool(self, prompt_stack: PromptStack) -> None: + from griptape.tools.structured_output.tool import StructuredOutputTool + + if prompt_stack.output_schema is None: + raise ValueError("PromptStack must have an output schema to use structured output.") + + structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) + if structured_output_tool not in prompt_stack.tools: + prompt_stack.tools.append(structured_output_tool) + def __process_run(self, prompt_stack: PromptStack) -> Message: return self.try_run(prompt_stack) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 3811db5cd..74c1d8432 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -53,6 +53,7 @@ class CoherePromptDriver(BasePromptDriver): model: str = field(metadata={"serializable": True}) force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), @@ -101,7 +102,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: messages = self.__to_cohere_messages(prompt_stack.messages) - return { + params = { "model": self.model, "messages": messages, "temperature": self.temperature, @@ -116,6 +117,21 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } + if prompt_stack.output_schema is not None and self.use_native_structured_output: + if self.native_structured_output_mode == "native": + params["response_format"] = { + "type": "json_object", + "schema": prompt_stack.output_schema.json_schema("Output"), + } + elif self.native_structured_output_mode == "tool": + # TODO: Implement tool choice once supported + self._add_structured_output_tool(prompt_stack) + + if prompt_stack.tools and self.use_native_tools: + params["tools"] = self.__to_cohere_tools(prompt_stack.tools) + + return params + def __to_cohere_messages(self, messages: list[Message]) -> list[dict]: cohere_messages = [] diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 2a6bdbf6d..f3a722b7b 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -2,9 +2,9 @@ import json import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from schema import Schema from griptape.artifacts import ActionArtifact, TextArtifact @@ -63,9 +63,20 @@ class GooglePromptDriver(BasePromptDriver): top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + native_structured_output_mode: Literal["native", "tool"] = field( + default="tool", kw_only=True, metadata={"serializable": True} + ) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True}) _client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + @native_structured_output_mode.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_native_structured_output_mode(self, attribute: Attribute, value: str) -> str: + if value == "native": + raise ValueError("GooglePromptDriver does not support `native` structured output mode.") + + return value + @lazy_property() def client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") @@ -135,7 +146,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages], ) - return { + params = { "generation_config": types.GenerationConfig( **{ # For some reason, providing stop sequences when streaming breaks native functions @@ -148,16 +159,23 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, }, ), - **( - { - "tools": self.__to_google_tools(prompt_stack.tools), - "tool_config": {"function_calling_config": {"mode": self.tool_choice}}, - } - if prompt_stack.tools and self.use_native_tools - else {} - ), } + if prompt_stack.tools and self.use_native_tools: + params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}} + + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.native_structured_output_mode == "tool" + ): + params["tool_config"]["function_calling_config"]["mode"] = "auto" + self._add_structured_output_tool(prompt_stack) + + params["tools"] = self.__to_google_tools(prompt_stack.tools) + + return params + def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType: types = import_optional_dependency("google.generativeai.types") diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index c2c45c3ae..83d911c2a 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -1,9 +1,9 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from griptape.common import DeltaMessage, Message, PromptStack, TextDeltaMessageContent, observable from griptape.configs import Defaults @@ -35,6 +35,10 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): api_token: str = field(kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + native_structured_output_mode: Literal["native", "tool"] = field( + default="native", kw_only=True, metadata={"serializable": True} + ) tokenizer: HuggingFaceTokenizer = field( default=Factory( lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), @@ -51,11 +55,23 @@ def client(self) -> InferenceClient: token=self.api_token, ) + @native_structured_output_mode.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_native_structured_output_mode(self, attribute: Attribute, value: str) -> str: + if value == "tool": + raise ValueError("HuggingFaceHubPromptDriver does not support `tool` structured output mode.") + + return value + @observable def try_run(self, prompt_stack: PromptStack) -> Message: prompt = self.prompt_stack_to_string(prompt_stack) full_params = self._base_params(prompt_stack) - logger.debug((prompt, full_params)) + logger.debug( + { + "prompt": prompt, + **full_params, + } + ) response = self.client.text_generation( prompt, @@ -75,7 +91,12 @@ def try_run(self, prompt_stack: PromptStack) -> Message: def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: prompt = self.prompt_stack_to_string(prompt_stack) full_params = {**self._base_params(prompt_stack), "stream": True} - logger.debug((prompt, full_params)) + logger.debug( + { + "prompt": prompt, + **full_params, + } + ) response = self.client.text_generation(prompt, **full_params) @@ -94,12 +115,26 @@ def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) def _base_params(self, prompt_stack: PromptStack) -> dict: - return { + params = { "return_full_text": False, "max_new_tokens": self.max_tokens, **self.extra_params, } + if ( + prompt_stack.output_schema + and self.use_native_structured_output + and self.native_structured_output_mode == "native" + ): + # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding + output_schema = prompt_stack.output_schema.json_schema("Output Schema") + # Grammar does not support $schema and $id + del output_schema["$schema"] + del output_schema["$id"] + params["grammar"] = {"type": "json", "value": output_schema} + + return params + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] for message in prompt_stack.messages: diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 5cbba1fdf..c49dc30d8 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -68,6 +68,7 @@ class OllamaPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() @@ -79,7 +80,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) response = self.client.chat(**params) - logger.debug(response) + logger.debug(response.model_dump()) return Message( content=self.__to_prompt_stack_message_content(response), @@ -102,20 +103,26 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: def _base_params(self, prompt_stack: PromptStack) -> dict: messages = self._prompt_stack_to_messages(prompt_stack) - return { + params = { "messages": messages, "model": self.model, "options": self.options, - **( - {"tools": self.__to_ollama_tools(prompt_stack.tools)} - if prompt_stack.tools - and self.use_native_tools - and not self.stream # Tool calling is only supported when not streaming - else {} - ), **self.extra_params, } + if prompt_stack.output_schema is not None and self.use_native_structured_output: + if self.native_structured_output_mode == "native": + params["format"] = prompt_stack.output_schema.json_schema("Output") + elif self.native_structured_output_mode == "tool": + # TODO: Implement tool choice once supported + self._add_structured_output_tool(prompt_stack) + + # Tool calling is only supported when not streaming + if prompt_stack.tools and self.use_native_tools and not self.stream: + params["tools"] = self.__to_ollama_tools(prompt_stack.tools) + + return params + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: ollama_messages = [] for message in prompt_stack.messages: diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index eed0e35f0..cb6d01b40 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -76,6 +76,7 @@ class OpenAiChatPromptDriver(BasePromptDriver): seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True}) ignored_exception_types: tuple[type[Exception], ...] = field( default=Factory( @@ -148,22 +149,41 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "temperature": self.temperature, "user": self.user, "seed": self.seed, - **( - { - "tools": self.__to_openai_tools(prompt_stack.tools), - "tool_choice": self.tool_choice, - "parallel_tool_calls": self.parallel_tool_calls, - } - if prompt_stack.tools and self.use_native_tools - else {} - ), **({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}), **({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}), **({"stream_options": {"include_usage": True}} if self.stream else {}), **self.extra_params, } - if self.response_format is not None: + if prompt_stack.tools and self.use_native_tools: + params["tools"] = self.__to_openai_tools(prompt_stack.tools) + params["tool_choice"] = self.tool_choice + params["parallel_tool_calls"] = self.parallel_tool_calls + + if self.native_structured_output_mode == "tool": + params["tool_choice"] = "required" + + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.native_structured_output_mode == "native" + ): + params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": prompt_stack.output_schema.json_schema("Output"), + "strict": True, + }, + } + + if prompt_stack.tools and self.use_native_tools: + params["tools"] = self.__to_openai_tools(prompt_stack.tools) + params["parallel_tool_calls"] = self.parallel_tool_calls + if "tool_choice" not in params: + params["tool_choice"] = self.tool_choice + + if self.response_format is not None and "response_format" not in params: if self.response_format == {"type": "json_object"}: params["response_format"] = self.response_format # JSON mode still requires a system message instructing the LLM to output JSON. diff --git a/griptape/rules/json_schema_rule.py b/griptape/rules/json_schema_rule.py index c068eb4a1..84a700ce8 100644 --- a/griptape/rules/json_schema_rule.py +++ b/griptape/rules/json_schema_rule.py @@ -1,17 +1,22 @@ from __future__ import annotations import json +from typing import TYPE_CHECKING, Union from attrs import Factory, define, field from griptape.rules import BaseRule from griptape.utils import J2 +if TYPE_CHECKING: + from schema import Schema + @define() class JsonSchemaRule(BaseRule): - value: dict = field(metadata={"serializable": True}) + value: Union[dict, Schema] = field(metadata={"serializable": True}) generate_template: J2 = field(default=Factory(lambda: J2("rules/json_schema.j2"))) def to_text(self) -> str: - return self.generate_template.render(json_schema=json.dumps(self.value)) + value = self.value if isinstance(self.value, dict) else self.value.json_schema("Output Schema") + return self.generate_template.render(json_schema=json.dumps(value)) diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 7b23c620f..9217f26c2 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -151,6 +151,8 @@ def _resolve_types(cls, attrs_cls: type) -> None: from collections.abc import Sequence from typing import Any + from schema import Schema + from griptape.artifacts import BaseArtifact from griptape.common import ( BaseDeltaMessageContent, @@ -215,6 +217,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: "BaseRule": BaseRule, "Ruleset": Ruleset, # Third party modules + "Schema": Schema, "Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any, "ClientV2": import_optional_dependency("cohere").ClientV2 if is_dependency_installed("cohere") else Any, "GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 6f9d70053..c889554fd 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -9,12 +9,13 @@ from attrs import define, field from griptape import utils -from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, JsonArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction from griptape.configs import Defaults from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask +from griptape.tools.structured_output.tool import StructuredOutputTool from griptape.utils import remove_null_values_in_dict_recursively, with_contextvars if TYPE_CHECKING: @@ -87,6 +88,14 @@ def attach_to(self, parent_task: BaseTask) -> None: self.__init_from_prompt(self.input.to_text()) else: self.__init_from_artifacts(self.input) + + structured_outputs = [a for a in self.actions if isinstance(a.tool, StructuredOutputTool)] + if structured_outputs: + output_values = [JsonArtifact(a.input["values"]) for a in structured_outputs] + if len(structured_outputs) > 1: + self.output = ListArtifact(output_values) + else: + self.output = output_values[0] except Exception as e: logger.error("Subtask %s\nError parsing tool action: %s", self.origin_task.id, e) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 4ed313bf0..5f3a71da6 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union from attrs import NOTHING, Attribute, Factory, NothingType, define, field +from schema import Or, Schema from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact @@ -13,13 +14,11 @@ from griptape.memory.structure import Run from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin from griptape.mixins.rule_mixin import RuleMixin -from griptape.rules import Ruleset +from griptape.rules import JsonSchemaRule, Ruleset from griptape.tasks import ActionsSubtask, BaseTask from griptape.utils import J2 if TYPE_CHECKING: - from schema import Schema - from griptape.drivers import BasePromptDriver from griptape.memory import TaskMemory from griptape.memory.structure.base_conversation_memory import BaseConversationMemory @@ -92,54 +91,26 @@ def prompt_stack(self) -> PromptStack: stack = PromptStack(tools=self.tools) memory = self.structure.conversation_memory if self.structure is not None else None - system_template = self.generate_system_template(self) - if system_template: - stack.add_system_message(system_template) + system_contents = [TextArtifact(self.generate_system_template(self))] + if self.prompt_driver.use_native_structured_output: + self._add_native_schema_to_prompt_stack(stack) + else: + system_contents.append(TextArtifact(J2("rulesets/rulesets.j2").render(rulesets=self.rulesets))) + + has_system_content = any(system_content.value for system_content in system_contents) + + if has_system_content: + stack.add_system_message(ListArtifact(system_contents)) stack.add_user_message(self.input) if self.output: stack.add_assistant_message(self.output.to_text()) else: - for s in self.subtasks: - if self.prompt_driver.use_native_tools: - action_calls = [ - ToolAction(name=action.name, path=action.path, tag=action.tag, input=action.input) - for action in s.actions - ] - action_results = [ - ToolAction( - name=action.name, - path=action.path, - tag=action.tag, - output=action.output if action.output is not None else s.output, - ) - for action in s.actions - ] - - stack.add_assistant_message( - ListArtifact( - [ - *([TextArtifact(s.thought)] if s.thought else []), - *[ActionArtifact(a) for a in action_calls], - ], - ), - ) - stack.add_user_message( - ListArtifact( - [ - *[ActionArtifact(a) for a in action_results], - *([] if s.output else [TextArtifact("Please keep going")]), - ], - ), - ) - else: - stack.add_assistant_message(self.generate_assistant_subtask_template(s)) - stack.add_user_message(self.generate_user_subtask_template(s)) - + self._add_subtasks_to_prompt_stack(stack) if memory is not None: # inserting at index 1 to place memory right after system prompt - memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_template else 0) + memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if has_system_content else 0) return stack @@ -321,3 +292,54 @@ def _process_task_input( return ListArtifact([self._process_task_input(elem) for elem in task_input]) else: return self._process_task_input(TextArtifact(task_input)) + + def _add_native_schema_to_prompt_stack(self, stack: PromptStack) -> None: + json_schema_rules = [ + rule for ruleset in self.rulesets for rule in ruleset.rules if isinstance(rule, JsonSchemaRule) + ] + schemas = [rule.value for rule in json_schema_rules if isinstance(rule.value, Schema)] + + if len(json_schema_rules) != len(schemas): + logger.warning("`JsonSchemaRule` must take a `Schema` to be used in native structured output mode") + + json_schema = schemas[0] if len(schemas) == 1 else Schema(Or(*schemas)) + + if schemas: + stack.output_schema = json_schema + + def _add_subtasks_to_prompt_stack(self, stack: PromptStack) -> None: + for s in self.subtasks: + if self.prompt_driver.use_native_tools: + action_calls = [ + ToolAction(name=action.name, path=action.path, tag=action.tag, input=action.input) + for action in s.actions + ] + action_results = [ + ToolAction( + name=action.name, + path=action.path, + tag=action.tag, + output=action.output if action.output is not None else s.output, + ) + for action in s.actions + ] + + stack.add_assistant_message( + ListArtifact( + [ + *([TextArtifact(s.thought)] if s.thought else []), + *[ActionArtifact(a) for a in action_calls], + ], + ), + ) + stack.add_user_message( + ListArtifact( + [ + *[ActionArtifact(a) for a in action_results], + *([] if s.output else [TextArtifact("Please keep going")]), + ], + ), + ) + else: + stack.add_assistant_message(self.generate_assistant_subtask_template(s)) + stack.add_user_message(self.generate_user_subtask_template(s)) diff --git a/griptape/tools/__init__.py b/griptape/tools/__init__.py index 67a1712a1..ec9cbd5b7 100644 --- a/griptape/tools/__init__.py +++ b/griptape/tools/__init__.py @@ -23,6 +23,7 @@ from .extraction.tool import ExtractionTool from .prompt_summary.tool import PromptSummaryTool from .query.tool import QueryTool +from .structured_output.tool import StructuredOutputTool __all__ = [ "BaseTool", @@ -50,4 +51,5 @@ "ExtractionTool", "PromptSummaryTool", "QueryTool", + "StructuredOutputTool", ] diff --git a/griptape/tools/structured_output/__init__.py b/griptape/tools/structured_output/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/tools/structured_output/tool.py b/griptape/tools/structured_output/tool.py new file mode 100644 index 000000000..89e638f59 --- /dev/null +++ b/griptape/tools/structured_output/tool.py @@ -0,0 +1,20 @@ +from attrs import define, field +from schema import Schema + +from griptape.artifacts import BaseArtifact, JsonArtifact +from griptape.tools import BaseTool +from griptape.utils.decorators import activity + + +@define +class StructuredOutputTool(BaseTool): + output_schema: Schema = field(kw_only=True) + + @activity( + config={ + "description": "Used to provide the final response which ends this conversation.", + "schema": lambda self: self.output_schema, + } + ) + def provide_output(self, params: dict) -> BaseArtifact: + return JsonArtifact(params["values"]) diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index 52408922c..787c033fb 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -51,6 +51,8 @@ def test_to_dict(self, config): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_mode": "tool", "extra_params": {}, }, "vector_store_driver": { @@ -106,6 +108,8 @@ def test_to_dict_with_values(self, config_with_values): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_mode": "tool", "extra_params": {}, }, "vector_store_driver": { diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index 8a6f25ef2..5511ae0bf 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -25,6 +25,8 @@ def test_to_dict(self, config): "top_p": 0.999, "top_k": 250, "use_native_tools": True, + "native_structured_output_mode": "tool", + "use_native_structured_output": True, "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index 4c44113a0..9b3ce26e5 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -36,6 +36,8 @@ def test_to_dict(self, config): "stream": False, "user": "", "use_native_tools": True, + "native_structured_output_mode": "native", + "use_native_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 65295da52..9514b851f 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -26,6 +26,8 @@ def test_to_dict(self, config): "model": "command-r", "force_single_step": False, "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_mode": "native", "extra_params": {}, }, "embedding_driver": { diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index ca3cea60e..f9ad7a188 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -18,6 +18,8 @@ def test_to_dict(self, config): "max_tokens": None, "stream": False, "use_native_tools": False, + "use_native_structured_output": False, + "native_structured_output_mode": "native", "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index c1459a400..add44b294 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -25,6 +25,8 @@ def test_to_dict(self, config): "top_k": None, "tool_choice": "auto", "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_mode": "tool", "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index c71774b26..d1abfa5fe 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -28,6 +28,8 @@ def test_to_dict(self, config): "stream": False, "user": "", "use_native_tools": True, + "native_structured_output_mode": "native", + "use_native_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index ada192aae..d8691f210 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -1,4 +1,5 @@ import pytest +from schema import Schema from griptape.artifacts import ActionArtifact, ErrorArtifact, ImageArtifact, ListArtifact, TextArtifact from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction @@ -7,6 +8,29 @@ class TestAmazonBedrockPromptDriver: + BEDROCK_STRUCTURED_OUTPUT_TOOL = { + "toolSpec": { + "description": "Used to provide the final response which ends this conversation.", + "inputSchema": { + "json": { + "$id": "http://json-schema.org/draft-07/schema#", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + }, + "name": "StructuredOutputTool_provide_output", + }, + } BEDROCK_TOOLS = [ { "toolSpec": { @@ -229,6 +253,7 @@ def mock_converse_stream(self, mocker): def prompt_stack(self, request): prompt_stack = PromptStack() prompt_stack.tools = [MockTool()] + prompt_stack.output_schema = Schema({"foo": str}) if request.param: prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -357,10 +382,14 @@ def messages(self): ] @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools, use_native_structured_output): # Given driver = AmazonBedrockPromptDriver( - model="ai21.j2", use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="ai21.j2", + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + extra_params={"foo": "bar"}, ) # When @@ -377,7 +406,19 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): additionalModelRequestFields={}, **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), **( - {"toolConfig": {"tools": self.BEDROCK_TOOLS, "toolChoice": driver.tool_choice}} + { + "toolConfig": { + "tools": [ + *self.BEDROCK_TOOLS, + *( + [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.native_structured_output_mode == "tool" + else [] + ), + ], + "toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice, + } + } if use_native_tools else {} ), @@ -394,10 +435,17 @@ def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream_run( + self, mock_converse_stream, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = AmazonBedrockPromptDriver( - model="ai21.j2", stream=True, use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="ai21.j2", + stream=True, + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + extra_params={"foo": "bar"}, ) # When @@ -415,8 +463,20 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_ additionalModelRequestFields={}, **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), **( - {"toolConfig": {"tools": self.BEDROCK_TOOLS, "toolChoice": driver.tool_choice}} - if prompt_stack.tools and use_native_tools + { + "toolConfig": { + "tools": [ + *self.BEDROCK_TOOLS, + *( + [self.BEDROCK_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.native_structured_output_mode == "tool" + else [] + ), + ], + "toolChoice": {"any": {}} if use_native_structured_output else driver.tool_choice, + } + } + if use_native_tools else {} ), foo="bar", @@ -439,3 +499,11 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_ event = next(stream) assert event.usage.input_tokens == 5 assert event.usage.output_tokens == 10 + + def test_verify_native_structured_output_mode(self): + assert AmazonBedrockPromptDriver(model="foo", native_structured_output_mode="tool") + + with pytest.raises( + ValueError, match="AmazonBedrockPromptDriver does not support `native` structured output mode." + ): + AmazonBedrockPromptDriver(model="foo", native_structured_output_mode="native") diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 2b84b5a17..36a1f02ce 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -1,6 +1,7 @@ from unittest.mock import Mock import pytest +from schema import Schema from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact from griptape.artifacts.error_artifact import ErrorArtifact @@ -141,6 +142,24 @@ class TestAnthropicPromptDriver: }, ] + ANTHROPIC_STRUCTURED_OUTPUT_TOOL = { + "description": "Used to provide the final response which ends this conversation.", + "input_schema": { + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + "name": "StructuredOutputTool_provide_output", + } + @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("anthropic.Anthropic") @@ -199,6 +218,7 @@ def mock_stream_client(self, mocker): @pytest.fixture(params=[True, False]) def prompt_stack(self, request): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] if request.param: prompt_stack.add_system_message("system-input") @@ -343,10 +363,15 @@ def test_init(self): assert AnthropicPromptDriver(model="claude-3-haiku", api_key="1234") @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, use_native_structured_output): # Given driver = AnthropicPromptDriver( - model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="claude-3-haiku", + api_key="api-key", + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + extra_params={"foo": "bar"}, ) # When @@ -362,7 +387,21 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): top_p=0.999, top_k=250, **{"system": "system-input"} if prompt_stack.system_messages else {}, - **{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, + **{ + "tools": [ + *self.ANTHROPIC_TOOLS, + *( + [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.native_structured_output_mode == "tool" + else [] + ), + ] + if use_native_tools + else {}, + "tool_choice": {"type": "any"} if use_native_structured_output else driver.tool_choice, + } + if use_native_tools + else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -376,13 +415,17 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream_run( + self, mock_stream_client, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = AnthropicPromptDriver( model="claude-3-haiku", api_key="api-key", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"foo": "bar"}, ) @@ -401,7 +444,21 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na top_p=0.999, top_k=250, **{"system": "system-input"} if prompt_stack.system_messages else {}, - **{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, + **{ + "tools": [ + *self.ANTHROPIC_TOOLS, + *( + [self.ANTHROPIC_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and driver.native_structured_output_mode == "tool" + else [] + ), + ] + if use_native_tools + else {}, + "tool_choice": {"type": "any"} if use_native_structured_output else driver.tool_choice, + } + if use_native_tools + else {}, foo="bar", ) assert event.usage.input_tokens == 5 @@ -426,3 +483,9 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na event = next(stream) assert event.usage.output_tokens == 10 + + def test_verify_native_structured_output_mode(self): + assert AnthropicPromptDriver(model="foo", native_structured_output_mode="tool") + + with pytest.raises(ValueError, match="AnthropicPromptDriver does not support `native` structured output mode."): + AnthropicPromptDriver(model="foo", native_structured_output_mode="native") diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 9b7c24a98..34c988bd2 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -2,6 +2,7 @@ from unittest.mock import Mock import pytest +from schema import Schema from griptape.artifacts.action_artifact import ActionArtifact from griptape.artifacts.list_artifact import ListArtifact @@ -12,6 +13,36 @@ class TestCoherePromptDriver: + COHERE_STRUCTURED_OUTPUT_SCHEMA = { + "$id": "Output", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } + COHERE_STRUCTURED_OUTPUT_TOOL = { + "function": { + "description": "Used to provide the final response which ends this conversation.", + "name": "StructuredOutputTool_provide_output", + "parameters": { + "$id": "Parameters Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + }, + "type": "function", + } COHERE_TOOLS = [ { "function": { @@ -242,6 +273,7 @@ def mock_tokenizer(self, mocker): @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -306,10 +338,25 @@ def test_init(self): assert CoherePromptDriver(model="command", api_key="foobar") @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_mode", ["native", "tool"]) + def test_try_run( + self, + mock_client, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_mode, + ): # Given driver = CoherePromptDriver( - model="command", api_key="api-key", use_native_tools=use_native_tools, extra_params={"foo": "bar"} + model="command", + api_key="api-key", + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_mode=native_structured_output_mode, + extra_params={"foo": "bar"}, ) # When @@ -320,7 +367,26 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): model="command", messages=messages, max_tokens=None, - **({"tools": self.COHERE_TOOLS} if use_native_tools else {}), + **{ + "tools": [ + *self.COHERE_TOOLS, + *( + [self.COHERE_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_mode == "tool" + else [] + ), + ] + } + if use_native_tools + else {}, + **{ + "response_format": { + "type": "json_object", + "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, + } + } + if use_native_structured_output and native_structured_output_mode == "native" + else {}, stop_sequences=[], temperature=0.1, foo="bar", @@ -340,13 +406,25 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_mode", ["native", "tool"]) + def test_try_stream_run( + self, + mock_stream_client, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_mode, + ): # Given driver = CoherePromptDriver( model="command", api_key="api-key", stream=True, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_mode=native_structured_output_mode, extra_params={"foo": "bar"}, ) @@ -359,7 +437,26 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_na model="command", messages=messages, max_tokens=None, - **({"tools": self.COHERE_TOOLS} if use_native_tools else {}), + **{ + "tools": [ + *self.COHERE_TOOLS, + *( + [self.COHERE_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_mode == "tool" + else [] + ), + ] + } + if use_native_tools + else {}, + **{ + "response_format": { + "type": "json_object", + "schema": self.COHERE_STRUCTURED_OUTPUT_SCHEMA, + } + } + if use_native_structured_output and native_structured_output_mode == "native" + else {}, stop_sequences=[], temperature=0.1, foo="bar", diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 72cf51d03..927f680dc 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -4,6 +4,7 @@ from google.generativeai.protos import FunctionCall, FunctionResponse, Part from google.generativeai.types import ContentDict, GenerationConfig from google.protobuf.json_format import MessageToDict +from schema import Schema from griptape.artifacts import ActionArtifact, GenericArtifact, ImageArtifact, TextArtifact from griptape.artifacts.list_artifact import ListArtifact @@ -13,6 +14,15 @@ class TestGooglePromptDriver: + GOOGLE_STRUCTURED_OUTPUT_TOOL = { + "description": "Used to provide the final response which ends this conversation.", + "name": "StructuredOutputTool_provide_output", + "parameters": { + "properties": {"foo": {"type": "STRING"}}, + "required": ["foo"], + "type": "OBJECT", + }, + } GOOGLE_TOOLS = [ { "name": "MockTool_test", @@ -100,6 +110,7 @@ def mock_stream_generative_model(self, mocker): @pytest.fixture(params=[True, False]) def prompt_stack(self, request): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] if request.param: prompt_stack.add_system_message("system-input") @@ -166,7 +177,10 @@ def test_init(self): assert driver @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run( + self, mock_generative_model, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -174,6 +188,8 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native top_p=0.5, top_k=50, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_mode="tool", extra_params={"max_output_tokens": 10}, ) @@ -195,9 +211,14 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native ) if use_native_tools: tool_declarations = call_args.kwargs["tools"] - assert [ - MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations - ] == self.GOOGLE_TOOLS + tools = [ + *self.GOOGLE_TOOLS, + *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_native_structured_output else []), + ] + assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools + + if use_native_structured_output: + assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(message.value[0], TextArtifact) assert message.value[0].value == "model-output" @@ -210,7 +231,10 @@ def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native assert message.usage.output_tokens == 10 @pytest.mark.parametrize("use_native_tools", [True, False]) - def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream( + self, mock_stream_generative_model, prompt_stack, messages, use_native_tools, use_native_structured_output + ): # Given driver = GooglePromptDriver( model="gemini-pro", @@ -219,6 +243,7 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, top_p=0.5, top_k=50, use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, extra_params={"max_output_tokens": 10}, ) @@ -242,9 +267,14 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, ) if use_native_tools: tool_declarations = call_args.kwargs["tools"] - assert [ - MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations - ] == self.GOOGLE_TOOLS + tools = [ + *self.GOOGLE_TOOLS, + *([self.GOOGLE_STRUCTURED_OUTPUT_TOOL] if use_native_structured_output else []), + ] + assert [MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations] == tools + + if use_native_structured_output: + assert call_args.kwargs["tool_config"] == {"function_calling_config": {"mode": "auto"}} assert isinstance(event.content, TextDeltaMessageContent) assert event.content.text == "model-output" assert event.usage.input_tokens == 5 @@ -259,3 +289,9 @@ def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, event = next(stream) assert event.usage.output_tokens == 5 + + def test_verify_native_structured_output_mode(self): + assert GooglePromptDriver(model="foo", native_structured_output_mode="tool") + + with pytest.raises(ValueError, match="GooglePromptDriver does not support `native` structured output mode."): + GooglePromptDriver(model="foo", native_structured_output_mode="native") diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 4b7aa4d13..1fdddbbd1 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -1,10 +1,18 @@ import pytest +from schema import Schema from griptape.common import PromptStack, TextDeltaMessageContent from griptape.drivers import HuggingFaceHubPromptDriver class TestHuggingFaceHubPromptDriver: + HUGGINGFACE_HUB_OUTPUT_SCHEMA = { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } + @pytest.fixture() def mock_client(self, mocker): mock_client = mocker.patch("huggingface_hub.InferenceClient").return_value @@ -31,6 +39,7 @@ def mock_client_stream(self, mocker): @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") prompt_stack.add_assistant_message("assistant-input") @@ -45,9 +54,15 @@ def mock_autotokenizer(self, mocker): def test_init(self): assert HuggingFaceHubPromptDriver(api_token="foobar", model="gpt2") - def test_try_run(self, prompt_stack, mock_client): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_run(self, prompt_stack, mock_client, use_native_structured_output): # Given - driver = HuggingFaceHubPromptDriver(api_token="api-token", model="repo-id", extra_params={"foo": "bar"}) + driver = HuggingFaceHubPromptDriver( + api_token="api-token", + model="repo-id", + use_native_structured_output=use_native_structured_output, + extra_params={"foo": "bar"}, + ) # When message = driver.try_run(prompt_stack) @@ -58,15 +73,23 @@ def test_try_run(self, prompt_stack, mock_client): return_full_text=False, max_new_tokens=250, foo="bar", + **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} + if use_native_structured_output + else {}, ) assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - def test_try_stream(self, prompt_stack, mock_client_stream): + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + def test_try_stream(self, prompt_stack, mock_client_stream, use_native_structured_output): # Given driver = HuggingFaceHubPromptDriver( - api_token="api-token", model="repo-id", stream=True, extra_params={"foo": "bar"} + api_token="api-token", + model="repo-id", + stream=True, + use_native_structured_output=use_native_structured_output, + extra_params={"foo": "bar"}, ) # When @@ -79,6 +102,9 @@ def test_try_stream(self, prompt_stack, mock_client_stream): return_full_text=False, max_new_tokens=250, foo="bar", + **{"grammar": {"type": "json", "value": self.HUGGINGFACE_HUB_OUTPUT_SCHEMA}} + if use_native_structured_output + else {}, stream=True, ) assert isinstance(event.content, TextDeltaMessageContent) @@ -87,3 +113,11 @@ def test_try_stream(self, prompt_stack, mock_client_stream): event = next(stream) assert event.usage.input_tokens == 3 assert event.usage.output_tokens == 3 + + def test_verify_native_structured_output_mode(self): + assert HuggingFaceHubPromptDriver(model="foo", api_token="bar", native_structured_output_mode="native") + + with pytest.raises( + ValueError, match="HuggingFaceHubPromptDriver does not support `tool` structured output mode." + ): + HuggingFaceHubPromptDriver(model="foo", api_token="bar", native_structured_output_mode="tool") diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index 51a3dbb77..b339a3785 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -1,4 +1,5 @@ import pytest +from schema import Schema from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact from griptape.common import PromptStack, TextDeltaMessageContent, ToolAction @@ -7,6 +8,27 @@ class TestOllamaPromptDriver: + OLLAMA_STRUCTURED_OUTPUT_SCHEMA = { + "$id": "Output", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } + OLLAMA_STRUCTURED_OUTPUT_TOOL = { + "function": { + "description": "Used to provide the final response which ends this conversation.", + "name": "StructuredOutputTool_provide_output", + "parameters": { + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + }, + }, + "type": "function", + } OLLAMA_TOOLS = [ { "function": { @@ -112,7 +134,9 @@ class TestOllamaPromptDriver: def mock_client(self, mocker): mock_client = mocker.patch("ollama.Client") - mock_client.return_value.chat.return_value = { + mock_response = mocker.MagicMock() + + data = { "message": { "content": "model-output", "tool_calls": [ @@ -126,6 +150,10 @@ def mock_client(self, mocker): }, } + mock_response.__getitem__.side_effect = lambda key: data[key] + mock_response.model_dump.return_value = data + mock_client.return_value.chat.return_value = mock_response + return mock_client @pytest.fixture() @@ -138,6 +166,7 @@ def mock_stream_client(self, mocker): @pytest.fixture() def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) prompt_stack.tools = [MockTool()] prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") @@ -202,10 +231,26 @@ def messages(self): def test_init(self): assert OllamaPromptDriver(model="llama") - @pytest.mark.parametrize("use_native_tools", [True]) - def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_mode", ["native", "tool"]) + def test_try_run( + self, + mock_client, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_mode, + ): # Given - driver = OllamaPromptDriver(model="llama", extra_params={"foo": "bar"}) + driver = OllamaPromptDriver( + model="llama", + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_mode=native_structured_output_mode, + extra_params={"foo": "bar"}, + ) # When message = driver.try_run(prompt_stack) @@ -219,7 +264,21 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): "stop": [], "num_predict": driver.max_tokens, }, - **{"tools": self.OLLAMA_TOOLS} if use_native_tools else {}, + **{ + "tools": [ + *self.OLLAMA_TOOLS, + *( + [self.OLLAMA_STRUCTURED_OUTPUT_TOOL] + if use_native_structured_output and native_structured_output_mode == "tool" + else [] + ), + ] + } + if use_native_tools + else {}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} + if use_native_structured_output and native_structured_output_mode == "native" + else {}, foo="bar", ) assert isinstance(message.value[0], TextArtifact) @@ -230,33 +289,39 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): assert message.value[1].value.path == "test" assert message.value[1].value.input == {"foo": "bar"} - def test_try_stream_run(self, mock_stream_client): + @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("use_native_structured_output", [True, False]) + @pytest.mark.parametrize("native_structured_output_mode", ["native", "tool"]) + def test_try_stream_run( + self, + mock_stream_client, + prompt_stack, + messages, + use_native_tools, + use_native_structured_output, + native_structured_output_mode, + ): # Given - prompt_stack = PromptStack() - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message( - ListArtifact( - [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] - ) + driver = OllamaPromptDriver( + model="llama", + stream=True, + use_native_tools=use_native_tools, + use_native_structured_output=use_native_structured_output, + native_structured_output_mode=native_structured_output_mode, + extra_params={"foo": "bar"}, ) - prompt_stack.add_assistant_message("assistant-input") - expected_messages = [ - {"role": "system", "content": "system-input"}, - {"role": "user", "content": "user-input"}, - {"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]}, - {"role": "assistant", "content": "assistant-input"}, - ] - driver = OllamaPromptDriver(model="llama", stream=True, extra_params={"foo": "bar"}) # When text_artifact = next(driver.try_stream(prompt_stack)) # Then mock_stream_client.return_value.chat.assert_called_once_with( - messages=expected_messages, + messages=messages, model=driver.model, options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, + **{"format": self.OLLAMA_STRUCTURED_OUTPUT_SCHEMA} + if use_native_structured_output and native_structured_output_mode == "native" + else {}, stream=True, foo="bar", ) diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 3b0e508e0..1b6088d1e 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -82,6 +82,8 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, + "use_native_structured_output": False, + "native_structured_output_mode": "native", }, } ], diff --git a/tests/unit/tasks/test_actions_subtask.py b/tests/unit/tasks/test_actions_subtask.py index e7d44b5af..764c3440c 100644 --- a/tests/unit/tasks/test_actions_subtask.py +++ b/tests/unit/tasks/test_actions_subtask.py @@ -4,9 +4,10 @@ from griptape.artifacts import ActionArtifact, ListArtifact, TextArtifact from griptape.artifacts.error_artifact import ErrorArtifact +from griptape.artifacts.json_artifact import JsonArtifact from griptape.common import ToolAction from griptape.structures import Agent -from griptape.tasks import ActionsSubtask, PromptTask +from griptape.tasks import ActionsSubtask, PromptTask, ToolkitTask from tests.mocks.mock_tool.tool import MockTool @@ -257,3 +258,68 @@ def test_origin_task(self): with pytest.raises(Exception, match="ActionSubtask has no origin task."): assert ActionsSubtask("test").origin_task + + def test_structured_output_tool(self): + import schema + + from griptape.tools.structured_output.tool import StructuredOutputTool + + actions = ListArtifact( + [ + ActionArtifact( + ToolAction( + tag="foo", + name="StructuredOutputTool", + path="provide_output", + input={"values": {"test": "value"}}, + ) + ), + ] + ) + + task = ToolkitTask(tools=[StructuredOutputTool(output_schema=schema.Schema({"test": str}))]) + Agent().add_task(task) + subtask = task.add_subtask(ActionsSubtask(actions)) + + assert isinstance(subtask.output, JsonArtifact) + assert subtask.output.value == {"test": "value"} + + def test_structured_output_tool_multiple(self): + import schema + + from griptape.tools.structured_output.tool import StructuredOutputTool + + actions = ListArtifact( + [ + ActionArtifact( + ToolAction( + tag="foo", + name="StructuredOutputTool1", + path="provide_output", + input={"values": {"test1": "value"}}, + ) + ), + ActionArtifact( + ToolAction( + tag="foo", + name="StructuredOutputTool2", + path="provide_output", + input={"values": {"test2": "value"}}, + ) + ), + ] + ) + + task = ToolkitTask( + tools=[ + StructuredOutputTool(name="StructuredOutputTool1", output_schema=schema.Schema({"test": str})), + StructuredOutputTool(name="StructuredOutputTool2", output_schema=schema.Schema({"test": str})), + ] + ) + Agent().add_task(task) + subtask = task.add_subtask(ActionsSubtask(actions)) + + assert isinstance(subtask.output, ListArtifact) + assert len(subtask.output.value) == 2 + assert subtask.output.value[0].value == {"test1": "value"} + assert subtask.output.value[1].value == {"test2": "value"} diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index ca0576ebe..ce3530193 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -257,6 +257,8 @@ def test_to_dict(self): "stream": False, "temperature": 0.1, "type": "MockPromptDriver", + "native_structured_output_mode": "native", + "use_native_structured_output": False, "use_native_tools": False, }, "tool": { diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 3c17ff479..71637b775 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -399,6 +399,8 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, + "use_native_structured_output": False, + "native_structured_output_mode": "native", }, "tools": [ { diff --git a/tests/unit/tools/test_structured_output_tool.py b/tests/unit/tools/test_structured_output_tool.py new file mode 100644 index 000000000..d310b2f9b --- /dev/null +++ b/tests/unit/tools/test_structured_output_tool.py @@ -0,0 +1,13 @@ +import pytest +import schema + +from griptape.tools import StructuredOutputTool + + +class TestStructuredOutputTool: + @pytest.fixture() + def tool(self): + return StructuredOutputTool(output_schema=schema.Schema({"foo": "bar"})) + + def test_provide_output(self, tool): + assert tool.provide_output({"values": {"foo": "bar"}}).value == {"foo": "bar"}