diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index e57177ac4d..cb50811343 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -8,6 +8,7 @@ from .media_artifact import MediaArtifact from .image_artifact import ImageArtifact from .audio_artifact import AudioArtifact +from .action_call_artifact import ActionCallArtifact __all__ = [ @@ -21,4 +22,5 @@ "MediaArtifact", "ImageArtifact", "AudioArtifact", + "ActionCallArtifact", ] diff --git a/griptape/artifacts/action_call_artifact.py b/griptape/artifacts/action_call_artifact.py new file mode 100644 index 0000000000..54e80f304f --- /dev/null +++ b/griptape/artifacts/action_call_artifact.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import json +from typing import TYPE_CHECKING + +from attrs import define, field + +from griptape.artifacts import BaseArtifact +from griptape.mixins import SerializableMixin + +if TYPE_CHECKING: + pass + + +@define() +class ActionCallArtifact(BaseArtifact, SerializableMixin): + """Represents an instance of an LLM calling a Action. + + Attributes: + tag: The tag (unique identifier) of the action. + name: The name (Tool name) of the action. + path: The path (Tool activity name) of the action. + input: The input (Tool params) of the action. + """ + + @define(kw_only=True) + class ActionCall(SerializableMixin): + tag: str = field(metadata={"serializable": True}) + name: str = field(metadata={"serializable": True}) + path: str = field(default=None, metadata={"serializable": True}) + input: str = field(default={}, metadata={"serializable": True}) + + def __str__(self) -> str: + value = self.to_dict() + + input = value.pop("input") + formatted_json = ( + "{" + ", ".join([f'"{k}": {json.dumps(v)}' for k, v in value.items()]) + f', "input": {input}' + "}" + ) + + return formatted_json + + def to_dict(self) -> dict: + return {"tag": self.tag, "name": self.name, "path": self.path, "input": self.input} + + value: ActionCall = field(metadata={"serializable": True}) + + def __add__(self, other: BaseArtifact) -> ActionCallArtifact: + raise NotImplementedError diff --git a/griptape/common/__init__.py b/griptape/common/__init__.py index 303c52db6d..c215141a10 100644 --- a/griptape/common/__init__.py +++ b/griptape/common/__init__.py @@ -3,6 +3,9 @@ from .prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent from .prompt_stack.contents.text_prompt_stack_content import TextPromptStackContent from .prompt_stack.contents.image_prompt_stack_content import ImagePromptStackContent +from .prompt_stack.contents.delta_action_call_prompt_stack_content import DeltaActionCallPromptStackContent +from .prompt_stack.contents.action_call_prompt_stack_content import ActionCallPromptStackContent +from .prompt_stack.contents.action_result_prompt_stack_content import ActionResultPromptStackContent from .prompt_stack.elements.base_prompt_stack_element import BasePromptStackElement from .prompt_stack.elements.delta_prompt_stack_element import DeltaPromptStackElement @@ -19,5 +22,8 @@ "DeltaTextPromptStackContent", "TextPromptStackContent", "ImagePromptStackContent", + "DeltaActionCallPromptStackContent", + "ActionCallPromptStackContent", + "ActionResultPromptStackContent", "PromptStack", ] diff --git a/griptape/common/prompt_stack/contents/action_call_prompt_stack_content.py b/griptape/common/prompt_stack/contents/action_call_prompt_stack_content.py new file mode 100644 index 0000000000..5217ae430b --- /dev/null +++ b/griptape/common/prompt_stack/contents/action_call_prompt_stack_content.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from attrs import define, field +from typing import Sequence + +from griptape.artifacts.action_call_artifact import ActionCallArtifact +from griptape.common import BasePromptStackContent, BaseDeltaPromptStackContent +from griptape.common import DeltaActionCallPromptStackContent + + +@define +class ActionCallPromptStackContent(BasePromptStackContent): + artifact: ActionCallArtifact = field(metadata={"serializable": True}) + + @classmethod + def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> ActionCallPromptStackContent: + action_call_deltas = [delta for delta in deltas if isinstance(delta, DeltaActionCallPromptStackContent)] + + tag = None + name = None + path = None + input = "" + + for delta in action_call_deltas: + if delta.tag is not None: + tag = delta.tag + if delta.name is not None: + name = delta.name + if delta.path is not None: + path = delta.path + if delta.delta_input is not None: + input += delta.delta_input + + if tag is not None and name is not None and path is not None: + action = ActionCallArtifact.ActionCall(tag=tag, name=name, path=path, input=input) + else: + raise ValueError("Missing required fields for ActionCallArtifact.Action") + + artifact = ActionCallArtifact(value=action) + + return cls(artifact=artifact) diff --git a/griptape/common/prompt_stack/contents/action_result_prompt_stack_content.py b/griptape/common/prompt_stack/contents/action_result_prompt_stack_content.py new file mode 100644 index 0000000000..bbea0606bb --- /dev/null +++ b/griptape/common/prompt_stack/contents/action_result_prompt_stack_content.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from attrs import define, field +from typing import Sequence + +from griptape.artifacts.base_artifact import BaseArtifact +from griptape.common import BaseDeltaPromptStackContent, BasePromptStackContent + + +@define +class ActionResultPromptStackContent(BasePromptStackContent): + artifact: BaseArtifact = field(metadata={"serializable": True}) + action_tag: str = field(kw_only=True, metadata={"serializable": True}) + action_name: str = field(kw_only=True, metadata={"serializable": True}) + action_path: str = field(kw_only=True, metadata={"serializable": True}) + + @classmethod + def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> ActionResultPromptStackContent: + raise NotImplementedError diff --git a/griptape/common/prompt_stack/contents/base_prompt_stack_content.py b/griptape/common/prompt_stack/contents/base_prompt_stack_content.py index 74d94bc6d2..6ec4f0ca38 100644 --- a/griptape/common/prompt_stack/contents/base_prompt_stack_content.py +++ b/griptape/common/prompt_stack/contents/base_prompt_stack_content.py @@ -1,6 +1,6 @@ from __future__ import annotations -from abc import ABC +from abc import ABC, abstractmethod from collections.abc import Sequence from attrs import define, field @@ -28,4 +28,5 @@ def __len__(self) -> int: return len(self.artifact) @classmethod + @abstractmethod def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> BasePromptStackContent: ... diff --git a/griptape/common/prompt_stack/contents/delta_action_call_prompt_stack_content.py b/griptape/common/prompt_stack/contents/delta_action_call_prompt_stack_content.py new file mode 100644 index 0000000000..729778f637 --- /dev/null +++ b/griptape/common/prompt_stack/contents/delta_action_call_prompt_stack_content.py @@ -0,0 +1,27 @@ +from __future__ import annotations +from attrs import define, field +from typing import Optional + +from griptape.common import BaseDeltaPromptStackContent + + +@define +class DeltaActionCallPromptStackContent(BaseDeltaPromptStackContent): + tag: Optional[str] = field(default=None, metadata={"serializable": True}) + name: Optional[str] = field(default=None, metadata={"serializable": True}) + path: Optional[str] = field(default=None, metadata={"serializable": True}) + delta_input: Optional[str] = field(default=None, metadata={"serializable": True}) + + def __str__(self) -> str: + output = "" + + if self.name is not None: + output += f"{self.name}" + if self.path is not None: + output += f".{self.path}" + if self.tag is not None: + output += f" ({self.tag})" + if self.delta_input is not None: + output += f" {self.delta_input}" + + return output diff --git a/griptape/common/prompt_stack/elements/prompt_stack_element.py b/griptape/common/prompt_stack/elements/prompt_stack_element.py index b94c8a5f44..b86184b9e1 100644 --- a/griptape/common/prompt_stack/elements/prompt_stack_element.py +++ b/griptape/common/prompt_stack/elements/prompt_stack_element.py @@ -1,11 +1,13 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any, Optional, Sequence from attrs import Factory, define, field from griptape.artifacts import TextArtifact from griptape.common import BasePromptStackContent, TextPromptStackContent +from griptape.common.prompt_stack.contents.action_call_prompt_stack_content import ActionCallPromptStackContent +from griptape.common.prompt_stack.contents.action_result_prompt_stack_content import ActionResultPromptStackContent from griptape.mixins.serializable_mixin import SerializableMixin from .base_prompt_stack_element import BasePromptStackElement @@ -22,12 +24,12 @@ class Usage(SerializableMixin): def total_tokens(self) -> float: return (self.input_tokens or 0) + (self.output_tokens or 0) - def __init__(self, content: str | list[BasePromptStackContent], **kwargs: Any): + def __init__(self, content: str | Sequence[BasePromptStackContent], **kwargs: Any): if isinstance(content, str): content = [TextPromptStackContent(TextArtifact(value=content))] self.__attrs_init__(content, **kwargs) # pyright: ignore[reportAttributeAccessIssue] - content: list[BasePromptStackContent] = field(metadata={"serializable": True}) + content: Sequence[BasePromptStackContent] = field(metadata={"serializable": True}) usage: Usage = field( kw_only=True, default=Factory(lambda: PromptStackElement.Usage()), metadata={"serializable": True} ) @@ -45,10 +47,26 @@ def __str__(self) -> str: def to_text(self) -> str: return self.to_text_artifact().to_text() + def has_action_results(self) -> bool: + return any(isinstance(content, ActionResultPromptStackContent) for content in self.content) + + def has_action_calls(self) -> bool: + return any(isinstance(content, ActionCallPromptStackContent) for content in self.content) + def to_text_artifact(self) -> TextArtifact: - artifact = TextArtifact(value="") + action_call_contents = [ + content for content in self.content if isinstance(content, ActionCallPromptStackContent) + ] + text_contents = [content for content in self.content if isinstance(content, TextPromptStackContent)] - for content in self.content: - artifact.value += content.artifact.to_text() + text_output = "".join([content.artifact.value for content in text_contents]) + if action_call_contents: + actions_output = [str(action.artifact.value) for action in action_call_contents] + output = "Actions: [" + ", ".join(actions_output) + "]" + + if text_output: + output = f"Thought: {text_output}\n{output}" + else: + output = text_output - return artifact + return TextArtifact(value=output) diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py index 9ac1819177..afab9a104c 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -1,42 +1,104 @@ from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Optional + from attrs import define, field -from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact, ImageArtifact +from griptape.artifacts import ActionCallArtifact, BaseArtifact, ImageArtifact, ListArtifact, TextArtifact +from griptape.common import ( + ActionCallPromptStackContent, + ImagePromptStackContent, + PromptStackElement, + TextPromptStackContent, +) +from griptape.common.prompt_stack.contents.action_result_prompt_stack_content import ActionResultPromptStackContent from griptape.mixins import SerializableMixin -from griptape.common import PromptStackElement, TextPromptStackContent, ImagePromptStackContent + +if TYPE_CHECKING: + from griptape.tools import BaseTool + from griptape.tasks.actions_subtask import ActionsSubtask @define class PromptStack(SerializableMixin): inputs: list[PromptStackElement] = field(factory=list, kw_only=True, metadata={"serializable": True}) + actions: list[BaseTool] = field(factory=list, kw_only=True) - def add_input(self, content: str | BaseArtifact, role: str) -> PromptStackElement: - if isinstance(content, str): - self.inputs.append(PromptStackElement(content=[TextPromptStackContent(TextArtifact(content))], role=role)) - elif isinstance(content, TextArtifact): - self.inputs.append(PromptStackElement(content=[TextPromptStackContent(content)], role=role)) - elif isinstance(content, ImageArtifact): - self.inputs.append(PromptStackElement(content=[ImagePromptStackContent(content)], role=role)) - elif isinstance(content, ListArtifact): - contents = [] - for artifact in content.value: + def add_input(self, input_content: str | BaseArtifact, role: str) -> PromptStackElement: + content = [] + if isinstance(input_content, str): + content.append(TextPromptStackContent(TextArtifact(input_content))) + elif isinstance(input_content, TextArtifact): + content.append(TextPromptStackContent(input_content)) + elif isinstance(input_content, ImageArtifact): + content.append(ImagePromptStackContent(input_content)) + elif isinstance(input_content, ListArtifact): + for artifact in input_content.value: if isinstance(artifact, TextArtifact): - contents.append(TextPromptStackContent(artifact)) + content.append(TextPromptStackContent(artifact)) elif isinstance(artifact, ImageArtifact): - contents.append(ImagePromptStackContent(artifact)) + content.append(ImagePromptStackContent(artifact)) + elif isinstance(artifact, ActionCallArtifact): + action = artifact.value + content.append( + ActionCallPromptStackContent( + ActionCallArtifact( + value=ActionCallArtifact.ActionCall( + tag=action.tag, name=action.name, input=json.dumps(action.input) + ) + ) + ) + ) else: raise ValueError(f"Unsupported artifact type: {type(artifact)}") - self.inputs.append(PromptStackElement(content=contents, role=role)) else: - raise ValueError(f"Unsupported content type: {type(content)}") + raise ValueError(f"Unsupported input_content type: {type(input_content)}") + + self.inputs.append(PromptStackElement(content=content, role=role)) return self.inputs[-1] - def add_system_input(self, content: str) -> PromptStackElement: - return self.add_input(content, PromptStackElement.SYSTEM_ROLE) + def add_system_input(self, input_content: str) -> PromptStackElement: + return self.add_input(input_content, PromptStackElement.SYSTEM_ROLE) + + def add_user_input(self, input_content: str | BaseArtifact) -> PromptStackElement: + return self.add_input(input_content, PromptStackElement.USER_ROLE) + + def add_assistant_input(self, input_content: str | BaseArtifact) -> PromptStackElement: + return self.add_input(input_content, PromptStackElement.ASSISTANT_ROLE) - def add_user_input(self, content: str | BaseArtifact) -> PromptStackElement: - return self.add_input(content, PromptStackElement.USER_ROLE) + def add_action_call_input(self, thought: Optional[str], actions: list[ActionsSubtask.Action]): + artifact = ListArtifact( + [ + ActionCallArtifact( + value=ActionCallArtifact.ActionCall( + tag=action.tag, name=action.name, path=action.path, input=json.dumps(action.input) + ) + ) + for action in actions + ] + ) - def add_assistant_input(self, content: str | BaseArtifact) -> PromptStackElement: - return self.add_input(content, PromptStackElement.ASSISTANT_ROLE) + if thought: + artifact = ListArtifact([TextArtifact(thought), *artifact.value]) + + return self.add_input(artifact, PromptStackElement.ASSISTANT_ROLE) + + def add_action_result_input( + self, input_content: Optional[str], actions: list[ActionsSubtask.Action] + ) -> PromptStackElement: + element = PromptStackElement( + content=[ + ActionResultPromptStackContent( + artifact=action.output, action_tag=action.tag, action_name=action.name, action_path=action.path + ) + for action in actions + if action.output + ], + role=PromptStackElement.USER_ROLE, + ) + + self.inputs.append(element) + + return self.inputs[-1] diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 89539028e4..c7b6452179 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -8,7 +8,9 @@ from griptape.artifacts.text_artifact import TextArtifact from griptape.common import ( + ActionCallPromptStackContent, BaseDeltaPromptStackContent, + DeltaActionCallPromptStackContent, DeltaPromptStackElement, DeltaTextPromptStackContent, PromptStack, @@ -36,6 +38,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: The model name. tokenizer: An instance of `BaseTokenizer` to when calculating tokens. stream: Whether to stream the completion or not. `CompletionChunkEvent`s will be published to the `Structure` if one is provided. + use_native_tools: Whether to use LLM's native function calling capabilities. Must be supported by the model. """ temperature: float = field(default=0.1, kw_only=True, metadata={"serializable": True}) @@ -47,6 +50,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(metadata={"serializable": True}) tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: if self.structure: @@ -134,12 +138,20 @@ def __process_stream(self, prompt_stack: PromptStack) -> PromptStackElement: if isinstance(delta, DeltaTextPromptStackContent): self.structure.publish_event(CompletionChunkEvent(token=delta.text)) + elif isinstance(delta, DeltaActionCallPromptStackContent): + if delta.tag is not None and delta.name is not None and delta.path is not None: + self.structure.publish_event(CompletionChunkEvent(token=str(delta))) + elif delta.delta_input is not None: + self.structure.publish_event(CompletionChunkEvent(token=delta.delta_input)) content = [] - for index, deltas in delta_contents.items(): - text_deltas = [delta for delta in deltas if isinstance(delta, DeltaTextPromptStackContent)] + for index, delta_content in delta_contents.items(): + text_deltas = [delta for delta in delta_content if isinstance(delta, DeltaTextPromptStackContent)] + action_deltas = [delta for delta in delta_content if isinstance(delta, DeltaActionCallPromptStackContent)] if text_deltas: content.append(TextPromptStackContent.from_deltas(text_deltas)) + if action_deltas: + content.append(ActionCallPromptStackContent.from_deltas(action_deltas)) result = PromptStackElement( content=content, diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index c35186976f..9c31892745 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -1,12 +1,15 @@ from __future__ import annotations +import json from collections.abc import Iterator from typing import Literal, Optional, TYPE_CHECKING import openai from attrs import Factory, define, field +from schema import Schema from griptape.artifacts import TextArtifact +from griptape.artifacts.action_call_artifact import ActionCallArtifact from griptape.common import ( BaseDeltaPromptStackContent, BasePromptStackContent, @@ -17,6 +20,11 @@ PromptStackElement, TextPromptStackContent, ) +from griptape.common.prompt_stack.contents.action_call_prompt_stack_content import ActionCallPromptStackContent +from griptape.common.prompt_stack.contents.action_result_prompt_stack_content import ActionResultPromptStackContent +from griptape.common.prompt_stack.contents.delta_action_call_prompt_stack_content import ( + DeltaActionCallPromptStackContent, +) from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer @@ -24,6 +32,7 @@ if TYPE_CHECKING: from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.chat.chat_completion_chunk import ChoiceDelta + from griptape.tools import BaseTool @define @@ -60,6 +69,8 @@ class OpenAiChatPromptDriver(BasePromptDriver): default=None, kw_only=True, metadata={"serializable": True} ) seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) + tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False}) + use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) ignored_exception_types: tuple[type[Exception], ...] = field( default=Factory( lambda: ( @@ -81,8 +92,8 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: message = result.choices[0].message return PromptStackElement( - content=[self.__message_to_prompt_stack_content(message)], - role=message.role, + content=self.__message_to_prompt_stack_content(message), + role=PromptStackElement.ASSISTANT_ROLE, usage=PromptStackElement.Usage( input_tokens=result.usage.prompt_tokens, output_tokens=result.usage.completion_tokens ), @@ -112,7 +123,55 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElem raise Exception("Completion with more than one choice is not supported yet.") def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: - return [{"role": self.__to_role(input), "content": self.__to_content(input)} for input in prompt_stack.inputs] + messages = [] + + for input in prompt_stack.inputs: + if input.has_action_results(): + # Action results need to be expanded into separate messages. + for action_result in input.content: + if isinstance(action_result, ActionResultPromptStackContent): + messages.append( + { + "role": self.__to_role(input), + "content": self.__prompt_stack_content_message_content(action_result), + "name": f"{action_result.action_name}-{action_result.action_path}", + "tool_call_id": action_result.action_tag, + } + ) + else: + # Action calls are attached to the assistant message that originally generated them. + messages.append( + { + "role": self.__to_role(input), + "content": [ + self.__prompt_stack_content_message_content(content) + for content in input.content + if not isinstance( # Action calls do not belong in the content + content, ActionCallPromptStackContent + ) + ], + **( + { + "tool_calls": [ + self.__prompt_stack_content_message_content(action_call) + for action_call in input.content + if isinstance(action_call, ActionCallPromptStackContent) + ] + } + if input.has_action_calls() + else {} + ), + } + ) + + return messages + + def _prompt_stack_to_tools(self, prompt_stack: PromptStack) -> dict: + return ( + {"tools": self.__to_tools(prompt_stack.actions), "tool_choice": self.tool_choice} + if prompt_stack.actions and self.use_native_tools + else {} + ) def _base_params(self, prompt_stack: PromptStack) -> dict: params = { @@ -120,6 +179,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "temperature": self.temperature, "user": self.user, "seed": self.seed, + **self._prompt_stack_to_tools(prompt_stack), **({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}), **({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}), } @@ -141,15 +201,26 @@ def __to_role(self, input: PromptStackElement) -> str: elif input.is_assistant(): return "assistant" else: - return "user" + if input.has_action_results(): + return "tool" + else: + return "user" - def __to_content(self, input: PromptStackElement) -> str | list[dict]: - if all(isinstance(content, TextPromptStackContent) for content in input.content): - return input.to_text_artifact().to_text() - else: - return [self.__prompt_stack_content_message_content(content) for content in input.content] + def __to_tools(self, tools: list[BaseTool]) -> list[dict]: + return [ + { + "function": { + "name": f"{tool.name}-{tool.activity_name(activity)}", + "description": tool.activity_description(activity), + "parameters": (tool.activity_schema(activity) or Schema({})).json_schema("Parameters Schema"), + }, + "type": "function", + } + for tool in tools + for activity in tool.activities() + ] - def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> dict: + def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> str | dict: if isinstance(content, TextPromptStackContent): return {"type": "text", "text": content.artifact.to_text()} elif isinstance(content, ImagePromptStackContent): @@ -157,12 +228,36 @@ def __prompt_stack_content_message_content(self, content: BasePromptStackContent "type": "image_url", "image_url": {"url": f"data:{content.artifact.mime_type};base64,{content.artifact.base64}"}, } + elif isinstance(content, ActionCallPromptStackContent): + action = content.artifact.value + + return { + "type": "function", + "id": action.tag, + "function": {"name": f"{action.name}-{action.path}", "arguments": json.dumps(action.input)}, + } + elif isinstance(content, ActionResultPromptStackContent): + return content.artifact.to_text() else: raise ValueError(f"Unsupported content type: {type(content)}") - def __message_to_prompt_stack_content(self, message: ChatCompletionMessage) -> BasePromptStackContent: + def __message_to_prompt_stack_content(self, message: ChatCompletionMessage) -> list[BasePromptStackContent]: if message.content is not None: - return TextPromptStackContent(TextArtifact(message.content)) + return [TextPromptStackContent(TextArtifact(message.content))] + elif message.tool_calls is not None: + return [ + ActionCallPromptStackContent( + ActionCallArtifact( + ActionCallArtifact.ActionCall( + tag=tool_call.id, + name=tool_call.function.name.split("-")[0], + path=tool_call.function.name.split("-")[1], + input=tool_call.function.arguments, + ) + ) + ) + for tool_call in message.tool_calls + ] else: raise ValueError(f"Unsupported message type: {message}") @@ -171,5 +266,24 @@ def __message_delta_to_prompt_stack_content_delta(self, content_delta: ChoiceDel delta_content = content_delta.content return DeltaTextPromptStackContent(delta_content, role=content_delta.role) + elif content_delta.tool_calls is not None: + tool_calls = content_delta.tool_calls + + if len(tool_calls) == 1: + tool_call = tool_calls[0] + index = tool_call.index + + # Tool call delta either contains the function header or the partial input. + if tool_call.id is not None: + return DeltaActionCallPromptStackContent( + index=index, + tag=tool_call.id, + name=tool_call.function.name.split("-")[0], + path=tool_call.function.name.split("-")[1], + ) + else: + return DeltaActionCallPromptStackContent(index=index, delta_input=tool_call.function.arguments) + else: + raise ValueError(f"Unsupported tool call delta length: {len(tool_calls)}") else: return DeltaTextPromptStackContent("", role=content_delta.role) diff --git a/griptape/mixins/exponential_backoff_mixin.py b/griptape/mixins/exponential_backoff_mixin.py index 5045575f12..b4d6449ab2 100644 --- a/griptape/mixins/exponential_backoff_mixin.py +++ b/griptape/mixins/exponential_backoff_mixin.py @@ -9,7 +9,7 @@ class ExponentialBackoffMixin(ABC): min_retry_delay: float = field(default=2, kw_only=True) max_retry_delay: float = field(default=10, kw_only=True) - max_attempts: int = field(default=10, kw_only=True) + max_attempts: int = field(default=0, kw_only=True) after_hook: Callable = field(default=lambda s: logging.warning(s), kw_only=True) ignored_exception_types: tuple[type[Exception], ...] = field(factory=tuple, kw_only=True) diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index e61307c64b..c29afa0e88 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -107,6 +107,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: from griptape.structures import Structure from griptape.common import PromptStack, PromptStackElement from griptape.tokenizers.base_tokenizer import BaseTokenizer + from griptape.tools import BaseTool from typing import Any boto3 = import_optional_dependency("boto3") if is_dependency_installed("boto3") else Any @@ -117,6 +118,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: localns={ "PromptStack": PromptStack, "Usage": PromptStackElement.Usage, + "BaseTool": BaseTool, "Structure": Structure, "BaseConversationMemoryDriver": BaseConversationMemoryDriver, "BasePromptDriver": BasePromptDriver, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 4aa2b27830..e760d1a0c7 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -23,9 +23,10 @@ class ActionsSubtask(BaseTextInputTask): class Action: tag: str = field() name: str = field() - path: Optional[str] = field(default=None) + path: str = field(default=None) input: dict = field() tool: Optional[BaseTool] = field(default=None) + output: Optional[BaseArtifact] = field(default=None) THOUGHT_PATTERN = r"(?s)^Thought:\s*(.*?)$" ACTIONS_PATTERN = r"(?s)Actions:[^\[]*(\[.*\])" @@ -34,8 +35,8 @@ class Action: parent_task_id: Optional[str] = field(default=None, kw_only=True) thought: Optional[str] = field(default=None, kw_only=True) actions: list[Action] = field(factory=list, kw_only=True) + output: Optional[BaseArtifact] = field(default=None, init=False) - _input: Optional[str | TextArtifact | Callable[[BaseTask], TextArtifact]] = field(default=None) _memory: Optional[TaskMemory] = None @property @@ -77,12 +78,12 @@ def before_run(self) -> None: subtask_actions=self.actions_to_dicts(), ) ) - self.structure.logger.info(f"Subtask {self.id}\n{self.input.to_text()}") + self.structure.logger.info(f"Subtask {self.id}\n{self.actions_to_json()}") def run(self) -> BaseArtifact: try: - if any(a.name == "error" for a in self.actions): - errors = [a.input["error"] for a in self.actions if a.name == "error"] + if any(isinstance(a.output, ErrorArtifact) for a in self.actions): + errors = [a.output.value for a in self.actions if isinstance(a.output, ErrorArtifact)] self.output = ErrorArtifact("\n\n".join(errors)) else: @@ -105,7 +106,7 @@ def run(self) -> BaseArtifact: else: return ErrorArtifact("no tool output") - def execute_actions(self, actions: list[Action]) -> list[tuple[str, BaseArtifact]]: + def execute_actions(self, actions: list[ActionsSubtask.Action]) -> list[tuple[str, BaseArtifact]]: results = utils.execute_futures_dict( {a.tag: self.futures_executor.submit(self.execute_action, a) for a in actions} ) @@ -120,6 +121,7 @@ def execute_action(self, action: Action) -> tuple[str, BaseArtifact]: output = ErrorArtifact("action path not found") else: output = ErrorArtifact("action name not found") + action.output = output return action.tag, output @@ -163,7 +165,7 @@ def actions_to_dicts(self) -> list[dict]: return json_list def actions_to_json(self) -> str: - return json.dumps(self.actions_to_dicts()) + return json.dumps(self.actions_to_dicts(), indent=2) def _process_task_input( self, task_input: str | list | BaseArtifact | Callable[[BaseTask], BaseArtifact] @@ -190,69 +192,53 @@ def __init_from_prompt(self, value: str) -> None: def __parse_actions(self, actions_matches: list[str]) -> None: if len(actions_matches) == 0: return - try: data = actions_matches[-1] - actions_list: list = json.loads(data, strict=False) - - if isinstance(self.origin_task, ActionsSubtaskOriginMixin): - self.origin_task.actions_schema().validate(actions_list) - - for action_object in actions_list: - # Load action name; throw exception if the key is not present - action_tag = action_object["tag"] - - # Load action name; throw exception if the key is not present - action_name = action_object["name"] - - # Load action method; throw exception if the key is not present - action_path = action_object["path"] - - # Load optional input value; don't throw exceptions if key is not present - if "input" in action_object: - # The schema library has a bug, where something like `Or(str, None)` doesn't get - # correctly translated into JSON schema. For some optional input fields LLMs sometimes - # still provide null value, which trips up the validator. The temporary solution that - # works is to strip all key-values where value is null. - action_input = remove_null_values_in_dict_recursively(action_object["input"]) - else: - action_input = {} - - # Load the action itself - if isinstance(self.origin_task, ActionsSubtaskOriginMixin): - tool = self.origin_task.find_tool(action_name) - else: - raise Exception( - "ActionSubtask must be attached to a Task that implements ActionSubtaskOriginMixin." - ) - - new_action = ActionsSubtask.Action( - tag=action_tag, name=action_name, path=action_path, input=action_input, tool=tool - ) - - if new_action.tool: - if new_action.input: - self.__validate_action(new_action) - - # Don't forget to add it to the subtask actions list! - self.actions.append(new_action) - except SyntaxError as e: - self.structure.logger.error(f"Subtask {self.origin_task.id}\nSyntax error: {e}") + actions_list: list[dict] = json.loads(data, strict=False) - self.actions.append(self.__error_to_action(f"syntax error: {e}")) - except schema.SchemaError as e: - self.structure.logger.error(f"Subtask {self.origin_task.id}\nInvalid action JSON: {e}") + self.actions = [self.__process_action_object(action_object) for action_object in actions_list] + except json.JSONDecodeError as e: + self.structure.logger.error(f"Subtask {self.origin_task.id}\nInvalid actions JSON: {e}") - self.actions.append(self.__error_to_action(f"Action JSON validation error: {e}")) - except Exception as e: - self.structure.logger.error(f"Subtask {self.origin_task.id}\nError parsing tool action: {e}") + self.output = ErrorArtifact(f"Actions JSON decoding error: {e}", exception=e) - self.actions.append(self.__error_to_action(f"Action input parsing error: {e}")) + def __process_action_object(self, action_object: dict) -> ActionsSubtask.Action: + # Load action name; throw exception if the key is not present + action_tag = action_object["tag"] - def __error_to_action(self, error: str) -> Action: - return ActionsSubtask.Action(tag="error", name="error", input={"error": error}) + # Load action name; throw exception if the key is not present + action_name = action_object["name"] - def __validate_action(self, action: Action) -> None: + # Load action method; throw exception if the key is not present + action_path = action_object["path"] + + # Load optional input value; don't throw exceptions if key is not present + if "input" in action_object: + # The schema library has a bug, where something like `Or(str, None)` doesn't get + # correctly translated into JSON schema. For some optional input fields LLMs sometimes + # still provide null value, which trips up the validator. The temporary solution that + # works is to strip all key-values where value is null. + action_input = remove_null_values_in_dict_recursively(action_object["input"]) + else: + action_input = {} + + # Load the action itself + if isinstance(self.origin_task, ActionsSubtaskOriginMixin): + tool = self.origin_task.find_tool(action_name) + else: + raise Exception("ActionSubtask must be attached to a Task that implements ActionSubtaskOriginMixin.") + + new_action = ActionsSubtask.Action( + tag=action_tag, name=action_name, path=action_path, input=action_input, tool=tool + ) + + if new_action.tool: + if new_action.input: + self.__validate_action(new_action) + + return new_action + + def __validate_action(self, action: ActionsSubtask.Action) -> None: try: if action.path is not None: activity = getattr(action.tool, action.path) @@ -267,6 +253,10 @@ def __validate_action(self, action: Action) -> None: if activity_schema: activity_schema.validate(action.input) except schema.SchemaError as e: - self.structure.logger.error(f"Subtask {self.origin_task.id}\nInvalid activity input JSON: {e}") + self.structure.logger.error(f"Subtask {self.origin_task.id}\nInvalid action JSON: {e}") + + action.output = ErrorArtifact(f"Activity input JSON validation error: {e}", exception=e) + except SyntaxError as e: + self.structure.logger.error(f"Subtask {self.origin_task.id}\nSyntax error: {e}") - self.actions.append(self.__error_to_action(f"Activity input JSON validation error: {e}")) + action.output = ErrorArtifact(f"Syntax error: {e}", exception=e) diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index de291d1a92..020517b6c7 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -62,7 +62,7 @@ def tool_output_memory(self) -> list[TaskMemory]: @property def prompt_stack(self) -> PromptStack: - stack = PromptStack() + stack = PromptStack(actions=self.tools) memory = self.structure.conversation_memory stack.add_system_input(self.generate_system_template(self)) @@ -73,8 +73,12 @@ def prompt_stack(self) -> PromptStack: stack.add_assistant_input(self.output.to_text()) else: for s in self.subtasks: - stack.add_assistant_input(self.generate_assistant_subtask_template(s)) - stack.add_user_input(self.generate_user_subtask_template(s)) + if self.prompt_driver.use_native_tools: + stack.add_action_call_input(s.thought, s.actions) + stack.add_action_result_input(self.generate_user_subtask_template(s), s.actions) + else: + stack.add_assistant_input(self.generate_assistant_subtask_template(s)) + stack.add_user_input(self.generate_user_subtask_template(s)) if memory: # inserting at index 1 to place memory right after system prompt @@ -99,6 +103,7 @@ def default_system_template_generator(self, _: PromptTask) -> str: action_names=str.join(", ", [tool.name for tool in self.tools]), actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), + use_native_tools=self.prompt_driver.use_native_tools, stop_sequence=self.response_stop_sequence, ) diff --git a/griptape/templates/tasks/toolkit_task/system.j2 b/griptape/templates/tasks/toolkit_task/system.j2 index 8ddfcf407b..09c4e9af98 100644 --- a/griptape/templates/tasks/toolkit_task/system.j2 +++ b/griptape/templates/tasks/toolkit_task/system.j2 @@ -1,14 +1,20 @@ -Think step-by-step and execute actions sequentially or in parallel. You must use the following format when executing actions: +You can think step-by-step and execute actions sequentially or in parallel to get your final answer. +{% if not use_native_tools %} + +You must use the following format when executing actions: Thought: Actions: {{ stop_sequence }}: ...repeat Thought/Actions/{{ stop_sequence }} as many times as you need -Answer: +"Thought", "Actions", "{{ stop_sequence }}" must always start on a new line. If {{ stop_sequence }} contains an error, you MUST ALWAYS try to fix the error with another Thought/Actions/{{ stop_sequence }}. -"Thought", "Actions", "{{ stop_sequence }}", and "Answer" MUST ALWAYS start on a new line. If {{ stop_sequence }} contains an error, you MUST ALWAYS try to fix the error with another Thought/Actions/{{ stop_sequence }}. NEVER make up actions. Actions must ALWAYS be a plain JSON array of objects. ALWAYS use double quotes for keys and string values in JSON objects. NEVER make up facts. Be truthful. ALWAYS be proactive and NEVER ask the user for more information input. Keep going until you have the final answer. +{% endif %} +You must use the following format when providing your final answer: +Answer: -You have access ONLY to the actions with the following names: [{{ action_names }}]. You can use multiple actions in a sequence or in parallel to get the final answer. NEVER make up action names or action paths. NEVER reference tags in other action input values. +Be truthful. ALWAYS be proactive and NEVER ask the user for more information input. Keep using actions until you have your final answer. +NEVER make up actions, action names, or action paths. NEVER make up facts. NEVER reference tags in other action input values. Actions might store their output in memory as artifacts (with `memory_name` and `artifact_namespace`). If action output is stored in memory, ALWAYS try to pass it to another action. NEVER make up memory names or artifact namespaces. {% if meta_memory %}