Skip to content

Commit

Permalink
Add support for openai structure outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 9, 2024
1 parent 12ac9e9 commit 382ca1f
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 14 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `LocalRulesetDriver` for loading a `Ruleset` from a local `.json` file.
- `GriptapeCloudRulesetDriver` for loading a `Ruleset` resource from Griptape Cloud.
- Parameter `alias` on `GriptapeCloudConversationMemoryDriver` for fetching a Thread by alias.
- Basic support for OpenAi Structured Output via `OpenAiChatPromptDriver.response_format` parameter.

### Changed
- **BREAKING**: Renamed parameters on several classes to `client`:
Expand Down Expand Up @@ -61,6 +62,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: `CsvExtractionEngine.column_names` is now required.
- **BREAKING**: Renamed`RuleMixin.all_rulesets` to `RuleMixin.rulesets`.
- **BREAKING**: Renamed `GriptapeCloudKnowledgeBaseVectorStoreDriver` to `GriptapeCloudVectorStoreDriver`.
- **BREAKING**: `OpenAiChatPromptDriver.response_format` is now a `dict` instead of a `str`.
- `MarkdownifyWebScraperDriver.DEFAULT_EXCLUDE_TAGS` now includes media/blob-like HTML tags
- `StructureRunTask` now inherits from `PromptTask`.
- Several places where API clients are initialized are now lazy loaded.
Expand Down
18 changes: 18 additions & 0 deletions MIGRATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,24 @@ from griptape.drivers.griptape_cloud_vector_store_driver import GriptapeCloudVec
driver = GriptapeCloudVectorStoreDriver(...)
```

### `OpenAiChatPromptDriver.response_format` is now a `dict` instead of a `str`.

`OpenAiChatPromptDriver.response_format` is now structured as the `openai` SDK accepts it.

#### Before
```python
driver = OpenAiChatPromptDriver(
response_format="json_object"
)
```

#### After
```python
driver = OpenAiChatPromptDriver(
response_format={"type": "json_object"}
)
```

## 0.31.X to 0.32.X

### Removed `DataframeLoader`
Expand Down
17 changes: 12 additions & 5 deletions docs/griptape-framework/drivers/src/prompt_drivers_3.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
import os

import schema

from griptape.drivers import OpenAiChatPromptDriver
from griptape.rules import Rule
from griptape.structures import Agent

agent = Agent(
prompt_driver=OpenAiChatPromptDriver(
api_key=os.environ["OPENAI_API_KEY"],
model="gpt-4o-2024-08-06",
temperature=0.1,
model="gpt-4o",
response_format="json_object",
seed=42,
response_format={
"type": "json_schema",
"json_schema": {
"strict": True,
"name": "Output",
"schema": schema.Schema({"css_code": str, "relevant_emojies": [str]}).json_schema("Output Schema"),
},
},
),
input="You will be provided with a description of a mood, and your task is to generate the CSS code for a color that matches it. Description: {{ args[0] }}",
rules=[Rule(value='Write your output in json with a single key called "css_code".')],
input="You will be provided with a description of a mood, and your task is to generate the CSS color code for a color that matches it. Description: {{ args[0] }}",
)

agent.run("Blue sky at dusk.")
15 changes: 9 additions & 6 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Optional

import openai
from attrs import Factory, define, field
Expand Down Expand Up @@ -62,7 +62,7 @@ class OpenAiChatPromptDriver(BasePromptDriver):
kw_only=True,
)
user: str = field(default="", kw_only=True, metadata={"serializable": True})
response_format: Optional[Literal["json_object"]] = field(
response_format: Optional[dict] = field(
default=None,
kw_only=True,
metadata={"serializable": True},
Expand Down Expand Up @@ -145,10 +145,13 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**({"stream_options": {"include_usage": True}} if self.stream else {}),
}

if self.response_format == "json_object":
params["response_format"] = {"type": "json_object"}
# JSON mode still requires a system message instructing the LLM to output JSON.
prompt_stack.add_system_message("Provide your response as a valid JSON object.")
if self.response_format is not None:
if self.response_format == {"type": "json_object"}:
params["response_format"] = self.response_format
# JSON mode still requires a system message instructing the LLM to output JSON.
prompt_stack.add_system_message("Provide your response as a valid JSON object.")
else:
params["response_format"] = self.response_format

messages = self.__to_openai_messages(prompt_stack.messages)

Expand Down
52 changes: 50 additions & 2 deletions tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import Mock

import pytest
import schema

from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact
from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction
Expand Down Expand Up @@ -343,10 +344,12 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_
assert message.value[1].value.path == "test"
assert message.value[1].value.input == {"foo": "bar"}

def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack, messages):
def test_try_run_response_format_json_object(self, mock_chat_completion_create, prompt_stack, messages):
# Given
driver = OpenAiChatPromptDriver(
model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, response_format="json_object", use_native_tools=False
model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL,
response_format={"type": "json_object"},
use_native_tools=False,
)

# When
Expand All @@ -365,6 +368,51 @@ def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack
assert message.usage.input_tokens == 5
assert message.usage.output_tokens == 10

def test_try_run_response_format_json_schema(self, mock_chat_completion_create, prompt_stack, messages):
# Given
driver = OpenAiChatPromptDriver(
model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL,
response_format={
"type": "json_schema",
"json_schema": {
"strict": True,
"name": "OutputSchema",
"schema": schema.Schema({"test": str}).json_schema("Output Schema"),
},
},
use_native_tools=False,
)

# When
message = driver.try_run(prompt_stack)

# Then
mock_chat_completion_create.assert_called_once_with(
model=driver.model,
temperature=driver.temperature,
user=driver.user,
messages=[*messages],
seed=driver.seed,
response_format={
"json_schema": {
"schema": {
"$id": "Output Schema",
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": False,
"properties": {"test": {"type": "string"}},
"required": ["test"],
"type": "object",
},
"name": "OutputSchema",
"strict": True,
},
"type": "json_schema",
},
)
assert message.value[0].value == "model-output"
assert message.usage.input_tokens == 5
assert message.usage.output_tokens == 10

@pytest.mark.parametrize("use_native_tools", [True, False])
def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages, use_native_tools):
# Given
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/structure_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def verify_structure_output(self, structure) -> dict:
model="gpt-4o",
azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"],
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"],
response_format="json_object",
response_format={"type": "json_object"},
)
output_schema = Schema(
{
Expand Down

0 comments on commit 382ca1f

Please sign in to comment.