From d82542f19ff24a02d34af217d1f43e5af358e847 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 18 Dec 2024 12:07:47 -0800 Subject: [PATCH] WIP --- .../prompt/amazon_bedrock_prompt_driver.py | 13 ++++++++-- griptape/drivers/prompt/base_prompt_driver.py | 7 ----- .../prompt/openai_chat_prompt_driver.py | 9 +++++++ griptape/tasks/toolkit_task.py | 3 ++- .../templates/tasks/toolkit_task/system.j2 | 26 ------------------- 5 files changed, 22 insertions(+), 36 deletions(-) diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index b108180d2..598a12485 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -102,11 +102,12 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: raise Exception("model response is empty") def _base_params(self, prompt_stack: PromptStack) -> dict: - system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages] + from griptape.tools.structured_output.tool import StructuredOutputTool + system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages] messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()]) - return { + params = { "modelId": self.model, "messages": messages, "system": system_messages, @@ -123,6 +124,14 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **self.extra_params, } + if not self.use_native_structured_output and prompt_stack.output_schema is not None: + structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) + params["tool_choice"] = {"any": {}} + if structured_ouptut_tool not in prompt_stack.tools: + prompt_stack.tools.append(structured_ouptut_tool) + + return params + def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]: return [ { diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index e964f0fe7..229269f65 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -60,13 +60,6 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: - from griptape.tools.structured_output.tool import StructuredOutputTool - - if not self.use_native_structured_output and prompt_stack.output_schema is not None: - structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) - if structured_ouptut_tool not in prompt_stack.tools: - prompt_stack.tools.append(structured_ouptut_tool) - EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index f0468cd02..aefb379b4 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -144,6 +144,15 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(delta)) def _base_params(self, prompt_stack: PromptStack) -> dict: + from griptape.tools.structured_output.tool import StructuredOutputTool + + tools = prompt_stack.tools + if not self.use_native_structured_output and prompt_stack.output_schema is not None: + structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema) + params["tool_choice"] = "required" + if structured_ouptut_tool not in prompt_stack.tools: + prompt_stack.tools.append(structured_ouptut_tool) + params = { "model": self.model, "temperature": self.temperature, diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index 6b0cb8c9b..821ee8177 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -70,7 +70,8 @@ def prompt_stack(self) -> PromptStack: stack = PromptStack(tools=self.tools, output_schema=self.output_schema) memory = self.structure.conversation_memory if self.structure is not None else None - stack.add_system_message(self.generate_system_template(self)) + if system_prompt := self.generate_system_template(self): + stack.add_system_message(system_prompt) stack.add_user_message(self.input) diff --git a/griptape/templates/tasks/toolkit_task/system.j2 b/griptape/templates/tasks/toolkit_task/system.j2 index e9ec48739..e69de29bb 100644 --- a/griptape/templates/tasks/toolkit_task/system.j2 +++ b/griptape/templates/tasks/toolkit_task/system.j2 @@ -1,26 +0,0 @@ -You can think step-by-step and execute actions sequentially or in parallel to get your final answer. -{% if not use_native_tools %} - -You must use the following format when executing actions: - -Thought: -Actions: -{{ stop_sequence }}: -"Thought", "Actions", "{{ stop_sequence }}" must always start on a new line. - -You must use the following format when providing your final answer: -Answer: -{% endif %} -Repeat executing actions as many times as you need. -If an action's output contains an error, you MUST ALWAYS try to fix the error by executing another action. - -Be truthful. ALWAYS be proactive and NEVER ask the user for more information input. Keep using actions until you have your final answer. -NEVER make up actions, action names, or action paths. NEVER make up facts. NEVER reference tags in other action input values. -{% if meta_memory %} - -{{ meta_memory }} -{% endif %} -{% if rulesets %} - -{{ rulesets }} -{% endif %}