Skip to content

Commit

Permalink
Add Prompt Driver tests
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jul 6, 2024
1 parent da95012 commit e431f5d
Show file tree
Hide file tree
Showing 11 changed files with 1,007 additions and 208 deletions.
9 changes: 3 additions & 6 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True})
tool_schema_id: str = field(
default="https://griptape.ai", # Amazon Bedrock requires that this be a valid URL.
kw_only=True,
metadata={"serializable": True},
)

def try_run(self, prompt_stack: PromptStack) -> Message:
response = self.bedrock_client.converse(**self._base_params(prompt_stack))
Expand Down Expand Up @@ -124,7 +119,9 @@ def __to_bedrock_tools(self, tools: list[BaseTool]) -> list[dict]:
"name": f"{tool.name}_{tool.activity_name(activity)}",
"description": tool.activity_description(activity),
"inputSchema": {
"json": (tool.activity_schema(activity) or Schema({})).json_schema(self.tool_schema_id)
"json": (tool.activity_schema(activity) or Schema({})).json_schema(
"http://json-schema.org/draft-07/schema#"
)
},
}
}
Expand Down
8 changes: 2 additions & 6 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,8 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:

if "message" in message[0]:
user_message = message[0]["message"]
elif "tool_results" in message[0]:
if "tool_results" in message[0]:
tool_results = message[0]["tool_results"]

Check warning on line 85 in griptape/drivers/prompt/cohere_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/cohere_prompt_driver.py#L85

Added line #L85 was not covered by tests
else:
raise ValueError("Unsupported message type")

# History messages
history_messages = self.__to_cohere_messages(
Expand All @@ -104,7 +102,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**({"tool_results": tool_results} if tool_results else {}),
**(
{"tools": self.__to_cohere_tools(prompt_stack.actions), "force_single_step": self.force_single_step}
if self.use_native_tools
if prompt_stack.actions and self.use_native_tools
else {}
),
**({"preamble": preamble} if preamble else {}),
Expand Down Expand Up @@ -160,8 +158,6 @@ def __to_cohere_message_content(self, content: BaseMessageContent) -> str | dict
def __to_cohere_role(self, message: Message) -> str:
if message.is_system():
return "SYSTEM"
elif message.is_user():
return "USER"
elif message.is_assistant():
return "CHATBOT"
else:
Expand Down
10 changes: 6 additions & 4 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from attrs import Factory, define, field
from google.generativeai.types import ContentsType


from griptape.common import (
BaseMessageContent,
DeltaMessage,
Expand All @@ -24,13 +25,13 @@
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import BaseTokenizer, GoogleTokenizer
from griptape.utils import import_optional_dependency, remove_key_in_dict_recursively
from schema import Schema

if TYPE_CHECKING:
from google.generativeai import GenerativeModel
from google.generativeai.types import ContentDict, GenerateContentResponse
from google.generativeai.protos import Part
from griptape.tools import BaseTool
from schema import Schema


@define
Expand Down Expand Up @@ -166,9 +167,10 @@ def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]:
tool_declarations = []
for tool in tools:
for activity in tool.activities():
schema = (tool.activity_schema(activity) or Schema({})).json_schema("Parameters Schema")["properties"][
"values"
]
schema = (tool.activity_schema(activity) or Schema({})).json_schema("Parameters Schema")

if "values" in schema["properties"]:
schema = schema["properties"]["values"]

schema = remove_key_in_dict_recursively(schema, "additionalProperties")
tool_declaration = FunctionDeclaration(
Expand Down
34 changes: 19 additions & 15 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,24 +241,28 @@ def __to_openai_message_content(self, content: BaseMessageContent) -> str | dict
raise ValueError(f"Unsupported content type: {type(content)}")

def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) -> list[BaseMessageContent]:
content = []

if response.content is not None:
return [TextMessageContent(TextArtifact(response.content))]
elif response.tool_calls is not None:
return [
ActionCallMessageContent(
ActionArtifact(
ActionArtifact.Action(
tag=tool_call.id,
name=tool_call.function.name.split("_", 1)[0],
path=tool_call.function.name.split("_", 1)[1],
input=json.loads(tool_call.function.arguments),
content.append(TextMessageContent(TextArtifact(response.content)))
if response.tool_calls is not None:
content.extend(
[
ActionCallMessageContent(
ActionArtifact(
ActionArtifact.Action(
tag=tool_call.id,
name=tool_call.function.name.split("_", 1)[0],
path=tool_call.function.name.split("_", 1)[1],
input=json.loads(tool_call.function.arguments),
)
)
)
)
for tool_call in response.tool_calls
]
else:
raise ValueError(f"Unsupported message type: {response}")
for tool_call in response.tool_calls
]
)

return content

def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> BaseDeltaMessageContent:
if content_delta.content is not None:
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/config/test_amazon_bedrock_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def test_to_dict(self, config):
"temperature": 0.1,
"type": "AmazonBedrockPromptDriver",
"tool_choice": {"auto": {}},
"tool_schema_id": "https://griptape.ai",
"use_native_tools": True,
},
"vector_store_driver": {
Expand Down Expand Up @@ -106,7 +105,6 @@ def test_to_dict_with_values(self, config_with_values):
"temperature": 0.1,
"type": "AmazonBedrockPromptDriver",
"tool_choice": {"auto": {}},
"tool_schema_id": "https://griptape.ai",
"use_native_tools": True,
},
"vector_store_driver": {
Expand Down
Loading

0 comments on commit e431f5d

Please sign in to comment.