Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Jan 15, 2024
1 parent 7312b18 commit ad4eedf
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 19 deletions.
26 changes: 16 additions & 10 deletions dlt/common/plugins/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import multiprocessing as mp
from functools import wraps
import threading
import queue


SAFE_SUFFIX = "_safe"
Expand All @@ -22,10 +23,11 @@ def create_safe_version(f: TFun) -> TFun:
@wraps(f)
def _wrap(self: "PluginsContext", *args: Any, **kwargs: Any) -> Any:
# send message to shared queue if this is not the main instance
if not self._main or threading.main_thread() != threading.current_thread():
self._queue.put((f.__name__, args, kwargs))
else:
getattr(self, (f.__name__ + SAFE_SUFFIX))(*args, **kwargs)
if self._plugins:
if not self._main or threading.main_thread() != threading.current_thread():
self._queue.put((f.__name__, args, kwargs))
else:
getattr(self, (f.__name__ + SAFE_SUFFIX))(*args, **kwargs)
return f(self, *args, **kwargs)

return _wrap # type: ignore
Expand All @@ -38,10 +40,7 @@ def __init__(self, main: bool = True) -> None:
self._callback_plugins: List[CallbackPlugin[BaseConfiguration]] = []
self._initial_plugins: TPluginArg = []
self._main = main

if self._main:
manager = mp.Manager()
self._queue = manager.Queue()
self._queue: Optional[queue.Queue[Any]] = None

def _resolve_plugin(self, plugin: TSinglePluginArg) -> Plugin[BaseConfiguration]:
resolved_plugin: Plugin[BaseConfiguration] = None
Expand All @@ -61,7 +60,7 @@ def _resolve_plugin(self, plugin: TSinglePluginArg) -> Plugin[BaseConfiguration]
raise TypeError(f"Plugin {plugin} is not a subclass of Plugin nor a plugin name string")
return resolved_plugin

# pickle support
# pickle support, send plugins config and queue to child process / threads
def __getstate__(self) -> Dict[str, Any]:
return {"plugins": self._initial_plugins, "queue": self._queue}

Expand All @@ -72,6 +71,8 @@ def __setstate__(self, d: Dict[str, Any]) -> None:

def process_queue(self) -> None:
assert self._main
if not self._plugins:
return
try:
while True:
name, args, kwargs = self._queue.get_nowait()
Expand All @@ -80,9 +81,10 @@ def process_queue(self) -> None:
pass

def setup_plugins(self, plugins: TPluginArg) -> None:
self._initial_plugins = plugins
if not plugins:
return
self._initial_plugins = plugins
# if not plugins:
if not isinstance(plugins, Iterable):
plugins = [plugins]
for p in plugins:
Expand All @@ -91,6 +93,10 @@ def setup_plugins(self, plugins: TPluginArg) -> None:
if isinstance(resolved_plugin, CallbackPlugin):
self._callback_plugins.append(resolved_plugin)

if self._plugins and self._main:
manager = mp.Manager()
self._queue = manager.Queue()

def get_plugin(self, plugin_name: str) -> Optional[Plugin[BaseConfiguration]]:
for p in self._plugins:
if p.NAME == plugin_name:
Expand Down
21 changes: 12 additions & 9 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,20 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any:

# call step
self._plugin_ctx.on_step_start(f.__name__, self)
with self._container.injectable_context(self._plugin_ctx):
result = f(self, *args, **kwargs)
self._plugin_ctx.on_step_end(f.__name__, self)

# ensure messages queue is completely processed
self._plugin_ctx.process_queue()
try:
with self._container.injectable_context(self._plugin_ctx):
result = f(self, *args, **kwargs)
finally:
self._plugin_ctx.on_step_end(f.__name__, self)

# ensure messages queue is completely processed
self._plugin_ctx.process_queue()

if is_new_context:
self._plugin_ctx.on_end(self)
self._last_plugin_ctx = self._plugin_ctx
self._plugin_ctx = None
if is_new_context:
self._plugin_ctx.on_end(self)
self._last_plugin_ctx = self._plugin_ctx
self._plugin_ctx = None

return result

Expand Down
12 changes: 12 additions & 0 deletions tests/common/plugins/test_plugin_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@ def __init__(self) -> None:
super().__init__()
self.start_steps: List[str] = []
self.end_steps: List[str] = []
self.on_start_called: int = 0
self.on_end_called: int = 0

def on_step_start(self, step: str, pipeline: SupportsPipeline) -> None:
self.start_steps.append(step)

def on_step_end(self, step: str, pipeline: SupportsPipeline) -> None:
self.end_steps.append(step)

def on_start(self, pipeline: SupportsPipeline) -> None:
self.on_start_called += 1

def on_end(self, pipeline: SupportsPipeline) -> None:
self.on_end_called += 1


def test_simple_plugin_steps() -> None:
"""very simple test to see if plugins work"""
Expand All @@ -39,6 +47,10 @@ def test_simple_plugin_steps() -> None:

assert plug.start_steps == ["run", "extract", "normalize", "load"]
assert plug.end_steps == ["extract", "normalize", "load", "run"]
assert plug.on_start_called == 1
assert plug.on_end_called == 1
assert pipeline._last_plugin_ctx is not None
assert pipeline._plugin_ctx is None


def test_plugin_resolution() -> None:
Expand Down

0 comments on commit ad4eedf

Please sign in to comment.