Skip to content

Commit

Permalink
refactor/feat: add scheme_name to Security() and refactor security …
Browse files Browse the repository at this point in the history
…/ binders (#40)
  • Loading branch information
adriangb authored Feb 3, 2022
1 parent 4d5b8a9 commit 4548dc2
Show file tree
Hide file tree
Showing 60 changed files with 567 additions and 400 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "xpresso"
version = "0.9.4"
version = "0.10.0"
description = "A developer centric, performant Python web framework"
authors = ["Adrian Garcia Badaracco <[email protected]>"]
readme = "README.md"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,30 @@
},
"parameters": [
{
"description": "Count of items to skip starting from the 0th item",
"description": "Maximum number of items to return",
"required": True,
"style": "form",
"explode": True,
"schema": {
"title": "Skip",
"title": "Limit",
"exclusiveMinimum": 0.0,
"type": "integer",
"description": "Maximum number of items to return",
},
"name": "skip",
"name": "limit",
"in": "query",
},
{
"description": "Maximum number of items to return",
"description": "Count of items to skip starting from the 0th item",
"required": True,
"style": "form",
"explode": True,
"schema": {
"title": "Limit",
"title": "Skip",
"exclusiveMinimum": 0.0,
"type": "integer",
"description": "Maximum number of items to return",
},
"name": "limit",
"name": "skip",
"in": "query",
},
],
Expand Down
8 changes: 4 additions & 4 deletions tests/test_docs/tutorial/query_params/test_tutorial_001.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
{
"style": "form",
"explode": True,
"schema": {"title": "Skip", "type": "integer", "default": 0},
"name": "skip",
"schema": {"title": "Limit", "type": "integer", "default": 2},
"name": "limit",
"in": "query",
},
{
"style": "form",
"explode": True,
"schema": {"title": "Limit", "type": "integer", "default": 2},
"name": "limit",
"schema": {"title": "Skip", "type": "integer", "default": 0},
"name": "skip",
"in": "query",
},
],
Expand Down
14 changes: 7 additions & 7 deletions tests/test_docs/tutorial/query_params/test_tutorial_003.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@
},
},
"parameters": [
{
"style": "form",
"explode": True,
"schema": {"title": "Skip", "type": "integer", "default": 0},
"name": "skip",
"in": "query",
},
{
"style": "form",
"explode": True,
Expand All @@ -45,6 +38,13 @@
"name": "limit",
"in": "query",
},
{
"style": "form",
"explode": True,
"schema": {"title": "Skip", "type": "integer", "default": 0},
"name": "skip",
"in": "query",
},
],
}
}
Expand Down
80 changes: 80 additions & 0 deletions tests/test_security/test_security_scheme_naming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any, Dict

from xpresso import App, Path, Security
from xpresso.security import APIKeyHeader
from xpresso.testclient import TestClient
from xpresso.typing import Annotated


def test_duplicate_scheme_name_resolution() -> None:

api_key1 = APIKeyHeader(name="key1")
api_key2 = APIKeyHeader(name="key2")

def endpoint(
key1: Annotated[str, Security(api_key1)],
key2: Annotated[str, Security(api_key2)],
) -> None:
...

app = App([Path("/", get=endpoint)])

openapi_schema: Dict[str, Any] = {
"openapi": "3.0.3",
"info": {"title": "API", "version": "0.1.0"},
"paths": {
"/": {
"get": {
"responses": {"200": {"description": "Successful Response"}},
"security": [{"APIKeyHeader_1": []}, {"APIKeyHeader_2": []}],
}
}
},
"components": {
"securitySchemes": {
"APIKeyHeader_1": {"type": "apiKey", "name": "key1", "in": "header"},
"APIKeyHeader_2": {"type": "apiKey", "name": "key2", "in": "header"},
}
},
}

client = TestClient(app)

response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == openapi_schema, response.json()


def test_scheme_name() -> None:
api_key = APIKeyHeader(name="key1")

def endpoint(
key: Annotated[str, Security(api_key, scheme_name="foobarbaz")],
) -> None:
...

app = App([Path("/", get=endpoint)])

openapi_schema: Dict[str, Any] = {
"openapi": "3.0.3",
"info": {"title": "API", "version": "0.1.0"},
"paths": {
"/": {
"get": {
"responses": {"200": {"description": "Successful Response"}},
"security": [{"foobarbaz": []}],
}
}
},
"components": {
"securitySchemes": {
"foobarbaz": {"type": "apiKey", "name": "key1", "in": "header"}
}
},
}

client = TestClient(app)

response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == openapi_schema
2 changes: 1 addition & 1 deletion xpresso/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
PathParam,
QueryParam,
RepeatedFormField,
Security,
)
from xpresso.datastructures import UploadFile
from xpresso.dependencies.models import Dependant
Expand All @@ -35,7 +36,6 @@
from xpresso.routing.pathitem import Path
from xpresso.routing.router import Router
from xpresso.routing.websockets import WebSocketRoute
from xpresso.security._functions import Security
from xpresso.websockets import WebSocket

__all__ = (
Expand Down
2 changes: 1 addition & 1 deletion xpresso/_utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pydantic.schema import field_schema

from xpresso._utils.typing import filter_pydantic_models_from_mapping
from xpresso.binders._openapi_providers.api import ModelNameMap, Schemas
from xpresso.binders.api import ModelNameMap, Schemas
from xpresso.openapi import models as openapi_models
from xpresso.openapi.constants import REF_PREFIX

Expand Down
24 changes: 2 additions & 22 deletions xpresso/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@
from xpresso.exceptions import HTTPException, RequestValidationError
from xpresso.middleware.exceptions import ExceptionMiddleware
from xpresso.openapi import models as openapi_models
from xpresso.openapi._builder import SecurityModels, genrate_openapi
from xpresso.openapi._builder import genrate_openapi
from xpresso.openapi._html import get_swagger_ui_html
from xpresso.routing.pathitem import Path
from xpresso.routing.router import Router
from xpresso.routing.websockets import WebSocketRoute
from xpresso.security._dependants import Security

ExceptionHandler = typing.Callable[
[Request, Exception], typing.Union[Response, typing.Awaitable[Response]]
Expand Down Expand Up @@ -240,31 +239,12 @@ def _setup(self) -> typing.List[typing.Callable[..., typing.AsyncIterator[None]]
async def get_openapi(self) -> openapi_models.OpenAPI:
return genrate_openapi(
visitor=visit_routes(app_type=App, router=self.router, nodes=[self, self.router], path=""), # type: ignore # for Pylance
container=self.container,
version=self.openapi_version,
info=self.openapi_info,
servers=self.servers,
security_models=await self.gather_security_models(),
)

async def gather_security_models(self) -> SecurityModels:
security_dependants: typing.List[Security] = []
for route in visit_routes(app_type=App, router=self.router, nodes=[self, self.router], path=""): # type: ignore[misc]
if isinstance(route.route, Path):
for operation in route.route.operations.values():
dependant = operation.dependant
assert dependant is not None
for subdependant in dependant.dag:
if isinstance(subdependant, Security):
security_dependants.append(subdependant)
executor = AsyncExecutor()
return {
sec_dep: await self.container.execute_async(
self.container.solve(sec_dep),
executor=executor,
)
for sec_dep in security_dependants
}

def _get_doc_routes(
self,
openapi_url: typing.Optional[str],
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from xpresso.typing import Some


def validate(
def validate_body_field(
values: typing.Optional[Some[typing.Any]],
*,
field: ModelField,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from di.typing import get_markers_from_parameter
from starlette.requests import Request

from xpresso.binders._extractors.api import BodyExtractor, BodyExtractorMarker
from xpresso.binders.api import BodyExtractor, BodyExtractorMarker
from xpresso.binders.dependants import BodyBinderMarker
from xpresso.exceptions import HTTPException

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from starlette.datastructures import FormData

from xpresso._utils.typing import model_field_from_param
from xpresso.binders._extractors.api import BodyExtractor, BodyExtractorMarker
from xpresso.binders.api import BodyExtractor, BodyExtractorMarker
from xpresso.binders.dependants import BodyBinderMarker
from xpresso.typing import Some

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
get_validator as get_media_type_validator,
)
from xpresso._utils.typing import model_field_from_param
from xpresso.binders._extractors.api import BodyExtractor, BodyExtractorMarker
from xpresso.binders._extractors.body.utils import stream_to_bytes
from xpresso.binders._extractors.validator import validate as validate_data
from xpresso.binders._body.extractors.body_field_validation import validate_body_field
from xpresso.binders._utils.stream_to_bytes import convert_stream_to_bytes
from xpresso.binders.api import BodyExtractor, BodyExtractorMarker
from xpresso.exceptions import RequestValidationError
from xpresso.typing import Some

Expand All @@ -33,12 +33,14 @@ async def extract_from_request(self, request: Request) -> typing.Any:
self.media_type_validator.validate(media_type, loc=("body",))
if self.field.type_ is bytes:
if self.consume:
data = await stream_to_bytes(request.stream())
data = await convert_stream_to_bytes(request.stream())
if data is None:
return validate_data(Some(b""), field=self.field, loc=("body",))
return validate_body_field(
Some(b""), field=self.field, loc=("body",)
)
else:
data = await request.body()
return validate_data(Some(data), field=self.field, loc=("body",))
return validate_body_field(Some(data), field=self.field, loc=("body",))
# create an UploadFile from the body's stream
file: UploadFile = self.field.type_( # use the field type to allow users to subclass UploadFile
filename="body", content_type=media_type or "*/*"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
get_validator as get_media_type_validator,
)
from xpresso._utils.typing import model_field_from_param
from xpresso.binders._extractors.api import BodyExtractor, BodyExtractorMarker
from xpresso.binders._extractors.body.form_field import FormFieldBodyExtractorMarker
from xpresso.binders._extractors.validator import validate as validate_data
from xpresso.binders._body.extractors.body_field_validation import validate_body_field
from xpresso.binders._body.extractors.form_field import FormFieldBodyExtractorMarker
from xpresso.binders.api import BodyExtractor, BodyExtractorMarker
from xpresso.binders.dependants import BodyBinderMarker
from xpresso.typing import Some

Expand All @@ -38,12 +38,12 @@ async def extract_from_request(self, request: Request) -> typing.Any:
and self.field.required is not True
):
# this is the only way to know the body is empty
return validate_data(
return validate_body_field(
None,
field=self.field,
loc=("body",),
)
return validate_data(
return validate_body_field(
Some(await self._extract(await request.form(), loc=("body",))),
field=self.field,
loc=("body",),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from starlette.datastructures import FormData

from xpresso._utils.typing import model_field_from_param
from xpresso.binders._extractors.api import BodyExtractor, BodyExtractorMarker
from xpresso.binders._extractors.exceptions import InvalidSerialization
from xpresso.binders._extractors.form_utils import (
from xpresso.binders._utils.forms import (
Extractor,
UnexpectedFileReceived,
get_extractor,
)
from xpresso.binders.api import BodyExtractor, BodyExtractorMarker
from xpresso.binders.exceptions import InvalidSerialization
from xpresso.exceptions import RequestValidationError
from xpresso.typing import Some

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
get_validator as get_media_type_validator,
)
from xpresso._utils.typing import model_field_from_param
from xpresso.binders._extractors.api import BodyExtractor, BodyExtractorMarker
from xpresso.binders._extractors.body.utils import stream_to_bytes
from xpresso.binders._extractors.validator import validate as validate_data
from xpresso.binders._body.extractors.body_field_validation import validate_body_field
from xpresso.binders._utils.stream_to_bytes import convert_stream_to_bytes
from xpresso.binders.api import BodyExtractor, BodyExtractorMarker
from xpresso.exceptions import RequestValidationError
from xpresso.typing import Some

Expand All @@ -46,12 +46,12 @@ async def extract_from_request(self, request: Request) -> typing.Any:
request.headers.get("content-type", None), loc=loc
)
if self.consume:
data_from_stream = await stream_to_bytes(request.stream())
data_from_stream = await convert_stream_to_bytes(request.stream())
if data_from_stream is None:
return validate_data(None, field=self.field, loc=loc)
return validate_body_field(None, field=self.field, loc=loc)
else:
data_from_stream = await request.body()
return validate_data(
return validate_body_field(
Some(await self._decode(data_from_stream, loc=loc)),
field=self.field,
loc=loc,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@
from di.typing import get_markers_from_parameter

from xpresso._utils.typing import model_field_from_param
from xpresso.binders._openapi_providers.api import (
ModelNameMap,
OpenAPIBody,
OpenAPIBodyMarker,
Schemas,
)
from xpresso.binders.api import ModelNameMap, OpenAPIBody, OpenAPIBodyMarker, Schemas
from xpresso.binders.dependants import BodyBinderMarker
from xpresso.openapi import models as openapi_models

Expand Down
Loading

0 comments on commit 4548dc2

Please sign in to comment.