Skip to content

Commit

Permalink
bug: fix bug in parsing of repeated headers
Browse files Browse the repository at this point in the history
Closes #19
  • Loading branch information
adriangb committed Jan 21, 2022
1 parent e361d9c commit edcde50
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "xpresso"
version = "0.4.3"
version = "0.4.4"
description = "A developer centric, performant Python web framework"
authors = ["Adrian Garcia Badaracco <[email protected]>"]
readme = "README.md"
Expand Down
18 changes: 7 additions & 11 deletions tests/test_extractors/params/test_header_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"headers,status_code,json_response",
[
({"Header": "123"}, 200, {"Header": "123"}),
({"Header": "1,2,3"}, 200, {"Header": "1,2,3"}),
({"Header": "1,2,3"}, 200, {"Header": "1"}),
({"Header": ""}, 200, {"Header": ""}),
],
)
Expand Down Expand Up @@ -47,16 +47,8 @@ async def test(header: Annotated[str, HeaderParam(explode=explode)]) -> Any:
({"Header": "123"}, 200, {"Header": 123}),
(
{"Header": "1,2,3"},
422,
{
"detail": [
{
"loc": ["header", "header"],
"msg": "value is not a valid integer",
"type": "type_error.integer",
}
]
},
200,
{"Header": 1},
),
(
{"Header": ""},
Expand Down Expand Up @@ -111,6 +103,7 @@ async def test(header: Annotated[int, HeaderParam(explode=explode)]) -> Any:
[
({"Header": "1,2"}, 200, {"Header": ["1", "2"]}),
({"Header": "1,2,"}, 200, {"Header": ["1", "2", ""]}),
({"Header": "1, 2"}, 200, {"Header": ["1", "2"]}),
({"Header": ""}, 200, {"Header": []}),
({"Header": ","}, 200, {"Header": ["", ""]}),
({}, 200, {"Header": []}),
Expand Down Expand Up @@ -153,6 +146,7 @@ async def test(header: Annotated[List[str], HeaderParam(explode=explode)]) -> An
]
},
),
({"Header": "1, 2"}, 200, {"Header": [1, 2]}),
({"Header": ""}, 200, {"Header": []}),
(
{"Header": ","},
Expand Down Expand Up @@ -200,6 +194,7 @@ async def test(header: Annotated[List[int], HeaderParam(explode=explode)]) -> An
[
# explode = True
(True, {"Header": "foo=1,bar=2"}, 200, {"foo": 1, "bar": "2", "baz": "3"}),
(True, {"Header": "foo=1, bar=2"}, 200, {"foo": 1, "bar": "2", "baz": "3"}),
(
True,
{"Header": "foo=1,bar=2,baz=4"},
Expand Down Expand Up @@ -284,6 +279,7 @@ async def test(header: Annotated[List[int], HeaderParam(explode=explode)]) -> An
),
# explode = False
(False, {"Header": "foo,1,bar,2"}, 200, {"foo": 1, "bar": "2", "baz": "3"}),
(False, {"Header": "foo, 1, bar, 2"}, 200, {"foo": 1, "bar": "2", "baz": "3"}),
(
False,
{"Header": "foo,1,bar,2,baz,4"},
Expand Down
34 changes: 27 additions & 7 deletions xpresso/binders/_extractors/params/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,19 @@
from xpresso.typing import Some


def collect_scalar(value: Optional[str]) -> Optional[Some[str]]:
if value is None:
return None
split = value.split(",")
if len(split) == 1:
return Some(split[0])
return Some(next(iter(split)))


def collect_sequence(value: Optional[str]) -> Optional[Some[List[str]]]:
if not value:
return Some(cast(List[str], []))
return Some(value.split(","))
return Some([v.lstrip() for v in value.split(",")])


def collect_object(
Expand All @@ -44,10 +53,12 @@ def collect_object(
if len(split) == 1 or not field or field[0] == "=":
raise InvalidSerialization(f"invalid object style header: {value}")
name, val = split
res[name] = val
res[name.lstrip()] = val
return Some(res)
try:
groups = cast(Iterable[Tuple[str, str]], grouped(value.split(",")))
groups = cast(
Iterable[Tuple[str, str]], grouped([v.lstrip() for v in value.split(",")])
)
except ValueError:
raise InvalidSerialization(f"invalid object style header: {value}")
return Some(dict(groups))
Expand All @@ -65,7 +76,7 @@ def get_extractor(explode: bool, field: ModelField) -> Extractor:
if is_mapping_like(field):
return functools.partial(collect_object, explode)
# single item
return lambda value: Some(value) if value is not None else None
return collect_scalar


ERRORS = {
Expand All @@ -87,12 +98,21 @@ async def extract(
send: starlette.types.Send,
connection: HTTPConnection,
) -> Any:
param_value: "Optional[str]" = None
# parse headers according to RFC 7230
# this means treating repeated headers and "," seperated ones the same
# so here we merge them all into one "," seperated string
# also note that whitespaces after a "," don't matter, so we .lstrip() as needed
header_values: "List[str]" = []
for name, value in scope["headers"]:
if name == self.header_name:
param_value = value.decode("latin-1")
header_values.append(value.decode("latin-1"))
header_value: "Optional[str]"
if header_values:
header_value = ",".join(header_values)
else:
header_value = None
try:
extracted = self.extractor(param_value)
extracted = self.extractor(header_value)
except InvalidSerialization as exc:
raise ERRORS[scope["type"]](
[ErrorWrapper(exc=exc, loc=("header", self.name))]
Expand Down

0 comments on commit edcde50

Please sign in to comment.