Skip to content

Commit

Permalink
tests, check class of node before registering
Browse files Browse the repository at this point in the history
  • Loading branch information
roveo committed Nov 10, 2020
1 parent b52b92c commit dda82ff
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 17 deletions.
23 changes: 19 additions & 4 deletions streamz/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def register_api(cls, modifier=identity, attribute_name=None):
>>> Stream().foo(...) # this works now
It attaches the callable as a normal attribute to the class object. In
doing so it respsects inheritance (all subclasses of Stream will also
doing so it respects inheritance (all subclasses of Stream will also
get the foo attribute).
By default callables are assumed to be instance methods. If you like
Expand All @@ -285,6 +285,15 @@ def register_api(cls, modifier=identity, attribute_name=None):
... ...
>>> Stream.foo(...) # Foo operates as a static method
You can also provide an optional ``attribute_name`` argument to control
the name of the attribute your callable will be attached as.
>>> @Stream.register_api(attribute_name="bar")
... class foo(Stream):
... ...
>> Stream().bar(...) # foo was actually attached as bar
"""
def _(func):
@functools.wraps(func)
Expand All @@ -298,11 +307,17 @@ def wrapped(*args, **kwargs):
@classmethod
def register_plugin_entry_point(cls, entry_point, modifier=identity):
def stub(*args, **kwargs):
attribute = entry_point.load()
node = entry_point.load()
if not issubclass(node, Stream):
raise TypeError(
f"Error loading {entry_point.name} "
f"from module {entry_point.module_name}: "
f"{entry_point.cls.__name__} must be a subclass of Stream"
)
cls.register_api(
modifier=modifier, attribute_name=entry_point.name
)(attribute)
return attribute(*args, **kwargs)
)(node)
return node(*args, **kwargs)
cls.register_api(modifier=modifier, attribute_name=entry_point.name)(stub)

def start(self):
Expand Down
6 changes: 3 additions & 3 deletions streamz/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

def load_plugins(cls):
for entry_point in pkg_resources.iter_entry_points("streamz.sources"):
cls.register_plugin_entrypoint(entry_point, staticmethod)
cls.register_plugin_entry_point(entry_point, staticmethod)
for entry_point in pkg_resources.iter_entry_points("streamz.nodes"):
cls.register_plugin_entrypoint(entry_point)
cls.register_plugin_entry_point(entry_point)
for entry_point in pkg_resources.iter_entry_points("streamz.sinks"):
cls.register_plugin_entrypoint(entry_point)
cls.register_plugin_entry_point(entry_point)
37 changes: 27 additions & 10 deletions streamz/tests/test_plugins.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,55 @@
from streamz.sources import Source
from streamz import Stream
import pytest
from streamz import Source, Stream


class MockEntryPoint:

def __init__(self, name, cls):
def __init__(self, name, cls, module_name=None):
self.name = name
self.cls = cls
self.module_name = module_name

def load(self):
return self.cls


def test_register_plugin_entry_point():
class test(Stream):
class test_stream(Stream):
pass

entry_point = MockEntryPoint("test_node", test)
entry_point = MockEntryPoint("test_node", test_stream)
Stream.register_plugin_entry_point(entry_point)

assert Stream.test_node.__name__ == "stub"

Stream().test_node()

assert Stream.test_node.__name__ == "test"
assert Stream.test_node.__name__ == "test_stream"


def test_register_plugin_entry_point_modifier():
class test(Source):
class test_source(Source):
pass

entry_point = MockEntryPoint("from_test", test)
Stream.register_plugin_entry_point(entry_point, staticmethod)
def modifier(fn):
fn.__name__ = 'modified_name'
return staticmethod(fn)

entry_point = MockEntryPoint("from_test", test_source)
Stream.register_plugin_entry_point(entry_point, modifier)

Stream.from_test()

assert Stream.from_test.__self__ is Stream
assert Stream.from_test.__name__ == "modified_name"


def test_register_plugin_entry_point_raises():
class invalid_node:
pass

entry_point = MockEntryPoint("test", invalid_node, "test_module.test")

Stream.register_plugin_entry_point(entry_point)

with pytest.raises(TypeError):
Stream.test()

0 comments on commit dda82ff

Please sign in to comment.