diff --git a/CHANGELOG.md b/CHANGELOG.md index 54cecbdc1..03ead3395 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for bitshift composition in `BaseTask` for adding parent/child tasks. - `JsonArtifact` for handling de/seralization of values. - `Chat.logger_level` for setting what the `Chat` utility sets the logger level to. +- `FuturesExecutorMixin` to DRY up and optimize concurrent code across multiple classes. ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. diff --git a/griptape/drivers/event_listener/base_event_listener_driver.py b/griptape/drivers/event_listener/base_event_listener_driver.py index 9f7cb79fb..0af57f0f3 100644 --- a/griptape/drivers/event_listener/base_event_listener_driver.py +++ b/griptape/drivers/event_listener/base_event_listener_driver.py @@ -2,11 +2,12 @@ import logging from abc import ABC, abstractmethod -from concurrent import futures -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from attrs import Factory, define, field +from griptape.mixins import FuturesExecutorMixin + if TYPE_CHECKING: from griptape.events import BaseEvent @@ -14,11 +15,7 @@ @define -class BaseEventListenerDriver(ABC): - futures_executor_fn: Callable[[], futures.Executor] = field( - default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), - kw_only=True, - ) +class BaseEventListenerDriver(FuturesExecutorMixin, ABC): batched: bool = field(default=True, kw_only=True) batch_size: int = field(default=10, kw_only=True) @@ -29,8 +26,7 @@ def batch(self) -> list[dict]: return self._batch def publish_event(self, event: BaseEvent | dict, *, flush: bool = False) -> None: - with self.futures_executor_fn() as executor: - executor.submit(self._safe_try_publish_event, event, flush=flush) + self.futures_executor.submit(self._safe_try_publish_event, event, flush=flush) @abstractmethod def try_publish_event_payload(self, event_payload: dict) -> None: ... diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index ed1f2d589..7ebccdcad 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -2,22 +2,21 @@ import uuid from abc import ABC, abstractmethod -from concurrent import futures from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Optional -from attrs import Factory, define, field +from attrs import define, field from griptape import utils from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact -from griptape.mixins import SerializableMixin +from griptape.mixins import FuturesExecutorMixin, SerializableMixin if TYPE_CHECKING: from griptape.drivers import BaseEmbeddingDriver @define -class BaseVectorStoreDriver(SerializableMixin, ABC): +class BaseVectorStoreDriver(SerializableMixin, FuturesExecutorMixin, ABC): DEFAULT_QUERY_COUNT = 5 @dataclass @@ -36,10 +35,6 @@ def to_artifact(self) -> BaseArtifact: return BaseArtifact.from_json(self.meta["artifact"]) # pyright: ignore[reportOptionalSubscript] embedding_driver: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True}) - futures_executor_fn: Callable[[], futures.Executor] = field( - default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), - kw_only=True, - ) def upsert_text_artifacts( self, @@ -48,24 +43,23 @@ def upsert_text_artifacts( meta: Optional[dict] = None, **kwargs, ) -> None: - with self.futures_executor_fn() as executor: - if isinstance(artifacts, list): - utils.execute_futures_list( - [ - executor.submit(self.upsert_text_artifact, a, namespace=None, meta=meta, **kwargs) - for a in artifacts - ], - ) - else: - utils.execute_futures_dict( - { - namespace: executor.submit( - self.upsert_text_artifact, a, namespace=namespace, meta=meta, **kwargs - ) - for namespace, artifact_list in artifacts.items() - for a in artifact_list - }, - ) + if isinstance(artifacts, list): + utils.execute_futures_list( + [ + self.futures_executor.submit(self.upsert_text_artifact, a, namespace=None, meta=meta, **kwargs) + for a in artifacts + ], + ) + else: + utils.execute_futures_dict( + { + namespace: self.futures_executor.submit( + self.upsert_text_artifact, a, namespace=namespace, meta=meta, **kwargs + ) + for namespace, artifact_list in artifacts.items() + for a in artifact_list + }, + ) def upsert_text_artifact( self, diff --git a/griptape/engines/rag/modules/base_rag_module.py b/griptape/engines/rag/modules/base_rag_module.py index 01d6d1b1d..668b3aced 100644 --- a/griptape/engines/rag/modules/base_rag_module.py +++ b/griptape/engines/rag/modules/base_rag_module.py @@ -2,25 +2,22 @@ import uuid from abc import ABC -from concurrent import futures -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Optional from attrs import Factory, define, field from griptape.common import Message, PromptStack +from griptape.mixins import FuturesExecutorMixin if TYPE_CHECKING: from griptape.engines.rag import RagContext @define(kw_only=True) -class BaseRagModule(ABC): +class BaseRagModule(FuturesExecutorMixin, ABC): name: str = field( default=Factory(lambda self: f"{self.__class__.__name__}-{uuid.uuid4().hex}", takes_self=True), kw_only=True ) - futures_executor_fn: Callable[[], futures.Executor] = field( - default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), - ) def generate_prompt_stack(self, system_prompt: Optional[str], query: str) -> PromptStack: messages = [] diff --git a/griptape/engines/rag/stages/base_rag_stage.py b/griptape/engines/rag/stages/base_rag_stage.py index 4f5a9bcd1..6a28551b4 100644 --- a/griptape/engines/rag/stages/base_rag_stage.py +++ b/griptape/engines/rag/stages/base_rag_stage.py @@ -1,20 +1,15 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from concurrent import futures -from typing import Callable -from attrs import Factory, define, field +from attrs import define from griptape.engines.rag import RagContext from griptape.engines.rag.modules import BaseRagModule +from griptape.mixins import FuturesExecutorMixin @define(kw_only=True) -class BaseRagStage(ABC): - futures_executor_fn: Callable[[], futures.Executor] = field( - default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), - ) - +class BaseRagStage(FuturesExecutorMixin, ABC): @abstractmethod def run(self, context: RagContext) -> RagContext: ... diff --git a/griptape/engines/rag/stages/response_rag_stage.py b/griptape/engines/rag/stages/response_rag_stage.py index 4bc0b2be5..de286317c 100644 --- a/griptape/engines/rag/stages/response_rag_stage.py +++ b/griptape/engines/rag/stages/response_rag_stage.py @@ -31,8 +31,9 @@ def modules(self) -> list[BaseRagModule]: def run(self, context: RagContext) -> RagContext: logging.info("ResponseRagStage: running %s retrieval modules in parallel", len(self.response_modules)) - with self.futures_executor_fn() as executor: - results = utils.execute_futures_list([executor.submit(r.run, context) for r in self.response_modules]) + results = utils.execute_futures_list( + [self.futures_executor.submit(r.run, context) for r in self.response_modules] + ) context.outputs = results diff --git a/griptape/engines/rag/stages/retrieval_rag_stage.py b/griptape/engines/rag/stages/retrieval_rag_stage.py index fa618a7ff..6ce9fb19f 100644 --- a/griptape/engines/rag/stages/retrieval_rag_stage.py +++ b/griptape/engines/rag/stages/retrieval_rag_stage.py @@ -35,8 +35,9 @@ def modules(self) -> list[BaseRagModule]: def run(self, context: RagContext) -> RagContext: logging.info("RetrievalRagStage: running %s retrieval modules in parallel", len(self.retrieval_modules)) - with self.futures_executor_fn() as executor: - results = utils.execute_futures_list([executor.submit(r.run, context) for r in self.retrieval_modules]) + results = utils.execute_futures_list( + [self.futures_executor.submit(r.run, context) for r in self.retrieval_modules] + ) # flatten the list of lists results = list(itertools.chain.from_iterable(results)) diff --git a/griptape/loaders/base_loader.py b/griptape/loaders/base_loader.py index 09551d9ab..525b4df0a 100644 --- a/griptape/loaders/base_loader.py +++ b/griptape/loaders/base_loader.py @@ -1,11 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from concurrent import futures -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Optional -from attrs import Factory, define, field +from attrs import define, field +from griptape.mixins import FuturesExecutorMixin from griptape.utils.futures import execute_futures_dict from griptape.utils.hash import bytes_to_hash, str_to_hash @@ -16,11 +16,7 @@ @define -class BaseLoader(ABC): - futures_executor_fn: Callable[[], futures.Executor] = field( - default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), - kw_only=True, - ) +class BaseLoader(FuturesExecutorMixin, ABC): encoding: Optional[str] = field(default=None, kw_only=True) @abstractmethod @@ -36,10 +32,12 @@ def load_collection( # to avoid duplicate work. sources_by_key = {self.to_key(source): source for source in sources} - with self.futures_executor_fn() as executor: - return execute_futures_dict( - {key: executor.submit(self.load, source, *args, **kwargs) for key, source in sources_by_key.items()}, - ) + return execute_futures_dict( + { + key: self.futures_executor.submit(self.load, source, *args, **kwargs) + for key, source in sources_by_key.items() + }, + ) def to_key(self, source: Any, *args, **kwargs) -> str: if isinstance(source, bytes): diff --git a/griptape/mixins/__init__.py b/griptape/mixins/__init__.py index d9eea53c2..1bfa95c9a 100644 --- a/griptape/mixins/__init__.py +++ b/griptape/mixins/__init__.py @@ -4,6 +4,7 @@ from .rule_mixin import RuleMixin from .serializable_mixin import SerializableMixin from .media_artifact_file_output_mixin import BlobArtifactFileOutputMixin +from .futures_executor_mixin import FuturesExecutorMixin __all__ = [ "ActivityMixin", @@ -12,4 +13,5 @@ "RuleMixin", "BlobArtifactFileOutputMixin", "SerializableMixin", + "FuturesExecutorMixin", ] diff --git a/griptape/mixins/futures_executor_mixin.py b/griptape/mixins/futures_executor_mixin.py new file mode 100644 index 000000000..6c09aee32 --- /dev/null +++ b/griptape/mixins/futures_executor_mixin.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import threading +from abc import ABC +from concurrent import futures +from typing import Callable, Optional + +from attrs import Factory, define, field + + +@define(slots=False, kw_only=True) +class FuturesExecutorMixin(ABC): + futures_executor_fn: Callable[[], futures.Executor] = field( + default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), + ) + thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock())) + + _futures_executor: Optional[futures.Executor] = field(init=False, default=None) + + @property + def futures_executor(self) -> futures.Executor: + with self.thread_lock: + if self._futures_executor is None: + self._futures_executor = self.futures_executor_fn() + + return self._futures_executor + + def __del__(self) -> None: + with self.thread_lock: + if self._futures_executor: + self._futures_executor.shutdown(wait=True) + self._futures_executor = None diff --git a/griptape/structures/workflow.py b/griptape/structures/workflow.py index f8a3a6ee0..f1e1ec86b 100644 --- a/griptape/structures/workflow.py +++ b/griptape/structures/workflow.py @@ -1,14 +1,15 @@ from __future__ import annotations import concurrent.futures as futures -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Optional -from attrs import Factory, define, field +from attrs import define from graphlib import TopologicalSorter from griptape.artifacts import ErrorArtifact from griptape.common import observable from griptape.memory.structure import Run +from griptape.mixins import FuturesExecutorMixin from griptape.structures import Structure if TYPE_CHECKING: @@ -16,12 +17,7 @@ @define -class Workflow(Structure): - futures_executor_fn: Callable[[], futures.Executor] = field( - default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), - kw_only=True, - ) - +class Workflow(Structure, FuturesExecutorMixin): @property def input_task(self) -> Optional[BaseTask]: return self.order_tasks()[0] if self.tasks else None @@ -94,22 +90,21 @@ def insert_task( def try_run(self, *args) -> Workflow: exit_loop = False - with self.futures_executor_fn() as executor: - while not self.is_finished() and not exit_loop: - futures_list = {} - ordered_tasks = self.order_tasks() + while not self.is_finished() and not exit_loop: + futures_list = {} + ordered_tasks = self.order_tasks() - for task in ordered_tasks: - if task.can_execute(): - future = executor.submit(task.execute) - futures_list[future] = task + for task in ordered_tasks: + if task.can_execute(): + future = self.futures_executor.submit(task.execute) + futures_list[future] = task - # Wait for all tasks to complete - for future in futures.as_completed(futures_list): - if isinstance(future.result(), ErrorArtifact) and self.fail_fast: - exit_loop = True + # Wait for all tasks to complete + for future in futures.as_completed(futures_list): + if isinstance(future.result(), ErrorArtifact) and self.fail_fast: + exit_loop = True - break + break if self.conversation_memory and self.output is not None: run = Run(input=self.input_task.input, output=self.output) diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 7d2fb5efd..befad59e0 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -139,8 +139,9 @@ def run(self) -> BaseArtifact: return ErrorArtifact("no tool output") def execute_actions(self, actions: list[ToolAction]) -> list[tuple[str, BaseArtifact]]: - with self.futures_executor_fn() as executor: - results = utils.execute_futures_dict({a.tag: executor.submit(self.execute_action, a) for a in actions}) + results = utils.execute_futures_dict( + {a.tag: self.futures_executor.submit(self.execute_action, a) for a in actions} + ) return list(results.values()) diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 1899eccd6..f5a772e48 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -3,15 +3,15 @@ import logging import uuid from abc import ABC, abstractmethod -from concurrent import futures from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Optional from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact from griptape.config import config from griptape.events import FinishTaskEvent, StartTaskEvent, event_bus +from griptape.mixins import FuturesExecutorMixin if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -22,7 +22,7 @@ @define -class BaseTask(ABC): +class BaseTask(FuturesExecutorMixin, ABC): class State(Enum): PENDING = 1 EXECUTING = 2 @@ -37,10 +37,6 @@ class State(Enum): output: Optional[BaseArtifact] = field(default=None, init=False) context: dict[str, Any] = field(factory=dict, kw_only=True) - futures_executor_fn: Callable[[], futures.Executor] = field( - default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), - kw_only=True, - ) def __rshift__(self, other: BaseTask) -> BaseTask: self.add_child(other) diff --git a/tests/mocks/mock_futures_executor.py b/tests/mocks/mock_futures_executor.py new file mode 100644 index 000000000..cbbf84560 --- /dev/null +++ b/tests/mocks/mock_futures_executor.py @@ -0,0 +1,4 @@ +from griptape.mixins import FuturesExecutorMixin + + +class MockFuturesExecutor(FuturesExecutorMixin): ... diff --git a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py index 04cfef34b..114778f72 100644 --- a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py @@ -12,9 +12,7 @@ def test_publish_event(self): driver.publish_event(MockEvent().to_dict()) - executor.__enter__.assert_called_once() executor.submit.assert_called_once() - executor.__exit__.assert_called_once() def test__safe_try_publish_event(self): driver = MockEventListenerDriver(batched=False) diff --git a/tests/unit/mixins/test_futures_executor_mixin.py b/tests/unit/mixins/test_futures_executor_mixin.py new file mode 100644 index 000000000..3be336687 --- /dev/null +++ b/tests/unit/mixins/test_futures_executor_mixin.py @@ -0,0 +1,10 @@ +from concurrent import futures + +from tests.mocks.mock_futures_executor import MockFuturesExecutor + + +class TestFuturesExecutorMixin: + def test_futures_executor(self): + executor = futures.ThreadPoolExecutor() + + assert MockFuturesExecutor(futures_executor_fn=lambda: executor).futures_executor == executor