From b16e7248c272aa1dbebf3cec2ce29cda303dc0df Mon Sep 17 00:00:00 2001 From: Abel Cheung Date: Fri, 20 Oct 2023 14:38:41 +0000 Subject: [PATCH] Improve overloads and return types for `iterparse.__new__` Fixes #19 --- lxml-stubs/etree/_iterparse.pyi | 119 +++++++++++++++++++++++++++++--- test-rt/test_iterparse.py | 72 +++++++++++++++---- 2 files changed, 167 insertions(+), 24 deletions(-) diff --git a/lxml-stubs/etree/_iterparse.pyi b/lxml-stubs/etree/_iterparse.pyi index f67bb83..6286f59 100644 --- a/lxml-stubs/etree/_iterparse.pyi +++ b/lxml-stubs/etree/_iterparse.pyi @@ -27,6 +27,17 @@ _SaxNsEventValues: TypeAlias = tuple[str, str] | None # for start-ns & end-ns e class iterparse(Iterator[_T_co]): """Incremental parser + Annotation + ---------- + Totally 5 function signatures are available: + - Default XML mode, where only `end` event is emitted + - `start`, `end`, `comment` and `pi` events, where only + Element values are produced + - HTML mode (`html=True`), where namespace events are ignored + - XML mode with `start-ns` or `end-ns` events, producing + namespace tuple (for `start-ns`) or nothing (`end-ns`) + - Final catch-all signature for XML mode + Original Docstring ------------------ Parses XML into a tree and generates tuples (event, element) in a @@ -52,7 +63,7 @@ class iterparse(Iterator[_T_co]): libxml2 parser configuration. A DTD will also be loaded if validation or attribute default values are requested.""" - @overload # default events + @overload # default values, only 'end' event emitted def __new__( cls, source: _FilePath | IO[bytes] | SupportsReadClose[bytes], @@ -76,7 +87,31 @@ class iterparse(Iterator[_T_co]): collect_ids: bool = ..., schema: XMLSchema | None = ..., ) -> iterparse[tuple[Literal["end"], _Element]]: ... - @overload # html mode -> namespace events supressed + @overload # element-only events + def __new__( + cls, + source: _FilePath | IO[bytes] | SupportsReadClose[bytes], + events: Iterable[_NoNSEventNames], + *, + tag: _TagSelector | Iterable[_TagSelector] | None = ..., + attribute_defaults: bool = ..., + dtd_validation: bool = ..., + load_dtd: bool = ..., + no_network: bool = ..., + remove_blank_text: bool = ..., + compact: bool = ..., + resolve_entities: bool = ..., + remove_comments: bool = ..., + remove_pis: bool = ..., + strip_cdata: bool = ..., + encoding: _AnyStr | None = ..., + html: bool = ..., + recover: bool | None = ..., + huge_tree: bool = ..., + collect_ids: bool = ..., + schema: XMLSchema | None = ..., + ) -> iterparse[tuple[_NoNSEventNames, _Element]]: ... + @overload # html mode -> namespace events suppressed def __new__( cls, source: _FilePath | IO[bytes] | SupportsReadClose[bytes], @@ -100,7 +135,33 @@ class iterparse(Iterator[_T_co]): collect_ids: bool = ..., schema: XMLSchema | None = ..., ) -> iterparse[tuple[_NoNSEventNames, _Element]]: ... - @overload # custom events, xml mode + @overload # xml mode & NS-only events + def __new__( + cls, + source: _FilePath | IO[bytes] | SupportsReadClose[bytes], + events: Iterable[Literal['start-ns', 'end-ns']], + *, + tag: _TagSelector | Iterable[_TagSelector] | None = ..., + attribute_defaults: bool = ..., + dtd_validation: bool = ..., + load_dtd: bool = ..., + no_network: bool = ..., + remove_blank_text: bool = ..., + compact: bool = ..., + resolve_entities: bool = ..., + remove_comments: bool = ..., + remove_pis: bool = ..., + strip_cdata: bool = ..., + encoding: _AnyStr | None = ..., + html: Literal[False] = ..., + recover: bool | None = ..., + huge_tree: bool = ..., + collect_ids: bool = ..., + schema: XMLSchema | None = ..., + ) -> iterparse[ + tuple[Literal['start-ns'], tuple[str, str]] + | tuple[Literal['end-ns'], None]]: ... + @overload # xml mode, catch all def __new__( cls, source: _FilePath | IO[bytes] | SupportsReadClose[bytes], @@ -118,12 +179,16 @@ class iterparse(Iterator[_T_co]): remove_pis: bool = ..., strip_cdata: bool = ..., encoding: _AnyStr | None = ..., - html: bool = ..., + html: Literal[False] = ..., recover: bool | None = ..., huge_tree: bool = ..., collect_ids: bool = ..., schema: XMLSchema | None = ..., - ) -> iterparse[tuple[_SaxEventNames, _Element | _SaxNsEventValues]]: ... + ) -> iterparse[ + tuple[_NoNSEventNames, _Element] + | tuple[Literal['start-ns'], tuple[str, str]] + | tuple[Literal['end-ns'], None] + ]: ... def __next__(self) -> _T_co: ... # root property only present after parsing is done @property @@ -151,6 +216,17 @@ class iterwalk(Iterator[_T_co]): """Tree walker that generates events from an existing tree as if it was parsing XML data with ``iterparse()`` + Annotation + ---------- + Totally 4 function signatures, depending on `events` argument: + - Default value, where only `end` event is emitted + - `start`, `end`, `comment` and `pi` events, where only + Element values are produced + - Namespace events (`start-ns` or `end-ns`), producing + namespace tuple (for `start-ns`) or nothing (`end-ns`) + - Final catch-all for custom events combination + + Original Docstring ------------------ Just as for ``iterparse()``, the ``tag`` argument can be a single tag or a @@ -163,19 +239,40 @@ class iterwalk(Iterator[_T_co]): # There is no concept of html mode in iterwalk; namespace events # are not supressed like iterparse might do - @overload # custom events + @overload # default events def __new__( cls, element_or_tree: _ET_co | _ElementTree[_ET_co], - events: Iterable[_SaxEventNames], + events: None = ..., tag: _TagSelector | Iterable[_TagSelector] | None = ..., - ) -> iterwalk[tuple[_SaxEventNames, _ET_co | _SaxNsEventValues]]: ... - @overload # default events + ) -> iterwalk[tuple[Literal["end"], _ET_co]]: ... + @overload # element-only events def __new__( cls, element_or_tree: _ET_co | _ElementTree[_ET_co], - events: None = ..., + events: Iterable[_NoNSEventNames], tag: _TagSelector | Iterable[_TagSelector] | None = ..., - ) -> iterwalk[tuple[Literal["end"], _ET_co]]: ... + ) -> iterwalk[tuple[_NoNSEventNames, _ET_co]]: ... + @overload # namespace-only events + def __new__( + cls, + element_or_tree: _ET_co | _ElementTree[_ET_co], + events: Iterable[Literal['start-ns', 'end-ns']], + tag: _TagSelector | Iterable[_TagSelector] | None = ..., + ) -> iterwalk[ + tuple[Literal['start-ns'], tuple[str, str]] + | tuple[Literal['end-ns'], None] + ]: ... + @overload # catch-all + def __new__( + cls, + element_or_tree: _ET_co | _ElementTree[_ET_co], + events: Iterable[_SaxEventNames], + tag: _TagSelector | Iterable[_TagSelector] | None = ..., + ) -> iterwalk[ + tuple[_NoNSEventNames, _ET_co] + | tuple[Literal['start-ns'], tuple[str, str]] + | tuple[Literal['end-ns'], None] + ]: ... def __next__(self) -> _T_co: ... def skip_subtree(self) -> None: ... diff --git a/test-rt/test_iterparse.py b/test-rt/test_iterparse.py index 8a76cb0..26d0cf0 100644 --- a/test-rt/test_iterparse.py +++ b/test-rt/test_iterparse.py @@ -1,5 +1,6 @@ from __future__ import annotations +import pytest from pathlib import Path import _testutils @@ -20,46 +21,91 @@ def test_xml_default_event(self, xml_tree: _ElementTree[_Element]) -> None: def test_xml_more_event(self, xml_tree: _ElementTree[_Element]) -> None: walker = iterwalk(xml_tree, ["start", "end", "start-ns", "end-ns", "comment"]) reveal_type(walker) - for event, elem in walker: - reveal_type(event) - reveal_type(elem) + # Generated values are not unpacked here to test type narrowing + # See issue #19 for more info + for item in walker: + if item[0] == 'start-ns': + reveal_type(item[1]) + elif item[0] == 'end-ns': + reveal_type(item[1]) + else: + reveal_type(item[1]) def test_html_default_event(self, html_tree: _ElementTree[HtmlElement]) -> None: - walker = iterwalk(html_tree) + walker = iterwalk(html_tree, tag=('div', 'span')) reveal_type(walker) for event, elem in walker: reveal_type(event) reveal_type(elem) def test_html_more_event(self, html_tree: _ElementTree[HtmlElement]) -> None: - # BUG Since HtmlComment is pretended as HtmlElement subclass + # Since HtmlComment is pretended as HtmlElement subclass # in stub but not runtime, adding 'comment' event would fail - walker = iterwalk(html_tree, ["start", "end", "start-ns", "end-ns"], "div") + walker = iterwalk(html_tree, ("start", "end", "start-ns", "end-ns"), "div") reveal_type(walker) - for event, elem in walker: - reveal_type(event) - reveal_type(elem) + # Unlike iterparse(), iterwalk behaves the same with HTML + for item in walker: + if item[0] == 'start-ns': + reveal_type(item[1]) + elif item[0] == 'end-ns': + reveal_type(item[1]) + else: + reveal_type(item[1]) class TestIterparse: - def test_default_event(self, x1_filepath: Path) -> None: + def test_xml_default_event(self, x1_filepath: Path) -> None: walker = iterparse(x1_filepath) reveal_type(walker) for event, elem in walker: reveal_type(event) reveal_type(elem) + def test_xml_more_event(self, x1_filepath: Path) -> None: + walker = iterparse(x1_filepath, [ + "start", "end", "start-ns", "end-ns", "comment" + ]) + reveal_type(walker) + # Generated values are not unpacked here to test type narrowing + # See issue #19 for more info + for item in walker: + if item[0] == 'start-ns': + reveal_type(item[1]) + elif item[0] == 'end-ns': + reveal_type(item[1]) + else: + reveal_type(item[1]) + def test_html_mode(self, x1_filepath: Path) -> None: - walker = iterparse(source=str(x1_filepath), html=True) + walker = iterparse( + source=x1_filepath, + html=True, + events = ("start", "end", "start-ns", "end-ns", "comment")) reveal_type(walker) for event, elem in walker: reveal_type(event) reveal_type(elem) - def test_custom_event(self, x1_filepath: Path) -> None: + def test_plain_filename(self, x1_filepath: Path) -> None: + walker = iterparse(str(x1_filepath)) + reveal_type(walker) + for event, elem in walker: + reveal_type(event) + reveal_type(elem) + + def test_binary_io(self, x1_filepath: Path) -> None: with open(x1_filepath, "rb") as f: - walker = iterparse(f, ["start", "end", "start-ns", "end-ns"]) + walker = iterparse(f) reveal_type(walker) for event, elem in walker: reveal_type(event) reveal_type(elem) + + def test_text_io(self, x1_filepath: Path) -> None: + with pytest.raises( + TypeError, + match='reading file objects must return bytes objects'): + with open(x1_filepath, "r") as f: + walker = iterparse(f) # pyright: ignore + for event, elem in walker: # pyright: ignore + print(event, elem) # pyright: ignore