From a78c8f22101c8b375e5d93bb105647e16900a394 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 17 Jun 2024 16:41:24 -0700 Subject: [PATCH] Add bedrock support --- griptape/artifacts/list_artifact.py | 8 +- .../action_result_prompt_stack_content.py | 5 +- griptape/common/prompt_stack/prompt_stack.py | 46 +++--- .../prompt/amazon_bedrock_prompt_driver.py | 142 ++++++++++++++++-- .../drivers/prompt/anthropic_prompt_driver.py | 44 ++++-- griptape/tasks/toolkit_task.py | 4 +- 6 files changed, 198 insertions(+), 51 deletions(-) diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index 68b377df2a..d7ae3b117d 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -1,12 +1,16 @@ +from __future__ import annotations from typing import Optional from collections.abc import Sequence from attrs import field, define from griptape.artifacts import BaseArtifact +from typing import TypeVar, Generic + +T = TypeVar("T", bound=BaseArtifact) @define -class ListArtifact(BaseArtifact): - value: Sequence[BaseArtifact] = field(factory=list, metadata={"serializable": True}) +class ListArtifact(BaseArtifact, Generic[T]): + value: Sequence[T] = field(factory=list, metadata={"serializable": True}) item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True}) validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True}) diff --git a/griptape/common/prompt_stack/contents/action_result_prompt_stack_content.py b/griptape/common/prompt_stack/contents/action_result_prompt_stack_content.py index e8dfccbe03..ae99b0ecb1 100644 --- a/griptape/common/prompt_stack/contents/action_result_prompt_stack_content.py +++ b/griptape/common/prompt_stack/contents/action_result_prompt_stack_content.py @@ -1,9 +1,10 @@ from __future__ import annotations -from attrs import define, field from collections.abc import Sequence -from griptape.artifacts.base_artifact import BaseArtifact +from attrs import define, field + +from griptape.artifacts import BaseArtifact from griptape.common import BaseDeltaPromptStackContent, BasePromptStackContent diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py index 7740af609c..6a41c6b2c2 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -22,24 +22,24 @@ class PromptStack(SerializableMixin): inputs: list[PromptStackElement] = field(factory=list, kw_only=True, metadata={"serializable": True}) actions: list[BaseTool] = field(factory=list, kw_only=True) - def add_input(self, content: str | BaseArtifact, role: str) -> PromptStackElement: - new_content = self.__process_content(content) + def add_input(self, artifact: str | BaseArtifact, role: str) -> PromptStackElement: + content = self.__process_artifact(artifact) - self.inputs.append(PromptStackElement(content=new_content, role=role)) + self.inputs.append(PromptStackElement(content=content, role=role)) return self.inputs[-1] - def add_system_input(self, content: str | BaseArtifact) -> PromptStackElement: - return self.add_input(content, PromptStackElement.SYSTEM_ROLE) + def add_system_input(self, artifact: str | BaseArtifact) -> PromptStackElement: + return self.add_input(artifact, PromptStackElement.SYSTEM_ROLE) - def add_user_input(self, content: str | BaseArtifact) -> PromptStackElement: - return self.add_input(content, PromptStackElement.USER_ROLE) + def add_user_input(self, artifact: str | BaseArtifact) -> PromptStackElement: + return self.add_input(artifact, PromptStackElement.USER_ROLE) - def add_assistant_input(self, content: str | BaseArtifact) -> PromptStackElement: - return self.add_input(content, PromptStackElement.ASSISTANT_ROLE) + def add_assistant_input(self, artifact: str | BaseArtifact) -> PromptStackElement: + return self.add_input(artifact, PromptStackElement.ASSISTANT_ROLE) def add_action_call_input(self, thought: Optional[str], actions: list[ActionsSubtask.Action]) -> PromptStackElement: - thought_content = self.__process_content(thought) if thought else [] + thought_content = self.__process_artifact(thought) if thought else [] action_calls_content = [ ActionCallPromptStackContent( @@ -63,7 +63,7 @@ def add_action_call_input(self, thought: Optional[str], actions: list[ActionsSub def add_action_result_input( self, instructions: Optional[str | BaseArtifact], actions: list[ActionsSubtask.Action] ) -> PromptStackElement: - instructions_content = self.__process_content(instructions) if instructions else [] + instructions_content = self.__process_artifact(instructions) if instructions else [] action_results_content = [ ActionResultPromptStackContent(action.output, action_tag=action.tag) @@ -79,21 +79,21 @@ def add_action_result_input( return self.inputs[-1] - def __process_content(self, content: str | BaseArtifact) -> list[BasePromptStackContent]: - if isinstance(content, str): - return [TextPromptStackContent(TextArtifact(content))] - elif isinstance(content, TextArtifact): - return [TextPromptStackContent(content)] - elif isinstance(content, ImageArtifact): - return [ImagePromptStackContent(content)] - elif isinstance(content, ActionCallArtifact): - return [ActionCallPromptStackContent(content)] - elif isinstance(content, ListArtifact): - processed_contents = [self.__process_content(artifact) for artifact in content.value] + def __process_artifact(self, artifact: str | BaseArtifact) -> list[BasePromptStackContent]: + if isinstance(artifact, str): + return [TextPromptStackContent(TextArtifact(artifact))] + elif isinstance(artifact, TextArtifact): + return [TextPromptStackContent(artifact)] + elif isinstance(artifact, ImageArtifact): + return [ImagePromptStackContent(artifact)] + elif isinstance(artifact, ActionCallArtifact): + return [ActionCallPromptStackContent(artifact)] + elif isinstance(artifact, ListArtifact): + processed_contents = [self.__process_artifact(artifact) for artifact in artifact.value] flattened_content = [ sub_content for processed_content in processed_contents for sub_content in processed_content ] return flattened_content else: - raise ValueError(f"Unsupported content type: {type(content)}") + raise ValueError(f"Unsupported artifact type: {type(artifact)}") diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index d3c13a5da3..d209037d90 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -2,10 +2,15 @@ from collections.abc import Iterator from typing import TYPE_CHECKING, Any +import json from attrs import Factory, define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import TextArtifact, ActionCallArtifact, ImageArtifact +from griptape.artifacts.base_artifact import BaseArtifact +from griptape.artifacts.error_artifact import ErrorArtifact +from griptape.artifacts.info_artifact import InfoArtifact +from griptape.artifacts.list_artifact import ListArtifact from griptape.common import ( BaseDeltaPromptStackContent, DeltaPromptStackElement, @@ -15,13 +20,20 @@ TextPromptStackContent, ImagePromptStackContent, ) +from griptape.common import ActionCallPromptStackContent +from griptape.common.prompt_stack.contents.action_result_prompt_stack_content import ActionResultPromptStackContent +from griptape.common.prompt_stack.contents.delta_action_call_prompt_stack_content import ( + DeltaActionCallPromptStackContent, +) from griptape.drivers import BasePromptDriver from griptape.tokenizers import AmazonBedrockTokenizer, BaseTokenizer from griptape.utils import import_optional_dependency +from schema import Schema if TYPE_CHECKING: import boto3 + from griptape.tools import BaseTool from griptape.common import PromptStack @@ -35,6 +47,8 @@ class AmazonBedrockPromptDriver(BasePromptDriver): tokenizer: BaseTokenizer = field( default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True ) + use_native_tools: bool = field(default=True, kw_only=True) + tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": False}) def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: response = self.bedrock_client.converse(**self._base_params(prompt_stack)) @@ -43,7 +57,7 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: output_message = response["output"]["message"] return PromptStackElement( - content=[TextPromptStackContent(TextArtifact(content["text"])) for content in output_message["content"]], + content=[self.__message_content_to_prompt_stack_content(content) for content in output_message["content"]], role=PromptStackElement.ASSISTANT_ROLE, usage=PromptStackElement.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]), ) @@ -55,12 +69,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElem if stream is not None: for event in stream: if "messageStart" in event: - yield DeltaPromptStackElement(role=event["messageStart"]["role"]) - elif "contentBlockDelta" in event: - content_block_delta = event["contentBlockDelta"] - yield DeltaTextPromptStackContent( - content_block_delta["delta"]["text"], index=content_block_delta["contentBlockIndex"] - ) + yield DeltaPromptStackElement(role=PromptStackElement.ASSISTANT_ROLE) + elif "contentBlockDelta" in event or "contentBlockStart" in event: + yield self.__message_content_delta_to_prompt_stack_content_delta(event) elif "metadata" in event: usage = event["metadata"]["usage"] yield DeltaPromptStackElement( @@ -80,6 +91,13 @@ def _prompt_stack_elements_to_messages(self, elements: list[PromptStackElement]) for input in elements ] + def _prompt_stack_to_tools(self, prompt_stack: PromptStack) -> dict: + return ( + {"toolConfig": {"tools": self.__to_tools(prompt_stack.actions), "toolChoice": self.tool_choice}} + if prompt_stack.actions and self.use_native_tools + else {} + ) + def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = [ {"text": input.to_text_artifact().to_text()} for input in prompt_stack.inputs if input.is_system() @@ -95,16 +113,103 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "system": system_messages, "inferenceConfig": {"temperature": self.temperature}, "additionalModelRequestFields": self.additional_model_request_fields, + **self._prompt_stack_to_tools(prompt_stack), } + def __message_content_to_prompt_stack_content(self, content: dict) -> BasePromptStackContent: + if "text" in content: + return TextPromptStackContent(TextArtifact(content["text"])) + elif "toolUse" in content: + name, path = content["toolUse"]["name"].split("_", 1) + return ActionCallPromptStackContent( + artifact=ActionCallArtifact( + value=ActionCallArtifact.ActionCall( + tag=content["toolUse"]["toolUseId"], + name=name, + path=path, + input=json.dumps(content["toolUse"]["input"]), + ) + ) + ) + else: + raise ValueError(f"Unsupported message content type: {content}") + + def __message_content_delta_to_prompt_stack_content_delta(self, event: dict) -> BaseDeltaPromptStackContent: + if "contentBlockStart" in event: + content_block = event["contentBlockStart"]["start"] + + if "toolUse" in content_block: + name, path = content_block["toolUse"]["name"].split("_", 1) + + return DeltaActionCallPromptStackContent( + index=event["contentBlockStart"]["contentBlockIndex"], + tag=content_block["toolUse"]["toolUseId"], + name=name, + path=path, + ) + else: + raise ValueError(f"Unsupported message content type: {event}") + elif "contentBlockDelta" in event: + content_block_delta = event["contentBlockDelta"] + + if "text" in content_block_delta["delta"]: + return DeltaTextPromptStackContent( + content_block_delta["delta"]["text"], index=content_block_delta["contentBlockIndex"] + ) + elif "toolUse" in content_block_delta["delta"]: + return DeltaActionCallPromptStackContent( + index=content_block_delta["contentBlockIndex"], + delta_input=content_block_delta["delta"]["toolUse"]["input"], + ) + else: + raise ValueError(f"Unsupported message content type: {event}") + else: + raise ValueError(f"Unsupported message content type: {event}") + def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> dict: if isinstance(content, TextPromptStackContent): - return {"text": content.artifact.to_text()} + return self.__artifact_to_message_content(content.artifact) elif isinstance(content, ImagePromptStackContent): - return {"image": {"format": content.artifact.format, "source": {"bytes": content.artifact.value}}} + return self.__artifact_to_message_content(content.artifact) + elif isinstance(content, ActionCallPromptStackContent): + action_call = content.artifact.value + + return { + "toolUse": { + "toolUseId": action_call.tag, + "name": f"{action_call.name}_{action_call.path}", + "input": json.loads(action_call.input), + } + } + elif isinstance(content, ActionResultPromptStackContent): + artifact = content.artifact + + return { + "toolResult": { + "toolUseId": content.action_tag, + "content": [self.__artifact_to_message_content(artifact) for artifact in artifact.value] + if isinstance(artifact, ListArtifact) + else [self.__artifact_to_message_content(artifact)], + "status": "error" if isinstance(artifact, ErrorArtifact) else "success", + } + } else: raise ValueError(f"Unsupported content type: {type(content)}") + def __artifact_to_message_content(self, artifact: BaseArtifact) -> dict: + if isinstance(artifact, ImageArtifact): + return {"image": {"format": artifact.format, "source": {"bytes": artifact.value}}} + elif ( + isinstance(artifact, TextArtifact) + or isinstance(artifact, ErrorArtifact) + or isinstance(artifact, InfoArtifact) + ): + return {"text": artifact.to_text()} + elif isinstance(artifact, ErrorArtifact): + return {"text": artifact.to_text()} + else: + raise ValueError(f"Unsupported artifact type: {type(artifact)}") + def __to_role(self, input: PromptStackElement) -> str: if input.is_system(): return "system" @@ -112,3 +217,20 @@ def __to_role(self, input: PromptStackElement) -> str: return "assistant" else: return "user" + + def __to_tools(self, tools: list[BaseTool]) -> list[dict]: + return [ + { + "toolSpec": { + "name": f"{tool.name}_{tool.activity_name(activity)}", + "description": tool.activity_description(activity), + "inputSchema": { + "json": (tool.activity_schema(activity) or Schema({})).json_schema( + "https://griptape.ai" + ) # TODO: Allow for non-griptape ids + }, + } + } + for tool in tools + for activity in tool.activities() + ] diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 19d97fb720..1570e444e9 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -7,6 +7,10 @@ from attrs import Factory, define, field from griptape.artifacts import ActionCallArtifact, ErrorArtifact, TextArtifact +from griptape.artifacts.base_artifact import BaseArtifact +from griptape.artifacts.image_artifact import ImageArtifact +from griptape.artifacts.info_artifact import InfoArtifact +from griptape.artifacts.list_artifact import ListArtifact from griptape.common import ( BaseDeltaPromptStackContent, BasePromptStackContent, @@ -17,11 +21,9 @@ PromptStackElement, TextPromptStackContent, ) -from griptape.common.prompt_stack.contents.action_call_prompt_stack_content import ActionCallPromptStackContent -from griptape.common.prompt_stack.contents.action_result_prompt_stack_content import ActionResultPromptStackContent -from griptape.common.prompt_stack.contents.delta_action_call_prompt_stack_content import ( - DeltaActionCallPromptStackContent, -) +from griptape.common import ActionCallPromptStackContent +from griptape.common import ActionResultPromptStackContent +from griptape.common import DeltaActionCallPromptStackContent from griptape.drivers import BasePromptDriver from griptape.tokenizers import AnthropicTokenizer, BaseTokenizer from griptape.utils import import_optional_dependency @@ -140,20 +142,17 @@ def __to_content(self, input: PromptStackElement) -> str | list[dict]: if all(isinstance(content, TextPromptStackContent) for content in input.content): return input.to_text_artifact().to_text() else: - content = [self.__prompt_stack_content_message_content(content) for content in input.content] + content = [self.__prompt_stack_content_to_message_content(content) for content in input.content] sorted_content = sorted( content, key=lambda message_content: -1 if message_content["type"] == "tool_result" else 1 ) # Tool results must come first in the content list return sorted_content - def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> dict: + def __prompt_stack_content_to_message_content(self, content: BasePromptStackContent) -> dict: if isinstance(content, TextPromptStackContent): - return {"type": "text", "text": content.artifact.to_text()} + return self.__artifact_to_message_content(content.artifact) elif isinstance(content, ImagePromptStackContent): - return { - "type": "image", - "source": {"type": "base64", "media_type": content.artifact.mime_type, "data": content.artifact.base64}, - } + return self.__artifact_to_message_content(content.artifact) elif isinstance(content, ActionCallPromptStackContent): action = content.artifact.value @@ -169,12 +168,31 @@ def __prompt_stack_content_message_content(self, content: BasePromptStackContent return { "type": "tool_result", "tool_use_id": content.action_tag, - "content": artifact.to_text(), + "content": [self.__artifact_to_message_content(artifact) for artifact in artifact.value] + if isinstance(artifact, ListArtifact) + else [self.__artifact_to_message_content(artifact)], "is_error": isinstance(artifact, ErrorArtifact), } else: raise ValueError(f"Unsupported prompt content type: {type(content)}") + def __artifact_to_message_content(self, artifact: BaseArtifact) -> dict: + if isinstance(artifact, ImageArtifact): + return { + "type": "image", + "source": {"type": "base64", "media_type": artifact.mime_type, "data": artifact.base64}, + } + elif ( + isinstance(artifact, TextArtifact) + or isinstance(artifact, ErrorArtifact) + or isinstance(artifact, InfoArtifact) + ): + return {"text": artifact.to_text()} + elif isinstance(artifact, ErrorArtifact): + return {"text": artifact.to_text()} + else: + raise ValueError(f"Unsupported artifact type: {type(artifact)}") + def __message_content_to_prompt_stack_content(self, content: ContentBlock) -> BasePromptStackContent: if content.type == "text": return TextPromptStackContent(TextArtifact(content.text)) diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index 479968239b..33e5643ddc 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -75,7 +75,9 @@ def prompt_stack(self) -> PromptStack: for s in self.subtasks: if self.prompt_driver.use_native_tools: stack.add_action_call_input(s.thought, s.actions) - stack.add_action_result_input(None if s.output else "Please keep going!", s.actions) + stack.add_action_result_input( + None if s.output else "Please keep going!", s.actions + ) # TODO: Instructions may not be necessary else: stack.add_assistant_input(self.generate_assistant_subtask_template(s)) stack.add_user_input(self.generate_user_subtask_template(s))