Skip to content

Commit

Permalink
Feature/native functions (#867)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Jul 10, 2024
1 parent d987f1a commit 8fa35a5
Show file tree
Hide file tree
Showing 64 changed files with 2,930 additions and 575 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
- Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, and `CoherePromptDriver`.
- `OllamaEmbeddingDriver` for generating embeddings with Ollama.

### Changed
Expand Down
18 changes: 18 additions & 0 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ Griptape offers the following Prompt Drivers for interacting with LLMs.

### OpenAI Chat

!!! info
This driver uses [OpenAi function calling](https://platform.openai.com/docs/guides/function-calling) when using [Tools](../tools/index.md). You can change this to use Griptape's own tool calling implementation by setting `use_native_tools` to `False`.

The [OpenAiChatPromptDriver](../../reference/griptape/drivers/prompt/openai_chat_prompt_driver.md) connects to the [OpenAI Chat](https://platform.openai.com/docs/guides/chat) API.

```python
Expand Down Expand Up @@ -96,6 +99,9 @@ agent.run("Blue sky at dusk.")

### Azure OpenAI Chat

!!! info
This driver uses [Azure OpenAi function calling](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/function-calling) when using [Tools](../tools/index.md). You can change this to use Griptape's own tool calling implementation by setting `use_native_tools` to `False`.

The [AzureOpenAiChatPromptDriver](../../reference/griptape/drivers/prompt/azure_openai_chat_prompt_driver.md) connects to Azure OpenAI [Chat Completion](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference) APIs.

```python
Expand Down Expand Up @@ -132,6 +138,9 @@ The [CoherePromptDriver](../../reference/griptape/drivers/prompt/cohere_prompt_d
!!! info
This driver requires the `drivers-prompt-cohere` [extra](../index.md#extras).

!!! info
This driver uses [Cohere tool use](https://docs.cohere.com/docs/tools) when using [Tools](../tools/index.md). You can change this to use Griptape's own tool calling implementation by setting `use_native_tools` to `False`.

```python
import os
from griptape.structures import Agent
Expand All @@ -155,6 +164,9 @@ agent.run('What is the sentiment of this review? Review: "I really enjoyed this
!!! info
This driver requires the `drivers-prompt-anthropic` [extra](../index.md#extras).

!!! info
This driver uses [Anthropic tool use](https://docs.anthropic.com/en/docs/build-with-claude/tool-use) when using [Tools](../tools/index.md). You can change this to use Griptape's own tool calling implementation by setting `use_native_tools` to `False`.

The [AnthropicPromptDriver](../../reference/griptape/drivers/prompt/anthropic_prompt_driver.md) connects to the Anthropic [Messages](https://docs.anthropic.com/claude/reference/messages_post) API.

```python
Expand All @@ -180,6 +192,9 @@ agent.run('Where is the best place to see cherry blossums in Japan?')
!!! info
This driver requires the `drivers-prompt-google` [extra](../index.md#extras).

!!! info
This driver uses [Gemini function calling](https://ai.google.dev/gemini-api/docs/function-calling) when using [Tools](../tools/index.md). You can change this to use Griptape's own tool calling implementation by setting `use_native_tools` to `False`.

The [GooglePromptDriver](../../reference/griptape/drivers/prompt/google_prompt_driver.md) connects to the [Google Generative AI](https://ai.google.dev/tutorials/python_quickstart#generate_text_from_text_inputs) API.

```python
Expand All @@ -205,6 +220,9 @@ agent.run('Briefly explain how a computer works to a young child.')
!!! info
This driver requires the `drivers-prompt-amazon-bedrock` [extra](../index.md#extras).

!!! info
This drivers uses [Bedrock tool use](https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use.html) when using [Tools](../tools/index.md). You can change this to use Griptape's own tool calling implementation by setting `use_native_tools` to `False`.

The [AmazonBedrockPromptDriver](../../reference/griptape/drivers/prompt/amazon_bedrock_prompt_driver.md) uses [Amazon Bedrock](https://aws.amazon.com/bedrock/)'s [Converse API](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html).

All models supported by the Converse API are available for use with this driver.
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ from griptape.events import BaseEvent, StartPromptEvent, EventListener

def handler(event: BaseEvent):
if isinstance(event, StartPromptEvent):
print("Prompt Stack PromptStack:")
print("Prompt Stack Messages:")
for message in event.prompt_stack.messages:
print(f"{message.role}: {message.content}")
print("Final Prompt String:")
Expand Down
7 changes: 3 additions & 4 deletions docs/griptape-framework/tools/index.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
## Overview

One of the most powerful features of Griptape is the ability for Toolkit Tasks to generate _chains of thought_ (CoT) and use tools that can interact with the outside world. We use the [ReAct](https://arxiv.org/abs/2210.03629) technique to implement CoT reasoning and acting in the underlying LLMs without using any fine-tuning.

Griptape implements the reasoning loop in the Toolkit Tasks and integrates Griptape Tools natively.
One of the most powerful features of Griptape is the ability to use tools that can interact with the outside world.
Many of our [Prompt Drivers](../drivers/prompt-drivers.md) leverage the native function calling built into the LLMs. For LLMs that don't support this, Griptape provides its own implementation using the [ReAct](https://arxiv.org/abs/2210.03629) technique.

## Tools
Here is an example of a pipeline using tools:
Here is an example of a Pipeline using Tools:

```python
from griptape.tasks import ToolkitTask
Expand Down
2 changes: 2 additions & 0 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .media_artifact import MediaArtifact
from .image_artifact import ImageArtifact
from .audio_artifact import AudioArtifact
from .action_artifact import ActionArtifact


__all__ = [
Expand All @@ -23,4 +24,5 @@
"MediaArtifact",
"ImageArtifact",
"AudioArtifact",
"ActionArtifact",
]
18 changes: 18 additions & 0 deletions griptape/artifacts/action_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations

from attrs import define, field
from typing import TYPE_CHECKING

from griptape.artifacts import BaseArtifact
from griptape.mixins import SerializableMixin

if TYPE_CHECKING:
from griptape.common import ToolAction


@define()
class ActionArtifact(BaseArtifact, SerializableMixin):
value: ToolAction = field(metadata={"serializable": True})

def __add__(self, other: BaseArtifact) -> ActionArtifact:
raise NotImplementedError
12 changes: 12 additions & 0 deletions griptape/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from .actions.base_action import BaseAction
from .actions.tool_action import ToolAction

from .prompt_stack.contents.base_message_content import BaseMessageContent
from .prompt_stack.contents.base_delta_message_content import BaseDeltaMessageContent
from .prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent
from .prompt_stack.contents.text_message_content import TextMessageContent
from .prompt_stack.contents.image_message_content import ImageMessageContent
from .prompt_stack.contents.action_call_delta_message_content import ActionCallDeltaMessageContent
from .prompt_stack.contents.action_call_message_content import ActionCallMessageContent
from .prompt_stack.contents.action_result_message_content import ActionResultMessageContent

from .prompt_stack.messages.base_message import BaseMessage
from .prompt_stack.messages.delta_message import DeltaMessage
Expand All @@ -12,6 +18,7 @@

from .reference import Reference


__all__ = [
"BaseMessage",
"BaseDeltaMessageContent",
Expand All @@ -21,6 +28,11 @@
"TextDeltaMessageContent",
"TextMessageContent",
"ImageMessageContent",
"ActionCallDeltaMessageContent",
"ActionCallMessageContent",
"ActionResultMessageContent",
"PromptStack",
"Reference",
"BaseAction",
"ToolAction",
]
Empty file.
5 changes: 5 additions & 0 deletions griptape/common/actions/base_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from griptape.mixins import SerializableMixin
from abc import ABC


class BaseAction(SerializableMixin, ABC): ...
54 changes: 54 additions & 0 deletions griptape/common/actions/tool_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Optional

from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.common import BaseAction

if TYPE_CHECKING:
from griptape.tools import BaseTool


@define(kw_only=True)
class ToolAction(BaseAction):
"""Represents an instance of an LLM using a Tool.
Attributes:
tag: The tag (unique identifier) of the action.
name: The name (Tool name) of the action.
path: The path (Tool activity name) of the action.
input: The input (Tool params) of the action.
tool: The matched Tool of the action.
output: The output (Tool result) of the action.
"""

tag: str = field(metadata={"serializable": True})
name: str = field(metadata={"serializable": True})
path: Optional[str] = field(default=None, metadata={"serializable": True})
input: dict = field(factory=dict, metadata={"serializable": True})
tool: Optional[BaseTool] = field(default=None)
output: Optional[BaseArtifact] = field(default=None)

def __str__(self) -> str:
return json.dumps(self.to_dict())

def to_native_tool_name(self) -> str:
parts = [self.name]

if self.path is not None:
parts.append(self.path)

return "_".join(parts)

@classmethod
def from_native_tool_name(cls, native_tool_name: str) -> tuple[str, Optional[str]]:
parts = native_tool_name.split("_", 1)

if len(parts) == 1:
name, path = parts[0], None
else:
name, path = parts

return name, path
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations
from attrs import define, field
from typing import Optional

from griptape.common import BaseDeltaMessageContent


@define
class ActionCallDeltaMessageContent(BaseDeltaMessageContent):
tag: Optional[str] = field(default=None, metadata={"serializable": True})
name: Optional[str] = field(default=None, metadata={"serializable": True})
path: Optional[str] = field(default=None, metadata={"serializable": True})
partial_input: Optional[str] = field(default=None, metadata={"serializable": True})

def __str__(self) -> str:
parts = []

if self.name:
parts.append(self.name)
if self.path:
parts.append(f".{self.path}")
if self.tag:
parts.append(f" ({self.tag})")

if self.partial_input:
if parts:
parts.append(f" {self.partial_input}")
else:
parts.append(self.partial_input)

return "".join(parts)
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

import json
from collections.abc import Sequence

from attrs import define, field

from griptape.common import ToolAction
from griptape.artifacts import ActionArtifact
from griptape.common import BaseDeltaMessageContent, BaseMessageContent, ActionCallDeltaMessageContent


@define
class ActionCallMessageContent(BaseMessageContent):
artifact: ActionArtifact = field(metadata={"serializable": True})

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> ActionCallMessageContent:
action_call_deltas = [delta for delta in deltas if isinstance(delta, ActionCallDeltaMessageContent)]

tag = None
name = None
path = None
input = ""

for delta in action_call_deltas:
if delta.tag is not None:
tag = delta.tag
if delta.name is not None:
name = delta.name
if delta.path is not None:
path = delta.path
if delta.partial_input is not None:
input += delta.partial_input

if tag is not None and name is not None and path is not None:
try:
parsed_input = json.loads(input)
except json.JSONDecodeError as exc:
raise ValueError("Invalid JSON input for ToolAction") from exc
action = ToolAction(tag=tag, name=name, path=path, input=parsed_input)
else:
raise ValueError("Missing required fields for ToolAction")

artifact = ActionArtifact(value=action)

return cls(artifact=artifact)
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations

from collections.abc import Sequence

from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.common import BaseDeltaMessageContent, BaseMessageContent, ToolAction


@define
class ActionResultMessageContent(BaseMessageContent):
artifact: BaseArtifact = field(metadata={"serializable": True})
action: ToolAction = field(metadata={"serializable": True})

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> ActionResultMessageContent:
raise NotImplementedError
10 changes: 6 additions & 4 deletions griptape/common/prompt_stack/contents/base_message_content.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from abc import ABC

from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING
from attrs import define, field
Expand All @@ -14,9 +15,6 @@
class BaseMessageContent(ABC, SerializableMixin):
artifact: BaseArtifact = field(metadata={"serializable": True})

def to_text(self) -> str:
return str(self.artifact)

def __str__(self) -> str:
return self.artifact.to_text()

Expand All @@ -26,5 +24,9 @@ def __bool__(self) -> bool:
def __len__(self) -> int:
return len(self.artifact)

def to_text(self) -> str:
return str(self.artifact)

@classmethod
@abstractmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> BaseMessageContent: ...
6 changes: 3 additions & 3 deletions griptape/common/prompt_stack/messages/delta_message.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from __future__ import annotations

from typing import Optional

from attrs import define, field

from griptape.common.prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent

from griptape.common import BaseDeltaMessageContent

from .base_message import BaseMessage


@define
class DeltaMessage(BaseMessage):
role: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
content: Optional[TextDeltaMessageContent] = field(kw_only=True, default=None, metadata={"serializable": True})
content: Optional[BaseDeltaMessageContent] = field(kw_only=True, default=None, metadata={"serializable": True})
Loading

0 comments on commit 8fa35a5

Please sign in to comment.