Skip to content

Commit

Permalink
Add and integrate FuturesExecutorMixin (#1069)
Browse files Browse the repository at this point in the history
  • Loading branch information
vasinov authored Aug 16, 2024
1 parent 24a6824 commit 14a0f0d
Show file tree
Hide file tree
Showing 16 changed files with 119 additions and 98 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support for bitshift composition in `BaseTask` for adding parent/child tasks.
- `JsonArtifact` for handling de/seralization of values.
- `Chat.logger_level` for setting what the `Chat` utility sets the logger level to.
- `FuturesExecutorMixin` to DRY up and optimize concurrent code across multiple classes.

### Changed
- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`.
Expand Down
14 changes: 5 additions & 9 deletions griptape/drivers/event_listener/base_event_listener_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,20 @@

import logging
from abc import ABC, abstractmethod
from concurrent import futures
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING

from attrs import Factory, define, field

from griptape.mixins import FuturesExecutorMixin

if TYPE_CHECKING:
from griptape.events import BaseEvent

logger = logging.getLogger(__name__)


@define
class BaseEventListenerDriver(ABC):
futures_executor_fn: Callable[[], futures.Executor] = field(
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()),
kw_only=True,
)
class BaseEventListenerDriver(FuturesExecutorMixin, ABC):
batched: bool = field(default=True, kw_only=True)
batch_size: int = field(default=10, kw_only=True)

Expand All @@ -29,8 +26,7 @@ def batch(self) -> list[dict]:
return self._batch

def publish_event(self, event: BaseEvent | dict, *, flush: bool = False) -> None:
with self.futures_executor_fn() as executor:
executor.submit(self._safe_try_publish_event, event, flush=flush)
self.futures_executor.submit(self._safe_try_publish_event, event, flush=flush)

@abstractmethod
def try_publish_event_payload(self, event_payload: dict) -> None: ...
Expand Down
48 changes: 21 additions & 27 deletions griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@

import uuid
from abc import ABC, abstractmethod
from concurrent import futures
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Optional

from attrs import Factory, define, field
from attrs import define, field

from griptape import utils
from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact
from griptape.mixins import SerializableMixin
from griptape.mixins import FuturesExecutorMixin, SerializableMixin

if TYPE_CHECKING:
from griptape.drivers import BaseEmbeddingDriver


@define
class BaseVectorStoreDriver(SerializableMixin, ABC):
class BaseVectorStoreDriver(SerializableMixin, FuturesExecutorMixin, ABC):
DEFAULT_QUERY_COUNT = 5

@dataclass
Expand All @@ -36,10 +35,6 @@ def to_artifact(self) -> BaseArtifact:
return BaseArtifact.from_json(self.meta["artifact"]) # pyright: ignore[reportOptionalSubscript]

embedding_driver: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True})
futures_executor_fn: Callable[[], futures.Executor] = field(
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()),
kw_only=True,
)

def upsert_text_artifacts(
self,
Expand All @@ -48,24 +43,23 @@ def upsert_text_artifacts(
meta: Optional[dict] = None,
**kwargs,
) -> None:
with self.futures_executor_fn() as executor:
if isinstance(artifacts, list):
utils.execute_futures_list(
[
executor.submit(self.upsert_text_artifact, a, namespace=None, meta=meta, **kwargs)
for a in artifacts
],
)
else:
utils.execute_futures_dict(
{
namespace: executor.submit(
self.upsert_text_artifact, a, namespace=namespace, meta=meta, **kwargs
)
for namespace, artifact_list in artifacts.items()
for a in artifact_list
},
)
if isinstance(artifacts, list):
utils.execute_futures_list(
[
self.futures_executor.submit(self.upsert_text_artifact, a, namespace=None, meta=meta, **kwargs)
for a in artifacts
],
)
else:
utils.execute_futures_dict(
{
namespace: self.futures_executor.submit(
self.upsert_text_artifact, a, namespace=namespace, meta=meta, **kwargs
)
for namespace, artifact_list in artifacts.items()
for a in artifact_list
},
)

def upsert_text_artifact(
self,
Expand Down
9 changes: 3 additions & 6 deletions griptape/engines/rag/modules/base_rag_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,22 @@

import uuid
from abc import ABC
from concurrent import futures
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Optional

from attrs import Factory, define, field

from griptape.common import Message, PromptStack
from griptape.mixins import FuturesExecutorMixin

if TYPE_CHECKING:
from griptape.engines.rag import RagContext


@define(kw_only=True)
class BaseRagModule(ABC):
class BaseRagModule(FuturesExecutorMixin, ABC):
name: str = field(
default=Factory(lambda self: f"{self.__class__.__name__}-{uuid.uuid4().hex}", takes_self=True), kw_only=True
)
futures_executor_fn: Callable[[], futures.Executor] = field(
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()),
)

def generate_prompt_stack(self, system_prompt: Optional[str], query: str) -> PromptStack:
messages = []
Expand Down
11 changes: 3 additions & 8 deletions griptape/engines/rag/stages/base_rag_stage.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from concurrent import futures
from typing import Callable

from attrs import Factory, define, field
from attrs import define

from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import BaseRagModule
from griptape.mixins import FuturesExecutorMixin


@define(kw_only=True)
class BaseRagStage(ABC):
futures_executor_fn: Callable[[], futures.Executor] = field(
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()),
)

class BaseRagStage(FuturesExecutorMixin, ABC):
@abstractmethod
def run(self, context: RagContext) -> RagContext: ...

Expand Down
5 changes: 3 additions & 2 deletions griptape/engines/rag/stages/response_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def modules(self) -> list[BaseRagModule]:
def run(self, context: RagContext) -> RagContext:
logging.info("ResponseRagStage: running %s retrieval modules in parallel", len(self.response_modules))

with self.futures_executor_fn() as executor:
results = utils.execute_futures_list([executor.submit(r.run, context) for r in self.response_modules])
results = utils.execute_futures_list(
[self.futures_executor.submit(r.run, context) for r in self.response_modules]
)

context.outputs = results

Expand Down
5 changes: 3 additions & 2 deletions griptape/engines/rag/stages/retrieval_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def modules(self) -> list[BaseRagModule]:
def run(self, context: RagContext) -> RagContext:
logging.info("RetrievalRagStage: running %s retrieval modules in parallel", len(self.retrieval_modules))

with self.futures_executor_fn() as executor:
results = utils.execute_futures_list([executor.submit(r.run, context) for r in self.retrieval_modules])
results = utils.execute_futures_list(
[self.futures_executor.submit(r.run, context) for r in self.retrieval_modules]
)

# flatten the list of lists
results = list(itertools.chain.from_iterable(results))
Expand Down
22 changes: 10 additions & 12 deletions griptape/loaders/base_loader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from concurrent import futures
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Optional

from attrs import Factory, define, field
from attrs import define, field

from griptape.mixins import FuturesExecutorMixin
from griptape.utils.futures import execute_futures_dict
from griptape.utils.hash import bytes_to_hash, str_to_hash

Expand All @@ -16,11 +16,7 @@


@define
class BaseLoader(ABC):
futures_executor_fn: Callable[[], futures.Executor] = field(
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()),
kw_only=True,
)
class BaseLoader(FuturesExecutorMixin, ABC):
encoding: Optional[str] = field(default=None, kw_only=True)

@abstractmethod
Expand All @@ -36,10 +32,12 @@ def load_collection(
# to avoid duplicate work.
sources_by_key = {self.to_key(source): source for source in sources}

with self.futures_executor_fn() as executor:
return execute_futures_dict(
{key: executor.submit(self.load, source, *args, **kwargs) for key, source in sources_by_key.items()},
)
return execute_futures_dict(
{
key: self.futures_executor.submit(self.load, source, *args, **kwargs)
for key, source in sources_by_key.items()
},
)

def to_key(self, source: Any, *args, **kwargs) -> str:
if isinstance(source, bytes):
Expand Down
2 changes: 2 additions & 0 deletions griptape/mixins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .rule_mixin import RuleMixin
from .serializable_mixin import SerializableMixin
from .media_artifact_file_output_mixin import BlobArtifactFileOutputMixin
from .futures_executor_mixin import FuturesExecutorMixin

__all__ = [
"ActivityMixin",
Expand All @@ -12,4 +13,5 @@
"RuleMixin",
"BlobArtifactFileOutputMixin",
"SerializableMixin",
"FuturesExecutorMixin",
]
32 changes: 32 additions & 0 deletions griptape/mixins/futures_executor_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

import threading
from abc import ABC
from concurrent import futures
from typing import Callable, Optional

from attrs import Factory, define, field


@define(slots=False, kw_only=True)
class FuturesExecutorMixin(ABC):
futures_executor_fn: Callable[[], futures.Executor] = field(
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()),
)
thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock()))

_futures_executor: Optional[futures.Executor] = field(init=False, default=None)

@property
def futures_executor(self) -> futures.Executor:
with self.thread_lock:
if self._futures_executor is None:
self._futures_executor = self.futures_executor_fn()

return self._futures_executor

def __del__(self) -> None:
with self.thread_lock:
if self._futures_executor:
self._futures_executor.shutdown(wait=True)
self._futures_executor = None
37 changes: 16 additions & 21 deletions griptape/structures/workflow.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
from __future__ import annotations

import concurrent.futures as futures
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Optional

from attrs import Factory, define, field
from attrs import define
from graphlib import TopologicalSorter

from griptape.artifacts import ErrorArtifact
from griptape.common import observable
from griptape.memory.structure import Run
from griptape.mixins import FuturesExecutorMixin
from griptape.structures import Structure

if TYPE_CHECKING:
from griptape.tasks import BaseTask


@define
class Workflow(Structure):
futures_executor_fn: Callable[[], futures.Executor] = field(
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()),
kw_only=True,
)

class Workflow(Structure, FuturesExecutorMixin):
@property
def input_task(self) -> Optional[BaseTask]:
return self.order_tasks()[0] if self.tasks else None
Expand Down Expand Up @@ -94,22 +90,21 @@ def insert_task(
def try_run(self, *args) -> Workflow:
exit_loop = False

with self.futures_executor_fn() as executor:
while not self.is_finished() and not exit_loop:
futures_list = {}
ordered_tasks = self.order_tasks()
while not self.is_finished() and not exit_loop:
futures_list = {}
ordered_tasks = self.order_tasks()

for task in ordered_tasks:
if task.can_execute():
future = executor.submit(task.execute)
futures_list[future] = task
for task in ordered_tasks:
if task.can_execute():
future = self.futures_executor.submit(task.execute)
futures_list[future] = task

# Wait for all tasks to complete
for future in futures.as_completed(futures_list):
if isinstance(future.result(), ErrorArtifact) and self.fail_fast:
exit_loop = True
# Wait for all tasks to complete
for future in futures.as_completed(futures_list):
if isinstance(future.result(), ErrorArtifact) and self.fail_fast:
exit_loop = True

break
break

if self.conversation_memory and self.output is not None:
run = Run(input=self.input_task.input, output=self.output)
Expand Down
5 changes: 3 additions & 2 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ def run(self) -> BaseArtifact:
return ErrorArtifact("no tool output")

def execute_actions(self, actions: list[ToolAction]) -> list[tuple[str, BaseArtifact]]:
with self.futures_executor_fn() as executor:
results = utils.execute_futures_dict({a.tag: executor.submit(self.execute_action, a) for a in actions})
results = utils.execute_futures_dict(
{a.tag: self.futures_executor.submit(self.execute_action, a) for a in actions}
)

return list(results.values())

Expand Down
Loading

0 comments on commit 14a0f0d

Please sign in to comment.