From 382ca1f6e4c688075de37622ace8acfcfa7e4a2d Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 2 Oct 2024 10:34:21 -0700 Subject: [PATCH] Add support for openai structure outputs --- CHANGELOG.md | 2 + MIGRATION.md | 18 +++++++ .../drivers/src/prompt_drivers_3.py | 17 ++++-- .../prompt/openai_chat_prompt_driver.py | 15 +++--- .../prompt/test_openai_chat_prompt_driver.py | 52 ++++++++++++++++++- tests/utils/structure_tester.py | 2 +- 6 files changed, 92 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e73db1340..9b4fb656b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `LocalRulesetDriver` for loading a `Ruleset` from a local `.json` file. - `GriptapeCloudRulesetDriver` for loading a `Ruleset` resource from Griptape Cloud. - Parameter `alias` on `GriptapeCloudConversationMemoryDriver` for fetching a Thread by alias. +- Basic support for OpenAi Structured Output via `OpenAiChatPromptDriver.response_format` parameter. ### Changed - **BREAKING**: Renamed parameters on several classes to `client`: @@ -61,6 +62,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: `CsvExtractionEngine.column_names` is now required. - **BREAKING**: Renamed`RuleMixin.all_rulesets` to `RuleMixin.rulesets`. - **BREAKING**: Renamed `GriptapeCloudKnowledgeBaseVectorStoreDriver` to `GriptapeCloudVectorStoreDriver`. +- **BREAKING**: `OpenAiChatPromptDriver.response_format` is now a `dict` instead of a `str`. - `MarkdownifyWebScraperDriver.DEFAULT_EXCLUDE_TAGS` now includes media/blob-like HTML tags - `StructureRunTask` now inherits from `PromptTask`. - Several places where API clients are initialized are now lazy loaded. diff --git a/MIGRATION.md b/MIGRATION.md index d41a26ee9..474a2a9f3 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -168,6 +168,24 @@ from griptape.drivers.griptape_cloud_vector_store_driver import GriptapeCloudVec driver = GriptapeCloudVectorStoreDriver(...) ``` +### `OpenAiChatPromptDriver.response_format` is now a `dict` instead of a `str`. + +`OpenAiChatPromptDriver.response_format` is now structured as the `openai` SDK accepts it. + +#### Before +```python +driver = OpenAiChatPromptDriver( + response_format="json_object" +) +``` + +#### After +```python +driver = OpenAiChatPromptDriver( + response_format={"type": "json_object"} +) +``` + ## 0.31.X to 0.32.X ### Removed `DataframeLoader` diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_3.py b/docs/griptape-framework/drivers/src/prompt_drivers_3.py index 8e85ce887..bf37f2d72 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_3.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_3.py @@ -1,19 +1,26 @@ import os +import schema + from griptape.drivers import OpenAiChatPromptDriver -from griptape.rules import Rule from griptape.structures import Agent agent = Agent( prompt_driver=OpenAiChatPromptDriver( api_key=os.environ["OPENAI_API_KEY"], + model="gpt-4o-2024-08-06", temperature=0.1, - model="gpt-4o", - response_format="json_object", seed=42, + response_format={ + "type": "json_schema", + "json_schema": { + "strict": True, + "name": "Output", + "schema": schema.Schema({"css_code": str, "relevant_emojies": [str]}).json_schema("Output Schema"), + }, + }, ), - input="You will be provided with a description of a mood, and your task is to generate the CSS code for a color that matches it. Description: {{ args[0] }}", - rules=[Rule(value='Write your output in json with a single key called "css_code".')], + input="You will be provided with a description of a mood, and your task is to generate the CSS color code for a color that matches it. Description: {{ args[0] }}", ) agent.run("Blue sky at dusk.") diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index bab20d3f0..ec10ab72e 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Optional import openai from attrs import Factory, define, field @@ -62,7 +62,7 @@ class OpenAiChatPromptDriver(BasePromptDriver): kw_only=True, ) user: str = field(default="", kw_only=True, metadata={"serializable": True}) - response_format: Optional[Literal["json_object"]] = field( + response_format: Optional[dict] = field( default=None, kw_only=True, metadata={"serializable": True}, @@ -145,10 +145,13 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **({"stream_options": {"include_usage": True}} if self.stream else {}), } - if self.response_format == "json_object": - params["response_format"] = {"type": "json_object"} - # JSON mode still requires a system message instructing the LLM to output JSON. - prompt_stack.add_system_message("Provide your response as a valid JSON object.") + if self.response_format is not None: + if self.response_format == {"type": "json_object"}: + params["response_format"] = self.response_format + # JSON mode still requires a system message instructing the LLM to output JSON. + prompt_stack.add_system_message("Provide your response as a valid JSON object.") + else: + params["response_format"] = self.response_format messages = self.__to_openai_messages(prompt_stack.messages) diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index ae42aa3a1..dc0cd0555 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -1,6 +1,7 @@ from unittest.mock import Mock import pytest +import schema from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction @@ -343,10 +344,12 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_ assert message.value[1].value.path == "test" assert message.value[1].value.input == {"foo": "bar"} - def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack, messages): + def test_try_run_response_format_json_object(self, mock_chat_completion_create, prompt_stack, messages): # Given driver = OpenAiChatPromptDriver( - model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, response_format="json_object", use_native_tools=False + model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, + response_format={"type": "json_object"}, + use_native_tools=False, ) # When @@ -365,6 +368,51 @@ def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack assert message.usage.input_tokens == 5 assert message.usage.output_tokens == 10 + def test_try_run_response_format_json_schema(self, mock_chat_completion_create, prompt_stack, messages): + # Given + driver = OpenAiChatPromptDriver( + model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, + response_format={ + "type": "json_schema", + "json_schema": { + "strict": True, + "name": "OutputSchema", + "schema": schema.Schema({"test": str}).json_schema("Output Schema"), + }, + }, + use_native_tools=False, + ) + + # When + message = driver.try_run(prompt_stack) + + # Then + mock_chat_completion_create.assert_called_once_with( + model=driver.model, + temperature=driver.temperature, + user=driver.user, + messages=[*messages], + seed=driver.seed, + response_format={ + "json_schema": { + "schema": { + "$id": "Output Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + }, + "name": "OutputSchema", + "strict": True, + }, + "type": "json_schema", + }, + ) + assert message.value[0].value == "model-output" + assert message.usage.input_tokens == 5 + assert message.usage.output_tokens == 10 + @pytest.mark.parametrize("use_native_tools", [True, False]) def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages, use_native_tools): # Given diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index ac5b3c771..c943525b6 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -235,7 +235,7 @@ def verify_structure_output(self, structure) -> dict: model="gpt-4o", azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"], azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], - response_format="json_object", + response_format={"type": "json_object"}, ) output_schema = Schema( {