diff --git a/dlt/common/plugins/plugins.py b/dlt/common/plugins/plugins.py index 34710dfa74..bf8ca71041 100644 --- a/dlt/common/plugins/plugins.py +++ b/dlt/common/plugins/plugins.py @@ -13,6 +13,7 @@ import multiprocessing as mp from functools import wraps import threading +import queue SAFE_SUFFIX = "_safe" @@ -22,10 +23,11 @@ def create_safe_version(f: TFun) -> TFun: @wraps(f) def _wrap(self: "PluginsContext", *args: Any, **kwargs: Any) -> Any: # send message to shared queue if this is not the main instance - if not self._main or threading.main_thread() != threading.current_thread(): - self._queue.put((f.__name__, args, kwargs)) - else: - getattr(self, (f.__name__ + SAFE_SUFFIX))(*args, **kwargs) + if self._plugins: + if not self._main or threading.main_thread() != threading.current_thread(): + self._queue.put((f.__name__, args, kwargs)) + else: + getattr(self, (f.__name__ + SAFE_SUFFIX))(*args, **kwargs) return f(self, *args, **kwargs) return _wrap # type: ignore @@ -38,10 +40,7 @@ def __init__(self, main: bool = True) -> None: self._callback_plugins: List[CallbackPlugin[BaseConfiguration]] = [] self._initial_plugins: TPluginArg = [] self._main = main - - if self._main: - manager = mp.Manager() - self._queue = manager.Queue() + self._queue: Optional[queue.Queue[Any]] = None def _resolve_plugin(self, plugin: TSinglePluginArg) -> Plugin[BaseConfiguration]: resolved_plugin: Plugin[BaseConfiguration] = None @@ -61,7 +60,7 @@ def _resolve_plugin(self, plugin: TSinglePluginArg) -> Plugin[BaseConfiguration] raise TypeError(f"Plugin {plugin} is not a subclass of Plugin nor a plugin name string") return resolved_plugin - # pickle support + # pickle support, send plugins config and queue to child process / threads def __getstate__(self) -> Dict[str, Any]: return {"plugins": self._initial_plugins, "queue": self._queue} @@ -72,6 +71,8 @@ def __setstate__(self, d: Dict[str, Any]) -> None: def process_queue(self) -> None: assert self._main + if not self._plugins: + return try: while True: name, args, kwargs = self._queue.get_nowait() @@ -80,9 +81,10 @@ def process_queue(self) -> None: pass def setup_plugins(self, plugins: TPluginArg) -> None: - self._initial_plugins = plugins if not plugins: return + self._initial_plugins = plugins + # if not plugins: if not isinstance(plugins, Iterable): plugins = [plugins] for p in plugins: @@ -91,6 +93,10 @@ def setup_plugins(self, plugins: TPluginArg) -> None: if isinstance(resolved_plugin, CallbackPlugin): self._callback_plugins.append(resolved_plugin) + if self._plugins and self._main: + manager = mp.Manager() + self._queue = manager.Queue() + def get_plugin(self, plugin_name: str) -> Optional[Plugin[BaseConfiguration]]: for p in self._plugins: if p.NAME == plugin_name: diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 7496bcfab7..08925df273 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -185,17 +185,20 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: # call step self._plugin_ctx.on_step_start(f.__name__, self) - with self._container.injectable_context(self._plugin_ctx): - result = f(self, *args, **kwargs) - self._plugin_ctx.on_step_end(f.__name__, self) - # ensure messages queue is completely processed - self._plugin_ctx.process_queue() + try: + with self._container.injectable_context(self._plugin_ctx): + result = f(self, *args, **kwargs) + finally: + self._plugin_ctx.on_step_end(f.__name__, self) + + # ensure messages queue is completely processed + self._plugin_ctx.process_queue() - if is_new_context: - self._plugin_ctx.on_end(self) - self._last_plugin_ctx = self._plugin_ctx - self._plugin_ctx = None + if is_new_context: + self._plugin_ctx.on_end(self) + self._last_plugin_ctx = self._plugin_ctx + self._plugin_ctx = None return result diff --git a/tests/common/plugins/test_plugin_basics.py b/tests/common/plugins/test_plugin_basics.py index a280e6fed4..50dd1960ab 100644 --- a/tests/common/plugins/test_plugin_basics.py +++ b/tests/common/plugins/test_plugin_basics.py @@ -20,6 +20,8 @@ def __init__(self) -> None: super().__init__() self.start_steps: List[str] = [] self.end_steps: List[str] = [] + self.on_start_called: int = 0 + self.on_end_called: int = 0 def on_step_start(self, step: str, pipeline: SupportsPipeline) -> None: self.start_steps.append(step) @@ -27,6 +29,12 @@ def on_step_start(self, step: str, pipeline: SupportsPipeline) -> None: def on_step_end(self, step: str, pipeline: SupportsPipeline) -> None: self.end_steps.append(step) + def on_start(self, pipeline: SupportsPipeline) -> None: + self.on_start_called += 1 + + def on_end(self, pipeline: SupportsPipeline) -> None: + self.on_end_called += 1 + def test_simple_plugin_steps() -> None: """very simple test to see if plugins work""" @@ -39,6 +47,10 @@ def test_simple_plugin_steps() -> None: assert plug.start_steps == ["run", "extract", "normalize", "load"] assert plug.end_steps == ["extract", "normalize", "load", "run"] + assert plug.on_start_called == 1 + assert plug.on_end_called == 1 + assert pipeline._last_plugin_ctx is not None + assert pipeline._plugin_ctx is None def test_plugin_resolution() -> None: