diff --git a/CHANGELOG.md b/CHANGELOG.md index cb0a4e021..968dfdead 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BaseVectorStoreDriver.query_vector` for querying vector stores with vectors. +### Fixed + +- Occasional crash during `FuturesExecutorMixin` cleanup. + +### Deprecated + +- `FuturesExecutorMixin.futures_executor`. Use `FuturesExecutorMixin.create_futures_executor` instead. + ## [1.1.1] - 2025-01-03 ### Fixed diff --git a/griptape/drivers/event_listener/base_event_listener_driver.py b/griptape/drivers/event_listener/base_event_listener_driver.py index 56d1d8c5e..2d1d0218c 100644 --- a/griptape/drivers/event_listener/base_event_listener_driver.py +++ b/griptape/drivers/event_listener/base_event_listener_driver.py @@ -30,17 +30,19 @@ def batch(self) -> list[dict]: def publish_event(self, event: BaseEvent | dict) -> None: event_payload = event if isinstance(event, dict) else event.to_dict() - if self.batched: - self._batch.append(event_payload) - if len(self.batch) >= self.batch_size: - self.futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch) - self._batch = [] - else: - self.futures_executor.submit(with_contextvars(self._safe_publish_event_payload), event_payload) + with self.create_futures_executor() as futures_executor: + if self.batched: + self._batch.append(event_payload) + if len(self.batch) >= self.batch_size: + futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch) + self._batch = [] + else: + futures_executor.submit(with_contextvars(self._safe_publish_event_payload), event_payload) def flush_events(self) -> None: if self.batch: - self.futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch) + with self.create_futures_executor() as futures_executor: + futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch) self._batch = [] @abstractmethod diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index 2cb2fa072..dc70028f0 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -45,30 +45,31 @@ def upsert_text_artifacts( meta: Optional[dict] = None, **kwargs, ) -> list[str] | dict[str, list[str]]: - if isinstance(artifacts, list): - return utils.execute_futures_list( - [ - self.futures_executor.submit( - with_contextvars(self.upsert_text_artifact), a, namespace=None, meta=meta, **kwargs - ) - for a in artifacts - ], - ) - else: - futures_dict = {} - - for namespace, artifact_list in artifacts.items(): - for a in artifact_list: - if not futures_dict.get(namespace): - futures_dict[namespace] = [] - - futures_dict[namespace].append( - self.futures_executor.submit( - with_contextvars(self.upsert_text_artifact), a, namespace=namespace, meta=meta, **kwargs + with self.create_futures_executor() as futures_executor: + if isinstance(artifacts, list): + return utils.execute_futures_list( + [ + futures_executor.submit( + with_contextvars(self.upsert_text_artifact), a, namespace=None, meta=meta, **kwargs + ) + for a in artifacts + ], + ) + else: + futures_dict = {} + + for namespace, artifact_list in artifacts.items(): + for a in artifact_list: + if not futures_dict.get(namespace): + futures_dict[namespace] = [] + + futures_dict[namespace].append( + futures_executor.submit( + with_contextvars(self.upsert_text_artifact), a, namespace=namespace, meta=meta, **kwargs + ) ) - ) - return utils.execute_futures_list_dict(futures_dict) + return utils.execute_futures_list_dict(futures_dict) def upsert_text_artifact( self, diff --git a/griptape/engines/rag/stages/response_rag_stage.py b/griptape/engines/rag/stages/response_rag_stage.py index 06d163944..b5e3473e3 100644 --- a/griptape/engines/rag/stages/response_rag_stage.py +++ b/griptape/engines/rag/stages/response_rag_stage.py @@ -32,9 +32,10 @@ def modules(self) -> list[BaseRagModule]: def run(self, context: RagContext) -> RagContext: logging.info("ResponseRagStage: running %s retrieval modules in parallel", len(self.response_modules)) - results = utils.execute_futures_list( - [self.futures_executor.submit(with_contextvars(r.run), context) for r in self.response_modules] - ) + with self.create_futures_executor() as futures_executor: + results = utils.execute_futures_list( + [futures_executor.submit(with_contextvars(r.run), context) for r in self.response_modules] + ) context.outputs = results diff --git a/griptape/engines/rag/stages/retrieval_rag_stage.py b/griptape/engines/rag/stages/retrieval_rag_stage.py index 3e2e78b6d..c6ff1e0ea 100644 --- a/griptape/engines/rag/stages/retrieval_rag_stage.py +++ b/griptape/engines/rag/stages/retrieval_rag_stage.py @@ -36,9 +36,10 @@ def modules(self) -> list[BaseRagModule]: def run(self, context: RagContext) -> RagContext: logging.info("RetrievalRagStage: running %s retrieval modules in parallel", len(self.retrieval_modules)) - results = utils.execute_futures_list( - [self.futures_executor.submit(with_contextvars(r.run), context) for r in self.retrieval_modules] - ) + with self.create_futures_executor() as futures_executor: + results = utils.execute_futures_list( + [futures_executor.submit(with_contextvars(r.run), context) for r in self.retrieval_modules] + ) # flatten the list of lists results = list(itertools.chain.from_iterable(results)) diff --git a/griptape/loaders/base_loader.py b/griptape/loaders/base_loader.py index 63324e10c..31f7de4ee 100644 --- a/griptape/loaders/base_loader.py +++ b/griptape/loaders/base_loader.py @@ -61,12 +61,13 @@ def load_collection( # to avoid duplicate work. sources_by_key = {self.to_key(source): source for source in sources} - return execute_futures_dict( - { - key: self.futures_executor.submit(with_contextvars(self.load), source) - for key, source in sources_by_key.items() - }, - ) + with self.create_futures_executor() as futures_executor: + return execute_futures_dict( + { + key: futures_executor.submit(with_contextvars(self.load), source) + for key, source in sources_by_key.items() + }, + ) def to_key(self, source: S) -> str: """Converts the source to a key for the collection.""" diff --git a/griptape/mixins/futures_executor_mixin.py b/griptape/mixins/futures_executor_mixin.py index f711a034f..3a36c265a 100644 --- a/griptape/mixins/futures_executor_mixin.py +++ b/griptape/mixins/futures_executor_mixin.py @@ -1,6 +1,6 @@ from __future__ import annotations -import contextlib +import warnings from abc import ABC from concurrent import futures from typing import Callable @@ -14,16 +14,27 @@ class FuturesExecutorMixin(ABC): default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), ) - futures_executor: futures.Executor = field( - default=Factory(lambda self: self.create_futures_executor(), takes_self=True) + _futures_executor: futures.Executor = field( + default=Factory( + lambda self: self.create_futures_executor(), + takes_self=True, + ), + alias="futures_executor", ) - def __del__(self) -> None: - executor = self.futures_executor - - if executor is not None: - self.futures_executor = None # pyright: ignore[reportAttributeAccessIssue] In practice this is safe, nobody will access this attribute after this point - - with contextlib.suppress(Exception): - # don't raise exceptions in __del__ - executor.shutdown(wait=True) + @property + def futures_executor(self) -> futures.Executor: + self.__raise_deprecation_warning() + return self._futures_executor + + @futures_executor.setter + def futures_executor(self, value: futures.Executor) -> None: + self.__raise_deprecation_warning() + self._futures_executor = value + + def __raise_deprecation_warning(self) -> None: + warnings.warn( + "`FuturesExecutorMixin.futures_executor` is deprecated and will be removed in a future release. Use `FuturesExecutorMixin.create_futures_executor` instead.", + DeprecationWarning, + stacklevel=2, + ) diff --git a/griptape/structures/workflow.py b/griptape/structures/workflow.py index 5228759db..5648d51d0 100644 --- a/griptape/structures/workflow.py +++ b/griptape/structures/workflow.py @@ -103,23 +103,24 @@ def insert_task( def try_run(self, *args) -> Workflow: exit_loop = False - while not self.is_finished() and not exit_loop: - futures_list = {} - ordered_tasks = self.order_tasks() + with self.create_futures_executor() as futures_executor: + while not self.is_finished() and not exit_loop: + futures_list = {} + ordered_tasks = self.order_tasks() - for task in ordered_tasks: - if task.can_run(): - future = self.futures_executor.submit(with_contextvars(task.run)) - futures_list[future] = task + for task in ordered_tasks: + if task.can_run(): + future = futures_executor.submit(with_contextvars(task.run)) + futures_list[future] = task - # Wait for all tasks to complete - for future in futures.as_completed(futures_list): - if isinstance(future.result(), ErrorArtifact) and self.fail_fast: - exit_loop = True + # Wait for all tasks to complete + for future in futures.as_completed(futures_list): + if isinstance(future.result(), ErrorArtifact) and self.fail_fast: + exit_loop = True - break + break - return self + return self def context(self, task: BaseTask) -> dict[str, Any]: context = super().context(task) diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 6f9d70053..9132e4ee2 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -139,9 +139,10 @@ def try_run(self) -> BaseArtifact: return ErrorArtifact("no tool output") def run_actions(self, actions: list[ToolAction]) -> list[tuple[str, BaseArtifact]]: - return utils.execute_futures_list( - [self.futures_executor.submit(with_contextvars(self.run_action), a) for a in actions] - ) + with self.create_futures_executor() as futures_executor: + return utils.execute_futures_list( + [futures_executor.submit(with_contextvars(self.run_action), a) for a in actions] + ) def run_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: if action.tool is not None: diff --git a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py index 36c8f3711..5f6515d66 100644 --- a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py @@ -8,7 +8,7 @@ class TestBaseEventListenerDriver: def test_publish_event_no_batched(self): executor = MagicMock() executor.__enter__.return_value = executor - driver = MockEventListenerDriver(batched=False, futures_executor=executor) + driver = MockEventListenerDriver(batched=False, create_futures_executor=lambda: executor) mock_event_payload = MockEvent().to_dict() driver.publish_event(mock_event_payload) @@ -18,7 +18,7 @@ def test_publish_event_no_batched(self): def test_publish_event_yes_batched(self): executor = MagicMock() executor.__enter__.return_value = executor - driver = MockEventListenerDriver(batched=True, futures_executor=executor) + driver = MockEventListenerDriver(batched=True, create_futures_executor=lambda: executor) mock_event_payload = MockEvent().to_dict() # Publish 9 events to fill the batch @@ -38,7 +38,7 @@ def test_publish_event_yes_batched(self): def test_flush_events(self): executor = MagicMock() executor.__enter__.return_value = executor - driver = MockEventListenerDriver(batched=True, futures_executor=executor) + driver = MockEventListenerDriver(batched=True, create_futures_executor=lambda: executor) driver.try_publish_event_payload_batch = MagicMock(side_effect=driver.try_publish_event_payload) driver.flush_events() diff --git a/tests/unit/mixins/test_futures_executor_mixin.py b/tests/unit/mixins/test_futures_executor_mixin.py index 437903fe3..0cc05eb61 100644 --- a/tests/unit/mixins/test_futures_executor_mixin.py +++ b/tests/unit/mixins/test_futures_executor_mixin.py @@ -1,5 +1,7 @@ from concurrent import futures +import pytest + from tests.mocks.mock_futures_executor import MockFuturesExecutor @@ -8,3 +10,9 @@ def test_futures_executor(self): executor = futures.ThreadPoolExecutor() assert MockFuturesExecutor(create_futures_executor=lambda: executor).futures_executor == executor + + def test_deprecated_futures_executor(self): + mock_executor = MockFuturesExecutor() + with pytest.warns(DeprecationWarning): + assert mock_executor.futures_executor + mock_executor.futures_executor = futures.ThreadPoolExecutor()