Skip to content

Commit

Permalink
restrict containers accepted by multi (#2995)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidism authored Nov 6, 2024
2 parents 1a1728e + 598bb1d commit ea93b54
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 28 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ Version 3.1.3

Unreleased

- Initial data passed to ``MultiDict`` and similar interfaces only accepts
``list``, ``tuple``, or ``set`` when passing multiple values. It had been
changed to accept any ``Collection``, but this matched types that should be
treated as single values, such as ``bytes``. :issue:`2994`


Version 3.1.2
-------------
Expand Down
27 changes: 14 additions & 13 deletions src/werkzeug/datastructures/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
defaults: (
Headers
| MultiDict[str, t.Any]
| cabc.Mapping[str, t.Any | cabc.Collection[t.Any]]
| cabc.Mapping[str, t.Any | list[t.Any] | tuple[t.Any, ...] | set[t.Any]]
| cabc.Iterable[tuple[str, t.Any]]
| None
) = None,
Expand Down Expand Up @@ -227,7 +227,7 @@ def extend(
arg: (
Headers
| MultiDict[str, t.Any]
| cabc.Mapping[str, t.Any | cabc.Collection[t.Any]]
| cabc.Mapping[str, t.Any | list[t.Any] | tuple[t.Any, ...] | set[t.Any]]
| cabc.Iterable[tuple[str, t.Any]]
| None
) = None,
Expand Down Expand Up @@ -491,12 +491,14 @@ def update(
arg: (
Headers
| MultiDict[str, t.Any]
| cabc.Mapping[str, t.Any | cabc.Collection[t.Any]]
| cabc.Mapping[
str, t.Any | list[t.Any] | tuple[t.Any, ...] | cabc.Set[t.Any]
]
| cabc.Iterable[tuple[str, t.Any]]
| None
) = None,
/,
**kwargs: t.Any | cabc.Collection[t.Any],
**kwargs: t.Any | list[t.Any] | tuple[t.Any, ...] | cabc.Set[t.Any],
) -> None:
"""Replace headers in this object with items from another
headers object and keyword arguments.
Expand All @@ -516,9 +518,7 @@ def update(
self.setlist(key, arg.getlist(key))
elif isinstance(arg, cabc.Mapping):
for key, value in arg.items():
if isinstance(value, cabc.Collection) and not isinstance(
value, str
):
if isinstance(value, (list, tuple, set)):
self.setlist(key, value)
else:
self.set(key, value)
Expand All @@ -527,13 +527,16 @@ def update(
self.set(key, value)

for key, value in kwargs.items():
if isinstance(value, cabc.Collection) and not isinstance(value, str):
if isinstance(value, (list, tuple, set)):
self.setlist(key, value)
else:
self.set(key, value)

def __or__(
self, other: cabc.Mapping[str, t.Any | cabc.Collection[t.Any]]
self,
other: cabc.Mapping[
str, t.Any | list[t.Any] | tuple[t.Any, ...] | cabc.Set[t.Any]
],
) -> te.Self:
if not isinstance(other, cabc.Mapping):
return NotImplemented
Expand All @@ -545,13 +548,11 @@ def __or__(
def __ior__(
self,
other: (
cabc.Mapping[str, t.Any | cabc.Collection[t.Any]]
cabc.Mapping[str, t.Any | list[t.Any] | tuple[t.Any, ...] | cabc.Set[t.Any]]
| cabc.Iterable[tuple[str, t.Any]]
),
) -> te.Self:
if not isinstance(other, (cabc.Mapping, cabc.Iterable)) or isinstance(
other, str
):
if not isinstance(other, (cabc.Mapping, cabc.Iterable)):
return NotImplemented

self.update(other)
Expand Down
29 changes: 15 additions & 14 deletions src/werkzeug/datastructures/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
def iter_multi_items(
mapping: (
MultiDict[K, V]
| cabc.Mapping[K, V | cabc.Collection[V]]
| cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]]
| cabc.Iterable[tuple[K, V]]
),
) -> cabc.Iterator[tuple[K, V]]:
Expand All @@ -33,11 +33,11 @@ def iter_multi_items(
yield from mapping.items(multi=True)
elif isinstance(mapping, cabc.Mapping):
for key, value in mapping.items():
if isinstance(value, cabc.Collection) and not isinstance(value, str):
if isinstance(value, (list, tuple, set)):
for v in value:
yield key, v
else:
yield key, value # type: ignore[misc]
yield key, value
else:
yield from mapping

Expand Down Expand Up @@ -182,7 +182,7 @@ def __init__(
self,
mapping: (
MultiDict[K, V]
| cabc.Mapping[K, V | cabc.Collection[V]]
| cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]]
| cabc.Iterable[tuple[K, V]]
| None
) = None,
Expand All @@ -194,7 +194,7 @@ def __init__(
elif isinstance(mapping, cabc.Mapping):
tmp = {}
for key, value in mapping.items():
if isinstance(value, cabc.Collection) and not isinstance(value, str):
if isinstance(value, (list, tuple, set)):
value = list(value)

if not value:
Expand Down Expand Up @@ -419,7 +419,7 @@ def update( # type: ignore[override]
self,
mapping: (
MultiDict[K, V]
| cabc.Mapping[K, V | cabc.Collection[V]]
| cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]]
| cabc.Iterable[tuple[K, V]]
),
) -> None:
Expand All @@ -444,7 +444,7 @@ def update( # type: ignore[override]
self.add(key, value)

def __or__( # type: ignore[override]
self, other: cabc.Mapping[K, V | cabc.Collection[V]]
self, other: cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]]
) -> MultiDict[K, V]:
if not isinstance(other, cabc.Mapping):
return NotImplemented
Expand All @@ -455,11 +455,12 @@ def __or__( # type: ignore[override]

def __ior__( # type: ignore[override]
self,
other: cabc.Mapping[K, V | cabc.Collection[V]] | cabc.Iterable[tuple[K, V]],
other: (
cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]]
| cabc.Iterable[tuple[K, V]]
),
) -> te.Self:
if not isinstance(other, (cabc.Mapping, cabc.Iterable)) or isinstance(
other, str
):
if not isinstance(other, (cabc.Mapping, cabc.Iterable)):
return NotImplemented

self.update(other)
Expand Down Expand Up @@ -600,7 +601,7 @@ def __init__(
self,
mapping: (
MultiDict[K, V]
| cabc.Mapping[K, V | cabc.Collection[V]]
| cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]]
| cabc.Iterable[tuple[K, V]]
| None
) = None,
Expand Down Expand Up @@ -744,7 +745,7 @@ def update( # type: ignore[override]
self,
mapping: (
MultiDict[K, V]
| cabc.Mapping[K, V | cabc.Collection[V]]
| cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]]
| cabc.Iterable[tuple[K, V]]
),
) -> None:
Expand Down Expand Up @@ -1009,7 +1010,7 @@ def __init__(
self,
mapping: (
MultiDict[K, V]
| cabc.Mapping[K, V | cabc.Collection[V]]
| cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]]
| cabc.Iterable[tuple[K, V]]
| None
) = None,
Expand Down
17 changes: 16 additions & 1 deletion tests/test_datastructures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import io
import pickle
import tempfile
import typing as t
from contextlib import contextmanager
from copy import copy
from copy import deepcopy
Expand Down Expand Up @@ -43,7 +46,7 @@ def items(self, multi=1):


class _MutableMultiDictTests:
storage_class: type["ds.MultiDict"]
storage_class: type[ds.MultiDict]

def test_pickle(self):
cls = self.storage_class
Expand Down Expand Up @@ -1280,3 +1283,15 @@ def test_range_to_header(ranges):
def test_range_validates_ranges(ranges):
with pytest.raises(ValueError):
ds.Range("bytes", ranges)


@pytest.mark.parametrize(
("value", "expect"),
[
({"a": "ab"}, [("a", "ab")]),
({"a": ["a", "b"]}, [("a", "a"), ("a", "b")]),
({"a": b"ab"}, [("a", b"ab")]),
],
)
def test_iter_multi_data(value: t.Any, expect: list[tuple[t.Any, t.Any]]) -> None:
assert list(ds.iter_multi_items(value)) == expect

0 comments on commit ea93b54

Please sign in to comment.