diff --git a/Makefile b/Makefile index 81eab80..edcdec8 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,12 @@ setup: develop develop: install-dev-requirements install-test-requirements -test: +types: + @echo "Type checking Python files" + .venv/bin/mypy --pretty + @echo "" + +test: types @echo "Running Python tests" export VIRTUAL_ENV=.venv; .venv/bin/wait-for-it --service httpbin.local:443 --service localhost:6379 --timeout 5 -- .venv/bin/pytest tests/ || exit 1 @echo "" diff --git a/mocket/compat.py b/mocket/compat.py index be72767..4651b8e 100644 --- a/mocket/compat.py +++ b/mocket/compat.py @@ -1,27 +1,30 @@ +from __future__ import annotations + import codecs import os import shlex +from typing import Any, Final -ENCODING = os.getenv("MOCKET_ENCODING", "utf-8") +ENCODING: Final[str] = os.getenv("MOCKET_ENCODING", "utf-8") text_type = str byte_type = bytes basestring = (str,) -def encode_to_bytes(s, encoding=ENCODING): +def encode_to_bytes(s: str | bytes, encoding: str = ENCODING) -> bytes: if isinstance(s, text_type): s = s.encode(encoding) return byte_type(s) -def decode_from_bytes(s, encoding=ENCODING): +def decode_from_bytes(s: str | bytes, encoding: str = ENCODING) -> str: if isinstance(s, byte_type): s = codecs.decode(s, encoding, "ignore") return text_type(s) -def shsplit(s): +def shsplit(s: str | bytes) -> list[str]: s = decode_from_bytes(s) return shlex.split(s) diff --git a/mocket/utils.py b/mocket/utils.py index 5a0a420..7d4bf58 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -1,17 +1,26 @@ +from __future__ import annotations + import binascii import io import os import ssl -from typing import Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar from .compat import decode_from_bytes, encode_to_bytes from .exceptions import StrictMocketException +if TYPE_CHECKING: + from _typeshed import ReadableBuffer + from typing_extensions import NoReturn + SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2 class MocketSocketCore(io.BytesIO): - def write(self, content): + def write( # type: ignore[override] # BytesIO returns int + self, + content: ReadableBuffer, + ) -> None: super(MocketSocketCore, self).write(content) from mocket import Mocket @@ -20,7 +29,7 @@ def write(self, content): os.write(Mocket.w_fd, content) -def hexdump(binary_string): +def hexdump(binary_string: bytes) -> str: r""" >>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F")) True @@ -29,7 +38,7 @@ def hexdump(binary_string): return " ".join(a + b for a, b in zip(bs[::2], bs[1::2])) -def hexload(string): +def hexload(string: str) -> bytes: r""" >>> hexload("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F") == encode_to_bytes("bar foobar foo") True @@ -38,39 +47,40 @@ def hexload(string): return encode_to_bytes(binascii.unhexlify(string_no_spaces)) -def get_mocketize(wrapper_): +def get_mocketize(wrapper_: Callable) -> Callable: import decorator - if decorator.__version__ < "5": # pragma: no cover + if decorator.__version__ < "5": # type: ignore[attr-defined] # pragma: no cover return decorator.decorator(wrapper_) - return decorator.decorator(wrapper_, kwsyntax=True) + return decorator.decorator( # type: ignore[call-arg] # kwsyntax + wrapper_, + kwsyntax=True, + ) class MocketMode: - __shared_state = {} - STRICT = None - STRICT_ALLOWED = None + __shared_state: ClassVar[dict[str, Any]] = {} + STRICT: ClassVar = None + STRICT_ALLOWED: ClassVar = None - def __init__(self): + def __init__(self) -> None: self.__dict__ = self.__shared_state - def is_allowed(self, location: Union[str, Tuple[str, int]]) -> bool: + def is_allowed(self, location: str | tuple[str, int]) -> bool: """ Checks if (`host`, `port`) or at least `host` are allowed locations to perform real `socket` calls """ if not self.STRICT: return True - try: - host, _ = location - except ValueError: - host = None - return location in self.STRICT_ALLOWED or ( - host is not None and host in self.STRICT_ALLOWED - ) + + host_allowed = False + if isinstance(location, tuple): + host_allowed = location[0] in self.STRICT_ALLOWED + return host_allowed or location in self.STRICT_ALLOWED @staticmethod - def raise_not_allowed(): + def raise_not_allowed() -> NoReturn: from .mocket import Mocket current_entries = [ diff --git a/pyproject.toml b/pyproject.toml index 2d482f2..2e48ec1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,8 @@ test = [ "twine", "fastapi", "wait-for-it", + "mypy", + "types-decorator", ] speedups = [ "xxhash;platform_python_implementation=='CPython'", @@ -81,3 +83,27 @@ include = [ exclude = [ ".*", ] + +[tool.mypy] +python_version = "3.8" +files = [ + "mocket/exceptions.py", + "mocket/compat.py", + "mocket/utils.py", + # "tests/" + ] +strict = true +warn_unused_configs = true +ignore_missing_imports = true +warn_redundant_casts = true +warn_unused_ignores = true +show_error_codes = true +implicit_reexport = true +disallow_any_generics = false +follow_imports = "silent" # enable this once majority is typed +enable_error_code = ['ignore-without-code'] +disable_error_code = ["no-untyped-def"] # enable this once full type-coverage is reached + +[[tool.mypy.overrides]] +module = "tests.*" +disable_error_code = ['type-arg', 'no-untyped-def']