diff --git a/streamz/core.py b/streamz/core.py index bd266467..a02a4128 100644 --- a/streamz/core.py +++ b/streamz/core.py @@ -273,7 +273,7 @@ def register_api(cls, modifier=identity, attribute_name=None): >>> Stream().foo(...) # this works now It attaches the callable as a normal attribute to the class object. In - doing so it respsects inheritance (all subclasses of Stream will also + doing so it respects inheritance (all subclasses of Stream will also get the foo attribute). By default callables are assumed to be instance methods. If you like @@ -285,6 +285,15 @@ def register_api(cls, modifier=identity, attribute_name=None): ... ... >>> Stream.foo(...) # Foo operates as a static method + + You can also provide an optional ``attribute_name`` argument to control + the name of the attribute your callable will be attached as. + + >>> @Stream.register_api(attribute_name="bar") + ... class foo(Stream): + ... ... + + >> Stream().bar(...) # foo was actually attached as bar """ def _(func): @functools.wraps(func) @@ -298,11 +307,17 @@ def wrapped(*args, **kwargs): @classmethod def register_plugin_entry_point(cls, entry_point, modifier=identity): def stub(*args, **kwargs): - attribute = entry_point.load() + node = entry_point.load() + if not issubclass(node, Stream): + raise TypeError( + f"Error loading {entry_point.name} " + f"from module {entry_point.module_name}: " + f"{entry_point.cls.__name__} must be a subclass of Stream" + ) cls.register_api( modifier=modifier, attribute_name=entry_point.name - )(attribute) - return attribute(*args, **kwargs) + )(node) + return node(*args, **kwargs) cls.register_api(modifier=modifier, attribute_name=entry_point.name)(stub) def start(self): diff --git a/streamz/plugins.py b/streamz/plugins.py index b1952651..313ee587 100644 --- a/streamz/plugins.py +++ b/streamz/plugins.py @@ -3,8 +3,8 @@ def load_plugins(cls): for entry_point in pkg_resources.iter_entry_points("streamz.sources"): - cls.register_plugin_entrypoint(entry_point, staticmethod) + cls.register_plugin_entry_point(entry_point, staticmethod) for entry_point in pkg_resources.iter_entry_points("streamz.nodes"): - cls.register_plugin_entrypoint(entry_point) + cls.register_plugin_entry_point(entry_point) for entry_point in pkg_resources.iter_entry_points("streamz.sinks"): - cls.register_plugin_entrypoint(entry_point) + cls.register_plugin_entry_point(entry_point) diff --git a/streamz/tests/test_plugins.py b/streamz/tests/test_plugins.py index 99ce36dd..116f440d 100644 --- a/streamz/tests/test_plugins.py +++ b/streamz/tests/test_plugins.py @@ -1,38 +1,55 @@ -from streamz.sources import Source -from streamz import Stream +import pytest +from streamz import Source, Stream class MockEntryPoint: - def __init__(self, name, cls): + def __init__(self, name, cls, module_name=None): self.name = name self.cls = cls + self.module_name = module_name def load(self): return self.cls def test_register_plugin_entry_point(): - class test(Stream): + class test_stream(Stream): pass - entry_point = MockEntryPoint("test_node", test) + entry_point = MockEntryPoint("test_node", test_stream) Stream.register_plugin_entry_point(entry_point) assert Stream.test_node.__name__ == "stub" Stream().test_node() - assert Stream.test_node.__name__ == "test" + assert Stream.test_node.__name__ == "test_stream" def test_register_plugin_entry_point_modifier(): - class test(Source): + class test_source(Source): pass - entry_point = MockEntryPoint("from_test", test) - Stream.register_plugin_entry_point(entry_point, staticmethod) + def modifier(fn): + fn.__name__ = 'modified_name' + return staticmethod(fn) + + entry_point = MockEntryPoint("from_test", test_source) + Stream.register_plugin_entry_point(entry_point, modifier) Stream.from_test() - assert Stream.from_test.__self__ is Stream + assert Stream.from_test.__name__ == "modified_name" + + +def test_register_plugin_entry_point_raises(): + class invalid_node: + pass + + entry_point = MockEntryPoint("test", invalid_node, "test_module.test") + + Stream.register_plugin_entry_point(entry_point) + + with pytest.raises(TypeError): + Stream.test()