diff --git a/dlt/common/plugins/plugins.py b/dlt/common/plugins/plugins.py index 1c9fca7983..34710dfa74 100644 --- a/dlt/common/plugins/plugins.py +++ b/dlt/common/plugins/plugins.py @@ -100,6 +100,14 @@ def get_plugin(self, plugin_name: str) -> Optional[Plugin[BaseConfiguration]]: # # Main Pipeline Callbacks # + def on_start(self, pipeline: SupportsPipeline) -> None: + for p in self._callback_plugins: + p.on_start(pipeline) + + def on_end(self, pipeline: SupportsPipeline) -> None: + for p in self._callback_plugins: + p.on_end(pipeline) + def on_step_start(self, step: str, pipeline: SupportsPipeline) -> None: for p in self._callback_plugins: p.on_step_start(step, pipeline) diff --git a/dlt/common/plugins/reference.py b/dlt/common/plugins/reference.py index eee9fd9534..4cb9891131 100644 --- a/dlt/common/plugins/reference.py +++ b/dlt/common/plugins/reference.py @@ -10,6 +10,12 @@ class SupportsCallbackPlugin: + def on_start(self, pipeline: SupportsPipeline) -> None: + pass + + def on_end(self, pipeline: SupportsPipeline) -> None: + pass + def on_step_start(self, step: str, pipeline: SupportsPipeline) -> None: pass diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index b6043fbd55..7496bcfab7 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -177,9 +177,11 @@ def with_plugins() -> Callable[[TFun], TFun]: def decorator(f: TFun) -> TFun: @wraps(f) def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: + is_new_context = self._plugin_ctx is None if self._plugin_ctx is None: self._plugin_ctx = PluginsContext() self._plugin_ctx.setup_plugins(self.plugins) + self._plugin_ctx.on_start(self) # call step self._plugin_ctx.on_step_start(f.__name__, self) @@ -190,6 +192,11 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: # ensure messages queue is completely processed self._plugin_ctx.process_queue() + if is_new_context: + self._plugin_ctx.on_end(self) + self._last_plugin_ctx = self._plugin_ctx + self._plugin_ctx = None + return result return _wrap # type: ignore @@ -335,6 +342,7 @@ def __init__( self._schema_storage_config: SchemaStorageConfiguration = None self._trace: PipelineTrace = None self._plugin_ctx: PluginsContext = None + self._last_plugin_ctx: PluginsContext = None self._last_trace: PipelineTrace = None self._state_restored: bool = False @@ -838,8 +846,8 @@ def last_trace(self) -> PipelineTrace: def get_plugin(self, plugin_name: str) -> Any: """Returns the plugin instance by name""" - if self._plugin_ctx is not None: - return self._plugin_ctx.get_plugin(plugin_name) + if self._last_plugin_ctx is not None: + return self._last_plugin_ctx.get_plugin(plugin_name) return None @deprecated(