Skip to content

Commit

Permalink
Add back all_rulesets in favor of rulesets property cleverness
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jan 10, 2025
1 parent 028f35c commit 21c7295
Show file tree
Hide file tree
Showing 19 changed files with 102 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def run(self, context: RagContext) -> BaseArtifact:
def default_generate_system_template(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]}

if len(self.rulesets) > 0:
params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.rulesets)
if len(self.all_rulesets) > 0:
params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets)

if self.metadata is not None:
params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata)
Expand Down
8 changes: 4 additions & 4 deletions griptape/mixins/rule_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
class RuleMixin(SerializableMixin):
DEFAULT_RULESET_NAME = "Default Ruleset"

_rulesets: list[Ruleset] = field(factory=list, kw_only=True, alias="rulesets", metadata={"serializable": True})
rules: list[BaseRule] = field(factory=list, kw_only=True)
rulesets: list[Ruleset] = field(factory=list, kw_only=True, metadata={"serializable": True})
rules: list[BaseRule] = field(factory=list, kw_only=True, metadata={"serializable": True})
_default_ruleset_name: str = field(default=Factory(lambda: RuleMixin.DEFAULT_RULESET_NAME), kw_only=True)
_default_ruleset_id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)

@property
def rulesets(self) -> list[Ruleset]:
rulesets = self._rulesets.copy()
def all_rulesets(self) -> list[Ruleset]:
rulesets = self.rulesets.copy()

if self.rules:
rulesets.append(Ruleset(id=self._default_ruleset_id, name=self._default_ruleset_name, rules=self.rules))
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/base_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _read_from_file(self, path: str) -> ImageArtifact:
return ImageLoader().load(Path(path))

def _get_prompts(self, prompt: str) -> list[str]:
return [prompt, *[rule.value for ruleset in self.rulesets for rule in ruleset.rules]]
return [prompt, *[rule.value for ruleset in self.all_rulesets for rule in ruleset.rules]]

def _get_negative_prompts(self) -> list[str]:
return [rule.value for ruleset in self.negative_rulesets for rule in ruleset.rules]
4 changes: 3 additions & 1 deletion griptape/tasks/extraction_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ class ExtractionTask(BaseTextInputTask):
args: dict = field(kw_only=True, factory=dict)

def try_run(self) -> ListArtifact | ErrorArtifact:
return self.extraction_engine.extract_artifacts(ListArtifact([self.input]), rulesets=self.rulesets, **self.args)
return self.extraction_engine.extract_artifacts(
ListArtifact([self.input]), rulesets=self.all_rulesets, **self.args
)
10 changes: 5 additions & 5 deletions griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ class PromptTask(BaseTask, RuleMixin, ActionsSubtaskOriginMixin):
response_stop_sequence: str = field(default=RESPONSE_STOP_SEQUENCE, kw_only=True)

@property
def rulesets(self) -> list:
def all_rulesets(self) -> list:
default_rules = self.rules
rulesets = self._rulesets.copy()
rulesets = self.rulesets.copy()

if self.structure is not None:
if self.structure._rulesets:
rulesets = self.structure._rulesets + self._rulesets
if self.structure.rulesets:
rulesets = self.structure.rulesets + self.rulesets
if self.structure.rules:
default_rules = self.structure.rules + self.rules

Expand Down Expand Up @@ -213,7 +213,7 @@ def default_generate_system_template(self, _: PromptTask) -> str:
schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually.

return J2("tasks/prompt_task/system.j2").render(
rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets),
rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets),
action_names=str.join(", ", [tool.name for tool in self.tools]),
actions_schema=utils.minify_json(json.dumps(schema)),
meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories),
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/text_summary_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ class TextSummaryTask(BaseTextInputTask):
summary_engine: BaseSummaryEngine = field(default=Factory(lambda: PromptSummaryEngine()), kw_only=True)

def try_run(self) -> TextArtifact:
return TextArtifact(self.summary_engine.summarize_text(self.input.to_text(), rulesets=self.rulesets))
return TextArtifact(self.summary_engine.summarize_text(self.input.to_text(), rulesets=self.all_rulesets))
2 changes: 1 addition & 1 deletion griptape/tasks/tool_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def preprocess(self, structure: Structure) -> ToolTask:

def default_generate_system_template(self, _: PromptTask) -> str:
return J2("tasks/tool_task/system.j2").render(
rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets),
rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets),
action_schema=utils.minify_json(json.dumps(self.tool.schema())),
meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories),
use_native_tools=self.prompt_driver.use_native_tools,
Expand Down
2 changes: 1 addition & 1 deletion griptape/tools/prompt_summary/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ def summarize(self, params: dict) -> BaseArtifact:
else:
return ErrorArtifact("memory not found")

return self.prompt_summary_engine.summarize_artifacts(artifacts, rulesets=self.rulesets)
return self.prompt_summary_engine.summarize_artifacts(artifacts, rulesets=self.all_rulesets)
2 changes: 1 addition & 1 deletion griptape/tools/query/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class QueryTool(BaseTool, RuleMixin):
lambda self: RagEngine(
response_stage=ResponseRagStage(
response_modules=[
PromptResponseRagModule(prompt_driver=self.prompt_driver, rulesets=self.rulesets)
PromptResponseRagModule(prompt_driver=self.prompt_driver, rulesets=self.all_rulesets)
],
),
),
Expand Down
51 changes: 23 additions & 28 deletions tests/unit/mixins/test_rule_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@ def test_rules(self):
assert mixin.rules == [rule]

def test_rulesets(self):
ruleset = Ruleset("foo", [Rule("bar")])
mixin = RuleMixin(rulesets=[ruleset])
ruleset1 = Ruleset("foo", [Rule("bar")])
mixin = RuleMixin(rulesets=[ruleset1])

assert mixin.rulesets == [ruleset]
mixin.rulesets.append(Ruleset("baz", [Rule("qux")]))
assert mixin.rulesets == [ruleset]
assert mixin.rulesets == [ruleset1]
ruleset2 = Ruleset("baz", [Rule("qux")])
mixin.rulesets.append(ruleset2)
assert mixin.rulesets == [ruleset1, ruleset2]

def test_rules_and_rulesets(self):
mixin = RuleMixin(rules=[Rule("foo")], rulesets=[Ruleset("bar", [Rule("baz")])])

assert mixin.rules == [Rule("foo")]
assert mixin.rulesets[0].name == "bar"
assert mixin.rulesets[0].rules == [Rule("baz")]
assert mixin.rulesets[1].name == "Default Ruleset"
assert mixin.rulesets[1].rules == [Rule("foo")]
assert mixin.all_rulesets[0].name == "bar"
assert mixin.all_rulesets[0].rules == [Rule("baz")]
assert mixin.all_rulesets[1].name == "Default Ruleset"
assert mixin.all_rulesets[1].rules == [Rule("foo")]

def test_inherits_structure_rulesets(self):
# Tests that a task using the mixin inherits rulesets from its structure.
Expand All @@ -40,7 +41,7 @@ def test_inherits_structure_rulesets(self):
task = PromptTask(rulesets=[ruleset2])
agent.add_task(task)

assert task.rulesets == [ruleset1, ruleset2]
assert task.all_rulesets == [ruleset1, ruleset2]

def test_to_dict(self):
mixin = RuleMixin(
Expand All @@ -60,31 +61,25 @@ def test_to_dict(self):
)

assert mixin.to_dict() == {
"rules": [
{"type": "Rule", "value": "foo"},
{
"type": "JsonSchemaRule",
"value": {
"properties": {"foo": {"type": "string"}},
"required": ["foo"],
"type": "object",
},
},
],
"rulesets": [
{
"id": mixin.rulesets[0].id,
"id": mixin.all_rulesets[0].id,
"meta": {},
"name": "bar",
"rules": [{"type": "Rule", "value": "baz"}],
"type": "Ruleset",
},
{
"name": "Default Ruleset",
"id": mixin.rulesets[1].id,
"meta": {},
"rules": [
{"type": "Rule", "value": "foo"},
{
"type": "JsonSchemaRule",
"value": {
"properties": {"foo": {"type": "string"}},
"required": ["foo"],
"type": "object",
},
},
],
"type": "Ruleset",
},
],
"type": "RuleMixin",
}
Expand Down
24 changes: 12 additions & 12 deletions tests/unit/structures/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def test_init(self):
assert agent.prompt_driver is driver
assert isinstance(agent.task, PromptTask)
assert isinstance(agent.task, PromptTask)
assert agent.rulesets[0].name == "TestRuleset"
assert agent.rulesets[0].rules[0].value == "test"
assert agent.all_rulesets[0].name == "TestRuleset"
assert agent.all_rulesets[0].rules[0].value == "test"
assert isinstance(agent.conversation_memory, ConversationMemory)
assert isinstance(Agent(tools=[MockTool()]).task, PromptTask)

Expand All @@ -32,31 +32,31 @@ def test_rulesets(self):
agent.add_task(PromptTask(rulesets=[Ruleset("Bar", [Rule("bar test")])]))

assert isinstance(agent.task, PromptTask)
assert len(agent.task.rulesets) == 2
assert agent.task.rulesets[0].name == "Foo"
assert agent.task.rulesets[1].name == "Bar"
assert len(agent.task.all_rulesets) == 2
assert agent.task.all_rulesets[0].name == "Foo"
assert agent.task.all_rulesets[1].name == "Bar"

def test_rules(self):
agent = Agent(rules=[Rule("foo test")])

agent.add_task(PromptTask(rules=[Rule("bar test")]))

assert isinstance(agent.task, PromptTask)
assert len(agent.task.rulesets) == 1
assert agent.task.rulesets[0].name == "Default Ruleset"
assert len(agent.task.rulesets[0].rules) == 2
assert agent.task.rulesets[0].rules[0].value == "foo test"
assert agent.task.rulesets[0].rules[1].value == "bar test"
assert len(agent.task.all_rulesets) == 1
assert agent.task.all_rulesets[0].name == "Default Ruleset"
assert len(agent.task.all_rulesets[0].rules) == 2
assert agent.task.all_rulesets[0].rules[0].value == "foo test"
assert agent.task.all_rulesets[0].rules[1].value == "bar test"

def test_rules_and_rulesets(self):
agent = Agent(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])
assert len(agent.rulesets) == 2
assert len(agent.all_rulesets) == 2
assert len(agent.rules) == 1

agent = Agent()
agent.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]))
assert isinstance(agent.task, PromptTask)
assert len(agent.task.rulesets) == 2
assert len(agent.task.all_rulesets) == 2
assert len(agent.task.rules) == 1

def test_with_task_memory(self):
Expand Down
31 changes: 16 additions & 15 deletions tests/unit/structures/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def test_init(self):

assert pipeline.input_task is None
assert pipeline.output_task is None
assert pipeline.rulesets[0].name == "TestRuleset"
assert pipeline.rulesets[0].rules[0].value == "test"
assert pipeline.all_rulesets[0].name == "TestRuleset"
assert pipeline.all_rulesets[0].rules[0].value == "test"
assert pipeline.conversation_memory is not None

def test_rulesets(self):
Expand All @@ -45,38 +45,39 @@ def test_rulesets(self):
)

assert isinstance(pipeline.tasks[0], PromptTask)
assert len(pipeline.tasks[0].rulesets) == 2
assert pipeline.tasks[0].rulesets[0].name == "Foo"
assert pipeline.tasks[0].rulesets[1].name == "Bar"
assert len(pipeline.tasks[0].all_rulesets) == 2
assert pipeline.tasks[0].all_rulesets[0].name == "Foo"
assert pipeline.tasks[0].all_rulesets[1].name == "Bar"

assert isinstance(pipeline.tasks[1], PromptTask)
assert len(pipeline.tasks[1].rulesets) == 2
assert pipeline.tasks[1].rulesets[0].name == "Foo"
assert pipeline.tasks[1].rulesets[1].name == "Baz"
assert len(pipeline.tasks[1].all_rulesets) == 2
assert pipeline.tasks[1].all_rulesets[0].name == "Foo"
assert pipeline.tasks[1].all_rulesets[1].name == "Baz"

def test_rules(self):
pipeline = Pipeline(rules=[Rule("foo test")])

pipeline.add_tasks(PromptTask(rules=[Rule("bar test")]), PromptTask(rules=[Rule("baz test")]))

assert isinstance(pipeline.tasks[0], PromptTask)
assert len(pipeline.tasks[0].rulesets) == 1
assert pipeline.tasks[0].rulesets[0].name == "Default Ruleset"
assert len(pipeline.tasks[0].rulesets[0].rules) == 2
assert len(pipeline.tasks[0].all_rulesets) == 1
assert pipeline.tasks[0].all_rulesets[0].name == "Default Ruleset"
assert len(pipeline.tasks[0].all_rulesets[0].rules) == 2

assert isinstance(pipeline.tasks[1], PromptTask)
assert pipeline.tasks[1].rulesets[0].name == "Default Ruleset"
assert len(pipeline.tasks[1].rulesets[0].rules) == 2
assert pipeline.tasks[1].all_rulesets[0].name == "Default Ruleset"
assert len(pipeline.tasks[1].all_rulesets[0].rules) == 2

def test_rules_and_rulesets(self):
pipeline = Pipeline(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])
assert len(pipeline.rulesets) == 2
assert len(pipeline.all_rulesets) == 2
assert len(pipeline.rulesets) == 1
assert len(pipeline.rules) == 1

pipeline = Pipeline()
pipeline.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]))
assert isinstance(pipeline.tasks[0], PromptTask)
assert len(pipeline.tasks[0].rulesets) == 2
assert len(pipeline.tasks[0].all_rulesets) == 2
assert len(pipeline.tasks[0].rules) == 1

def test_with_no_task_memory(self):
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/structures/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_to_dict(self):
"max_meta_memory_entries": agent.tasks[0].max_meta_memory_entries,
"context": agent.tasks[0].context,
"rulesets": [],
"rules": [],
"max_subtasks": 20,
"tools": [],
"prompt_driver": {
Expand All @@ -88,6 +89,7 @@ def test_to_dict(self):
}
],
"rulesets": [],
"rules": [],
"conversation_memory": {
"type": agent.conversation_memory.type,
"runs": agent.conversation_memory.runs,
Expand Down
28 changes: 14 additions & 14 deletions tests/unit/structures/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,39 +42,39 @@ def test_rulesets(self):
)

assert isinstance(workflow.tasks[0], PromptTask)
assert len(workflow.tasks[0].rulesets) == 2
assert workflow.tasks[0].rulesets[0].name == "Foo"
assert workflow.tasks[0].rulesets[1].name == "Bar"
assert len(workflow.tasks[0].all_rulesets) == 2
assert workflow.tasks[0].all_rulesets[0].name == "Foo"
assert workflow.tasks[0].all_rulesets[1].name == "Bar"

assert isinstance(workflow.tasks[1], PromptTask)
assert len(workflow.tasks[1].rulesets) == 2
assert workflow.tasks[1].rulesets[0].name == "Foo"
assert workflow.tasks[1].rulesets[1].name == "Baz"
assert len(workflow.tasks[1].all_rulesets) == 2
assert workflow.tasks[1].all_rulesets[0].name == "Foo"
assert workflow.tasks[1].all_rulesets[1].name == "Baz"

def test_rules(self):
workflow = Workflow(rules=[Rule("foo test")])

workflow.add_tasks(PromptTask(rules=[Rule("bar test")]), PromptTask(rules=[Rule("baz test")]))

assert isinstance(workflow.tasks[0], PromptTask)
assert len(workflow.tasks[0].rulesets) == 1
assert workflow.tasks[0].rulesets[0].name == "Default Ruleset"
assert len(workflow.tasks[0].rulesets[0].rules) == 2
assert len(workflow.tasks[0].all_rulesets) == 1
assert workflow.tasks[0].all_rulesets[0].name == "Default Ruleset"
assert len(workflow.tasks[0].all_rulesets[0].rules) == 2

assert isinstance(workflow.tasks[1], PromptTask)
assert len(workflow.tasks[1].rulesets) == 1
assert workflow.tasks[1].rulesets[0].name == "Default Ruleset"
assert len(workflow.tasks[1].rulesets[0].rules) == 2
assert len(workflow.tasks[1].all_rulesets) == 1
assert workflow.tasks[1].all_rulesets[0].name == "Default Ruleset"
assert len(workflow.tasks[1].all_rulesets[0].rules) == 2

def test_rules_and_rulesets(self):
workflow = Workflow(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])])
assert len(workflow.rulesets) == 2
assert len(workflow.all_rulesets) == 2
assert len(workflow.rules) == 1

workflow = Workflow()
workflow.add_task(PromptTask(rules=[Rule("foo test")], rulesets=[Ruleset("Bar", [Rule("bar test")])]))
assert isinstance(workflow.tasks[0], PromptTask)
assert len(workflow.tasks[0].rulesets) == 2
assert len(workflow.tasks[0].all_rulesets) == 2
assert len(workflow.tasks[0].rules) == 1

def test_with_no_task_memory(self):
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/tasks/test_base_text_input_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def test_rulesets(self):
rulesets=[Ruleset("Foo", [Rule("foo test")]), Ruleset("Bar", [Rule("bar test")])]
)

assert len(prompt_task.rulesets) == 2
assert prompt_task.rulesets[0].name == "Foo"
assert prompt_task.rulesets[1].name == "Bar"
assert len(prompt_task.all_rulesets) == 2
assert prompt_task.all_rulesets[0].name == "Foo"
assert prompt_task.all_rulesets[1].name == "Bar"

def test_rules(self):
prompt_task = MockTextInputTask(rules=[Rule("foo test"), Rule("bar test")])

assert prompt_task.rulesets[0].name == "Default Ruleset"
assert prompt_task.all_rulesets[0].name == "Default Ruleset"
Loading

0 comments on commit 21c7295

Please sign in to comment.