Skip to content

Commit

Permalink
Clean up drivers, improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed May 23, 2024
1 parent 00a270e commit 4766361
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 41 deletions.
81 changes: 40 additions & 41 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,47 +160,7 @@ def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict[str,
messages = []

for input in prompt_stack.inputs:
# Each Tool result requires a separate message
if input.is_tool_result():
actions_artifact = input.content

if isinstance(actions_artifact, ActionsArtifact):
tool_result_messages = [
{
"tool_call_id": tool_call.tag,
"role": self.__to_openai_role(input),
"name": f"{tool_call.name}-{tool_call.path}",
"content": tool_call.output.to_text(),
}
for tool_call in actions_artifact.actions
]
messages.extend(tool_result_messages)
else:
raise ValueError("PromptStack Input content must be an ActionsArtifact")

else:
if input.is_tool_call():
actions_artifact = input.content

if isinstance(actions_artifact, ActionsArtifact):
tool_calls = [
{
"id": action.tag,
"function": {
"name": f"{action.name}-{action.path}",
"arguments": json.dumps(action.input),
},
"type": "function",
}
for action in actions_artifact.actions
]
message = {"role": self.__to_openai_role(input), "tool_calls": tool_calls}
else:
raise ValueError("PromptStack Input content must be an ActionsArtifact")
else:
message = {"role": self.__to_openai_role(input), "content": input.content}

messages.append(message)
messages.extend(self.__to_openai_content(input))

return messages

Expand Down Expand Up @@ -263,6 +223,45 @@ def _extract_ratelimit_metadata(self, response):
self._ratelimit_token_limit = response.headers.get("x-ratelimit-limit-tokens")
self._ratelimit_tokens_remaining = response.headers.get("x-ratelimit-remaining-tokens")

def __to_openai_content(self, input: PromptStack.Input) -> list[dict]:
content = []

if input.is_tool_call():
actions_artifact = input.content

if isinstance(actions_artifact, ActionsArtifact):
tool_calls = [
{
"id": action.tag,
"function": {"name": f"{action.name}-{action.path}", "arguments": json.dumps(action.input)},
"type": "function",
}
for action in actions_artifact.actions
]
content.append({"role": self.__to_openai_role(input), "tool_calls": tool_calls})
else:
raise ValueError("PromptStack Input content must be an ActionsArtifact")

Check warning on line 243 in griptape/drivers/prompt/openai_chat_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/openai_chat_prompt_driver.py#L243

Added line #L243 was not covered by tests
elif input.is_tool_result():
actions_artifact = input.content

if isinstance(actions_artifact, ActionsArtifact):
tool_result_messages = [
{
"tool_call_id": tool_call.tag,
"role": self.__to_openai_role(input),
"name": f"{tool_call.name}-{tool_call.path}",
"content": tool_call.output.to_text(),
}
for tool_call in actions_artifact.actions
]
content.extend(tool_result_messages)
else:
raise ValueError("PromptStack Input content must be an ActionsArtifact")

Check warning on line 259 in griptape/drivers/prompt/openai_chat_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/openai_chat_prompt_driver.py#L259

Added line #L259 was not covered by tests
else:
content.append({"role": self.__to_openai_role(input), "content": input.content})

return content

def __to_openai_role(self, prompt_input: PromptStack.Input) -> str:
if prompt_input.is_system():
return "system"
Expand Down
46 changes: 46 additions & 0 deletions tests/unit/drivers/prompt/test_anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,17 @@ def test_try_run_with_tools(self, mock_client_with_tools, model):
prompt_stack.add_system_input("system-input")
prompt_stack.add_user_input("user-input")
prompt_stack.add_assistant_input("assistant-input")
prompt_stack.add_tool_call_input(
content="Thinking",
actions=[
ActionsArtifact.Action(
tag="tool-call-id",
name="ToolName",
path="ActivityName",
input={"parameter-name": "parameter-value"},
)
],
)
prompt_stack.add_tool_call_input(
content=None,
actions=[
Expand All @@ -266,6 +277,18 @@ def test_try_run_with_tools(self, mock_client_with_tools, model):
expected_messages = [
{"role": "user", "content": "user-input"},
{"role": "assistant", "content": "assistant-input"},
{
"content": [
{"text": "Thinking", "type": "text"},
{
"id": "tool-call-id",
"input": {"parameter-name": "parameter-value"},
"name": "ToolName-ActivityName",
"type": "tool_use",
},
],
"role": "assistant",
},
{
"content": [
{
Expand Down Expand Up @@ -381,6 +404,17 @@ def test_try_stream_run_with_tools(self, mock_stream_client_with_tools, model):
prompt_stack.add_system_input("system-input")
prompt_stack.add_user_input("user-input")
prompt_stack.add_assistant_input("assistant-input")
prompt_stack.add_tool_call_input(
content="Thinking",
actions=[
ActionsArtifact.Action(
tag="tool-call-id",
name="ToolName",
path="ActivityName",
input={"parameter-name": "parameter-value"},
)
],
)
prompt_stack.add_tool_call_input(
content=None,
actions=[
Expand Down Expand Up @@ -408,6 +442,18 @@ def test_try_stream_run_with_tools(self, mock_stream_client_with_tools, model):
{"role": "user", "content": "generic-input"},
{"role": "user", "content": "user-input"},
{"role": "assistant", "content": "assistant-input"},
{
"content": [
{"text": "Thinking", "type": "text"},
{
"id": "tool-call-id",
"input": {"parameter-name": "parameter-value"},
"name": "ToolName-ActivityName",
"type": "tool_use",
},
],
"role": "assistant",
},
{
"content": [
{
Expand Down

0 comments on commit 4766361

Please sign in to comment.