Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 18, 2024
1 parent 2072b42 commit d82542f
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 36 deletions.
13 changes: 11 additions & 2 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,12 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
raise Exception("model response is empty")

def _base_params(self, prompt_stack: PromptStack) -> dict:
system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages]
from griptape.tools.structured_output.tool import StructuredOutputTool

system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages]
messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()])

return {
params = {
"modelId": self.model,
"messages": messages,
"system": system_messages,
Expand All @@ -123,6 +124,14 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if not self.use_native_structured_output and prompt_stack.output_schema is not None:
structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
params["tool_choice"] = {"any": {}}
if structured_ouptut_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_ouptut_tool)

return params

def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]:
return [
{
Expand Down
7 changes: 0 additions & 7 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,6 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})

def before_run(self, prompt_stack: PromptStack) -> None:
from griptape.tools.structured_output.tool import StructuredOutputTool

if not self.use_native_structured_output and prompt_stack.output_schema is not None:
structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
if structured_ouptut_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_ouptut_tool)

EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack))

def after_run(self, result: Message) -> None:
Expand Down
9 changes: 9 additions & 0 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,15 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(delta))

def _base_params(self, prompt_stack: PromptStack) -> dict:
from griptape.tools.structured_output.tool import StructuredOutputTool

tools = prompt_stack.tools
if not self.use_native_structured_output and prompt_stack.output_schema is not None:
structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
params["tool_choice"] = "required"
if structured_ouptut_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_ouptut_tool)

params = {
"model": self.model,
"temperature": self.temperature,
Expand Down
3 changes: 2 additions & 1 deletion griptape/tasks/toolkit_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def prompt_stack(self) -> PromptStack:
stack = PromptStack(tools=self.tools, output_schema=self.output_schema)
memory = self.structure.conversation_memory if self.structure is not None else None

stack.add_system_message(self.generate_system_template(self))
if system_prompt := self.generate_system_template(self):
stack.add_system_message(system_prompt)

stack.add_user_message(self.input)

Expand Down
26 changes: 0 additions & 26 deletions griptape/templates/tasks/toolkit_task/system.j2
Original file line number Diff line number Diff line change
@@ -1,26 +0,0 @@
You can think step-by-step and execute actions sequentially or in parallel to get your final answer.
{% if not use_native_tools %}

You must use the following format when executing actions:

Thought: <your step-by-step thought process describing what actions you need to use>
Actions: <JSON array of actions that MUST follow this schema: {{ actions_schema }}>
{{ stop_sequence }}: <action outputs>
"Thought", "Actions", "{{ stop_sequence }}" must always start on a new line.

You must use the following format when providing your final answer:
Answer: <final answer>
{% endif %}
Repeat executing actions as many times as you need.
If an action's output contains an error, you MUST ALWAYS try to fix the error by executing another action.

Be truthful. ALWAYS be proactive and NEVER ask the user for more information input. Keep using actions until you have your final answer.
NEVER make up actions, action names, or action paths. NEVER make up facts. NEVER reference tags in other action input values.
{% if meta_memory %}

{{ meta_memory }}
{% endif %}
{% if rulesets %}

{{ rulesets }}
{% endif %}

0 comments on commit d82542f

Please sign in to comment.