From b5aef48e63c7d2fca99dd27df8c503876e3df054 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Wed, 29 Dec 2021 13:13:08 +0530 Subject: [PATCH] Use `option` instead of long cast statements. Also try to automate and delegate resolution to flag initializer. This makes flag class independent of cli arg names --- proxy/common/flag.py | 130 +++++++++++++------------------------------ 1 file changed, 39 insertions(+), 91 deletions(-) diff --git a/proxy/common/flag.py b/proxy/common/flag.py index f0c7f7c7fb..9960647c1a 100644 --- a/proxy/common/flag.py +++ b/proxy/common/flag.py @@ -20,7 +20,6 @@ from typing import Optional, List, Any, cast from .plugins import Plugins -from .types import IpAddress from .utils import bytes_, is_py2, is_threadless, set_open_file_limit from .constants import COMMA, DEFAULT_DATA_DIRECTORY_PATH, DEFAULT_NUM_ACCEPTORS, DEFAULT_NUM_WORKERS from .constants import DEFAULT_DEVTOOLS_WS_PATH, DEFAULT_DISABLE_HEADERS, PY2_DEPRECATION_MESSAGE @@ -109,12 +108,27 @@ def initialize( print(__version__) sys.exit(0) + # https://github.com/python/mypy/issues/5865 + def option(t: object, key: str, default: Optional[Any] = None) -> Any: + return cast( + t, # type: ignore + opts.get( + key, + default or getattr(args, key), + ), + ) + + # Command line arguments MUST always take preference + # over kwargs passed to the program constructor. + # for f in args.__dict__.keys(): + # print(f) + # print(option(Any, f)) + # proxy.py currently cannot serve over HTTPS and also perform TLS interception # at the same time. Check if user is trying to enable both feature # at the same time. # - # TODO: Use parser.add_mutually_exclusive_group() - # and remove this logic from here. + # TODO: Use parser.add_mutually_exclusive_group() and remove this logic from here. if (args.cert_file and args.key_file) and \ (args.ca_key_file and args.ca_cert_file and args.ca_signing_key_file): print( @@ -157,27 +171,9 @@ def initialize( # --enable flags must be parsed before loading plugins # otherwise we will miss the plugins passed via constructor - args.enable_web_server = cast( - bool, - opts.get( - 'enable_web_server', - args.enable_web_server, - ), - ) - args.enable_static_server = cast( - bool, - opts.get( - 'enable_static_server', - args.enable_static_server, - ), - ) - args.enable_events = cast( - bool, - opts.get( - 'enable_events', - args.enable_events, - ), - ) + args.enable_web_server = option(bool, 'enable_web_server') + args.enable_static_server = option(bool, 'enable_static_server') + args.enable_events = option(bool, 'enable_events') # Load default plugins along with user provided --plugins default_plugins = [ @@ -191,10 +187,6 @@ def initialize( default_plugins + auth_plugins + requested_plugins, ) - # https://github.com/python/mypy/issues/5865 - # - # def option(t: object, key: str, default: Any) -> Any: - # return cast(t, opts.get(key, default)) args.work_klass = work_klass args.plugins = plugins args.auth_code = cast( @@ -204,20 +196,8 @@ def initialize( auth_code, ), ) - args.server_recvbuf_size = cast( - int, - opts.get( - 'server_recvbuf_size', - args.server_recvbuf_size, - ), - ) - args.client_recvbuf_size = cast( - int, - opts.get( - 'client_recvbuf_size', - args.client_recvbuf_size, - ), - ) + args.server_recvbuf_size = option(int, 'server_recvbuf_size') + args.client_recvbuf_size = option(int, 'client_recvbuf_size') args.pac_file = cast( Optional[str], opts.get( 'pac_file', bytes_( @@ -241,44 +221,18 @@ def initialize( ], ), ) - args.disable_headers = disabled_headers if disabled_headers is not None else DEFAULT_DISABLE_HEADERS - args.certfile = cast( - Optional[str], opts.get( - 'cert_file', args.cert_file, - ), - ) - args.keyfile = cast(Optional[str], opts.get('key_file', args.key_file)) - args.ca_key_file = cast( - Optional[str], opts.get( - 'ca_key_file', args.ca_key_file, - ), - ) - args.ca_cert_file = cast( - Optional[str], opts.get( - 'ca_cert_file', args.ca_cert_file, - ), - ) - args.ca_signing_key_file = cast( - Optional[str], - opts.get( - 'ca_signing_key_file', - args.ca_signing_key_file, - ), - ) - args.ca_file = cast( - Optional[str], - opts.get( - 'ca_file', - args.ca_file, - ), - ) - args.hostname = cast( - IpAddress, - opts.get('hostname', ipaddress.ip_address(args.hostname)), - ) - args.unix_socket_path = opts.get( - 'unix_socket_path', args.unix_socket_path, - ) + args.disable_headers = disabled_headers \ + if disabled_headers is not None \ + else DEFAULT_DISABLE_HEADERS + args.certfile = option(Optional[str], 'cert_file') + args.keyfile = option(Optional[str], 'key_file') + args.ca_key_file = option(Optional[str], 'ca_key_file') + args.ca_cert_file = option(Optional[str], 'ca_cert_file') + args.ca_signing_key_file = option(Optional[str], 'ca_signing_key_file') + args.ca_file = option(Optional[str], 'ca_file') + args.hostname = option(str, 'hostname') + args.hostname = ipaddress.ip_address(args.hostname) + args.unix_socket_path = option(str, 'unix_socket_path') # AF_UNIX is not available on Windows # See https://bugs.python.org/issue33408 if not IS_WINDOWS: @@ -294,13 +248,13 @@ def initialize( # # assert args.unix_socket_path is None args.family = socket.AF_INET6 if args.hostname.version == 6 else socket.AF_INET - args.port = cast(int, opts.get('port', args.port)) - args.backlog = cast(int, opts.get('backlog', args.backlog)) - num_workers = opts.get('num_workers', args.num_workers) + args.port = option(int, 'port') + args.backlog = option(int, 'backlog') + num_workers = option(int, 'num_workers') args.num_workers = cast( int, num_workers if num_workers > 0 else multiprocessing.cpu_count(), ) - num_acceptors = opts.get('num_acceptors', args.num_acceptors) + num_acceptors = option(int, 'num_acceptors') # See https://github.com/abhinavsingh/proxy.py/pull/714 description # to understand rationale behind the following logic. # @@ -314,13 +268,7 @@ def initialize( int, num_acceptors if num_acceptors > 0 else multiprocessing.cpu_count(), ) - args.static_server_dir = cast( - str, - opts.get( - 'static_server_dir', - args.static_server_dir, - ), - ) + args.static_server_dir = option(str, 'static_server_dir') args.min_compression_limit = cast( bool, opts.get(