Skip to content

Commit

Permalink
Add 'self.order_tasks()' before getting 'output_task' (#873)
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo authored Jun 18, 2024
1 parent 53ca5b1 commit 6ceadf6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
6 changes: 5 additions & 1 deletion griptape/structures/workflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
import concurrent.futures as futures
from graphlib import TopologicalSorter
from typing import Any
from typing import Any, Optional
from attrs import define, field, Factory
from griptape.artifacts import ErrorArtifact
from griptape.structures import Structure
Expand All @@ -13,6 +13,10 @@
class Workflow(Structure):
futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True)

@property
def output_task(self) -> Optional[BaseTask]:
return self.order_tasks()[-1] if self.tasks else None

def add_task(self, task: BaseTask) -> BaseTask:
task.preprocess(self)

Expand Down
9 changes: 9 additions & 0 deletions tests/unit/structures/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,15 @@ def test_output_task(self):

assert task4 == workflow.output_task

task4.add_parents([task2, task3])
task1.add_children([task2, task3])

# task4 is the final task, but its defined at index 0
workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task4, task1, task2, task3])

# ouput_task topologically should be task4
assert task4 == workflow.output_task

def test_to_graph(self):
task1 = PromptTask("prompt1", id="task1")
task2 = PromptTask("prompt2", id="task2")
Expand Down

0 comments on commit 6ceadf6

Please sign in to comment.