Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement or and ior operators #2979

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Unreleased
:issue:`2970`
- ``MultiDict.getlist`` catches ``TypeError`` in addition to ``ValueError``
when doing type conversion. :issue:`2976`
- Implement ``|`` and ``|=`` operators for ``MultiDict``, ``Headers``, and
``CallbackDict``, and disallow ``|=`` on immutable types. :issue:`2977`


Version 3.0.6
Expand Down
31 changes: 31 additions & 0 deletions src/werkzeug/datastructures/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class Headers(cabc.MutableMapping[str, str]):

:param defaults: The list of default values for the :class:`Headers`.

.. versionchanged:: 3.1
Implement ``|`` and ``|=`` operators.

.. versionchanged:: 2.1.0
Default values are validated the same as values added later.

Expand Down Expand Up @@ -524,6 +527,31 @@ def update( # type: ignore[override]
else:
self.set(key, value)

def __or__(
self, other: cabc.Mapping[str, t.Any | cabc.Collection[t.Any]]
) -> te.Self:
if not isinstance(other, cabc.Mapping):
return NotImplemented

rv = self.copy()
rv.update(other)
return rv

def __ior__(
self,
other: (
cabc.Mapping[str, t.Any | cabc.Collection[t.Any]]
| cabc.Iterable[tuple[str, t.Any]]
),
) -> te.Self:
if not isinstance(other, (cabc.Mapping, cabc.Iterable)) or isinstance(
other, str
):
return NotImplemented

self.update(other)
return self

def to_wsgi_list(self) -> list[tuple[str, str]]:
"""Convert the headers into a list suitable for WSGI.

Expand Down Expand Up @@ -620,6 +648,9 @@ def __iter__(self) -> cabc.Iterator[tuple[str, str]]: # type: ignore[override]
def copy(self) -> t.NoReturn:
raise TypeError(f"cannot create {type(self).__name__!r} copies")

def __or__(self, other: t.Any) -> t.NoReturn:
raise TypeError(f"cannot create {type(self).__name__!r} copies")


# circular dependencies
from .. import http
21 changes: 21 additions & 0 deletions src/werkzeug/datastructures/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def sort(self, key: t.Any = None, reverse: t.Any = False) -> t.NoReturn:
class ImmutableDictMixin(t.Generic[K, V]):
"""Makes a :class:`dict` immutable.

.. versionchanged:: 3.1
Disallow ``|=`` operator.

.. versionadded:: 0.5

:private:
Expand Down Expand Up @@ -117,6 +120,9 @@ def setdefault(self, key: t.Any, default: t.Any = None) -> t.NoReturn:
def update(self, arg: t.Any, /, **kwargs: t.Any) -> t.NoReturn:
_immutable_error(self)

def __ior__(self, other: t.Any) -> t.NoReturn:
_immutable_error(self)

def pop(self, key: t.Any, default: t.Any = None) -> t.NoReturn:
_immutable_error(self)

Expand Down Expand Up @@ -168,6 +174,9 @@ class ImmutableHeadersMixin:
hashable though since the only usecase for this datastructure
in Werkzeug is a view on a mutable structure.

.. versionchanged:: 3.1
Disallow ``|=`` operator.

.. versionadded:: 0.5

:private:
Expand Down Expand Up @@ -200,6 +209,9 @@ def extend(self, arg: t.Any, /, **kwargs: t.Any) -> t.NoReturn:
def update(self, arg: t.Any, /, **kwargs: t.Any) -> t.NoReturn:
_immutable_error(self)

def __ior__(self, other: t.Any) -> t.NoReturn:
_immutable_error(self)

def insert(self, pos: t.Any, value: t.Any) -> t.NoReturn:
_immutable_error(self)

Expand Down Expand Up @@ -233,6 +245,9 @@ def wrapper(
class UpdateDictMixin(dict[K, V]):
"""Makes dicts call `self.on_update` on modifications.

.. versionchanged:: 3.1
Implement ``|=`` operator.

.. versionadded:: 0.5

:private:
Expand Down Expand Up @@ -294,3 +309,9 @@ def update( # type: ignore[override]
super().update(**kwargs)
else:
super().update(arg, **kwargs)

@_always_update
def __ior__( # type: ignore[override]
self, other: cabc.Mapping[K, V] | cabc.Iterable[tuple[K, V]]
) -> te.Self:
return super().__ior__(other)
25 changes: 25 additions & 0 deletions src/werkzeug/datastructures/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ class MultiDict(TypeConversionDict[K, V]):
:param mapping: the initial value for the :class:`MultiDict`. Either a
regular dict, an iterable of ``(key, value)`` tuples
or `None`.

.. versionchanged:: 3.1
Implement ``|`` and ``|=`` operators.
"""

def __init__(
Expand Down Expand Up @@ -435,6 +438,28 @@ def update( # type: ignore[override]
for key, value in iter_multi_items(mapping):
self.add(key, value)

def __or__( # type: ignore[override]
self, other: cabc.Mapping[K, V | cabc.Collection[V]]
) -> MultiDict[K, V]:
if not isinstance(other, cabc.Mapping):
return NotImplemented

rv = self.copy()
rv.update(other)
return rv

def __ior__( # type: ignore[override]
self,
other: cabc.Mapping[K, V | cabc.Collection[V]] | cabc.Iterable[tuple[K, V]],
) -> te.Self:
if not isinstance(other, (cabc.Mapping, cabc.Iterable)) or isinstance(
other, str
):
return NotImplemented

self.update(other)
return self

@t.overload
def pop(self, key: K) -> V: ...
@t.overload
Expand Down
48 changes: 47 additions & 1 deletion tests/test_datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,17 @@ def test_basic_interface(self):
md.setlist("foo", [1, 2])
assert md.getlist("foo") == [1, 2]

def test_or(self) -> None:
a = self.storage_class({"x": 1})
b = a | {"y": 2}
assert isinstance(b, self.storage_class)
assert "x" in b and "y" in b

def test_ior(self) -> None:
a = self.storage_class({"x": 1})
a |= {"y": 2}
assert "x" in a and "y" in a


class _ImmutableDictTests:
storage_class: type[dict]
Expand Down Expand Up @@ -305,6 +316,17 @@ def test_dict_is_hashable(self):
assert immutable in x
assert immutable2 in x

def test_or(self) -> None:
a = self.storage_class({"x": 1})
b = a | {"y": 2}
assert "x" in b and "y" in b

def test_ior(self) -> None:
a = self.storage_class({"x": 1})

with pytest.raises(TypeError):
a |= {"y": 2}


class TestImmutableTypeConversionDict(_ImmutableDictTests):
storage_class = ds.ImmutableTypeConversionDict
Expand Down Expand Up @@ -799,6 +821,17 @@ def test_equality(self):

assert h1 == h2

def test_or(self) -> None:
a = ds.Headers({"x": 1})
b = a | {"y": 2}
assert isinstance(b, ds.Headers)
assert "x" in b and "y" in b

def test_ior(self) -> None:
a = ds.Headers({"x": 1})
a |= {"y": 2}
assert "x" in a and "y" in a


class TestEnvironHeaders:
storage_class = ds.EnvironHeaders
Expand Down Expand Up @@ -840,6 +873,18 @@ def test_return_type_is_str(self):
assert headers["Foo"] == "\xe2\x9c\x93"
assert next(iter(headers)) == ("Foo", "\xe2\x9c\x93")

def test_or(self) -> None:
headers = ds.EnvironHeaders({"x": "1"})

with pytest.raises(TypeError):
headers | {"y": "2"}

def test_ior(self) -> None:
headers = ds.EnvironHeaders({})

with pytest.raises(TypeError):
headers |= {"y": "2"}


class TestHeaderSet:
storage_class = ds.HeaderSet
Expand Down Expand Up @@ -927,7 +972,7 @@ def test_callback_dict_writes(self):
assert_calls, func = make_call_asserter()
initial = {"a": "foo", "b": "bar"}
dct = self.storage_class(initial=initial, on_update=func)
with assert_calls(8, "callback not triggered by write method"):
with assert_calls(9, "callback not triggered by write method"):
# always-write methods
dct["z"] = 123
dct["z"] = 123 # must trigger again
Expand All @@ -937,6 +982,7 @@ def test_callback_dict_writes(self):
dct.popitem()
dct.update([])
dct.clear()
dct |= {}
with assert_calls(0, "callback triggered by failed del"):
pytest.raises(KeyError, lambda: dct.__delitem__("x"))
with assert_calls(0, "callback triggered by failed pop"):
Expand Down