Skip to content

Commit

Permalink
feature: better typing using Buffer (PEP 688) (#673)
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeut authored Jan 2, 2024
1 parent 4109a0f commit 4e00978
Show file tree
Hide file tree
Showing 17 changed files with 320 additions and 171 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:

- uses: wntrblm/[email protected]
with:
python-versions: "3.8, 3.11, pypy3.10"
python-versions: "3.8, 3.12, pypy3.10"

- name: Install Intel SDE
run: |
Expand Down
7 changes: 6 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,9 @@ repos:
rev: v1.8.0
hooks:
- id: mypy
exclude: docs|conftest.py|tests
exclude: docs|conftest.py
args: ["--python-version=3.7"]
additional_dependencies:
- nox
- pytest
- types-setuptools
3 changes: 3 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ include LICENSE
# Include headers
include src/pybase64/_pybase64_get_simd_flags.h

# Include type stub for extension
include src/pybase64/_pybase64.pyi

# Include full base64 folder
graft base64
# but the git folder
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
# |version| and |release|, also used in various other places throughout the
# built documents.
# Get version
_version = runpy.run_path(os.path.join(here, "..", "src", "pybase64", "_version.py"))["__version__"]
_version = runpy.run_path(os.path.join(here, "..", "src", "pybase64", "_version.py"))["_version"]
# The short X.Y version.
version = _version
# The full version, including alpha/beta/rc tags.
Expand Down
11 changes: 5 additions & 6 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def lint(session: nox.Session) -> None:
def update_env_macos(session: nox.Session, env: dict[str, str]) -> None:
if sys.platform.startswith("darwin"):
# we don't support universal builds
machine = session.run(
machine = session.run( # type: ignore[union-attr]
"python", "-sSEc", "import platform; print(platform.machine())", silent=True
).strip()
env["ARCHFLAGS"] = f"-arch {machine}"
Expand All @@ -37,7 +37,7 @@ def remove_extension(session: nox.Session, in_place: bool = False) -> None:
where = HERE / "src" / "pybase64"
else:
command = "import sysconfig; print(sysconfig.get_path('platlib'))"
platlib = session.run("python", "-c", command, silent=True).strip()
platlib = session.run("python", "-c", command, silent=True).strip() # type: ignore[union-attr]
where = Path(platlib) / "pybase64"
assert where.exists()

Expand All @@ -54,7 +54,7 @@ def remove_extension(session: nox.Session, in_place: bool = False) -> None:
@nox.session(python="3.12")
def develop(session: nox.Session) -> None:
"""create venv for dev."""
session.install("-r", "requirements-test.txt")
session.install("nox", "setuptools", "-r", "requirements-test.txt")
# make extension mandatory by exporting CIBUILDWHEEL=1
env = {"CIBUILDWHEEL": "1"}
update_env_macos(session, env)
Expand All @@ -76,7 +76,7 @@ def test(session: nox.Session) -> None:
session.run("pytest", *session.posargs, env=env)


@nox.session(python=["3.8", "3.11", "pypy3.10"])
@nox.session(python=["3.8", "3.12", "pypy3.10"])
def _coverage(session: nox.Session) -> None:
"""internal coverage run. Do not run manually"""
with_sde = "--with-sde" in session.posargs
Expand All @@ -86,7 +86,6 @@ def _coverage(session: nox.Session) -> None:
"--cov=pybase64",
"--cov=tests",
"--cov-append",
"--cov-branch",
"--cov-report=",
)
pytest_command = ("pytest", *coverage_args)
Expand Down Expand Up @@ -144,7 +143,7 @@ def coverage(session: nox.Session) -> None:
posargs.add("--report")
session.notify("_coverage-3.8", ["--clean"])
session.notify("_coverage-pypy3.10", [])
session.notify("_coverage-3.11", posargs)
session.notify("_coverage-3.12", posargs)


@nox.session(python="3.11")
Expand Down
60 changes: 52 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,56 @@ test-requires = "-r requirements-test.txt"
test-command = "pytest {project}/tests"
build-verbosity = 1

[tool.coverage.run]
branch = true

[tool.coverage.report]
exclude_lines = ["pragma: no cover", "class .*\\(Protocol\\):", "if TYPE_CHECKING:"]

[tool.mypy]
python_version = "3.7"
files = [
"src/**/*.py",
"test/**/*.py",
"noxfile.py",
"setup.py",
]
warn_unused_configs = true
show_error_codes = true

warn_redundant_casts = true
no_implicit_reexport = true
strict_equality = true
warn_unused_ignores = true
check_untyped_defs = true
ignore_missing_imports = false

disallow_subclassing_any = true
disallow_any_generics = true
disallow_untyped_defs = true
disallow_untyped_calls = true
disallow_incomplete_defs = true
disallow_untyped_decorators = true
disallow_any_explicit = true
warn_return_any = true

no_implicit_optional = true
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
warn_unreachable = true

[[tool.mypy.overrides]]
module = ["pybase64.__main__"]
disallow_any_explicit = false

[[tool.mypy.overrides]]
module = ["tests.test_pybase64"]
disallow_any_explicit = false


[tool.pytest.ini_options]
minversion = "7.0"
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config", "-p", "no:legacypath"]

[tool.ruff]
target-version = "py37"
line-length = 100
Expand Down Expand Up @@ -37,15 +87,9 @@ extend-select = [
ignore = [
"PLR", # Design related pylint codes
]
typing-modules = ["pybase64._typing"]

[tool.ruff.lint.flake8-tidy-imports.banned-api]
"typing.Callable".msg = "Use collections.abc.Callable instead."
"typing.Iterator".msg = "Use collections.abc.Iterator instead."
"typing.Sequence".msg = "Use collections.abc.Sequence instead."

[tool.mypy]
python_version = "3.7"
follow_imports = "silent"
ignore_missing_imports = true
disallow_untyped_defs = true
disallow_any_generics = true
warn_unused_ignores = true
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pytest==7.4.4
typing_extensions>=4.6.0
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sysconfig
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
from typing import Generator, cast

from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
Expand All @@ -25,7 +25,7 @@
# Get version
version_dict: dict[str, object] = {}
exec(HERE.joinpath("src", "pybase64", "_version.py").read_text(), {}, version_dict)
version = version_dict["__version__"]
version = cast(str, version_dict["_version"])

# Get the long description from the README file
long_description = HERE.joinpath("README.rst").read_text()
Expand Down Expand Up @@ -244,7 +244,7 @@ def run(self) -> None:
# simple. Or you can use find_packages().
packages=find_packages(where="src"),
package_dir={"": "src"},
package_data={"pybase64": ["py.typed"]},
package_data={"pybase64": ["py.typed", "_pybase64.pyi"]},
# To provide executable scripts, use entry points in preference to the
# "scripts" keyword. Entry points provide cross-platform support and allow
# pip to create the appropriate form of executable for the target platform.
Expand Down
32 changes: 25 additions & 7 deletions src/pybase64/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING

from ._license import _license
from ._version import __version__
from ._version import _version

if TYPE_CHECKING:
from ._typing import Buffer

try:
from ._pybase64 import (
Expand All @@ -16,7 +19,7 @@
encodebytes,
)
except ImportError:
from ._fallback import ( # noqa: F401
from ._fallback import (
_get_simd_name,
_get_simd_path,
b64decode,
Expand All @@ -27,6 +30,21 @@
)


__all__ = (
"b64decode",
"b64decode_as_bytearray",
"b64encode",
"b64encode_as_string",
"encodebytes",
"standard_b64encode",
"standard_b64decode",
"urlsafe_b64encode",
"urlsafe_b64decode",
)

__version__ = _version


def get_license_text() -> str:
"""Returns pybase64 license information as a :class:`str` object.
Expand All @@ -47,7 +65,7 @@ def get_version() -> str:
return f"{__version__} (C extension inactive)"


def standard_b64encode(s: Any) -> bytes:
def standard_b64encode(s: Buffer) -> bytes:
"""Encode bytes using the standard Base64 alphabet.
Argument ``s`` is a :term:`bytes-like object` to encode.
Expand All @@ -57,7 +75,7 @@ def standard_b64encode(s: Any) -> bytes:
return b64encode(s)


def standard_b64decode(s: Any) -> bytes:
def standard_b64decode(s: str | Buffer) -> bytes:
"""Decode bytes encoded with the standard Base64 alphabet.
Argument ``s`` is a :term:`bytes-like object` or ASCII string to
Expand All @@ -73,7 +91,7 @@ def standard_b64decode(s: Any) -> bytes:
return b64decode(s)


def urlsafe_b64encode(s: Any) -> bytes:
def urlsafe_b64encode(s: Buffer) -> bytes:
"""Encode bytes using the URL- and filesystem-safe Base64 alphabet.
Argument ``s`` is a :term:`bytes-like object` to encode.
Expand All @@ -85,7 +103,7 @@ def urlsafe_b64encode(s: Any) -> bytes:
return b64encode(s, b"-_")


def urlsafe_b64decode(s: Any) -> bytes:
def urlsafe_b64decode(s: str | Buffer) -> bytes:
"""Decode bytes using the URL- and filesystem-safe Base64 alphabet.
Argument ``s`` is a :term:`bytes-like object` or ASCII string to
Expand Down
23 changes: 13 additions & 10 deletions src/pybase64/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,23 @@
import sys
from base64 import b64decode as b64decodeValidate
from base64 import encodebytes as b64encodebytes
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from timeit import default_timer as timer
from typing import Any, BinaryIO
from typing import TYPE_CHECKING, Any, BinaryIO, cast

import pybase64

if TYPE_CHECKING:
from pybase64._typing import Decode, Encode, EncodeBytes


def bench_one(
duration: float,
data: bytes,
enc: Callable[..., bytes],
dec: Callable[..., bytes],
encbytes: Callable[[Any], bytes],
altchars: Any | None = None,
enc: Encode,
dec: Decode,
encbytes: EncodeBytes,
altchars: bytes | None = None,
validate: bool = False,
) -> None:
duration = duration / 2.0
Expand Down Expand Up @@ -93,11 +96,11 @@ def bench_one(


def readall(file: BinaryIO) -> bytes:
if file == sys.stdin:
if file == cast(BinaryIO, sys.stdin):
# Python 3 < 3.9 does not honor the binary flag,
# read from the underlying buffer
if hasattr(file, "buffer"):
return file.buffer.read()
return cast(BinaryIO, file.buffer).read()
return file.read() # pragma: no cover
# do not close the file
try:
Expand All @@ -108,7 +111,7 @@ def readall(file: BinaryIO) -> bytes:


def writeall(file: BinaryIO, data: bytes) -> None:
if file == sys.stdout:
if file == cast(BinaryIO, sys.stdout):
# Python 3 does not honor the binary flag,
# write to the underlying buffer
if hasattr(file, "buffer"):
Expand Down Expand Up @@ -142,7 +145,7 @@ def benchmark(duration: float, input: BinaryIO) -> None:
duration,
data,
base64.b64encode,
b64decodeValidate,
b64decodeValidate, # type: ignore[arg-type] # c.f. https://github.com/python/typeshed/pull/11210
b64encodebytes,
altchars,
validate,
Expand Down
Loading

0 comments on commit 4e00978

Please sign in to comment.