From f74e2e69f19bc6b7d8e499a8b5ef08b94c83a73a Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 18 Dec 2024 15:56:49 -0800 Subject: [PATCH] WIP --- .../prompt/amazon_bedrock_prompt_driver.py | 30 +++++++---- .../drivers/prompt/anthropic_prompt_driver.py | 14 ++++- griptape/drivers/prompt/base_prompt_driver.py | 5 +- .../drivers/prompt/cohere_prompt_driver.py | 38 ++++++------- .../drivers/prompt/google_prompt_driver.py | 24 +++++---- .../prompt/huggingface_hub_prompt_driver.py | 25 +++++++-- .../drivers/prompt/ollama_prompt_driver.py | 14 ++++- .../prompt/openai_chat_prompt_driver.py | 54 ++++++++----------- griptape/mixins/rule_mixin.py | 6 +++ griptape/rules/json_schema_rule.py | 9 +++- griptape/tasks/prompt_task.py | 33 +++++++++--- griptape/tasks/toolkit_task.py | 3 +- .../templates/tasks/toolkit_task/system.j2 | 26 +++++++++ 13 files changed, 192 insertions(+), 89 deletions(-) diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 598a12485..7443e8426 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from attrs import Factory, define, field from schema import Schema @@ -55,6 +55,9 @@ class AmazonBedrockPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + native_structured_output_mode: Literal["native", "tool"] = field( + default="tool", kw_only=True, metadata={"serializable": True} + ) tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True}) _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @@ -116,19 +119,24 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}), }, "additionalModelRequestFields": self.additional_model_request_fields, - **( - {"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}} - if prompt_stack.tools and self.use_native_tools - else {} - ), **self.extra_params, } - if not self.use_native_structured_output and prompt_stack.output_schema is not None: - structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) - params["tool_choice"] = {"any": {}} - if structured_ouptut_tool not in prompt_stack.tools: - prompt_stack.tools.append(structured_ouptut_tool) + if prompt_stack.output_schema is not None and self.use_native_structured_output: + if self.native_structured_output_mode == "tool": + structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) + params["toolConfig"] = { + "toolChoice": {"any": {}}, + } + if structured_ouptut_tool not in prompt_stack.tools: + prompt_stack.tools.append(structured_ouptut_tool) + else: + raise ValueError(f"Unsupported native structured output mode: {self.native_structured_output_mode}") + + if prompt_stack.tools and self.use_native_tools: + params["toolConfig"] = { + "tools": self.__to_bedrock_tools(prompt_stack.tools), + } return params diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 3341006a1..af549bc71 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -1,9 +1,9 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional -from attrs import Factory, define, field +from attrs import Attribute, Factory, define, field from schema import Schema from griptape.artifacts import ( @@ -68,6 +68,9 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + native_structured_output_mode: Literal["native", "tool"] = field( + default="tool", kw_only=True, metadata={"serializable": True} + ) max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @@ -75,6 +78,13 @@ class AnthropicPromptDriver(BasePromptDriver): def client(self) -> Client: return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key) + @native_structured_output_mode.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] + def validate_native_structured_output_mode(self, attribute: Attribute, value: str) -> str: + if value == "native": + raise ValueError("Anthropic does not support native structured output mode.") + + return value + @observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 229269f65..eb8ba431a 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional from attrs import Factory, define, field @@ -57,6 +57,9 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + native_structured_output_mode: Literal["native", "tool"] = field( + default="native", kw_only=True, metadata={"serializable": True} + ) extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 5fe0400ac..0f796a04d 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -97,37 +97,39 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event)) def _base_params(self, prompt_stack: PromptStack) -> dict: + from griptape.tools.structured_output.tool import StructuredOutputTool + tool_results = [] messages = self.__to_cohere_messages(prompt_stack.messages) - return { + params = { "model": self.model, "messages": messages, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, "max_tokens": self.max_tokens, **({"tool_results": tool_results} if tool_results else {}), - **( - {"tools": self.__to_cohere_tools(prompt_stack.tools)} - if prompt_stack.tools and self.use_native_tools - else {} - ), - **( - { - "response_format": { - "type": "json_object", - "schema": prompt_stack.output_schema.json_schema("Output"), - } - } - if not prompt_stack.tools # Respond format is not supported with tools https://docs.cohere.com/reference/chat#request.body.response_format - and prompt_stack.output_schema is not None - and self.use_native_structured_output - else {} - ), **self.extra_params, } + if prompt_stack.output_schema is not None: + if self.use_native_structured_output: + params["response_format"] = { + "type": "json_object", + "schema": prompt_stack.output_schema.json_schema("Output"), + } + else: + # This does not work great since Cohere does not support forced tool use. + structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) + if structured_output_tool not in prompt_stack.tools: + prompt_stack.tools.append(structured_output_tool) + + if prompt_stack.tools and self.use_native_tools: + params["tools"] = self.__to_cohere_tools(prompt_stack.tools) + + return params + def __to_cohere_messages(self, messages: list[Message]) -> list[dict]: cohere_messages = [] diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 2a6bdbf6d..8a1dc8fd3 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -125,6 +125,8 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ) def _base_params(self, prompt_stack: PromptStack) -> dict: + from griptape.tools.structured_output.tool import StructuredOutputTool + types = import_optional_dependency("google.generativeai.types") protos = import_optional_dependency("google.generativeai.protos") @@ -135,7 +137,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages], ) - return { + params = { "generation_config": types.GenerationConfig( **{ # For some reason, providing stop sequences when streaming breaks native functions @@ -148,15 +150,19 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, }, ), - **( - { - "tools": self.__to_google_tools(prompt_stack.tools), - "tool_config": {"function_calling_config": {"mode": self.tool_choice}}, - } - if prompt_stack.tools and self.use_native_tools - else {} - ), } + if prompt_stack.output_schema is not None: + structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) + params["tool_config"] = { + "function_calling_config": {"mode": self.tool_choice}, + } + if structured_ouptut_tool not in prompt_stack.tools: + prompt_stack.tools.append(structured_ouptut_tool) + + if prompt_stack.tools and self.use_native_tools: + params["tools"] = self.__to_google_tools(prompt_stack.tools) + + return params def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType: types = import_optional_dependency("google.generativeai.types") diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index c2c45c3ae..7cf751166 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -35,6 +35,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): api_token: str = field(kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) + use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True}) tokenizer: HuggingFaceTokenizer = field( default=Factory( lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), @@ -55,7 +56,12 @@ def client(self) -> InferenceClient: def try_run(self, prompt_stack: PromptStack) -> Message: prompt = self.prompt_stack_to_string(prompt_stack) full_params = self._base_params(prompt_stack) - logger.debug((prompt, full_params)) + logger.debug( + { + "prompt": prompt, + **full_params, + } + ) response = self.client.text_generation( prompt, @@ -75,7 +81,12 @@ def try_run(self, prompt_stack: PromptStack) -> Message: def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: prompt = self.prompt_stack_to_string(prompt_stack) full_params = {**self._base_params(prompt_stack), "stream": True} - logger.debug((prompt, full_params)) + logger.debug( + { + "prompt": prompt, + **full_params, + } + ) response = self.client.text_generation(prompt, **full_params) @@ -94,12 +105,20 @@ def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) def _base_params(self, prompt_stack: PromptStack) -> dict: - return { + params = { "return_full_text": False, "max_new_tokens": self.max_tokens, **self.extra_params, } + if prompt_stack.output_schema and self.use_native_structured_output: + # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding + params["grammar"] = {"type": "json", "value": prompt_stack.output_schema.json_schema("Output Schema")} + del params["grammar"]["value"]["$schema"] + del params["grammar"]["value"]["$id"] + + return params + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] for message in prompt_stack.messages: diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index f372f3b54..383104fa5 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -101,9 +101,11 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: raise Exception("invalid model response") def _base_params(self, prompt_stack: PromptStack) -> dict: + from griptape.tools.structured_output.tool import StructuredOutputTool + messages = self._prompt_stack_to_messages(prompt_stack) - return { + params = { "messages": messages, "model": self.model, "options": self.options, @@ -122,6 +124,16 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } + if prompt_stack.output_schema is not None: + if self.use_native_structured_output: + params["format"] = prompt_stack.output_schema.json_schema("Output") + else: + structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) + if structured_ouptut_tool not in prompt_stack.tools: + prompt_stack.tools.append(structured_ouptut_tool) + + return params + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: ollama_messages = [] for message in prompt_stack.messages: diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index aefb379b4..f2d32a55b 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -144,50 +144,38 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(delta)) def _base_params(self, prompt_stack: PromptStack) -> dict: - from griptape.tools.structured_output.tool import StructuredOutputTool - - tools = prompt_stack.tools - if not self.use_native_structured_output and prompt_stack.output_schema is not None: - structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) - params["tool_choice"] = "required" - if structured_ouptut_tool not in prompt_stack.tools: - prompt_stack.tools.append(structured_ouptut_tool) - params = { "model": self.model, "temperature": self.temperature, "user": self.user, "seed": self.seed, - **( - { - "tools": self.__to_openai_tools(prompt_stack.tools), - "tool_choice": self.tool_choice, - "parallel_tool_calls": self.parallel_tool_calls, - } - if prompt_stack.tools and self.use_native_tools - else {} - ), + "stop": self.tokenizer.stop_sequences, **({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}), **({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}), **({"stream_options": {"include_usage": True}} if self.stream else {}), - **( - { - "response_format": { - "type": "json_schema", - "json_schema": { - "name": "Output", - "schema": prompt_stack.output_schema.json_schema("Output"), - "strict": True, - }, - } - } - if prompt_stack.output_schema is not None and self.use_native_structured_output - else {} - ), **self.extra_params, } - if self.response_format is not None: + if prompt_stack.output_schema is not None and self.use_native_structured_output: + if self.native_structured_output_mode == "native": + params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": "Output", + "schema": prompt_stack.output_schema.json_schema("Output"), + "strict": True, + }, + } + else: + params["tool_choice"] = "required" + + if prompt_stack.tools and self.use_native_tools: + params["tools"] = self.__to_openai_tools(prompt_stack.tools) + params["parallel_tool_calls"] = self.parallel_tool_calls + if "tool_choice" not in params: + params["tool_choice"] = self.tool_choice + + if self.response_format is not None and "response_format" not in params: if self.response_format == {"type": "json_object"}: params["response_format"] = self.response_format # JSON mode still requires a system message instructing the LLM to output JSON. diff --git a/griptape/mixins/rule_mixin.py b/griptape/mixins/rule_mixin.py index cae731865..6716f246c 100644 --- a/griptape/mixins/rule_mixin.py +++ b/griptape/mixins/rule_mixin.py @@ -1,12 +1,15 @@ from __future__ import annotations import uuid +from typing import TypeVar from attrs import Factory, define, field from griptape.mixins.serializable_mixin import SerializableMixin from griptape.rules import BaseRule, Ruleset +T = TypeVar("T", bound=BaseRule) + @define(slots=False) class RuleMixin(SerializableMixin): @@ -25,3 +28,6 @@ def rulesets(self) -> list[Ruleset]: rulesets.append(Ruleset(id=self._default_ruleset_id, name=self._default_ruleset_name, rules=self.rules)) return rulesets + + def get_rules_for_type(self, rule_type: type[T]) -> list[T]: + return [rule for ruleset in self.rulesets for rule in ruleset.rules if isinstance(rule, rule_type)] diff --git a/griptape/rules/json_schema_rule.py b/griptape/rules/json_schema_rule.py index c068eb4a1..84a700ce8 100644 --- a/griptape/rules/json_schema_rule.py +++ b/griptape/rules/json_schema_rule.py @@ -1,17 +1,22 @@ from __future__ import annotations import json +from typing import TYPE_CHECKING, Union from attrs import Factory, define, field from griptape.rules import BaseRule from griptape.utils import J2 +if TYPE_CHECKING: + from schema import Schema + @define() class JsonSchemaRule(BaseRule): - value: dict = field(metadata={"serializable": True}) + value: Union[dict, Schema] = field(metadata={"serializable": True}) generate_template: J2 = field(default=Factory(lambda: J2("rules/json_schema.j2"))) def to_text(self) -> str: - return self.generate_template.render(json_schema=json.dumps(self.value)) + value = self.value if isinstance(self.value, dict) else self.value.json_schema("Output Schema") + return self.generate_template.render(json_schema=json.dumps(value)) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index dee6c2b04..b218f094d 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -12,7 +12,9 @@ from griptape.memory.structure import Run from griptape.mixins.rule_mixin import RuleMixin from griptape.rules import Ruleset +from griptape.rules.json_schema_rule import JsonSchemaRule from griptape.tasks import BaseTask +from griptape.tools import StructuredOutputTool from griptape.utils import J2 if TYPE_CHECKING: @@ -72,9 +74,28 @@ def prompt_stack(self) -> PromptStack: stack = PromptStack(output_schema=self.output_schema) memory = self.conversation_memory - system_template = self.generate_system_template(self) - if system_template: - stack.add_system_message(system_template) + system_contents = [TextArtifact(self.generate_system_template(self))] + if self.prompt_driver.use_native_structured_output: + json_schema_rules = self.get_rules_for_type(JsonSchemaRule) + if len(json_schema_rules) > 1: + raise ValueError("Only one JSON Schema rule is allowed per task when using native structured output.") + json_schema_rule = json_schema_rules[0] if json_schema_rules else None + if json_schema_rule is not None: + if isinstance(json_schema_rule.value, Schema): + stack.output_schema = json_schema_rule.value + if self.prompt_driver.native_structured_output_mode == "native": + stack.output_schema = json_schema_rule.value + else: + stack.tools.append(StructuredOutputTool(output_schema=stack.output_schema)) + else: + raise ValueError( + "JSON Schema rule value must be of type Schema when using native structured output." + ) + else: + system_contents.append(TextArtifact(J2("rulesets/rulesets.j2").render(rulesets=self.rulesets))) + + if system_contents: + stack.add_system_message(ListArtifact(system_contents)) stack.add_user_message(self.input) @@ -83,14 +104,12 @@ def prompt_stack(self) -> PromptStack: if memory is not None and memory is not NOTHING: # insert memory into the stack right before the user messages - memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_template else 0) + memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_contents else 0) return stack def default_generate_system_template(self, _: PromptTask) -> str: - return J2("tasks/prompt_task/system.j2").render( - rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), - ) + return J2("tasks/prompt_task/system.j2").render() def before_run(self) -> None: super().before_run() diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index 821ee8177..6b0cb8c9b 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -70,8 +70,7 @@ def prompt_stack(self) -> PromptStack: stack = PromptStack(tools=self.tools, output_schema=self.output_schema) memory = self.structure.conversation_memory if self.structure is not None else None - if system_prompt := self.generate_system_template(self): - stack.add_system_message(system_prompt) + stack.add_system_message(self.generate_system_template(self)) stack.add_user_message(self.input) diff --git a/griptape/templates/tasks/toolkit_task/system.j2 b/griptape/templates/tasks/toolkit_task/system.j2 index e69de29bb..e9ec48739 100644 --- a/griptape/templates/tasks/toolkit_task/system.j2 +++ b/griptape/templates/tasks/toolkit_task/system.j2 @@ -0,0 +1,26 @@ +You can think step-by-step and execute actions sequentially or in parallel to get your final answer. +{% if not use_native_tools %} + +You must use the following format when executing actions: + +Thought: +Actions: +{{ stop_sequence }}: +"Thought", "Actions", "{{ stop_sequence }}" must always start on a new line. + +You must use the following format when providing your final answer: +Answer: +{% endif %} +Repeat executing actions as many times as you need. +If an action's output contains an error, you MUST ALWAYS try to fix the error by executing another action. + +Be truthful. ALWAYS be proactive and NEVER ask the user for more information input. Keep using actions until you have your final answer. +NEVER make up actions, action names, or action paths. NEVER make up facts. NEVER reference tags in other action input values. +{% if meta_memory %} + +{{ meta_memory }} +{% endif %} +{% if rulesets %} + +{{ rulesets }} +{% endif %}