From e431f5db8e9f3c115dbd3b0814852231b22e56b2 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 5 Jul 2024 17:00:02 -0700 Subject: [PATCH] Add Prompt Driver tests --- .../prompt/amazon_bedrock_prompt_driver.py | 9 +- .../drivers/prompt/cohere_prompt_driver.py | 8 +- .../drivers/prompt/google_prompt_driver.py | 10 +- .../prompt/openai_chat_prompt_driver.py | 34 +- .../test_amazon_bedrock_structure_config.py | 2 - .../test_amazon_bedrock_prompt_driver.py | 300 ++++++++++++++- .../prompt/test_anthropic_prompt_driver.py | 357 ++++++++++++++---- .../test_azure_openai_chat_prompt_driver.py | 90 ++++- .../prompt/test_cohere_prompt_driver.py | 4 +- .../prompt/test_google_prompt_driver.py | 159 +++++--- .../prompt/test_openai_chat_prompt_driver.py | 242 +++++++++++- 11 files changed, 1007 insertions(+), 208 deletions(-) diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 59d396fc6..f0a73f308 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -50,11 +50,6 @@ class AmazonBedrockPromptDriver(BasePromptDriver): ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True}) - tool_schema_id: str = field( - default="https://griptape.ai", # Amazon Bedrock requires that this be a valid URL. - kw_only=True, - metadata={"serializable": True}, - ) def try_run(self, prompt_stack: PromptStack) -> Message: response = self.bedrock_client.converse(**self._base_params(prompt_stack)) @@ -124,7 +119,9 @@ def __to_bedrock_tools(self, tools: list[BaseTool]) -> list[dict]: "name": f"{tool.name}_{tool.activity_name(activity)}", "description": tool.activity_description(activity), "inputSchema": { - "json": (tool.activity_schema(activity) or Schema({})).json_schema(self.tool_schema_id) + "json": (tool.activity_schema(activity) or Schema({})).json_schema( + "http://json-schema.org/draft-07/schema#" + ) }, } } diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index e0d57617f..46e1cbf9d 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -81,10 +81,8 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if "message" in message[0]: user_message = message[0]["message"] - elif "tool_results" in message[0]: + if "tool_results" in message[0]: tool_results = message[0]["tool_results"] - else: - raise ValueError("Unsupported message type") # History messages history_messages = self.__to_cohere_messages( @@ -104,7 +102,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **({"tool_results": tool_results} if tool_results else {}), **( {"tools": self.__to_cohere_tools(prompt_stack.actions), "force_single_step": self.force_single_step} - if self.use_native_tools + if prompt_stack.actions and self.use_native_tools else {} ), **({"preamble": preamble} if preamble else {}), @@ -160,8 +158,6 @@ def __to_cohere_message_content(self, content: BaseMessageContent) -> str | dict def __to_cohere_role(self, message: Message) -> str: if message.is_system(): return "SYSTEM" - elif message.is_user(): - return "USER" elif message.is_assistant(): return "CHATBOT" else: diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 7c41cd96a..a54b0a52e 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -7,6 +7,7 @@ from attrs import Factory, define, field from google.generativeai.types import ContentsType + from griptape.common import ( BaseMessageContent, DeltaMessage, @@ -24,13 +25,13 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer, GoogleTokenizer from griptape.utils import import_optional_dependency, remove_key_in_dict_recursively +from schema import Schema if TYPE_CHECKING: from google.generativeai import GenerativeModel from google.generativeai.types import ContentDict, GenerateContentResponse from google.generativeai.protos import Part from griptape.tools import BaseTool - from schema import Schema @define @@ -166,9 +167,10 @@ def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]: tool_declarations = [] for tool in tools: for activity in tool.activities(): - schema = (tool.activity_schema(activity) or Schema({})).json_schema("Parameters Schema")["properties"][ - "values" - ] + schema = (tool.activity_schema(activity) or Schema({})).json_schema("Parameters Schema") + + if "values" in schema["properties"]: + schema = schema["properties"]["values"] schema = remove_key_in_dict_recursively(schema, "additionalProperties") tool_declaration = FunctionDeclaration( diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 8d9b0a39f..08d3584f2 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -241,24 +241,28 @@ def __to_openai_message_content(self, content: BaseMessageContent) -> str | dict raise ValueError(f"Unsupported content type: {type(content)}") def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) -> list[BaseMessageContent]: + content = [] + if response.content is not None: - return [TextMessageContent(TextArtifact(response.content))] - elif response.tool_calls is not None: - return [ - ActionCallMessageContent( - ActionArtifact( - ActionArtifact.Action( - tag=tool_call.id, - name=tool_call.function.name.split("_", 1)[0], - path=tool_call.function.name.split("_", 1)[1], - input=json.loads(tool_call.function.arguments), + content.append(TextMessageContent(TextArtifact(response.content))) + if response.tool_calls is not None: + content.extend( + [ + ActionCallMessageContent( + ActionArtifact( + ActionArtifact.Action( + tag=tool_call.id, + name=tool_call.function.name.split("_", 1)[0], + path=tool_call.function.name.split("_", 1)[1], + input=json.loads(tool_call.function.arguments), + ) ) ) - ) - for tool_call in response.tool_calls - ] - else: - raise ValueError(f"Unsupported message type: {response}") + for tool_call in response.tool_calls + ] + ) + + return content def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> BaseDeltaMessageContent: if content_delta.content is not None: diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_structure_config.py index ebfe22300..824e6ce11 100644 --- a/tests/unit/config/test_amazon_bedrock_structure_config.py +++ b/tests/unit/config/test_amazon_bedrock_structure_config.py @@ -52,7 +52,6 @@ def test_to_dict(self, config): "temperature": 0.1, "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, - "tool_schema_id": "https://griptape.ai", "use_native_tools": True, }, "vector_store_driver": { @@ -106,7 +105,6 @@ def test_to_dict_with_values(self, config_with_values): "temperature": 0.1, "type": "AmazonBedrockPromptDriver", "tool_choice": {"auto": {}}, - "tool_schema_id": "https://griptape.ai", "use_native_tools": True, }, "vector_store_driver": { diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index e5edfd155..9af21c3a5 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -1,18 +1,154 @@ import pytest -from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.artifacts import ImageArtifact, TextArtifact, ListArtifact, ErrorArtifact, ActionArtifact from griptape.common import PromptStack -from griptape.common.prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent +from griptape.common import TextDeltaMessageContent, ActionCallDeltaMessageContent from griptape.drivers import AmazonBedrockPromptDriver +from tests.mocks.mock_tool.tool import MockTool + class TestAmazonBedrockPromptDriver: + BEDROCK_TOOLS = [ + { + "toolSpec": { + "description": "test description: foo", + "inputSchema": { + "json": { + "$id": "http://json-schema.org/draft-07/schema#", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + } + }, + "required": ["values"], + "type": "object", + } + }, + "name": "MockTool_test", + } + }, + { + "toolSpec": { + "description": "test description: foo", + "inputSchema": { + "json": { + "$id": "http://json-schema.org/draft-07/schema#", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + } + }, + "required": ["values"], + "type": "object", + } + }, + "name": "MockTool_test_error", + } + }, + { + "toolSpec": { + "description": "test description", + "inputSchema": { + "json": { + "$id": "http://json-schema.org/draft-07/schema#", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {}, + "required": [], + "type": "object", + } + }, + "name": "MockTool_test_list_output", + } + }, + { + "toolSpec": { + "description": "test description", + "inputSchema": { + "json": { + "$id": "http://json-schema.org/draft-07/schema#", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {}, + "required": [], + "type": "object", + } + }, + "name": "MockTool_test_no_schema", + } + }, + { + "toolSpec": { + "description": "test description: foo", + "inputSchema": { + "json": { + "$id": "http://json-schema.org/draft-07/schema#", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + } + }, + "required": ["values"], + "type": "object", + } + }, + "name": "MockTool_test_str_output", + } + }, + { + "toolSpec": { + "description": "test description", + "inputSchema": { + "json": { + "$id": "http://json-schema.org/draft-07/schema#", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + } + }, + "required": ["values"], + "type": "object", + } + }, + "name": "MockTool_test_without_default_memory", + } + }, + ] + @pytest.fixture def mock_converse(self, mocker): mock_converse = mocker.patch("boto3.Session").return_value.client.return_value.converse mock_converse.return_value = { - "output": {"message": {"content": [{"text": "model-output"}]}}, + "output": { + "message": { + "content": [ + {"text": "model-output"}, + {"toolUse": {"name": "MockTool_test", "toolUseId": "mock-id", "input": {"foo": "bar"}}}, + ] + } + }, "usage": {"inputTokens": 5, "outputTokens": 10}, } @@ -24,7 +160,15 @@ def mock_converse_stream(self, mocker): mock_converse_stream.return_value = { "stream": [ + {"contentBlockStart": {"contentBlockIndex": 0, "start": {"text": "model-output"}}}, {"contentBlockDelta": {"contentBlockIndex": 0, "delta": {"text": "model-output"}}}, + { + "contentBlockStart": { + "contentBlockIndex": 1, + "start": {"toolUse": {"name": "MockTool_test", "toolUseId": "mock-id"}}, + } + }, + {"contentBlockDelta": {"contentBlockIndex": 1, "delta": {"toolUse": {"input": '{"foo": "bar"}'}}}}, {"metadata": {"usage": {"inputTokens": 5, "outputTokens": 10}}}, ] } @@ -34,12 +178,60 @@ def mock_converse_stream(self, mocker): @pytest.fixture(params=[True, False]) def prompt_stack(self, request): prompt_stack = PromptStack() + prompt_stack.actions = [MockTool()] if request.param: prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message(TextArtifact("user-input")) - prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) + prompt_stack.add_user_message( + ListArtifact( + [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] + ) + ) prompt_stack.add_assistant_message("assistant-input") + prompt_stack.add_action_call_message( + "thought", [ActionArtifact.Action(tag="MockTool_test", name="MockTool", path="test", input={"foo": "bar"})] + ) + prompt_stack.add_action_result_message( + "keep-going", + [ + ActionArtifact.Action( + tag="MockTool_test", + name="MockTool", + path="test", + input={"foo": "bar"}, + output=TextArtifact("tool-output"), + ) + ], + ) + prompt_stack.add_action_result_message( + "keep-going", + [ + ActionArtifact.Action( + tag="MockTool_test", + name="MockTool", + path="test", + input={"foo": "bar"}, + output=ListArtifact( + [ + TextArtifact("tool-output"), + ImageArtifact(value=b"image-data", format="png", width=100, height=100), + ] + ), + ) + ], + ) + prompt_stack.add_action_result_message( + "keep-going", + [ + ActionArtifact.Action( + tag="MockTool_test", + name="MockTool", + path="test", + input={"foo": "bar"}, + output=ErrorArtifact("error"), + ) + ], + ) return prompt_stack @@ -47,33 +239,91 @@ def prompt_stack(self, request): def messages(self): return [ {"role": "user", "content": [{"text": "user-input"}]}, - {"role": "user", "content": [{"text": "user-input"}]}, - {"role": "user", "content": [{"image": {"format": "png", "source": {"bytes": b"image-data"}}}]}, + { + "role": "user", + "content": [{"text": "user-input"}, {"image": {"format": "png", "source": {"bytes": b"image-data"}}}], + }, {"role": "assistant", "content": [{"text": "assistant-input"}]}, + { + "content": [ + {"text": "thought"}, + {"toolUse": {"input": {"foo": "bar"}, "name": "MockTool_test", "toolUseId": "MockTool_test"}}, + ], + "role": "assistant", + }, + { + "content": [ + { + "toolResult": { + "content": [{"text": "tool-output"}], + "status": "success", + "toolUseId": "MockTool_test", + } + }, + {"text": "keep-going"}, + ], + "role": "user", + }, + { + "content": [ + { + "toolResult": { + "content": [ + {"text": "tool-output"}, + {"image": {"format": "png", "source": {"bytes": b"image-data"}}}, + ], + "status": "success", + "toolUseId": "MockTool_test", + } + }, + {"text": "keep-going"}, + ], + "role": "user", + }, + { + "content": [ + {"toolResult": {"content": [{"text": "error"}], "status": "error", "toolUseId": "MockTool_test"}}, + {"text": "keep-going"}, + ], + "role": "user", + }, ] - def test_try_run(self, mock_converse, prompt_stack, messages): + @pytest.mark.parametrize("use_native_tools", [True, False]) + def test_try_run(self, mock_converse, prompt_stack, messages, use_native_tools): # Given - driver = AmazonBedrockPromptDriver(model="ai21.j2") + driver = AmazonBedrockPromptDriver(model="ai21.j2", use_native_tools=use_native_tools) # When - text_artifact = driver.try_run(prompt_stack) + message = driver.try_run(prompt_stack) # Then mock_converse.assert_called_once_with( modelId=driver.model, messages=messages, - **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), inferenceConfig={"temperature": driver.temperature}, additionalModelRequestFields={}, + **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), + **( + {"toolConfig": {"tools": self.BEDROCK_TOOLS, "toolChoice": driver.tool_choice}} + if use_native_tools + else {} + ), ) - assert text_artifact.value == "model-output" - assert text_artifact.usage.input_tokens == 5 - assert text_artifact.usage.output_tokens == 10 + assert isinstance(message.value[0], TextArtifact) + assert message.value[0].value == "model-output" + assert isinstance(message.value[1], ActionArtifact) + assert message.value[1].value.tag == "mock-id" + assert message.value[1].value.name == "MockTool" + assert message.value[1].value.path == "test" + assert message.value[1].value.input == {"foo": "bar"} + assert message.usage.input_tokens == 5 + assert message.usage.output_tokens == 10 - def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages): + @pytest.mark.parametrize("use_native_tools", [True, False]) + def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages, use_native_tools): # Given - driver = AmazonBedrockPromptDriver(model="ai21.j2", stream=True) + driver = AmazonBedrockPromptDriver(model="ai21.j2", stream=True, use_native_tools=use_native_tools) # When stream = driver.try_stream(prompt_stack) @@ -83,14 +333,30 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages): mock_converse_stream.assert_called_once_with( modelId=driver.model, messages=messages, - **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), inferenceConfig={"temperature": driver.temperature}, additionalModelRequestFields={}, + **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), + **( + {"toolConfig": {"tools": self.BEDROCK_TOOLS, "toolChoice": driver.tool_choice}} + if prompt_stack.actions and use_native_tools + else {} + ), ) + event = next(stream) assert isinstance(event.content, TextDeltaMessageContent) assert event.content.text == "model-output" + event = next(stream) + assert isinstance(event.content, ActionCallDeltaMessageContent) + assert event.content.tag == "mock-id" + assert event.content.name == "MockTool" + assert event.content.path == "test" + + event = next(stream) + assert isinstance(event.content, ActionCallDeltaMessageContent) + assert event.content.partial_input == '{"foo": "bar"}' + event = next(stream) assert event.usage.input_tokens == 5 assert event.usage.output_tokens == 10 diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 2f668faea..02f6f46e9 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -1,20 +1,129 @@ +from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import AnthropicPromptDriver -from griptape.common import PromptStack, TextDeltaMessageContent -from griptape.artifacts import TextArtifact, ImageArtifact, ListArtifact +from griptape.common import PromptStack, TextDeltaMessageContent, ActionCallDeltaMessageContent +from griptape.artifacts import TextArtifact, ActionArtifact, ImageArtifact, ListArtifact from unittest.mock import Mock import pytest +from tests.mocks.mock_tool.tool import MockTool + class TestAnthropicPromptDriver: + ANTHROPIC_TOOLS = [ + { + "description": "test description: foo", + "input_schema": { + "$id": "Input Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + } + }, + "required": ["values"], + "type": "object", + }, + "name": "MockTool_test", + }, + { + "description": "test description: foo", + "input_schema": { + "$id": "Input Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + } + }, + "required": ["values"], + "type": "object", + }, + "name": "MockTool_test_error", + }, + { + "description": "test description", + "input_schema": { + "$id": "Input Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {}, + "required": [], + "type": "object", + }, + "name": "MockTool_test_list_output", + }, + { + "description": "test description", + "input_schema": { + "$id": "Input Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {}, + "required": [], + "type": "object", + }, + "name": "MockTool_test_no_schema", + }, + { + "description": "test description: foo", + "input_schema": { + "$id": "Input Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + } + }, + "required": ["values"], + "type": "object", + }, + "name": "MockTool_test_str_output", + }, + { + "description": "test description", + "input_schema": { + "$id": "Input Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + } + }, + "required": ["values"], + "type": "object", + }, + "name": "MockTool_test_without_default_memory", + }, + ] + @pytest.fixture def mock_client(self, mocker): mock_client = mocker.patch("anthropic.Anthropic") + mock_tool_use = Mock(type="tool_use", id="mock-id", input={"foo": "bar"}) + mock_tool_use.name = "MockTool_test" mock_client.return_value = Mock( messages=Mock( create=Mock( return_value=Mock( - usage=Mock(input_tokens=5, output_tokens=10), content=[Mock(type="text", text="model-output")] + usage=Mock(input_tokens=5, output_tokens=10), + content=[Mock(type="text", text="model-output"), mock_tool_use], ) ) ) @@ -26,13 +135,29 @@ def mock_client(self, mocker): def mock_stream_client(self, mocker): mock_stream_client = mocker.patch("anthropic.Anthropic") + mock_tool_call_delta_header = Mock(type="tool_use", id="mock-id") + mock_tool_call_delta_header.name = "MockTool_test" + mock_stream_client.return_value = Mock( messages=Mock( create=Mock( return_value=iter( [ Mock(type="message_start", message=Mock(usage=Mock(input_tokens=5))), - Mock(type="content_block_delta", delta=Mock(type="text_delta", text="model-output")), + Mock( + type="content_block_start", + index=0, + content_block=Mock(type="text", text="model-output"), + ), + Mock( + type="content_block_delta", index=0, delta=Mock(type="text_delta", text="model-output") + ), + Mock(type="content_block_start", index=1, content_block=mock_tool_call_delta_header), + Mock( + type="content_block_delta", + index=1, + delta=Mock(type="input_json_delta", partial_json='{"foo": "bar"}'), + ), Mock(type="message_delta", usage=Mock(output_tokens=10)), ] ) @@ -42,104 +167,171 @@ def mock_stream_client(self, mocker): return mock_stream_client - @pytest.mark.parametrize("model", [("claude-2.1"), ("claude-2.0")]) - def test_init(self, model): - assert AnthropicPromptDriver(model=model, api_key="1234") - - @pytest.mark.parametrize( - "model", - [ - ("claude-instant-1.2"), - ("claude-2.1"), - ("claude-2.0"), - ("claude-3-opus"), - ("claude-3-sonnet"), - ("claude-3-haiku"), - ], - ) - @pytest.mark.parametrize("system_enabled", [True, False]) - def test_try_run(self, mock_client, model, system_enabled): - # Given + @pytest.fixture(params=[True, False]) + def prompt_stack(self, request): prompt_stack = PromptStack() - if system_enabled: + prompt_stack.actions = [MockTool()] + if request.param: prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message(TextArtifact("user-input")) - prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) + prompt_stack.add_user_message( + ListArtifact( + [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] + ) + ) prompt_stack.add_assistant_message("assistant-input") - driver = AnthropicPromptDriver(model=model, api_key="api-key") - expected_messages = [ - {"role": "user", "content": "user-input"}, + prompt_stack.add_action_call_message( + "thought", [ActionArtifact.Action(tag="MockTool_test", name="MockTool", path="test", input={"foo": "bar"})] + ) + prompt_stack.add_action_result_message( + "keep-going", + [ + ActionArtifact.Action( + tag="MockTool_test", + name="MockTool", + path="test", + input={"foo": "bar"}, + output=TextArtifact("tool-output"), + ) + ], + ) + prompt_stack.add_action_result_message( + "keep-going", + [ + ActionArtifact.Action( + tag="MockTool_test", + name="MockTool", + path="test", + input={"foo": "bar"}, + output=ListArtifact( + [ + TextArtifact("tool-output"), + ImageArtifact(value=b"image-data", format="png", width=100, height=100), + ] + ), + ) + ], + ) + prompt_stack.add_action_result_message( + "keep-going", + [ + ActionArtifact.Action( + tag="MockTool_test", + name="MockTool", + path="test", + input={"foo": "bar"}, + output=ErrorArtifact("error"), + ) + ], + ) + + return prompt_stack + + @pytest.fixture + def messages(self): + return [ {"role": "user", "content": "user-input"}, { "content": [ + {"type": "text", "text": "user-input"}, { "source": {"data": "aW1hZ2UtZGF0YQ==", "media_type": "image/png", "type": "base64"}, "type": "image", - } + }, ], "role": "user", }, {"role": "assistant", "content": "assistant-input"}, + { + "content": [ + {"text": "thought", "type": "text"}, + {"id": "MockTool_test", "input": {"foo": "bar"}, "name": "MockTool_test", "type": "tool_use"}, + ], + "role": "assistant", + }, + { + "content": [ + { + "content": [{"text": "tool-output", "type": "text"}], + "is_error": False, + "tool_use_id": "MockTool_test", + "type": "tool_result", + }, + {"text": "keep-going", "type": "text"}, + ], + "role": "user", + }, + { + "content": [ + { + "content": [ + {"text": "tool-output", "type": "text"}, + { + "source": {"data": "aW1hZ2UtZGF0YQ==", "media_type": "image/png", "type": "base64"}, + "type": "image", + }, + ], + "is_error": False, + "tool_use_id": "MockTool_test", + "type": "tool_result", + }, + {"text": "keep-going", "type": "text"}, + ], + "role": "user", + }, + { + "content": [ + { + "content": [{"text": "error", "type": "text"}], + "is_error": True, + "tool_use_id": "MockTool_test", + "type": "tool_result", + }, + {"text": "keep-going", "type": "text"}, + ], + "role": "user", + }, ] + def test_init(self): + assert AnthropicPromptDriver(model="claude-3-haiku", api_key="1234") + + @pytest.mark.parametrize("use_native_tools", [True, False]) + def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools): + # Given + driver = AnthropicPromptDriver(model="claude-3-haiku", api_key="api-key", use_native_tools=use_native_tools) + # When message = driver.try_run(prompt_stack) # Then mock_client.return_value.messages.create.assert_called_once_with( - messages=expected_messages, + messages=messages, stop_sequences=[], model=driver.model, max_tokens=1000, temperature=0.1, top_p=0.999, top_k=250, - **{"system": "system-input"} if system_enabled else {}, + **{"system": "system-input"} if prompt_stack.system_messages else {}, + **{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, ) - assert message.value == "model-output" + assert isinstance(message.value[0], TextArtifact) + assert message.value[0].value == "model-output" + assert isinstance(message.value[1], ActionArtifact) + assert message.value[1].value.tag == "mock-id" + assert message.value[1].value.name == "MockTool" + assert message.value[1].value.path == "test" + assert message.value[1].value.input == {"foo": "bar"} assert message.usage.input_tokens == 5 assert message.usage.output_tokens == 10 - @pytest.mark.parametrize( - "model", - [ - ("claude-instant-1.2"), - ("claude-2.1"), - ("claude-2.0"), - ("claude-3-opus"), - ("claude-3-sonnet"), - ("claude-3-haiku"), - ], - ) - @pytest.mark.parametrize("system_enabled", [True, False]) - def test_try_stream_run(self, mock_stream_client, model, system_enabled): + @pytest.mark.parametrize("use_native_tools", [True, False]) + def test_try_stream_run(self, mock_stream_client, prompt_stack, messages, use_native_tools): # Given - prompt_stack = PromptStack() - if system_enabled: - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message( - ListArtifact( - [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] - ) + driver = AnthropicPromptDriver( + model="claude-3-haiku", api_key="api-key", stream=True, use_native_tools=use_native_tools ) - prompt_stack.add_assistant_message("assistant-input") - expected_messages = [ - {"role": "user", "content": "user-input"}, - { - "content": [ - {"type": "text", "text": "user-input"}, - { - "source": {"data": "aW1hZ2UtZGF0YQ==", "media_type": "image/png", "type": "base64"}, - "type": "image", - }, - ], - "role": "user", - }, - {"role": "assistant", "content": "assistant-input"}, - ] - driver = AnthropicPromptDriver(model=model, api_key="api-key", stream=True) # When stream = driver.try_stream(prompt_stack) @@ -147,7 +339,7 @@ def test_try_stream_run(self, mock_stream_client, model, system_enabled): # Then mock_stream_client.return_value.messages.create.assert_called_once_with( - messages=expected_messages, + messages=messages, stop_sequences=[], model=driver.model, max_tokens=1000, @@ -155,7 +347,8 @@ def test_try_stream_run(self, mock_stream_client, model, system_enabled): stream=True, top_p=0.999, top_k=250, - **{"system": "system-input"} if system_enabled else {}, + **{"system": "system-input"} if prompt_stack.system_messages else {}, + **{"tools": self.ANTHROPIC_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, ) assert event.usage.input_tokens == 5 @@ -164,16 +357,18 @@ def test_try_stream_run(self, mock_stream_client, model, system_enabled): assert event.content.text == "model-output" event = next(stream) - assert event.usage.output_tokens == 10 + assert isinstance(event.content, TextDeltaMessageContent) + assert event.content.text == "model-output" - def test_try_run_throws_when_prompt_stack_is_string(self): - # Given - prompt_stack = "prompt-stack" - driver = AnthropicPromptDriver(model="claude", api_key="api-key") + event = next(stream) + assert isinstance(event.content, ActionCallDeltaMessageContent) + assert event.content.tag == "mock-id" + assert event.content.name == "MockTool" + assert event.content.path == "test" - # When - with pytest.raises(Exception) as e: - driver.try_run(prompt_stack) # pyright: ignore + event = next(stream) + assert isinstance(event.content, ActionCallDeltaMessageContent) + assert event.content.partial_input == '{"foo": "bar"}' - # Then - assert e.value.args[0] == "'str' object has no attribute 'messages'" + event = next(stream) + assert event.usage.output_tokens == 10 diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index 25b491bb1..70aae90dd 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -1,7 +1,8 @@ import pytest +from griptape.artifacts import TextArtifact, ActionArtifact from unittest.mock import Mock from griptape.drivers import AzureOpenAiChatPromptDriver -from griptape.common import TextDeltaMessageContent +from griptape.common import TextDeltaMessageContent, ActionCallDeltaMessageContent from tests.unit.drivers.prompt.test_openai_chat_prompt_driver import TestOpenAiChatPromptDriverFixtureMixin @@ -9,20 +10,52 @@ class TestAzureOpenAiChatPromptDriver(TestOpenAiChatPromptDriverFixtureMixin): @pytest.fixture def mock_chat_completion_create(self, mocker): mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.chat.completions.create + mock_function = Mock(arguments='{"foo": "bar"}', id="mock-id") + mock_function.name = "MockTool_test" mock_chat_create.return_value = Mock( headers={}, - choices=[Mock(message=Mock(content="model-output"))], + choices=[ + Mock(message=Mock(content="model-output", tool_calls=[Mock(id="mock-id", function=mock_function)])) + ], usage=Mock(prompt_tokens=5, completion_tokens=10), ) + return mock_chat_create @pytest.fixture def mock_chat_completion_stream_create(self, mocker): mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.chat.completions.create + mock_tool_call_delta_header = Mock() + mock_tool_call_delta_header.name = "MockTool_test" + mock_tool_call_delta_body = Mock(arguments='{"foo": "bar"}') + mock_tool_call_delta_body.name = None + mock_chat_create.return_value = iter( [ - Mock(choices=[Mock(delta=Mock(content="model-output"))], usage=None), + Mock(choices=[Mock(delta=Mock(content="model-output", tool_calls=None))], usage=None), + Mock( + choices=[ + Mock( + delta=Mock( + content=None, + tool_calls=[Mock(index=0, id="mock-id", function=mock_tool_call_delta_header)], + ) + ) + ], + usage=None, + ), + Mock( + choices=[ + Mock( + delta=Mock( + content=None, tool_calls=[Mock(index=0, id=None, function=mock_tool_call_delta_body)] + ) + ) + ], + usage=None, + ), Mock(choices=None, usage=Mock(prompt_tokens=5, completion_tokens=10)), + Mock(choices=[Mock(delta=Mock(content=None, tool_calls=None))], usage=None), ] ) return mock_chat_create @@ -31,25 +64,44 @@ def test_init(self): assert AzureOpenAiChatPromptDriver(azure_endpoint="foobar", azure_deployment="foobar", model="gpt-4") assert AzureOpenAiChatPromptDriver(azure_endpoint="foobar", model="gpt-4").azure_deployment == "gpt-4" - def test_try_run(self, mock_chat_completion_create, prompt_stack, messages): + @pytest.mark.parametrize("use_native_tools", [True, False]) + def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_native_tools): # Given - driver = AzureOpenAiChatPromptDriver(azure_endpoint="endpoint", azure_deployment="deployment-id", model="gpt-4") + driver = AzureOpenAiChatPromptDriver( + azure_endpoint="endpoint", + azure_deployment="deployment-id", + model="gpt-4", + use_native_tools=use_native_tools, + ) # When - text_artifact = driver.try_run(prompt_stack) + message = driver.try_run(prompt_stack) # Then mock_chat_completion_create.assert_called_once_with( - model=driver.model, temperature=driver.temperature, user=driver.user, messages=messages + model=driver.model, + temperature=driver.temperature, + user=driver.user, + messages=messages, + **{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, ) - assert text_artifact.value == "model-output" - assert text_artifact.usage.input_tokens == 5 - assert text_artifact.usage.output_tokens == 10 + assert isinstance(message.value[0], TextArtifact) + assert message.value[0].value == "model-output" + assert isinstance(message.value[1], ActionArtifact) + assert message.value[1].value.tag == "mock-id" + assert message.value[1].value.name == "MockTool" + assert message.value[1].value.path == "test" + assert message.value[1].value.input == {"foo": "bar"} - def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages): + @pytest.mark.parametrize("use_native_tools", [True, False]) + def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages, use_native_tools): # Given driver = AzureOpenAiChatPromptDriver( - azure_endpoint="endpoint", azure_deployment="deployment-id", model="gpt-4", stream=True + azure_endpoint="endpoint", + azure_deployment="deployment-id", + model="gpt-4", + stream=True, + use_native_tools=use_native_tools, ) # When @@ -64,11 +116,25 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, stream=True, messages=messages, stream_options={"include_usage": True}, + **{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, ) assert isinstance(event.content, TextDeltaMessageContent) assert event.content.text == "model-output" + event = next(stream) + assert isinstance(event.content, ActionCallDeltaMessageContent) + assert event.content.tag == "mock-id" + assert event.content.name == "MockTool" + assert event.content.path == "test" + + event = next(stream) + assert isinstance(event.content, ActionCallDeltaMessageContent) + assert event.content.partial_input == '{"foo": "bar"}' + event = next(stream) assert event.usage.input_tokens == 5 assert event.usage.output_tokens == 10 + event = next(stream) + assert isinstance(event.content, TextDeltaMessageContent) + assert event.content.text == "" diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index d648be2ec..087b1022a 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -118,7 +118,7 @@ def test_try_run(self, mock_client, prompt_stack, use_native_tools): "tool_calls": [{"name": "MockTool_test", "parameters": {"foo": "bar"}}], }, { - "role": "USER", + "role": "TOOL", "tool_results": [ { "call": {"name": "MockTool_test", "parameters": {"foo": "bar"}}, @@ -166,7 +166,7 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack, use_native_tools "tool_calls": [{"name": "MockTool_test", "parameters": {"foo": "bar"}}], }, { - "role": "USER", + "role": "TOOL", "tool_results": [ { "call": {"name": "MockTool_test", "parameters": {"foo": "bar"}}, diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 9e577929d..5a47772e5 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -1,19 +1,50 @@ from google.generativeai.types import ContentDict, GenerationConfig from google.generativeai.protos import Part -from griptape.artifacts import TextArtifact, ImageArtifact -from griptape.common.prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent +from griptape.artifacts import TextArtifact, ImageArtifact, ActionArtifact +from griptape.common import TextDeltaMessageContent, ActionCallDeltaMessageContent from griptape.drivers import GooglePromptDriver from griptape.common import PromptStack from unittest.mock import Mock +from tests.mocks.mock_tool.tool import MockTool +from google.protobuf.json_format import MessageToDict + import pytest class TestGooglePromptDriver: + GOOGLE_TOOLS = [ + { + "name": "MockTool_test", + "description": "test description: foo", + "parameters": {"type": "OBJECT", "properties": {"test": {"type": "STRING"}}, "required": ["test"]}, + }, + { + "name": "MockTool_test_error", + "description": "test description: foo", + "parameters": {"type": "OBJECT", "properties": {"test": {"type": "STRING"}}, "required": ["test"]}, + }, + {"name": "MockTool_test_list_output", "description": "test description", "parameters": {"type": "OBJECT"}}, + {"name": "MockTool_test_no_schema", "description": "test description", "parameters": {"type": "OBJECT"}}, + { + "name": "MockTool_test_str_output", + "description": "test description: foo", + "parameters": {"type": "OBJECT", "properties": {"test": {"type": "STRING"}}, "required": ["test"]}, + }, + { + "name": "MockTool_test_without_default_memory", + "description": "test description", + "parameters": {"type": "OBJECT", "properties": {"test": {"type": "STRING"}}, "required": ["test"]}, + }, + ] + @pytest.fixture def mock_generative_model(self, mocker): mock_generative_model = mocker.patch("google.generativeai.GenerativeModel") + mock_function_call = Mock(type="tool_use", id="MockTool_test", args={"foo": "bar"}) + mock_function_call.name = "MockTool_test" mock_generative_model.return_value.generate_content.return_value = Mock( - parts=[Mock(text="model-output")], usage_metadata=Mock(prompt_token_count=5, candidates_token_count=10) + parts=[Mock(text="model-output", function_call=None), Mock(text=None, function_call=mock_function_call)], + usage_metadata=Mock(prompt_token_count=5, candidates_token_count=10), ) return mock_generative_model @@ -21,12 +52,18 @@ def mock_generative_model(self, mocker): @pytest.fixture def mock_stream_generative_model(self, mocker): mock_generative_model = mocker.patch("google.generativeai.GenerativeModel") + mock_function_call_delta = Mock(type="tool_use", id="MockTool_test", args={"foo": "bar"}) + mock_function_call_delta.name = "MockTool_test" mock_generative_model.return_value.generate_content.return_value = iter( [ Mock( parts=[Mock(text="model-output")], usage_metadata=Mock(prompt_token_count=5, candidates_token_count=5), ), + Mock( + parts=[Mock(text=None, function_call=mock_function_call_delta)], + usage_metadata=Mock(prompt_token_count=5, candidates_token_count=5), + ), Mock( parts=[Mock(text="model-output")], usage_metadata=Mock(prompt_token_count=5, candidates_token_count=5), @@ -36,82 +73,110 @@ def mock_stream_generative_model(self, mocker): return mock_generative_model - def test_init(self): - driver = GooglePromptDriver(model="gemini-pro", api_key="1234") - assert driver - - @pytest.mark.parametrize("system_enabled", [True, False]) - def test_try_run(self, mock_generative_model, system_enabled): - # Given + @pytest.fixture(params=[True, False]) + def prompt_stack(self, request): prompt_stack = PromptStack() - if system_enabled: + prompt_stack.actions = [MockTool()] + if request.param: prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") prompt_stack.add_user_message(TextArtifact("user-input")) prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) prompt_stack.add_assistant_message("assistant-input") - driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", top_p=0.5, top_k=50) - # When - text_artifact = driver.try_run(prompt_stack) + return prompt_stack - # Then - messages = [ + @pytest.fixture + def messages(self): + return [ {"parts": ["user-input"], "role": "user"}, {"parts": ["user-input"], "role": "user"}, {"parts": [{"data": b"image-data", "mime_type": "image/png"}], "role": "user"}, {"parts": ["assistant-input"], "role": "model"}, ] - if system_enabled: + + def test_init(self): + driver = GooglePromptDriver(model="gemini-pro", api_key="1234") + assert driver + + @pytest.mark.parametrize("use_native_tools", [True, False]) + def test_try_run(self, mock_generative_model, prompt_stack, messages, use_native_tools): + # Given + driver = GooglePromptDriver( + model="gemini-pro", api_key="api-key", top_p=0.5, top_k=50, use_native_tools=use_native_tools + ) + + # When + message = driver.try_run(prompt_stack) + + # Then + if prompt_stack.system_messages: assert mock_generative_model.return_value._system_instruction == ContentDict( role="system", parts=[Part(text="system-input")] ) - mock_generative_model.return_value.generate_content.assert_called_once_with( - messages, - generation_config=GenerationConfig( - max_output_tokens=None, temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[] - ), - ) - assert text_artifact.value == "model-output" - assert text_artifact.usage.input_tokens == 5 - assert text_artifact.usage.output_tokens == 10 + mock_generative_model.return_value.generate_content.assert_called_once() + # We can't use assert_called_once_with because we can't compare the FunctionDeclaration objects + call_args = mock_generative_model.return_value.generate_content.call_args + assert messages == call_args.args[0] + generation_config = call_args.kwargs["generation_config"] + assert generation_config == GenerationConfig(temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[]) + if use_native_tools: + tool_declarations = call_args.kwargs["tools"] + assert [ + MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations + ] == self.GOOGLE_TOOLS + + assert isinstance(message.value[0], TextArtifact) + assert message.value[0].value == "model-output" + assert isinstance(message.value[1], ActionArtifact) + assert message.value[1].value.tag == "MockTool_test" + assert message.value[1].value.name == "MockTool" + assert message.value[1].value.path == "test" + assert message.value[1].value.input == {"foo": "bar"} + assert message.usage.input_tokens == 5 + assert message.usage.output_tokens == 10 - @pytest.mark.parametrize("system_enabled", [True, False]) - def test_try_stream(self, mock_stream_generative_model, system_enabled): + @pytest.mark.parametrize("use_native_tools", [True, False]) + def test_try_stream(self, mock_stream_generative_model, prompt_stack, messages, use_native_tools): # Given - prompt_stack = PromptStack() - if system_enabled: - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message(TextArtifact("user-input")) - prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) - prompt_stack.add_assistant_message("assistant-input") - driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", stream=True, top_p=0.5, top_k=50) + driver = GooglePromptDriver( + model="gemini-pro", api_key="api-key", stream=True, top_p=0.5, top_k=50, use_native_tools=use_native_tools + ) # When stream = driver.try_stream(prompt_stack) # Then event = next(stream) - messages = [ - {"parts": ["user-input"], "role": "user"}, - {"parts": ["user-input"], "role": "user"}, - {"parts": [{"data": b"image-data", "mime_type": "image/png"}], "role": "user"}, - {"parts": ["assistant-input"], "role": "model"}, - ] - if system_enabled: + if prompt_stack.system_messages: assert mock_stream_generative_model.return_value._system_instruction == ContentDict( role="system", parts=[Part(text="system-input")] ) - mock_stream_generative_model.return_value.generate_content.assert_called_once_with( - messages, - stream=True, - generation_config=GenerationConfig(temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[]), + # We can't use assert_called_once_with because we can't compare the FunctionDeclaration objects + mock_stream_generative_model.return_value.generate_content.assert_called_once() + call_args = mock_stream_generative_model.return_value.generate_content.call_args + + assert messages == call_args.args[0] + assert call_args.kwargs["stream"] is True + assert call_args.kwargs["generation_config"] == GenerationConfig( + temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[] ) + if use_native_tools: + tool_declarations = call_args.kwargs["tools"] + assert [ + MessageToDict(tool_declaration.to_proto()._pb) for tool_declaration in tool_declarations + ] == self.GOOGLE_TOOLS assert isinstance(event.content, TextDeltaMessageContent) assert event.content.text == "model-output" assert event.usage.input_tokens == 5 assert event.usage.output_tokens == 5 + event = next(stream) + assert isinstance(event.content, ActionCallDeltaMessageContent) + assert event.content.tag == "MockTool_test" + assert event.content.name == "MockTool" + assert event.content.path == "test" + assert event.content.partial_input == '{"foo": "bar"}' + event = next(stream) assert event.usage.output_tokens == 5 diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index 3a86411bc..ca283ba4e 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -1,22 +1,147 @@ from griptape.artifacts import ImageArtifact, ListArtifact -from griptape.artifacts import TextArtifact +from griptape.artifacts import TextArtifact, ActionArtifact from griptape.drivers import OpenAiChatPromptDriver -from griptape.common import PromptStack, TextDeltaMessageContent +from griptape.common import PromptStack, TextDeltaMessageContent, ActionCallDeltaMessageContent from griptape.tokenizers import OpenAiTokenizer from unittest.mock import Mock from tests.mocks.mock_tokenizer import MockTokenizer +from tests.mocks.mock_tool.tool import MockTool import pytest class TestOpenAiChatPromptDriverFixtureMixin: + OPENAI_TOOLS = [ + { + "function": { + "description": "test description: foo", + "name": "MockTool_test", + "parameters": { + "$id": "Parameters Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + } + }, + "required": ["values"], + "type": "object", + }, + }, + "type": "function", + }, + { + "function": { + "description": "test description: foo", + "name": "MockTool_test_error", + "parameters": { + "$id": "Parameters Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + } + }, + "required": ["values"], + "type": "object", + }, + }, + "type": "function", + }, + { + "function": { + "description": "test description", + "name": "MockTool_test_list_output", + "parameters": { + "$id": "Parameters Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {}, + "required": [], + "type": "object", + }, + }, + "type": "function", + }, + { + "function": { + "description": "test description", + "name": "MockTool_test_no_schema", + "parameters": { + "$id": "Parameters Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {}, + "required": [], + "type": "object", + }, + }, + "type": "function", + }, + { + "function": { + "description": "test description: foo", + "name": "MockTool_test_str_output", + "parameters": { + "$id": "Parameters Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + } + }, + "required": ["values"], + "type": "object", + }, + }, + "type": "function", + }, + { + "function": { + "description": "test description", + "name": "MockTool_test_without_default_memory", + "parameters": { + "$id": "Parameters Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + } + }, + "required": ["values"], + "type": "object", + }, + }, + "type": "function", + }, + ] + @pytest.fixture def mock_chat_completion_create(self, mocker): mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create + mock_function = Mock(arguments='{"foo": "bar"}', id="mock-id") + mock_function.name = "MockTool_test" mock_chat_create.return_value = Mock( headers={}, - choices=[Mock(message=Mock(content="model-output"))], + choices=[ + Mock(message=Mock(content="model-output", tool_calls=[Mock(id="mock-id", function=mock_function)])) + ], usage=Mock(prompt_tokens=5, completion_tokens=10), - tool_calls=[], ) return mock_chat_create @@ -24,9 +149,35 @@ def mock_chat_completion_create(self, mocker): @pytest.fixture def mock_chat_completion_stream_create(self, mocker): mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create + mock_tool_call_delta_header = Mock() + mock_tool_call_delta_header.name = "MockTool_test" + mock_tool_call_delta_body = Mock(arguments='{"foo": "bar"}') + mock_tool_call_delta_body.name = None + mock_chat_create.return_value = iter( [ Mock(choices=[Mock(delta=Mock(content="model-output", tool_calls=None))], usage=None), + Mock( + choices=[ + Mock( + delta=Mock( + content=None, + tool_calls=[Mock(index=0, id="mock-id", function=mock_tool_call_delta_header)], + ) + ) + ], + usage=None, + ), + Mock( + choices=[ + Mock( + delta=Mock( + content=None, tool_calls=[Mock(index=0, id=None, function=mock_tool_call_delta_body)] + ) + ) + ], + usage=None, + ), Mock(choices=None, usage=Mock(prompt_tokens=5, completion_tokens=10)), Mock(choices=[Mock(delta=Mock(content=None, tool_calls=None))], usage=None), ] @@ -36,6 +187,7 @@ def mock_chat_completion_stream_create(self, mocker): @pytest.fixture def prompt_stack(self): prompt_stack = PromptStack() + prompt_stack.actions = [MockTool()] prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") prompt_stack.add_user_message( @@ -44,6 +196,21 @@ def prompt_stack(self): ) ) prompt_stack.add_assistant_message("assistant-input") + prompt_stack.add_action_call_message( + "thought", [ActionArtifact.Action(tag="MockTool_test", name="MockTool", path="test", input={"foo": "bar"})] + ) + prompt_stack.add_action_result_message( + "keep-going", + [ + ActionArtifact.Action( + tag="MockTool_test", + name="MockTool", + path="test", + input={"foo": "bar"}, + output=TextArtifact("tool-output"), + ) + ], + ) return prompt_stack @pytest.fixture @@ -59,6 +226,18 @@ def messages(self): ], }, {"role": "assistant", "content": "assistant-input"}, + { + "content": [{"text": "thought", "type": "text"}], + "role": "assistant", + "tool_calls": [ + { + "function": {"arguments": '{"foo": "bar"}', "name": "MockTool_test"}, + "id": "MockTool_test", + "type": "function", + } + ], + }, + {"content": "tool-output", "role": "tool", "tool_call_id": "MockTool_test"}, ] @@ -99,23 +278,37 @@ class TestOpenAiChatPromptDriver(TestOpenAiChatPromptDriverFixtureMixin): def test_init(self): assert OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_4_MODEL) - def test_try_run(self, mock_chat_completion_create, prompt_stack, messages): + @pytest.mark.parametrize("use_native_tools", [True, False]) + def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_native_tools): # Given - driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) + driver = OpenAiChatPromptDriver( + model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, use_native_tools=use_native_tools + ) # When - event = driver.try_run(prompt_stack) + message = driver.try_run(prompt_stack) # Then mock_chat_completion_create.assert_called_once_with( - model=driver.model, temperature=driver.temperature, user=driver.user, messages=messages, seed=driver.seed + model=driver.model, + temperature=driver.temperature, + user=driver.user, + messages=messages, + seed=driver.seed, + **{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, ) - assert event.value == "model-output" + assert isinstance(message.value[0], TextArtifact) + assert message.value[0].value == "model-output" + assert isinstance(message.value[1], ActionArtifact) + assert message.value[1].value.tag == "mock-id" + assert message.value[1].value.name == "MockTool" + assert message.value[1].value.path == "test" + assert message.value[1].value.input == {"foo": "bar"} def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack, messages): # Given driver = OpenAiChatPromptDriver( - model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, response_format="json_object" + model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, response_format="json_object", use_native_tools=False ) # When @@ -130,13 +323,16 @@ def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack seed=driver.seed, response_format={"type": "json_object"}, ) - assert message.value == "model-output" + assert message.value[0].value == "model-output" assert message.usage.input_tokens == 5 assert message.usage.output_tokens == 10 - def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages): + @pytest.mark.parametrize("use_native_tools", [True, False]) + def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages, use_native_tools): # Given - driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True) + driver = OpenAiChatPromptDriver( + model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True, use_native_tools=use_native_tools + ) # When stream = driver.try_stream(prompt_stack) @@ -151,11 +347,22 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages=messages, seed=driver.seed, stream_options={"include_usage": True}, + **{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {}, ) assert isinstance(event.content, TextDeltaMessageContent) assert event.content.text == "model-output" + event = next(stream) + assert isinstance(event.content, ActionCallDeltaMessageContent) + assert event.content.tag == "mock-id" + assert event.content.name == "MockTool" + assert event.content.path == "test" + + event = next(stream) + assert isinstance(event.content, ActionCallDeltaMessageContent) + assert event.content.partial_input == '{"foo": "bar"}' + event = next(stream) assert event.usage.input_tokens == 5 assert event.usage.output_tokens == 10 @@ -165,7 +372,9 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack, messages): # Given - driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=1) + driver = OpenAiChatPromptDriver( + model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=1, use_native_tools=False + ) # When event = driver.try_run(prompt_stack) @@ -179,7 +388,7 @@ def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack max_tokens=1, seed=driver.seed, ) - assert event.value == "model-output" + assert event.value[0].value == "model-output" def test_try_run_throws_when_multiple_choices_returned(self, mock_chat_completion_create, prompt_stack): # Given @@ -198,6 +407,7 @@ def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messa model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, tokenizer=MockTokenizer(model="mock-model", stop_sequences=["mock-stop"]), max_tokens=1, + use_native_tools=False, ) # When @@ -213,4 +423,4 @@ def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messa seed=driver.seed, max_tokens=1, ) - assert event.value == "model-output" + assert event.value[0].value == "model-output"