Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/native functions #867

Merged
merged 90 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 79 commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
46fef79
Refactor prompt stack
collindutter Jun 13, 2024
a405af0
Add support for more modalities to conversation memory
collindutter Jun 14, 2024
1740d1d
Update default artifact
collindutter Jun 14, 2024
c91f7ea
Fix bad merge
collindutter Jun 18, 2024
ad3a0e3
Rename Prompt Stack Element to Prompt Stack Message
collindutter Jun 18, 2024
4665a8f
Fix Ollama
collindutter Jun 19, 2024
19929a5
Clean up roles
collindutter Jun 19, 2024
95c5849
Rename deltas
collindutter Jun 19, 2024
0b4b455
PR cleanup
collindutter Jun 19, 2024
05269ea
Change task hierarchy
collindutter Jun 19, 2024
9381332
Update changelog
collindutter Jun 19, 2024
c725b94
Regenerate lock file
collindutter Jun 20, 2024
6914d40
Add back missing logs
collindutter Jun 20, 2024
0054312
Fix doc var names
collindutter Jun 20, 2024
f1e8088
Clean up message building
collindutter Jun 20, 2024
032bc04
Add tests
collindutter Jun 20, 2024
c4b9351
Add image input support to ollama
collindutter Jun 20, 2024
f4f0eed
Fix tests
collindutter Jun 20, 2024
e8a2dcb
Refactor prompt stack
collindutter Jun 13, 2024
a499374
Implement native function calling
collindutter Jun 13, 2024
be58aee
Add anthropic support
collindutter Jun 17, 2024
e83c143
Add bedrock support
collindutter Jun 17, 2024
06290e5
Add cohere support
collindutter Jun 18, 2024
e0d09b1
Rename Prompt Stack Element to Prompt Stack Message
collindutter Jun 18, 2024
05b45ae
Partial google support
collindutter Jun 18, 2024
c0ef8f4
Refactor action artifacts
collindutter Jun 19, 2024
c66d056
Better google function calling
collindutter Jun 19, 2024
1637f4f
Rename deltas, clean up artifacts
collindutter Jun 19, 2024
76da060
Remove list artifact generics
collindutter Jun 19, 2024
dc7af54
Regenerate lock file
collindutter Jun 20, 2024
e0023a5
Fix bad merge
collindutter Jun 21, 2024
a9bdb2f
Update anthropic
collindutter Jun 21, 2024
70778f7
Rename PromptStackMessage to Message
collindutter Jul 2, 2024
af0f5e3
Rename PromptStackContent to MessageContent
collindutter Jul 2, 2024
eeaf2df
Merge branch 'dev' into feature/native-functions
collindutter Jul 2, 2024
b8ac3b3
Fix pyright
collindutter Jul 2, 2024
88f85d9
Merge branch 'dev' into feature/native-functions
collindutter Jul 2, 2024
8a05aa6
Regenerate lock file
collindutter Jul 2, 2024
55a535f
Clean up from bad merge
collindutter Jul 2, 2024
9224a0e
WIP
collindutter Jul 2, 2024
e1fc021
Merge branch 'dev' into feature/native-functions
collindutter Jul 2, 2024
ed8b935
Update pyright, enable experimental features
collindutter Jul 3, 2024
c414839
Update pyright, fix new pyright errors
collindutter Jul 3, 2024
8630c4d
Merge branch 'dev' into feature/native-functions
collindutter Jul 3, 2024
5a7ae37
WIP
collindutter Jul 3, 2024
63b8da7
Regenerate lock file
collindutter Jul 3, 2024
730156c
Fix some tests
collindutter Jul 3, 2024
b9029a0
Merge branch 'dev' into feature/native-functions
collindutter Jul 3, 2024
b91631e
Fix more tests
collindutter Jul 3, 2024
0e1aacc
Regenerate lock file
collindutter Jul 3, 2024
4c18b29
Remove print
collindutter Jul 3, 2024
b0f9e36
Merge branch 'dev' into feature/native-functions
collindutter Jul 3, 2024
47c0995
Merge branch 'dev' into feature/native-functions
collindutter Jul 5, 2024
9605a8c
Clean up/standardize prompt driver classes
collindutter Jul 5, 2024
85bcc86
Add cohere test coverage
collindutter Jul 5, 2024
f270170
Merge branch 'dev' into feature/native-functions
collindutter Jul 5, 2024
da95012
Fix lint
collindutter Jul 5, 2024
e431f5d
Add Prompt Driver tests
collindutter Jul 6, 2024
569792c
Simplify types
collindutter Jul 6, 2024
f1c5abd
Remove extra code
collindutter Jul 6, 2024
dc46fa3
Maybe fix serialization
collindutter Jul 6, 2024
6e0d0cc
Rename variable
collindutter Jul 8, 2024
6766313
Add unit tests
collindutter Jul 8, 2024
59ab49f
Add more unit tests
collindutter Jul 8, 2024
55d2cda
Simplify prompt stack methods
collindutter Jul 8, 2024
2bf3f77
Merge branch 'dev' into feature/native-functions
collindutter Jul 8, 2024
c64d0af
Add test
collindutter Jul 8, 2024
061cc1b
Rename actions to tools
collindutter Jul 8, 2024
db2136f
Move ActionArtifact to common module
collindutter Jul 8, 2024
950c3b3
Create util for native tool naming
collindutter Jul 8, 2024
78e162d
Add actions tests
collindutter Jul 8, 2024
b3fa757
Fix toolkittask
collindutter Jul 8, 2024
888f6e2
Update google dep
collindutter Jul 8, 2024
f0198f0
Link to issue
collindutter Jul 8, 2024
a2bdbae
Rename method
collindutter Jul 8, 2024
0e2939a
Fix stream output
collindutter Jul 8, 2024
33aa25a
Clean up toolkit task
collindutter Jul 8, 2024
e76f906
Implement native function calling in tooltask
collindutter Jul 8, 2024
5c07c99
Update changelog
collindutter Jul 9, 2024
ba8e2f6
Improve test coverage
collindutter Jul 9, 2024
68f865f
Improve test coverage
collindutter Jul 9, 2024
4be252b
Merge branch 'dev' into feature/native-functions
collindutter Jul 9, 2024
09086ff
Update docs
collindutter Jul 9, 2024
47940f1
Merge branch 'dev' into feature/native-functions
collindutter Jul 9, 2024
29481f7
Merge branch 'dev' into feature/native-functions
collindutter Jul 10, 2024
97d620f
Merge branch 'dev' into feature/native-functions
collindutter Jul 10, 2024
d55ce6e
Fix docstring location
collindutter Jul 10, 2024
e2d1a03
Properly to_dict
collindutter Jul 10, 2024
9c44a62
Create BaseAction and ToolAction
collindutter Jul 10, 2024
3d4fd4a
Merge branch 'dev' into feature/native-functions
collindutter Jul 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support for `PromptTask`s to take `TextArtifact`s, `ImageArtifact`s, and `ListArtifact`s as input.
- Parameters `sort_key` and `sort_key_value` on `AmazonDynamoDbConversationMemoryDriver` for tables with sort keys.
- `Reference` for supporting artifact citations in loaders and RAG engine modules.
- Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, and `CoherePromptDriver`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉


### Changed
- **BREAKING**: Moved/renamed `griptape.utils.PromptStack` to `griptape.common.PromptStack`.
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ from griptape.events import BaseEvent, StartPromptEvent, EventListener

def handler(event: BaseEvent):
if isinstance(event, StartPromptEvent):
print("Prompt Stack PromptStack:")
print("Prompt Stack Messages:")
for message in event.prompt_stack.messages:
print(f"{message.role}: {message.content}")
print("Final Prompt String:")
Expand Down
2 changes: 2 additions & 0 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .media_artifact import MediaArtifact
from .image_artifact import ImageArtifact
from .audio_artifact import AudioArtifact
from .action_artifact import ActionArtifact


__all__ = [
Expand All @@ -23,4 +24,5 @@
"MediaArtifact",
"ImageArtifact",
"AudioArtifact",
"ActionArtifact",
]
29 changes: 29 additions & 0 deletions griptape/artifacts/action_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

from attrs import define, field
from typing import TYPE_CHECKING

from griptape.artifacts import BaseArtifact
from griptape.mixins import SerializableMixin

if TYPE_CHECKING:
from griptape.common import Action


@define()
class ActionArtifact(BaseArtifact, SerializableMixin):
dylanholmes marked this conversation as resolved.
Show resolved Hide resolved
vasinov marked this conversation as resolved.
Show resolved Hide resolved
"""Represents an instance of an LLM calling a Action.

Attributes:
tag: The tag (unique identifier) of the action.
name: The name (Tool name) of the action.
path: The path (Tool activity name) of the action.
input: The input (Tool params) of the action.
tool: The matched Tool of the action.
output: The output (Tool result) of the action.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring should be in Action.


value: Action = field(metadata={"serializable": True})

def __add__(self, other: BaseArtifact) -> ActionArtifact:
raise NotImplementedError
10 changes: 10 additions & 0 deletions griptape/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from .action import Action

from .prompt_stack.contents.base_message_content import BaseMessageContent
from .prompt_stack.contents.base_delta_message_content import BaseDeltaMessageContent
from .prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent
from .prompt_stack.contents.text_message_content import TextMessageContent
from .prompt_stack.contents.image_message_content import ImageMessageContent
from .prompt_stack.contents.action_call_delta_message_content import ActionCallDeltaMessageContent
from .prompt_stack.contents.action_call_message_content import ActionCallMessageContent
from .prompt_stack.contents.action_result_message_content import ActionResultMessageContent

from .prompt_stack.messages.base_message import BaseMessage
from .prompt_stack.messages.delta_message import DeltaMessage
Expand All @@ -12,6 +17,7 @@

from .reference import Reference


__all__ = [
"BaseMessage",
"BaseDeltaMessageContent",
Expand All @@ -21,6 +27,10 @@
"TextDeltaMessageContent",
"TextMessageContent",
"ImageMessageContent",
"ActionCallDeltaMessageContent",
"ActionCallMessageContent",
"ActionResultMessageContent",
"PromptStack",
"Reference",
"Action",
]
46 changes: 46 additions & 0 deletions griptape/common/action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Optional

from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.mixins import SerializableMixin

if TYPE_CHECKING:
from griptape.tools import BaseTool


@define(kw_only=True)
class Action(SerializableMixin):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add some prefix to this? I'm worried that we might have some other "action" in the future, unrelated to LLM actions, which will be confusing naming-wise.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ToolAction?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored into BaseAction and ToolAction

tag: str = field(metadata={"serializable": True})
name: str = field(metadata={"serializable": True})
path: Optional[str] = field(default=None, metadata={"serializable": True})
input: dict = field(factory=dict, metadata={"serializable": True})
tool: Optional[BaseTool] = field(default=None)
output: Optional[BaseArtifact] = field(default=None, metadata={"serializable": True})

def __str__(self) -> str:
return json.dumps(self.to_dict())

def to_dict(self) -> dict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you need to override this? Can't we just set tool and output to non-serializable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, updated.

return {"tag": self.tag, "name": self.name, "path": self.path, "input": self.input}

def to_native_tool_name(self) -> str:
parts = [self.name]

if self.path is not None:
parts.append(self.path)

return "_".join(parts)

@classmethod
def from_native_tool_name(cls, native_tool_name: str) -> tuple[str, Optional[str]]:
parts = native_tool_name.split("_", 1)

if len(parts) == 1:
name, path = parts[0], None
else:
name, path = parts

return name, path
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations
from attrs import define, field
from typing import Optional

from griptape.common import BaseDeltaMessageContent


@define
class ActionCallDeltaMessageContent(BaseDeltaMessageContent):
vasinov marked this conversation as resolved.
Show resolved Hide resolved
tag: Optional[str] = field(default=None, metadata={"serializable": True})
name: Optional[str] = field(default=None, metadata={"serializable": True})
path: Optional[str] = field(default=None, metadata={"serializable": True})
partial_input: Optional[str] = field(default=None, metadata={"serializable": True})

def __str__(self) -> str:
parts = []

if self.name:
parts.append(self.name)
if self.path:
parts.append(f".{self.path}")
if self.tag:
parts.append(f" ({self.tag})")

if self.partial_input:
if parts:
parts.append(f" {self.partial_input}")
else:
parts.append(self.partial_input)

return "".join(parts)
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

import json
from collections.abc import Sequence

from attrs import define, field

from griptape.common import Action
from griptape.artifacts import ActionArtifact
from griptape.common import BaseDeltaMessageContent, BaseMessageContent, ActionCallDeltaMessageContent


@define
class ActionCallMessageContent(BaseMessageContent):
artifact: ActionArtifact = field(metadata={"serializable": True})

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> ActionCallMessageContent:
action_call_deltas = [delta for delta in deltas if isinstance(delta, ActionCallDeltaMessageContent)]

tag = None
name = None
path = None
input = ""

for delta in action_call_deltas:
if delta.tag is not None:
tag = delta.tag
if delta.name is not None:
name = delta.name
if delta.path is not None:
path = delta.path
if delta.partial_input is not None:
input += delta.partial_input

if tag is not None and name is not None and path is not None:
try:
parsed_input = json.loads(input)
except json.JSONDecodeError as exc:
raise ValueError("Invalid JSON input for Action") from exc
action = Action(tag=tag, name=name, path=path, input=parsed_input)
else:
raise ValueError("Missing required fields for Action")

artifact = ActionArtifact(value=action)

return cls(artifact=artifact)
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations

from collections.abc import Sequence

from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.common import BaseDeltaMessageContent, BaseMessageContent, Action


@define
class ActionResultMessageContent(BaseMessageContent):
artifact: BaseArtifact = field(metadata={"serializable": True})
action: Action = field(metadata={"serializable": True})

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> ActionResultMessageContent:
raise NotImplementedError
10 changes: 6 additions & 4 deletions griptape/common/prompt_stack/contents/base_message_content.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from abc import ABC

from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING
from attrs import define, field
Expand All @@ -14,9 +15,6 @@
class BaseMessageContent(ABC, SerializableMixin):
artifact: BaseArtifact = field(metadata={"serializable": True})

def to_text(self) -> str:
return str(self.artifact)

def __str__(self) -> str:
return self.artifact.to_text()

Expand All @@ -26,5 +24,9 @@ def __bool__(self) -> bool:
def __len__(self) -> int:
return len(self.artifact)

def to_text(self) -> str:
return str(self.artifact)

@classmethod
@abstractmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> BaseMessageContent: ...
6 changes: 3 additions & 3 deletions griptape/common/prompt_stack/messages/delta_message.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from __future__ import annotations

from typing import Optional

from attrs import define, field

from griptape.common.prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent

from griptape.common import BaseDeltaMessageContent

from .base_message import BaseMessage


@define
class DeltaMessage(BaseMessage):
role: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
content: Optional[TextDeltaMessageContent] = field(kw_only=True, default=None, metadata={"serializable": True})
content: Optional[BaseDeltaMessageContent] = field(kw_only=True, default=None, metadata={"serializable": True})
18 changes: 16 additions & 2 deletions griptape/common/prompt_stack/messages/message.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import annotations

from typing import Any
from typing import Any, TypeVar

from attrs import define, field

from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact
from griptape.artifacts import TextArtifact, ListArtifact, BaseArtifact
from griptape.common import BaseMessageContent, TextMessageContent

from .base_message import BaseMessage

T = TypeVar("T", bound=BaseMessageContent)


@define
class Message(BaseMessage):
Expand All @@ -26,6 +28,18 @@ def value(self) -> Any:
def __str__(self) -> str:
return self.to_text()

def has_all_content_type(self, content_type: type[T]) -> bool:
return all(isinstance(content, content_type) for content in self.content)

def has_any_content_type(self, content_type: type[T]) -> bool:
return any(isinstance(content, content_type) for content in self.content)

def get_content_type(self, content_type: type[T]) -> list[T]:
return [content for content in self.content if isinstance(content, content_type)]

def is_text(self) -> bool:
return all(isinstance(content, TextMessageContent) for content in self.content)

def to_text(self) -> str:
return "".join(
[content.artifact.to_text() for content in self.content if isinstance(content, TextMessageContent)]
Expand Down
33 changes: 27 additions & 6 deletions griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from attrs import define, field

from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact, ImageArtifact
from griptape.artifacts import ActionArtifact, BaseArtifact, ImageArtifact, ListArtifact, TextArtifact
from griptape.common import (
ActionCallMessageContent,
ActionResultMessageContent,
BaseMessageContent,
ImageMessageContent,
Message,
TextMessageContent,
)
from griptape.mixins import SerializableMixin
from griptape.common import Message, TextMessageContent, BaseMessageContent, ImageMessageContent

if TYPE_CHECKING:
from griptape.tools import BaseTool


@define
class PromptStack(SerializableMixin):
messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True})
tools: list[BaseTool] = field(factory=list, kw_only=True)

@property
def system_messages(self) -> list[Message]:
Expand All @@ -23,9 +37,9 @@ def assistant_messages(self) -> list[Message]:
return [message for message in self.messages if message.is_assistant()]

def add_message(self, artifact: str | BaseArtifact, role: str) -> Message:
new_content = self.__process_artifact(artifact)
content = self.__to_message_content(artifact)

self.messages.append(Message(content=new_content, role=role))
self.messages.append(Message(content=content, role=role))

return self.messages[-1]

Expand All @@ -38,15 +52,22 @@ def add_user_message(self, artifact: str | BaseArtifact) -> Message:
def add_assistant_message(self, artifact: str | BaseArtifact) -> Message:
return self.add_message(artifact, Message.ASSISTANT_ROLE)

def __process_artifact(self, artifact: str | BaseArtifact) -> list[BaseMessageContent]:
def __to_message_content(self, artifact: str | BaseArtifact) -> list[BaseMessageContent]:
if isinstance(artifact, str):
return [TextMessageContent(TextArtifact(artifact))]
elif isinstance(artifact, TextArtifact):
return [TextMessageContent(artifact)]
elif isinstance(artifact, ImageArtifact):
return [ImageMessageContent(artifact)]
elif isinstance(artifact, ActionArtifact):
action = artifact.value
output = action.output
if output is None:
return [ActionCallMessageContent(artifact)]
else:
return [ActionResultMessageContent(output, action=action)]
elif isinstance(artifact, ListArtifact):
processed_contents = [self.__process_artifact(artifact) for artifact in artifact.value]
processed_contents = [self.__to_message_content(artifact) for artifact in artifact.value]
flattened_content = [
sub_content for processed_content in processed_contents for sub_content in processed_content
]
Expand Down
Loading
Loading