Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 23, 2024
1 parent 56b8fff commit 997ef3d
Show file tree
Hide file tree
Showing 24 changed files with 163 additions and 137 deletions.
27 changes: 12 additions & 15 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
native_structured_output_mode: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
Expand Down Expand Up @@ -105,8 +106,6 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
raise Exception("model response is empty")

def _base_params(self, prompt_stack: PromptStack) -> dict:
from griptape.tools.structured_output.tool import StructuredOutputTool

system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages]
messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()])

Expand All @@ -119,23 +118,21 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}),
},
"additionalModelRequestFields": self.additional_model_request_fields,
**(
{"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}}
if prompt_stack.tools and self.use_native_tools
else {}
),
**self.extra_params,
}

if prompt_stack.output_schema is not None and self.use_native_structured_output:
if self.native_structured_output_mode == "tool":
structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
params["toolConfig"] = {
"toolChoice": {"any": {}},
}
if structured_output_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_output_tool)
else:
raise ValueError(f"Unsupported native structured output mode: {self.native_structured_output_mode}")

if prompt_stack.tools and self.use_native_tools:
if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.native_structured_output_mode == "tool"
):
params["toolConfig"] = {

Check warning on line 134 in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/amazon_bedrock_prompt_driver.py#L134

Added line #L134 was not covered by tests
"tools": self.__to_bedrock_tools(prompt_stack.tools),
"toolChoice": {"any": {}},
}

return params
Expand Down
12 changes: 11 additions & 1 deletion griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class AnthropicPromptDriver(BasePromptDriver):
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
native_structured_output_mode: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
Expand Down Expand Up @@ -120,7 +121,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
system_messages = prompt_stack.system_messages
system_message = system_messages[0].to_text() if system_messages else None

return {
params = {
"model": self.model,
"temperature": self.temperature,
"stop_sequences": self.tokenizer.stop_sequences,
Expand All @@ -137,6 +138,15 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.native_structured_output_mode == "tool"
):
params["tool_choice"] = {"type": "any"}

Check warning on line 146 in griptape/drivers/prompt/anthropic_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/anthropic_prompt_driver.py#L146

Added line #L146 was not covered by tests

return params

def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]:
return [
{"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)}
Expand Down
7 changes: 0 additions & 7 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,11 @@ def after_run(self, result: Message) -> None:

@observable(tags=["PromptDriver.run()"])
def run(self, prompt_input: PromptStack | BaseArtifact) -> Message:
from griptape.tools.structured_output.tool import StructuredOutputTool

if isinstance(prompt_input, BaseArtifact):
prompt_stack = PromptStack.from_artifact(prompt_input)
else:
prompt_stack = prompt_input

if not self.use_native_structured_output and prompt_stack.output_schema is not None:
structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
if structured_ouptut_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_ouptut_tool)

for attempt in self.retrying():
with attempt:
self.before_run(prompt_stack)
Expand Down
31 changes: 15 additions & 16 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class CoherePromptDriver(BasePromptDriver):
model: str = field(metadata={"serializable": True})
force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
tokenizer: BaseTokenizer = field(
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
Expand Down Expand Up @@ -97,8 +98,6 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))

def _base_params(self, prompt_stack: PromptStack) -> dict:
from griptape.tools.structured_output.tool import StructuredOutputTool

tool_results = []

messages = self.__to_cohere_messages(prompt_stack.messages)
Expand All @@ -110,23 +109,23 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"stop_sequences": self.tokenizer.stop_sequences,
"max_tokens": self.max_tokens,
**({"tool_results": tool_results} if tool_results else {}),
**(
{"tools": self.__to_cohere_tools(prompt_stack.tools)}
if prompt_stack.tools and self.use_native_tools
else {}
),
**self.extra_params,
}

if prompt_stack.output_schema is not None:
if self.use_native_structured_output:
params["response_format"] = {
"type": "json_object",
"schema": prompt_stack.output_schema.json_schema("Output"),
}
else:
# This does not work great since Cohere does not support forced tool use.
structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
if structured_output_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_output_tool)

if prompt_stack.tools and self.use_native_tools:
params["tools"] = self.__to_cohere_tools(prompt_stack.tools)
if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.native_structured_output_mode == "native"
):
params["response_format"] = {

Check warning on line 125 in griptape/drivers/prompt/cohere_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/cohere_prompt_driver.py#L125

Added line #L125 was not covered by tests
"type": "json_object",
"schema": prompt_stack.output_schema.json_schema("Output"),
}

return params

Expand Down
29 changes: 17 additions & 12 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class GooglePromptDriver(BasePromptDriver):
top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True})
_client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

Expand Down Expand Up @@ -125,8 +126,6 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
)

def _base_params(self, prompt_stack: PromptStack) -> dict:
from griptape.tools.structured_output.tool import StructuredOutputTool

types = import_optional_dependency("google.generativeai.types")
protos = import_optional_dependency("google.generativeai.protos")

Expand All @@ -150,17 +149,23 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
},
),
**(
{
"tools": self.__to_google_tools(prompt_stack.tools),
"tool_config": {"function_calling_config": {"mode": self.tool_choice}},
}
if prompt_stack.tools and self.use_native_tools
else {}
),
}
if prompt_stack.output_schema is not None:
structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
params["tool_config"] = {
"function_calling_config": {"mode": self.tool_choice},
}
if structured_ouptut_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_ouptut_tool)

if prompt_stack.tools and self.use_native_tools:
params["tools"] = self.__to_google_tools(prompt_stack.tools)
if prompt_stack.output_schema is not None and self.use_native_structured_output:
if self.native_structured_output_mode == "tool":
params["tool_config"] = {

Check warning on line 163 in griptape/drivers/prompt/google_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/google_prompt_driver.py#L163

Added line #L163 was not covered by tests
"function_calling_config": {"mode": "auto"},
}
elif self.native_structured_output_mode == "native":

Check warning on line 166 in griptape/drivers/prompt/google_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/google_prompt_driver.py#L166

Added line #L166 was not covered by tests
# TODO: Add support for native structured output
...

return params

Expand Down
7 changes: 6 additions & 1 deletion griptape/drivers/prompt/huggingface_hub_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,14 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if prompt_stack.output_schema and self.use_native_structured_output:
if (
prompt_stack.output_schema
and self.use_native_structured_output
and self.native_structured_output_mode == "native"
):
# https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding
params["grammar"] = {"type": "json", "value": prompt_stack.output_schema.json_schema("Output Schema")}

Check warning on line 120 in griptape/drivers/prompt/huggingface_hub_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/huggingface_hub_prompt_driver.py#L120

Added line #L120 was not covered by tests
# Grammar does not support $schema and $id
del params["grammar"]["value"]["$schema"]
del params["grammar"]["value"]["$id"]

Check warning on line 123 in griptape/drivers/prompt/huggingface_hub_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/huggingface_hub_prompt_driver.py#L122-L123

Added lines #L122 - L123 were not covered by tests

Expand Down
15 changes: 6 additions & 9 deletions griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
raise Exception("invalid model response")

def _base_params(self, prompt_stack: PromptStack) -> dict:
from griptape.tools.structured_output.tool import StructuredOutputTool

messages = self._prompt_stack_to_messages(prompt_stack)

params = {
Expand All @@ -124,13 +122,12 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if prompt_stack.output_schema is not None:
if self.use_native_structured_output:
params["format"] = prompt_stack.output_schema.json_schema("Output")
else:
structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
if structured_ouptut_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_ouptut_tool)
if (
prompt_stack.output_schema is not None
and self.use_native_structured_output
and self.native_structured_output_mode == "tool"
):
params["format"] = prompt_stack.output_schema.json_schema("Output")

Check warning on line 130 in griptape/drivers/prompt/ollama_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/ollama_prompt_driver.py#L130

Added line #L130 was not covered by tests

return params

Expand Down
3 changes: 1 addition & 2 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"temperature": self.temperature,
"user": self.user,
"seed": self.seed,
"stop": self.tokenizer.stop_sequences,
**({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}),
**({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}),
**({"stream_options": {"include_usage": True}} if self.stream else {}),
Expand All @@ -166,7 +165,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"strict": True,
},
}
else:
elif self.native_structured_output_mode == "tool":
params["tool_choice"] = "required"

Check warning on line 169 in griptape/drivers/prompt/openai_chat_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/openai_chat_prompt_driver.py#L169

Added line #L169 was not covered by tests

if prompt_stack.tools and self.use_native_tools:
Expand Down
3 changes: 0 additions & 3 deletions griptape/mixins/rule_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,3 @@ def rulesets(self) -> list[Ruleset]:
rulesets.append(Ruleset(id=self._default_ruleset_id, name=self._default_ruleset_name, rules=self.rules))

return rulesets

def get_rules_for_type(self, rule_type: type[T]) -> list[T]:
return [rule for ruleset in self.rulesets for rule in ruleset.rules if isinstance(rule, rule_type)]
3 changes: 3 additions & 0 deletions griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def _resolve_types(cls, attrs_cls: type) -> None:
from collections.abc import Sequence
from typing import Any

from schema import Schema

from griptape.artifacts import BaseArtifact
from griptape.common import (
BaseDeltaMessageContent,
Expand Down Expand Up @@ -215,6 +217,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
"BaseRule": BaseRule,
"Ruleset": Ruleset,
# Third party modules
"Schema": Schema,
"Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any,
"ClientV2": import_optional_dependency("cohere").ClientV2 if is_dependency_installed("cohere") else Any,
"GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel
Expand Down
10 changes: 2 additions & 8 deletions griptape/structures/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, Callable, Optional, Union

from attrs import Attribute, Factory, define, evolve, field
from schema import Schema

from griptape.artifacts.text_artifact import TextArtifact
from griptape.common import observable
Expand All @@ -12,6 +11,8 @@
from griptape.tasks import PromptTask

if TYPE_CHECKING:
from schema import Schema

from griptape.artifacts import BaseArtifact
from griptape.drivers import BasePromptDriver
from griptape.tasks import BaseTask
Expand Down Expand Up @@ -51,7 +52,6 @@ def __attrs_post_init__(self) -> None:
prompt_driver=self.prompt_driver,
tools=self.tools,
max_meta_memory_entries=self.max_meta_memory_entries,
output_schema=self._build_schema_from_type(self.output_type) if self.output_type is not None else None,
)
self.add_task(task)

Expand All @@ -78,9 +78,3 @@ def try_run(self, *args) -> Agent:
self.task.run()

return self

def _build_schema_from_type(self, output_type: type | Schema) -> Schema:
if isinstance(output_type, Schema):
return output_type
else:
return Schema(output_type)
Loading

0 comments on commit 997ef3d

Please sign in to comment.