Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 23, 2024
1 parent a3b3c69 commit 9f00ec3
Show file tree
Hide file tree
Showing 12 changed files with 187 additions and 83 deletions.
30 changes: 19 additions & 11 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

from attrs import Factory, define, field
from schema import Schema
Expand Down Expand Up @@ -55,6 +55,9 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
native_structured_output_mode: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True})
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

Expand Down Expand Up @@ -116,19 +119,24 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}),
},
"additionalModelRequestFields": self.additional_model_request_fields,
**(
{"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}}
if prompt_stack.tools and self.use_native_tools
else {}
),
**self.extra_params,
}

if not self.use_native_structured_output and prompt_stack.output_schema is not None:
structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
params["tool_choice"] = {"any": {}}
if structured_ouptut_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_ouptut_tool)
if prompt_stack.output_schema is not None and self.use_native_structured_output:
if self.native_structured_output_mode == "tool":
structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
params["toolConfig"] = {
"toolChoice": {"any": {}},
}
if structured_ouptut_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_ouptut_tool)
else:
raise ValueError(f"Unsupported native structured output mode: {self.native_structured_output_mode}")

if prompt_stack.tools and self.use_native_tools:
params["toolConfig"] = {
"tools": self.__to_bedrock_tools(prompt_stack.tools),
}

return params

Expand Down
14 changes: 12 additions & 2 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional

from attrs import Factory, define, field
from attrs import Attribute, Factory, define, field
from schema import Schema

from griptape.artifacts import (
Expand Down Expand Up @@ -68,13 +68,23 @@ class AnthropicPromptDriver(BasePromptDriver):
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
native_structured_output_mode: Literal["native", "tool"] = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})
_client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Client:
return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key)

@native_structured_output_mode.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_native_structured_output_mode(self, attribute: Attribute, value: str) -> str:
if value == "native":
raise ValueError("Anthropic does not support native structured output mode.")

return value

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
params = self._base_params(prompt_stack)
Expand Down
5 changes: 4 additions & 1 deletion griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional

from attrs import Factory, define, field

Expand Down Expand Up @@ -57,6 +57,9 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True})
native_structured_output_mode: Literal["native", "tool"] = field(
default="native", kw_only=True, metadata={"serializable": True}
)
extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})

def before_run(self, prompt_stack: PromptStack) -> None:
Expand Down
38 changes: 20 additions & 18 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,37 +97,39 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))

def _base_params(self, prompt_stack: PromptStack) -> dict:
from griptape.tools.structured_output.tool import StructuredOutputTool

tool_results = []

messages = self.__to_cohere_messages(prompt_stack.messages)

return {
params = {
"model": self.model,
"messages": messages,
"temperature": self.temperature,
"stop_sequences": self.tokenizer.stop_sequences,
"max_tokens": self.max_tokens,
**({"tool_results": tool_results} if tool_results else {}),
**(
{"tools": self.__to_cohere_tools(prompt_stack.tools)}
if prompt_stack.tools and self.use_native_tools
else {}
),
**(
{
"response_format": {
"type": "json_object",
"schema": prompt_stack.output_schema.json_schema("Output"),
}
}
if not prompt_stack.tools # Respond format is not supported with tools https://docs.cohere.com/reference/chat#request.body.response_format
and prompt_stack.output_schema is not None
and self.use_native_structured_output
else {}
),
**self.extra_params,
}

if prompt_stack.output_schema is not None:
if self.use_native_structured_output:
params["response_format"] = {
"type": "json_object",
"schema": prompt_stack.output_schema.json_schema("Output"),
}
else:
# This does not work great since Cohere does not support forced tool use.
structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
if structured_output_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_output_tool)

if prompt_stack.tools and self.use_native_tools:
params["tools"] = self.__to_cohere_tools(prompt_stack.tools)

return params

def __to_cohere_messages(self, messages: list[Message]) -> list[dict]:
cohere_messages = []

Expand Down
24 changes: 15 additions & 9 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
)

def _base_params(self, prompt_stack: PromptStack) -> dict:
from griptape.tools.structured_output.tool import StructuredOutputTool

types = import_optional_dependency("google.generativeai.types")
protos = import_optional_dependency("google.generativeai.protos")

Expand All @@ -135,7 +137,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages],
)

return {
params = {
"generation_config": types.GenerationConfig(
**{
# For some reason, providing stop sequences when streaming breaks native functions
Expand All @@ -148,15 +150,19 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
},
),
**(
{
"tools": self.__to_google_tools(prompt_stack.tools),
"tool_config": {"function_calling_config": {"mode": self.tool_choice}},
}
if prompt_stack.tools and self.use_native_tools
else {}
),
}
if prompt_stack.output_schema is not None:
structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
params["tool_config"] = {
"function_calling_config": {"mode": self.tool_choice},
}
if structured_ouptut_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_ouptut_tool)

if prompt_stack.tools and self.use_native_tools:
params["tools"] = self.__to_google_tools(prompt_stack.tools)

return params

def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType:
types = import_optional_dependency("google.generativeai.types")
Expand Down
25 changes: 22 additions & 3 deletions griptape/drivers/prompt/huggingface_hub_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver):
api_token: str = field(kw_only=True, metadata={"serializable": True})
max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
tokenizer: HuggingFaceTokenizer = field(
default=Factory(
lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens),
Expand All @@ -55,7 +56,12 @@ def client(self) -> InferenceClient:
def try_run(self, prompt_stack: PromptStack) -> Message:
prompt = self.prompt_stack_to_string(prompt_stack)
full_params = self._base_params(prompt_stack)
logger.debug((prompt, full_params))
logger.debug(
{
"prompt": prompt,
**full_params,
}
)

response = self.client.text_generation(
prompt,
Expand All @@ -75,7 +81,12 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
prompt = self.prompt_stack_to_string(prompt_stack)
full_params = {**self._base_params(prompt_stack), "stream": True}
logger.debug((prompt, full_params))
logger.debug(
{
"prompt": prompt,
**full_params,
}
)

response = self.client.text_generation(prompt, **full_params)

Expand All @@ -94,12 +105,20 @@ def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))

def _base_params(self, prompt_stack: PromptStack) -> dict:
return {
params = {
"return_full_text": False,
"max_new_tokens": self.max_tokens,
**self.extra_params,
}

if prompt_stack.output_schema and self.use_native_structured_output:
# https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding
params["grammar"] = {"type": "json", "value": prompt_stack.output_schema.json_schema("Output Schema")}
del params["grammar"]["value"]["$schema"]
del params["grammar"]["value"]["$id"]

return params

def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
messages = []
for message in prompt_stack.messages:
Expand Down
14 changes: 13 additions & 1 deletion griptape/drivers/prompt/ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
raise Exception("invalid model response")

def _base_params(self, prompt_stack: PromptStack) -> dict:
from griptape.tools.structured_output.tool import StructuredOutputTool

messages = self._prompt_stack_to_messages(prompt_stack)

return {
params = {
"messages": messages,
"model": self.model,
"options": self.options,
Expand All @@ -122,6 +124,16 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**self.extra_params,
}

if prompt_stack.output_schema is not None:
if self.use_native_structured_output:
params["format"] = prompt_stack.output_schema.json_schema("Output")
else:
structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
if structured_ouptut_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_ouptut_tool)

return params

def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
ollama_messages = []
for message in prompt_stack.messages:
Expand Down
54 changes: 21 additions & 33 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,50 +144,38 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(delta))

def _base_params(self, prompt_stack: PromptStack) -> dict:
from griptape.tools.structured_output.tool import StructuredOutputTool

tools = prompt_stack.tools
if not self.use_native_structured_output and prompt_stack.output_schema is not None:
structured_ouptut_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
params["tool_choice"] = "required"
if structured_ouptut_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_ouptut_tool)

params = {
"model": self.model,
"temperature": self.temperature,
"user": self.user,
"seed": self.seed,
**(
{
"tools": self.__to_openai_tools(prompt_stack.tools),
"tool_choice": self.tool_choice,
"parallel_tool_calls": self.parallel_tool_calls,
}
if prompt_stack.tools and self.use_native_tools
else {}
),
"stop": self.tokenizer.stop_sequences,
**({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}),
**({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}),
**({"stream_options": {"include_usage": True}} if self.stream else {}),
**(
{
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "Output",
"schema": prompt_stack.output_schema.json_schema("Output"),
"strict": True,
},
}
}
if prompt_stack.output_schema is not None and self.use_native_structured_output
else {}
),
**self.extra_params,
}

if self.response_format is not None:
if prompt_stack.output_schema is not None and self.use_native_structured_output:
if self.native_structured_output_mode == "native":
params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": "Output",
"schema": prompt_stack.output_schema.json_schema("Output"),
"strict": True,
},
}
else:
params["tool_choice"] = "required"

if prompt_stack.tools and self.use_native_tools:
params["tools"] = self.__to_openai_tools(prompt_stack.tools)
params["parallel_tool_calls"] = self.parallel_tool_calls
if "tool_choice" not in params:
params["tool_choice"] = self.tool_choice

if self.response_format is not None and "response_format" not in params:
if self.response_format == {"type": "json_object"}:
params["response_format"] = self.response_format
# JSON mode still requires a system message instructing the LLM to output JSON.
Expand Down
6 changes: 6 additions & 0 deletions griptape/mixins/rule_mixin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import uuid
from typing import TypeVar

from attrs import Factory, define, field

from griptape.mixins.serializable_mixin import SerializableMixin
from griptape.rules import BaseRule, Ruleset

T = TypeVar("T", bound=BaseRule)


@define(slots=False)
class RuleMixin(SerializableMixin):
Expand All @@ -25,3 +28,6 @@ def rulesets(self) -> list[Ruleset]:
rulesets.append(Ruleset(id=self._default_ruleset_id, name=self._default_ruleset_name, rules=self.rules))

return rulesets

def get_rules_for_type(self, rule_type: type[T]) -> list[T]:
return [rule for ruleset in self.rulesets for rule in ruleset.rules if isinstance(rule, rule_type)]
Loading

0 comments on commit 9f00ec3

Please sign in to comment.