Skip to content

Commit

Permalink
Improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jul 9, 2024
1 parent 5c07c99 commit ba8e2f6
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
7 changes: 2 additions & 5 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,14 @@ def __to_cohere_messages(self, messages: list[Message]) -> list[dict]:
cohere_messages = []

for message in messages:
cohere_message: dict = {"role": self.__to_cohere_role(message)}
cohere_message: dict = {"role": self.__to_cohere_role(message), "message": message.to_text()}

if message.has_any_content_type(ActionResultMessageContent):
cohere_message["tool_results"] = [
self.__to_cohere_message_content(action_result)
for action_result in message.get_content_type(ActionResultMessageContent)
]
else:
cohere_message["message"] = message.to_text()
if message.has_any_content_type(ActionCallMessageContent):
cohere_message["tool_calls"] = [
self.__to_cohere_message_content(action_call)
Expand All @@ -133,9 +132,7 @@ def __to_cohere_messages(self, messages: list[Message]) -> list[dict]:
return cohere_messages

def __to_cohere_message_content(self, content: BaseMessageContent) -> str | dict:
if isinstance(content, TextMessageContent):
return content.artifact.to_text()
elif isinstance(content, ActionCallMessageContent):
if isinstance(content, ActionCallMessageContent):
action = content.artifact.value

return {"name": action.to_native_tool_name(), "parameters": action.input}
Expand Down
31 changes: 27 additions & 4 deletions tests/unit/drivers/prompt/test_cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,23 @@ def prompt_stack(self, request):
prompt_stack.add_user_message(
ListArtifact(
[
TextArtifact("tool-output"),
TextArtifact("keep going"),
ActionArtifact(
Action(
tag="MockTool_test",
name="MockTool",
path="test",
input={"foo": "bar"},
output=TextArtifact("tool-output"),
)
),
]
)
)
prompt_stack.add_user_message(
ListArtifact(
[
TextArtifact("keep going"),
ActionArtifact(
Action(
tag="MockTool_test",
Expand All @@ -103,7 +119,6 @@ def prompt_stack(self, request):
]
)
)
prompt_stack.add_user_message("user-input")
return prompt_stack

def test_init(self):
Expand All @@ -128,6 +143,7 @@ def test_try_run(self, mock_client, prompt_stack, use_native_tools):
"tool_calls": [{"name": "MockTool_test", "parameters": {"foo": "bar"}}],
},
{
"message": "keep going",
"role": "TOOL",
"tool_results": [
{
Expand All @@ -138,9 +154,12 @@ def test_try_run(self, mock_client, prompt_stack, use_native_tools):
},
],
max_tokens=None,
message="user-input",
message="keep going",
**({"tools": self.COHERE_TOOLS, "force_single_step": False} if use_native_tools else {}),
**({"preamble": "system-input"} if prompt_stack.system_messages else {}),
tool_results=[
{"call": {"name": "MockTool_test", "parameters": {"foo": "bar"}}, "outputs": [{"text": "tool-output"}]}
],
stop_sequences=[],
temperature=0.1,
)
Expand Down Expand Up @@ -177,6 +196,7 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, use_native_tools
},
{
"role": "TOOL",
"message": "keep going",
"tool_results": [
{
"call": {"name": "MockTool_test", "parameters": {"foo": "bar"}},
Expand All @@ -186,9 +206,12 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, use_native_tools
},
],
max_tokens=None,
message="user-input",
message="keep going",
**({"tools": self.COHERE_TOOLS, "force_single_step": False} if use_native_tools else {}),
**({"preamble": "system-input"} if prompt_stack.system_messages else {}),
tool_results=[
{"call": {"name": "MockTool_test", "parameters": {"foo": "bar"}}, "outputs": [{"text": "tool-output"}]}
],
stop_sequences=[],
temperature=0.1,
)
Expand Down

0 comments on commit ba8e2f6

Please sign in to comment.