Skip to content

Commit

Permalink
squash
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Aug 21, 2024
1 parent e00d10e commit 9738ee7
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 89 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `OpenAiAudioTranscriptionDriver` for integration with OpenAI's speech-to-text models, including Whisper.
- Parameter `env` to `BaseStructureRunDriver` to set environment variables for a Structure Run.
- `PusherEventListenerDriver` to enable sending of framework events over a Pusher WebSocket.
- Parameter `futures_executor_fn` and methods `try_run` and `order_tasks` to `Structure` from `Workflow`.

### Changed
- **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name.
Expand Down
16 changes: 4 additions & 12 deletions griptape/structures/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from attrs import Attribute, Factory, define, field

from griptape.artifacts.text_artifact import TextArtifact
from griptape.common import observable
from griptape.configs import Defaults
from griptape.memory.structure import Run
from griptape.tools import BaseTool
from griptape.structures import Structure
from griptape.tasks import PromptTask, ToolkitTask

Expand Down Expand Up @@ -73,13 +72,6 @@ def add_tasks(self, *tasks: BaseTask) -> list[BaseTask]:
raise ValueError("Agents can only have one task.")
return super().add_tasks(*tasks)

@observable
def try_run(self, *args) -> Agent:
self.task.execute()

if self.conversation_memory and self.output is not None:
run = Run(input=self.input_task.input, output=self.output)

self.conversation_memory.add_run(run)

return self
def resolve_relationships(self) -> None:
if len(self.tasks) > 1:
raise ValueError("Agents can only have one task.")
35 changes: 16 additions & 19 deletions griptape/structures/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from griptape.artifacts import ErrorArtifact
from griptape.common import observable
from griptape.memory.structure import Run
from typing import TYPE_CHECKING, Any
from attrs import define
from griptape.structures import Structure

if TYPE_CHECKING:
Expand Down Expand Up @@ -49,17 +51,6 @@ def insert_task(self, parent_task: BaseTask, task: BaseTask) -> BaseTask:

return task

@observable
def try_run(self, *args) -> Pipeline:
self.__run_from_task(self.input_task)

if self.conversation_memory and self.output is not None:
run = Run(input=self.input_task.input, output=self.output)

self.conversation_memory.add_run(run)

return self

def context(self, task: BaseTask) -> dict[str, Any]:
context = super().context(task)

Expand All @@ -73,11 +64,17 @@ def context(self, task: BaseTask) -> dict[str, Any]:

return context

def __run_from_task(self, task: Optional[BaseTask]) -> None:
if task is None:
return
else:
if isinstance(task.execute(), ErrorArtifact) and self.fail_fast:
return
else:
self.__run_from_task(next(iter(task.children), None))
def resolve_relationships(self) -> None:
for i, task in enumerate(self.tasks):
if i > 0 and self.tasks[i - 1].id not in task.parent_ids:
task.parent_ids.append(self.tasks[i - 1].id)
if i < len(self.tasks) - 1 and self.tasks[i + 1].id not in task.child_ids:
task.child_ids.append(self.tasks[i + 1].id)
if i == 0 and len(task.parent_ids) > 0:
raise ValueError("The first task in a pipeline cannot have a parent.")
if len(task.parent_ids) > 1 or len(task.child_ids) > 1:
raise ValueError("Pipelines can only have one parent and one child per task.")
if i == len(self.tasks) - 1 and len(task.child_ids) > 0:
raise ValueError("The last task in a pipeline cannot have a child.")

super().resolve_relationships()
56 changes: 54 additions & 2 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from concurrent import futures
from graphlib import TopologicalSorter
import uuid
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
Expand All @@ -11,6 +13,14 @@
from griptape.memory import TaskMemory
from griptape.memory.meta import MetaMemory
from griptape.memory.structure import ConversationMemory
from typing import TYPE_CHECKING, Any, Callable, Optional
from attrs import Factory, define, field
from griptape.artifacts import BaseArtifact, ErrorArtifact

from griptape.events import FinishStructureRunEvent, StartStructureRunEvent
from griptape.memory import TaskMemory
from griptape.memory.meta import MetaMemory
from griptape.memory.structure import ConversationMemory, Run

if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact
Expand All @@ -35,6 +45,10 @@ class Structure(ABC):
)
meta_memory: MetaMemory = field(default=Factory(lambda: MetaMemory()), kw_only=True)
fail_fast: bool = field(default=True, kw_only=True)
futures_executor_fn: Callable[[], futures.Executor] = field(
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), kw_only=True
)

_execution_args: tuple = ()

@rulesets.validator # pyright: ignore[reportAttributeAccessIssue]
Expand Down Expand Up @@ -168,5 +182,43 @@ def run(self, *args) -> Structure:

return result

@abstractmethod
def try_run(self, *args) -> Structure: ...
def try_run(self, *args) -> Structure:
exit_loop = False

while not self.is_finished() and not exit_loop:
futures_list = {}
ordered_tasks = self.order_tasks()

for task in ordered_tasks:
if task.can_execute():
future = self.futures_executor_fn().submit(task.execute)
futures_list[future] = task

# Wait for all tasks to complete
for future in futures.as_completed(futures_list):
if isinstance(future.result(), ErrorArtifact) and self.fail_fast:
exit_loop = True

break

if self.conversation_memory and self.output is not None:
run = Run(input=self.input_task.input, output=self.output)

self.conversation_memory.add_run(run)

return self

def order_tasks(self) -> list[BaseTask]:
return [self.find_task(task_id) for task_id in TopologicalSorter(self.to_graph()).static_order()]

def to_graph(self) -> dict[str, set[str]]:
graph: dict[str, set[str]] = {}

for key_task in self.tasks:
graph[key_task.id] = set()

for value_task in self.tasks:
if key_task.id in value_task.child_ids:
graph[key_task.id].add(value_task.id)

return graph
59 changes: 3 additions & 56 deletions griptape/structures/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from attrs import define
from graphlib import TopologicalSorter

from typing import Any, Optional
from attrs import define
from griptape.artifacts import ErrorArtifact
from griptape.common import observable
from griptape.memory.structure import Run
Expand All @@ -17,7 +19,7 @@


@define
class Workflow(Structure, FuturesExecutorMixin):
class Workflow(Structure):
@property
def input_task(self) -> Optional[BaseTask]:
return self.order_tasks()[0] if self.tasks else None
Expand Down Expand Up @@ -126,58 +128,3 @@ def context(self, task: BaseTask) -> dict[str, Any]:
)

return context

def to_graph(self) -> dict[str, set[str]]:
graph: dict[str, set[str]] = {}

for key_task in self.tasks:
graph[key_task.id] = set()

for value_task in self.tasks:
if key_task.id in value_task.child_ids:
graph[key_task.id].add(value_task.id)

return graph

def order_tasks(self) -> list[BaseTask]:
return [self.find_task(task_id) for task_id in TopologicalSorter(self.to_graph()).static_order()]

def __link_task_to_children(self, task: BaseTask, child_tasks: list[BaseTask]) -> None:
for child_task in child_tasks:
# Link the new task to the child task
if child_task.id not in task.child_ids:
task.child_ids.append(child_task.id)
if task.id not in child_task.parent_ids:
child_task.parent_ids.append(task.id)

def __remove_old_parent_child_relationships(
self,
parent_tasks: list[BaseTask],
child_tasks: list[BaseTask],
) -> None:
for parent_task in parent_tasks:
for child_task in child_tasks:
# Remove the old parent/child relationship
if child_task.id in parent_task.child_ids:
parent_task.child_ids.remove(child_task.id)
if parent_task.id in child_task.parent_ids:
child_task.parent_ids.remove(parent_task.id)

def __link_task_to_parents(self, task: BaseTask, parent_tasks: list[BaseTask]) -> int:
last_parent_index = -1
for parent_task in parent_tasks:
# Link the new task to the parent task
if parent_task.id not in task.parent_ids:
task.parent_ids.append(parent_task.id)
if task.id not in parent_task.child_ids:
parent_task.child_ids.append(task.id)

try:
parent_index = self.tasks.index(parent_task)
except ValueError as exc:
raise ValueError(f"Parent task {parent_task.id} not found in workflow.") from exc
else:
if parent_index > last_parent_index:
last_parent_index = parent_index

return last_parent_index
6 changes: 6 additions & 0 deletions tests/unit/structures/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,9 @@ def finished_tasks(self):
def test_fail_fast(self):
with pytest.raises(ValueError):
Agent(prompt_driver=MockPromptDriver(), fail_fast=True)

def test_fail_too_many_tasks(self):
with pytest.raises(ValueError, match="Agents can only have one task."):
agent = Agent()
agent.tasks = [PromptTask("input"), PromptTask("input")]
agent.run()
30 changes: 30 additions & 0 deletions tests/unit/structures/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,33 @@ def test_add_duplicate_task_directly(self):

with pytest.raises(ValueError, match=f"Duplicate task with id {task.id} found."):
pipeline.run()

def test_invalid_task_relationships(self):
pipeline = Pipeline(prompt_driver=MockPromptDriver())
tasks = [PromptTask("test1", id="task1"), PromptTask("test2", id="task2"), PromptTask("test3", id="task3")]
new_task = PromptTask("test4", id="task4")
tasks[1].add_child(new_task)
pipeline.tasks = [*tasks, new_task]

with pytest.raises(ValueError, match="Pipelines can only have one parent and one child per task."):
pipeline.resolve_relationships()

def test_invalid_task_relationships_first_task(self):
pipeline = Pipeline(prompt_driver=MockPromptDriver())
tasks = [PromptTask("test1", id="task1"), PromptTask("test2", id="task2"), PromptTask("test3", id="task3")]
new_task = PromptTask("test4", id="task4")
tasks[0].add_parent(new_task)
pipeline.tasks = [*tasks, new_task]

with pytest.raises(ValueError, match="The first task in a pipeline cannot have a parent."):
pipeline.resolve_relationships()

def test_invalid_task_relationships_last_task(self):
pipeline = Pipeline(prompt_driver=MockPromptDriver())
tasks = [PromptTask("test1", id="task1"), PromptTask("test2", id="task2"), PromptTask("test3", id="task3")]
new_task = PromptTask("test4", id="task4")
new_task.add_child("task2")
pipeline.tasks = [*tasks, new_task]

with pytest.raises(ValueError, match="The last task in a pipeline cannot have a child."):
pipeline.resolve_relationships()

0 comments on commit 9738ee7

Please sign in to comment.