Skip to content

Commit

Permalink
Refactor RuleMixin
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 4, 2024
1 parent 192515c commit 94b3d23
Show file tree
Hide file tree
Showing 21 changed files with 109 additions and 130 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Anthropic native Tool calling.
- Empty `ActionsSubtask.thought` being logged.
- `RuleMixin` no longer prevents setting `rulesets` _and_ `rules` at the same time.
- `PromptTask` will merge in its Structure's Rulesets and Rules.
- `PromptTask` not checking whether Structure was set before building Prompt Stack.

## [0.32.0] - 2024-09-17

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def run(self, context: RagContext) -> BaseArtifact:
def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]}

if len(self.all_rulesets) > 0:
params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets)
if len(self.rulesets) > 0:
params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.rulesets)

if self.metadata is not None:
params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata)
Expand Down
48 changes: 7 additions & 41 deletions griptape/mixins/rule_mixin.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,22 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

from attrs import Attribute, define, field
from attrs import define, field

from griptape.rules import BaseRule, Ruleset

if TYPE_CHECKING:
from griptape.structures import Structure


@define(slots=False)
class RuleMixin:
DEFAULT_RULESET_NAME = "Default Ruleset"
ADDITIONAL_RULESET_NAME = "Additional Ruleset"

rulesets: list[Ruleset] = field(factory=list, kw_only=True)
_rulesets: list[Ruleset] = field(factory=list, kw_only=True, alias="rulesets")
rules: list[BaseRule] = field(factory=list, kw_only=True)
structure: Optional[Structure] = field(default=None, kw_only=True)

@rulesets.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_rulesets(self, _: Attribute, rulesets: list[Ruleset]) -> None:
if not rulesets:
return

if self.rules:
raise ValueError("Can't have both rulesets and rules specified.")

@rules.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_rules(self, _: Attribute, rules: list[BaseRule]) -> None:
if not rules:
return

if self.rulesets:
raise ValueError("Can't have both rules and rulesets specified.")

@property
def all_rulesets(self) -> list[Ruleset]:
structure_rulesets = []

if self.structure:
if self.structure.rulesets:
structure_rulesets = self.structure.rulesets
elif self.structure.rules:
structure_rulesets = [Ruleset(name=self.DEFAULT_RULESET_NAME, rules=self.structure.rules)]
def rulesets(self) -> list[Ruleset]:
rulesets = self._rulesets

task_rulesets = []
if self.rulesets:
task_rulesets = self.rulesets
elif self.rules:
task_ruleset_name = self.ADDITIONAL_RULESET_NAME if structure_rulesets else self.DEFAULT_RULESET_NAME

task_rulesets = [Ruleset(name=task_ruleset_name, rules=self.rules)]
if self.rules:
rulesets.append(Ruleset(name=self.DEFAULT_RULESET_NAME, rules=self.rules))

return structure_rulesets + task_rulesets
return rulesets
24 changes: 3 additions & 21 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,24 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional

from attrs import Attribute, Factory, define, field
from attrs import Factory, define, field

from griptape.common import observable
from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent
from griptape.memory import TaskMemory
from griptape.memory.meta import MetaMemory
from griptape.memory.structure import ConversationMemory, Run
from griptape.mixins.rule_mixin import RuleMixin

if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact
from griptape.memory.structure import BaseConversationMemory
from griptape.rules import BaseRule, Rule, Ruleset
from griptape.tasks import BaseTask


@define
class Structure(ABC):
class Structure(ABC, RuleMixin):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
rulesets: list[Ruleset] = field(factory=list, kw_only=True)
rules: list[BaseRule] = field(factory=list, kw_only=True)
_tasks: list[BaseTask | list[BaseTask]] = field(factory=list, kw_only=True, alias="tasks")
conversation_memory: Optional[BaseConversationMemory] = field(
default=Factory(lambda: ConversationMemory()),
Expand All @@ -37,22 +35,6 @@ class Structure(ABC):
fail_fast: bool = field(default=True, kw_only=True)
_execution_args: tuple = ()

@rulesets.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_rulesets(self, _: Attribute, rulesets: list[Ruleset]) -> None:
if not rulesets:
return

if self.rules:
raise ValueError("can't have both rulesets and rules specified")

@rules.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_rules(self, _: Attribute, rules: list[Rule]) -> None:
if not rules:
return

if self.rulesets:
raise ValueError("can't have both rules and rulesets specified")

def __attrs_post_init__(self) -> None:
tasks = self._tasks.copy()
self._tasks.clear()
Expand Down
4 changes: 2 additions & 2 deletions griptape/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def parents_output_text(self) -> str:

@property
def meta_memories(self) -> list[BaseMetaEntry]:
if self.structure and self.structure.meta_memory:
if self.structure is not None and self.structure.meta_memory:
if self.max_meta_memory_entries:
return self.structure.meta_memory.entries[: self.max_meta_memory_entries]
else:
Expand Down Expand Up @@ -191,7 +191,7 @@ def run(self) -> BaseArtifact: ...

@property
def full_context(self) -> dict[str, Any]:
if self.structure:
if self.structure is not None:
structure_context = self.structure.context(self)

structure_context.update(self.context)
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/inpainting_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run(self) -> ImageArtifact:
prompts=[prompt_artifact.to_text()],
image=image_artifact,
mask=mask_artifact,
rulesets=self.all_rulesets,
rulesets=self.rulesets,
negative_rulesets=self.negative_rulesets,
)

Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/outpainting_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run(self) -> ImageArtifact:
prompts=[prompt_artifact.to_text()],
image=image_artifact,
mask=mask_artifact,
rulesets=self.all_rulesets,
rulesets=self.rulesets,
negative_rulesets=self.negative_rulesets,
)

Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/prompt_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def input(self, value: TextArtifact) -> None:
def run(self) -> ImageArtifact:
image_artifact = self.image_generation_engine.run(
prompts=[self.input.to_text()],
rulesets=self.all_rulesets,
rulesets=self.rulesets,
negative_rulesets=self.negative_rulesets,
)

Expand Down
21 changes: 19 additions & 2 deletions griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from griptape.common import PromptStack
from griptape.configs import Defaults
from griptape.mixins.rule_mixin import RuleMixin
from griptape.rules import Ruleset
from griptape.tasks import BaseTask
from griptape.utils import J2

Expand All @@ -32,6 +33,22 @@ class PromptTask(RuleMixin, BaseTask):
alias="input",
)

@property
def rulesets(self) -> list:
default_rules = self.rules
rulesets = self._rulesets

if self.structure is not None:
if self.structure._rulesets:
rulesets = self.structure._rulesets + self._rulesets
if self.structure.rules:
default_rules = self.structure.rules + self.rules

if default_rules:
rulesets.append(Ruleset(name=self.DEFAULT_RULESET_NAME, rules=default_rules))

return rulesets

@property
def input(self) -> BaseArtifact:
return self._process_task_input(self._input)
Expand All @@ -45,7 +62,7 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask],
@property
def prompt_stack(self) -> PromptStack:
stack = PromptStack()
memory = self.structure.conversation_memory
memory = self.structure.conversation_memory if self.structure is not None else None

system_template = self.generate_system_template(self)
if system_template:
Expand All @@ -64,7 +81,7 @@ def prompt_stack(self) -> PromptStack:

def default_system_template_generator(self, _: PromptTask) -> str:
return J2("tasks/prompt_task/system.j2").render(
rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets),
rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets),
)

def before_run(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/text_summary_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ class TextSummaryTask(BaseTextInputTask):
summary_engine: BaseSummaryEngine = field(default=Factory(lambda: PromptSummaryEngine()), kw_only=True)

def run(self) -> TextArtifact:
return TextArtifact(self.summary_engine.summarize_text(self.input.to_text(), rulesets=self.all_rulesets))
return TextArtifact(self.summary_engine.summarize_text(self.input.to_text(), rulesets=self.rulesets))
2 changes: 1 addition & 1 deletion griptape/tasks/text_to_speech_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def input(self, value: TextArtifact) -> None:
self._input = value

def run(self) -> AudioArtifact:
audio_artifact = self.text_to_speech_engine.run(prompts=[self.input.to_text()], rulesets=self.all_rulesets)
audio_artifact = self.text_to_speech_engine.run(prompts=[self.input.to_text()], rulesets=self.rulesets)

if self.output_dir or self.output_file:
self._write_to_file(audio_artifact)
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/tool_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def preprocess(self, structure: Structure) -> ToolTask:

def default_system_template_generator(self, _: PromptTask) -> str:
return J2("tasks/tool_task/system.j2").render(
rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets),
rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets),
action_schema=utils.minify_json(json.dumps(self.tool.schema())),
meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories),
use_native_tools=self.prompt_driver.use_native_tools,
Expand Down
6 changes: 3 additions & 3 deletions griptape/tasks/toolkit_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def tool_output_memory(self) -> list[TaskMemory]:
@property
def prompt_stack(self) -> PromptStack:
stack = PromptStack(tools=self.tools)
memory = self.structure.conversation_memory
memory = self.structure.conversation_memory if self.structure is not None else None

stack.add_system_message(self.generate_system_template(self))

Expand Down Expand Up @@ -113,7 +113,7 @@ def prompt_stack(self) -> PromptStack:
stack.add_assistant_message(self.generate_assistant_subtask_template(s))
stack.add_user_message(self.generate_user_subtask_template(s))

if memory:
if memory is not None:
# inserting at index 1 to place memory right after system prompt
memory.add_to_prompt_stack(self.prompt_driver, stack, 1)

Expand All @@ -132,7 +132,7 @@ def default_system_template_generator(self, _: PromptTask) -> str:
schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually.

return J2("tasks/toolkit_task/system.j2").render(
rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets),
rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets),
action_names=str.join(", ", [tool.name for tool in self.tools]),
actions_schema=utils.minify_json(json.dumps(schema)),
meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories),
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/variation_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def run(self) -> ImageArtifact:
output_image_artifact = self.image_generation_engine.run(
prompts=[prompt_artifact.to_text()],
image=image_artifact,
rulesets=self.all_rulesets,
rulesets=self.rulesets,
negative_rulesets=self.negative_rulesets,
)

Expand Down
2 changes: 1 addition & 1 deletion griptape/tools/prompt_summary/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ def summarize(self, params: dict) -> BaseArtifact:
else:
return ErrorArtifact("memory not found")

return self.prompt_summary_engine.summarize_artifacts(artifacts, rulesets=self.all_rulesets)
return self.prompt_summary_engine.summarize_artifacts(artifacts, rulesets=self.rulesets)
10 changes: 5 additions & 5 deletions tests/unit/mixins/test_rule_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pytest

from griptape.mixins.rule_mixin import RuleMixin
from griptape.rules import Rule, Ruleset
from griptape.structures import Agent
Expand All @@ -23,8 +21,10 @@ def test_rulesets(self):
assert mixin.rulesets == [ruleset]

def test_rules_and_rulesets(self):
with pytest.raises(ValueError):
RuleMixin(rules=[Rule("foo")], rulesets=[Ruleset("bar", [Rule("baz")])])
mixin = RuleMixin(rules=[Rule("foo")], rulesets=[Ruleset("bar", [Rule("baz")])])

assert mixin.rules == [Rule("foo")]
assert mixin.rulesets == [Ruleset("bar", [Rule("baz")]), Ruleset("Default Ruleset", [Rule("foo")])]

def test_inherits_structure_rulesets(self):
# Tests that a task using the mixin inherits rulesets from its structure.
Expand All @@ -35,4 +35,4 @@ def test_inherits_structure_rulesets(self):
task = PromptTask(rulesets=[ruleset2])
agent.add_task(task)

assert task.all_rulesets == [ruleset1, ruleset2]
assert task.rulesets == [ruleset1, ruleset2]
25 changes: 15 additions & 10 deletions tests/unit/structures/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,32 @@ def test_rulesets(self):
agent.add_task(PromptTask(rulesets=[Ruleset("Bar", [Rule("bar test")])]))

assert isinstance(agent.task, PromptTask)
assert len(agent.task.all_rulesets) == 2
assert agent.task.all_rulesets[0].name == "Foo"
assert agent.task.all_rulesets[1].name == "Bar"
assert len(agent.task.rulesets) == 2
assert agent.task.rulesets[0].name == "Foo"
assert agent.task.rulesets[1].name == "Bar"

def test_rules(self):
agent = Agent(rules=[Rule("foo test")])

agent.add_task(PromptTask(rules=[Rule("bar test")]))

assert isinstance(agent.task, PromptTask)
assert len(agent.task.all_rulesets) == 2
assert agent.task.all_rulesets[0].name == "Default Ruleset"
assert agent.task.all_rulesets[1].name == "Additional Ruleset"
assert len(agent.task.rulesets) == 1
assert agent.task.rulesets[0].name == "Default Ruleset"
assert len(agent.task.rulesets[0].rules) == 2
assert agent.task.rulesets[0].rules[0].value == "foo test"
assert agent.task.rulesets[0].rules[1].value == "bar test"

def test_rules_and_rulesets(self):
with pytest.raises(ValueError):
Agent(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])
agent = Agent(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])
assert len(agent.rulesets) == 2
assert len(agent.rules) == 1

agent = Agent()
with pytest.raises(ValueError):
agent.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]))
agent.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]))
assert isinstance(agent.task, PromptTask)
assert len(agent.task.rulesets) == 2
assert len(agent.task.rules) == 1

def test_with_task_memory(self):
agent = Agent(tools=[MockTool(off_prompt=True)])
Expand Down
Loading

0 comments on commit 94b3d23

Please sign in to comment.