From 997ef3d68b3b0f1e694c9e599635ca706610e796 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 23 Dec 2024 14:28:19 -0800 Subject: [PATCH] Fix tests --- .../prompt/amazon_bedrock_prompt_driver.py | 27 ++-- .../drivers/prompt/anthropic_prompt_driver.py | 12 +- griptape/drivers/prompt/base_prompt_driver.py | 7 -- .../drivers/prompt/cohere_prompt_driver.py | 31 +++-- .../drivers/prompt/google_prompt_driver.py | 29 +++-- .../prompt/huggingface_hub_prompt_driver.py | 7 +- .../drivers/prompt/ollama_prompt_driver.py | 15 +-- .../prompt/openai_chat_prompt_driver.py | 3 +- griptape/mixins/rule_mixin.py | 3 - griptape/schemas/base_schema.py | 3 + griptape/structures/agent.py | 10 +- griptape/tasks/prompt_task.py | 115 ++++++++++-------- .../test_amazon_bedrock_drivers_config.py | 4 + .../drivers/test_anthropic_drivers_config.py | 2 + .../test_azure_openai_drivers_config.py | 2 + .../drivers/test_cohere_drivers_config.py | 2 + .../configs/drivers/test_drivers_config.py | 2 + .../drivers/test_google_drivers_config.py | 2 + .../drivers/test_openai_driver_config.py | 2 + .../prompt/test_ollama_prompt_driver.py | 8 +- tests/unit/structures/test_agent.py | 8 -- tests/unit/structures/test_structure.py | 2 + tests/unit/tasks/test_tool_task.py | 2 + tests/unit/tasks/test_toolkit_task.py | 2 + 24 files changed, 163 insertions(+), 137 deletions(-) diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 1f9d4ed4d..ce8237416 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -55,6 +55,7 @@ class AmazonBedrockPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) native_structured_output_mode: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) @@ -105,8 +106,6 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: raise Exception("model response is empty") def _base_params(self, prompt_stack: PromptStack) -> dict: - from griptape.tools.structured_output.tool import StructuredOutputTool - system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages] messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()]) @@ -119,23 +118,21 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}), }, "additionalModelRequestFields": self.additional_model_request_fields, + **( + {"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}} + if prompt_stack.tools and self.use_native_tools + else {} + ), **self.extra_params, } - if prompt_stack.output_schema is not None and self.use_native_structured_output: - if self.native_structured_output_mode == "tool": - structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) - params["toolConfig"] = { - "toolChoice": {"any": {}}, - } - if structured_output_tool not in prompt_stack.tools: - prompt_stack.tools.append(structured_output_tool) - else: - raise ValueError(f"Unsupported native structured output mode: {self.native_structured_output_mode}") - - if prompt_stack.tools and self.use_native_tools: + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.native_structured_output_mode == "tool" + ): params["toolConfig"] = { - "tools": self.__to_bedrock_tools(prompt_stack.tools), + "toolChoice": {"any": {}}, } return params diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index af549bc71..7aaf32573 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -68,6 +68,7 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) native_structured_output_mode: Literal["native", "tool"] = field( default="tool", kw_only=True, metadata={"serializable": True} ) @@ -120,7 +121,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = prompt_stack.system_messages system_message = system_messages[0].to_text() if system_messages else None - return { + params = { "model": self.model, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, @@ -137,6 +138,15 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.native_structured_output_mode == "tool" + ): + params["tool_choice"] = {"type": "any"} + + return params + def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]: return [ {"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)} diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index eb8ba431a..4f850dc6d 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -77,18 +77,11 @@ def after_run(self, result: Message) -> None: @observable(tags=["PromptDriver.run()"]) def run(self, prompt_input: PromptStack | BaseArtifact) -> Message: - from griptape.tools.structured_output.tool import StructuredOutputTool - if isinstance(prompt_input, BaseArtifact): prompt_stack = PromptStack.from_artifact(prompt_input) else: prompt_stack = prompt_input - if not self.use_native_structured_output and prompt_stack.output_schema is not None: - structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) - if structured_ouptut_tool not in prompt_stack.tools: - prompt_stack.tools.append(structured_ouptut_tool) - for attempt in self.retrying(): with attempt: self.before_run(prompt_stack) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 0f796a04d..896c2a6ff 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -53,6 +53,7 @@ class CoherePromptDriver(BasePromptDriver): model: str = field(metadata={"serializable": True}) force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), @@ -97,8 +98,6 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event)) def _base_params(self, prompt_stack: PromptStack) -> dict: - from griptape.tools.structured_output.tool import StructuredOutputTool - tool_results = [] messages = self.__to_cohere_messages(prompt_stack.messages) @@ -110,23 +109,23 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "stop_sequences": self.tokenizer.stop_sequences, "max_tokens": self.max_tokens, **({"tool_results": tool_results} if tool_results else {}), + **( + {"tools": self.__to_cohere_tools(prompt_stack.tools)} + if prompt_stack.tools and self.use_native_tools + else {} + ), **self.extra_params, } - if prompt_stack.output_schema is not None: - if self.use_native_structured_output: - params["response_format"] = { - "type": "json_object", - "schema": prompt_stack.output_schema.json_schema("Output"), - } - else: - # This does not work great since Cohere does not support forced tool use. - structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) - if structured_output_tool not in prompt_stack.tools: - prompt_stack.tools.append(structured_output_tool) - - if prompt_stack.tools and self.use_native_tools: - params["tools"] = self.__to_cohere_tools(prompt_stack.tools) + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.native_structured_output_mode == "native" + ): + params["response_format"] = { + "type": "json_object", + "schema": prompt_stack.output_schema.json_schema("Output"), + } return params diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 8a1dc8fd3..455ce7769 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -63,6 +63,7 @@ class GooglePromptDriver(BasePromptDriver): top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True}) _client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @@ -125,8 +126,6 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ) def _base_params(self, prompt_stack: PromptStack) -> dict: - from griptape.tools.structured_output.tool import StructuredOutputTool - types = import_optional_dependency("google.generativeai.types") protos = import_optional_dependency("google.generativeai.protos") @@ -150,17 +149,23 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, }, ), + **( + { + "tools": self.__to_google_tools(prompt_stack.tools), + "tool_config": {"function_calling_config": {"mode": self.tool_choice}}, + } + if prompt_stack.tools and self.use_native_tools + else {} + ), } - if prompt_stack.output_schema is not None: - structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) - params["tool_config"] = { - "function_calling_config": {"mode": self.tool_choice}, - } - if structured_ouptut_tool not in prompt_stack.tools: - prompt_stack.tools.append(structured_ouptut_tool) - - if prompt_stack.tools and self.use_native_tools: - params["tools"] = self.__to_google_tools(prompt_stack.tools) + if prompt_stack.output_schema is not None and self.use_native_structured_output: + if self.native_structured_output_mode == "tool": + params["tool_config"] = { + "function_calling_config": {"mode": "auto"}, + } + elif self.native_structured_output_mode == "native": + # TODO: Add support for native structured output + ... return params diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 7cf751166..ca4f8b1bc 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -111,9 +111,14 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema and self.use_native_structured_output: + if ( + prompt_stack.output_schema + and self.use_native_structured_output + and self.native_structured_output_mode == "native" + ): # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding params["grammar"] = {"type": "json", "value": prompt_stack.output_schema.json_schema("Output Schema")} + # Grammar does not support $schema and $id del params["grammar"]["value"]["$schema"] del params["grammar"]["value"]["$id"] diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 383104fa5..be39558ee 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -101,8 +101,6 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: raise Exception("invalid model response") def _base_params(self, prompt_stack: PromptStack) -> dict: - from griptape.tools.structured_output.tool import StructuredOutputTool - messages = self._prompt_stack_to_messages(prompt_stack) params = { @@ -124,13 +122,12 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } - if prompt_stack.output_schema is not None: - if self.use_native_structured_output: - params["format"] = prompt_stack.output_schema.json_schema("Output") - else: - structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) - if structured_ouptut_tool not in prompt_stack.tools: - prompt_stack.tools.append(structured_ouptut_tool) + if ( + prompt_stack.output_schema is not None + and self.use_native_structured_output + and self.native_structured_output_mode == "tool" + ): + params["format"] = prompt_stack.output_schema.json_schema("Output") return params diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index f2d32a55b..815891fcc 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -149,7 +149,6 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "temperature": self.temperature, "user": self.user, "seed": self.seed, - "stop": self.tokenizer.stop_sequences, **({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}), **({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}), **({"stream_options": {"include_usage": True}} if self.stream else {}), @@ -166,7 +165,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "strict": True, }, } - else: + elif self.native_structured_output_mode == "tool": params["tool_choice"] = "required" if prompt_stack.tools and self.use_native_tools: diff --git a/griptape/mixins/rule_mixin.py b/griptape/mixins/rule_mixin.py index 6716f246c..f92679229 100644 --- a/griptape/mixins/rule_mixin.py +++ b/griptape/mixins/rule_mixin.py @@ -28,6 +28,3 @@ def rulesets(self) -> list[Ruleset]: rulesets.append(Ruleset(id=self._default_ruleset_id, name=self._default_ruleset_name, rules=self.rules)) return rulesets - - def get_rules_for_type(self, rule_type: type[T]) -> list[T]: - return [rule for ruleset in self.rulesets for rule in ruleset.rules if isinstance(rule, rule_type)] diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 7b23c620f..9217f26c2 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -151,6 +151,8 @@ def _resolve_types(cls, attrs_cls: type) -> None: from collections.abc import Sequence from typing import Any + from schema import Schema + from griptape.artifacts import BaseArtifact from griptape.common import ( BaseDeltaMessageContent, @@ -215,6 +217,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: "BaseRule": BaseRule, "Ruleset": Ruleset, # Third party modules + "Schema": Schema, "Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any, "ClientV2": import_optional_dependency("cohere").ClientV2 if is_dependency_installed("cohere") else Any, "GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index 55b7c5942..2099c6ee9 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Callable, Optional, Union from attrs import Attribute, Factory, define, evolve, field -from schema import Schema from griptape.artifacts.text_artifact import TextArtifact from griptape.common import observable @@ -12,6 +11,8 @@ from griptape.tasks import PromptTask if TYPE_CHECKING: + from schema import Schema + from griptape.artifacts import BaseArtifact from griptape.drivers import BasePromptDriver from griptape.tasks import BaseTask @@ -51,7 +52,6 @@ def __attrs_post_init__(self) -> None: prompt_driver=self.prompt_driver, tools=self.tools, max_meta_memory_entries=self.max_meta_memory_entries, - output_schema=self._build_schema_from_type(self.output_type) if self.output_type is not None else None, ) self.add_task(task) @@ -78,9 +78,3 @@ def try_run(self, *args) -> Agent: self.task.run() return self - - def _build_schema_from_type(self, output_type: type | Schema) -> Schema: - if isinstance(output_type, Schema): - return output_type - else: - return Schema(output_type) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 1453bd662..9ffa79840 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union from attrs import NOTHING, Attribute, Factory, NothingType, define, field -from schema import Schema +from schema import Or, Schema from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact @@ -94,25 +94,13 @@ def prompt_stack(self) -> PromptStack: system_contents = [TextArtifact(self.generate_system_template(self))] if self.prompt_driver.use_native_structured_output: - json_schema_rules = self.get_rules_for_type(JsonSchemaRule) - if len(json_schema_rules) > 1: - raise ValueError("Only one JSON Schema rule is allowed per task when using native structured output.") - json_schema_rule = json_schema_rules[0] if json_schema_rules else None - if json_schema_rule is not None: - if isinstance(json_schema_rule.value, Schema): - stack.output_schema = json_schema_rule.value - if self.prompt_driver.native_structured_output_mode == "native": - stack.output_schema = json_schema_rule.value - else: - stack.tools.append(StructuredOutputTool(output_schema=stack.output_schema)) - else: - raise ValueError( - "JSON Schema rule value must be of type Schema when using native structured output." - ) + self._add_native_schema_to_prompt_stack(stack) else: system_contents.append(TextArtifact(J2("rulesets/rulesets.j2").render(rulesets=self.rulesets))) - if system_contents: + has_system_content = any(system_content.value for system_content in system_contents) + + if has_system_content: stack.add_system_message(ListArtifact(system_contents)) stack.add_user_message(self.input) @@ -120,45 +108,10 @@ def prompt_stack(self) -> PromptStack: if self.output: stack.add_assistant_message(self.output.to_text()) else: - for s in self.subtasks: - if self.prompt_driver.use_native_tools: - action_calls = [ - ToolAction(name=action.name, path=action.path, tag=action.tag, input=action.input) - for action in s.actions - ] - action_results = [ - ToolAction( - name=action.name, - path=action.path, - tag=action.tag, - output=action.output if action.output is not None else s.output, - ) - for action in s.actions - ] - - stack.add_assistant_message( - ListArtifact( - [ - *([TextArtifact(s.thought)] if s.thought else []), - *[ActionArtifact(a) for a in action_calls], - ], - ), - ) - stack.add_user_message( - ListArtifact( - [ - *[ActionArtifact(a) for a in action_results], - *([] if s.output else [TextArtifact("Please keep going")]), - ], - ), - ) - else: - stack.add_assistant_message(self.generate_assistant_subtask_template(s)) - stack.add_user_message(self.generate_user_subtask_template(s)) - + self._add_subtasks_to_prompt_stack(stack) if memory is not None: # inserting at index 1 to place memory right after system prompt - memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_contents else 0) + memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if has_system_content else 0) return stack @@ -340,3 +293,57 @@ def _process_task_input( return ListArtifact([self._process_task_input(elem) for elem in task_input]) else: return self._process_task_input(TextArtifact(task_input)) + + def _add_native_schema_to_prompt_stack(self, stack: PromptStack) -> None: + json_schema_rules = [ + rule for ruleset in self.rulesets for rule in ruleset.rules if isinstance(rule, JsonSchemaRule) + ] + schemas = [rule.value for rule in json_schema_rules if isinstance(rule.value, Schema)] + + if len(json_schema_rules) != len(schemas): + logger.warning("`JsonSchemaRule` must take a `Schema` to be used in native structured output mode") + + json_schema = schemas[0] if len(schemas) == 1 else Schema(Or(*schemas)) + + if schemas: + stack.output_schema = json_schema + + if self.prompt_driver.native_structured_output_mode == "tool": + stack.tools.append(StructuredOutputTool(output_schema=json_schema)) + + def _add_subtasks_to_prompt_stack(self, stack: PromptStack) -> None: + for s in self.subtasks: + if self.prompt_driver.use_native_tools: + action_calls = [ + ToolAction(name=action.name, path=action.path, tag=action.tag, input=action.input) + for action in s.actions + ] + action_results = [ + ToolAction( + name=action.name, + path=action.path, + tag=action.tag, + output=action.output if action.output is not None else s.output, + ) + for action in s.actions + ] + + stack.add_assistant_message( + ListArtifact( + [ + *([TextArtifact(s.thought)] if s.thought else []), + *[ActionArtifact(a) for a in action_calls], + ], + ), + ) + stack.add_user_message( + ListArtifact( + [ + *[ActionArtifact(a) for a in action_results], + *([] if s.output else [TextArtifact("Please keep going")]), + ], + ), + ) + else: + stack.add_assistant_message(self.generate_assistant_subtask_template(s)) + stack.add_user_message(self.generate_user_subtask_template(s)) diff --git a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index 52408922c..787c033fb 100644 --- a/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -51,6 +51,8 @@ def test_to_dict(self, config): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_mode": "tool", "extra_params": {}, }, "vector_store_driver": { @@ -106,6 +108,8 @@ def test_to_dict_with_values(self, config_with_values): "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_mode": "tool", "extra_params": {}, }, "vector_store_driver": { diff --git a/tests/unit/configs/drivers/test_anthropic_drivers_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py index 8a6f25ef2..5511ae0bf 100644 --- a/tests/unit/configs/drivers/test_anthropic_drivers_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -25,6 +25,8 @@ def test_to_dict(self, config): "top_p": 0.999, "top_k": 250, "use_native_tools": True, + "native_structured_output_mode": "tool", + "use_native_structured_output": True, "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py index 4c44113a0..9b3ce26e5 100644 --- a/tests/unit/configs/drivers/test_azure_openai_drivers_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -36,6 +36,8 @@ def test_to_dict(self, config): "stream": False, "user": "", "use_native_tools": True, + "native_structured_output_mode": "native", + "use_native_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_cohere_drivers_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py index 65295da52..9514b851f 100644 --- a/tests/unit/configs/drivers/test_cohere_drivers_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -26,6 +26,8 @@ def test_to_dict(self, config): "model": "command-r", "force_single_step": False, "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_mode": "native", "extra_params": {}, }, "embedding_driver": { diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index ca3cea60e..f9ad7a188 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -18,6 +18,8 @@ def test_to_dict(self, config): "max_tokens": None, "stream": False, "use_native_tools": False, + "use_native_structured_output": False, + "native_structured_output_mode": "native", "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/configs/drivers/test_google_drivers_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py index c1459a400..7b149a923 100644 --- a/tests/unit/configs/drivers/test_google_drivers_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -25,6 +25,8 @@ def test_to_dict(self, config): "top_k": None, "tool_choice": "auto", "use_native_tools": True, + "use_native_structured_output": True, + "native_structured_output_mode": "native", "extra_params": {}, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, diff --git a/tests/unit/configs/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py index c71774b26..d1abfa5fe 100644 --- a/tests/unit/configs/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -28,6 +28,8 @@ def test_to_dict(self, config): "stream": False, "user": "", "use_native_tools": True, + "native_structured_output_mode": "native", + "use_native_structured_output": True, "extra_params": {}, }, "conversation_memory_driver": { diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index 51a3dbb77..96ed13148 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -112,7 +112,9 @@ class TestOllamaPromptDriver: def mock_client(self, mocker): mock_client = mocker.patch("ollama.Client") - mock_client.return_value.chat.return_value = { + mock_response = mocker.MagicMock() + + data = { "message": { "content": "model-output", "tool_calls": [ @@ -126,6 +128,10 @@ def mock_client(self, mocker): }, } + mock_response.__getitem__.side_effect = lambda key: data[key] + mock_response.model_dump.return_value = data + mock_client.return_value.chat.return_value = mock_response + return mock_client @pytest.fixture() diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 4cc2be43d..ecd1ac6dc 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -284,11 +284,3 @@ def test_stream_mutation(self): assert isinstance(agent.tasks[0], PromptTask) assert agent.tasks[0].prompt_driver.stream is True assert agent.tasks[0].prompt_driver is not prompt_driver - - def test_output_type_primitive(self): - from griptape.tools import StructuredOutputTool - - agent = Agent(output_type=str) - - assert isinstance(agent.tools[0], StructuredOutputTool) - assert agent.tools[0].output_schema == agent._build_schema_from_type(str) diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 3b0e508e0..1b6088d1e 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -82,6 +82,8 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, + "use_native_structured_output": False, + "native_structured_output_mode": "native", }, } ], diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index ca0576ebe..ce3530193 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -257,6 +257,8 @@ def test_to_dict(self): "stream": False, "temperature": 0.1, "type": "MockPromptDriver", + "native_structured_output_mode": "native", + "use_native_structured_output": False, "use_native_tools": False, }, "tool": { diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 3c17ff479..71637b775 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -399,6 +399,8 @@ def test_to_dict(self): "temperature": 0.1, "type": "MockPromptDriver", "use_native_tools": False, + "use_native_structured_output": False, + "native_structured_output_mode": "native", }, "tools": [ {