From b65b587aa12ec40df1eecb4c6aadaf8de8931187 Mon Sep 17 00:00:00 2001 From: David Lord Date: Tue, 29 Oct 2024 19:20:33 -0700 Subject: [PATCH] implement or and ior operators --- CHANGES.rst | 2 + src/werkzeug/datastructures/headers.py | 31 +++++++++++++++ src/werkzeug/datastructures/mixins.py | 21 ++++++++++ src/werkzeug/datastructures/structures.py | 25 ++++++++++++ tests/test_datastructures.py | 48 ++++++++++++++++++++++- 5 files changed, 126 insertions(+), 1 deletion(-) diff --git a/CHANGES.rst b/CHANGES.rst index c76cce4c5..0a0616d57 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -25,6 +25,8 @@ Unreleased :issue:`2970` - ``MultiDict.getlist`` catches ``TypeError`` in addition to ``ValueError`` when doing type conversion. :issue:`2976` +- Implement ``|`` and ``|=`` operators for ``MultiDict``, ``Headers``, and + ``CallbackDict``, and disallow ``|=`` on immutable types. :issue:`2977` Version 3.0.6 diff --git a/src/werkzeug/datastructures/headers.py b/src/werkzeug/datastructures/headers.py index a23a0e0b1..db53cda7b 100644 --- a/src/werkzeug/datastructures/headers.py +++ b/src/werkzeug/datastructures/headers.py @@ -41,6 +41,9 @@ class Headers(cabc.MutableMapping[str, str]): :param defaults: The list of default values for the :class:`Headers`. + .. versionchanged:: 3.1 + Implement ``|`` and ``|=`` operators. + .. versionchanged:: 2.1.0 Default values are validated the same as values added later. @@ -524,6 +527,31 @@ def update( # type: ignore[override] else: self.set(key, value) + def __or__( + self, other: cabc.Mapping[str, t.Any | cabc.Collection[t.Any]] + ) -> te.Self: + if not isinstance(other, cabc.Mapping): + return NotImplemented + + rv = self.copy() + rv.update(other) + return rv + + def __ior__( + self, + other: ( + cabc.Mapping[str, t.Any | cabc.Collection[t.Any]] + | cabc.Iterable[tuple[str, t.Any]] + ), + ) -> te.Self: + if not isinstance(other, (cabc.Mapping, cabc.Iterable)) or isinstance( + other, str + ): + return NotImplemented + + self.update(other) + return self + def to_wsgi_list(self) -> list[tuple[str, str]]: """Convert the headers into a list suitable for WSGI. @@ -620,6 +648,9 @@ def __iter__(self) -> cabc.Iterator[tuple[str, str]]: # type: ignore[override] def copy(self) -> t.NoReturn: raise TypeError(f"cannot create {type(self).__name__!r} copies") + def __or__(self, other: t.Any) -> t.NoReturn: + raise TypeError(f"cannot create {type(self).__name__!r} copies") + # circular dependencies from .. import http diff --git a/src/werkzeug/datastructures/mixins.py b/src/werkzeug/datastructures/mixins.py index 76324d5a2..03d461ad8 100644 --- a/src/werkzeug/datastructures/mixins.py +++ b/src/werkzeug/datastructures/mixins.py @@ -76,6 +76,9 @@ def sort(self, key: t.Any = None, reverse: t.Any = False) -> t.NoReturn: class ImmutableDictMixin(t.Generic[K, V]): """Makes a :class:`dict` immutable. + .. versionchanged:: 3.1 + Disallow ``|=`` operator. + .. versionadded:: 0.5 :private: @@ -117,6 +120,9 @@ def setdefault(self, key: t.Any, default: t.Any = None) -> t.NoReturn: def update(self, arg: t.Any, /, **kwargs: t.Any) -> t.NoReturn: _immutable_error(self) + def __ior__(self, other: t.Any) -> t.NoReturn: + _immutable_error(self) + def pop(self, key: t.Any, default: t.Any = None) -> t.NoReturn: _immutable_error(self) @@ -168,6 +174,9 @@ class ImmutableHeadersMixin: hashable though since the only usecase for this datastructure in Werkzeug is a view on a mutable structure. + .. versionchanged:: 3.1 + Disallow ``|=`` operator. + .. versionadded:: 0.5 :private: @@ -200,6 +209,9 @@ def extend(self, arg: t.Any, /, **kwargs: t.Any) -> t.NoReturn: def update(self, arg: t.Any, /, **kwargs: t.Any) -> t.NoReturn: _immutable_error(self) + def __ior__(self, other: t.Any) -> t.NoReturn: + _immutable_error(self) + def insert(self, pos: t.Any, value: t.Any) -> t.NoReturn: _immutable_error(self) @@ -233,6 +245,9 @@ def wrapper( class UpdateDictMixin(dict[K, V]): """Makes dicts call `self.on_update` on modifications. + .. versionchanged:: 3.1 + Implement ``|=`` operator. + .. versionadded:: 0.5 :private: @@ -294,3 +309,9 @@ def update( # type: ignore[override] super().update(**kwargs) else: super().update(arg, **kwargs) + + @_always_update + def __ior__( # type: ignore[override] + self, other: cabc.Mapping[K, V] | cabc.Iterable[tuple[K, V]] + ) -> te.Self: + return super().__ior__(other) diff --git a/src/werkzeug/datastructures/structures.py b/src/werkzeug/datastructures/structures.py index a48d504e4..db2f99800 100644 --- a/src/werkzeug/datastructures/structures.py +++ b/src/werkzeug/datastructures/structures.py @@ -170,6 +170,9 @@ class MultiDict(TypeConversionDict[K, V]): :param mapping: the initial value for the :class:`MultiDict`. Either a regular dict, an iterable of ``(key, value)`` tuples or `None`. + + .. versionchanged:: 3.1 + Implement ``|`` and ``|=`` operators. """ def __init__( @@ -435,6 +438,28 @@ def update( # type: ignore[override] for key, value in iter_multi_items(mapping): self.add(key, value) + def __or__( # type: ignore[override] + self, other: cabc.Mapping[K, V | cabc.Collection[V]] + ) -> MultiDict[K, V]: + if not isinstance(other, cabc.Mapping): + return NotImplemented + + rv = self.copy() + rv.update(other) + return rv + + def __ior__( # type: ignore[override] + self, + other: cabc.Mapping[K, V | cabc.Collection[V]] | cabc.Iterable[tuple[K, V]], + ) -> te.Self: + if not isinstance(other, (cabc.Mapping, cabc.Iterable)) or isinstance( + other, str + ): + return NotImplemented + + self.update(other) + return self + @t.overload def pop(self, key: K) -> V: ... @t.overload diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 9d11d2aab..76a5530fc 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -258,6 +258,17 @@ def test_basic_interface(self): md.setlist("foo", [1, 2]) assert md.getlist("foo") == [1, 2] + def test_or(self) -> None: + a = self.storage_class({"x": 1}) + b = a | {"y": 2} + assert isinstance(b, self.storage_class) + assert "x" in b and "y" in b + + def test_ior(self) -> None: + a = self.storage_class({"x": 1}) + a |= {"y": 2} + assert "x" in a and "y" in a + class _ImmutableDictTests: storage_class: type[dict] @@ -305,6 +316,17 @@ def test_dict_is_hashable(self): assert immutable in x assert immutable2 in x + def test_or(self) -> None: + a = self.storage_class({"x": 1}) + b = a | {"y": 2} + assert "x" in b and "y" in b + + def test_ior(self) -> None: + a = self.storage_class({"x": 1}) + + with pytest.raises(TypeError): + a |= {"y": 2} + class TestImmutableTypeConversionDict(_ImmutableDictTests): storage_class = ds.ImmutableTypeConversionDict @@ -799,6 +821,17 @@ def test_equality(self): assert h1 == h2 + def test_or(self) -> None: + a = ds.Headers({"x": 1}) + b = a | {"y": 2} + assert isinstance(b, ds.Headers) + assert "x" in b and "y" in b + + def test_ior(self) -> None: + a = ds.Headers({"x": 1}) + a |= {"y": 2} + assert "x" in a and "y" in a + class TestEnvironHeaders: storage_class = ds.EnvironHeaders @@ -840,6 +873,18 @@ def test_return_type_is_str(self): assert headers["Foo"] == "\xe2\x9c\x93" assert next(iter(headers)) == ("Foo", "\xe2\x9c\x93") + def test_or(self) -> None: + headers = ds.EnvironHeaders({"x": "1"}) + + with pytest.raises(TypeError): + headers | {"y": "2"} + + def test_ior(self) -> None: + headers = ds.EnvironHeaders({}) + + with pytest.raises(TypeError): + headers |= {"y": "2"} + class TestHeaderSet: storage_class = ds.HeaderSet @@ -927,7 +972,7 @@ def test_callback_dict_writes(self): assert_calls, func = make_call_asserter() initial = {"a": "foo", "b": "bar"} dct = self.storage_class(initial=initial, on_update=func) - with assert_calls(8, "callback not triggered by write method"): + with assert_calls(9, "callback not triggered by write method"): # always-write methods dct["z"] = 123 dct["z"] = 123 # must trigger again @@ -937,6 +982,7 @@ def test_callback_dict_writes(self): dct.popitem() dct.update([]) dct.clear() + dct |= {} with assert_calls(0, "callback triggered by failed del"): pytest.raises(KeyError, lambda: dct.__delitem__("x")) with assert_calls(0, "callback triggered by failed pop"):