diff --git a/fs/opener/registry.py b/fs/opener/registry.py index 19547234..16f40417 100644 --- a/fs/opener/registry.py +++ b/fs/opener/registry.py @@ -8,7 +8,7 @@ import collections import contextlib -import pkg_resources +import sys from ..errors import ResourceReadOnly from .base import Opener @@ -21,6 +21,30 @@ from ..base import FS +if sys.version_info >= (3, 8): + import importlib.metadata + + if sys.version_info >= (3, 10): + + def entrypoints(group, name=None): + ep = importlib.metadata.entry_points(group=group, name=name) + return tuple(n for n in ep) + + else: + + def entrypoints(group, name=None): + ep = importlib.metadata.entry_points() + if name: + return tuple(n for n in ep.get(group, ()) if n.name == name) + return ep.get(group, ()) + +else: + import pkg_resources + + def entrypoints(group, name=None): + return tuple(pkg_resources.iter_entry_points(group, name)) + + class Registry(object): """A registry for `Opener` instances.""" @@ -74,10 +98,7 @@ def protocols(self): """`list`: the list of supported protocols.""" _protocols = list(self._protocols) if self.load_extern: - _protocols.extend( - entry_point.name - for entry_point in pkg_resources.iter_entry_points("fs.opener") - ) + _protocols.extend(n.name for n in entrypoints("fs.opener")) _protocols = list(collections.OrderedDict.fromkeys(_protocols)) return _protocols @@ -101,10 +122,9 @@ def get_opener(self, protocol): """ protocol = protocol or self.default_opener - if self.load_extern: - entry_point = next( - pkg_resources.iter_entry_points("fs.opener", protocol), None - ) + ep = entrypoints("fs.opener", protocol) + if self.load_extern and ep: + entry_point = ep[0] else: entry_point = None diff --git a/tests/test_opener.py b/tests/test_opener.py index 43d56903..7500080d 100644 --- a/tests/test_opener.py +++ b/tests/test_opener.py @@ -3,7 +3,6 @@ import sys import os -import pkg_resources import shutil import tempfile import unittest @@ -21,6 +20,11 @@ except ImportError: import mock +if sys.version_info >= (3, 8): + import importlib.metadata +else: + import pkg_resources + class TestParse(unittest.TestCase): def test_registry_repr(self): @@ -111,14 +115,25 @@ def test_protocols(self): def test_registry_protocols(self): # Check registry.protocols list the names of all available extension - extensions = [ - pkg_resources.EntryPoint("proto1", "mod1"), - pkg_resources.EntryPoint("proto2", "mod2"), - ] - m = mock.MagicMock(return_value=extensions) - with mock.patch.object( - sys.modules["pkg_resources"], "iter_entry_points", new=m - ): + if sys.version_info >= (3, 8): + extensions = ( + importlib.metadata.EntryPoint("proto1", "mod1", "fs.opener"), + importlib.metadata.EntryPoint("proto2", "mod2", "fs.opener"), + ) + if sys.version_info >= (3, 10): + m = mock.MagicMock(return_value=extensions) + else: + m = mock.MagicMock(return_value={"fs.opener": extensions}) + patch = mock.patch("importlib.metadata.entry_points", m) + else: + extensions = [ + pkg_resources.EntryPoint("proto1", "mod1"), + pkg_resources.EntryPoint("proto2", "mod2"), + ] + m = mock.MagicMock(return_value=extensions) + patch = mock.patch("pkg_resources.iter_entry_points", m) + + with patch: self.assertIn("proto1", opener.registry.protocols) self.assertIn("proto2", opener.registry.protocols) @@ -129,11 +144,19 @@ def test_unknown_protocol(self): def test_entry_point_load_error(self): entry_point = mock.MagicMock() + entry_point.name = "test" entry_point.load.side_effect = ValueError("some error") - iter_entry_points = mock.MagicMock(return_value=iter([entry_point])) - - with mock.patch("pkg_resources.iter_entry_points", iter_entry_points): + if sys.version_info >= (3, 8): + if sys.version_info >= (3, 10): + entry_points = mock.MagicMock(return_value=tuple([entry_point])) + else: + entry_points = mock.MagicMock(return_value={"fs.opener": [entry_point]}) + patch = mock.patch("importlib.metadata.entry_points", entry_points) + else: + iter_entry_points = mock.MagicMock(return_value=iter([entry_point])) + patch = mock.patch("pkg_resources.iter_entry_points", iter_entry_points) + with patch: with self.assertRaises(errors.EntryPointError) as ctx: opener.open_fs("test://") self.assertEqual( @@ -145,10 +168,19 @@ class NotAnOpener(object): pass entry_point = mock.MagicMock() + entry_point.name = "test" entry_point.load = mock.MagicMock(return_value=NotAnOpener) - iter_entry_points = mock.MagicMock(return_value=iter([entry_point])) - with mock.patch("pkg_resources.iter_entry_points", iter_entry_points): + if sys.version_info >= (3, 8): + if sys.version_info >= (3, 10): + entry_points = mock.MagicMock(return_value=tuple([entry_point])) + else: + entry_points = mock.MagicMock(return_value={"fs.opener": [entry_point]}) + patch = mock.patch("importlib.metadata.entry_points", entry_points) + else: + iter_entry_points = mock.MagicMock(return_value=iter([entry_point])) + patch = mock.patch("pkg_resources.iter_entry_points", iter_entry_points) + with patch: with self.assertRaises(errors.EntryPointError) as ctx: opener.open_fs("test://") self.assertEqual("entry point did not return an opener", str(ctx.exception)) @@ -162,10 +194,20 @@ def open_fs(self, *args, **kwargs): pass entry_point = mock.MagicMock() + entry_point.name = "test" entry_point.load = mock.MagicMock(return_value=BadOpener) - iter_entry_points = mock.MagicMock(return_value=iter([entry_point])) - with mock.patch("pkg_resources.iter_entry_points", iter_entry_points): + if sys.version_info >= (3, 8): + if sys.version_info >= (3, 10): + entry_points = mock.MagicMock(return_value=tuple([entry_point])) + else: + entry_points = mock.MagicMock(return_value={"fs.opener": [entry_point]}) + patch = mock.patch("importlib.metadata.entry_points", entry_points) + else: + iter_entry_points = mock.MagicMock(return_value=iter([entry_point])) + patch = mock.patch("pkg_resources.iter_entry_points", iter_entry_points) + + with patch: with self.assertRaises(errors.EntryPointError) as ctx: opener.open_fs("test://") self.assertEqual(