diff --git a/CHANGELOG.md b/CHANGELOG.md index d3cce61..c676675 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,14 @@ For the purpose of determining breaking changes: [python-versions]: https://devguide.python.org/versions/#supported-versions +## [Unreleased] + +[unreleased]: https://github.com/rogdham/bigxml/compare/v1.0.1...HEAD + +### :rocket: Added + +- Add support for buffer protocol ([PEP 688](https://peps.python.org/pep-0688/)) + ## [1.0.1] - 2024-04-27 [1.0.1]: https://github.com/rogdham/bigxml/compare/v1.0.0...v1.0.1 diff --git a/pyproject.toml b/pyproject.toml index 1cde03e..c11b4f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ requires-python = ">=3.8" dependencies = [ "defusedxml>=0.7.1", - "typing-extensions>=4.3.0 ; python_version<'3.10'", + "typing-extensions>=4.6.0 ; python_version<'3.12'", ] [project.urls] diff --git a/src/bigxml/stream.py b/src/bigxml/stream.py index bc061f6..8c95d20 100644 --- a/src/bigxml/stream.py +++ b/src/bigxml/stream.py @@ -1,17 +1,24 @@ from io import IOBase +import sys from typing import Any, Generator, Iterable, Optional, cast from bigxml.typing import Streamable, SupportsRead from bigxml.utils import autostart_generator +if sys.version_info < (3, 12): # pragma: no cover + from typing_extensions import Buffer +else: # pragma: no cover + from collections.abc import Buffer + @autostart_generator def _flatten_stream(stream: Streamable) -> Generator[Optional[memoryview], int, None]: yield None - # bytes-like + # buffer protocol (bytes, etc.) try: - yield memoryview(cast(bytes, stream)) + # we try-except instead of isinstance(stream, Buffer) for compatibility reasons + yield memoryview(cast(Buffer, stream)) return # noqa: TRY300 except TypeError: pass diff --git a/src/bigxml/typing.py b/src/bigxml/typing.py index 86f9afd..6760e33 100644 --- a/src/bigxml/typing.py +++ b/src/bigxml/typing.py @@ -16,6 +16,10 @@ else: # pragma: no cover from typing import ParamSpec +if sys.version_info < (3, 12): # pragma: no cover + from typing_extensions import Buffer +else: # pragma: no cover + from collections.abc import Buffer P = ParamSpec("P") T = TypeVar("T") @@ -30,7 +34,7 @@ class SupportsRead(Protocol[T_co]): def read(self, size: Optional[int] = None) -> T_co: ... # pragma: no cover -Streamable = Union[SupportsRead[bytes], bytes, Iterable["Streamable"]] +Streamable = Union[Buffer, SupportsRead[bytes], Iterable["Streamable"]] class ClassHandlerWithCustomWrapper0(Protocol[T_co]): diff --git a/tests/unit/test_stream.py b/tests/unit/test_stream.py index c6c549b..3e79a2e 100644 --- a/tests/unit/test_stream.py +++ b/tests/unit/test_stream.py @@ -1,5 +1,9 @@ +from array import array +import inspect from io import BytesIO, IOBase, StringIO +from mmap import mmap from string import ascii_lowercase +import sys from typing import Iterator, Optional, Tuple, cast import pytest @@ -14,27 +18,52 @@ def test_no_stream() -> None: assert stream.read(42) == b"" -def abcdef_generator() -> Iterator[bytes]: - yield b"abcdef" +DATA = b"a\x00b\x7fc\x80d\xffe" + + +def to_mmap(data: bytes) -> mmap: + out = mmap(-1, len(data)) + out.write(data) + return out + + +def custom_generator() -> Iterator[bytes]: + yield DATA + + +class CustomBuffer: + def __buffer__(self, flags: int) -> memoryview: + if flags != inspect.BufferFlags.FULL_RO: + raise TypeError("Only BufferFlags.FULL_RO supported") + return memoryview(DATA).toreadonly() @pytest.mark.parametrize( "stream", [ - b"abcdef", - bytearray(b"abcdef"), - memoryview(b"abcdef"), - BytesIO(b"abcdef"), - [b"abcdef"], - (b"abcdef",), - iter([b"abcdef"]), - abcdef_generator(), + DATA, + bytearray(DATA), + memoryview(DATA), + array("B", DATA), + to_mmap(DATA), + BytesIO(DATA), + [DATA], + (DATA,), + iter([DATA]), + custom_generator(), + pytest.param( + CustomBuffer(), + marks=pytest.mark.skipif( + sys.version_info < (3, 12), + reason="requires python3.12 or higher", + ), + ), ], ids=type, ) def test_types(stream: Streamable) -> None: stream = StreamChain(stream) - assert stream.read(42) == b"abcdef" + assert stream.read(42) == DATA assert stream.read(42) == b""