Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/native functions #867

Merged
merged 90 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
46fef79
Refactor prompt stack
collindutter Jun 13, 2024
a405af0
Add support for more modalities to conversation memory
collindutter Jun 14, 2024
1740d1d
Update default artifact
collindutter Jun 14, 2024
c91f7ea
Fix bad merge
collindutter Jun 18, 2024
ad3a0e3
Rename Prompt Stack Element to Prompt Stack Message
collindutter Jun 18, 2024
4665a8f
Fix Ollama
collindutter Jun 19, 2024
19929a5
Clean up roles
collindutter Jun 19, 2024
95c5849
Rename deltas
collindutter Jun 19, 2024
0b4b455
PR cleanup
collindutter Jun 19, 2024
05269ea
Change task hierarchy
collindutter Jun 19, 2024
9381332
Update changelog
collindutter Jun 19, 2024
c725b94
Regenerate lock file
collindutter Jun 20, 2024
6914d40
Add back missing logs
collindutter Jun 20, 2024
0054312
Fix doc var names
collindutter Jun 20, 2024
f1e8088
Clean up message building
collindutter Jun 20, 2024
032bc04
Add tests
collindutter Jun 20, 2024
c4b9351
Add image input support to ollama
collindutter Jun 20, 2024
f4f0eed
Fix tests
collindutter Jun 20, 2024
e8a2dcb
Refactor prompt stack
collindutter Jun 13, 2024
a499374
Implement native function calling
collindutter Jun 13, 2024
be58aee
Add anthropic support
collindutter Jun 17, 2024
e83c143
Add bedrock support
collindutter Jun 17, 2024
06290e5
Add cohere support
collindutter Jun 18, 2024
e0d09b1
Rename Prompt Stack Element to Prompt Stack Message
collindutter Jun 18, 2024
05b45ae
Partial google support
collindutter Jun 18, 2024
c0ef8f4
Refactor action artifacts
collindutter Jun 19, 2024
c66d056
Better google function calling
collindutter Jun 19, 2024
1637f4f
Rename deltas, clean up artifacts
collindutter Jun 19, 2024
76da060
Remove list artifact generics
collindutter Jun 19, 2024
dc7af54
Regenerate lock file
collindutter Jun 20, 2024
e0023a5
Fix bad merge
collindutter Jun 21, 2024
a9bdb2f
Update anthropic
collindutter Jun 21, 2024
70778f7
Rename PromptStackMessage to Message
collindutter Jul 2, 2024
af0f5e3
Rename PromptStackContent to MessageContent
collindutter Jul 2, 2024
eeaf2df
Merge branch 'dev' into feature/native-functions
collindutter Jul 2, 2024
b8ac3b3
Fix pyright
collindutter Jul 2, 2024
88f85d9
Merge branch 'dev' into feature/native-functions
collindutter Jul 2, 2024
8a05aa6
Regenerate lock file
collindutter Jul 2, 2024
55a535f
Clean up from bad merge
collindutter Jul 2, 2024
9224a0e
WIP
collindutter Jul 2, 2024
e1fc021
Merge branch 'dev' into feature/native-functions
collindutter Jul 2, 2024
ed8b935
Update pyright, enable experimental features
collindutter Jul 3, 2024
c414839
Update pyright, fix new pyright errors
collindutter Jul 3, 2024
8630c4d
Merge branch 'dev' into feature/native-functions
collindutter Jul 3, 2024
5a7ae37
WIP
collindutter Jul 3, 2024
63b8da7
Regenerate lock file
collindutter Jul 3, 2024
730156c
Fix some tests
collindutter Jul 3, 2024
b9029a0
Merge branch 'dev' into feature/native-functions
collindutter Jul 3, 2024
b91631e
Fix more tests
collindutter Jul 3, 2024
0e1aacc
Regenerate lock file
collindutter Jul 3, 2024
4c18b29
Remove print
collindutter Jul 3, 2024
b0f9e36
Merge branch 'dev' into feature/native-functions
collindutter Jul 3, 2024
47c0995
Merge branch 'dev' into feature/native-functions
collindutter Jul 5, 2024
9605a8c
Clean up/standardize prompt driver classes
collindutter Jul 5, 2024
85bcc86
Add cohere test coverage
collindutter Jul 5, 2024
f270170
Merge branch 'dev' into feature/native-functions
collindutter Jul 5, 2024
da95012
Fix lint
collindutter Jul 5, 2024
e431f5d
Add Prompt Driver tests
collindutter Jul 6, 2024
569792c
Simplify types
collindutter Jul 6, 2024
f1c5abd
Remove extra code
collindutter Jul 6, 2024
dc46fa3
Maybe fix serialization
collindutter Jul 6, 2024
6e0d0cc
Rename variable
collindutter Jul 8, 2024
6766313
Add unit tests
collindutter Jul 8, 2024
59ab49f
Add more unit tests
collindutter Jul 8, 2024
55d2cda
Simplify prompt stack methods
collindutter Jul 8, 2024
2bf3f77
Merge branch 'dev' into feature/native-functions
collindutter Jul 8, 2024
c64d0af
Add test
collindutter Jul 8, 2024
061cc1b
Rename actions to tools
collindutter Jul 8, 2024
db2136f
Move ActionArtifact to common module
collindutter Jul 8, 2024
950c3b3
Create util for native tool naming
collindutter Jul 8, 2024
78e162d
Add actions tests
collindutter Jul 8, 2024
b3fa757
Fix toolkittask
collindutter Jul 8, 2024
888f6e2
Update google dep
collindutter Jul 8, 2024
f0198f0
Link to issue
collindutter Jul 8, 2024
a2bdbae
Rename method
collindutter Jul 8, 2024
0e2939a
Fix stream output
collindutter Jul 8, 2024
33aa25a
Clean up toolkit task
collindutter Jul 8, 2024
e76f906
Implement native function calling in tooltask
collindutter Jul 8, 2024
5c07c99
Update changelog
collindutter Jul 9, 2024
ba8e2f6
Improve test coverage
collindutter Jul 9, 2024
68f865f
Improve test coverage
collindutter Jul 9, 2024
4be252b
Merge branch 'dev' into feature/native-functions
collindutter Jul 9, 2024
09086ff
Update docs
collindutter Jul 9, 2024
47940f1
Merge branch 'dev' into feature/native-functions
collindutter Jul 9, 2024
29481f7
Merge branch 'dev' into feature/native-functions
collindutter Jul 10, 2024
97d620f
Merge branch 'dev' into feature/native-functions
collindutter Jul 10, 2024
d55ce6e
Fix docstring location
collindutter Jul 10, 2024
e2d1a03
Properly to_dict
collindutter Jul 10, 2024
9c44a62
Create BaseAction and ToolAction
collindutter Jul 10, 2024
3d4fd4a
Merge branch 'dev' into feature/native-functions
collindutter Jul 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
- Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, and `CoherePromptDriver`.
- `OllamaEmbeddingDriver` for generating embeddings with Ollama.

### Changed
Expand Down
18 changes: 18 additions & 0 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ Griptape offers the following Prompt Drivers for interacting with LLMs.

### OpenAI Chat

!!! info
This driver uses [OpenAi function calling](https://platform.openai.com/docs/guides/function-calling) when using [Tools](../tools/index.md). You can change this to use Griptape's own tool calling implementation by setting `use_native_tools` to `False`.

The [OpenAiChatPromptDriver](../../reference/griptape/drivers/prompt/openai_chat_prompt_driver.md) connects to the [OpenAI Chat](https://platform.openai.com/docs/guides/chat) API.

```python
Expand Down Expand Up @@ -96,6 +99,9 @@ agent.run("Blue sky at dusk.")

### Azure OpenAI Chat

!!! info
This driver uses [Azure OpenAi function calling](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/function-calling) when using [Tools](../tools/index.md). You can change this to use Griptape's own tool calling implementation by setting `use_native_tools` to `False`.

The [AzureOpenAiChatPromptDriver](../../reference/griptape/drivers/prompt/azure_openai_chat_prompt_driver.md) connects to Azure OpenAI [Chat Completion](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference) APIs.

```python
Expand Down Expand Up @@ -132,6 +138,9 @@ The [CoherePromptDriver](../../reference/griptape/drivers/prompt/cohere_prompt_d
!!! info
This driver requires the `drivers-prompt-cohere` [extra](../index.md#extras).

!!! info
This driver uses [Cohere tool use](https://docs.cohere.com/docs/tools) when using [Tools](../tools/index.md). You can change this to use Griptape's own tool calling implementation by setting `use_native_tools` to `False`.

```python
import os
from griptape.structures import Agent
Expand All @@ -155,6 +164,9 @@ agent.run('What is the sentiment of this review? Review: "I really enjoyed this
!!! info
This driver requires the `drivers-prompt-anthropic` [extra](../index.md#extras).

!!! info
This driver uses [Anthropic tool use](https://docs.anthropic.com/en/docs/build-with-claude/tool-use) when using [Tools](../tools/index.md). You can change this to use Griptape's own tool calling implementation by setting `use_native_tools` to `False`.

The [AnthropicPromptDriver](../../reference/griptape/drivers/prompt/anthropic_prompt_driver.md) connects to the Anthropic [Messages](https://docs.anthropic.com/claude/reference/messages_post) API.

```python
Expand All @@ -180,6 +192,9 @@ agent.run('Where is the best place to see cherry blossums in Japan?')
!!! info
This driver requires the `drivers-prompt-google` [extra](../index.md#extras).

!!! info
This driver uses [Gemini function calling](https://ai.google.dev/gemini-api/docs/function-calling) when using [Tools](../tools/index.md). You can change this to use Griptape's own tool calling implementation by setting `use_native_tools` to `False`.

The [GooglePromptDriver](../../reference/griptape/drivers/prompt/google_prompt_driver.md) connects to the [Google Generative AI](https://ai.google.dev/tutorials/python_quickstart#generate_text_from_text_inputs) API.

```python
Expand All @@ -205,6 +220,9 @@ agent.run('Briefly explain how a computer works to a young child.')
!!! info
This driver requires the `drivers-prompt-amazon-bedrock` [extra](../index.md#extras).

!!! info
This drivers uses [Bedrock tool use](https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use.html) when using [Tools](../tools/index.md). You can change this to use Griptape's own tool calling implementation by setting `use_native_tools` to `False`.

The [AmazonBedrockPromptDriver](../../reference/griptape/drivers/prompt/amazon_bedrock_prompt_driver.md) uses [Amazon Bedrock](https://aws.amazon.com/bedrock/)'s [Converse API](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html).

All models supported by the Converse API are available for use with this driver.
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ from griptape.events import BaseEvent, StartPromptEvent, EventListener

def handler(event: BaseEvent):
if isinstance(event, StartPromptEvent):
print("Prompt Stack PromptStack:")
print("Prompt Stack Messages:")
for message in event.prompt_stack.messages:
print(f"{message.role}: {message.content}")
print("Final Prompt String:")
Expand Down
7 changes: 3 additions & 4 deletions docs/griptape-framework/tools/index.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
## Overview

One of the most powerful features of Griptape is the ability for Toolkit Tasks to generate _chains of thought_ (CoT) and use tools that can interact with the outside world. We use the [ReAct](https://arxiv.org/abs/2210.03629) technique to implement CoT reasoning and acting in the underlying LLMs without using any fine-tuning.

Griptape implements the reasoning loop in the Toolkit Tasks and integrates Griptape Tools natively.
One of the most powerful features of Griptape is the ability to use tools that can interact with the outside world.
Many of our [Prompt Drivers](../drivers/prompt-drivers.md) leverage the native function calling built into the LLMs. For LLMs that don't support this, Griptape provides its own implementation using the [ReAct](https://arxiv.org/abs/2210.03629) technique.

## Tools
Here is an example of a pipeline using tools:
Here is an example of a Pipeline using Tools:

```python
from griptape.tasks import ToolkitTask
Expand Down
2 changes: 2 additions & 0 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .media_artifact import MediaArtifact
from .image_artifact import ImageArtifact
from .audio_artifact import AudioArtifact
from .action_artifact import ActionArtifact


__all__ = [
Expand All @@ -23,4 +24,5 @@
"MediaArtifact",
"ImageArtifact",
"AudioArtifact",
"ActionArtifact",
]
18 changes: 18 additions & 0 deletions griptape/artifacts/action_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations

from attrs import define, field
from typing import TYPE_CHECKING

from griptape.artifacts import BaseArtifact
from griptape.mixins import SerializableMixin

if TYPE_CHECKING:
from griptape.common import ToolAction


@define()
class ActionArtifact(BaseArtifact, SerializableMixin):
dylanholmes marked this conversation as resolved.
Show resolved Hide resolved
vasinov marked this conversation as resolved.
Show resolved Hide resolved
value: ToolAction = field(metadata={"serializable": True})

def __add__(self, other: BaseArtifact) -> ActionArtifact:
raise NotImplementedError
12 changes: 12 additions & 0 deletions griptape/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from .actions.base_action import BaseAction
from .actions.tool_action import ToolAction

from .prompt_stack.contents.base_message_content import BaseMessageContent
from .prompt_stack.contents.base_delta_message_content import BaseDeltaMessageContent
from .prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent
from .prompt_stack.contents.text_message_content import TextMessageContent
from .prompt_stack.contents.image_message_content import ImageMessageContent
from .prompt_stack.contents.action_call_delta_message_content import ActionCallDeltaMessageContent
from .prompt_stack.contents.action_call_message_content import ActionCallMessageContent
from .prompt_stack.contents.action_result_message_content import ActionResultMessageContent

from .prompt_stack.messages.base_message import BaseMessage
from .prompt_stack.messages.delta_message import DeltaMessage
Expand All @@ -12,6 +18,7 @@

from .reference import Reference


__all__ = [
"BaseMessage",
"BaseDeltaMessageContent",
Expand All @@ -21,6 +28,11 @@
"TextDeltaMessageContent",
"TextMessageContent",
"ImageMessageContent",
"ActionCallDeltaMessageContent",
"ActionCallMessageContent",
"ActionResultMessageContent",
"PromptStack",
"Reference",
"BaseAction",
"ToolAction",
]
Empty file.
5 changes: 5 additions & 0 deletions griptape/common/actions/base_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from griptape.mixins import SerializableMixin
from abc import ABC


class BaseAction(SerializableMixin, ABC): ...
54 changes: 54 additions & 0 deletions griptape/common/actions/tool_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Optional

from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.common import BaseAction

if TYPE_CHECKING:
from griptape.tools import BaseTool


@define(kw_only=True)
class ToolAction(BaseAction):
"""Represents an instance of an LLM using a Tool.

Attributes:
tag: The tag (unique identifier) of the action.
name: The name (Tool name) of the action.
path: The path (Tool activity name) of the action.
input: The input (Tool params) of the action.
tool: The matched Tool of the action.
output: The output (Tool result) of the action.
"""

tag: str = field(metadata={"serializable": True})
name: str = field(metadata={"serializable": True})
path: Optional[str] = field(default=None, metadata={"serializable": True})
input: dict = field(factory=dict, metadata={"serializable": True})
tool: Optional[BaseTool] = field(default=None)
output: Optional[BaseArtifact] = field(default=None)

def __str__(self) -> str:
return json.dumps(self.to_dict())

def to_native_tool_name(self) -> str:
parts = [self.name]

if self.path is not None:
parts.append(self.path)

return "_".join(parts)

@classmethod
def from_native_tool_name(cls, native_tool_name: str) -> tuple[str, Optional[str]]:
parts = native_tool_name.split("_", 1)

if len(parts) == 1:
name, path = parts[0], None
else:
name, path = parts

return name, path
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations
from attrs import define, field
from typing import Optional

from griptape.common import BaseDeltaMessageContent


@define
class ActionCallDeltaMessageContent(BaseDeltaMessageContent):
vasinov marked this conversation as resolved.
Show resolved Hide resolved
tag: Optional[str] = field(default=None, metadata={"serializable": True})
name: Optional[str] = field(default=None, metadata={"serializable": True})
path: Optional[str] = field(default=None, metadata={"serializable": True})
partial_input: Optional[str] = field(default=None, metadata={"serializable": True})

def __str__(self) -> str:
parts = []

if self.name:
parts.append(self.name)
if self.path:
parts.append(f".{self.path}")
if self.tag:
parts.append(f" ({self.tag})")

if self.partial_input:
if parts:
parts.append(f" {self.partial_input}")
else:
parts.append(self.partial_input)

return "".join(parts)
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

import json
from collections.abc import Sequence

from attrs import define, field

from griptape.common import ToolAction
from griptape.artifacts import ActionArtifact
from griptape.common import BaseDeltaMessageContent, BaseMessageContent, ActionCallDeltaMessageContent


@define
class ActionCallMessageContent(BaseMessageContent):
artifact: ActionArtifact = field(metadata={"serializable": True})

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> ActionCallMessageContent:
action_call_deltas = [delta for delta in deltas if isinstance(delta, ActionCallDeltaMessageContent)]

tag = None
name = None
path = None
input = ""

for delta in action_call_deltas:
if delta.tag is not None:
tag = delta.tag
if delta.name is not None:
name = delta.name
if delta.path is not None:
path = delta.path
if delta.partial_input is not None:
input += delta.partial_input

if tag is not None and name is not None and path is not None:
try:
parsed_input = json.loads(input)
except json.JSONDecodeError as exc:
raise ValueError("Invalid JSON input for ToolAction") from exc
action = ToolAction(tag=tag, name=name, path=path, input=parsed_input)
else:
raise ValueError("Missing required fields for ToolAction")

artifact = ActionArtifact(value=action)

return cls(artifact=artifact)
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations

from collections.abc import Sequence

from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.common import BaseDeltaMessageContent, BaseMessageContent, ToolAction


@define
class ActionResultMessageContent(BaseMessageContent):
artifact: BaseArtifact = field(metadata={"serializable": True})
action: ToolAction = field(metadata={"serializable": True})

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> ActionResultMessageContent:
raise NotImplementedError
10 changes: 6 additions & 4 deletions griptape/common/prompt_stack/contents/base_message_content.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from abc import ABC

from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING
from attrs import define, field
Expand All @@ -14,9 +15,6 @@
class BaseMessageContent(ABC, SerializableMixin):
artifact: BaseArtifact = field(metadata={"serializable": True})

def to_text(self) -> str:
return str(self.artifact)

def __str__(self) -> str:
return self.artifact.to_text()

Expand All @@ -26,5 +24,9 @@ def __bool__(self) -> bool:
def __len__(self) -> int:
return len(self.artifact)

def to_text(self) -> str:
return str(self.artifact)

@classmethod
@abstractmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> BaseMessageContent: ...
6 changes: 3 additions & 3 deletions griptape/common/prompt_stack/messages/delta_message.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from __future__ import annotations

from typing import Optional

from attrs import define, field

from griptape.common.prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent

from griptape.common import BaseDeltaMessageContent

from .base_message import BaseMessage


@define
class DeltaMessage(BaseMessage):
role: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
content: Optional[TextDeltaMessageContent] = field(kw_only=True, default=None, metadata={"serializable": True})
content: Optional[BaseDeltaMessageContent] = field(kw_only=True, default=None, metadata={"serializable": True})
Loading
Loading