Skip to content

Commit

Permalink
Implement native function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jun 14, 2024
1 parent 573d81f commit 1c8648b
Show file tree
Hide file tree
Showing 15 changed files with 471 additions and 119 deletions.
2 changes: 2 additions & 0 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .media_artifact import MediaArtifact
from .image_artifact import ImageArtifact
from .audio_artifact import AudioArtifact
from .action_call_artifact import ActionCallArtifact


__all__ = [
Expand All @@ -21,4 +22,5 @@
"MediaArtifact",
"ImageArtifact",
"AudioArtifact",
"ActionCallArtifact",
]
49 changes: 49 additions & 0 deletions griptape/artifacts/action_call_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING

from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.mixins import SerializableMixin

if TYPE_CHECKING:
pass


@define()
class ActionCallArtifact(BaseArtifact, SerializableMixin):
"""Represents an instance of an LLM calling a Action.
Attributes:
tag: The tag (unique identifier) of the action.
name: The name (Tool name) of the action.
path: The path (Tool activity name) of the action.
input: The input (Tool params) of the action.
"""

@define(kw_only=True)
class ActionCall(SerializableMixin):
tag: str = field(metadata={"serializable": True})
name: str = field(metadata={"serializable": True})
path: str = field(default=None, metadata={"serializable": True})
input: str = field(default={}, metadata={"serializable": True})

def __str__(self) -> str:
value = self.to_dict()

input = value.pop("input")
formatted_json = (
"{" + ", ".join([f'"{k}": {json.dumps(v)}' for k, v in value.items()]) + f', "input": {input}' + "}"
)

return formatted_json

def to_dict(self) -> dict:
return {"tag": self.tag, "name": self.name, "path": self.path, "input": self.input}

value: ActionCall = field(metadata={"serializable": True})

def __add__(self, other: BaseArtifact) -> ActionCallArtifact:
raise NotImplementedError
6 changes: 6 additions & 0 deletions griptape/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from .prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent
from .prompt_stack.contents.text_prompt_stack_content import TextPromptStackContent
from .prompt_stack.contents.image_prompt_stack_content import ImagePromptStackContent
from .prompt_stack.contents.delta_action_call_prompt_stack_content import DeltaActionCallPromptStackContent
from .prompt_stack.contents.action_call_prompt_stack_content import ActionCallPromptStackContent
from .prompt_stack.contents.action_result_prompt_stack_content import ActionResultPromptStackContent

from .prompt_stack.elements.base_prompt_stack_element import BasePromptStackElement
from .prompt_stack.elements.delta_prompt_stack_element import DeltaPromptStackElement
Expand All @@ -19,5 +22,8 @@
"DeltaTextPromptStackContent",
"TextPromptStackContent",
"ImagePromptStackContent",
"DeltaActionCallPromptStackContent",
"ActionCallPromptStackContent",
"ActionResultPromptStackContent",
"PromptStack",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

from attrs import define, field
from typing import Sequence

from griptape.artifacts.action_call_artifact import ActionCallArtifact
from griptape.common import BasePromptStackContent, BaseDeltaPromptStackContent
from griptape.common import DeltaActionCallPromptStackContent


@define
class ActionCallPromptStackContent(BasePromptStackContent):
artifact: ActionCallArtifact = field(metadata={"serializable": True})

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> ActionCallPromptStackContent:
action_call_deltas = [delta for delta in deltas if isinstance(delta, DeltaActionCallPromptStackContent)]

tag = None
name = None
path = None
input = ""

for delta in action_call_deltas:
if delta.tag is not None:
tag = delta.tag
if delta.name is not None:
name = delta.name
if delta.path is not None:
path = delta.path
if delta.delta_input is not None:
input += delta.delta_input

if tag is not None and name is not None and path is not None:
action = ActionCallArtifact.ActionCall(tag=tag, name=name, path=path, input=input)
else:
raise ValueError("Missing required fields for ActionCallArtifact.Action")

artifact = ActionCallArtifact(value=action)

return cls(artifact=artifact)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from attrs import define, field
from typing import Sequence

from griptape.artifacts.base_artifact import BaseArtifact
from griptape.common import BaseDeltaPromptStackContent, BasePromptStackContent


@define
class ActionResultPromptStackContent(BasePromptStackContent):
artifact: BaseArtifact = field(metadata={"serializable": True})
action_tag: str = field(kw_only=True, metadata={"serializable": True})
action_name: str = field(kw_only=True, metadata={"serializable": True})
action_path: str = field(kw_only=True, metadata={"serializable": True})

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> ActionResultPromptStackContent:
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from abc import ABC
from abc import ABC, abstractmethod
from collections.abc import Sequence

from attrs import define, field
Expand Down Expand Up @@ -28,4 +28,5 @@ def __len__(self) -> int:
return len(self.artifact)

@classmethod
@abstractmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> BasePromptStackContent: ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations
from attrs import define, field
from typing import Optional

from griptape.common import BaseDeltaPromptStackContent


@define
class DeltaActionCallPromptStackContent(BaseDeltaPromptStackContent):
tag: Optional[str] = field(default=None, metadata={"serializable": True})
name: Optional[str] = field(default=None, metadata={"serializable": True})
path: Optional[str] = field(default=None, metadata={"serializable": True})
delta_input: Optional[str] = field(default=None, metadata={"serializable": True})

def __str__(self) -> str:
output = ""

if self.name is not None:
output += f"{self.name}"
if self.path is not None:
output += f".{self.path}"
if self.tag is not None:
output += f" ({self.tag})"
if self.delta_input is not None:
output += f" {self.delta_input}"

return output
32 changes: 25 additions & 7 deletions griptape/common/prompt_stack/elements/prompt_stack_element.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

from typing import Any, Optional
from typing import Any, Optional, Sequence

from attrs import Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.common import BasePromptStackContent, TextPromptStackContent
from griptape.common.prompt_stack.contents.action_call_prompt_stack_content import ActionCallPromptStackContent
from griptape.common.prompt_stack.contents.action_result_prompt_stack_content import ActionResultPromptStackContent
from griptape.mixins.serializable_mixin import SerializableMixin

from .base_prompt_stack_element import BasePromptStackElement
Expand All @@ -22,12 +24,12 @@ class Usage(SerializableMixin):
def total_tokens(self) -> float:
return (self.input_tokens or 0) + (self.output_tokens or 0)

def __init__(self, content: str | list[BasePromptStackContent], **kwargs: Any):
def __init__(self, content: str | Sequence[BasePromptStackContent], **kwargs: Any):
if isinstance(content, str):
content = [TextPromptStackContent(TextArtifact(value=content))]
self.__attrs_init__(content, **kwargs) # pyright: ignore[reportAttributeAccessIssue]

content: list[BasePromptStackContent] = field(metadata={"serializable": True})
content: Sequence[BasePromptStackContent] = field(metadata={"serializable": True})
usage: Usage = field(
kw_only=True, default=Factory(lambda: PromptStackElement.Usage()), metadata={"serializable": True}
)
Expand All @@ -45,10 +47,26 @@ def __str__(self) -> str:
def to_text(self) -> str:
return self.to_text_artifact().to_text()

def has_action_results(self) -> bool:
return any(isinstance(content, ActionResultPromptStackContent) for content in self.content)

def has_action_calls(self) -> bool:
return any(isinstance(content, ActionCallPromptStackContent) for content in self.content)

def to_text_artifact(self) -> TextArtifact:
artifact = TextArtifact(value="")
action_call_contents = [
content for content in self.content if isinstance(content, ActionCallPromptStackContent)
]
text_contents = [content for content in self.content if isinstance(content, TextPromptStackContent)]

for content in self.content:
artifact.value += content.artifact.to_text()
text_output = "".join([content.artifact.value for content in text_contents])
if action_call_contents:
actions_output = [str(action.artifact.value) for action in action_call_contents]
output = "Actions: [" + ", ".join(actions_output) + "]"

if text_output:
output = f"Thought: {text_output}\n{output}"
else:
output = text_output

return artifact
return TextArtifact(value=output)
106 changes: 84 additions & 22 deletions griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,104 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Optional

from attrs import define, field

from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact, ImageArtifact
from griptape.artifacts import ActionCallArtifact, BaseArtifact, ImageArtifact, ListArtifact, TextArtifact
from griptape.common import (
ActionCallPromptStackContent,
ImagePromptStackContent,
PromptStackElement,
TextPromptStackContent,
)
from griptape.common.prompt_stack.contents.action_result_prompt_stack_content import ActionResultPromptStackContent
from griptape.mixins import SerializableMixin
from griptape.common import PromptStackElement, TextPromptStackContent, ImagePromptStackContent

if TYPE_CHECKING:
from griptape.tools import BaseTool
from griptape.tasks.actions_subtask import ActionsSubtask


@define
class PromptStack(SerializableMixin):
inputs: list[PromptStackElement] = field(factory=list, kw_only=True, metadata={"serializable": True})
actions: list[BaseTool] = field(factory=list, kw_only=True)

def add_input(self, content: str | BaseArtifact, role: str) -> PromptStackElement:
if isinstance(content, str):
self.inputs.append(PromptStackElement(content=[TextPromptStackContent(TextArtifact(content))], role=role))
elif isinstance(content, TextArtifact):
self.inputs.append(PromptStackElement(content=[TextPromptStackContent(content)], role=role))
elif isinstance(content, ImageArtifact):
self.inputs.append(PromptStackElement(content=[ImagePromptStackContent(content)], role=role))
elif isinstance(content, ListArtifact):
contents = []
for artifact in content.value:
def add_input(self, input_content: str | BaseArtifact, role: str) -> PromptStackElement:
content = []
if isinstance(input_content, str):
content.append(TextPromptStackContent(TextArtifact(input_content)))
elif isinstance(input_content, TextArtifact):
content.append(TextPromptStackContent(input_content))
elif isinstance(input_content, ImageArtifact):
content.append(ImagePromptStackContent(input_content))
elif isinstance(input_content, ListArtifact):
for artifact in input_content.value:
if isinstance(artifact, TextArtifact):
contents.append(TextPromptStackContent(artifact))
content.append(TextPromptStackContent(artifact))
elif isinstance(artifact, ImageArtifact):
contents.append(ImagePromptStackContent(artifact))
content.append(ImagePromptStackContent(artifact))
elif isinstance(artifact, ActionCallArtifact):
action = artifact.value
content.append(
ActionCallPromptStackContent(
ActionCallArtifact(
value=ActionCallArtifact.ActionCall(
tag=action.tag, name=action.name, input=json.dumps(action.input)
)
)
)
)
else:
raise ValueError(f"Unsupported artifact type: {type(artifact)}")
self.inputs.append(PromptStackElement(content=contents, role=role))
else:
raise ValueError(f"Unsupported content type: {type(content)}")
raise ValueError(f"Unsupported input_content type: {type(input_content)}")

self.inputs.append(PromptStackElement(content=content, role=role))

return self.inputs[-1]

def add_system_input(self, content: str) -> PromptStackElement:
return self.add_input(content, PromptStackElement.SYSTEM_ROLE)
def add_system_input(self, input_content: str) -> PromptStackElement:
return self.add_input(input_content, PromptStackElement.SYSTEM_ROLE)

def add_user_input(self, input_content: str | BaseArtifact) -> PromptStackElement:
return self.add_input(input_content, PromptStackElement.USER_ROLE)

def add_assistant_input(self, input_content: str | BaseArtifact) -> PromptStackElement:
return self.add_input(input_content, PromptStackElement.ASSISTANT_ROLE)

def add_user_input(self, content: str | BaseArtifact) -> PromptStackElement:
return self.add_input(content, PromptStackElement.USER_ROLE)
def add_action_call_input(self, thought: Optional[str], actions: list[ActionsSubtask.Action]):
artifact = ListArtifact(
[
ActionCallArtifact(
value=ActionCallArtifact.ActionCall(
tag=action.tag, name=action.name, path=action.path, input=json.dumps(action.input)
)
)
for action in actions
]
)

def add_assistant_input(self, content: str | BaseArtifact) -> PromptStackElement:
return self.add_input(content, PromptStackElement.ASSISTANT_ROLE)
if thought:
artifact = ListArtifact([TextArtifact(thought), *artifact.value])

return self.add_input(artifact, PromptStackElement.ASSISTANT_ROLE)

def add_action_result_input(
self, input_content: Optional[str], actions: list[ActionsSubtask.Action]
) -> PromptStackElement:
element = PromptStackElement(
content=[
ActionResultPromptStackContent(
artifact=action.output, action_tag=action.tag, action_name=action.name, action_path=action.path
)
for action in actions
if action.output
],
role=PromptStackElement.USER_ROLE,
)

self.inputs.append(element)

return self.inputs[-1]
Loading

0 comments on commit 1c8648b

Please sign in to comment.