From f3d19ff9d749b3ecef346d742cde9a88a05657f6 Mon Sep 17 00:00:00 2001 From: Alexey Pelykh Date: Tue, 14 Mar 2023 15:30:07 +0100 Subject: [PATCH] Support plugins defined as inner classes (#1318) * Support plugins defined as inner classes * Prefer __qualname__ over __name__ for classes --------- Co-authored-by: Abhinav Singh <126065+abhinavsingh@users.noreply.github.com> --- docs/changelog-fragments.d/1318.feature.md | 1 + proxy/common/plugins.py | 52 ++++++++++++++----- proxy/core/acceptor/acceptor.py | 2 +- proxy/core/work/fd/fd.py | 2 +- proxy/core/work/work.py | 2 +- proxy/http/exception/http_request_rejected.py | 2 +- proxy/http/exception/proxy_auth_failed.py | 2 +- proxy/http/exception/proxy_conn_failed.py | 2 +- proxy/http/proxy/plugin.py | 2 +- proxy/http/proxy/server.py | 8 +-- proxy/http/server/plugin.py | 2 +- tests/common/my_plugins/__init__.py | 25 +++++++++ tests/common/test_flags.py | 39 +++++++++++++- tests/core/test_event_dispatcher.py | 8 +-- tests/core/test_event_queue.py | 4 +- tests/core/test_event_subscriber.py | 4 +- 16 files changed, 122 insertions(+), 35 deletions(-) create mode 100644 docs/changelog-fragments.d/1318.feature.md create mode 100644 tests/common/my_plugins/__init__.py diff --git a/docs/changelog-fragments.d/1318.feature.md b/docs/changelog-fragments.d/1318.feature.md new file mode 100644 index 0000000000..aa194db536 --- /dev/null +++ b/docs/changelog-fragments.d/1318.feature.md @@ -0,0 +1 @@ +Support plugins defined as inner classes diff --git a/proxy/common/plugins.py b/proxy/common/plugins.py index c919154ccd..f92193ee8c 100644 --- a/proxy/common/plugins.py +++ b/proxy/common/plugins.py @@ -13,6 +13,7 @@ import logging import importlib import itertools +from types import ModuleType from typing import Any, Dict, List, Tuple, Union, Optional from .utils import text_, bytes_ @@ -75,14 +76,14 @@ def load( # this plugin_ is implementing base_klass = None for k in mro: - if bytes_(k.__name__) in p: + if bytes_(k.__qualname__) in p: base_klass = k break if base_klass is None: raise ValueError('%s is NOT a valid plugin' % text_(plugin_)) - if klass not in p[bytes_(base_klass.__name__)]: - p[bytes_(base_klass.__name__)].append(klass) - logger.info('Loaded plugin %s.%s', module_name, klass.__name__) + if klass not in p[bytes_(base_klass.__qualname__)]: + p[bytes_(base_klass.__qualname__)].append(klass) + logger.info('Loaded plugin %s.%s', module_name, klass.__qualname__) # print(p) return p @@ -90,16 +91,39 @@ def load( def importer(plugin: Union[bytes, type]) -> Tuple[type, str]: """Import and returns the plugin.""" if isinstance(plugin, type): - return (plugin, '__main__') + if inspect.isclass(plugin): + return (plugin, plugin.__module__ or '__main__') + raise ValueError('%s is not a valid reference to a plugin class' % text_(plugin)) plugin_ = text_(plugin.strip()) assert plugin_ != '' - module_name, klass_name = plugin_.rsplit(text_(DOT), 1) - klass = getattr( - importlib.import_module( - module_name.replace( - os.path.sep, text_(DOT), - ), - ), - klass_name, - ) + path = plugin_.split(text_(DOT)) + klass = None + + def locate_klass(klass_module_name: str, klass_path: List[str]) -> Union[type, None]: + klass_module_name = klass_module_name.replace(os.path.sep, text_(DOT)) + try: + klass_module = importlib.import_module(klass_module_name) + except ModuleNotFoundError: + return None + klass_container: Union[ModuleType, type] = klass_module + for klass_path_part in klass_path: + try: + klass_container = getattr(klass_container, klass_path_part) + except AttributeError: + return None + if not isinstance(klass_container, type) or not inspect.isclass(klass_container): + return None + return klass_container + + module_name = None + for module_name_parts in range(len(path) - 1, 0, -1): + module_name = '.'.join(path[0:module_name_parts]) + klass = locate_klass(module_name, path[module_name_parts:]) + if klass: + break + if klass is None: + module_name = '__main__' + klass = locate_klass(module_name, path) + if klass is None or module_name is None: + raise ValueError('%s is not resolvable as a plugin class' % text_(plugin)) return (klass, module_name) diff --git a/proxy/core/acceptor/acceptor.py b/proxy/core/acceptor/acceptor.py index 299c0efbd0..e6db855ee2 100644 --- a/proxy/core/acceptor/acceptor.py +++ b/proxy/core/acceptor/acceptor.py @@ -246,7 +246,7 @@ def _work(self, conn: socket.socket, addr: Optional[HostPort]) -> None: conn, addr, event_queue=self.event_queue, - publisher_id=self.__class__.__name__, + publisher_id=self.__class__.__qualname__, ) # TODO: Move me into target method logger.debug( # pragma: no cover diff --git a/proxy/core/work/fd/fd.py b/proxy/core/work/fd/fd.py index cb6e903d74..577e5c7c37 100644 --- a/proxy/core/work/fd/fd.py +++ b/proxy/core/work/fd/fd.py @@ -39,7 +39,7 @@ def work(self, *args: Any) -> None: self.works[fileno].publish_event( event_name=eventNames.WORK_STARTED, event_payload={'fileno': fileno, 'addr': addr}, - publisher_id=self.__class__.__name__, + publisher_id=self.__class__.__qualname__, ) try: self.works[fileno].initialize() diff --git a/proxy/core/work/work.py b/proxy/core/work/work.py index d68969a726..a9ba046fae 100644 --- a/proxy/core/work/work.py +++ b/proxy/core/work/work.py @@ -83,7 +83,7 @@ def shutdown(self) -> None: self.publish_event( event_name=eventNames.WORK_FINISHED, event_payload={}, - publisher_id=self.__class__.__name__, + publisher_id=self.__class__.__qualname__, ) def run(self) -> None: diff --git a/proxy/http/exception/http_request_rejected.py b/proxy/http/exception/http_request_rejected.py index 2b2e7a13b1..633a8e07e9 100644 --- a/proxy/http/exception/http_request_rejected.py +++ b/proxy/http/exception/http_request_rejected.py @@ -36,7 +36,7 @@ def __init__( self.reason: Optional[bytes] = reason self.headers: Optional[Dict[bytes, bytes]] = headers self.body: Optional[bytes] = body - klass_name = self.__class__.__name__ + klass_name = self.__class__.__qualname__ super().__init__( message='%s %r' % (klass_name, reason) if reason diff --git a/proxy/http/exception/proxy_auth_failed.py b/proxy/http/exception/proxy_auth_failed.py index afb2e4048e..022f8fcf1e 100644 --- a/proxy/http/exception/proxy_auth_failed.py +++ b/proxy/http/exception/proxy_auth_failed.py @@ -28,7 +28,7 @@ class ProxyAuthenticationFailed(HttpProtocolException): incoming request doesn't present necessary credentials.""" def __init__(self, **kwargs: Any) -> None: - super().__init__(self.__class__.__name__, **kwargs) + super().__init__(self.__class__.__qualname__, **kwargs) def response(self, _request: 'HttpParser') -> memoryview: return PROXY_AUTH_FAILED_RESPONSE_PKT diff --git a/proxy/http/exception/proxy_conn_failed.py b/proxy/http/exception/proxy_conn_failed.py index 2001b33605..e3854dd039 100644 --- a/proxy/http/exception/proxy_conn_failed.py +++ b/proxy/http/exception/proxy_conn_failed.py @@ -29,7 +29,7 @@ def __init__(self, host: str, port: int, reason: str, **kwargs: Any): self.host: str = host self.port: int = port self.reason: str = reason - super().__init__('%s %s' % (self.__class__.__name__, reason), **kwargs) + super().__init__('%s %s' % (self.__class__.__qualname__, reason), **kwargs) def response(self, _request: 'HttpParser') -> memoryview: return BAD_GATEWAY_RESPONSE_PKT diff --git a/proxy/http/proxy/plugin.py b/proxy/http/proxy/plugin.py index 7768e3c59d..a9c10e88f3 100644 --- a/proxy/http/proxy/plugin.py +++ b/proxy/http/proxy/plugin.py @@ -51,7 +51,7 @@ def name(self) -> str: Defaults to name of the class. This helps plugin developers to directly access a specific plugin by its name.""" - return self.__class__.__name__ # pragma: no cover + return self.__class__.__qualname__ # pragma: no cover def resolve_dns(self, host: str, port: int) -> Tuple[Optional[str], Optional['HostPort']]: """Resolve upstream server host to an IP address. diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 4838219030..f18f45fc55 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -883,7 +883,7 @@ def emit_request_complete(self) -> None: if self.request.method == httpMethods.POST else None, }, - publisher_id=self.__class__.__name__, + publisher_id=self.__class__.__qualname__, ) def emit_response_events(self, chunk_size: int) -> None: @@ -911,7 +911,7 @@ def emit_response_headers_complete(self) -> None: for k, v in self.response.headers.items() }, }, - publisher_id=self.__class__.__name__, + publisher_id=self.__class__.__qualname__, ) def emit_response_chunk_received(self, chunk_size: int) -> None: @@ -925,7 +925,7 @@ def emit_response_chunk_received(self, chunk_size: int) -> None: 'chunk_size': chunk_size, 'encoded_chunk_size': chunk_size, }, - publisher_id=self.__class__.__name__, + publisher_id=self.__class__.__qualname__, ) def emit_response_complete(self) -> None: @@ -938,7 +938,7 @@ def emit_response_complete(self) -> None: event_payload={ 'encoded_response_size': self.response.total_size, }, - publisher_id=self.__class__.__name__, + publisher_id=self.__class__.__qualname__, ) # diff --git a/proxy/http/server/plugin.py b/proxy/http/server/plugin.py index 544d39ab8f..434fba24b8 100644 --- a/proxy/http/server/plugin.py +++ b/proxy/http/server/plugin.py @@ -72,7 +72,7 @@ def name(self) -> str: Defaults to name of the class. This helps plugin developers to directly access a specific plugin by its name.""" - return self.__class__.__name__ # pragma: no cover + return self.__class__.__qualname__ # pragma: no cover @abstractmethod def routes(self) -> List[Tuple[int, str]]: diff --git a/tests/common/my_plugins/__init__.py b/tests/common/my_plugins/__init__.py new file mode 100644 index 0000000000..f0c5ddc802 --- /dev/null +++ b/tests/common/my_plugins/__init__.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from typing import Any + +from proxy.http.proxy import HttpProxyPlugin + + +class MyHttpProxyPlugin(HttpProxyPlugin): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class OuterClass: + + class MyHttpProxyPlugin(HttpProxyPlugin): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) diff --git a/tests/common/test_flags.py b/tests/common/test_flags.py index 1fc92119c5..48a3761862 100644 --- a/tests/common/test_flags.py +++ b/tests/common/test_flags.py @@ -8,7 +8,7 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -from typing import Dict, List +from typing import Any, Dict, List import unittest from unittest import mock @@ -19,6 +19,7 @@ from proxy.common.utils import bytes_ from proxy.common.version import __version__ from proxy.common.constants import PLUGIN_HTTP_PROXY, PY2_DEPRECATION_MESSAGE +from . import my_plugins class TestFlags(unittest.TestCase): @@ -140,6 +141,42 @@ def test_unique_plugin_from_class(self) -> None: ], }) + def test_plugin_from_inner_class_by_type(self) -> None: + self.flags = FlagParser.initialize( + [], plugins=[ + TestFlags.MyHttpProxyPlugin, + my_plugins.MyHttpProxyPlugin, + my_plugins.OuterClass.MyHttpProxyPlugin, + ], + ) + self.assert_plugins({ + 'HttpProtocolHandlerPlugin': [ + TestFlags.MyHttpProxyPlugin, + my_plugins.MyHttpProxyPlugin, + my_plugins.OuterClass.MyHttpProxyPlugin, + ], + }) + + def test_plugin_from_inner_class_by_name(self) -> None: + self.flags = FlagParser.initialize( + [], plugins=[ + b'tests.common.test_flags.TestFlags.MyHttpProxyPlugin', + b'tests.common.my_plugins.MyHttpProxyPlugin', + b'tests.common.my_plugins.OuterClass.MyHttpProxyPlugin', + ], + ) + self.assert_plugins({ + 'HttpProtocolHandlerPlugin': [ + TestFlags.MyHttpProxyPlugin, + my_plugins.MyHttpProxyPlugin, + my_plugins.OuterClass.MyHttpProxyPlugin, + ], + }) + + class MyHttpProxyPlugin(HttpProxyPlugin): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + def test_basic_auth_flag_is_base64_encoded(self) -> None: flags = FlagParser.initialize(['--basic-auth', 'user:pass']) self.assertEqual(flags.auth_code, b'dXNlcjpwYXNz') diff --git a/tests/core/test_event_dispatcher.py b/tests/core/test_event_dispatcher.py index 63999187b6..79ea042a72 100644 --- a/tests/core/test_event_dispatcher.py +++ b/tests/core/test_event_dispatcher.py @@ -40,7 +40,7 @@ def test_empties_queue(self) -> None: request_id='1234', event_name=eventNames.WORK_STARTED, event_payload={'hello': 'events'}, - publisher_id=self.__class__.__name__, + publisher_id=self.__class__.__qualname__, ) self.dispatcher.run_once() with self.assertRaises(queue.Empty): @@ -64,7 +64,7 @@ def subscribe(self, mock_time: mock.Mock) -> connection.Connection: request_id='1234', event_name=eventNames.WORK_STARTED, event_payload={'hello': 'events'}, - publisher_id=self.__class__.__name__, + publisher_id=self.__class__.__qualname__, ) # consume self.dispatcher.run_once() @@ -79,7 +79,7 @@ def subscribe(self, mock_time: mock.Mock) -> connection.Connection: 'event_timestamp': 1234567, 'event_name': eventNames.WORK_STARTED, 'event_payload': {'hello': 'events'}, - 'publisher_id': self.__class__.__name__, + 'publisher_id': self.__class__.__qualname__, }, ) return relay_recv @@ -101,7 +101,7 @@ def test_unsubscribe(self) -> None: request_id='1234', event_name=eventNames.WORK_STARTED, event_payload={'hello': 'events'}, - publisher_id=self.__class__.__name__, + publisher_id=self.__class__.__qualname__, ) self.dispatcher.run_once() with self.assertRaises(EOFError): diff --git a/tests/core/test_event_queue.py b/tests/core/test_event_queue.py index 29c4299805..5450d5ace6 100644 --- a/tests/core/test_event_queue.py +++ b/tests/core/test_event_queue.py @@ -34,7 +34,7 @@ def test_publish(self, mock_time: mock.Mock) -> None: request_id='1234', event_name=eventNames.WORK_STARTED, event_payload={'hello': 'events'}, - publisher_id=self.__class__.__name__, + publisher_id=self.__class__.__qualname__, ) self.assertEqual( evq.queue.get(), { @@ -44,7 +44,7 @@ def test_publish(self, mock_time: mock.Mock) -> None: 'event_timestamp': 1234567, 'event_name': eventNames.WORK_STARTED, 'event_payload': {'hello': 'events'}, - 'publisher_id': self.__class__.__name__, + 'publisher_id': self.__class__.__qualname__, }, ) diff --git a/tests/core/test_event_subscriber.py b/tests/core/test_event_subscriber.py index 59be997d41..9e0ad3dad9 100644 --- a/tests/core/test_event_subscriber.py +++ b/tests/core/test_event_subscriber.py @@ -50,7 +50,7 @@ def test_event_subscriber(self, mock_time: mock.Mock) -> None: request_id='1234', event_name=eventNames.WORK_STARTED, event_payload={'hello': 'events'}, - publisher_id=self.__class__.__name__, + publisher_id=self.__class__.__qualname__, ) self.dispatcher.run_once() self.subscriber.unsubscribe() @@ -69,6 +69,6 @@ def callback(self, ev: Dict[str, Any]) -> None: 'event_timestamp': 1234567, 'event_name': eventNames.WORK_STARTED, 'event_payload': {'hello': 'events'}, - 'publisher_id': self.__class__.__name__, + 'publisher_id': self.__class__.__qualname__, }, )