From 6ceadf6446e33ae3a7769f41afd22b23fba977cf Mon Sep 17 00:00:00 2001 From: Matt Vallillo Date: Tue, 18 Jun 2024 11:38:16 -0500 Subject: [PATCH] Add 'self.order_tasks()' before getting 'output_task' (#873) --- griptape/structures/workflow.py | 6 +++++- tests/unit/structures/test_workflow.py | 9 +++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/griptape/structures/workflow.py b/griptape/structures/workflow.py index f9570edc0..6552fba89 100644 --- a/griptape/structures/workflow.py +++ b/griptape/structures/workflow.py @@ -1,7 +1,7 @@ from __future__ import annotations import concurrent.futures as futures from graphlib import TopologicalSorter -from typing import Any +from typing import Any, Optional from attrs import define, field, Factory from griptape.artifacts import ErrorArtifact from griptape.structures import Structure @@ -13,6 +13,10 @@ class Workflow(Structure): futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True) + @property + def output_task(self) -> Optional[BaseTask]: + return self.order_tasks()[-1] if self.tasks else None + def add_task(self, task: BaseTask) -> BaseTask: task.preprocess(self) diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index 14661036d..bbcf5138e 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -637,6 +637,15 @@ def test_output_task(self): assert task4 == workflow.output_task + task4.add_parents([task2, task3]) + task1.add_children([task2, task3]) + + # task4 is the final task, but its defined at index 0 + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task4, task1, task2, task3]) + + # ouput_task topologically should be task4 + assert task4 == workflow.output_task + def test_to_graph(self): task1 = PromptTask("prompt1", id="task1") task2 = PromptTask("prompt2", id="task2")