Skip to content

Commit

Permalink
Implement (de)serialization in rulesets
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 16, 2024
1 parent 9aaceaa commit e4bc671
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 16 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BranchTask` for selecting which Tasks (if any) to run based on a condition.
- Support for `BranchTask` in `StructureVisualizer`.

### Changed

- Rulesets can now be serialized and deserialized.

### Fixed

- Exception when calling `Structure.to_json()` after it has run.
Expand Down
13 changes: 9 additions & 4 deletions griptape/mixins/rule_mixin.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
from __future__ import annotations

from attrs import define, field
import uuid

from attrs import Factory, define, field

from griptape.mixins.serializable_mixin import SerializableMixin
from griptape.rules import BaseRule, Ruleset


@define(slots=False)
class RuleMixin:
class RuleMixin(SerializableMixin):
DEFAULT_RULESET_NAME = "Default Ruleset"

_rulesets: list[Ruleset] = field(factory=list, kw_only=True, alias="rulesets")
_rulesets: list[Ruleset] = field(factory=list, kw_only=True, alias="rulesets", metadata={"serializable": True})
rules: list[BaseRule] = field(factory=list, kw_only=True)
_default_ruleset_name: str = field(default=Factory(lambda: RuleMixin.DEFAULT_RULESET_NAME), kw_only=True)
_default_ruleset_id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)

@property
def rulesets(self) -> list[Ruleset]:
rulesets = self._rulesets.copy()

if self.rules:
rulesets.append(Ruleset(name=self.DEFAULT_RULESET_NAME, rules=self.rules))
rulesets.append(Ruleset(id=self._default_ruleset_id, name=self._default_ruleset_name, rules=self.rules))

return rulesets
8 changes: 5 additions & 3 deletions griptape/rules/base_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from attrs import define, field

from griptape.mixins.serializable_mixin import SerializableMixin

@define(frozen=True)
class BaseRule(ABC):
value: Any = field()

@define()
class BaseRule(ABC, SerializableMixin):
value: Any = field(metadata={"serializable": True})
meta: dict[str, Any] = field(factory=dict, kw_only=True)

def __str__(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions griptape/rules/json_schema_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from griptape.utils import J2


@define(frozen=True)
@define()
class JsonSchemaRule(BaseRule):
value: dict = field()
value: dict = field(metadata={"serializable": True})
generate_template: J2 = field(default=Factory(lambda: J2("rules/json_schema.j2")))

def to_text(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions griptape/rules/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from griptape.rules import BaseRule


@define(frozen=True)
@define()
class Rule(BaseRule):
value: str = field()
value: str = field(metadata={"serializable": True})

def to_text(self) -> str:
return self.value
9 changes: 5 additions & 4 deletions griptape/rules/ruleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from attrs import Factory, define, field

from griptape.configs import Defaults
from griptape.mixins.serializable_mixin import SerializableMixin

if TYPE_CHECKING:
from collections.abc import Sequence
Expand All @@ -15,17 +16,17 @@


@define
class Ruleset:
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
class Ruleset(SerializableMixin):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
name: str = field(
default=Factory(lambda self: self.id, takes_self=True),
metadata={"serializable": True},
)
ruleset_driver: BaseRulesetDriver = field(
default=Factory(lambda: Defaults.drivers_config.ruleset_driver), kw_only=True
)
meta: dict[str, Any] = field(factory=dict, kw_only=True)
rules: Sequence[BaseRule] = field(factory=list)
meta: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True})
rules: Sequence[BaseRule] = field(factory=list, metadata={"serializable": True})

def __attrs_post_init__(self) -> None:
rules, meta = self.ruleset_driver.load(self.name)
Expand Down
4 changes: 4 additions & 0 deletions griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def _resolve_types(cls, attrs_cls: type) -> None:
from griptape.memory import TaskMemory
from griptape.memory.structure import BaseConversationMemory, Run
from griptape.memory.task.storage import BaseArtifactStorage
from griptape.rules.base_rule import BaseRule
from griptape.rules.ruleset import Ruleset
from griptape.structures import Structure
from griptape.tasks import BaseTask
from griptape.tokenizers import BaseTokenizer
Expand Down Expand Up @@ -210,6 +212,8 @@ def _resolve_types(cls, attrs_cls: type) -> None:
"State": BaseTask.State,
"BaseConversationMemory": BaseConversationMemory,
"BaseArtifactStorage": BaseArtifactStorage,
"BaseRule": BaseRule,
"Ruleset": Ruleset,
# Third party modules
"Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any,
"GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel
Expand Down
75 changes: 74 additions & 1 deletion tests/unit/mixins/test_rule_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from griptape.mixins.rule_mixin import RuleMixin
from griptape.rules import Rule, Ruleset
from griptape.rules import JsonSchemaRule, Rule, Ruleset
from griptape.structures import Agent
from griptape.tasks import PromptTask

Expand Down Expand Up @@ -41,3 +41,76 @@ def test_inherits_structure_rulesets(self):
agent.add_task(task)

assert task.rulesets == [ruleset1, ruleset2]

def test_to_dict(self):
mixin = RuleMixin(
rules=[
Rule("foo"),
JsonSchemaRule(
{
"type": "object",
"properties": {
"foo": {"type": "string"},
},
"required": ["foo"],
}
),
],
rulesets=[Ruleset("bar", [Rule("baz")])],
)

assert mixin.to_dict() == {
"rulesets": [
{
"id": mixin.rulesets[0].id,
"meta": {},
"name": "bar",
"rules": [{"type": "Rule", "value": "baz"}],
"type": "Ruleset",
},
{
"name": "Default Ruleset",
"id": mixin.rulesets[1].id,
"meta": {},
"rules": [
{"type": "Rule", "value": "foo"},
{
"type": "JsonSchemaRule",
"value": {
"properties": {"foo": {"type": "string"}},
"required": ["foo"],
"type": "object",
},
},
],
"type": "Ruleset",
},
],
"type": "RuleMixin",
}

def test_from_dict(self):
mixin = RuleMixin(
rules=[
Rule("foo"),
JsonSchemaRule(
{
"type": "object",
"properties": {
"foo": {"type": "string"},
},
"required": ["foo"],
}
),
],
rulesets=[Ruleset("bar", [Rule("baz")])],
)

new_mixin = RuleMixin.from_dict(mixin.to_dict())

for idx, _ in enumerate(new_mixin.rulesets):
rules = mixin.rulesets[idx].rules
new_rules = new_mixin.rulesets[idx].rules
for idx, _ in enumerate(rules):
assert rules[idx].value == new_rules[idx].value
assert rules[idx].meta == new_rules[idx].meta
2 changes: 2 additions & 0 deletions tests/unit/structures/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ def test_to_dict(self):
"child_ids": agent.tasks[0].child_ids,
"max_meta_memory_entries": agent.tasks[0].max_meta_memory_entries,
"context": agent.tasks[0].context,
"rulesets": [],
}
],
"rulesets": [],
"conversation_memory": {
"type": agent.conversation_memory.type,
"runs": agent.conversation_memory.runs,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/tasks/test_tool_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def test_to_dict(self):
"child_ids": task.child_ids,
"max_meta_memory_entries": task.max_meta_memory_entries,
"context": task.context,
"rulesets": [],
"tool": {
"type": task.tool.type,
"name": task.tool.name,
Expand Down

0 comments on commit e4bc671

Please sign in to comment.