Skip to content

Commit

Permalink
Create BaseAction and ToolAction
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jul 10, 2024
1 parent e2d1a03 commit a06442c
Show file tree
Hide file tree
Showing 29 changed files with 135 additions and 127 deletions.
4 changes: 2 additions & 2 deletions griptape/artifacts/action_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from griptape.mixins import SerializableMixin

if TYPE_CHECKING:
from griptape.common import Action
from griptape.common import ToolAction


@define()
class ActionArtifact(BaseArtifact, SerializableMixin):
value: Action = field(metadata={"serializable": True})
value: ToolAction = field(metadata={"serializable": True})

def __add__(self, other: BaseArtifact) -> ActionArtifact:
raise NotImplementedError
6 changes: 4 additions & 2 deletions griptape/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .action import Action
from .actions.base_action import BaseAction
from .actions.tool_action import ToolAction

from .prompt_stack.contents.base_message_content import BaseMessageContent
from .prompt_stack.contents.base_delta_message_content import BaseDeltaMessageContent
Expand Down Expand Up @@ -32,5 +33,6 @@
"ActionResultMessageContent",
"PromptStack",
"Reference",
"Action",
"BaseAction",
"ToolAction",
]
Empty file.
5 changes: 5 additions & 0 deletions griptape/common/actions/base_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from griptape.mixins import SerializableMixin
from abc import ABC


class BaseAction(SerializableMixin, ABC): ...
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.mixins import SerializableMixin
from griptape.common import BaseAction

if TYPE_CHECKING:
from griptape.tools import BaseTool


@define(kw_only=True)
class Action(SerializableMixin):
"""Represents an instance of an LLM calling a Action.
class ToolAction(BaseAction):
"""Represents an instance of an LLM using a Tool.
Attributes:
tag: The tag (unique identifier) of the action.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from attrs import define, field

from griptape.common import Action
from griptape.common import ToolAction
from griptape.artifacts import ActionArtifact
from griptape.common import BaseDeltaMessageContent, BaseMessageContent, ActionCallDeltaMessageContent

Expand Down Expand Up @@ -37,10 +37,10 @@ def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> ActionCallMes
try:
parsed_input = json.loads(input)
except json.JSONDecodeError as exc:
raise ValueError("Invalid JSON input for Action") from exc
action = Action(tag=tag, name=name, path=path, input=parsed_input)
raise ValueError("Invalid JSON input for ToolAction") from exc
action = ToolAction(tag=tag, name=name, path=path, input=parsed_input)
else:
raise ValueError("Missing required fields for Action")
raise ValueError("Missing required fields for ToolAction")

artifact = ActionArtifact(value=action)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.common import BaseDeltaMessageContent, BaseMessageContent, Action
from griptape.common import BaseDeltaMessageContent, BaseMessageContent, ToolAction


@define
class ActionResultMessageContent(BaseMessageContent):
artifact: BaseArtifact = field(metadata={"serializable": True})
action: Action = field(metadata={"serializable": True})
action: ToolAction = field(metadata={"serializable": True})

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> ActionResultMessageContent:
Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
Message,
TextDeltaMessageContent,
TextMessageContent,
Action,
ToolAction,
)
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import AmazonBedrockTokenizer, BaseTokenizer
Expand Down Expand Up @@ -177,10 +177,10 @@ def __to_prompt_stack_message_content(self, content: dict) -> BaseMessageContent
if "text" in content:
return TextMessageContent(TextArtifact(content["text"]))
elif "toolUse" in content:
name, path = Action.from_native_tool_name(content["toolUse"]["name"])
name, path = ToolAction.from_native_tool_name(content["toolUse"]["name"])
return ActionCallMessageContent(
artifact=ActionArtifact(
value=Action(
value=ToolAction(
tag=content["toolUse"]["toolUseId"], name=name, path=path, input=content["toolUse"]["input"]
)
)
Expand All @@ -193,7 +193,7 @@ def __to_prompt_stack_delta_message_content(self, event: dict) -> BaseDeltaMessa
content_block = event["contentBlockStart"]["start"]

if "toolUse" in content_block:
name, path = Action.from_native_tool_name(content_block["toolUse"]["name"])
name, path = ToolAction.from_native_tool_name(content_block["toolUse"]["name"])

return ActionCallDeltaMessageContent(
index=event["contentBlockStart"]["contentBlockIndex"],
Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
ImageMessageContent,
PromptStack,
Message,
Action,
ToolAction,
TextMessageContent,
)
from griptape.drivers import BasePromptDriver
Expand Down Expand Up @@ -183,11 +183,11 @@ def __to_prompt_stack_message_content(self, content: ContentBlock) -> BaseMessag
if content.type == "text":
return TextMessageContent(TextArtifact(content.text))
elif content.type == "tool_use":
name, path = Action.from_native_tool_name(content.name)
name, path = ToolAction.from_native_tool_name(content.name)

return ActionCallMessageContent(
artifact=ActionArtifact(
value=Action(tag=content.id, name=name, path=path, input=content.input) # pyright: ignore[reportArgumentType]
value=ToolAction(tag=content.id, name=name, path=path, input=content.input) # pyright: ignore[reportArgumentType]
)
)
else:
Expand All @@ -200,7 +200,7 @@ def __to_prompt_stack_delta_message_content(
content_block = event.content_block

if content_block.type == "tool_use":
name, path = Action.from_native_tool_name(content_block.name)
name, path = ToolAction.from_native_tool_name(content_block.name)

return ActionCallDeltaMessageContent(index=event.index, tag=content_block.id, name=name, path=path)
elif content_block.type == "text":
Expand Down
10 changes: 5 additions & 5 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Message,
TextMessageContent,
ActionResultMessageContent,
Action,
ToolAction,
)
from griptape.utils import import_optional_dependency
from griptape.tokenizers import BaseTokenizer
Expand Down Expand Up @@ -208,10 +208,10 @@ def __to_prompt_stack_message_content(self, response: NonStreamedChatResponse) -
[
ActionCallMessageContent(
ActionArtifact(
Action(
ToolAction(
tag=tool_call.name,
name=Action.from_native_tool_name(tool_call.name)[0],
path=Action.from_native_tool_name(tool_call.name)[1],
name=ToolAction.from_native_tool_name(tool_call.name)[0],
path=ToolAction.from_native_tool_name(tool_call.name)[1],
input=tool_call.parameters,
)
)
Expand All @@ -229,7 +229,7 @@ def __to_prompt_stack_delta_message_content(self, event: Any) -> BaseDeltaMessag
if event.tool_call_delta is not None:
tool_call_delta = event.tool_call_delta
if tool_call_delta.name is not None:
name, path = Action.from_native_tool_name(tool_call_delta.name)
name, path = ToolAction.from_native_tool_name(tool_call_delta.name)

return ActionCallDeltaMessageContent(tag=tool_call_delta.name, name=name, path=path)
else:
Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ActionResultMessageContent,
ActionCallDeltaMessageContent,
BaseDeltaMessageContent,
Action,
ToolAction,
)
from griptape.artifacts import TextArtifact, ActionArtifact
from griptape.drivers import BasePromptDriver
Expand Down Expand Up @@ -219,11 +219,11 @@ def __to_prompt_stack_message_content(self, content: Part) -> BaseMessageContent
elif content.function_call:
function_call = content.function_call

name, path = Action.from_native_tool_name(function_call.name)
name, path = ToolAction.from_native_tool_name(function_call.name)

args = {k: v for k, v in function_call.args.items()}
return ActionCallMessageContent(
artifact=ActionArtifact(value=Action(tag=function_call.name, name=name, path=path, input=args))
artifact=ActionArtifact(value=ToolAction(tag=function_call.name, name=name, path=path, input=args))
)
else:
raise ValueError(f"Unsupported message content type {content}")
Expand All @@ -234,7 +234,7 @@ def __to_prompt_stack_delta_message_content(self, content: Part) -> BaseDeltaMes
elif content.function_call:
function_call = content.function_call

name, path = Action.from_native_tool_name(function_call.name)
name, path = ToolAction.from_native_tool_name(function_call.name)

args = {k: v for k, v in function_call.args.items()}
return ActionCallDeltaMessageContent(
Expand Down
18 changes: 9 additions & 9 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
PromptStack,
Message,
TextMessageContent,
Action,
ToolAction,
)
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer
Expand Down Expand Up @@ -152,7 +152,7 @@ def __to_openai_messages(self, messages: list[Message]) -> list[dict]:
if message.is_text():
openai_messages.append({"role": message.role, "content": message.to_text()})
elif message.has_any_content_type(ActionResultMessageContent):
# Action results need to be expanded into separate messages.
# ToolAction results need to be expanded into separate messages.
openai_messages.extend(
[
{
Expand All @@ -164,7 +164,7 @@ def __to_openai_messages(self, messages: list[Message]) -> list[dict]:
]
)
else:
# Action calls are attached to the assistant message that originally generated them.
# ToolAction calls are attached to the assistant message that originally generated them.
action_call_content = []
non_action_call_content = []
for content in message.content:
Expand All @@ -178,7 +178,7 @@ def __to_openai_messages(self, messages: list[Message]) -> list[dict]:
"role": self.__to_openai_role(message),
"content": [
self.__to_openai_message_content(content)
for content in non_action_call_content # Action calls do not belong in the content
for content in non_action_call_content # ToolAction calls do not belong in the content
],
**(
{
Expand Down Expand Up @@ -250,10 +250,10 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->
[
ActionCallMessageContent(
ActionArtifact(
Action(
ToolAction(
tag=tool_call.id,
name=Action.from_native_tool_name(tool_call.function.name)[0],
path=Action.from_native_tool_name(tool_call.function.name)[1],
name=ToolAction.from_native_tool_name(tool_call.function.name)[0],
path=ToolAction.from_native_tool_name(tool_call.function.name)[1],
input=json.loads(tool_call.function.arguments),
)
)
Expand All @@ -279,8 +279,8 @@ def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) ->
return ActionCallDeltaMessageContent(
index=index,
tag=tool_call.id,
name=Action.from_native_tool_name(tool_call.function.name)[0],
path=Action.from_native_tool_name(tool_call.function.name)[1],
name=ToolAction.from_native_tool_name(tool_call.function.name)[0],
path=ToolAction.from_native_tool_name(tool_call.function.name)[1],
)
else:
return ActionCallDeltaMessageContent(index=index, partial_input=tool_call.function.arguments)
Expand Down
4 changes: 2 additions & 2 deletions griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
# These modules are required to avoid `NameError`s when resolving types.
from griptape.drivers import BaseConversationMemoryDriver, BasePromptDriver
from griptape.structures import Structure
from griptape.common import PromptStack, Message, Reference, Action
from griptape.common import PromptStack, Message, Reference, ToolAction
from griptape.tokenizers.base_tokenizer import BaseTokenizer
from griptape.tools import BaseTool
from typing import Any
Expand All @@ -131,7 +131,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
"BaseTokenizer": BaseTokenizer,
"boto3": boto3,
"Client": Client,
"Action": Action,
"ToolAction": ToolAction,
"GenerativeModel": GenerativeModel,
"Reference": Reference,
"BaseArtifact": BaseArtifact,
Expand Down
20 changes: 10 additions & 10 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import schema
from attrs import define, field
from griptape import utils
from griptape.common import Action
from griptape.common import ToolAction
from griptape.utils import remove_null_values_in_dict_recursively
from griptape.mixins import ActionsSubtaskOriginMixin
from griptape.tasks import BaseTask
Expand All @@ -25,7 +25,7 @@ class ActionsSubtask(BaseTask):

parent_task_id: Optional[str] = field(default=None, kw_only=True)
thought: Optional[str] = field(default=None, kw_only=True)
actions: list[Action] = field(factory=list, kw_only=True)
actions: list[ToolAction] = field(factory=list, kw_only=True)
output: Optional[BaseArtifact] = field(default=None, init=False)
_input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field(
default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""),
Expand Down Expand Up @@ -74,7 +74,7 @@ def attach_to(self, parent_task: BaseTask):
except Exception as e:
self.structure.logger.error(f"Subtask {self.origin_task.id}\nError parsing tool action: {e}")

self.output = ErrorArtifact(f"Action input parsing error: {e}", exception=e)
self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e)

def before_run(self) -> None:
self.structure.publish_event(
Expand Down Expand Up @@ -116,13 +116,13 @@ def run(self) -> BaseArtifact:
else:
return ErrorArtifact("no tool output")

def execute_actions(self, actions: list[Action]) -> list[tuple[str, BaseArtifact]]:
def execute_actions(self, actions: list[ToolAction]) -> list[tuple[str, BaseArtifact]]:
with self.futures_executor_fn() as executor:
results = utils.execute_futures_dict({a.tag: executor.submit(self.execute_action, a) for a in actions})

return [r for r in results.values()]

def execute_action(self, action: Action) -> tuple[str, BaseArtifact]:
def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]:
if action.tool is not None:
if action.path is not None:
output = action.tool.execute(getattr(action.tool, action.path), self, action)
Expand Down Expand Up @@ -208,7 +208,7 @@ def __init_from_prompt(self, value: str) -> None:

def __init_from_artifacts(self, artifacts: ListArtifact) -> None:
"""Parses the input Artifacts to extract the thought and actions.
Text Artifacts are used to extract the thought, and Action Artifacts are used to extract the actions.
Text Artifacts are used to extract the thought, and ToolAction Artifacts are used to extract the actions.
Args:
artifacts: The input Artifacts.
Expand Down Expand Up @@ -238,7 +238,7 @@ def __parse_actions(self, actions_matches: list[str]) -> None:

self.output = ErrorArtifact(f"Actions JSON decoding error: {e}", exception=e)

def __process_action_object(self, action_object: dict) -> Action:
def __process_action_object(self, action_object: dict) -> ToolAction:
# Load action tag; throw exception if the key is not present
action_tag = action_object["tag"]

Expand Down Expand Up @@ -269,19 +269,19 @@ def __process_action_object(self, action_object: dict) -> Action:
else:
raise Exception("ActionSubtask must be attached to a Task that implements ActionSubtaskOriginMixin.")

action = Action(tag=action_tag, name=action_name, path=action_path, input=action_input, tool=tool)
action = ToolAction(tag=action_tag, name=action_name, path=action_path, input=action_input, tool=tool)

if action.tool and action.input:
self.__validate_action(action)

return action

def __validate_action(self, action: Action) -> None:
def __validate_action(self, action: ToolAction) -> None:
try:
if action.path is not None:
activity = getattr(action.tool, action.path)
else:
raise Exception("Action path not found.")
raise Exception("ToolAction path not found.")

if activity is not None:
activity_schema = action.tool.activity_schema(activity)
Expand Down
Loading

0 comments on commit a06442c

Please sign in to comment.