diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 2886f57..796fe11 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -17,7 +17,7 @@ jobs: - uses: wntrblm/nox@2023.04.22 with: - python-versions: "3.8, 3.11, pypy3.10" + python-versions: "3.8, 3.12, pypy3.10" - name: Install Intel SDE run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 92912a4..f9664b9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,4 +25,9 @@ repos: rev: v1.8.0 hooks: - id: mypy - exclude: docs|conftest.py|tests + exclude: docs|conftest.py + args: ["--python-version=3.7"] + additional_dependencies: + - nox + - pytest + - types-setuptools diff --git a/MANIFEST.in b/MANIFEST.in index bb06398..189dfb1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,6 +4,9 @@ include LICENSE # Include headers include src/pybase64/_pybase64_get_simd_flags.h +# Include type stub for extension +include src/pybase64/_pybase64.pyi + # Include full base64 folder graft base64 # but the git folder diff --git a/docs/conf.py b/docs/conf.py index f6873a0..bac1178 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -61,7 +61,7 @@ # |version| and |release|, also used in various other places throughout the # built documents. # Get version -_version = runpy.run_path(os.path.join(here, "..", "src", "pybase64", "_version.py"))["__version__"] +_version = runpy.run_path(os.path.join(here, "..", "src", "pybase64", "_version.py"))["_version"] # The short X.Y version. version = _version # The full version, including alpha/beta/rc tags. diff --git a/noxfile.py b/noxfile.py index d2c65ee..76a1cc8 100644 --- a/noxfile.py +++ b/noxfile.py @@ -25,7 +25,7 @@ def lint(session: nox.Session) -> None: def update_env_macos(session: nox.Session, env: dict[str, str]) -> None: if sys.platform.startswith("darwin"): # we don't support universal builds - machine = session.run( + machine = session.run( # type: ignore[union-attr] "python", "-sSEc", "import platform; print(platform.machine())", silent=True ).strip() env["ARCHFLAGS"] = f"-arch {machine}" @@ -37,7 +37,7 @@ def remove_extension(session: nox.Session, in_place: bool = False) -> None: where = HERE / "src" / "pybase64" else: command = "import sysconfig; print(sysconfig.get_path('platlib'))" - platlib = session.run("python", "-c", command, silent=True).strip() + platlib = session.run("python", "-c", command, silent=True).strip() # type: ignore[union-attr] where = Path(platlib) / "pybase64" assert where.exists() @@ -54,7 +54,7 @@ def remove_extension(session: nox.Session, in_place: bool = False) -> None: @nox.session(python="3.12") def develop(session: nox.Session) -> None: """create venv for dev.""" - session.install("-r", "requirements-test.txt") + session.install("nox", "setuptools", "-r", "requirements-test.txt") # make extension mandatory by exporting CIBUILDWHEEL=1 env = {"CIBUILDWHEEL": "1"} update_env_macos(session, env) @@ -76,7 +76,7 @@ def test(session: nox.Session) -> None: session.run("pytest", *session.posargs, env=env) -@nox.session(python=["3.8", "3.11", "pypy3.10"]) +@nox.session(python=["3.8", "3.12", "pypy3.10"]) def _coverage(session: nox.Session) -> None: """internal coverage run. Do not run manually""" with_sde = "--with-sde" in session.posargs @@ -86,7 +86,6 @@ def _coverage(session: nox.Session) -> None: "--cov=pybase64", "--cov=tests", "--cov-append", - "--cov-branch", "--cov-report=", ) pytest_command = ("pytest", *coverage_args) @@ -144,7 +143,7 @@ def coverage(session: nox.Session) -> None: posargs.add("--report") session.notify("_coverage-3.8", ["--clean"]) session.notify("_coverage-pypy3.10", []) - session.notify("_coverage-3.11", posargs) + session.notify("_coverage-3.12", posargs) @nox.session(python="3.11") diff --git a/pyproject.toml b/pyproject.toml index 0391469..ad0713e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,56 @@ test-requires = "-r requirements-test.txt" test-command = "pytest {project}/tests" build-verbosity = 1 +[tool.coverage.run] +branch = true + +[tool.coverage.report] +exclude_lines = ["pragma: no cover", "class .*\\(Protocol\\):", "if TYPE_CHECKING:"] + +[tool.mypy] +python_version = "3.7" +files = [ + "src/**/*.py", + "test/**/*.py", + "noxfile.py", + "setup.py", +] +warn_unused_configs = true +show_error_codes = true + +warn_redundant_casts = true +no_implicit_reexport = true +strict_equality = true +warn_unused_ignores = true +check_untyped_defs = true +ignore_missing_imports = false + +disallow_subclassing_any = true +disallow_any_generics = true +disallow_untyped_defs = true +disallow_untyped_calls = true +disallow_incomplete_defs = true +disallow_untyped_decorators = true +disallow_any_explicit = true +warn_return_any = true + +no_implicit_optional = true +enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] +warn_unreachable = true + +[[tool.mypy.overrides]] +module = ["pybase64.__main__"] +disallow_any_explicit = false + +[[tool.mypy.overrides]] +module = ["tests.test_pybase64"] +disallow_any_explicit = false + + +[tool.pytest.ini_options] +minversion = "7.0" +addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config", "-p", "no:legacypath"] + [tool.ruff] target-version = "py37" line-length = 100 @@ -37,15 +87,9 @@ extend-select = [ ignore = [ "PLR", # Design related pylint codes ] +typing-modules = ["pybase64._typing"] [tool.ruff.lint.flake8-tidy-imports.banned-api] "typing.Callable".msg = "Use collections.abc.Callable instead." +"typing.Iterator".msg = "Use collections.abc.Iterator instead." "typing.Sequence".msg = "Use collections.abc.Sequence instead." - -[tool.mypy] -python_version = "3.7" -follow_imports = "silent" -ignore_missing_imports = true -disallow_untyped_defs = true -disallow_any_generics = true -warn_unused_ignores = true diff --git a/requirements-test.txt b/requirements-test.txt index cb87efc..c16bdd0 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1 +1,2 @@ pytest==7.4.4 +typing_extensions>=4.6.0 diff --git a/setup.py b/setup.py index 93df91f..6e04871 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ import sysconfig from contextlib import contextmanager from pathlib import Path -from typing import Generator +from typing import Generator, cast from setuptools import Extension, find_packages, setup from setuptools.command.build_ext import build_ext @@ -25,7 +25,7 @@ # Get version version_dict: dict[str, object] = {} exec(HERE.joinpath("src", "pybase64", "_version.py").read_text(), {}, version_dict) -version = version_dict["__version__"] +version = cast(str, version_dict["_version"]) # Get the long description from the README file long_description = HERE.joinpath("README.rst").read_text() @@ -244,7 +244,7 @@ def run(self) -> None: # simple. Or you can use find_packages(). packages=find_packages(where="src"), package_dir={"": "src"}, - package_data={"pybase64": ["py.typed"]}, + package_data={"pybase64": ["py.typed", "_pybase64.pyi"]}, # To provide executable scripts, use entry points in preference to the # "scripts" keyword. Entry points provide cross-platform support and allow # pip to create the appropriate form of executable for the target platform. diff --git a/src/pybase64/__init__.py b/src/pybase64/__init__.py index d8e2ff6..79768bd 100644 --- a/src/pybase64/__init__.py +++ b/src/pybase64/__init__.py @@ -1,9 +1,12 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING from ._license import _license -from ._version import __version__ +from ._version import _version + +if TYPE_CHECKING: + from ._typing import Buffer try: from ._pybase64 import ( @@ -16,7 +19,7 @@ encodebytes, ) except ImportError: - from ._fallback import ( # noqa: F401 + from ._fallback import ( _get_simd_name, _get_simd_path, b64decode, @@ -27,6 +30,21 @@ ) +__all__ = ( + "b64decode", + "b64decode_as_bytearray", + "b64encode", + "b64encode_as_string", + "encodebytes", + "standard_b64encode", + "standard_b64decode", + "urlsafe_b64encode", + "urlsafe_b64decode", +) + +__version__ = _version + + def get_license_text() -> str: """Returns pybase64 license information as a :class:`str` object. @@ -47,7 +65,7 @@ def get_version() -> str: return f"{__version__} (C extension inactive)" -def standard_b64encode(s: Any) -> bytes: +def standard_b64encode(s: Buffer) -> bytes: """Encode bytes using the standard Base64 alphabet. Argument ``s`` is a :term:`bytes-like object` to encode. @@ -57,7 +75,7 @@ def standard_b64encode(s: Any) -> bytes: return b64encode(s) -def standard_b64decode(s: Any) -> bytes: +def standard_b64decode(s: str | Buffer) -> bytes: """Decode bytes encoded with the standard Base64 alphabet. Argument ``s`` is a :term:`bytes-like object` or ASCII string to @@ -73,7 +91,7 @@ def standard_b64decode(s: Any) -> bytes: return b64decode(s) -def urlsafe_b64encode(s: Any) -> bytes: +def urlsafe_b64encode(s: Buffer) -> bytes: """Encode bytes using the URL- and filesystem-safe Base64 alphabet. Argument ``s`` is a :term:`bytes-like object` to encode. @@ -85,7 +103,7 @@ def urlsafe_b64encode(s: Any) -> bytes: return b64encode(s, b"-_") -def urlsafe_b64decode(s: Any) -> bytes: +def urlsafe_b64decode(s: str | Buffer) -> bytes: """Decode bytes using the URL- and filesystem-safe Base64 alphabet. Argument ``s`` is a :term:`bytes-like object` or ASCII string to diff --git a/src/pybase64/__main__.py b/src/pybase64/__main__.py index 319005c..b6d8f24 100644 --- a/src/pybase64/__main__.py +++ b/src/pybase64/__main__.py @@ -5,20 +5,23 @@ import sys from base64 import b64decode as b64decodeValidate from base64 import encodebytes as b64encodebytes -from collections.abc import Callable, Sequence +from collections.abc import Sequence from timeit import default_timer as timer -from typing import Any, BinaryIO +from typing import TYPE_CHECKING, Any, BinaryIO, cast import pybase64 +if TYPE_CHECKING: + from pybase64._typing import Decode, Encode, EncodeBytes + def bench_one( duration: float, data: bytes, - enc: Callable[..., bytes], - dec: Callable[..., bytes], - encbytes: Callable[[Any], bytes], - altchars: Any | None = None, + enc: Encode, + dec: Decode, + encbytes: EncodeBytes, + altchars: bytes | None = None, validate: bool = False, ) -> None: duration = duration / 2.0 @@ -93,11 +96,11 @@ def bench_one( def readall(file: BinaryIO) -> bytes: - if file == sys.stdin: + if file == cast(BinaryIO, sys.stdin): # Python 3 < 3.9 does not honor the binary flag, # read from the underlying buffer if hasattr(file, "buffer"): - return file.buffer.read() + return cast(BinaryIO, file.buffer).read() return file.read() # pragma: no cover # do not close the file try: @@ -108,7 +111,7 @@ def readall(file: BinaryIO) -> bytes: def writeall(file: BinaryIO, data: bytes) -> None: - if file == sys.stdout: + if file == cast(BinaryIO, sys.stdout): # Python 3 does not honor the binary flag, # write to the underlying buffer if hasattr(file, "buffer"): @@ -142,7 +145,7 @@ def benchmark(duration: float, input: BinaryIO) -> None: duration, data, base64.b64encode, - b64decodeValidate, + b64decodeValidate, # type: ignore[arg-type] # c.f. https://github.com/python/typeshed/pull/11210 b64encodebytes, altchars, validate, diff --git a/src/pybase64/_fallback.py b/src/pybase64/_fallback.py index 7cdc644..4c0866f 100644 --- a/src/pybase64/_fallback.py +++ b/src/pybase64/_fallback.py @@ -4,17 +4,10 @@ from base64 import b64encode as builtin_encode from base64 import encodebytes as builtin_encodebytes from binascii import Error as BinAsciiError -from typing import Any - -__all__ = [ - "_get_simd_name", - "_get_simd_path", - "b64decode", - "b64encode", - "b64encode_as_string", - "encodebytes", -] +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ._typing import Buffer _bytes_types = (bytes, bytearray) # Types acceptable as binary data @@ -28,7 +21,7 @@ def _get_simd_path() -> int: return 0 -def _get_bytes(s: Any) -> bytes | bytearray: +def _get_bytes(s: str | Buffer) -> bytes | bytearray: if isinstance(s, str): try: return s.encode("ascii") @@ -51,7 +44,9 @@ def _get_bytes(s: Any) -> bytes | bytearray: raise TypeError(msg) from None -def b64decode(s: Any, altchars: Any = None, validate: bool = False) -> bytes: +def b64decode( + s: str | Buffer, altchars: str | Buffer | None = None, validate: bool = False +) -> bytes: """Decode bytes encoded with the standard Base64 alphabet. Argument ``s`` is a :term:`bytes-like object` or ASCII string to @@ -97,7 +92,9 @@ def b64decode(s: Any, altchars: Any = None, validate: bool = False) -> bytes: return builtin_decode(s, altchars, validate=False) -def b64decode_as_bytearray(s: Any, altchars: Any = None, validate: bool = False) -> bytearray: +def b64decode_as_bytearray( + s: str | Buffer, altchars: str | Buffer | None = None, validate: bool = False +) -> bytearray: """Decode bytes encoded with the standard Base64 alphabet. Argument ``s`` is a :term:`bytes-like object` or ASCII string to @@ -120,7 +117,7 @@ def b64decode_as_bytearray(s: Any, altchars: Any = None, validate: bool = False) return bytearray(b64decode(s, altchars=altchars, validate=validate)) -def b64encode(s: Any, altchars: Any = None) -> bytes: +def b64encode(s: Buffer, altchars: str | Buffer | None = None) -> bytes: """Encode bytes using the standard Base64 alphabet. Argument ``s`` is a :term:`bytes-like object` to encode. @@ -140,7 +137,7 @@ def b64encode(s: Any, altchars: Any = None) -> bytes: return builtin_encode(s, altchars) -def b64encode_as_string(s: Any, altchars: Any = None) -> str: +def b64encode_as_string(s: Buffer, altchars: str | Buffer | None = None) -> str: """Encode bytes using the standard Base64 alphabet. Argument ``s`` is a :term:`bytes-like object` to encode. @@ -154,7 +151,7 @@ def b64encode_as_string(s: Any, altchars: Any = None) -> str: return b64encode(s, altchars).decode("ascii") -def encodebytes(s: Any) -> bytes: +def encodebytes(s: Buffer) -> bytes: """Encode bytes into a bytes object with newlines (b'\\\\n') inserted after every 76 bytes of output, and ensuring that there is a trailing newline, as per :rfc:`2045` (MIME). diff --git a/src/pybase64/_license.pyi b/src/pybase64/_license.pyi new file mode 100644 index 0000000..39c5306 --- /dev/null +++ b/src/pybase64/_license.pyi @@ -0,0 +1 @@ +_license: str diff --git a/src/pybase64/_pybase64.pyi b/src/pybase64/_pybase64.pyi new file mode 100644 index 0000000..4cb4ab0 --- /dev/null +++ b/src/pybase64/_pybase64.pyi @@ -0,0 +1,16 @@ +from ._typing import Buffer + +def _get_simd_flags_compile() -> int: ... +def _get_simd_flags_runtime() -> int: ... +def _get_simd_name(flags: int) -> str: ... +def _get_simd_path() -> int: ... +def _set_simd_path(flags: int) -> None: ... +def b64decode( + s: str | Buffer, altchars: str | Buffer | None = None, validate: bool = False +) -> bytes: ... +def b64decode_as_bytearray( + s: str | Buffer, altchars: str | Buffer | None = None, validate: bool = False +) -> bytearray: ... +def b64encode(s: Buffer, altchars: str | Buffer | None = None) -> bytes: ... +def b64encode_as_string(s: Buffer, altchars: str | Buffer | None = None) -> str: ... +def encodebytes(s: Buffer) -> bytes: ... diff --git a/src/pybase64/_typing.py b/src/pybase64/_typing.py new file mode 100644 index 0000000..3cf65af --- /dev/null +++ b/src/pybase64/_typing.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import sys + +if sys.version_info < (3, 8): # pragma: no cover + from typing_extensions import Protocol +else: + from typing import Protocol + +if sys.version_info < (3, 12): + from typing_extensions import Buffer +else: + from collections.abc import Buffer + + +class Decode(Protocol): + __name__: str + __module__: str + + def __call__( + self, s: str | Buffer, altchars: str | Buffer | None = None, validate: bool = False + ) -> bytes: + ... + + +class Encode(Protocol): + __name__: str + __module__: str + + def __call__(self, s: Buffer, altchars: Buffer | None = None) -> bytes: + ... + + +class EncodeBytes(Protocol): + __name__: str + __module__: str + + def __call__(self, s: Buffer) -> bytes: + ... + + +__all__ = ("Buffer", "Decode", "Encode", "EncodeBytes") diff --git a/src/pybase64/_version.py b/src/pybase64/_version.py index e1e107f..123f2e6 100644 --- a/src/pybase64/_version.py +++ b/src/pybase64/_version.py @@ -1,3 +1,3 @@ from __future__ import annotations -__version__ = "1.3.2" +_version = "1.3.2" diff --git a/tests/test_main.py b/tests/test_main.py index e37da13..8520418 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,8 +1,9 @@ from __future__ import annotations -import os import re import sys +from collections.abc import Iterator, Sequence +from pathlib import Path import pybase64 import pytest @@ -10,24 +11,22 @@ @pytest.fixture() -def emptyfile(tmpdir): - _file = os.path.join(tmpdir.strpath, "empty") - with open(_file, "wb"): - pass - yield _file - os.remove(_file) +def emptyfile(tmp_path: Path) -> Iterator[str]: + _file = tmp_path / "empty" + _file.write_bytes(b"") + yield str(_file) + _file.unlink() @pytest.fixture() -def hellofile(tmpdir): - _file = os.path.join(tmpdir.strpath, "helloworld") - with open(_file, "wb") as f: - f.write(b"hello world !/?\n") - yield _file - os.remove(_file) +def hellofile(tmp_path: Path) -> Iterator[str]: + _file = tmp_path / "helloworld" + _file.write_bytes(b"hello world !/?\n") + yield str(_file) + _file.unlink() -def idfn_test_help(args): +def idfn_test_help(args: Sequence[str]) -> str: if len(args) == 0: return "(empty)" return " ".join(args) @@ -44,7 +43,7 @@ def idfn_test_help(args): ], ids=idfn_test_help, ) -def test_help(capsys, args): +def test_help(capsys: pytest.CaptureFixture[str], args: Sequence[str]) -> None: command = "pybase64" if len(args) == 2: command += f" {args[0]}" @@ -57,7 +56,7 @@ def test_help(capsys, args): assert exit_info.value.code == 0 -def test_version(capsys): +def test_version(capsys: pytest.CaptureFixture[str]) -> None: with pytest.raises(SystemExit) as exit_info: main(["-V"]) captured = capsys.readouterr() @@ -66,7 +65,7 @@ def test_version(capsys): assert exit_info.value.code == 0 -def test_license(capsys): +def test_license(capsys: pytest.CaptureFixture[str]) -> None: restr = "\n".join(x + "\n[=]+\n.*Copyright.*\n[=]+\n" for x in ["pybase64", "libbase64"]) regex = re.compile("^" + restr + "$", re.DOTALL) with pytest.raises(SystemExit) as exit_info: @@ -77,7 +76,7 @@ def test_license(capsys): assert exit_info.value.code == 0 -def test_benchmark(capsys, emptyfile): +def test_benchmark(capsys: pytest.CaptureFixture[str], emptyfile: str) -> None: main(["benchmark", "-d", "0.005", emptyfile]) captured = capsys.readouterr() assert captured.err == "" @@ -93,14 +92,18 @@ def test_benchmark(capsys, emptyfile): ], ids=["0", "1", "2"], ) -def test_encode(capsysbinary, hellofile, args, expect): +def test_encode( + capsysbinary: pytest.CaptureFixture[bytes], hellofile: str, args: Sequence[str], expect: bytes +) -> None: main(["encode", *args, hellofile]) captured = capsysbinary.readouterr() assert captured.err == b"" assert captured.out == expect -def test_encode_ouputfile(capsys, emptyfile, hellofile): +def test_encode_ouputfile( + capsys: pytest.CaptureFixture[str], emptyfile: str, hellofile: str +) -> None: main(["encode", "-o", hellofile, emptyfile]) captured = capsys.readouterr() assert captured.err == "" @@ -120,17 +123,21 @@ def test_encode_ouputfile(capsys, emptyfile, hellofile): ], ids=["0", "1", "2", "3"], ) -def test_decode(capsysbinary, tmpdir, args, b64string): - iname = os.path.join(tmpdir.strpath, "in") - with open(iname, "wb") as f: - f.write(b64string) - main(["decode", *args, iname]) +def test_decode( + capsysbinary: pytest.CaptureFixture[bytes], + tmp_path: Path, + args: Sequence[str], + b64string: bytes, +) -> None: + input_file = tmp_path / "in" + input_file.write_bytes(b64string) + main(["decode", *args, str(input_file)]) captured = capsysbinary.readouterr() assert captured.err == b"" assert captured.out == b"hello world !/?\n" -def test_subprocess(): +def test_subprocess() -> None: import subprocess process = subprocess.Popen( diff --git a/tests/test_pybase64.py b/tests/test_pybase64.py index a3f0d42..de3ccbf 100644 --- a/tests/test_pybase64.py +++ b/tests/test_pybase64.py @@ -4,9 +4,13 @@ import os from base64 import encodebytes as b64encodebytes from binascii import Error as BinAsciiError +from collections.abc import Callable, Iterator +from enum import IntEnum +from typing import Any import pybase64 import pytest +from pybase64._typing import Buffer, Decode, Encode try: from pybase64._pybase64 import ( @@ -25,16 +29,18 @@ _has_extension = False -def unused_args(*args): # noqa: ARG001 +def unused_args(*args: Any) -> None: # noqa: ARG001 return None -def b64encode_as_string(s, altchars=None): +def b64encode_as_string(s: Buffer, altchars: str | Buffer | None = None) -> bytes: """helper returning bytes instead of string for tests""" return pybase64.b64encode_as_string(s, altchars).encode("ascii") -def b64decode_as_bytearray(s, altchars=None, validate=False): +def b64decode_as_bytearray( + s: str | Buffer, altchars: str | Buffer | None = None, validate: bool = False +) -> bytes: """helper returning bytes instead of bytearray for tests""" return bytes(pybase64.b64decode_as_bytearray(s, altchars, validate)) @@ -45,40 +51,30 @@ def b64decode_as_bytearray(s, altchars=None, validate=False): ) -STD = 0 -URL = 1 -ALT1 = 2 -ALT2 = 3 -ALT3 = 4 -name_lut = ["standard", "urlsafe", "alternative", "alternative2", "alternative3"] +class AltCharsId(IntEnum): + STD = 0 + URL = 1 + ALT1 = 2 + ALT2 = 3 + ALT3 = 4 + + altchars_lut = [b"+/", b"-_", b"@&", b"+,", b";/"] -enc_helper_lut = [ +enc_helper_lut: list[Callable[[Buffer], bytes]] = [ pybase64.standard_b64encode, pybase64.urlsafe_b64encode, - None, - None, - None, ] -ref_enc_helper_lut = [ - pybase64.standard_b64encode, - pybase64.urlsafe_b64encode, - None, - None, - None, +ref_enc_helper_lut: list[Callable[[Buffer], bytes]] = [ + base64.standard_b64encode, + base64.urlsafe_b64encode, ] -dec_helper_lut = [ +dec_helper_lut: list[Callable[[str | Buffer], bytes]] = [ pybase64.standard_b64decode, pybase64.urlsafe_b64decode, - None, - None, - None, ] -ref_dec_helper_lut = [ +ref_dec_helper_lut: list[Callable[[str | Buffer], bytes]] = [ base64.standard_b64decode, base64.urlsafe_b64decode, - None, - None, - None, ] std = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/A" @@ -126,12 +122,12 @@ def b64decode_as_bytearray(s, altchars=None, validate=False): compile_flags += [(1 << i)] -def get_simd_name(simd_id): +def get_simd_name(simd_id: int) -> str: if _has_extension: simd_flag = compile_flags[simd_id] - simd_name = "c" if simd_flag == 0 else _get_simd_name(simd_flag).lower() + simd_name = "C" if simd_flag == 0 else _get_simd_name(simd_flag) else: - simd_name = "py" + simd_name = "PY" return simd_name @@ -141,13 +137,11 @@ def get_simd_name(simd_id): param_validate = pytest.mark.parametrize("validate", [False, True], ids=["novalidate", "validate"]) -param_altchars = pytest.mark.parametrize( - "altchars_id", [STD, URL, ALT1, ALT2, ALT3], ids=lambda x: name_lut[x] -) +param_altchars = pytest.mark.parametrize("altchars_id", list(AltCharsId), ids=lambda x: x.name) param_altchars_helper = pytest.mark.parametrize( - "altchars_id", [STD, URL], ids=lambda x: name_lut[x] + "altchars_id", [AltCharsId.STD, AltCharsId.URL], ids=lambda x: x.name ) @@ -157,7 +151,7 @@ def get_simd_name(simd_id): @pytest.fixture() -def simd(request): +def simd(request: pytest.FixtureRequest) -> Iterator[int]: simd_id = request.param if not _has_extension: assert simd_id == 0 @@ -170,14 +164,12 @@ def simd(request): old_flag = _get_simd_path() _set_simd_path(flag) assert _get_simd_path() == flag - try: - yield simd_id - finally: - _set_simd_path(old_flag) + yield simd_id + _set_simd_path(old_flag) @param_simd -def test_version(simd): +def test_version(simd: int) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests assert pybase64.get_version().startswith(pybase64.__version__) @@ -185,7 +177,7 @@ def test_version(simd): @param_simd @param_vector @param_altchars_helper -def test_enc_helper(altchars_id, vector_id, simd): +def test_enc_helper(altchars_id: int, vector_id: int, simd: int) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests vector = test_vectors_bin[altchars_id][vector_id] test = enc_helper_lut[altchars_id](vector) @@ -196,7 +188,7 @@ def test_enc_helper(altchars_id, vector_id, simd): @param_simd @param_vector @param_altchars_helper -def test_dec_helper(altchars_id, vector_id, simd): +def test_dec_helper(altchars_id: int, vector_id: int, simd: int) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests vector = test_vectors_b64[altchars_id][vector_id] test = dec_helper_lut[altchars_id](vector) @@ -207,7 +199,7 @@ def test_dec_helper(altchars_id, vector_id, simd): @param_simd @param_vector @param_altchars_helper -def test_dec_helper_unicode(altchars_id, vector_id, simd): +def test_dec_helper_unicode(altchars_id: int, vector_id: int, simd: int) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests vector = test_vectors_b64[altchars_id][vector_id] test = dec_helper_lut[altchars_id](str(vector, "utf-8")) @@ -218,7 +210,7 @@ def test_dec_helper_unicode(altchars_id, vector_id, simd): @param_simd @param_vector @param_altchars_helper -def test_rnd_helper(altchars_id, vector_id, simd): +def test_rnd_helper(altchars_id: int, vector_id: int, simd: int) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests vector = test_vectors_b64[altchars_id][vector_id] test = dec_helper_lut[altchars_id](vector) @@ -229,7 +221,7 @@ def test_rnd_helper(altchars_id, vector_id, simd): @param_simd @param_vector @param_altchars_helper -def test_rnd_helper_unicode(altchars_id, vector_id, simd): +def test_rnd_helper_unicode(altchars_id: int, vector_id: int, simd: int) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests vector = test_vectors_b64[altchars_id][vector_id] test = dec_helper_lut[altchars_id](str(vector, "utf-8")) @@ -239,9 +231,9 @@ def test_rnd_helper_unicode(altchars_id, vector_id, simd): @param_simd @param_vector -def test_encbytes(vector_id, simd): +def test_encbytes(vector_id: int, simd: int) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests - vector = test_vectors_bin[STD][vector_id] + vector = test_vectors_bin[AltCharsId.STD][vector_id] test = pybase64.encodebytes(vector) base = b64encodebytes(vector) assert test == base @@ -251,7 +243,7 @@ def test_encbytes(vector_id, simd): @param_vector @param_altchars @param_encode_functions -def test_enc(efn, altchars_id, vector_id, simd): +def test_enc(efn: Encode, altchars_id: int, vector_id: int, simd: int) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests vector = test_vectors_bin[altchars_id][vector_id] altchars = altchars_lut[altchars_id] @@ -265,7 +257,7 @@ def test_enc(efn, altchars_id, vector_id, simd): @param_altchars @param_validate @param_decode_functions -def test_dec(dfn, altchars_id, vector_id, validate, simd): +def test_dec(dfn: Decode, altchars_id: int, vector_id: int, validate: bool, simd: int) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests vector = test_vectors_b64[altchars_id][vector_id] altchars = altchars_lut[altchars_id] @@ -282,15 +274,16 @@ def test_dec(dfn, altchars_id, vector_id, validate, simd): @param_altchars @param_validate @param_decode_functions -def test_dec_unicode(dfn, altchars_id, vector_id, validate, simd): +def test_dec_unicode( + dfn: Decode, altchars_id: int, vector_id: int, validate: bool, simd: int +) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests - vector = test_vectors_b64[altchars_id][vector_id] - vector = str(vector, "utf-8") - altchars = None if altchars_id == STD else str(altchars_lut[altchars_id], "utf-8") + vector = str(test_vectors_b64[altchars_id][vector_id], "utf-8") + altchars = None if altchars_id == AltCharsId.STD else str(altchars_lut[altchars_id], "utf-8") if validate: - base = base64.b64decode(vector, altchars, validate) + base = base64.b64decode(vector, altchars, validate) # type: ignore[arg-type] # c.f. https://github.com/python/typeshed/pull/11210 else: - base = base64.b64decode(vector, altchars) + base = base64.b64decode(vector, altchars) # type: ignore[arg-type] # c.f. https://github.com/python/typeshed/pull/11210 test = dfn(vector, altchars, validate) assert test == base @@ -301,7 +294,9 @@ def test_dec_unicode(dfn, altchars_id, vector_id, validate, simd): @param_validate @param_encode_functions @param_decode_functions -def test_rnd(dfn, efn, altchars_id, vector_id, validate, simd): +def test_rnd( + dfn: Decode, efn: Encode, altchars_id: int, vector_id: int, validate: bool, simd: int +) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests vector = test_vectors_b64[altchars_id][vector_id] altchars = altchars_lut[altchars_id] @@ -316,7 +311,9 @@ def test_rnd(dfn, efn, altchars_id, vector_id, validate, simd): @param_validate @param_encode_functions @param_decode_functions -def test_rnd_unicode(dfn, efn, altchars_id, vector_id, validate, simd): +def test_rnd_unicode( + dfn: Decode, efn: Encode, altchars_id: int, vector_id: int, validate: bool, simd: int +) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests vector = test_vectors_b64[altchars_id][vector_id] altchars = altchars_lut[altchars_id] @@ -330,7 +327,9 @@ def test_rnd_unicode(dfn, efn, altchars_id, vector_id, validate, simd): @param_altchars @param_validate @param_decode_functions -def test_invalid_padding_dec(dfn, altchars_id, vector_id, validate, simd): +def test_invalid_padding_dec( + dfn: Decode, altchars_id: int, vector_id: int, validate: bool, simd: int +) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests vector = test_vectors_b64[altchars_id][vector_id][1:] if len(vector) > 0: @@ -339,7 +338,7 @@ def test_invalid_padding_dec(dfn, altchars_id, vector_id, validate, simd): dfn(vector, altchars, validate) -params_invalid_altchars = [ +params_invalid_altchars_values = [ [b"", AssertionError], [b"-", AssertionError], [b"-__", AssertionError], @@ -349,15 +348,17 @@ def test_invalid_padding_dec(dfn, altchars_id, vector_id, validate, simd): ] params_invalid_altchars = pytest.mark.parametrize( "altchars,exception", - params_invalid_altchars, - ids=[str(i) for i in range(len(params_invalid_altchars))], + params_invalid_altchars_values, + ids=[str(i) for i in range(len(params_invalid_altchars_values))], ) @param_simd @params_invalid_altchars @param_encode_functions -def test_invalid_altchars_enc(efn, altchars, exception, simd): +def test_invalid_altchars_enc( + efn: Encode, altchars: Any, exception: type[BaseException], simd: int +) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests with pytest.raises(exception): efn(b"ABCD", altchars) @@ -366,7 +367,9 @@ def test_invalid_altchars_enc(efn, altchars, exception, simd): @param_simd @params_invalid_altchars @param_decode_functions -def test_invalid_altchars_dec(dfn, altchars, exception, simd): +def test_invalid_altchars_dec( + dfn: Decode, altchars: Any, exception: type[BaseException], simd: int +) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests with pytest.raises(exception): dfn(b"ABCD", altchars) @@ -375,19 +378,21 @@ def test_invalid_altchars_dec(dfn, altchars, exception, simd): @param_simd @params_invalid_altchars @param_decode_functions -def test_invalid_altchars_dec_validate(dfn, altchars, exception, simd): +def test_invalid_altchars_dec_validate( + dfn: Decode, altchars: Any, exception: type[BaseException], simd: int +) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests with pytest.raises(exception): dfn(b"ABCD", altchars, True) -params_invalid_data_novalidate = [ +params_invalid_data_novalidate_values = [ [b"A@@@@FG", None, BinAsciiError], ["ABC€", None, ValueError], [3.0, None, TypeError], [memoryview(b"ABCDEFGH")[::2], None, BufferError], ] -params_invalid_data_validate = [ +params_invalid_data_validate_values = [ [b"\x00\x00\x00\x00", None, BinAsciiError], [b"A@@@@FGHIJKLMNOPQRSTUVWXYZabcdef", b"-_", BinAsciiError], [b"A@@@=FGHIJKLMNOPQRSTUVWXYZabcdef", b"-_", BinAsciiError], @@ -399,28 +404,32 @@ def test_invalid_altchars_dec_validate(dfn, altchars, exception, simd): ] params_invalid_data_all = pytest.mark.parametrize( "vector,altchars,exception", - params_invalid_data_novalidate + params_invalid_data_validate, + params_invalid_data_novalidate_values + params_invalid_data_validate_values, ids=[ str(i) - for i in range(len(params_invalid_data_novalidate) + len(params_invalid_data_validate)) + for i in range( + len(params_invalid_data_novalidate_values) + len(params_invalid_data_validate_values) + ) ], ) params_invalid_data_novalidate = pytest.mark.parametrize( "vector,altchars,exception", - params_invalid_data_novalidate, - ids=[str(i) for i in range(len(params_invalid_data_novalidate))], + params_invalid_data_novalidate_values, + ids=[str(i) for i in range(len(params_invalid_data_novalidate_values))], ) params_invalid_data_validate = pytest.mark.parametrize( "vector,altchars,exception", - params_invalid_data_validate, - ids=[str(i) for i in range(len(params_invalid_data_validate))], + params_invalid_data_validate_values, + ids=[str(i) for i in range(len(params_invalid_data_validate_values))], ) @param_simd @params_invalid_data_novalidate @param_decode_functions -def test_invalid_data_dec(dfn, vector, altchars, exception, simd): +def test_invalid_data_dec( + dfn: Decode, vector: Any, altchars: Buffer | None, exception: type[BaseException], simd: int +) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests with pytest.raises(exception): dfn(vector, altchars) @@ -429,7 +438,9 @@ def test_invalid_data_dec(dfn, vector, altchars, exception, simd): @param_simd @params_invalid_data_validate @param_decode_functions -def test_invalid_data_dec_skip(dfn, vector, altchars, exception, simd): +def test_invalid_data_dec_skip( + dfn: Decode, vector: Any, altchars: Buffer | None, exception: type[BaseException], simd: int +) -> None: unused_args(exception, simd) # simd is a parameter in order to control the order of tests test = dfn(vector, altchars) base = base64.b64decode(vector, altchars) @@ -439,59 +450,61 @@ def test_invalid_data_dec_skip(dfn, vector, altchars, exception, simd): @param_simd @params_invalid_data_all @param_decode_functions -def test_invalid_data_dec_validate(dfn, vector, altchars, exception, simd): +def test_invalid_data_dec_validate( + dfn: Decode, vector: Any, altchars: Buffer | None, exception: type[BaseException], simd: int +) -> None: unused_args(simd) # simd is a parameter in order to control the order of tests with pytest.raises(exception): dfn(vector, altchars, True) -params_invalid_data_enc = [ +params_invalid_data_enc_values = [ ["this is a test", TypeError], [memoryview(b"abcd")[::2], BufferError], ] -params_invalid_data_encodebytes = [ - *params_invalid_data_enc, +params_invalid_data_encodebytes_values = [ + *params_invalid_data_enc_values, [memoryview(b"abcd").cast("B", (2, 2)), TypeError], [memoryview(b"abcd").cast("I"), TypeError], ] params_invalid_data_enc = pytest.mark.parametrize( "vector,exception", - params_invalid_data_enc, - ids=[str(i) for i in range(len(params_invalid_data_enc))], + params_invalid_data_enc_values, + ids=[str(i) for i in range(len(params_invalid_data_enc_values))], ) params_invalid_data_encodebytes = pytest.mark.parametrize( "vector,exception", - params_invalid_data_encodebytes, - ids=[str(i) for i in range(len(params_invalid_data_encodebytes))], + params_invalid_data_encodebytes_values, + ids=[str(i) for i in range(len(params_invalid_data_encodebytes_values))], ) @params_invalid_data_enc @param_encode_functions -def test_invalid_data_enc(efn, vector, exception): +def test_invalid_data_enc(efn: Encode, vector: Any, exception: type[BaseException]) -> None: with pytest.raises(exception): efn(vector) @params_invalid_data_encodebytes -def test_invalid_data_encodebytes(vector, exception): +def test_invalid_data_encodebytes(vector: Any, exception: type[BaseException]) -> None: with pytest.raises(exception): pybase64.encodebytes(vector) @param_encode_functions -def test_invalid_args_enc_0(efn): +def test_invalid_args_enc_0(efn: Encode) -> None: with pytest.raises(TypeError): - efn() + efn() # type: ignore[call-arg] @param_decode_functions -def test_invalid_args_dec_0(dfn): +def test_invalid_args_dec_0(dfn: Decode) -> None: with pytest.raises(TypeError): - dfn() + dfn() # type: ignore[call-arg] -def test_flags(request): +def test_flags(request: pytest.FixtureRequest) -> None: cpu = request.config.getoption("--sde-cpu", skip=True) assert { "p4p": 1 | 2, # SSE3 @@ -505,7 +518,7 @@ def test_flags(request): @param_encode_functions -def test_enc_multi_dimensional(efn): +def test_enc_multi_dimensional(efn: Encode) -> None: source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV" vector = memoryview(source).cast("B", (4, len(source) // 4)) assert vector.c_contiguous @@ -515,7 +528,7 @@ def test_enc_multi_dimensional(efn): @param_decode_functions -def test_dec_multi_dimensional(dfn): +def test_dec_multi_dimensional(dfn: Decode) -> None: source = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV" vector = memoryview(source).cast("B", (4, len(source) // 4)) assert vector.c_contiguous