From 32207fe8a7ebdfb95271d8430c4977c7a654928c Mon Sep 17 00:00:00 2001 From: Piotr Kopalko Date: Mon, 21 Aug 2023 23:45:34 +0200 Subject: [PATCH] feat(typing): type util package (#2170) * feat: Add typing to util misc module * feat: Add typing to util reader module * feat: Type util structures module * feat: Added typing to util.sync * feat: Added typing to util.uri * feat: Added typing util deprecation module * feat: Add typing to util package * feat: Disallow untyped defs in utils package * feat: Simplify typing of is_python_func * feat: Added future annotations to reader module * feat: Simplified CaseInsensitiveDict init typing * feat: Join imports * feat: Use Result type var instead of Any --- falcon/util/__init__.py | 4 +- falcon/util/deprecation.py | 19 ++++++--- falcon/util/misc.py | 35 +++++++++------- falcon/util/reader.py | 77 ++++++++++++++++++++++------------- falcon/util/structures.py | 83 ++++++++++++++++++++++---------------- falcon/util/sync.py | 35 +++++++++++----- falcon/util/uri.py | 52 ++++++++++++++---------- pyproject.toml | 3 +- 8 files changed, 190 insertions(+), 118 deletions(-) diff --git a/falcon/util/__init__.py b/falcon/util/__init__.py index 6de8d40d1..3fec8b06e 100644 --- a/falcon/util/__init__.py +++ b/falcon/util/__init__.py @@ -21,6 +21,7 @@ from http import cookies as http_cookies import sys +from types import ModuleType # Hoist misc. utils from falcon.constants import PYTHON_VERSION @@ -77,7 +78,7 @@ ) -def __getattr__(name): +def __getattr__(name: str) -> ModuleType: if name == 'json': import warnings import json # NOQA @@ -86,7 +87,6 @@ def __getattr__(name): 'Importing json from "falcon.util" is deprecated.', DeprecatedWarning ) return json - from types import ModuleType # fallback to the default implementation mod = sys.modules[__name__] diff --git a/falcon/util/deprecation.py b/falcon/util/deprecation.py index a0a62e25c..5e2607ad7 100644 --- a/falcon/util/deprecation.py +++ b/falcon/util/deprecation.py @@ -18,6 +18,9 @@ """ import functools +from typing import Any +from typing import Callable +from typing import Optional import warnings @@ -41,7 +44,9 @@ class DeprecatedWarning(UserWarning): pass -def deprecated(instructions, is_property=False, method_name=None): +def deprecated( + instructions: str, is_property: bool = False, method_name: Optional[str] = None +) -> Callable[[Callable[..., Any]], Any]: """Flag a method as deprecated. This function returns a decorator which can be used to mark deprecated @@ -60,7 +65,7 @@ def deprecated(instructions, is_property=False, method_name=None): """ - def decorator(func): + def decorator(func: Callable[..., Any]) -> Callable[[Callable[..., Any]], Any]: object_name = 'property' if is_property else 'function' post_name = '' if is_property else '(...)' @@ -69,7 +74,7 @@ def decorator(func): ) @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]: warnings.warn(message, category=DeprecatedWarning, stacklevel=2) return func(*args, **kwargs) @@ -79,7 +84,9 @@ def wrapper(*args, **kwargs): return decorator -def deprecated_args(*, allowed_positional, is_method=True): +def deprecated_args( + *, allowed_positional: int, is_method: bool = True +) -> Callable[..., Callable[..., Any]]: """Flag a method call with positional args as deprecated. Keyword Args: @@ -98,9 +105,9 @@ def deprecated_args(*, allowed_positional, is_method=True): if is_method: allowed_positional += 1 - def deprecated_args(fn): + def deprecated_args(fn: Callable[..., Any]) -> Callable[..., Callable[..., Any]]: @functools.wraps(fn) - def wraps(*args, **kwargs): + def wraps(*args: Any, **kwargs: Any) -> Callable[..., Any]: if len(args) > allowed_positional: warnings.warn( warn_text.format(fn=fn.__qualname__), diff --git a/falcon/util/misc.py b/falcon/util/misc.py index 1c05d1090..2b7fb7b25 100644 --- a/falcon/util/misc.py +++ b/falcon/util/misc.py @@ -22,12 +22,17 @@ now = falcon.http_now() """ - import datetime import functools import http import inspect import re +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Tuple +from typing import Union import unicodedata from falcon import status_codes @@ -69,18 +74,18 @@ _UNSAFE_CHARS = re.compile(r'[^a-zA-Z0-9.-]') # PERF(kgriffs): Avoid superfluous namespace lookups -strptime = datetime.datetime.strptime -utcnow = datetime.datetime.utcnow +strptime: Callable[[str, str], datetime.datetime] = datetime.datetime.strptime +utcnow: Callable[[], datetime.datetime] = datetime.datetime.utcnow # NOTE(kgriffs,vytas): This is tested in the PyPy gate but we do not want devs # to have to install PyPy to check coverage on their workstations, so we use # the nocover pragma here. -def _lru_cache_nop(*args, **kwargs): # pragma: nocover - def decorator(func): +def _lru_cache_nop(maxsize: int) -> Callable[[Callable], Callable]: # pragma: nocover + def decorator(func: Callable) -> Callable: # NOTE(kgriffs): Partially emulate the lru_cache protocol; only add # cache_info() later if/when it becomes necessary. - func.cache_clear = lambda: None + func.cache_clear = lambda: None # type: ignore return func @@ -95,7 +100,7 @@ def decorator(func): _lru_cache_for_simple_logic = functools.lru_cache # type: ignore -def is_python_func(func): +def is_python_func(func: Union[Callable, Any]) -> bool: """Determine if a function or method uses a standard Python type. This helper can be used to check a function or method to determine if it @@ -251,7 +256,7 @@ def to_query_str( return query_str[:-1] -def get_bound_method(obj, method_name): +def get_bound_method(obj: object, method_name: str) -> Union[None, Callable[..., Any]]: """Get a bound method of the given object by name. Args: @@ -278,7 +283,7 @@ def get_bound_method(obj, method_name): return method -def get_argnames(func): +def get_argnames(func: Callable) -> List[str]: """Introspect the arguments of a callable. Args: @@ -308,7 +313,9 @@ def get_argnames(func): @deprecated('Please use falcon.code_to_http_status() instead.') -def get_http_status(status_code, default_reason=_DEFAULT_HTTP_REASON): +def get_http_status( + status_code: Union[str, int], default_reason: str = _DEFAULT_HTTP_REASON +) -> str: """Get both the http status code and description from just a code. Warning: @@ -387,7 +394,7 @@ def secure_filename(filename: str) -> str: @_lru_cache_for_simple_logic(maxsize=64) -def http_status_to_code(status): +def http_status_to_code(status: Union[http.HTTPStatus, int, bytes, str]) -> int: """Normalize an HTTP status to an integer code. This function takes a member of :class:`http.HTTPStatus`, an HTTP status @@ -425,7 +432,7 @@ def http_status_to_code(status): @_lru_cache_for_simple_logic(maxsize=64) -def code_to_http_status(status): +def code_to_http_status(status: Union[int, http.HTTPStatus, bytes, str]) -> str: """Normalize an HTTP status to an HTTP status line string. This function takes a member of :class:`http.HTTPStatus`, an ``int`` status @@ -473,7 +480,7 @@ def code_to_http_status(status): return '{} {}'.format(code, _DEFAULT_HTTP_REASON) -def _encode_items_to_latin1(data): +def _encode_items_to_latin1(data: Dict[str, str]) -> List[Tuple[bytes, bytes]]: """Decode all key/values of a dict to Latin-1. Args: @@ -491,7 +498,7 @@ def _encode_items_to_latin1(data): return result -def _isascii(string: str): +def _isascii(string: str) -> bool: """Return ``True`` if all characters in the string are ASCII. ASCII characters have code points in the range U+0000-U+007F. diff --git a/falcon/util/reader.py b/falcon/util/reader.py index dca60d094..263d66716 100644 --- a/falcon/util/reader.py +++ b/falcon/util/reader.py @@ -13,9 +13,14 @@ # limitations under the License. """Buffered stream reader.""" +from __future__ import annotations import functools import io +from typing import Callable +from typing import IO +from typing import List +from typing import Optional from falcon.errors import DelimiterError @@ -26,7 +31,12 @@ class BufferedReader: - def __init__(self, read, max_stream_len, chunk_size=None): + def __init__( + self, + read: Callable[[int], bytes], + max_stream_len: int, + chunk_size: Optional[int] = None, + ): self._read_func = read self._chunk_size = chunk_size or DEFAULT_CHUNK_SIZE self._max_join_size = self._chunk_size * _MAX_JOIN_CHUNKS @@ -36,7 +46,7 @@ def __init__(self, read, max_stream_len, chunk_size=None): self._buffer_pos = 0 self._max_bytes_remaining = max_stream_len - def _perform_read(self, size): + def _perform_read(self, size: int) -> bytes: # PERF(vytas): In Cython, bind types: # cdef bytes chunk # cdef Py_ssize_t chunk_len @@ -75,7 +85,7 @@ def _perform_read(self, size): self._max_bytes_remaining -= chunk_len result.write(chunk) - def _fill_buffer(self): + def _fill_buffer(self) -> None: # PERF(vytas): In Cython, bind types: # cdef Py_ssize_t read_size @@ -92,7 +102,7 @@ def _fill_buffer(self): self._buffer_len = len(self._buffer) - def peek(self, size=-1): + def peek(self, size: int = -1) -> bytes: if size < 0 or size > self._chunk_size: size = self._chunk_size @@ -101,7 +111,7 @@ def peek(self, size=-1): return self._buffer[self._buffer_pos : self._buffer_pos + size] - def _normalize_size(self, size): + def _normalize_size(self, size: Optional[int]) -> int: # PERF(vytas): In Cython, bind types: # cdef Py_ssize_t result # cdef Py_ssize_t max_size @@ -112,10 +122,10 @@ def _normalize_size(self, size): return max_size return size - def read(self, size=-1): + def read(self, size: int = -1) -> bytes: return self._read(self._normalize_size(size)) - def _read(self, size): + def _read(self, size: int) -> bytes: # PERF(vytas): In Cython, bind types: # cdef Py_ssize_t read_size # cdef bytes result @@ -150,7 +160,9 @@ def _read(self, size): self._buffer_pos = read_size return result + self._buffer[:read_size] - def read_until(self, delimiter, size=-1, consume_delimiter=False): + def read_until( + self, delimiter: bytes, size: int = -1, consume_delimiter: bool = False + ) -> bytes: # PERF(vytas): In Cython, bind types: # cdef Py_ssize_t read_size # cdef result @@ -168,15 +180,15 @@ def read_until(self, delimiter, size=-1, consume_delimiter=False): def _finalize_read_until( self, - size, - backlog, - have_bytes, - consume_bytes, - delimiter=None, - delimiter_pos=-1, - next_chunk=None, - next_chunk_len=0, - ): + size: int, + backlog: List[bytes], + have_bytes: int, + consume_bytes: int, + delimiter: Optional[bytes] = None, + delimiter_pos: int = -1, + next_chunk: Optional[bytes] = None, + next_chunk_len: int = 0, + ) -> bytes: if delimiter_pos < 0 and delimiter is not None: delimiter_pos = self._buffer.find(delimiter, self._buffer_pos) @@ -192,6 +204,7 @@ def _finalize_read_until( ret_value = b''.join(backlog) if next_chunk_len > 0: + assert next_chunk if self._buffer_len == 0: self._buffer = next_chunk self._buffer_len = next_chunk_len @@ -214,7 +227,9 @@ def _finalize_read_until( return ret_value - def _read_until(self, delimiter, size, consume_delimiter): + def _read_until( + self, delimiter: bytes, size: int, consume_delimiter: bool + ) -> bytes: # PERF(vytas): In Cython, bind types: # cdef list result = [] # cdef Py_ssize_t have_bytes = 0 @@ -223,7 +238,7 @@ def _read_until(self, delimiter, size, consume_delimiter): # cdef Py_ssize_t consume_bytes # cdef Py_ssize_t offset - result = [] + result: List[bytes] = [] have_bytes = 0 delimiter_len_1 = len(delimiter) - 1 delimiter_pos = -1 @@ -321,7 +336,7 @@ def _read_until(self, delimiter, size, consume_delimiter): self._buffer_pos = 0 self._buffer = next_chunk - def pipe(self, destination=None): + def pipe(self, destination: Optional[IO] = None) -> None: while True: chunk = self.read(self._chunk_size) if not chunk: @@ -331,8 +346,12 @@ def pipe(self, destination=None): destination.write(chunk) def pipe_until( - self, delimiter, destination=None, consume_delimiter=False, _size=None - ): + self, + delimiter: bytes, + destination: Optional[IO] = None, + consume_delimiter: bool = False, + _size: Optional[int] = None, + ) -> None: # PERF(vytas): In Cython, bind types: # cdef Py_ssize_t remaining @@ -354,14 +373,14 @@ def pipe_until( raise DelimiterError('expected delimiter missing') self._buffer_pos += delimiter_len - def exhaust(self): + def exhaust(self) -> None: self.pipe() - def delimit(self, delimiter): + def delimit(self, delimiter: bytes) -> BufferedReader: read = functools.partial(self.read_until, delimiter) return type(self)(read, self._normalize_size(None), self._chunk_size) - def readline(self, size=-1): + def readline(self, size: int = -1) -> bytes: size = self._normalize_size(size) result = self.read_until(b'\n', size) @@ -369,7 +388,7 @@ def readline(self, size=-1): return result + self.read(1) return result - def readlines(self, hint=-1): + def readlines(self, hint: int = -1) -> List[bytes]: # PERF(vytas): In Cython, bind types: # cdef Py_ssize_t read # cdef list result = [] @@ -391,14 +410,14 @@ def readlines(self, hint=-1): # --- implementing IOBase methods, the duck-typing way --- - def readable(self): + def readable(self) -> bool: """Return ``True`` always.""" return True - def seekable(self): + def seekable(self) -> bool: """Return ``False`` always.""" return False - def writeable(self): + def writeable(self) -> bool: """Return ``False`` always.""" return False diff --git a/falcon/util/structures.py b/falcon/util/structures.py index 13025cf4d..fc7ba2a88 100644 --- a/falcon/util/structures.py +++ b/falcon/util/structures.py @@ -25,9 +25,19 @@ things = falcon.CaseInsensitiveDict() """ +from __future__ import annotations from collections.abc import Mapping from collections.abc import MutableMapping +from typing import Any +from typing import Dict +from typing import ItemsView +from typing import Iterable +from typing import Iterator +from typing import KeysView +from typing import Optional +from typing import Tuple +from typing import ValuesView # TODO(kgriffs): If we ever diverge from what is upstream in Requests, @@ -61,34 +71,34 @@ class CaseInsensitiveDict(MutableMapping): # pragma: no cover """ - def __init__(self, data=None, **kwargs): - self._store = dict() + def __init__(self, data: Optional[Iterable[Tuple[str, Any]]] = None, **kwargs: Any): + self._store: Dict[str, Tuple[str, Any]] = dict() if data is None: data = {} self.update(data, **kwargs) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: # Use the lowercased key for lookups, but store the actual # key alongside the value. self._store[key.lower()] = (key, value) - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: return self._store[key.lower()][1] - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: del self._store[key.lower()] - def __iter__(self): + def __iter__(self) -> Iterator[str]: return (casedkey for casedkey, mappedvalue in self._store.values()) - def __len__(self): + def __len__(self) -> int: return len(self._store) - def lower_items(self): + def lower_items(self) -> Iterator[Tuple[str, Any]]: """Like iteritems(), but with all lowercase keys.""" return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items()) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Mapping): other = CaseInsensitiveDict(other) else: @@ -97,10 +107,10 @@ def __eq__(self, other): return dict(self.lower_items()) == dict(other.lower_items()) # Copy is required - def copy(self): + def copy(self) -> CaseInsensitiveDict: return CaseInsensitiveDict(self._store.values()) - def __repr__(self): + def __repr__(self) -> str: return '%s(%r)' % (self.__class__.__name__, dict(self.items())) @@ -131,77 +141,80 @@ class Context: True """ - def __contains__(self, key): + def __contains__(self, key: str) -> bool: return self.__dict__.__contains__(key) - def __getitem__(self, key): + def __getitem__(self, key: str) -> Optional[Any]: # PERF(vytas): On CPython, using this mapping interface (instead of a # standard dict) to get, set and delete items incurs overhead # approximately comparable to that of two function calls # (per get/set/delete operation, that is). return self.__dict__.__getitem__(key) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: return self.__dict__.__setitem__(key, value) - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: self.__dict__.__delitem__(key) - def __iter__(self): + def __iter__(self) -> Iterator[str]: return self.__dict__.__iter__() - def __len__(self): + def __len__(self) -> int: return self.__dict__.__len__() - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, type(self)): return self.__dict__.__eq__(other.__dict__) return self.__dict__.__eq__(other) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: if isinstance(other, type(self)): return self.__dict__.__ne__(other.__dict__) return self.__dict__.__ne__(other) - def __hash__(self): + def __hash__(self) -> int: return hash(self.__dict__) - def __repr__(self): + def __repr__(self) -> str: return '{}({})'.format(type(self).__name__, self.__dict__.__repr__()) - def __str__(self): + def __str__(self) -> str: return '{}({})'.format(type(self).__name__, self.__dict__.__str__()) - def clear(self): + def clear(self) -> None: return self.__dict__.clear() - def copy(self): + def copy(self) -> Context: ctx = type(self)() ctx.update(self.__dict__) return ctx - def get(self, key, default=None): + def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: return self.__dict__.get(key, default) - def items(self): + def items(self) -> ItemsView[str, Any]: return self.__dict__.items() - def keys(self): + def keys(self) -> KeysView[str]: return self.__dict__.keys() - def pop(self, key, default=None): + def pop(self, key: str, default: Optional[Any] = None) -> Optional[Any]: return self.__dict__.pop(key, default) - def popitem(self): + def popitem(self) -> Tuple[str, Any]: + return self.__dict__.popitem() - def setdefault(self, key, default_value=None): + def setdefault( + self, key: str, default_value: Optional[Any] = None + ) -> Optional[Any]: return self.__dict__.setdefault(key, default_value) - def update(self, items): + def update(self, items: dict[str, Any]) -> None: self.__dict__.update(items) - def values(self): + def values(self) -> ValuesView: return self.__dict__.values() @@ -243,7 +256,7 @@ def on_get(self, req, resp): is_weak = False - def strong_compare(self, other): + def strong_compare(self, other: ETag) -> bool: """Perform a strong entity-tag comparison. Two entity-tags are equivalent if both are not weak and their @@ -262,7 +275,7 @@ def strong_compare(self, other): return self == other and not (self.is_weak or other.is_weak) - def dumps(self): + def dumps(self) -> str: """Serialize the ETag to a string suitable for use in a precondition header. (See also: RFC 7232, Section 2.3) @@ -280,7 +293,7 @@ def dumps(self): return '"' + self + '"' @classmethod - def loads(cls, etag_str): + def loads(cls, etag_str: str) -> ETag: """Deserialize a single entity-tag string from a precondition header. Note: diff --git a/falcon/util/sync.py b/falcon/util/sync.py index 7b6107f19..db9e21e37 100644 --- a/falcon/util/sync.py +++ b/falcon/util/sync.py @@ -4,7 +4,12 @@ from functools import wraps import inspect import os +from typing import Any +from typing import Awaitable from typing import Callable +from typing import Optional +from typing import TypeVar +from typing import Union __all__ = [ @@ -17,14 +22,13 @@ 'wrap_sync_to_async_unsafe', ] - _one_thread_to_rule_them_all = ThreadPoolExecutor(max_workers=1) create_task = asyncio.create_task get_running_loop = asyncio.get_running_loop -def wrap_sync_to_async_unsafe(func) -> Callable: +def wrap_sync_to_async_unsafe(func: Callable[..., Any]) -> Callable[..., Any]: """Wrap a callable in a coroutine that executes the callable directly. This helper makes it easier to use synchronous callables with ASGI @@ -48,13 +52,15 @@ def wrap_sync_to_async_unsafe(func) -> Callable: """ @wraps(func) - async def wrapper(*args, **kwargs): + async def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]: return func(*args, **kwargs) return wrapper -def wrap_sync_to_async(func, threadsafe=None) -> Callable: +def wrap_sync_to_async( + func: Callable[..., Any], threadsafe: Optional[bool] = None +) -> Callable[..., Any]: """Wrap a callable in a coroutine that executes the callable in the background. This helper makes it easier to call functions that can not be @@ -94,7 +100,7 @@ def wrap_sync_to_async(func, threadsafe=None) -> Callable: executor = _one_thread_to_rule_them_all @wraps(func) - async def wrapper(*args, **kwargs): + async def wrapper(*args: Any, **kwargs: Any) -> Any: return await get_running_loop().run_in_executor( executor, partial(func, *args, **kwargs) ) @@ -102,7 +108,9 @@ async def wrapper(*args, **kwargs): return wrapper -async def sync_to_async(func, *args, **kwargs): +async def sync_to_async( + func: Callable[..., Any], *args: Any, **kwargs: Any +) -> Callable[..., Awaitable[Any]]: """Schedule a synchronous callable on the default executor and await the result. This helper makes it easier to call functions that can not be @@ -153,7 +161,9 @@ def _should_wrap_non_coroutines() -> bool: return 'FALCON_ASGI_WRAP_NON_COROUTINES' in os.environ -def _wrap_non_coroutine_unsafe(func): +def _wrap_non_coroutine_unsafe( + func: Optional[Callable[..., Any]] +) -> Union[Callable[..., Awaitable[Any]], Callable[..., Any], None]: """Wrap a coroutine using ``wrap_sync_to_async_unsafe()`` for internal test cases. This method is intended for Falcon's own test suite and should not be @@ -180,7 +190,12 @@ def _wrap_non_coroutine_unsafe(func): return wrap_sync_to_async_unsafe(func) -def async_to_sync(coroutine, *args, **kwargs): +Result = TypeVar('Result') + + +def async_to_sync( + coroutine: Callable[..., Awaitable[Result]], *args: Any, **kwargs: Any +) -> Result: """Invoke a coroutine function from a synchronous caller. This method can be used to invoke an asynchronous task from a synchronous @@ -212,7 +227,7 @@ def async_to_sync(coroutine, *args, **kwargs): return loop.run_until_complete(coroutine(*args, **kwargs)) -def runs_sync(coroutine): +def runs_sync(coroutine: Callable[..., Awaitable[Result]]) -> Callable[..., Result]: """Transform a coroutine function into a synchronous method. This is achieved by always invoking the decorated coroutine function via @@ -234,7 +249,7 @@ def runs_sync(coroutine): """ @wraps(coroutine) - def invoke(*args, **kwargs): + def invoke(*args: Any, **kwargs: Any) -> Any: return async_to_sync(coroutine, *args, **kwargs) return invoke diff --git a/falcon/util/uri.py b/falcon/util/uri.py index e6078dfb9..f4092634c 100644 --- a/falcon/util/uri.py +++ b/falcon/util/uri.py @@ -22,8 +22,12 @@ name, port = uri.parse_host('example.org:8080') """ - +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional from typing import Tuple, TYPE_CHECKING +from typing import Union from falcon.constants import PYPY @@ -52,7 +56,7 @@ } -def _create_char_encoder(allowed_chars): +def _create_char_encoder(allowed_chars: str) -> Callable[[int], str]: lookup = {} @@ -67,13 +71,15 @@ def _create_char_encoder(allowed_chars): return lookup.__getitem__ -def _create_str_encoder(is_value, check_is_escaped=False): +def _create_str_encoder( + is_value: bool, check_is_escaped: bool = False +) -> Callable[[str], str]: allowed_chars = _UNRESERVED if is_value else _ALL_ALLOWED allowed_chars_plus_percent = allowed_chars + '%' encode_char = _create_char_encoder(allowed_chars) - def encoder(uri): + def encoder(uri: str) -> str: # PERF(kgriffs): Very fast way to check, learned from urlib.quote if not uri.rstrip(allowed_chars): return uri @@ -107,7 +113,7 @@ def encoder(uri): # partially encoded, the caller will need to normalize it # before passing it in here. - uri = uri.encode() + encoded_uri = uri.encode() # Use our map to encode each char and join the result into a new uri # @@ -115,7 +121,7 @@ def encoder(uri): # CPython 3 (tested on CPython 3.5 and 3.7). A list comprehension # can be faster on PyPy3, but the difference is on the order of # nanoseconds in that case, so we aren't going to worry about it. - return ''.join(map(encode_char, uri)) + return ''.join(map(encode_char, encoded_uri)) return encoder @@ -143,7 +149,7 @@ def encoder(uri): """ encode_value = _create_str_encoder(True) -encode_value.name = 'encode_value' +encode_value.__name__ = 'encode_value' encode_value.__doc__ = """Encodes a value string according to RFC 3986. Disallowed characters are percent-encoded in a way that models @@ -171,7 +177,7 @@ def encoder(uri): """ encode_check_escaped = _create_str_encoder(False, True) -encode_check_escaped.name = 'encode_check_escaped' +encode_check_escaped.__name__ = 'encode_check_escaped' encode_check_escaped.__doc__ = """Encodes a full or relative URI according to RFC 3986. RFC 3986 defines a set of "unreserved" characters as well as a @@ -195,7 +201,7 @@ def encoder(uri): """ encode_value_check_escaped = _create_str_encoder(True, True) -encode_value_check_escaped.name = 'encode_value_check_escaped' +encode_value_check_escaped.__name__ = 'encode_value_check_escaped' encode_value_check_escaped.__doc__ = """Encodes a value string according to RFC 3986. RFC 3986 defines a set of "unreserved" characters as well as a @@ -224,7 +230,7 @@ def encoder(uri): """ -def _join_tokens_bytearray(tokens): +def _join_tokens_bytearray(tokens: List[bytes]) -> str: decoded_uri = bytearray(tokens[0]) for token in tokens[1:]: token_partial = token[:2] @@ -238,7 +244,7 @@ def _join_tokens_bytearray(tokens): return decoded_uri.decode('utf-8', 'replace') -def _join_tokens_list(tokens): +def _join_tokens_list(tokens: List[bytes]) -> str: decoded = tokens[:1] # PERF(vytas): Do not copy list: a simple bool flag is fastest on PyPy JIT. skip = True @@ -270,7 +276,7 @@ def _join_tokens_list(tokens): _join_tokens = _join_tokens_list if PYPY else _join_tokens_bytearray -def decode(encoded_uri, unquote_plus=True): +def decode(encoded_uri: str, unquote_plus: bool = True) -> str: """Decode percent-encoded characters in a URI or query string. This function models the behavior of `urllib.parse.unquote_plus`, @@ -306,31 +312,33 @@ def decode(encoded_uri, unquote_plus=True): # NOTE(kgriffs): Clients should never submit a URI that has # unescaped non-ASCII chars in them, but just in case they # do, let's encode into a non-lossy format. - decoded_uri = decoded_uri.encode() + reencoded_uri = decoded_uri.encode() # PERF(kgriffs): This was found to be faster than using # a regex sub call or list comprehension with a join. - tokens = decoded_uri.split(b'%') + tokens = reencoded_uri.split(b'%') # PERF(vytas): Just use in-place add for a low number of items: if len(tokens) < 8: - decoded_uri = tokens[0] + reencoded_uri = tokens[0] for token in tokens[1:]: token_partial = token[:2] try: - decoded_uri += _HEX_TO_BYTE[token_partial] + token[2:] + reencoded_uri += _HEX_TO_BYTE[token_partial] + token[2:] except KeyError: # malformed percentage like "x=%" or "y=%+" - decoded_uri += b'%' + token + reencoded_uri += b'%' + token # Convert back to str - return decoded_uri.decode('utf-8', 'replace') + return reencoded_uri.decode('utf-8', 'replace') # NOTE(vytas): Decode percent-encoded bytestring fragments and join them # back to a string using the platform-dependent method. return _join_tokens(tokens) -def parse_query_string(query_string: str, keep_blank: bool = False, csv: bool = True): +def parse_query_string( + query_string: str, keep_blank: bool = False, csv: bool = True +) -> Dict[str, Union[str, List[str]]]: """Parse a query string into a dict. Query string parameters are assumed to use standard form-encoding. Only @@ -455,7 +463,9 @@ def parse_query_string(query_string: str, keep_blank: bool = False, csv: bool = return params -def parse_host(host: str, default_port=None) -> Tuple[str, int]: +def parse_host( + host: str, default_port: Optional[int] = None +) -> Tuple[str, Optional[int]]: """Parse a canonical 'host:port' string into parts. Parse a host string (which may or may not contain a port) into @@ -506,7 +516,7 @@ def parse_host(host: str, default_port=None) -> Tuple[str, int]: return (name, int(port)) -def unquote_string(quoted): +def unquote_string(quoted: str) -> str: """Unquote an RFC 7320 "quoted-string". Args: diff --git a/pyproject.toml b/pyproject.toml index f3d986b06..ee04f5013 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,8 @@ [[tool.mypy.overrides]] module = [ - "falcon.stream" + "falcon.stream", + "falcon.util.*" ] disallow_untyped_defs = true