Skip to content

Commit

Permalink
Feature/workflow improvements (#1191)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Sep 25, 2024
1 parent 2d04000 commit 53bc38b
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 23 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
### Added
- `Workflow.input_tasks` and `Workflow.output_tasks` to access the input and output tasks of a Workflow.
- Ability to pass nested list of `Tasks` to `Structure.tasks` allowing for more complex declarative Structure definitions.


## Added
- Parameter `pipeline_task` on `HuggingFacePipelinePromptDriver` for creating different types of `Pipeline`s.
Expand Down
6 changes: 3 additions & 3 deletions griptape/structures/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ def task(self) -> BaseTask:
return self.tasks[0]

def add_task(self, task: BaseTask) -> BaseTask:
self.tasks.clear()
self._tasks.clear()

task.preprocess(self)

self.tasks.append(task)
self._tasks.append(task)

return task

def add_tasks(self, *tasks: BaseTask) -> list[BaseTask]:
def add_tasks(self, *tasks: BaseTask | list[BaseTask]) -> list[BaseTask]:
if len(tasks) > 1:
raise ValueError("Agents can only have one task.")
return super().add_tasks(*tasks)
Expand Down
4 changes: 2 additions & 2 deletions griptape/structures/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def add_task(self, task: BaseTask) -> BaseTask:
self.output_task.child_ids.append(task.id)
task.parent_ids.append(self.output_task.id)

self.tasks.append(task)
self._tasks.append(task)

return task

Expand All @@ -45,7 +45,7 @@ def insert_task(self, parent_task: BaseTask, task: BaseTask) -> BaseTask:
parent_task.child_ids.append(task.id)

parent_index = self.tasks.index(parent_task)
self.tasks.insert(parent_index + 1, task)
self._tasks.insert(parent_index + 1, task)

return task

Expand Down
31 changes: 24 additions & 7 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Structure(ABC):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
rulesets: list[Ruleset] = field(factory=list, kw_only=True)
rules: list[BaseRule] = field(factory=list, kw_only=True)
tasks: list[BaseTask] = field(factory=list, kw_only=True)
_tasks: list[BaseTask | list[BaseTask]] = field(factory=list, kw_only=True, alias="tasks")
conversation_memory: Optional[BaseConversationMemory] = field(
default=Factory(lambda: ConversationMemory()),
kw_only=True,
Expand Down Expand Up @@ -54,12 +54,23 @@ def validate_rules(self, _: Attribute, rules: list[Rule]) -> None:
raise ValueError("can't have both rules and rulesets specified")

def __attrs_post_init__(self) -> None:
tasks = self.tasks.copy()
self.tasks.clear()
tasks = self._tasks.copy()
self._tasks.clear()
self.add_tasks(*tasks)

def __add__(self, other: BaseTask | list[BaseTask]) -> list[BaseTask]:
return self.add_tasks(*other) if isinstance(other, list) else self + [other]
def __add__(self, other: BaseTask | list[BaseTask | list[BaseTask]]) -> list[BaseTask]:
return self.add_tasks(*other) if isinstance(other, list) else self.add_tasks(other)

@property
def tasks(self) -> list[BaseTask]:
tasks = []

for task in self._tasks:
if isinstance(task, list):
tasks.extend(task)
else:
tasks.append(task)
return tasks

@property
def execution_args(self) -> tuple:
Expand Down Expand Up @@ -98,8 +109,14 @@ def try_find_task(self, task_id: str) -> Optional[BaseTask]:
return task
return None

def add_tasks(self, *tasks: BaseTask) -> list[BaseTask]:
return [self.add_task(s) for s in tasks]
def add_tasks(self, *tasks: BaseTask | list[BaseTask]) -> list[BaseTask]:
added_tasks = []
for task in tasks:
if isinstance(task, list):
added_tasks.extend(self.add_tasks(*task))
else:
added_tasks.append(self.add_task(task))
return added_tasks

def context(self, task: BaseTask) -> dict[str, Any]:
return {"args": self.execution_args, "structure": self}
Expand Down
12 changes: 10 additions & 2 deletions griptape/structures/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,21 @@ def input_task(self) -> Optional[BaseTask]:
def output_task(self) -> Optional[BaseTask]:
return self.order_tasks()[-1] if self.tasks else None

@property
def input_tasks(self) -> list[BaseTask]:
return [task for task in self.tasks if not task.parents]

@property
def output_tasks(self) -> list[BaseTask]:
return [task for task in self.tasks if not task.children]

def add_task(self, task: BaseTask) -> BaseTask:
if (existing_task := self.try_find_task(task.id)) is not None:
return existing_task

task.preprocess(self)

self.tasks.append(task)
self._tasks.append(task)

return task

Expand Down Expand Up @@ -82,7 +90,7 @@ def insert_task(
last_parent_index = self.__link_task_to_parents(task, parent_tasks)

# Insert the new task once, just after the last parent task
self.tasks.insert(last_parent_index + 1, task)
self._tasks.insert(last_parent_index + 1, task)

return task

Expand Down
13 changes: 5 additions & 8 deletions tests/unit/structures/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,17 +133,14 @@ def test_add_tasks(self):

agent = Agent(prompt_driver=MockPromptDriver())

try:
with pytest.raises(ValueError):
agent.add_tasks(first_task, second_task)
raise AssertionError()
except ValueError:
assert True

try:
with pytest.raises(ValueError):
agent + [first_task, second_task]
raise AssertionError()
except ValueError:
assert True

with pytest.raises(ValueError):
agent.add_tasks([first_task, second_task])

def test_prompt_stack_without_memory(self):
agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=None, rules=[Rule("test")])
Expand Down
18 changes: 17 additions & 1 deletion tests/unit/structures/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,22 @@ def test_add_tasks(self):
assert len(second_task.parents) == 1
assert len(second_task.children) == 0

def test_nested_tasks(self):
pipeline = Pipeline(
tasks=[
[
PromptTask("parent", id=f"parent_{i}"),
PromptTask("child", id=f"child_{i}", parent_ids=[f"parent_{i}"]),
PromptTask("grandchild", id=f"grandchild_{i}", parent_ids=[f"child_{i}"]),
]
for i in range(3)
]
)

pipeline.run()
assert pipeline.output_task.id == "grandchild_2"
assert len(pipeline.tasks) == 9

def test_insert_task_in_middle(self):
first_task = PromptTask("test1", id="test1")
second_task = PromptTask("test2", id="test2")
Expand Down Expand Up @@ -374,7 +390,7 @@ def test_add_duplicate_task_directly(self):
pipeline = Pipeline()

pipeline + task
pipeline.tasks.append(task)
pipeline._tasks.append(task)

with pytest.raises(ValueError, match=f"Duplicate task with id {task.id} found."):
pipeline.run()
63 changes: 63 additions & 0 deletions tests/unit/structures/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,69 @@ def test_run_with_error_artifact_no_fail_fast(self, error_artifact_task, waiting

assert workflow.output is not None

def test_nested_tasks(self):
workflow = Workflow(
tasks=[
[
PromptTask("parent", id=f"parent_{i}"),
PromptTask("child", id=f"child_{i}", parent_ids=[f"parent_{i}"]),
PromptTask("grandchild", id=f"grandchild_{i}", parent_ids=[f"child_{i}"]),
]
for i in range(3)
],
)

workflow.run()

output_ids = [task.id for task in workflow.output_tasks]
assert output_ids == ["grandchild_0", "grandchild_1", "grandchild_2"]
assert len(workflow.tasks) == 9

def test_nested_tasks_property(self):
workflow = Workflow()
workflow._tasks = [
[
PromptTask("parent", id=f"parent_{i}"),
PromptTask("child", id=f"child_{i}", parent_ids=[f"parent_{i}"]),
PromptTask("grandchild", id=f"grandchild_{i}", parent_ids=[f"child_{i}"]),
]
for i in range(3)
]

assert len(workflow.tasks) == 9

def test_output_tasks(self):
parent = PromptTask("parent")
child = PromptTask("child")
grandchild = PromptTask("grandchild")
workflow = Workflow(
tasks=[
[parent, child, grandchild],
]
)

workflow + parent
parent.add_child(child)
child.add_child(grandchild)

assert workflow.output_tasks == [grandchild]

def test_input_tasks(self):
parent = PromptTask("parent")
child = PromptTask("child")
grandchild = PromptTask("grandchild")
workflow = Workflow(
tasks=[
[parent, child, grandchild],
]
)

workflow + parent
parent.add_child(child)
child.add_child(grandchild)

assert workflow.input_tasks == [parent]

@staticmethod
def _validate_topology_1(workflow) -> None:
assert len(workflow.tasks) == 4
Expand Down

0 comments on commit 53bc38b

Please sign in to comment.