Skip to content

Commit

Permalink
Add bedrock support
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jun 17, 2024
1 parent 1484e56 commit db726c7
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 52 deletions.
8 changes: 6 additions & 2 deletions griptape/artifacts/list_artifact.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations
from typing import Optional
from collections.abc import Sequence
from attrs import field, define
from griptape.artifacts import BaseArtifact
from typing import TypeVar, Generic

T = TypeVar("T", bound=BaseArtifact)


@define
class ListArtifact(BaseArtifact):
value: Sequence[BaseArtifact] = field(factory=list, metadata={"serializable": True})
class ListArtifact(BaseArtifact, Generic[T]):
value: Sequence[T] = field(factory=list, metadata={"serializable": True})
item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True})
validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True})

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from attrs import define, field
from collections.abc import Sequence

from griptape.artifacts.base_artifact import BaseArtifact
from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.common import BaseDeltaPromptStackContent, BasePromptStackContent


Expand Down
46 changes: 23 additions & 23 deletions griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,24 @@ class PromptStack(SerializableMixin):
inputs: list[PromptStackElement] = field(factory=list, kw_only=True, metadata={"serializable": True})
actions: list[BaseTool] = field(factory=list, kw_only=True)

def add_input(self, content: str | BaseArtifact, role: str) -> PromptStackElement:
new_content = self.__process_content(content)
def add_input(self, artifact: str | BaseArtifact, role: str) -> PromptStackElement:
content = self.__process_artifact(artifact)

self.inputs.append(PromptStackElement(content=new_content, role=role))
self.inputs.append(PromptStackElement(content=content, role=role))

return self.inputs[-1]

def add_system_input(self, content: str | BaseArtifact) -> PromptStackElement:
return self.add_input(content, PromptStackElement.SYSTEM_ROLE)
def add_system_input(self, artifact: str | BaseArtifact) -> PromptStackElement:
return self.add_input(artifact, PromptStackElement.SYSTEM_ROLE)

def add_user_input(self, content: str | BaseArtifact) -> PromptStackElement:
return self.add_input(content, PromptStackElement.USER_ROLE)
def add_user_input(self, artifact: str | BaseArtifact) -> PromptStackElement:
return self.add_input(artifact, PromptStackElement.USER_ROLE)

def add_assistant_input(self, content: str | BaseArtifact) -> PromptStackElement:
return self.add_input(content, PromptStackElement.ASSISTANT_ROLE)
def add_assistant_input(self, artifact: str | BaseArtifact) -> PromptStackElement:
return self.add_input(artifact, PromptStackElement.ASSISTANT_ROLE)

def add_action_call_input(self, thought: Optional[str], actions: list[ActionsSubtask.Action]) -> PromptStackElement:
thought_content = self.__process_content(thought) if thought else []
thought_content = self.__process_artifact(thought) if thought else []

action_calls_content = [
ActionCallPromptStackContent(
Expand All @@ -63,7 +63,7 @@ def add_action_call_input(self, thought: Optional[str], actions: list[ActionsSub
def add_action_result_input(
self, instructions: Optional[str | BaseArtifact], actions: list[ActionsSubtask.Action]
) -> PromptStackElement:
instructions_content = self.__process_content(instructions) if instructions else []
instructions_content = self.__process_artifact(instructions) if instructions else []

action_results_content = [
ActionResultPromptStackContent(action.output, action_tag=action.tag)
Expand All @@ -79,21 +79,21 @@ def add_action_result_input(

return self.inputs[-1]

def __process_content(self, content: str | BaseArtifact) -> list[BasePromptStackContent]:
if isinstance(content, str):
return [TextPromptStackContent(TextArtifact(content))]
elif isinstance(content, TextArtifact):
return [TextPromptStackContent(content)]
elif isinstance(content, ImageArtifact):
return [ImagePromptStackContent(content)]
elif isinstance(content, ActionCallArtifact):
return [ActionCallPromptStackContent(content)]
elif isinstance(content, ListArtifact):
processed_contents = [self.__process_content(artifact) for artifact in content.value]
def __process_artifact(self, artifact: str | BaseArtifact) -> list[BasePromptStackContent]:
if isinstance(artifact, str):
return [TextPromptStackContent(TextArtifact(artifact))]
elif isinstance(artifact, TextArtifact):
return [TextPromptStackContent(artifact)]
elif isinstance(artifact, ImageArtifact):
return [ImagePromptStackContent(artifact)]
elif isinstance(artifact, ActionCallArtifact):
return [ActionCallPromptStackContent(artifact)]
elif isinstance(artifact, ListArtifact):
processed_contents = [self.__process_artifact(artifact) for artifact in artifact.value]
flattened_content = [
sub_content for processed_content in processed_contents for sub_content in processed_content
]

return flattened_content
else:
raise ValueError(f"Unsupported content type: {type(content)}")
raise ValueError(f"Unsupported artifact type: {type(artifact)}")
142 changes: 132 additions & 10 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

from collections.abc import Iterator
from typing import TYPE_CHECKING, Any
import json

from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.artifacts import TextArtifact, ActionCallArtifact, ImageArtifact
from griptape.artifacts.base_artifact import BaseArtifact
from griptape.artifacts.error_artifact import ErrorArtifact
from griptape.artifacts.info_artifact import InfoArtifact
from griptape.artifacts.list_artifact import ListArtifact
from griptape.common import (
BaseDeltaPromptStackContent,
DeltaPromptStackElement,
Expand All @@ -15,13 +20,20 @@
TextPromptStackContent,
ImagePromptStackContent,
)
from griptape.common import ActionCallPromptStackContent
from griptape.common.prompt_stack.contents.action_result_prompt_stack_content import ActionResultPromptStackContent
from griptape.common.prompt_stack.contents.delta_action_call_prompt_stack_content import (
DeltaActionCallPromptStackContent,
)
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import AmazonBedrockTokenizer, BaseTokenizer
from griptape.utils import import_optional_dependency
from schema import Schema

if TYPE_CHECKING:
import boto3

from griptape.tools import BaseTool
from griptape.common import PromptStack


Expand All @@ -35,6 +47,8 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
tokenizer: BaseTokenizer = field(
default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True
)
use_native_tools: bool = field(default=True, kw_only=True)
tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": False})

def try_run(self, prompt_stack: PromptStack) -> PromptStackElement:
response = self.bedrock_client.converse(**self._base_params(prompt_stack))
Expand All @@ -43,7 +57,7 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackElement:
output_message = response["output"]["message"]

return PromptStackElement(
content=[TextPromptStackContent(TextArtifact(content["text"])) for content in output_message["content"]],
content=[self.__message_content_to_prompt_stack_content(content) for content in output_message["content"]],
role=PromptStackElement.ASSISTANT_ROLE,
usage=PromptStackElement.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]),
)
Expand All @@ -55,12 +69,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElem
if stream is not None:
for event in stream:
if "messageStart" in event:
yield DeltaPromptStackElement(role=event["messageStart"]["role"])
elif "contentBlockDelta" in event:
content_block_delta = event["contentBlockDelta"]
yield DeltaTextPromptStackContent(
content_block_delta["delta"]["text"], index=content_block_delta["contentBlockIndex"]
)
yield DeltaPromptStackElement(role=PromptStackElement.ASSISTANT_ROLE)
elif "contentBlockDelta" in event or "contentBlockStart" in event:
yield self.__message_content_delta_to_prompt_stack_content_delta(event)
elif "metadata" in event:
usage = event["metadata"]["usage"]
yield DeltaPromptStackElement(
Expand All @@ -80,6 +91,13 @@ def _prompt_stack_elements_to_messages(self, elements: list[PromptStackElement])
for input in elements
]

def _prompt_stack_to_tools(self, prompt_stack: PromptStack) -> dict:
return (
{"toolConfig": {"tools": self.__to_tools(prompt_stack.actions), "toolChoice": self.tool_choice}}
if prompt_stack.actions and self.use_native_tools
else {}
)

def _base_params(self, prompt_stack: PromptStack) -> dict:
system_messages = [
{"text": input.to_text_artifact().to_text()} for input in prompt_stack.inputs if input.is_system()
Expand All @@ -95,20 +113,124 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"system": system_messages,
"inferenceConfig": {"temperature": self.temperature},
"additionalModelRequestFields": self.additional_model_request_fields,
**self._prompt_stack_to_tools(prompt_stack),
}

def __message_content_to_prompt_stack_content(self, content: dict) -> BasePromptStackContent:
if "text" in content:
return TextPromptStackContent(TextArtifact(content["text"]))
elif "toolUse" in content:
name, path = content["toolUse"]["name"].split("_", 1)
return ActionCallPromptStackContent(
artifact=ActionCallArtifact(
value=ActionCallArtifact.ActionCall(
tag=content["toolUse"]["toolUseId"],
name=name,
path=path,
input=json.dumps(content["toolUse"]["input"]),
)
)
)
else:
raise ValueError(f"Unsupported message content type: {content}")

def __message_content_delta_to_prompt_stack_content_delta(self, event: dict) -> BaseDeltaPromptStackContent:
if "contentBlockStart" in event:
content_block = event["contentBlockStart"]["start"]

if "toolUse" in content_block:
name, path = content_block["toolUse"]["name"].split("_", 1)

return DeltaActionCallPromptStackContent(
index=event["contentBlockStart"]["contentBlockIndex"],
tag=content_block["toolUse"]["toolUseId"],
name=name,
path=path,
)
else:
raise ValueError(f"Unsupported message content type: {event}")
elif "contentBlockDelta" in event:
content_block_delta = event["contentBlockDelta"]

if "text" in content_block_delta["delta"]:
return DeltaTextPromptStackContent(
content_block_delta["delta"]["text"], index=content_block_delta["contentBlockIndex"]
)
elif "toolUse" in content_block_delta["delta"]:
return DeltaActionCallPromptStackContent(
index=content_block_delta["contentBlockIndex"],
delta_input=content_block_delta["delta"]["toolUse"]["input"],
)
else:
raise ValueError(f"Unsupported message content type: {event}")
else:
raise ValueError(f"Unsupported message content type: {event}")

def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> dict:
if isinstance(content, TextPromptStackContent):
return {"text": content.artifact.to_text()}
return self.__artifact_to_message_content(content.artifact)
elif isinstance(content, ImagePromptStackContent):
return {"image": {"format": content.artifact.format, "source": {"bytes": content.artifact.value}}}
return self.__artifact_to_message_content(content.artifact)
elif isinstance(content, ActionCallPromptStackContent):
action_call = content.artifact.value

return {
"toolUse": {
"toolUseId": action_call.tag,
"name": f"{action_call.name}_{action_call.path}",
"input": json.loads(action_call.input),
}
}
elif isinstance(content, ActionResultPromptStackContent):
artifact = content.artifact

return {
"toolResult": {
"toolUseId": content.action_tag,
"content": [self.__artifact_to_message_content(artifact) for artifact in artifact.value]
if isinstance(artifact, ListArtifact)
else [self.__artifact_to_message_content(artifact)],
"status": "error" if isinstance(artifact, ErrorArtifact) else "success",
}
}
else:
raise ValueError(f"Unsupported content type: {type(content)}")

def __artifact_to_message_content(self, artifact: BaseArtifact) -> dict:
if isinstance(artifact, ImageArtifact):
return {"image": {"format": artifact.format, "source": {"bytes": artifact.value}}}
elif (
isinstance(artifact, TextArtifact)
or isinstance(artifact, ErrorArtifact)
or isinstance(artifact, InfoArtifact)
):
return {"text": artifact.to_text()}
elif isinstance(artifact, ErrorArtifact):
return {"text": artifact.to_text()}
else:
raise ValueError(f"Unsupported artifact type: {type(artifact)}")

def __to_role(self, input: PromptStackElement) -> str:
if input.is_system():
return "system"
elif input.is_assistant():
return "assistant"
else:
return "user"

def __to_tools(self, tools: list[BaseTool]) -> list[dict]:
return [
{
"toolSpec": {
"name": f"{tool.name}_{tool.activity_name(activity)}",
"description": tool.activity_description(activity),
"inputSchema": {
"json": (tool.activity_schema(activity) or Schema({})).json_schema(
"https://griptape.ai"
) # TODO: Allow for non-griptape ids
},
}
}
for tool in tools
for activity in tool.activities()
]
Loading

0 comments on commit db726c7

Please sign in to comment.