From cecc138ce3fd4c21cdfa7062c8e4164e2e0e1ae6 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 17 Dec 2024 17:25:00 -0800 Subject: [PATCH] Serialize PromptTask Prompt Driver --- CHANGELOG.md | 1 + griptape/tasks/prompt_task.py | 2 +- tests/unit/structures/test_structure.py | 9 +++++++++ tests/unit/tasks/test_tool_task.py | 9 +++++++++ tests/unit/tasks/test_toolkit_task.py | 9 +++++++++ 5 files changed, 29 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 386edce98..b207c1e6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Rulesets can now be serialized and deserialized. - `ToolkitTask` now serializes its `tools` field. +- `PromptTask.prompt_driver` is now serialized. ### Fixed diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index e1d582f1f..7216536b8 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -25,7 +25,7 @@ @define class PromptTask(RuleMixin, BaseTask): prompt_driver: BasePromptDriver = field( - default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True + default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True, metadata={"serializable": True} ) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_generate_system_template, takes_self=True), diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 756a011e6..088e60d19 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -73,6 +73,14 @@ def test_to_dict(self): "max_meta_memory_entries": agent.tasks[0].max_meta_memory_entries, "context": agent.tasks[0].context, "rulesets": [], + "prompt_driver": { + "extra_params": {}, + "max_tokens": None, + "stream": False, + "temperature": 0.1, + "type": "MockPromptDriver", + "use_native_tools": False, + }, } ], "rulesets": [], @@ -111,6 +119,7 @@ def test_from_dict(self): serialized_agent = agent.to_dict() assert isinstance(serialized_agent, dict) + serialized_agent["tasks"][0]["prompt_driver"]["module_name"] = "tests.mocks.mock_prompt_driver" deserialized_agent = Agent.from_dict(serialized_agent) assert isinstance(deserialized_agent, Agent) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index 598f85ed3..ca0576ebe 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -251,6 +251,14 @@ def test_to_dict(self): "max_meta_memory_entries": task.max_meta_memory_entries, "context": task.context, "rulesets": [], + "prompt_driver": { + "extra_params": {}, + "max_tokens": None, + "stream": False, + "temperature": 0.1, + "type": "MockPromptDriver", + "use_native_tools": False, + }, "tool": { "type": task.tool.type, "name": task.tool.name, @@ -270,6 +278,7 @@ def test_from_dict(self): serialized_tool_task = task.to_dict() serialized_tool_task["tool"]["module_name"] = "tests.mocks.mock_tool.tool" + serialized_tool_task["prompt_driver"]["module_name"] = "tests.mocks.mock_prompt_driver" assert isinstance(serialized_tool_task, dict) deserialized_tool_task = ToolTask.from_dict(serialized_tool_task) diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 655cf3cb6..a8c51d7c1 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -390,6 +390,14 @@ def test_to_dict(self): "max_meta_memory_entries": 20, "context": {}, "rulesets": [], + "prompt_driver": { + "extra_params": {}, + "max_tokens": None, + "stream": False, + "temperature": 0.1, + "type": "MockPromptDriver", + "use_native_tools": False, + }, "tools": [ { "type": "MockTool", @@ -410,5 +418,6 @@ def test_from_dict(self): task = ToolkitTask("test", tools=[tool]) serialized_task = task.to_dict() serialized_task["tools"][0]["module_name"] = "tests.mocks.mock_tool.tool" + serialized_task["prompt_driver"]["module_name"] = "tests.mocks.mock_prompt_driver" assert ToolkitTask.from_dict(serialized_task).to_dict() == task.to_dict()