Skip to content

Commit

Permalink
Add anthropic support
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jun 17, 2024
1 parent 6d66095 commit 1484e56
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 38 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from attrs import define, field
from typing import Sequence
from collections.abc import Sequence

from griptape.artifacts.base_artifact import BaseArtifact
from griptape.common import BaseDeltaPromptStackContent, BasePromptStackContent
Expand All @@ -11,8 +11,6 @@
class ActionResultPromptStackContent(BasePromptStackContent):
artifact: BaseArtifact = field(metadata={"serializable": True})
action_tag: str = field(kw_only=True, metadata={"serializable": True})
action_name: str = field(kw_only=True, metadata={"serializable": True})
action_path: str = field(kw_only=True, metadata={"serializable": True})

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> ActionResultPromptStackContent:
Expand Down
60 changes: 50 additions & 10 deletions griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,16 @@

from attrs import define, field

from griptape.artifacts import ActionCallArtifact, BaseArtifact, ImageArtifact, ListArtifact, TextArtifact
from griptape.common import (
ActionCallPromptStackContent,
ImagePromptStackContent,
PromptStackElement,
TextPromptStackContent,
)
from griptape.artifacts import BaseArtifact, ImageArtifact, ListArtifact, TextArtifact
from griptape.artifacts.action_call_artifact import ActionCallArtifact
from griptape.common import ImagePromptStackContent, PromptStackElement, TextPromptStackContent, BasePromptStackContent
from griptape.common.prompt_stack.contents.action_call_prompt_stack_content import ActionCallPromptStackContent
from griptape.common.prompt_stack.contents.action_result_prompt_stack_content import ActionResultPromptStackContent
from griptape.mixins import SerializableMixin

if TYPE_CHECKING:
from griptape.tools import BaseTool
from griptape.tasks.actions_subtask import ActionsSubtask
from griptape.tools import BaseTool


@define
Expand All @@ -35,19 +32,62 @@ def add_input(self, content: str | BaseArtifact, role: str) -> PromptStackElemen
def add_system_input(self, content: str | BaseArtifact) -> PromptStackElement:
return self.add_input(content, PromptStackElement.SYSTEM_ROLE)

def add_user_input(self, input_content: str | BaseArtifact) -> PromptStackElement:
return self.add_input(input_content, PromptStackElement.USER_ROLE)
def add_user_input(self, content: str | BaseArtifact) -> PromptStackElement:
return self.add_input(content, PromptStackElement.USER_ROLE)

def add_assistant_input(self, content: str | BaseArtifact) -> PromptStackElement:
return self.add_input(content, PromptStackElement.ASSISTANT_ROLE)

def add_action_call_input(self, thought: Optional[str], actions: list[ActionsSubtask.Action]) -> PromptStackElement:
thought_content = self.__process_content(thought) if thought else []

action_calls_content = [
ActionCallPromptStackContent(
ActionCallArtifact(
ActionCallArtifact.ActionCall(
tag=action.tag, name=action.name, path=action.path, input=json.dumps(action.input)
)
)
)
for action in actions
]

self.inputs.append(
PromptStackElement(
content=[*thought_content, *action_calls_content], role=PromptStackElement.ASSISTANT_ROLE
)
)

return self.inputs[-1]

def add_action_result_input(
self, instructions: Optional[str | BaseArtifact], actions: list[ActionsSubtask.Action]
) -> PromptStackElement:
instructions_content = self.__process_content(instructions) if instructions else []

action_results_content = [
ActionResultPromptStackContent(action.output, action_tag=action.tag)
for action in actions
if action.output is not None
]

self.inputs.append(
PromptStackElement(
content=[*instructions_content, *action_results_content], role=PromptStackElement.USER_ROLE
)
)

return self.inputs[-1]

def __process_content(self, content: str | BaseArtifact) -> list[BasePromptStackContent]:
if isinstance(content, str):
return [TextPromptStackContent(TextArtifact(content))]
elif isinstance(content, TextArtifact):
return [TextPromptStackContent(content)]
elif isinstance(content, ImageArtifact):
return [ImagePromptStackContent(content)]
elif isinstance(content, ActionCallArtifact):
return [ActionCallPromptStackContent(content)]
elif isinstance(content, ListArtifact):
processed_contents = [self.__process_content(artifact) for artifact in content.value]
flattened_content = [
Expand Down
112 changes: 95 additions & 17 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import json
from collections.abc import Iterator
from typing import Any, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.artifacts import ActionCallArtifact, ErrorArtifact, TextArtifact
from griptape.common import (
BaseDeltaPromptStackContent,
BasePromptStackContent,
Expand All @@ -16,13 +17,20 @@
PromptStackElement,
TextPromptStackContent,
)
from griptape.common.prompt_stack.contents.action_call_prompt_stack_content import ActionCallPromptStackContent
from griptape.common.prompt_stack.contents.action_result_prompt_stack_content import ActionResultPromptStackContent
from griptape.common.prompt_stack.contents.delta_action_call_prompt_stack_content import (
DeltaActionCallPromptStackContent,
)
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import AnthropicTokenizer, BaseTokenizer
from griptape.utils import import_optional_dependency
from schema import Schema

if TYPE_CHECKING:
from anthropic.types import ContentBlockDeltaEvent
from anthropic.types import ContentBlock
from anthropic import Client
from anthropic.types import ContentBlock, ContentBlockDeltaEvent, ContentBlockStartEvent
from griptape.tools.base_tool import BaseTool


@define
Expand All @@ -36,7 +44,7 @@ class AnthropicPromptDriver(BasePromptDriver):

api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
model: str = field(kw_only=True, metadata={"serializable": True})
client: Any = field(
client: Client = field(
default=Factory(
lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), takes_self=True
),
Expand All @@ -47,6 +55,8 @@ class AnthropicPromptDriver(BasePromptDriver):
)
top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True})
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})

def try_run(self, prompt_stack: PromptStack) -> PromptStackElement:
Expand All @@ -64,7 +74,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElem
events = self.client.messages.create(**self._base_params(prompt_stack), stream=True)

for event in events:
if event.type == "content_block_delta":
if event.type == "content_block_delta" or event.type == "content_block_start":
yield self.__message_content_delta_to_prompt_stack_content_delta(event)
elif event.type == "message_start":
yield DeltaPromptStackElement(
Expand All @@ -79,6 +89,13 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElem
def _prompt_stack_elements_to_messages(self, elements: list[PromptStackElement]) -> list[dict]:
return [{"role": self.__to_role(input), "content": self.__to_content(input)} for input in elements]

def _prompt_stack_to_tools(self, prompt_stack: PromptStack) -> dict:
return (
{"tools": self.__to_tools(prompt_stack.actions), "tool_choice": self.tool_choice}
if prompt_stack.actions and self.use_native_tools
else {}
)

def _base_params(self, prompt_stack: PromptStack) -> dict:
messages = self._prompt_stack_elements_to_messages([i for i in prompt_stack.inputs if not i.is_system()])

Expand All @@ -96,6 +113,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"top_k": self.top_k,
"max_tokens": self.max_tokens,
"messages": messages,
**self._prompt_stack_to_tools(prompt_stack),
**({"system": system_message} if system_message else {}),
}

Expand All @@ -107,11 +125,26 @@ def __to_role(self, input: PromptStackElement) -> str:
else:
return "user"

def __to_tools(self, tools: list[BaseTool]) -> list[dict]:
return [
{
"name": f"{tool.name}-{tool.activity_name(activity)}",
"description": tool.activity_description(activity),
"input_schema": (tool.activity_schema(activity) or Schema({})).json_schema("Input Schema"),
}
for tool in tools
for activity in tool.activities()
]

def __to_content(self, input: PromptStackElement) -> str | list[dict]:
if all(isinstance(content, TextPromptStackContent) for content in input.content):
return input.to_text_artifact().to_text()
else:
return [self.__prompt_stack_content_message_content(content) for content in input.content]
content = [self.__prompt_stack_content_message_content(content) for content in input.content]
sorted_content = sorted(
content, key=lambda message_content: -1 if message_content["type"] == "tool_result" else 1
) # Tool results must come first in the content list
return sorted_content

def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> dict:
if isinstance(content, TextPromptStackContent):
Expand All @@ -121,24 +154,69 @@ def __prompt_stack_content_message_content(self, content: BasePromptStackContent
"type": "image",
"source": {"type": "base64", "media_type": content.artifact.mime_type, "data": content.artifact.base64},
}
elif isinstance(content, ActionCallPromptStackContent):
action = content.artifact.value

return {
"type": "tool_use",
"id": action.tag,
"name": f"{action.name}-{action.path}",
"input": json.loads(action.input),
}
elif isinstance(content, ActionResultPromptStackContent):
artifact = content.artifact

return {
"type": "tool_result",
"tool_use_id": content.action_tag,
"content": artifact.to_text(),
"is_error": isinstance(artifact, ErrorArtifact),
}
else:
raise ValueError(f"Unsupported prompt content type: {type(content)}")

def __message_content_to_prompt_stack_content(self, content: ContentBlock) -> BasePromptStackContent:
content_type = content.type

if content_type == "text":
if content.type == "text":
return TextPromptStackContent(TextArtifact(content.text))
elif content.type == "tool_use":
return ActionCallPromptStackContent(
artifact=ActionCallArtifact(
value=ActionCallArtifact.ActionCall(
tag=content.id,
name=content.name.split("-")[0],
path=content.name.split("-")[1],
input=json.dumps(content.input),
)
)
)
else:
raise ValueError(f"Unsupported message content type: {content_type}")
raise ValueError(f"Unsupported message content type: {content.type}")

def __message_content_delta_to_prompt_stack_content_delta(
self, content_delta: ContentBlockDeltaEvent
self, event: ContentBlockDeltaEvent | ContentBlockStartEvent
) -> BaseDeltaPromptStackContent:
index = content_delta.index
delta_type = content_delta.delta.type
index = event.index

if event.type == "content_block_start":
content_block = event.content_block

if delta_type == "text_delta":
return DeltaTextPromptStackContent(content_delta.delta.text, index=index)
if content_block.type == "tool_use":
return DeltaActionCallPromptStackContent(
index=index,
tag=content_block.id,
name=content_block.name.split("-")[0],
path=content_block.name.split("-")[1],
)
elif content_block.type == "text":
return DeltaTextPromptStackContent(content_block.text, index=index)
else:
raise ValueError(f"Unsupported message content start type : {content_block.type}")
else:
raise ValueError(f"Unsupported message content delta type : {delta_type}")
delta = event.delta

if delta.type == "text_delta":
return DeltaTextPromptStackContent(delta.text, index=index)
elif delta.type == "input_json_delta":
return DeltaActionCallPromptStackContent(delta_input=delta.partial_json, index=index)
else:
raise ValueError(f"Unsupported message content delta type : {delta.type}")
1 change: 0 additions & 1 deletion griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
{
"role": self.__to_role(input),
"content": self.__prompt_stack_content_message_content(action_result),
"name": f"{action_result.action_name}-{action_result.action_path}",
"tool_call_id": action_result.action_tag,
}
)
Expand Down
2 changes: 1 addition & 1 deletion griptape/mixins/exponential_backoff_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class ExponentialBackoffMixin(ABC):
min_retry_delay: float = field(default=2, kw_only=True)
max_retry_delay: float = field(default=10, kw_only=True)
max_attempts: int = field(default=10, kw_only=True)
max_attempts: int = field(default=0, kw_only=True)
after_hook: Callable = field(default=lambda s: logging.warning(s), kw_only=True)
ignored_exception_types: tuple[type[Exception], ...] = field(factory=tuple, kw_only=True)

Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/toolkit_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def prompt_stack(self) -> PromptStack:
for s in self.subtasks:
if self.prompt_driver.use_native_tools:
stack.add_action_call_input(s.thought, s.actions)
stack.add_action_result_input(self.generate_user_subtask_template(s), s.actions)
stack.add_action_result_input(None if s.output else "Please keep going!", s.actions)
else:
stack.add_assistant_input(self.generate_assistant_subtask_template(s))
stack.add_user_input(self.generate_user_subtask_template(s))
Expand Down
Loading

0 comments on commit 1484e56

Please sign in to comment.