diff --git a/proxy/common/constants.py b/proxy/common/constants.py index bd0a40e785..9727143b9b 100644 --- a/proxy/common/constants.py +++ b/proxy/common/constants.py @@ -114,6 +114,8 @@ def _env_threadless_compliant() -> bool: '{response_bytes} bytes - {connection_time_ms}ms' DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT = '{client_ip}:{client_port} - ' + \ '{request_method} {request_path} -> {upstream_proxy_pass} - {connection_time_ms}ms' +DEFAULT_LISTENER_POOL_KLASS = 'proxy.core.listener.pool.ListenerPool' +DEFAULT_ACCEPTOR_POOL_KLASS = 'proxy.core.acceptor.pool.AcceptorPool' DEFAULT_NUM_ACCEPTORS = 0 DEFAULT_NUM_WORKERS = 0 DEFAULT_OPEN_FILE_LIMIT = 1024 @@ -127,6 +129,7 @@ def _env_threadless_compliant() -> bool: DEFAULT_STATIC_SERVER_DIR = os.path.join(PROXY_PY_DIR, "public") DEFAULT_MIN_COMPRESSION_LENGTH = 20 # In bytes DEFAULT_THREADLESS = _env_threadless_compliant() +DEFAULT_THREADLESS_POOL_KLASS = 'proxy.core.work.pool.ThreadlessPool' DEFAULT_LOCAL_EXECUTOR = True DEFAULT_TIMEOUT = 10.0 DEFAULT_VERSION = False diff --git a/proxy/common/flag.py b/proxy/common/flag.py index f8395a6f62..5d3c986a38 100644 --- a/proxy/common/flag.py +++ b/proxy/common/flag.py @@ -138,6 +138,24 @@ def initialize( if isinstance(work_klass, str) \ else work_klass + # Load acceptor_pool_klass + acceptor_pool_klass = opts.get('acceptor_pool_klass', args.acceptor_pool_klass) + acceptor_pool_klass = Plugins.importer(bytes_(acceptor_pool_klass))[0] \ + if isinstance(acceptor_pool_klass, str) \ + else acceptor_pool_klass + + # Load listener_pool_klass + listener_pool_klass = opts.get('listener_pool_klass', args.listener_pool_klass) + listener_pool_klass = Plugins.importer(bytes_(listener_pool_klass))[0] \ + if isinstance(listener_pool_klass, str) \ + else listener_pool_klass + + # Load threadless_pool_klass + threadless_pool_klass = opts.get('threadless_pool_klass', args.threadless_pool_klass) + threadless_pool_klass = Plugins.importer(bytes_(threadless_pool_klass))[0] \ + if isinstance(threadless_pool_klass, str) \ + else threadless_pool_klass + # TODO: Plugin flag initialization logic must be moved within plugins. # # Generate auth_code required for basic authentication if enabled @@ -201,6 +219,8 @@ def initialize( # def option(t: object, key: str, default: Any) -> Any: # return cast(t, opts.get(key, default)) args.work_klass = work_klass + args.acceptor_pool_klass = acceptor_pool_klass + args.listener_pool_klass = listener_pool_klass args.plugins = plugins args.auth_code = cast( Optional[bytes], @@ -376,6 +396,7 @@ def initialize( # evaluates to False. args.threadless = cast(bool, opts.get('threadless', args.threadless)) args.threadless = is_threadless(args.threadless, args.threaded) + args.threadless_pool_klass = threadless_pool_klass args.pid_file = cast( Optional[str], opts.get( diff --git a/proxy/core/acceptor/pool.py b/proxy/core/acceptor/pool.py index 09fb9f447f..19871cef39 100644 --- a/proxy/core/acceptor/pool.py +++ b/proxy/core/acceptor/pool.py @@ -24,7 +24,9 @@ from .acceptor import Acceptor from ..listener import ListenerPool from ...common.flag import flags -from ...common.constants import DEFAULT_NUM_ACCEPTORS +from ...common.constants import ( + DEFAULT_NUM_ACCEPTORS, DEFAULT_ACCEPTOR_POOL_KLASS, +) if TYPE_CHECKING: # pragma: no cover @@ -33,6 +35,14 @@ logger = logging.getLogger(__name__) +flags.add_argument( + '--acceptor-pool-klass', + type=str, + default=DEFAULT_ACCEPTOR_POOL_KLASS, + help='Default: ' + DEFAULT_ACCEPTOR_POOL_KLASS + + '. Acceptor pool klass.', +) + flags.add_argument( '--num-acceptors', type=int, diff --git a/proxy/core/listener/pool.py b/proxy/core/listener/pool.py index b362ae558c..aef0b724fd 100644 --- a/proxy/core/listener/pool.py +++ b/proxy/core/listener/pool.py @@ -13,12 +13,23 @@ from .tcp import TcpSocketListener from .unix import UnixSocketListener +from ...common.flag import flags +from ...common.constants import DEFAULT_LISTENER_POOL_KLASS if TYPE_CHECKING: # pragma: no cover from .base import BaseListener +flags.add_argument( + '--listener-pool-klass', + type=str, + default=DEFAULT_LISTENER_POOL_KLASS, + help='Default: ' + DEFAULT_LISTENER_POOL_KLASS + + '. Listener pool klass.', +) + + class ListenerPool: """Provides abstraction around starting multiple listeners based upon flags.""" diff --git a/proxy/core/work/pool.py b/proxy/core/work/pool.py index 5458f0a89d..12d738f2c0 100644 --- a/proxy/core/work/pool.py +++ b/proxy/core/work/pool.py @@ -15,7 +15,9 @@ from multiprocessing import connection from ...common.flag import flags -from ...common.constants import DEFAULT_THREADLESS, DEFAULT_NUM_WORKERS +from ...common.constants import ( + DEFAULT_THREADLESS, DEFAULT_NUM_WORKERS, DEFAULT_THREADLESS_POOL_KLASS, +) if TYPE_CHECKING: # pragma: no cover @@ -54,6 +56,14 @@ help='Defaults to number of CPU cores.', ) +flags.add_argument( + '--threadless-pool-klass', + type=str, + default=DEFAULT_THREADLESS_POOL_KLASS, + help='Default: ' + DEFAULT_THREADLESS_POOL_KLASS + + '. Threadless pool klass.', +) + class ThreadlessPool: """Manages lifecycle of threadless pool and delegates work to them diff --git a/proxy/proxy.py b/proxy/proxy.py index d9d9f89798..3b4e23fa0c 100644 --- a/proxy/proxy.py +++ b/proxy/proxy.py @@ -199,7 +199,10 @@ def setup(self) -> None: self._write_pid_file() # We setup listeners first because of flags.port override # in case of ephemeral port being used - self.listeners = ListenerPool(flags=self.flags) + self.listeners = cast( + 'ListenerPool', + self.flags.listener_pool_klass(flags=self.flags), + ) self.listeners.setup() # Override flags.port to match the actual port # we are listening upon. This is necessary to preserve @@ -234,20 +237,26 @@ def setup(self) -> None: # Setup remote executors only if # --local-executor mode isn't enabled. if self.remote_executors_enabled: - self.executors = ThreadlessPool( - flags=self.flags, - event_queue=event_queue, - executor_klass=RemoteFdExecutor, + self.executors = cast( + 'ThreadlessPool', + self.flags.threadless_pool_klass( + flags=self.flags, + event_queue=event_queue, + executor_klass=RemoteFdExecutor, + ), ) self.executors.setup() # Setup acceptors - self.acceptors = AcceptorPool( - flags=self.flags, - listeners=self.listeners, - executor_queues=self.executors.work_queues if self.executors else [], - executor_pids=self.executors.work_pids if self.executors else [], - executor_locks=self.executors.work_locks if self.executors else [], - event_queue=event_queue, + self.acceptors = cast( + 'AcceptorPool', + self.flags.acceptor_pool_klass( + flags=self.flags, + listeners=self.listeners, + executor_queues=self.executors.work_queues if self.executors else [], + executor_pids=self.executors.work_pids if self.executors else [], + executor_locks=self.executors.work_locks if self.executors else [], + event_queue=event_queue, + ), ) self.acceptors.setup() # Start SSH tunnel acceptor if enabled