Skip to content

Commit

Permalink
feat: support context managers for extractors (#95)
Browse files Browse the repository at this point in the history
Use this feature to close files and forms automatically
  • Loading branch information
adriangb authored Apr 27, 2022
1 parent 5df6a45 commit a5634fd
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 30 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "xpresso"
version = "0.40.0"
version = "0.41.0"
description = "A developer centric, performant Python web framework"
authors = ["Adrian Garcia Badaracco <[email protected]>"]
readme = "README.md"
Expand Down
60 changes: 43 additions & 17 deletions xpresso/binders/_binders/file_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import enum
import inspect
import typing
from contextlib import asynccontextmanager

from pydantic.fields import ModelField
from starlette.datastructures import UploadFile
Expand All @@ -11,6 +12,11 @@
from xpresso._utils.typing import Literal
from xpresso.binders._binders.media_type_validator import MediaTypeValidator
from xpresso.binders._binders.pydantic_validators import validate_body_field
from xpresso.binders._binders.utils import (
Consumer,
ConsumerContextManager,
wrap_consumer_as_cm,
)
from xpresso.binders.api import ModelNameMap, SupportsExtractor, SupportsOpenAPI
from xpresso.openapi import models as openapi_models
from xpresso.openapi._utils import parse_examples
Expand All @@ -35,6 +41,10 @@ def get_file_type(field: ModelField) -> FileType:
raise TypeError(f"Target type {field.type_.__name__} is not recognized")


RequestConsumer = Consumer[Request]
RequestConsumerContextManger = ConsumerContextManager[Request]


async def consume_into_bytes(request: Request) -> bytes:
res = bytearray()
async for chunk in request.stream():
Expand All @@ -48,30 +58,42 @@ async def read_into_bytes(request: Request) -> bytes:

def create_consume_into_uploadfile(
cls: typing.Type[UploadFile],
) -> typing.Callable[[Request], typing.Awaitable[UploadFile]]:
async def consume_into_uploadfile(request: Request) -> UploadFile:
) -> RequestConsumerContextManger:
@asynccontextmanager
async def consume_into_uploadfile(
request: Request,
) -> typing.AsyncIterator[UploadFile]:
file = cls(
filename="body", content_type=request.headers.get("Content-Type", "*/*")
)
async for chunk in request.stream():
if chunk:
await file.write(chunk)
await file.seek(0)
return file
try:
yield file
finally:
await file.close()

return consume_into_uploadfile


def create_read_into_uploadfile(
cls: typing.Type[UploadFile],
) -> typing.Callable[[Request], typing.Awaitable[UploadFile]]:
async def read_into_uploadfile(request: Request) -> UploadFile:
) -> RequestConsumerContextManger:
@asynccontextmanager
async def read_into_uploadfile(
request: Request,
) -> typing.AsyncIterator[UploadFile]:
file = cls(
filename="body", content_type=request.headers.get("Content-Type", "*/*")
)
await file.write(await request.body())
await file.seek(0)
return file
try:
yield file
finally:
await file.close()

return read_into_uploadfile

Expand All @@ -95,7 +117,7 @@ def has_body(conn: HTTPConnection) -> bool:

class Extractor(typing.NamedTuple):
media_type_validator: MediaTypeValidator
consumer: typing.Callable[[Request], typing.Awaitable[typing.Any]]
consumer_cm: RequestConsumerContextManger
field: ModelField

def __hash__(self) -> int:
Expand All @@ -104,13 +126,17 @@ def __hash__(self) -> int:
def __eq__(self, __o: object) -> bool:
return isinstance(__o, Extractor)

async def extract(self, connection: HTTPConnection) -> typing.Any:
async def extract(
self, connection: HTTPConnection
) -> typing.AsyncIterator[typing.Any]:
assert isinstance(connection, Request)
if not has_body(connection):
return validate_body_field(None, field=self.field, loc=("body",))
yield validate_body_field(None, field=self.field, loc=("body",))
return
media_type = connection.headers.get("content-type", None)
self.media_type_validator.validate(media_type)
return await self.consumer(connection)
async with self.consumer_cm(connection) as res:
yield res


class ExtractorMarker(typing.NamedTuple):
Expand All @@ -123,27 +149,27 @@ def register_parameter(self, param: inspect.Parameter) -> SupportsExtractor:
media_type_validator = MediaTypeValidator(self.media_type)
else:
media_type_validator = MediaTypeValidator(None)
consumer: typing.Callable[[Request], typing.Any]
consumer_cm: RequestConsumerContextManger
field = model_field_from_param(param, arbitrary_types_allowed=True)
file_type = get_file_type(field)
if file_type is FileType.bytes:
if self.consume:
consumer = consume_into_bytes
consumer_cm = wrap_consumer_as_cm(consume_into_bytes)
else:
consumer = read_into_bytes
consumer_cm = wrap_consumer_as_cm(read_into_bytes)
elif file_type is FileType.uploadfile:
if self.consume:
consumer = create_consume_into_uploadfile(field.type_)
consumer_cm = create_consume_into_uploadfile(field.type_)
else:
consumer = create_read_into_uploadfile(field.type_)
consumer_cm = create_read_into_uploadfile(field.type_)
else: # stream
if self.consume:
consumer = consume_into_stream
consumer_cm = wrap_consumer_as_cm(consume_into_stream)
else:
raise ValueError("consume=False is not supported for streams")
return Extractor(
media_type_validator=media_type_validator,
consumer=consumer,
consumer_cm=consumer_cm,
field=field,
)

Expand Down
13 changes: 10 additions & 3 deletions xpresso/binders/_binders/form_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,26 +231,33 @@ def __hash__(self) -> int:
def __eq__(self, __o: object) -> bool:
return isinstance(__o, Extractor)

async def extract(self, connection: HTTPConnection) -> typing.Any:
async def extract(
self, connection: HTTPConnection
) -> typing.AsyncIterator[typing.Any]:
assert isinstance(connection, Request)
content_type = connection.headers.get("content-type", None)
if (
content_type is None
and connection.headers.get("content-length", "0") == "0"
):
return validate_body_field(None, field=self.field, loc=("body",))
yield validate_body_field(None, field=self.field, loc=("body",))
return
self.media_type_validator.validate(content_type)
form = await connection.form()
res: typing.Dict[str, typing.Any] = {}
for param_name, extractor in self.field_extractors.items():
extracted = await extractor.extract(form)
if isinstance(extracted, Some):
res[param_name] = extracted.value
return validate_body_field(
validated_form = validate_body_field(
Some(res),
field=self.field,
loc=("body",),
)
try:
yield validated_form
finally:
await form.close()


class ExtractorMarker(typing.NamedTuple):
Expand Down
45 changes: 38 additions & 7 deletions xpresso/binders/_binders/union.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
import contextlib
import inspect
import typing

import xpresso.openapi.models as openapi_models
from xpresso._utils.pydantic_utils import model_field_from_param
from xpresso._utils.typing import Annotated, get_args, get_origin
from xpresso.binders._binders.utils import (
Consumer,
ConsumerContextManager,
wrap_consumer_as_cm,
)
from xpresso.binders.api import ModelNameMap, SupportsExtractor, SupportsOpenAPI
from xpresso.binders.dependants import Binder, BinderMarker
from xpresso.exceptions import HTTPException, RequestValidationError
from xpresso.requests import HTTPConnection, Request

RequestConsumer = Consumer[Request]
RequestConsumerContextManger = ConsumerContextManager[Request]


def get_binders_from_union_annotation(param: inspect.Parameter) -> typing.List[Binder]:
providers: typing.List[Binder] = []
Expand Down Expand Up @@ -60,15 +69,24 @@ def register_parameter(self, param: inspect.Parameter) -> SupportsOpenAPI:
)


SupportsExtractorCM = typing.Callable[
[HTTPConnection], typing.AsyncContextManager[typing.Any]
]


class Extractor(typing.NamedTuple):
providers: typing.Tuple[SupportsExtractor, ...]
extractors: typing.Iterable[SupportsExtractorCM]

async def extract(self, connection: HTTPConnection) -> typing.Any:
async def extract(
self, connection: HTTPConnection
) -> typing.AsyncIterator[typing.Any]:
assert isinstance(connection, Request)
errors: "typing.List[typing.Union[HTTPException, RequestValidationError]]" = []
for provider in self.providers:
for extractor in self.extractors:
try:
return await provider.extract(connection)
async with extractor(connection) as res:
yield res
return
except (HTTPException, RequestValidationError) as error:
errors.append(error)
# if any body accepted the request but didn't pass validation, return the error from that one
Expand All @@ -81,9 +99,22 @@ async def extract(self, connection: HTTPConnection) -> typing.Any:
# and leaking implementation details
raise next(iter(errors))

def __hash__(self) -> int:
return id(self)

def __eq__(self, __o: object) -> bool:
return False


class ExtractorMarker(typing.NamedTuple):
def register_parameter(self, param: inspect.Parameter) -> SupportsExtractor:
return Extractor(
tuple(b.extractor for b in get_binders_from_union_annotation(param))
)
extractors: typing.List[SupportsExtractorCM] = []
for binder in get_binders_from_union_annotation(param):
extractor = binder.extractor.extract
if inspect.isasyncgenfunction(extractor):
extractors.append(
contextlib.asynccontextmanager(extractor) # type: ignore[arg-type]
)
else:
extractors.append(wrap_consumer_as_cm(extractor))
return Extractor(extractors)
16 changes: 16 additions & 0 deletions xpresso/binders/_binders/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import contextlib
import typing

T = typing.TypeVar("T")

Consumer = typing.Callable[[T], typing.Any]
ConsumerContextManager = typing.Callable[[T], typing.AsyncContextManager[typing.Any]]


def wrap_consumer_as_cm(consumer: Consumer[T]) -> ConsumerContextManager[T]:
@contextlib.asynccontextmanager
async def consume(request: T) -> typing.AsyncIterator[typing.Any]:
res = await consumer(request)
yield res

return consume
9 changes: 7 additions & 2 deletions xpresso/binders/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Awaitable, Dict, List, Union
from typing import Any, AsyncIterator, Awaitable, Dict, List, Union

from starlette.requests import HTTPConnection

Expand All @@ -7,14 +7,19 @@


class SupportsExtractor(Protocol):
def extract(self, connection: HTTPConnection) -> Union[Awaitable[Any], Any]:
def extract(
self, connection: HTTPConnection
) -> Union[Awaitable[Any], AsyncIterator[Any]]:
"""Extract data from an incoming connection.
The `connection` parameter will always be either a Request object or a WebSocket object,
which are both subclasses of HTTPConnection.
If you just need access to headers, query params, or any other metadata present in HTTPConnection
then you can use the parameter directly.
Otherwise, you can do `isinstance(connection, Request)` before accessing `Request.stream()` and such.
The return value can be an awaitable or an async iterable (context manager like).
The iterator versions will be wrapped with `@contextlib.{async}contextmanager`.
"""
...

Expand Down

0 comments on commit a5634fd

Please sign in to comment.