-
Notifications
You must be signed in to change notification settings - Fork 176
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
573d81f
commit 1c8648b
Showing
15 changed files
with
471 additions
and
119 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from __future__ import annotations | ||
|
||
import json | ||
from typing import TYPE_CHECKING | ||
|
||
from attrs import define, field | ||
|
||
from griptape.artifacts import BaseArtifact | ||
from griptape.mixins import SerializableMixin | ||
|
||
if TYPE_CHECKING: | ||
pass | ||
|
||
|
||
@define() | ||
class ActionCallArtifact(BaseArtifact, SerializableMixin): | ||
"""Represents an instance of an LLM calling a Action. | ||
Attributes: | ||
tag: The tag (unique identifier) of the action. | ||
name: The name (Tool name) of the action. | ||
path: The path (Tool activity name) of the action. | ||
input: The input (Tool params) of the action. | ||
""" | ||
|
||
@define(kw_only=True) | ||
class ActionCall(SerializableMixin): | ||
tag: str = field(metadata={"serializable": True}) | ||
name: str = field(metadata={"serializable": True}) | ||
path: str = field(default=None, metadata={"serializable": True}) | ||
input: str = field(default={}, metadata={"serializable": True}) | ||
|
||
def __str__(self) -> str: | ||
value = self.to_dict() | ||
|
||
input = value.pop("input") | ||
formatted_json = ( | ||
"{" + ", ".join([f'"{k}": {json.dumps(v)}' for k, v in value.items()]) + f', "input": {input}' + "}" | ||
) | ||
|
||
return formatted_json | ||
|
||
def to_dict(self) -> dict: | ||
return {"tag": self.tag, "name": self.name, "path": self.path, "input": self.input} | ||
|
||
value: ActionCall = field(metadata={"serializable": True}) | ||
|
||
def __add__(self, other: BaseArtifact) -> ActionCallArtifact: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
41 changes: 41 additions & 0 deletions
41
griptape/common/prompt_stack/contents/action_call_prompt_stack_content.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from __future__ import annotations | ||
|
||
from attrs import define, field | ||
from typing import Sequence | ||
|
||
from griptape.artifacts.action_call_artifact import ActionCallArtifact | ||
from griptape.common import BasePromptStackContent, BaseDeltaPromptStackContent | ||
from griptape.common import DeltaActionCallPromptStackContent | ||
|
||
|
||
@define | ||
class ActionCallPromptStackContent(BasePromptStackContent): | ||
artifact: ActionCallArtifact = field(metadata={"serializable": True}) | ||
|
||
@classmethod | ||
def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> ActionCallPromptStackContent: | ||
action_call_deltas = [delta for delta in deltas if isinstance(delta, DeltaActionCallPromptStackContent)] | ||
|
||
tag = None | ||
name = None | ||
path = None | ||
input = "" | ||
|
||
for delta in action_call_deltas: | ||
if delta.tag is not None: | ||
tag = delta.tag | ||
if delta.name is not None: | ||
name = delta.name | ||
if delta.path is not None: | ||
path = delta.path | ||
if delta.delta_input is not None: | ||
input += delta.delta_input | ||
|
||
if tag is not None and name is not None and path is not None: | ||
action = ActionCallArtifact.ActionCall(tag=tag, name=name, path=path, input=input) | ||
else: | ||
raise ValueError("Missing required fields for ActionCallArtifact.Action") | ||
|
||
artifact = ActionCallArtifact(value=action) | ||
|
||
return cls(artifact=artifact) |
19 changes: 19 additions & 0 deletions
19
griptape/common/prompt_stack/contents/action_result_prompt_stack_content.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from __future__ import annotations | ||
|
||
from attrs import define, field | ||
from typing import Sequence | ||
|
||
from griptape.artifacts.base_artifact import BaseArtifact | ||
from griptape.common import BaseDeltaPromptStackContent, BasePromptStackContent | ||
|
||
|
||
@define | ||
class ActionResultPromptStackContent(BasePromptStackContent): | ||
artifact: BaseArtifact = field(metadata={"serializable": True}) | ||
action_tag: str = field(kw_only=True, metadata={"serializable": True}) | ||
action_name: str = field(kw_only=True, metadata={"serializable": True}) | ||
action_path: str = field(kw_only=True, metadata={"serializable": True}) | ||
|
||
@classmethod | ||
def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> ActionResultPromptStackContent: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
27 changes: 27 additions & 0 deletions
27
griptape/common/prompt_stack/contents/delta_action_call_prompt_stack_content.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from __future__ import annotations | ||
from attrs import define, field | ||
from typing import Optional | ||
|
||
from griptape.common import BaseDeltaPromptStackContent | ||
|
||
|
||
@define | ||
class DeltaActionCallPromptStackContent(BaseDeltaPromptStackContent): | ||
tag: Optional[str] = field(default=None, metadata={"serializable": True}) | ||
name: Optional[str] = field(default=None, metadata={"serializable": True}) | ||
path: Optional[str] = field(default=None, metadata={"serializable": True}) | ||
delta_input: Optional[str] = field(default=None, metadata={"serializable": True}) | ||
|
||
def __str__(self) -> str: | ||
output = "" | ||
|
||
if self.name is not None: | ||
output += f"{self.name}" | ||
if self.path is not None: | ||
output += f".{self.path}" | ||
if self.tag is not None: | ||
output += f" ({self.tag})" | ||
if self.delta_input is not None: | ||
output += f" {self.delta_input}" | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,104 @@ | ||
from __future__ import annotations | ||
|
||
import json | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
from attrs import define, field | ||
|
||
from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact, ImageArtifact | ||
from griptape.artifacts import ActionCallArtifact, BaseArtifact, ImageArtifact, ListArtifact, TextArtifact | ||
from griptape.common import ( | ||
ActionCallPromptStackContent, | ||
ImagePromptStackContent, | ||
PromptStackElement, | ||
TextPromptStackContent, | ||
) | ||
from griptape.common.prompt_stack.contents.action_result_prompt_stack_content import ActionResultPromptStackContent | ||
from griptape.mixins import SerializableMixin | ||
from griptape.common import PromptStackElement, TextPromptStackContent, ImagePromptStackContent | ||
|
||
if TYPE_CHECKING: | ||
from griptape.tools import BaseTool | ||
from griptape.tasks.actions_subtask import ActionsSubtask | ||
|
||
|
||
@define | ||
class PromptStack(SerializableMixin): | ||
inputs: list[PromptStackElement] = field(factory=list, kw_only=True, metadata={"serializable": True}) | ||
actions: list[BaseTool] = field(factory=list, kw_only=True) | ||
|
||
def add_input(self, content: str | BaseArtifact, role: str) -> PromptStackElement: | ||
if isinstance(content, str): | ||
self.inputs.append(PromptStackElement(content=[TextPromptStackContent(TextArtifact(content))], role=role)) | ||
elif isinstance(content, TextArtifact): | ||
self.inputs.append(PromptStackElement(content=[TextPromptStackContent(content)], role=role)) | ||
elif isinstance(content, ImageArtifact): | ||
self.inputs.append(PromptStackElement(content=[ImagePromptStackContent(content)], role=role)) | ||
elif isinstance(content, ListArtifact): | ||
contents = [] | ||
for artifact in content.value: | ||
def add_input(self, input_content: str | BaseArtifact, role: str) -> PromptStackElement: | ||
content = [] | ||
if isinstance(input_content, str): | ||
content.append(TextPromptStackContent(TextArtifact(input_content))) | ||
elif isinstance(input_content, TextArtifact): | ||
content.append(TextPromptStackContent(input_content)) | ||
elif isinstance(input_content, ImageArtifact): | ||
content.append(ImagePromptStackContent(input_content)) | ||
elif isinstance(input_content, ListArtifact): | ||
for artifact in input_content.value: | ||
if isinstance(artifact, TextArtifact): | ||
contents.append(TextPromptStackContent(artifact)) | ||
content.append(TextPromptStackContent(artifact)) | ||
elif isinstance(artifact, ImageArtifact): | ||
contents.append(ImagePromptStackContent(artifact)) | ||
content.append(ImagePromptStackContent(artifact)) | ||
elif isinstance(artifact, ActionCallArtifact): | ||
action = artifact.value | ||
content.append( | ||
ActionCallPromptStackContent( | ||
ActionCallArtifact( | ||
value=ActionCallArtifact.ActionCall( | ||
tag=action.tag, name=action.name, input=json.dumps(action.input) | ||
) | ||
) | ||
) | ||
) | ||
else: | ||
raise ValueError(f"Unsupported artifact type: {type(artifact)}") | ||
self.inputs.append(PromptStackElement(content=contents, role=role)) | ||
else: | ||
raise ValueError(f"Unsupported content type: {type(content)}") | ||
raise ValueError(f"Unsupported input_content type: {type(input_content)}") | ||
|
||
self.inputs.append(PromptStackElement(content=content, role=role)) | ||
|
||
return self.inputs[-1] | ||
|
||
def add_system_input(self, content: str) -> PromptStackElement: | ||
return self.add_input(content, PromptStackElement.SYSTEM_ROLE) | ||
def add_system_input(self, input_content: str) -> PromptStackElement: | ||
return self.add_input(input_content, PromptStackElement.SYSTEM_ROLE) | ||
|
||
def add_user_input(self, input_content: str | BaseArtifact) -> PromptStackElement: | ||
return self.add_input(input_content, PromptStackElement.USER_ROLE) | ||
|
||
def add_assistant_input(self, input_content: str | BaseArtifact) -> PromptStackElement: | ||
return self.add_input(input_content, PromptStackElement.ASSISTANT_ROLE) | ||
|
||
def add_user_input(self, content: str | BaseArtifact) -> PromptStackElement: | ||
return self.add_input(content, PromptStackElement.USER_ROLE) | ||
def add_action_call_input(self, thought: Optional[str], actions: list[ActionsSubtask.Action]): | ||
artifact = ListArtifact( | ||
[ | ||
ActionCallArtifact( | ||
value=ActionCallArtifact.ActionCall( | ||
tag=action.tag, name=action.name, path=action.path, input=json.dumps(action.input) | ||
) | ||
) | ||
for action in actions | ||
] | ||
) | ||
|
||
def add_assistant_input(self, content: str | BaseArtifact) -> PromptStackElement: | ||
return self.add_input(content, PromptStackElement.ASSISTANT_ROLE) | ||
if thought: | ||
artifact = ListArtifact([TextArtifact(thought), *artifact.value]) | ||
|
||
return self.add_input(artifact, PromptStackElement.ASSISTANT_ROLE) | ||
|
||
def add_action_result_input( | ||
self, input_content: Optional[str], actions: list[ActionsSubtask.Action] | ||
) -> PromptStackElement: | ||
element = PromptStackElement( | ||
content=[ | ||
ActionResultPromptStackContent( | ||
artifact=action.output, action_tag=action.tag, action_name=action.name, action_path=action.path | ||
) | ||
for action in actions | ||
if action.output | ||
], | ||
role=PromptStackElement.USER_ROLE, | ||
) | ||
|
||
self.inputs.append(element) | ||
|
||
return self.inputs[-1] |
Oops, something went wrong.