Skip to content

Commit

Permalink
Serialize PromptTask Prompt Driver
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 18, 2024
1 parent 0224cc2 commit cecc138
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Rulesets can now be serialized and deserialized.
- `ToolkitTask` now serializes its `tools` field.
- `PromptTask.prompt_driver` is now serialized.

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
@define
class PromptTask(RuleMixin, BaseTask):
prompt_driver: BasePromptDriver = field(
default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True
default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True, metadata={"serializable": True}
)
generate_system_template: Callable[[PromptTask], str] = field(
default=Factory(lambda self: self.default_generate_system_template, takes_self=True),
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/structures/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ def test_to_dict(self):
"max_meta_memory_entries": agent.tasks[0].max_meta_memory_entries,
"context": agent.tasks[0].context,
"rulesets": [],
"prompt_driver": {
"extra_params": {},
"max_tokens": None,
"stream": False,
"temperature": 0.1,
"type": "MockPromptDriver",
"use_native_tools": False,
},
}
],
"rulesets": [],
Expand Down Expand Up @@ -111,6 +119,7 @@ def test_from_dict(self):
serialized_agent = agent.to_dict()
assert isinstance(serialized_agent, dict)

serialized_agent["tasks"][0]["prompt_driver"]["module_name"] = "tests.mocks.mock_prompt_driver"
deserialized_agent = Agent.from_dict(serialized_agent)
assert isinstance(deserialized_agent, Agent)

Expand Down
9 changes: 9 additions & 0 deletions tests/unit/tasks/test_tool_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,14 @@ def test_to_dict(self):
"max_meta_memory_entries": task.max_meta_memory_entries,
"context": task.context,
"rulesets": [],
"prompt_driver": {
"extra_params": {},
"max_tokens": None,
"stream": False,
"temperature": 0.1,
"type": "MockPromptDriver",
"use_native_tools": False,
},
"tool": {
"type": task.tool.type,
"name": task.tool.name,
Expand All @@ -270,6 +278,7 @@ def test_from_dict(self):

serialized_tool_task = task.to_dict()
serialized_tool_task["tool"]["module_name"] = "tests.mocks.mock_tool.tool"
serialized_tool_task["prompt_driver"]["module_name"] = "tests.mocks.mock_prompt_driver"
assert isinstance(serialized_tool_task, dict)

deserialized_tool_task = ToolTask.from_dict(serialized_tool_task)
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/tasks/test_toolkit_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,14 @@ def test_to_dict(self):
"max_meta_memory_entries": 20,
"context": {},
"rulesets": [],
"prompt_driver": {
"extra_params": {},
"max_tokens": None,
"stream": False,
"temperature": 0.1,
"type": "MockPromptDriver",
"use_native_tools": False,
},
"tools": [
{
"type": "MockTool",
Expand All @@ -410,5 +418,6 @@ def test_from_dict(self):
task = ToolkitTask("test", tools=[tool])
serialized_task = task.to_dict()
serialized_task["tools"][0]["module_name"] = "tests.mocks.mock_tool.tool"
serialized_task["prompt_driver"]["module_name"] = "tests.mocks.mock_prompt_driver"

assert ToolkitTask.from_dict(serialized_task).to_dict() == task.to_dict()

0 comments on commit cecc138

Please sign in to comment.