Skip to content

Commit

Permalink
add on start and on end handlers (tests missing)
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Jan 11, 2024
1 parent 003f4f5 commit 7312b18
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
8 changes: 8 additions & 0 deletions dlt/common/plugins/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ def get_plugin(self, plugin_name: str) -> Optional[Plugin[BaseConfiguration]]:
#
# Main Pipeline Callbacks
#
def on_start(self, pipeline: SupportsPipeline) -> None:
for p in self._callback_plugins:
p.on_start(pipeline)

def on_end(self, pipeline: SupportsPipeline) -> None:
for p in self._callback_plugins:
p.on_end(pipeline)

def on_step_start(self, step: str, pipeline: SupportsPipeline) -> None:
for p in self._callback_plugins:
p.on_step_start(step, pipeline)
Expand Down
6 changes: 6 additions & 0 deletions dlt/common/plugins/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@


class SupportsCallbackPlugin:
def on_start(self, pipeline: SupportsPipeline) -> None:
pass

def on_end(self, pipeline: SupportsPipeline) -> None:
pass

def on_step_start(self, step: str, pipeline: SupportsPipeline) -> None:
pass

Expand Down
12 changes: 10 additions & 2 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,11 @@ def with_plugins() -> Callable[[TFun], TFun]:
def decorator(f: TFun) -> TFun:
@wraps(f)
def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any:
is_new_context = self._plugin_ctx is None
if self._plugin_ctx is None:
self._plugin_ctx = PluginsContext()
self._plugin_ctx.setup_plugins(self.plugins)
self._plugin_ctx.on_start(self)

# call step
self._plugin_ctx.on_step_start(f.__name__, self)
Expand All @@ -190,6 +192,11 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any:
# ensure messages queue is completely processed
self._plugin_ctx.process_queue()

if is_new_context:
self._plugin_ctx.on_end(self)
self._last_plugin_ctx = self._plugin_ctx
self._plugin_ctx = None

return result

return _wrap # type: ignore
Expand Down Expand Up @@ -335,6 +342,7 @@ def __init__(
self._schema_storage_config: SchemaStorageConfiguration = None
self._trace: PipelineTrace = None
self._plugin_ctx: PluginsContext = None
self._last_plugin_ctx: PluginsContext = None
self._last_trace: PipelineTrace = None
self._state_restored: bool = False

Expand Down Expand Up @@ -838,8 +846,8 @@ def last_trace(self) -> PipelineTrace:

def get_plugin(self, plugin_name: str) -> Any:
"""Returns the plugin instance by name"""
if self._plugin_ctx is not None:
return self._plugin_ctx.get_plugin(plugin_name)
if self._last_plugin_ctx is not None:
return self._last_plugin_ctx.get_plugin(plugin_name)
return None

@deprecated(
Expand Down

0 comments on commit 7312b18

Please sign in to comment.