From 58fd5e611c9af3c9d109ee87c6bee271c0b1440e Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 8 Oct 2024 14:45:09 -0700 Subject: [PATCH] Add ability to pass callable to activity schema --- CHANGELOG.md | 1 + griptape/mixins/activity_mixin.py | 8 ++++-- griptape/utils/decorators.py | 7 +++++- tests/mocks/mock_tool/tool.py | 15 ++++++++++- .../test_amazon_bedrock_prompt_driver.py | 24 ++++++++++++++++++ .../prompt/test_anthropic_prompt_driver.py | 18 +++++++++++++ .../prompt/test_cohere_prompt_driver.py | 5 ++++ .../prompt/test_google_prompt_driver.py | 5 ++++ .../prompt/test_ollama_prompt_driver.py | 14 +++++++++++ .../prompt/test_openai_chat_prompt_driver.py | 23 +++++++++++++++++ tests/unit/mixins/test_activity_mixin.py | 25 +++++++++++++++++-- tests/unit/tasks/test_tool_task.py | 24 ++++++++++++++++++ tests/unit/tasks/test_toolkit_task.py | 24 ++++++++++++++++++ tests/unit/tools/test_base_tool.py | 23 +++++++++++++++++ 14 files changed, 210 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e73db1340b..d6aabba5a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `LocalRulesetDriver` for loading a `Ruleset` from a local `.json` file. - `GriptapeCloudRulesetDriver` for loading a `Ruleset` resource from Griptape Cloud. - Parameter `alias` on `GriptapeCloudConversationMemoryDriver` for fetching a Thread by alias. +- Ability to pass callable to `activity.schema` for dynamic schema generation. ### Changed - **BREAKING**: Renamed parameters on several classes to `client`: diff --git a/griptape/mixins/activity_mixin.py b/griptape/mixins/activity_mixin.py index 61e8076b19..497dffe623 100644 --- a/griptape/mixins/activity_mixin.py +++ b/griptape/mixins/activity_mixin.py @@ -88,8 +88,12 @@ def activity_schema(self, activity: Callable) -> Optional[Schema]: if activity is None or not getattr(activity, "is_activity", False): raise Exception("This method is not an activity.") if getattr(activity, "config")["schema"] is not None: - # Need to deepcopy to avoid modifying the original schema - config_schema = deepcopy(getattr(activity, "config")["schema"]) + config_schema = getattr(activity, "config")["schema"] + if isinstance(config_schema, Callable): + config_schema = config_schema(self) + else: + # Need to deepcopy to avoid modifying the original schema + config_schema = deepcopy(getattr(activity, "config")["schema"]) activity_name = self.activity_name(activity) if self.extra_schema_properties is not None and activity_name in self.extra_schema_properties: diff --git a/griptape/utils/decorators.py b/griptape/utils/decorators.py index 2ea2966939..356f4eec00 100644 --- a/griptape/utils/decorators.py +++ b/griptape/utils/decorators.py @@ -6,7 +6,12 @@ import schema from schema import Schema -CONFIG_SCHEMA = Schema({"description": str, schema.Optional("schema"): Schema}) +CONFIG_SCHEMA = Schema( + { + "description": str, + schema.Optional("schema"): lambda data: isinstance(data, (Schema, Callable)), + } +) def activity(config: dict) -> Any: diff --git a/tests/mocks/mock_tool/tool.py b/tests/mocks/mock_tool/tool.py index 7d09f391e1..9c2241636d 100644 --- a/tests/mocks/mock_tool/tool.py +++ b/tests/mocks/mock_tool/tool.py @@ -1,4 +1,4 @@ -from attrs import define, field +from attrs import Factory, define, field from schema import Literal, Schema from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact @@ -11,6 +11,7 @@ class MockTool(BaseTool): test_field: str = field(default="test", kw_only=True) test_int: int = field(default=5, kw_only=True) test_dict: dict = field(factory=dict, kw_only=True) + custom_schema: dict = field(default=Factory(lambda: {"test": str}), kw_only=True) @activity( config={ @@ -52,6 +53,15 @@ def test_str_output(self, value: dict) -> str: def test_no_schema(self, value: dict) -> str: return "no schema" + @activity( + config={ + "description": "test description", + "schema": lambda _self: _self.build_custom_schema(), + } + ) + def test_callable_schema(self) -> TextArtifact: + return TextArtifact("ack") + @activity(config={"description": "test description"}) def test_list_output(self, value: dict) -> ListArtifact: return ListArtifact([TextArtifact("foo"), TextArtifact("bar")]) @@ -64,3 +74,6 @@ def test_without_default_memory(self, value: dict) -> str: def foo(self) -> str: return "foo" + + def build_custom_schema(self) -> Schema: + return Schema(self.custom_schema, description="Test input") diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index c36c46074f..40b0a8a0e1 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -32,6 +32,30 @@ class TestAmazonBedrockPromptDriver: "name": "MockTool_test", } }, + { + "toolSpec": { + "description": "test description", + "inputSchema": { + "json": { + "$id": "http://json-schema.org/draft-07/schema#", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "description": "Test input", + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + } + }, + "required": ["values"], + "type": "object", + } + }, + "name": "MockTool_test_callable_schema", + } + }, { "toolSpec": { "description": "test description: foo", diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index c8e71705a4..40c983f7d2 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -29,6 +29,24 @@ class TestAnthropicPromptDriver: }, "name": "MockTool_test", }, + { + "description": "test description", + "input_schema": { + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "description": "Test input", + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + } + }, + "required": ["values"], + "type": "object", + }, + "name": "MockTool_test_callable_schema", + }, { "description": "test description: foo", "input_schema": { diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index c642b7ee00..e110d94698 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -17,6 +17,11 @@ class TestCoherePromptDriver: "name": "MockTool_test", "parameter_definitions": {"test": {"required": True, "type": "string"}}, }, + { + "description": "test description", + "name": "MockTool_test_callable_schema", + "parameter_definitions": {"test": {"required": True, "type": "string"}}, + }, { "description": "test description: foo", "name": "MockTool_test_error", diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 5d01217d9b..776664eb19 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -19,6 +19,11 @@ class TestGooglePromptDriver: "description": "test description: foo", "parameters": {"type": "OBJECT", "properties": {"test": {"type": "STRING"}}, "required": ["test"]}, }, + { + "name": "MockTool_test_callable_schema", + "description": "test description", + "parameters": {"type": "OBJECT", "properties": {"test": {"type": "STRING"}}, "required": ["test"]}, + }, { "name": "MockTool_test_error", "description": "test description: foo", diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index 0fc9e0f097..e4e9c47126 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -22,6 +22,20 @@ class TestOllamaPromptDriver: }, "type": "function", }, + { + "function": { + "description": "test description", + "name": "MockTool_test_callable_schema", + "parameters": { + "additionalProperties": False, + "description": "Test input", + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "type": "object", + }, + }, + "type": "function", + }, { "function": { "description": "test description: foo", diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index ae42aa3a1c..ce8ac84f53 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -35,6 +35,29 @@ class TestOpenAiChatPromptDriverFixtureMixin: }, "type": "function", }, + { + "function": { + "name": "MockTool_test_callable_schema", + "description": "test description", + "parameters": { + "type": "object", + "properties": { + "values": { + "description": "Test input", + "type": "object", + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "additionalProperties": False, + } + }, + "required": ["values"], + "additionalProperties": False, + "$id": "Parameters Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + }, + }, + "type": "function", + }, { "function": { "description": "test description: foo", diff --git a/tests/unit/mixins/test_activity_mixin.py b/tests/unit/mixins/test_activity_mixin.py index 1d684e2a5d..07c1738892 100644 --- a/tests/unit/mixins/test_activity_mixin.py +++ b/tests/unit/mixins/test_activity_mixin.py @@ -32,7 +32,7 @@ def test_find_activity(self): assert tool.find_activity("test_str_output") is None def test_activities(self, tool): - assert len(tool.activities()) == 7 + assert len(tool.activities()) == 8 assert tool.activities()[0] == tool.test def test_allowlist_and_denylist_validation(self): @@ -47,7 +47,7 @@ def test_allowlist(self): def test_denylist(self): tool = MockTool(test_field="hello", test_int=5, denylist=["test"]) - assert len(tool.activities()) == 6 + assert len(tool.activities()) == 7 def test_invalid_allowlist(self): with pytest.raises(ValueError): @@ -101,3 +101,24 @@ def test_extra_schema_properties(self): "additionalProperties": False, "type": "object", } + + def test_callable_schema(self): + tool = MockTool(custom_schema={"test": str}) + schema = tool.activity_schema(tool.test_callable_schema).json_schema("InputSchema") + + assert schema == { + "$id": "InputSchema", + "$schema": "http://json-schema.org/draft-07/schema#", + "properties": { + "values": { + "description": "Test input", + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "additionalProperties": False, + "type": "object", + } + }, + "required": ["values"], + "additionalProperties": False, + "type": "object", + } diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index dbb76a9432..f92f6a8870 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -40,6 +40,30 @@ class TestToolTask: "required": ["name", "path", "input", "tag"], "additionalProperties": False, }, + { + "type": "object", + "properties": { + "name": {"const": "MockTool"}, + "path": {"description": "test description", "const": "test_callable_schema"}, + "input": { + "type": "object", + "properties": { + "values": { + "description": "Test input", + "type": "object", + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "additionalProperties": False, + } + }, + "required": ["values"], + "additionalProperties": False, + }, + "tag": {"description": "Unique tag name for action execution.", "type": "string"}, + }, + "required": ["name", "path", "input", "tag"], + "additionalProperties": False, + }, { "type": "object", "properties": { diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 6b238c399d..24c715423c 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -36,6 +36,30 @@ class TestToolkitSubtask: "required": ["name", "path", "input", "tag"], "additionalProperties": False, }, + { + "type": "object", + "properties": { + "name": {"const": "MockTool"}, + "path": {"description": "test description", "const": "test_callable_schema"}, + "input": { + "type": "object", + "properties": { + "values": { + "description": "Test input", + "type": "object", + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "additionalProperties": False, + } + }, + "required": ["values"], + "additionalProperties": False, + }, + "tag": {"description": "Unique tag name for action execution.", "type": "string"}, + }, + "required": ["name", "path", "input", "tag"], + "additionalProperties": False, + }, { "type": "object", "properties": { diff --git a/tests/unit/tools/test_base_tool.py b/tests/unit/tools/test_base_tool.py index 318c2b3c34..60c9f68253 100644 --- a/tests/unit/tools/test_base_tool.py +++ b/tests/unit/tools/test_base_tool.py @@ -37,6 +37,29 @@ class TestBaseTool: "required": ["name", "path", "input"], "additionalProperties": False, }, + { + "type": "object", + "properties": { + "name": {"const": "MockTool"}, + "path": {"description": "test description", "const": "test_callable_schema"}, + "input": { + "type": "object", + "properties": { + "values": { + "description": "Test input", + "type": "object", + "properties": {"test": {"type": "string"}}, + "required": ["test"], + "additionalProperties": False, + } + }, + "required": ["values"], + "additionalProperties": False, + }, + }, + "required": ["name", "path", "input"], + "additionalProperties": False, + }, { "type": "object", "properties": {