From f061d1e3b02a661ae3a7ebcd933689c1926fafff Mon Sep 17 00:00:00 2001 From: dylanholmes <4370153+dylanholmes@users.noreply.github.com> Date: Tue, 4 Jun 2024 11:12:55 +0200 Subject: [PATCH] Add task relationship APIs --- CHANGELOG.md | 10 + docs/examples/multi-agent-workflow.md | 44 +- .../structures/workflows.md | 33 +- griptape/structures/structure.py | 22 + griptape/structures/workflow.py | 11 +- griptape/tasks/actions_subtask.py | 18 - griptape/tasks/base_task.py | 20 + tests/unit/structures/test_workflow.py | 525 +++++++++++++++--- tests/unit/tasks/test_toolkit_task.py | 4 +- 9 files changed, 550 insertions(+), 137 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d053ac1fb1..15fd14b30f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,10 +6,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Fixed +- `Workflow.insert_task()` no longer inserts duplicate tasks when given multiple parent tasks. + ### Added - `AzureOpenAiStructureConfig` for providing Structures with all Azure OpenAI Driver configuration. - `AzureOpenAiVisionImageQueryDriver` to support queries on images using Azure's OpenAI Vision models. - Parameter `env` to `BaseStructureRunDriver` to set environment variables for a Structure Run. +- `BaseTask.add_child()` to add a child task to a parent task. +- `BaseTask.add_children()` to add multiple child tasks to a parent task. +- `BaseTask.add_parent()` to add a parent task to a child task. +- `BaseTask.add_parents()` to add multiple parent tasks to a child task. +- `Structure.resolve_relationships()` to resolve asymmetrically defined parent/child relationships. In other words, if a parent declares a child, but the child does not declare the parent, the parent will automatically be added as a parent of the child when running this method. The method is invoked automatically by `Structure.before_run()`. ### Changed - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. @@ -18,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Field `azure_ad_token` on all Azure Drivers is no longer serializable. - Default standard OpenAI and Azure OpenAI image query model to `gpt-4o`. - Error message to be more helpful when importing optional dependencies. +- **BREAKING**: `Workflow` no longer modifies task relationships when adding tasks via `tasks` init param, `add_tasks()` or `add_task()`. Previously, adding a task would automatically add the previously added task as its parent. Existing code that relies on this behavior will need to be updated to explicitly add parent/child relationships using the API offered by `BaseTask`. +- `Structure.before_run()` now automatically resolves asymmetrically defined parent/child relationships using the new `Structure.resolve_relationships()`. ## [0.25.1] - 2024-05-15 diff --git a/docs/examples/multi-agent-workflow.md b/docs/examples/multi-agent-workflow.md index 044f4abadb..2678bf3ab0 100644 --- a/docs/examples/multi-agent-workflow.md +++ b/docs/examples/multi-agent-workflow.md @@ -155,35 +155,33 @@ if __name__ == "__main__": ), ), ) - end_task = team.add_task( - PromptTask( - 'State "All Done!"', - ) - ) - team.insert_tasks( - research_task, - [ - StructureRunTask( - ( - """Using insights provided, develop an engaging blog + writer_tasks = team.add_tasks(*[ + StructureRunTask( + ( + """Using insights provided, develop an engaging blog post that highlights the most significant AI advancements. Your post should be informative yet accessible, catering to a tech-savvy audience. Make it sound cool, avoid complex words so it doesn't sound like AI. Insights: {{ parent_outputs["research"] }}""", - ), - driver=LocalStructureRunDriver( - structure_factory_fn=lambda: build_writer( - role=writer["role"], - goal=writer["goal"], - backstory=writer["backstory"], - ) - ), - ) - for writer in WRITERS - ], - end_task, + ), + driver=LocalStructureRunDriver( + structure_factory_fn=lambda: build_writer( + role=writer["role"], + goal=writer["goal"], + backstory=writer["backstory"], + ) + ), + parent_ids=[research_task.id], + ) + for writer in WRITERS + ]) + end_task = team.add_task( + PromptTask( + 'State "All Done!"', + parent_ids=[writer_task.id for writer_task in writer_tasks], + ) ) team.run() diff --git a/docs/griptape-framework/structures/workflows.md b/docs/griptape-framework/structures/workflows.md index 3c2bac25a6..bc2af13d35 100644 --- a/docs/griptape-framework/structures/workflows.md +++ b/docs/griptape-framework/structures/workflows.md @@ -18,7 +18,14 @@ Let's build a simple workflow. Let's say, we want to write a story in a fantasy from griptape.tasks import PromptTask from griptape.structures import Workflow -workflow = Workflow() + +world_task = PromptTask( + "Create a fictional world based on the following key words {{ keywords|join(', ') }}", + context={ + "keywords": ["fantasy", "ocean", "tidal lock"] + }, + id="world" +) def character_task(task_id, character_name) -> PromptTask: return PromptTask( @@ -26,30 +33,20 @@ def character_task(task_id, character_name) -> PromptTask: context={ "name": character_name }, - id=task_id + id=task_id, + parent_ids=["world"] ) -world_task = PromptTask( - "Create a fictional world based on the following key words {{ keywords|join(', ') }}", - context={ - "keywords": ["fantasy", "ocean", "tidal lock"] - }, - id="world" -) -workflow.add_task(world_task) +scotty_task = character_task("scotty", "Scotty") +annie_task = character_task("annie", "Annie") story_task = PromptTask( "Based on the following description of the world and characters, write a short story:\n{{ parent_outputs['world'] }}\n{{ parent_outputs['scotty'] }}\n{{ parent_outputs['annie'] }}", - id="story" + id="story", + parent_ids=["world", "scotty", "annie"] ) -workflow.add_task(story_task) - -character_task_1 = character_task("scotty", "Scotty") -character_task_2 = character_task("annie", "Annie") -# Note the preserve_relationship flag. This ensures that world_task remains a parent of -# story_task so its output can be referenced in the story_task prompt. -workflow.insert_tasks(world_task, [character_task_1, character_task_2], story_task, preserve_relationship=True) +workflow = Workflow(tasks=[world_task, story_task, scotty_task, annie_task, story_task]) workflow.run() ``` diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 8b71dd905c..84058dc7a0 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -209,6 +209,26 @@ def publish_event(self, event: BaseEvent, flush: bool = False) -> None: def context(self, task: BaseTask) -> dict[str, Any]: return {"args": self.execution_args, "structure": self} + def resolve_relationships(self) -> None: + task_by_id = {task.id: task for task in self.tasks} + + for task in self.tasks: + # Ensure parents include this task as a child + for parent_id in task.parent_ids: + if parent_id not in task_by_id: + raise ValueError(f"Task with id {parent_id} doesn't exist.") + parent = task_by_id[parent_id] + if task.id not in parent.child_ids: + parent.child_ids.append(task.id) + + # Ensure children include this task as a parent + for child_id in task.child_ids: + if child_id not in task_by_id: + raise ValueError(f"Task with id {child_id} doesn't exist.") + child = task_by_id[child_id] + if task.id not in child.parent_ids: + child.parent_ids.append(task.id) + def before_run(self) -> None: self.publish_event( StartStructureRunEvent( @@ -216,6 +236,8 @@ def before_run(self) -> None: ) ) + self.resolve_relationships() + def after_run(self) -> None: self.publish_event( FinishStructureRunEvent( diff --git a/griptape/structures/workflow.py b/griptape/structures/workflow.py index e60efa4258..5927a1f4f6 100644 --- a/griptape/structures/workflow.py +++ b/griptape/structures/workflow.py @@ -16,10 +16,6 @@ class Workflow(Structure): def add_task(self, task: BaseTask) -> BaseTask: task.preprocess(self) - if self.output_task: - self.output_task.child_ids.append(task.id) - task.parent_ids.append(self.output_task.id) - self.tasks.append(task) return task @@ -77,6 +73,7 @@ def insert_task( if parent_task.id in child_task.parent_ids: child_task.parent_ids.remove(parent_task.id) + last_parent_index = -1 for parent_task in parent_tasks: # Link the new task to the parent task if parent_task.id not in task.parent_ids: @@ -85,7 +82,11 @@ def insert_task( parent_task.child_ids.append(task.id) parent_index = self.tasks.index(parent_task) - self.tasks.insert(parent_index + 1, task) + if parent_index > last_parent_index: + last_parent_index = parent_index + + # Insert the new task once, just after the last parent task + self.tasks.insert(last_parent_index + 1, task) return task diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index d47d1df32d..dfc54aca6c 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -172,24 +172,6 @@ def actions_to_dicts(self) -> list[dict]: def actions_to_json(self) -> str: return json.dumps(self.actions_to_dicts()) - def add_child(self, child: ActionsSubtask) -> ActionsSubtask: - if child.id not in self.child_ids: - self.child_ids.append(child.id) - - if self.id not in child.parent_ids: - child.parent_ids.append(self.id) - - return child - - def add_parent(self, parent: ActionsSubtask) -> ActionsSubtask: - if parent.id not in self.parent_ids: - self.parent_ids.append(parent.id) - - if self.id not in parent.child_ids: - parent.child_ids.append(self.id) - - return parent - def __init_from_prompt(self, value: str) -> None: thought_matches = re.findall(self.THOUGHT_PATTERN, value, re.MULTILINE) actions_matches = re.findall(self.ACTIONS_PATTERN, value, re.DOTALL) diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 771fe4dc89..f6cdddf0d6 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -60,6 +60,26 @@ def meta_memories(self) -> list[BaseMetaEntry]: def __str__(self) -> str: return str(self.output.value) + def add_parents(self, *parents: str | BaseTask) -> None: + for parent in parents: + self.add_parent(parent) + + def add_parent(self, parent: str | BaseTask) -> None: + parent_id = parent if isinstance(parent, str) else parent.id + + if parent_id not in self.parent_ids: + self.parent_ids.append(parent_id) + + def add_children(self, *children: str | BaseTask) -> None: + for child in children: + self.add_child(child) + + def add_child(self, child: str | BaseTask) -> None: + child_id = child if isinstance(child, str) else child.id + + if child_id not in self.child_ids: + self.child_ids.append(child_id) + def preprocess(self, structure: Structure) -> BaseTask: self.structure = structure diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index 970d43e74d..103a8aa9cc 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -140,10 +140,10 @@ def test_tasks_initialization(self): assert workflow.tasks[1].id == "test2" assert workflow.tasks[2].id == "test3" assert len(first_task.parents) == 0 - assert len(first_task.children) == 1 - assert len(second_task.parents) == 1 - assert len(second_task.children) == 1 - assert len(third_task.parents) == 1 + assert len(first_task.children) == 0 + assert len(second_task.parents) == 0 + assert len(second_task.children) == 0 + assert len(third_task.parents) == 0 assert len(third_task.children) == 0 def test_add_task(self): @@ -161,8 +161,8 @@ def test_add_task(self): assert first_task.structure == workflow assert second_task.structure == workflow assert len(first_task.parents) == 0 - assert len(first_task.children) == 1 - assert len(second_task.parents) == 1 + assert len(first_task.children) == 0 + assert len(second_task.parents) == 0 assert len(second_task.children) == 0 def test_add_tasks(self): @@ -179,8 +179,8 @@ def test_add_tasks(self): assert first_task.structure == workflow assert second_task.structure == workflow assert len(first_task.parents) == 0 - assert len(first_task.children) == 1 - assert len(second_task.parents) == 1 + assert len(first_task.children) == 0 + assert len(second_task.parents) == 0 assert len(second_task.children) == 0 def test_run(self): @@ -210,7 +210,111 @@ def test_run_with_args(self): assert task.input.to_text() == "-" - def test_run_topology_1(self): + @pytest.mark.parametrize( + "tasks", + [ + [PromptTask(id="task1", parent_ids=["missing"])], + [PromptTask(id="task1", child_ids=["missing"])], + [PromptTask(id="task1"), PromptTask(id="task2", parent_ids=["missing"])], + [PromptTask(id="task1"), PromptTask(id="task2", parent_ids=["task1", "missing"])], + [PromptTask(id="task1"), PromptTask(id="task2", parent_ids=["task1"], child_ids=["missing"])], + ], + ) + def test_run_raises_on_missing_parent_or_child_id(self, tasks): + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=tasks) + + with pytest.raises(ValueError) as e: + workflow.run() + + assert e.value.args[0] == "Task with id missing doesn't exist." + + def test_run_topology_1_declarative_parents(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("test1", id="task1"), + PromptTask("test2", id="task2", parent_ids=["task1"]), + PromptTask("test3", id="task3", parent_ids=["task1"]), + PromptTask("test4", id="task4", parent_ids=["task2", "task3"]), + ], + ) + + workflow.run() + + self._validate_topology_1(workflow) + + def test_run_topology_1_declarative_children(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("test1", id="task1", child_ids=["task2", "task3"]), + PromptTask("test2", id="task2", child_ids=["task4"]), + PromptTask("test3", id="task3", child_ids=["task4"]), + PromptTask("test4", id="task4"), + ], + ) + + workflow.run() + + self._validate_topology_1(workflow) + + def test_run_topology_1_declarative_mixed(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("test1", id="task1", child_ids=["task3"]), + PromptTask("test2", id="task2", parent_ids=["task1"], child_ids=["task4"]), + PromptTask("test3", id="task3"), + PromptTask("test4", id="task4", parent_ids=["task2", "task3"]), + ], + ) + + workflow.run() + + self._validate_topology_1(workflow) + + def test_run_topology_1_imperative_parents(self): + task1 = PromptTask("test1", id="task1") + task2 = PromptTask("test2", id="task2") + task3 = PromptTask("test3", id="task3") + task4 = PromptTask("test4", id="task4") + task2.add_parent(task1) + task3.add_parent("task1") + task4.add_parents(task2, "task3") + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4]) + + workflow.run() + + self._validate_topology_1(workflow) + + def test_run_topology_1_imperative_children(self): + task1 = PromptTask("test1", id="task1") + task2 = PromptTask("test2", id="task2") + task3 = PromptTask("test3", id="task3") + task4 = PromptTask("test4", id="task4") + task1.add_children(task2, task3) + task2.add_child(task4) + task3.add_child(task4) + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4]) + + workflow.run() + + self._validate_topology_1(workflow) + + def test_run_topology_1_imperative_mixed(self): + task1 = PromptTask("test1", id="task1") + task2 = PromptTask("test2", id="task2") + task3 = PromptTask("test3", id="task3") + task4 = PromptTask("test4", id="task4") + task1.add_children(task2, task3) + task4.add_parents(task2, task3) + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4]) + + workflow.run() + + self._validate_topology_1(workflow) + + def test_run_topology_1_imperative_insert(self): task1 = PromptTask("test1", id="task1") task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") @@ -225,23 +329,89 @@ def test_run_topology_1(self): workflow.run() - assert task1.state == BaseTask.State.FINISHED - assert task1.parent_ids == [] - assert task1.child_ids == ["task2", "task3"] + self._validate_topology_1(workflow) + + def test_run_topology_2_declarative_parents(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("testa", id="taska"), + PromptTask("testb", id="taskb", parent_ids=["taska"]), + PromptTask("testc", id="taskc", parent_ids=["taska"]), + PromptTask("testd", id="taskd", parent_ids=["taska", "taskb", "taskc"]), + PromptTask("teste", id="taske", parent_ids=["taska", "taskd", "taskc"]), + ], + ) - assert task2.state == BaseTask.State.FINISHED - assert task2.parent_ids == ["task1"] - assert task2.child_ids == ["task4"] + workflow.run() - assert task3.state == BaseTask.State.FINISHED - assert task3.parent_ids == ["task1"] - assert task3.child_ids == ["task4"] + self._validate_topology_2(workflow) + + def test_run_topology_2_declarative_children(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("testa", id="taska", child_ids=["taskb", "taskc", "taskd", "taske"]), + PromptTask("testb", id="taskb", child_ids=["taskd"]), + PromptTask("testc", id="taskc", child_ids=["taskd", "taske"]), + PromptTask("testd", id="taskd", child_ids=["taske"]), + PromptTask("teste", id="taske", child_ids=[]), + ], + ) - assert task4.state == BaseTask.State.FINISHED - assert task4.parent_ids == ["task2", "task3"] - assert task4.child_ids == [] + workflow.run() + + self._validate_topology_2(workflow) + + def test_run_topology_2_imperative_parents(self): + taska = PromptTask("testa", id="taska") + taskb = PromptTask("testb", id="taskb") + taskc = PromptTask("testc", id="taskc") + taskd = PromptTask("testd", id="taskd") + taske = PromptTask("teste", id="taske") + taskb.add_parent(taska) + taskc.add_parent("taska") + taskd.add_parents(taska, taskb, taskc) + taske.add_parents("taska", taskd, "taskc") + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[taska, taskb, taskc, taskd, taske]) + + workflow.run() + + self._validate_topology_2(workflow) + + def test_run_topology_2_imperative_children(self): + taska = PromptTask("testa", id="taska") + taskb = PromptTask("testb", id="taskb") + taskc = PromptTask("testc", id="taskc") + taskd = PromptTask("testd", id="taskd") + taske = PromptTask("teste", id="taske") + taska.add_children(taskb, taskc, taskd, taske) + taskb.add_child(taskd) + taskc.add_children(taskd, taske) + taskd.add_child(taske) + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[taska, taskb, taskc, taskd, taske]) - def test_run_topology_2(self): + workflow.run() + + self._validate_topology_2(workflow) + + def test_run_topology_2_imperative_mixed(self): + taska = PromptTask("testa", id="taska") + taskb = PromptTask("testb", id="taskb") + taskc = PromptTask("testc", id="taskc") + taskd = PromptTask("testd", id="taskd") + taske = PromptTask("teste", id="taske") + taska.add_children(taskb, taskc, taskd, taske) + taskb.add_child(taskd) + taskd.add_parent(taskc) + taske.add_parents("taska", taskd, "taskc") + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[taska, taskb, taskc, taskd, taske]) + + workflow.run() + + self._validate_topology_2(workflow) + + def test_run_topology_2_imperative_insert(self): """Adapted from https://en.wikipedia.org/wiki/Directed_acyclic_graph#/media/File:Tred-G.svg""" taska = PromptTask("testa", id="taska") taskb = PromptTask("testb", id="taskb") @@ -249,36 +419,63 @@ def test_run_topology_2(self): taskd = PromptTask("testd", id="taskd") taske = PromptTask("teste", id="taske") workflow = Workflow(prompt_driver=MockPromptDriver()) - workflow.add_task(taska) workflow.add_task(taske) + taske.add_parent(taska) workflow.insert_tasks(taska, taskd, taske, preserve_relationship=True) workflow.insert_tasks(taska, [taskc], [taskd, taske], preserve_relationship=True) workflow.insert_tasks(taska, taskb, taskd, preserve_relationship=True) workflow.run() - assert taska.state == BaseTask.State.FINISHED - assert taska.parent_ids == [] - assert set(taska.child_ids) == {"taskb", "taskd", "taskc", "taske"} + self._validate_topology_2(workflow) + + def test_run_topology_3_declarative_parents(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("test1", id="task1"), + PromptTask("test2", id="task2", parent_ids=["task4"]), + PromptTask("test4", id="task4", parent_ids=["task1"]), + PromptTask("test3", id="task3", parent_ids=["task2"]), + ], + ) - assert taskb.state == BaseTask.State.FINISHED - assert taskb.parent_ids == ["taska"] - assert taskb.child_ids == ["taskd"] + workflow.run() - assert taskc.state == BaseTask.State.FINISHED - assert taskc.parent_ids == ["taska"] - assert set(taskc.child_ids) == {"taskd", "taske"} + self._validate_topology_3(workflow) + + def test_run_topology_3_declarative_children(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("test1", id="task1", child_ids=["task4"]), + PromptTask("test2", id="task2", child_ids=["task3"]), + PromptTask("test4", id="task4", child_ids=["task2"]), + PromptTask("test3", id="task3", child_ids=[]), + ], + ) - assert taskd.state == BaseTask.State.FINISHED - assert set(taskd.parent_ids) == {"taskb", "taska", "taskc"} - assert taskd.child_ids == ["taske"] + workflow.run() - assert taske.state == BaseTask.State.FINISHED - assert set(taske.parent_ids) == {"taskd", "taskc", "taska"} - assert taske.child_ids == [] + self._validate_topology_3(workflow) + + def test_run_topology_3_declarative_mixed(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("test1", id="task1"), + PromptTask("test2", id="task2", parent_ids=["task4"], child_ids=["task3"]), + PromptTask("test4", id="task4", parent_ids=["task1"], child_ids=["task2"]), + PromptTask("test3", id="task3"), + ], + ) + + workflow.run() - def test_run_topology_3(self): + self._validate_topology_3(workflow) + + def test_run_topology_3_imperative_insert(self): task1 = PromptTask("test1", id="task1") task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") @@ -288,28 +485,75 @@ def test_run_topology_3(self): workflow + task1 workflow + task2 workflow + task3 + task2.add_parent(task1) + task3.add_parent(task2) workflow.insert_tasks(task1, task4, task2) workflow.run() - assert task1.state == BaseTask.State.FINISHED - assert task1.parent_ids == [] - assert task1.child_ids == ["task4"] + self._validate_topology_3(workflow) + + def test_run_topology_4_declarative_parents(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask(id="collect_movie_info"), + PromptTask(id="movie_info_1", parent_ids=["collect_movie_info"]), + PromptTask(id="movie_info_2", parent_ids=["collect_movie_info"]), + PromptTask(id="movie_info_3", parent_ids=["collect_movie_info"]), + PromptTask(id="compare_movies", parent_ids=["movie_info_1", "movie_info_2", "movie_info_3"]), + PromptTask(id="send_email_task", parent_ids=["compare_movies"]), + PromptTask(id="save_to_disk", parent_ids=["compare_movies"]), + PromptTask(id="publish_website", parent_ids=["compare_movies"]), + PromptTask(id="summarize_to_slack", parent_ids=["send_email_task", "save_to_disk", "publish_website"]), + ], + ) - assert task2.state == BaseTask.State.FINISHED - assert task2.parent_ids == ["task4"] - assert task2.child_ids == ["task3"] + workflow.run() - assert task3.state == BaseTask.State.FINISHED - assert task3.parent_ids == ["task2"] - assert task3.child_ids == [] + self._validate_topology_4(workflow) + + def test_run_topology_4_declarative_children(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask(id="collect_movie_info", child_ids=["movie_info_1", "movie_info_2", "movie_info_3"]), + PromptTask(id="movie_info_1", child_ids=["compare_movies"]), + PromptTask(id="movie_info_2", child_ids=["compare_movies"]), + PromptTask(id="movie_info_3", child_ids=["compare_movies"]), + PromptTask(id="compare_movies", child_ids=["send_email_task", "save_to_disk", "publish_website"]), + PromptTask(id="send_email_task", child_ids=["summarize_to_slack"]), + PromptTask(id="save_to_disk", child_ids=["summarize_to_slack"]), + PromptTask(id="publish_website", child_ids=["summarize_to_slack"]), + PromptTask(id="summarize_to_slack", child_ids=[]), + ], + ) - assert task4.state == BaseTask.State.FINISHED - assert task4.parent_ids == ["task1"] - assert task4.child_ids == ["task2"] + workflow.run() - def test_run_topology_4(self): - workflow = Workflow(prompt_driver=MockPromptDriver()) + self._validate_topology_4(workflow) + + def test_run_topology_4_declarative_mixed(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask(id="collect_movie_info"), + PromptTask(id="movie_info_1", parent_ids=["collect_movie_info"], child_ids=["compare_movies"]), + PromptTask(id="movie_info_2", parent_ids=["collect_movie_info"], child_ids=["compare_movies"]), + PromptTask(id="movie_info_3", parent_ids=["collect_movie_info"], child_ids=["compare_movies"]), + PromptTask(id="compare_movies"), + PromptTask(id="send_email_task", parent_ids=["compare_movies"], child_ids=["summarize_to_slack"]), + PromptTask(id="save_to_disk", parent_ids=["compare_movies"], child_ids=["summarize_to_slack"]), + PromptTask(id="publish_website", parent_ids=["compare_movies"], child_ids=["summarize_to_slack"]), + PromptTask(id="summarize_to_slack"), + ], + ) + + workflow.run() + + self._validate_topology_4(workflow) + + def test_run_topology_4_imperative_insert(self): collect_movie_info = PromptTask(id="collect_movie_info") summarize_to_slack = PromptTask(id="summarize_to_slack") movie_info_1 = PromptTask(id="movie_info_1") @@ -321,30 +565,34 @@ def test_run_topology_4(self): publish_website = PromptTask(id="publish_website") movie_info_3 = PromptTask(id="movie_info_3") + workflow = Workflow(prompt_driver=MockPromptDriver()) workflow.add_tasks(collect_movie_info, summarize_to_slack) workflow.insert_tasks(collect_movie_info, [movie_info_1, movie_info_2, movie_info_3], summarize_to_slack) workflow.insert_tasks([movie_info_1, movie_info_2, movie_info_3], compare_movies, summarize_to_slack) workflow.insert_tasks(compare_movies, [send_email_task, save_to_disk, publish_website], summarize_to_slack) - assert set(collect_movie_info.child_ids) == {"movie_info_1", "movie_info_2", "movie_info_3"} - - assert set(movie_info_1.parent_ids) == {"collect_movie_info"} - assert set(movie_info_2.parent_ids) == {"collect_movie_info"} - assert set(movie_info_3.parent_ids) == {"collect_movie_info"} - assert set(movie_info_1.child_ids) == {"compare_movies"} - assert set(movie_info_2.child_ids) == {"compare_movies"} - assert set(movie_info_3.child_ids) == {"compare_movies"} - - assert set(compare_movies.parent_ids) == {"movie_info_1", "movie_info_2", "movie_info_3"} - assert set(compare_movies.child_ids) == {"send_email_task", "save_to_disk", "publish_website"} - - assert set(send_email_task.parent_ids) == {"compare_movies"} - assert set(save_to_disk.parent_ids) == {"compare_movies"} - assert set(publish_website.parent_ids) == {"compare_movies"} - - assert set(send_email_task.child_ids) == {"summarize_to_slack"} - assert set(save_to_disk.child_ids) == {"summarize_to_slack"} - assert set(publish_website.child_ids) == {"summarize_to_slack"} + self._validate_topology_4(workflow) + + @pytest.mark.parametrize( + "tasks", + [ + [PromptTask(id="a", parent_ids=["a"])], + [PromptTask(id="a"), PromptTask(id="b", parent_ids=["a", "b"])], + [PromptTask(id="a", parent_ids=["b"]), PromptTask(id="b", parent_ids=["a"])], + [ + PromptTask(id="a", parent_ids=["c"]), + PromptTask(id="b", parent_ids=["a"]), + PromptTask(id="c", parent_ids=["b"]), + ], + ], + ) + def test_run_raises_on_cycle(self, tasks): + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=tasks) + + with pytest.raises(ValueError) as e: + workflow.run() + + assert e.value.args[0] == "nodes are in a cycle" def test_input_task(self): task1 = PromptTask("prompt1") @@ -417,6 +665,9 @@ def test_context(self): workflow + task workflow + child + task.add_parent(parent) + task.add_child(child) + context = workflow.context(task) assert context["parent_outputs"] == {parent.id: ""} @@ -439,3 +690,133 @@ def test_deprecation(self): with pytest.deprecated_call(): Workflow(stream=True) + + @staticmethod + def _validate_topology_1(workflow): + assert len(workflow.tasks) == 4 + assert workflow.input_task.id == "task1" + assert workflow.output_task.id == "task4" + assert workflow.input_task.id == workflow.tasks[0].id + assert workflow.output_task.id == workflow.tasks[-1].id + + task1 = workflow.find_task("task1") + assert task1.state == BaseTask.State.FINISHED + assert task1.parent_ids == [] + assert sorted(task1.child_ids) == ["task2", "task3"] + + task2 = workflow.find_task("task2") + assert task2.state == BaseTask.State.FINISHED + assert task2.parent_ids == ["task1"] + assert task2.child_ids == ["task4"] + + task3 = workflow.find_task("task3") + assert task3.state == BaseTask.State.FINISHED + assert task3.parent_ids == ["task1"] + assert task3.child_ids == ["task4"] + + task4 = workflow.find_task("task4") + assert task4.state == BaseTask.State.FINISHED + assert sorted(task4.parent_ids) == ["task2", "task3"] + assert task4.child_ids == [] + + @staticmethod + def _validate_topology_2(workflow): + """Adapted from https://en.wikipedia.org/wiki/Directed_acyclic_graph#/media/File:Tred-G.svg""" + assert len(workflow.tasks) == 5 + assert workflow.input_task.id == "taska" + assert workflow.output_task.id == "taske" + assert workflow.input_task.id == workflow.tasks[0].id + assert workflow.output_task.id == workflow.tasks[-1].id + + taska = workflow.find_task("taska") + assert taska.state == BaseTask.State.FINISHED + assert taska.parent_ids == [] + assert sorted(taska.child_ids) == ["taskb", "taskc", "taskd", "taske"] + + taskb = workflow.find_task("taskb") + assert taskb.state == BaseTask.State.FINISHED + assert taskb.parent_ids == ["taska"] + assert taskb.child_ids == ["taskd"] + + taskc = workflow.find_task("taskc") + assert taskc.state == BaseTask.State.FINISHED + assert taskc.parent_ids == ["taska"] + assert sorted(taskc.child_ids) == ["taskd", "taske"] + + taskd = workflow.find_task("taskd") + assert taskd.state == BaseTask.State.FINISHED + assert sorted(taskd.parent_ids) == ["taska", "taskb", "taskc"] + assert taskd.child_ids == ["taske"] + + taske = workflow.find_task("taske") + assert taske.state == BaseTask.State.FINISHED + assert sorted(taske.parent_ids) == ["taska", "taskc", "taskd"] + assert taske.child_ids == [] + + @staticmethod + def _validate_topology_3(workflow): + assert len(workflow.tasks) == 4 + assert workflow.input_task.id == "task1" + assert workflow.output_task.id == "task3" + assert workflow.input_task.id == workflow.tasks[0].id + assert workflow.output_task.id == workflow.tasks[-1].id + + task1 = workflow.find_task("task1") + assert task1.state == BaseTask.State.FINISHED + assert task1.parent_ids == [] + assert task1.child_ids == ["task4"] + + task2 = workflow.find_task("task2") + assert task2.state == BaseTask.State.FINISHED + assert task2.parent_ids == ["task4"] + assert task2.child_ids == ["task3"] + + task3 = workflow.find_task("task3") + assert task3.state == BaseTask.State.FINISHED + assert task3.parent_ids == ["task2"] + assert task3.child_ids == [] + + task4 = workflow.find_task("task4") + assert task4.state == BaseTask.State.FINISHED + assert task4.parent_ids == ["task1"] + assert task4.child_ids == ["task2"] + + @staticmethod + def _validate_topology_4(workflow): + assert len(workflow.tasks) == 9 + assert workflow.input_task.id == "collect_movie_info" + assert workflow.output_task.id == "summarize_to_slack" + assert workflow.input_task.id == workflow.tasks[0].id + assert workflow.output_task.id == workflow.tasks[-1].id + + collect_movie_info = workflow.find_task("collect_movie_info") + assert collect_movie_info.parent_ids == [] + assert sorted(collect_movie_info.child_ids) == ["movie_info_1", "movie_info_2", "movie_info_3"] + + movie_info_1 = workflow.find_task("movie_info_1") + assert movie_info_1.parent_ids == ["collect_movie_info"] + assert movie_info_1.child_ids == ["compare_movies"] + + movie_info_2 = workflow.find_task("movie_info_2") + assert movie_info_2.parent_ids == ["collect_movie_info"] + assert movie_info_2.child_ids == ["compare_movies"] + + movie_info_3 = workflow.find_task("movie_info_3") + assert movie_info_3.parent_ids == ["collect_movie_info"] + assert movie_info_3.child_ids == ["compare_movies"] + + compare_movies = workflow.find_task("compare_movies") + assert sorted(compare_movies.parent_ids) == ["movie_info_1", "movie_info_2", "movie_info_3"] + assert sorted(compare_movies.child_ids) == ["publish_website", "save_to_disk", "send_email_task"] + + send_email_task = workflow.find_task("send_email_task") + assert send_email_task.parent_ids == ["compare_movies"] + assert send_email_task.child_ids == ["summarize_to_slack"] + + save_to_disk = workflow.find_task("save_to_disk") + assert save_to_disk.parent_ids == ["compare_movies"] + assert save_to_disk.child_ids == ["summarize_to_slack"] + + publish_website = workflow.find_task("publish_website") + assert publish_website.parent_ids == ["compare_movies"] + assert publish_website.child_ids == ["summarize_to_slack"] diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index bba94f82da..f4933e3193 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -237,11 +237,13 @@ def test_add_subtask(self): "test2", actions=[ActionsSubtask.Action(tag="foo", name="test", path="test", input={"values": {"f": "b"}})] ) - Agent().add_task(task) + agent = Agent(tasks=[task]) task.add_subtask(subtask1) task.add_subtask(subtask2) + agent.resolve_relationships() + assert len(task.subtasks) == 2 assert len(subtask1.children) == 1