Skip to content

Commit

Permalink
Add StructuredOutputTool
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 18, 2024
1 parent 2b0eb3a commit a7ed072
Show file tree
Hide file tree
Showing 15 changed files with 195 additions and 8 deletions.
4 changes: 3 additions & 1 deletion griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from attrs import define, field
from schema import Schema

from griptape.artifacts import (
ActionArtifact,
Expand Down Expand Up @@ -31,6 +32,7 @@
class PromptStack(SerializableMixin):
messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True})
tools: list[BaseTool] = field(factory=list, kw_only=True)
output_schema: Optional[Schema] = field(default=None, kw_only=True)

@property
def system_messages(self) -> list[Message]:
Expand Down
15 changes: 15 additions & 0 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,17 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
tokenizer: BaseTokenizer
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True})
extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})

def before_run(self, prompt_stack: PromptStack) -> None:
from griptape.tools.structured_output.tool import StructuredOutputTool

if not self.use_native_structured_output and prompt_stack.output_schema is not None:
structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
if structured_ouptut_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_ouptut_tool)

EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack))

def after_run(self, result: Message) -> None:
Expand All @@ -73,11 +81,18 @@ def after_run(self, result: Message) -> None:

@observable(tags=["PromptDriver.run()"])
def run(self, prompt_input: PromptStack | BaseArtifact) -> Message:
from griptape.tools.structured_output.tool import StructuredOutputTool

if isinstance(prompt_input, BaseArtifact):
prompt_stack = PromptStack.from_artifact(prompt_input)
else:
prompt_stack = prompt_input

if not self.use_native_structured_output and prompt_stack.output_schema is not None:
structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
if structured_ouptut_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_ouptut_tool)

for attempt in self.retrying():
with attempt:
self.before_run(prompt_stack)
Expand Down
12 changes: 12 additions & 0 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
if prompt_stack.tools and self.use_native_tools
else {}
),
**(
{
"response_format": {
"type": "json_object",
"schema": prompt_stack.output_schema.json_schema("Output"),
}
}
if not prompt_stack.tools # Respond format is not supported with tools https://docs.cohere.com/reference/chat#request.body.response_format
and prompt_stack.output_schema is not None
and self.use_native_structured_output
else {}
),
**self.extra_params,
}

Expand Down
8 changes: 7 additions & 1 deletion griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class OllamaPromptDriver(BasePromptDriver):
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
Expand All @@ -79,7 +80,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
params = self._base_params(prompt_stack)
logger.debug(params)
response = self.client.chat(**params)
logger.debug(response)
logger.debug(response.model_dump())

return Message(
content=self.__to_prompt_stack_message_content(response),
Expand Down Expand Up @@ -113,6 +114,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
and not self.stream # Tool calling is only supported when not streaming
else {}
),
**(
{"format": prompt_stack.output_schema.json_schema("Output")}
if prompt_stack.output_schema and self.use_native_structured_output
else {}
),
**self.extra_params,
}

Expand Down
15 changes: 15 additions & 0 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class OpenAiChatPromptDriver(BasePromptDriver):
seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True})
ignored_exception_types: tuple[type[Exception], ...] = field(
default=Factory(
Expand Down Expand Up @@ -160,6 +161,20 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}),
**({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}),
**({"stream_options": {"include_usage": True}} if self.stream else {}),
**(
{
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "Output",
"schema": prompt_stack.output_schema.json_schema("Output"),
"strict": True,
},
}
}
if prompt_stack.output_schema is not None and self.use_native_structured_output
else {}
),
**self.extra_params,
}

Expand Down
23 changes: 20 additions & 3 deletions griptape/structures/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Callable, Optional, Union

from attrs import Attribute, Factory, define, evolve, field
from schema import Schema

from griptape.artifacts.text_artifact import TextArtifact
from griptape.common import observable
Expand Down Expand Up @@ -32,6 +33,7 @@ class Agent(Structure):
tools: list[BaseTool] = field(factory=list, kw_only=True)
max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True)
fail_fast: bool = field(default=False, kw_only=True)
output_type: Optional[Union[type, Schema]] = field(default=None, kw_only=True)

@fail_fast.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FBT001
Expand All @@ -41,18 +43,27 @@ def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FB
def __attrs_post_init__(self) -> None:
super().__attrs_post_init__()

self.prompt_driver.stream = self.stream
prompt_driver = self.prompt_driver
prompt_driver.stream = self.stream
if len(self.tasks) == 0:
if self.tools:
task = ToolkitTask(
self.input,
prompt_driver=self.prompt_driver,
prompt_driver=prompt_driver,
tools=self.tools,
max_meta_memory_entries=self.max_meta_memory_entries,
output_schema=self._build_schema_from_type(self.output_type)
if self.output_type is not None
else None,
)
else:
task = PromptTask(
self.input, prompt_driver=self.prompt_driver, max_meta_memory_entries=self.max_meta_memory_entries
self.input,
prompt_driver=prompt_driver,
max_meta_memory_entries=self.max_meta_memory_entries,
output_schema=self._build_schema_from_type(self.output_type)
if self.output_type is not None
else None,
)

self.add_task(task)
Expand Down Expand Up @@ -80,3 +91,9 @@ def try_run(self, *args) -> Agent:
self.task.run()

return self

def _build_schema_from_type(self, output_type: type | Schema) -> Schema:
if isinstance(output_type, Schema):
return output_type
else:
return Schema(output_type)
11 changes: 10 additions & 1 deletion griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from attrs import define, field

from griptape import utils
from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact
from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, JsonArtifact, ListArtifact, TextArtifact
from griptape.common import ToolAction
from griptape.configs import Defaults
from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent
from griptape.mixins.actions_subtask_origin_mixin import ActionsSubtaskOriginMixin
from griptape.tasks import BaseTask
from griptape.tools.structured_output.tool import StructuredOutputTool
from griptape.utils import remove_null_values_in_dict_recursively, with_contextvars

if TYPE_CHECKING:
Expand Down Expand Up @@ -87,6 +88,14 @@ def attach_to(self, parent_task: BaseTask) -> None:
self.__init_from_prompt(self.input.to_text())
else:
self.__init_from_artifacts(self.input)

structured_outputs = [a for a in self.actions if isinstance(a.tool, StructuredOutputTool)]
if structured_outputs:
output_values = [JsonArtifact(a.input["values"]) for a in structured_outputs]
if len(structured_outputs) > 1:
self.output = ListArtifact(output_values)
else:
self.output = output_values[0]
except Exception as e:
logger.error("Subtask %s\nError parsing tool action: %s", self.origin_task.id, e)

Expand Down
4 changes: 3 additions & 1 deletion griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Callable, Optional, Union

from attrs import NOTHING, Factory, NothingType, define, field
from schema import Schema

from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact
from griptape.common import PromptStack
Expand Down Expand Up @@ -38,6 +39,7 @@ class PromptTask(RuleMixin, BaseTask):
default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""),
alias="input",
)
output_schema: Optional[Schema] = field(default=None, kw_only=True)

@property
def rulesets(self) -> list:
Expand Down Expand Up @@ -67,7 +69,7 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask],

@property
def prompt_stack(self) -> PromptStack:
stack = PromptStack()
stack = PromptStack(output_schema=self.output_schema)
memory = self.conversation_memory

system_template = self.generate_system_template(self)
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/toolkit_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def tool_output_memory(self) -> list[TaskMemory]:

@property
def prompt_stack(self) -> PromptStack:
stack = PromptStack(tools=self.tools)
stack = PromptStack(tools=self.tools, output_schema=self.output_schema)
memory = self.structure.conversation_memory if self.structure is not None else None

stack.add_system_message(self.generate_system_template(self))
Expand Down
2 changes: 2 additions & 0 deletions griptape/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .extraction.tool import ExtractionTool
from .prompt_summary.tool import PromptSummaryTool
from .query.tool import QueryTool
from .structured_output.tool import StructuredOutputTool

__all__ = [
"BaseTool",
Expand Down Expand Up @@ -50,4 +51,5 @@
"ExtractionTool",
"PromptSummaryTool",
"QueryTool",
"StructuredOutputTool",
]
Empty file.
20 changes: 20 additions & 0 deletions griptape/tools/structured_output/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from attrs import define, field
from schema import Schema

from griptape.artifacts import BaseArtifact, JsonArtifact
from griptape.tools import BaseTool
from griptape.utils.decorators import activity


@define
class StructuredOutputTool(BaseTool):
output_schema: Schema = field(kw_only=True)

@activity(
config={
"description": "Used to provide the final response which ends this conversation.",
"schema": lambda self: self.output_schema,
}
)
def provide_output(self, params: dict) -> BaseArtifact:
return JsonArtifact(params["values"])
8 changes: 8 additions & 0 deletions tests/unit/structures/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,11 @@ def test_stream_mutation(self):
assert isinstance(agent.tasks[0], PromptTask)
assert agent.tasks[0].prompt_driver.stream is True
assert agent.tasks[0].prompt_driver is not prompt_driver

def test_output_type_primitive(self):
from griptape.tools import StructuredOutputTool

agent = Agent(output_type=str)

assert isinstance(agent.tools[0], StructuredOutputTool)
assert agent.tools[0].output_schema == agent._build_schema_from_type(str)
66 changes: 66 additions & 0 deletions tests/unit/tasks/test_actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from griptape.artifacts import ActionArtifact, ListArtifact, TextArtifact
from griptape.artifacts.error_artifact import ErrorArtifact
from griptape.artifacts.json_artifact import JsonArtifact
from griptape.common import ToolAction
from griptape.structures import Agent
from griptape.tasks import ActionsSubtask, ToolkitTask
Expand Down Expand Up @@ -257,3 +258,68 @@ def test_origin_task(self):

with pytest.raises(Exception, match="ActionSubtask has no origin task."):
assert ActionsSubtask("test").origin_task

def test_structured_output_tool(self):
import schema

from griptape.tools.structured_output.tool import StructuredOutputTool

actions = ListArtifact(
[
ActionArtifact(
ToolAction(
tag="foo",
name="StructuredOutputTool",
path="provide_output",
input={"values": {"test": "value"}},
)
),
]
)

task = ToolkitTask(tools=[StructuredOutputTool(output_schema=schema.Schema({"test": str}))])
Agent().add_task(task)
subtask = task.add_subtask(ActionsSubtask(actions))

assert isinstance(subtask.output, JsonArtifact)
assert subtask.output.value == {"test": "value"}

def test_structured_output_tool_multiple(self):
import schema

from griptape.tools.structured_output.tool import StructuredOutputTool

actions = ListArtifact(
[
ActionArtifact(
ToolAction(
tag="foo",
name="StructuredOutputTool1",
path="provide_output",
input={"values": {"test1": "value"}},
)
),
ActionArtifact(
ToolAction(
tag="foo",
name="StructuredOutputTool2",
path="provide_output",
input={"values": {"test2": "value"}},
)
),
]
)

task = ToolkitTask(
tools=[
StructuredOutputTool(name="StructuredOutputTool1", output_schema=schema.Schema({"test": str})),
StructuredOutputTool(name="StructuredOutputTool2", output_schema=schema.Schema({"test": str})),
]
)
Agent().add_task(task)
subtask = task.add_subtask(ActionsSubtask(actions))

assert isinstance(subtask.output, ListArtifact)
assert len(subtask.output.value) == 2
assert subtask.output.value[0].value == {"test1": "value"}
assert subtask.output.value[1].value == {"test2": "value"}
13 changes: 13 additions & 0 deletions tests/unit/tools/test_structured_output_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest
import schema

from griptape.tools import StructuredOutputTool


class TestStructuredOutputTool:
@pytest.fixture()
def tool(self):
return StructuredOutputTool(output_schema=schema.Schema({"foo": "bar"}))

def test_provide_output(self, tool):
assert tool.provide_output({"values": {"foo": "bar"}}).value == {"foo": "bar"}

0 comments on commit a7ed072

Please sign in to comment.