Skip to content

Commit

Permalink
Add ability to pass callable to activity schema
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 9, 2024
1 parent 12ac9e9 commit 58fd5e6
Show file tree
Hide file tree
Showing 14 changed files with 210 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `LocalRulesetDriver` for loading a `Ruleset` from a local `.json` file.
- `GriptapeCloudRulesetDriver` for loading a `Ruleset` resource from Griptape Cloud.
- Parameter `alias` on `GriptapeCloudConversationMemoryDriver` for fetching a Thread by alias.
- Ability to pass callable to `activity.schema` for dynamic schema generation.

### Changed
- **BREAKING**: Renamed parameters on several classes to `client`:
Expand Down
8 changes: 6 additions & 2 deletions griptape/mixins/activity_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,12 @@ def activity_schema(self, activity: Callable) -> Optional[Schema]:
if activity is None or not getattr(activity, "is_activity", False):
raise Exception("This method is not an activity.")
if getattr(activity, "config")["schema"] is not None:
# Need to deepcopy to avoid modifying the original schema
config_schema = deepcopy(getattr(activity, "config")["schema"])
config_schema = getattr(activity, "config")["schema"]
if isinstance(config_schema, Callable):
config_schema = config_schema(self)
else:
# Need to deepcopy to avoid modifying the original schema
config_schema = deepcopy(getattr(activity, "config")["schema"])
activity_name = self.activity_name(activity)

if self.extra_schema_properties is not None and activity_name in self.extra_schema_properties:
Expand Down
7 changes: 6 additions & 1 deletion griptape/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import schema
from schema import Schema

CONFIG_SCHEMA = Schema({"description": str, schema.Optional("schema"): Schema})
CONFIG_SCHEMA = Schema(
{
"description": str,
schema.Optional("schema"): lambda data: isinstance(data, (Schema, Callable)),
}
)


def activity(config: dict) -> Any:
Expand Down
15 changes: 14 additions & 1 deletion tests/mocks/mock_tool/tool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from attrs import define, field
from attrs import Factory, define, field
from schema import Literal, Schema

from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact
Expand All @@ -11,6 +11,7 @@ class MockTool(BaseTool):
test_field: str = field(default="test", kw_only=True)
test_int: int = field(default=5, kw_only=True)
test_dict: dict = field(factory=dict, kw_only=True)
custom_schema: dict = field(default=Factory(lambda: {"test": str}), kw_only=True)

@activity(
config={
Expand Down Expand Up @@ -52,6 +53,15 @@ def test_str_output(self, value: dict) -> str:
def test_no_schema(self, value: dict) -> str:
return "no schema"

@activity(
config={
"description": "test description",
"schema": lambda _self: _self.build_custom_schema(),
}
)
def test_callable_schema(self) -> TextArtifact:
return TextArtifact("ack")

@activity(config={"description": "test description"})
def test_list_output(self, value: dict) -> ListArtifact:
return ListArtifact([TextArtifact("foo"), TextArtifact("bar")])
Expand All @@ -64,3 +74,6 @@ def test_without_default_memory(self, value: dict) -> str:

def foo(self) -> str:
return "foo"

def build_custom_schema(self) -> Schema:
return Schema(self.custom_schema, description="Test input")
24 changes: 24 additions & 0 deletions tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,30 @@ class TestAmazonBedrockPromptDriver:
"name": "MockTool_test",
}
},
{
"toolSpec": {
"description": "test description",
"inputSchema": {
"json": {
"$id": "http://json-schema.org/draft-07/schema#",
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": False,
"properties": {
"values": {
"additionalProperties": False,
"description": "Test input",
"properties": {"test": {"type": "string"}},
"required": ["test"],
"type": "object",
}
},
"required": ["values"],
"type": "object",
}
},
"name": "MockTool_test_callable_schema",
}
},
{
"toolSpec": {
"description": "test description: foo",
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/drivers/prompt/test_anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,24 @@ class TestAnthropicPromptDriver:
},
"name": "MockTool_test",
},
{
"description": "test description",
"input_schema": {
"additionalProperties": False,
"properties": {
"values": {
"additionalProperties": False,
"description": "Test input",
"properties": {"test": {"type": "string"}},
"required": ["test"],
"type": "object",
}
},
"required": ["values"],
"type": "object",
},
"name": "MockTool_test_callable_schema",
},
{
"description": "test description: foo",
"input_schema": {
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/drivers/prompt/test_cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ class TestCoherePromptDriver:
"name": "MockTool_test",
"parameter_definitions": {"test": {"required": True, "type": "string"}},
},
{
"description": "test description",
"name": "MockTool_test_callable_schema",
"parameter_definitions": {"test": {"required": True, "type": "string"}},
},
{
"description": "test description: foo",
"name": "MockTool_test_error",
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/drivers/prompt/test_google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ class TestGooglePromptDriver:
"description": "test description: foo",
"parameters": {"type": "OBJECT", "properties": {"test": {"type": "STRING"}}, "required": ["test"]},
},
{
"name": "MockTool_test_callable_schema",
"description": "test description",
"parameters": {"type": "OBJECT", "properties": {"test": {"type": "STRING"}}, "required": ["test"]},
},
{
"name": "MockTool_test_error",
"description": "test description: foo",
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/drivers/prompt/test_ollama_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ class TestOllamaPromptDriver:
},
"type": "function",
},
{
"function": {
"description": "test description",
"name": "MockTool_test_callable_schema",
"parameters": {
"additionalProperties": False,
"description": "Test input",
"properties": {"test": {"type": "string"}},
"required": ["test"],
"type": "object",
},
},
"type": "function",
},
{
"function": {
"description": "test description: foo",
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,29 @@ class TestOpenAiChatPromptDriverFixtureMixin:
},
"type": "function",
},
{
"function": {
"name": "MockTool_test_callable_schema",
"description": "test description",
"parameters": {
"type": "object",
"properties": {
"values": {
"description": "Test input",
"type": "object",
"properties": {"test": {"type": "string"}},
"required": ["test"],
"additionalProperties": False,
}
},
"required": ["values"],
"additionalProperties": False,
"$id": "Parameters Schema",
"$schema": "http://json-schema.org/draft-07/schema#",
},
},
"type": "function",
},
{
"function": {
"description": "test description: foo",
Expand Down
25 changes: 23 additions & 2 deletions tests/unit/mixins/test_activity_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_find_activity(self):
assert tool.find_activity("test_str_output") is None

def test_activities(self, tool):
assert len(tool.activities()) == 7
assert len(tool.activities()) == 8
assert tool.activities()[0] == tool.test

def test_allowlist_and_denylist_validation(self):
Expand All @@ -47,7 +47,7 @@ def test_allowlist(self):
def test_denylist(self):
tool = MockTool(test_field="hello", test_int=5, denylist=["test"])

assert len(tool.activities()) == 6
assert len(tool.activities()) == 7

def test_invalid_allowlist(self):
with pytest.raises(ValueError):
Expand Down Expand Up @@ -101,3 +101,24 @@ def test_extra_schema_properties(self):
"additionalProperties": False,
"type": "object",
}

def test_callable_schema(self):
tool = MockTool(custom_schema={"test": str})
schema = tool.activity_schema(tool.test_callable_schema).json_schema("InputSchema")

assert schema == {
"$id": "InputSchema",
"$schema": "http://json-schema.org/draft-07/schema#",
"properties": {
"values": {
"description": "Test input",
"properties": {"test": {"type": "string"}},
"required": ["test"],
"additionalProperties": False,
"type": "object",
}
},
"required": ["values"],
"additionalProperties": False,
"type": "object",
}
24 changes: 24 additions & 0 deletions tests/unit/tasks/test_tool_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,30 @@ class TestToolTask:
"required": ["name", "path", "input", "tag"],
"additionalProperties": False,
},
{
"type": "object",
"properties": {
"name": {"const": "MockTool"},
"path": {"description": "test description", "const": "test_callable_schema"},
"input": {
"type": "object",
"properties": {
"values": {
"description": "Test input",
"type": "object",
"properties": {"test": {"type": "string"}},
"required": ["test"],
"additionalProperties": False,
}
},
"required": ["values"],
"additionalProperties": False,
},
"tag": {"description": "Unique tag name for action execution.", "type": "string"},
},
"required": ["name", "path", "input", "tag"],
"additionalProperties": False,
},
{
"type": "object",
"properties": {
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/tasks/test_toolkit_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,30 @@ class TestToolkitSubtask:
"required": ["name", "path", "input", "tag"],
"additionalProperties": False,
},
{
"type": "object",
"properties": {
"name": {"const": "MockTool"},
"path": {"description": "test description", "const": "test_callable_schema"},
"input": {
"type": "object",
"properties": {
"values": {
"description": "Test input",
"type": "object",
"properties": {"test": {"type": "string"}},
"required": ["test"],
"additionalProperties": False,
}
},
"required": ["values"],
"additionalProperties": False,
},
"tag": {"description": "Unique tag name for action execution.", "type": "string"},
},
"required": ["name", "path", "input", "tag"],
"additionalProperties": False,
},
{
"type": "object",
"properties": {
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/tools/test_base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,29 @@ class TestBaseTool:
"required": ["name", "path", "input"],
"additionalProperties": False,
},
{
"type": "object",
"properties": {
"name": {"const": "MockTool"},
"path": {"description": "test description", "const": "test_callable_schema"},
"input": {
"type": "object",
"properties": {
"values": {
"description": "Test input",
"type": "object",
"properties": {"test": {"type": "string"}},
"required": ["test"],
"additionalProperties": False,
}
},
"required": ["values"],
"additionalProperties": False,
},
},
"required": ["name", "path", "input"],
"additionalProperties": False,
},
{
"type": "object",
"properties": {
Expand Down

0 comments on commit 58fd5e6

Please sign in to comment.