diff --git a/docs/examples/middleware/session/cookie_backend.py b/docs/examples/middleware/session/cookie_backend.py index d3c6e4de25..cebbbfbdbf 100644 --- a/docs/examples/middleware/session/cookie_backend.py +++ b/docs/examples/middleware/session/cookie_backend.py @@ -3,6 +3,6 @@ from litestar import Litestar from litestar.middleware.session.client_side import CookieBackendConfig -session_config = CookieBackendConfig(secret=urandom(16)) # type: ignore[arg-type] +session_config = CookieBackendConfig(secret=urandom(16)) # type: ignore app = Litestar(middleware=[session_config.middleware]) diff --git a/docs/examples/request_data/custom_request.py b/docs/examples/request_data/custom_request.py new file mode 100644 index 0000000000..a1a21f9fe2 --- /dev/null +++ b/docs/examples/request_data/custom_request.py @@ -0,0 +1,32 @@ +from litestar import Litestar, Request, get +from litestar.connection.base import empty_receive, empty_send +from litestar.enums import HttpMethod +from litestar.types import Receive, Scope, Send + +KITTEN_NAMES_MAP = { + HttpMethod.GET: "Whiskers", +} + + +class CustomRequest(Request): + """Enrich request with the kitten name.""" + + __slots__ = ("kitten_name",) + + def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None: + """Initialize CustomRequest class.""" + super().__init__(scope=scope, receive=receive, send=send) + self.kitten_name = KITTEN_NAMES_MAP.get(scope["method"], "Mittens") + + +@get(path="/kitten-name") +def get_kitten_name(request: CustomRequest) -> str: + """Get kitten name based on the HTTP method.""" + return request.kitten_name + + +app = Litestar( + route_handlers=[get_kitten_name], + request_class=CustomRequest, + debug=True, +) diff --git a/docs/examples/responses/json_suffix_responses.py b/docs/examples/responses/json_suffix_responses.py new file mode 100644 index 0000000000..6de3b3e048 --- /dev/null +++ b/docs/examples/responses/json_suffix_responses.py @@ -0,0 +1,16 @@ +from typing import Any, Dict + +import litestar.status_codes +from litestar import Litestar, get + + +@get("/resources", status_code=litestar.status_codes.HTTP_418_IM_A_TEAPOT, media_type="application/problem+json") +async def retrieve_resource() -> Dict[str, Any]: + return { + "title": "Server thinks it is a teapot", + "type": "Server delusion", + "status": litestar.status_codes.HTTP_418_IM_A_TEAPOT, + } + + +app = Litestar(route_handlers=[retrieve_resource]) diff --git a/docs/examples/testing/test_set_session_data.py b/docs/examples/testing/test_set_session_data.py index 864f921aa5..913c690aa8 100644 --- a/docs/examples/testing/test_set_session_data.py +++ b/docs/examples/testing/test_set_session_data.py @@ -14,6 +14,8 @@ def get_session_data(request: Request) -> Dict[str, Any]: app = Litestar(route_handlers=[get_session_data], middleware=[session_config.middleware]) -with TestClient(app=app, session_config=session_config) as client: - client.set_session_data({"foo": "bar"}) - assert client.get("/test").json() == {"foo": "bar"} + +def test_get_session_data() -> None: + with TestClient(app=app, session_config=session_config) as client: + client.set_session_data({"foo": "bar"}) + assert client.get("/test").json() == {"foo": "bar"} diff --git a/docs/examples/websockets/custom_websocket.py b/docs/examples/websockets/custom_websocket.py new file mode 100644 index 0000000000..954a3bdee0 --- /dev/null +++ b/docs/examples/websockets/custom_websocket.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from litestar import Litestar, WebSocket, websocket_listener +from litestar.types.asgi_types import WebSocketMode + + +class CustomWebSocket(WebSocket): + async def receive_data(self, mode: WebSocketMode) -> str | bytes: + """Return fixed response for every websocket message.""" + await super().receive_data(mode=mode) + return "Fixed response" + + +@websocket_listener("/") +async def handler(data: str) -> str: + return data + + +app = Litestar([handler], websocket_class=CustomWebSocket) diff --git a/docs/release-notes/changelog.rst b/docs/release-notes/changelog.rst index b2d82d05b3..3dbbe2da95 100644 --- a/docs/release-notes/changelog.rst +++ b/docs/release-notes/changelog.rst @@ -3,6 +3,129 @@ 2.x Changelog ============= +.. changelog:: 2.7.0 + :date: 2024-03-10 + + .. change:: missing cors headers in response + :type: bugfix + :pr: 3179 + :issue: 3178 + + Set CORS Middleware headers as per spec. + Addresses issues outlined on https://github.com/litestar-org/litestar/issues/3178 + + .. change:: sending empty data in sse in js client + :type: bugfix + :pr: 3176 + + Fix an issue with SSE where JavaScript clients fail to receive an event without data. + The `spec `_ is not clear in whether or not an event without data is ok. + Considering the EventSource "client" is not ok with it, and that it's so easy DX-wise to make the mistake not explicitly sending it, this change fixes it by defaulting to the empty-string + + .. change:: Support ``ResponseSpec(..., examples=[...])`` + :type: feature + :pr: 3100 + :issue: 3068 + + Allow defining custom examples for the responses via ``ResponseSpec``. + The examples set this way are always generated locally, for each response: + Examples that go within the schema definition cannot be set by this. + + .. code-block:: json + + { + "paths": { + "/": { + "get": { + "responses": { + "200": { + "content": { + "application/json": { + "schema": {}, + "examples": "..."}} + }} + }} + } + } + + + .. change:: support "+json"-suffixed response media types + :type: feature + :pr: 3096 + :issue: 3088 + + Automatically encode responses with media type of the form "application/+json" as json. + + .. change:: Allow reusable ``Router`` instances + :type: feature + :pr: 3103 + :issue: 3012 + + It was not possible to re-attach a router instance once it was attached. This + makes that possible. + + The router instance now gets deecopied when it's registered to another router. + + The application startup performance gets a hit here, but the same approach is + already used for controllers and handlers, so this only harmonizes the + implementation. + + .. change:: only display path in ``ValidationException``\ s + :type: feature + :pr: 3064 + :issue: 3061 + + Fix an issue where ``ValidationException`` exposes the full URL in the error response, leaking internal IP(s) or other similar infra related information. + + .. change:: expose ``request_class`` to other layers + :type: feature + :pr: 3125 + + Expose ``request_class`` to other layers + + .. change:: expose ``websocket_class`` + :type: feature + :pr: 3152 + + Expose ``websocket_class`` to other layers + + .. change:: Add ``type_decoders`` to Router and route handlers + :type: feature + :pr: 3153 + + Add ``type_decoders`` to ``__init__`` method for handler, routers and decorators to keep consistency with ``type_encoders`` parameter + + .. change:: Pass ``type_decoders`` in ``WebsocketListenerRouteHandler`` + :type: feature + :pr: 3162 + + Pass ``type_decoders`` to parent's ``__init__`` in ``WebsocketListenerRouteHandler`` init, otherwise ``type_decoders`` will be ``None`` + replace params order in docs, ``__init__`` (`decoders` before `encoders`) + + .. change:: 3116 enhancement session middleware + :type: feature + :pr: 3127 + :issue: 3116 + + For server side sessions, the session id is now generated before the route handler. Thus, on first visit, a session id will be available inside the route handler's scope instead of afterwards + A new abstract method ``get_session_id`` was added to ``BaseSessionBackend`` since this method will be called for both ClientSideSessions and ServerSideSessions. Only for ServerSideSessions it will return an actual id. + Using ``request.set_session(...)`` will return the session id for ServerSideSessions and None for ClientSideSessions + The session auth MiddlewareWrapper now refers to the Session Middleware via the configured backend, instead of it being hardcoded + + .. change:: make random seed for openapi example generation configurable + :type: feature + :pr: 3166 + + Allow random seed used for generating the examples in the OpenAPI schema (when ``create_examples`` is set to ``True``) to be configured by the user. + This is related to https://github.com/litestar-org/litestar/issues/3059 however whether this change is enough to close that issue or not is not confirmed. + + .. change:: generate openapi components schemas in a deterministic order + :type: feature + :pr: 3172 + + Ensure that the insertion into the ``Components.schemas`` dictionary of the OpenAPI spec will be in alphabetical order (based on the normalized name of the ``Schema``). + + .. changelog:: 2.6.3 :date: 2024-03-04 diff --git a/docs/usage/requests.rst b/docs/usage/requests.rst index f23ddd79fd..fdc65f2784 100644 --- a/docs/usage/requests.rst +++ b/docs/usage/requests.rst @@ -138,3 +138,25 @@ for ``Body`` , by using :class:`RequestEncodingType.MESSAGEPACK <.enums.RequestE .. literalinclude:: /examples/request_data/msgpack_request.py :caption: msgpack_request.py :language: python + + +Custom Request +-------------- + +.. versionadded:: 2.7.0 + +Litestar supports custom ``request_class`` instances, which can be used to further configure the default :class:`Request`. +The example below illustrates how to implement custom request class for the whole application. + +.. dropdown:: Example of a custom request at the application level + + .. literalinclude:: /examples/request_data/custom_request.py + :language: python + +.. admonition:: Layered architecture + + Request classes are part of Litestar's layered architecture, which means you can + set a request class on every layer of the application. If you have set a request + class on multiple layers, the layer closest to the route handler will take precedence. + + You can read more about this in the :ref:`usage/applications:layered architecture` section diff --git a/docs/usage/responses.rst b/docs/usage/responses.rst index 7f9f82af03..6b9780cfc6 100644 --- a/docs/usage/responses.rst +++ b/docs/usage/responses.rst @@ -77,6 +77,17 @@ As previously mentioned, the default ``media_type`` is ``MediaType.JSON``. which If you need to return other values and would like to extend serialization you can do this :ref:`custom responses `. +You can also set an application media type string with the ``+json`` suffix +defined in `RFC 6839 `_ +as the ``media_type`` and it will be recognized and serialized as json. +For example, you can use ``application/problem+json`` +(see `RFC 7807 `_) +and it will work just like json but have the appropriate content-type header +and show up in the generated OpenAPI schema. + +.. literalinclude:: /examples/responses/json_suffix_responses.py + :language: python + MessagePack responses +++++++++++++++++++++ diff --git a/docs/usage/websockets.rst b/docs/usage/websockets.rst index cc43904f17..6cbbc8c267 100644 --- a/docs/usage/websockets.rst +++ b/docs/usage/websockets.rst @@ -249,3 +249,25 @@ encapsulate more complex logic. .. literalinclude:: /examples/websockets/listener_class_based_async.py :language: python + + +Custom WebSocket +---------------- + +.. versionadded:: 2.7.0 + +Litestar supports custom ``websocket_class`` instances, which can be used to further configure the default :class:`WebSocket`. +The example below illustrates how to implement custom websocket class for the whole application. + +.. dropdown:: Example of a custom websocket at the application level + + .. literalinclude:: /examples/websockets/custom_websocket.py + :language: python + +.. admonition:: Layered architecture + + WebSocket classes are part of Litestar's layered architecture, which means you can + set a websocket class on every layer of the application. If you have set a webscoket + class on multiple layers, the layer closest to the route handler will take precedence. + + You can read more about this in the :ref:`usage/applications:layered architecture` section diff --git a/litestar/_asgi/routing_trie/mapping.py b/litestar/_asgi/routing_trie/mapping.py index 8ef4a27c56..7a56b97a2f 100644 --- a/litestar/_asgi/routing_trie/mapping.py +++ b/litestar/_asgi/routing_trie/mapping.py @@ -212,7 +212,7 @@ def build_route_middleware_stack( handler, kwargs = cast("tuple[Any, dict[str, Any]]", middleware) asgi_handler = handler(app=asgi_handler, **kwargs) else: - asgi_handler = middleware(app=asgi_handler) # type: ignore + asgi_handler = middleware(app=asgi_handler) # type: ignore[call-arg] # we wrap the entire stack again in ExceptionHandlerMiddleware return wrap_in_exception_handler( diff --git a/litestar/_kwargs/extractors.py b/litestar/_kwargs/extractors.py index bf6f34a4ec..e3b347eadb 100644 --- a/litestar/_kwargs/extractors.py +++ b/litestar/_kwargs/extractors.py @@ -11,6 +11,7 @@ ) from litestar.datastructures import Headers from litestar.datastructures.upload_file import UploadFile +from litestar.datastructures.url import URL from litestar.enums import ParamType, RequestEncodingType from litestar.exceptions import ValidationException from litestar.params import BodyKwarg @@ -106,8 +107,12 @@ def extractor(values: dict[str, Any], connection: ASGIConnection) -> None: values.update(connection_mapping) except KeyError as e: param = alias_to_params[e.args[0]] + path = URL.from_components( + path=connection.url.path, + query=connection.url.query, + ) raise ValidationException( - f"Missing required {param.param_type.value} parameter {param.field_alias!r} for url {connection.url}" + f"Missing required {param.param_type.value} parameter {param.field_alias!r} for path {path}" ) from e return extractor @@ -130,7 +135,7 @@ def create_query_default_dict( for k, v in parsed_query: if k in sequence_query_parameter_names: - output[k].append(v) # type: ignore + output[k].append(v) # type: ignore[union-attr] else: output[k] = v diff --git a/litestar/_kwargs/kwargs_model.py b/litestar/_kwargs/kwargs_model.py index f6374017d0..e69563622e 100644 --- a/litestar/_kwargs/kwargs_model.py +++ b/litestar/_kwargs/kwargs_model.py @@ -137,7 +137,7 @@ def _create_extractors(self) -> list[Callable[[dict[str, Any], ASGIConnection], "headers": headers_extractor, "cookies": cookies_extractor, "query": query_extractor, - "body": body_extractor, # type: ignore + "body": body_extractor, # type: ignore[dict-item] } extractors: list[Callable[[dict[str, Any], ASGIConnection], None]] = [ diff --git a/litestar/_openapi/datastructures.py b/litestar/_openapi/datastructures.py index 64466ec0ee..d97c8db405 100644 --- a/litestar/_openapi/datastructures.py +++ b/litestar/_openapi/datastructures.py @@ -150,7 +150,8 @@ def generate_components_schemas(self) -> dict[str, Schema]: self.set_reference_paths(name_, registered_schema) components_schemas[name_] = registered_schema.schema - return components_schemas + # Sort them by name to ensure they're always generated in the same order. + return {name: components_schemas[name] for name in sorted(components_schemas.keys())} class OpenAPIContext: diff --git a/litestar/_openapi/plugin.py b/litestar/_openapi/plugin.py index ffb0fd2d06..9bdbdecebd 100644 --- a/litestar/_openapi/plugin.py +++ b/litestar/_openapi/plugin.py @@ -32,8 +32,15 @@ def __init__(self, app: Litestar) -> None: self._openapi_schema: OpenAPI | None = None def _build_openapi_schema(self) -> OpenAPI: - openapi = self.openapi_config.to_openapi_schema() - context = OpenAPIContext(openapi_config=self.openapi_config, plugins=self.app.plugins.openapi) + openapi_config = self.openapi_config + + if openapi_config.create_examples: + from litestar._openapi.schema_generation.examples import ExampleFactory + + ExampleFactory.seed_random(openapi_config.random_seed) + + openapi = openapi_config.to_openapi_schema() + context = OpenAPIContext(openapi_config=openapi_config, plugins=self.app.plugins.openapi) openapi.paths = { route.path_format or "/": create_path_item_for_route(context, route) for route in self.included_routes.values() diff --git a/litestar/_openapi/responses.py b/litestar/_openapi/responses.py index ba1bfa6aac..7701d02659 100644 --- a/litestar/_openapi/responses.py +++ b/litestar/_openapi/responses.py @@ -9,9 +9,10 @@ from typing import TYPE_CHECKING, Any, Iterator from litestar._openapi.schema_generation import SchemaCreator +from litestar._openapi.schema_generation.utils import get_formatted_examples from litestar.enums import MediaType from litestar.exceptions import HTTPException, ValidationException -from litestar.openapi.spec import Example, OpenAPIResponse +from litestar.openapi.spec import Example, OpenAPIResponse, Reference from litestar.openapi.spec.enums import OpenAPIFormat, OpenAPIType from litestar.openapi.spec.header import OpenAPIHeader from litestar.openapi.spec.media_type import OpenAPIMediaType @@ -240,13 +241,18 @@ def create_additional_responses(self) -> Iterator[tuple[str, OpenAPIResponse]]: prefer_alias=False, generate_examples=additional_response.generate_examples, ) + field_def = FieldDefinition.from_annotation(additional_response.data_container) + + examples: dict[str, Example | Reference] | None = ( + dict(get_formatted_examples(field_def, additional_response.examples)) + if additional_response.examples + else None + ) content: dict[str, OpenAPIMediaType] | None if additional_response.data_container is not None: - schema = schema_creator.for_field_definition( - FieldDefinition.from_annotation(additional_response.data_container) - ) - content = {additional_response.media_type: OpenAPIMediaType(schema=schema)} + schema = schema_creator.for_field_definition(field_def) + content = {additional_response.media_type: OpenAPIMediaType(schema=schema, examples=examples)} else: content = None diff --git a/litestar/_openapi/schema_generation/schema.py b/litestar/_openapi/schema_generation/schema.py index 891ace9c43..2691ab12e9 100644 --- a/litestar/_openapi/schema_generation/schema.py +++ b/litestar/_openapi/schema_generation/schema.py @@ -189,7 +189,7 @@ def create_enum_schema(annotation: EnumMeta, include_null: bool = False) -> Sche Returns: A schema instance. """ - enum_values: list[str | int | None] = [v.value for v in annotation] # type: ignore + enum_values: list[str | int | None] = [v.value for v in annotation] # type: ignore[var-annotated] if include_null and None not in enum_values: enum_values.append(None) return Schema(type=_types_in_list(enum_values), enum=enum_values) diff --git a/litestar/_signature/model.py b/litestar/_signature/model.py index f3b37ea701..42c79947f5 100644 --- a/litestar/_signature/model.py +++ b/litestar/_signature/model.py @@ -32,6 +32,7 @@ _validate_signature_dependencies, ) from litestar.datastructures.state import ImmutableState +from litestar.datastructures.url import URL from litestar.dto import AbstractDTO, DTOData from litestar.enums import ParamType, ScopeType from litestar.exceptions import InternalServerException, ValidationException @@ -119,7 +120,11 @@ def _create_exception(cls, connection: ASGIConnection, messages: list[ErrorMessa for err_message in messages if ("key" in err_message and err_message["key"] not in cls._dependency_name_set) or "key" not in err_message ]: - return ValidationException(detail=f"Validation failed for {method} {connection.url}", extra=client_errors) + path = URL.from_components( + path=connection.url.path, + query=connection.url.query, + ) + return ValidationException(detail=f"Validation failed for {method} {path}", extra=client_errors) return InternalServerException() @classmethod diff --git a/litestar/app.py b/litestar/app.py index 4bdb2648c3..e1bd989d75 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -216,8 +216,8 @@ def __init__( stores: StoreRegistry | dict[str, Store] | None = None, tags: Sequence[str] | None = None, template_config: TemplateConfigType | None = None, - type_encoders: TypeEncodersMap | None = None, type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, websocket_class: type[WebSocket] | None = None, lifespan: Sequence[Callable[[Litestar], AbstractAsyncContextManager] | AbstractAsyncContextManager] | None = None, @@ -308,9 +308,9 @@ def __init__( tags: A sequence of string tags that will be appended to the schema of all route handlers under the application. template_config: An instance of :class:`TemplateConfig <.template.TemplateConfig>` - type_encoders: A mapping of types to callables that transform them into types supported for serialization. type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. websocket_class: An optional subclass of :class:`WebSocket <.connection.WebSocket>` to use for websocket connections. experimental_features: An iterable of experimental features to enable @@ -413,12 +413,12 @@ def __init__( self.on_shutdown = config.on_shutdown self.on_startup = config.on_startup self.openapi_config = config.openapi_config - self.request_class = config.request_class or Request + self.request_class: type[Request] = config.request_class or Request self.response_cache_config = config.response_cache_config self.state = config.state self._static_files_config = config.static_files_config self.template_engine = config.template_config.engine_instance if config.template_config else None - self.websocket_class = config.websocket_class or WebSocket + self.websocket_class: type[WebSocket] = config.websocket_class or WebSocket self.debug = config.debug self.pdb_on_exception: bool = config.pdb_on_exception self.include_in_schema = include_in_schema @@ -449,6 +449,7 @@ def __init__( opt=config.opt, parameters=config.parameters, path="", + request_class=self.request_class, response_class=config.response_class, response_cookies=config.response_cookies, response_headers=config.response_headers, @@ -462,6 +463,7 @@ def __init__( type_encoders=config.type_encoders, type_decoders=config.type_decoders, include_in_schema=config.include_in_schema, + websocket_class=self.websocket_class, ) for route_handler in config.route_handlers: @@ -575,7 +577,7 @@ async def __call__( await self.asgi_handler(scope, receive, self._wrap_send(send=send, scope=scope)) # type: ignore[arg-type] async def _call_lifespan_hook(self, hook: LifespanHook) -> None: - ret = hook(self) if inspect.signature(hook).parameters else hook() # type: ignore + ret = hook(self) if inspect.signature(hook).parameters else hook() # type: ignore[call-arg] if is_async_callable(hook): # pyright: ignore[reportGeneralTypeIssues] await ret diff --git a/litestar/cli/_utils.py b/litestar/cli/_utils.py index fa482482d4..f36cd7703c 100644 --- a/litestar/cli/_utils.py +++ b/litestar/cli/_utils.py @@ -441,7 +441,7 @@ def validate_ssl_file_paths(certfile_arg: str | None, keyfile_arg: str | None) - raise LitestarCLIException(f"File provided for {argname} was not found: {path}") resolved_paths.append(str(path)) - return tuple(resolved_paths) # type: ignore + return tuple(resolved_paths) # type: ignore[return-value] def create_ssl_files( diff --git a/litestar/cli/commands/schema.py b/litestar/cli/commands/schema.py index b145722b58..a323bc7871 100644 --- a/litestar/cli/commands/schema.py +++ b/litestar/cli/commands/schema.py @@ -41,7 +41,7 @@ def _generate_openapi_schema(app: Litestar, output: Path) -> None: raise LitestarCLIException(f"failed to write schema to path {output}") from e -@schema_group.command("openapi") # type: ignore +@schema_group.command("openapi") # type: ignore[misc] @option( "--output", help="output file path", @@ -54,7 +54,7 @@ def generate_openapi_schema(app: Litestar, output: Path) -> None: _generate_openapi_schema(app, output) -@schema_group.command("typescript") # type: ignore +@schema_group.command("typescript") # type: ignore[misc] @option( "--output", help="output file path", diff --git a/litestar/cli/commands/sessions.py b/litestar/cli/commands/sessions.py index 783f982834..f048dd13a7 100644 --- a/litestar/cli/commands/sessions.py +++ b/litestar/cli/commands/sessions.py @@ -29,7 +29,7 @@ def sessions_group() -> None: """Manage server-side sessions.""" -@sessions_group.command("delete") # type: ignore +@sessions_group.command("delete") # type: ignore[misc] @argument("session-id") def delete_session_command(session_id: str, app: Litestar) -> None: """Delete a specific session.""" @@ -43,7 +43,7 @@ def delete_session_command(session_id: str, app: Litestar) -> None: console.print(f"[green]Deleted session {session_id!r}") -@sessions_group.command("clear") # type: ignore +@sessions_group.command("clear") # type: ignore[misc] def clear_sessions_command(app: Litestar) -> None: """Delete all sessions.""" import anyio diff --git a/litestar/config/app.py b/litestar/config/app.py index 859999bfd6..0acefb1ae9 100644 --- a/litestar/config/app.py +++ b/litestar/config/app.py @@ -198,10 +198,10 @@ class AppConfig: """A list of string tags that will be appended to the schema of all route handlers under the application.""" template_config: TemplateConfigType | None = field(default=None) """An instance of :class:`TemplateConfig <.template.TemplateConfig>`.""" - type_encoders: TypeEncodersMap | None = field(default=None) - """A mapping of types to callables that transform them into types supported for serialization.""" type_decoders: TypeDecodersSequence | None = field(default=None) """A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization.""" + type_encoders: TypeEncodersMap | None = field(default=None) + """A mapping of types to callables that transform them into types supported for serialization.""" websocket_class: type[WebSocket] | None = field(default=None) """An optional subclass of :class:`WebSocket <.connection.WebSocket>` to use for websocket connections.""" multipart_form_part_limit: int = field(default=1000) diff --git a/litestar/connection/base.py b/litestar/connection/base.py index 7fb7098101..d14c6620e5 100644 --- a/litestar/connection/base.py +++ b/litestar/connection/base.py @@ -9,6 +9,7 @@ from litestar.datastructures.url import URL, Address, make_absolute_url from litestar.exceptions import ImproperlyConfiguredException from litestar.types.empty import Empty +from litestar.utils.empty import value_or_default from litestar.utils.scope.state import ScopeState if TYPE_CHECKING: @@ -287,7 +288,7 @@ def set_session(self, value: dict[str, Any] | DataContainerType | EmptyType) -> value: Dictionary or pydantic model instance for the session data. Returns: - None. + None """ self.scope["session"] = value @@ -301,6 +302,10 @@ def clear_session(self) -> None: None. """ self.scope["session"] = Empty + self._connection_state.session_id = Empty + + def get_session_id(self) -> str | None: + return value_or_default(value=self._connection_state.session_id, default=None) def url_for(self, name: str, **path_parameters: Any) -> str: """Return the url for a given route handler name. diff --git a/litestar/contrib/piccolo.py b/litestar/contrib/piccolo.py index 217db8445d..73bd27150a 100644 --- a/litestar/contrib/piccolo.py +++ b/litestar/contrib/piccolo.py @@ -56,7 +56,7 @@ def _parse_piccolo_type(column: Column, extra: dict[str, Any]) -> FieldDefinitio else: meta = Meta(max_length=column.length, extra=extra) elif isinstance(column, column_types.Array): - column_type = List[column.base_column.value_type] # type: ignore + column_type = List[column.base_column.value_type] # type: ignore[name-defined] meta = Meta(extra=extra) elif isinstance(column, (column_types.JSON, column_types.JSONB)): column_type = str diff --git a/litestar/contrib/prometheus/controller.py b/litestar/contrib/prometheus/controller.py index bee6eeee53..15f5bf1d52 100644 --- a/litestar/contrib/prometheus/controller.py +++ b/litestar/contrib/prometheus/controller.py @@ -43,11 +43,11 @@ async def get(self) -> Response: registry = REGISTRY if "prometheus_multiproc_dir" in os.environ or "PROMETHEUS_MULTIPROC_DIR" in os.environ: registry = CollectorRegistry() - multiprocess.MultiProcessCollector(registry) # type: ignore + multiprocess.MultiProcessCollector(registry) # type: ignore[no-untyped-call] if self.openmetrics_format: headers = {"Content-Type": OPENMETRICS_CONTENT_TYPE_LATEST} - return Response(openmetrics_generate_latest(registry), status_code=200, headers=headers) # type: ignore + return Response(openmetrics_generate_latest(registry), status_code=200, headers=headers) # type: ignore[no-untyped-call] headers = {"Content-Type": CONTENT_TYPE_LATEST} return Response(generate_latest(registry), status_code=200, headers=headers) diff --git a/litestar/contrib/pydantic/pydantic_schema_plugin.py b/litestar/contrib/pydantic/pydantic_schema_plugin.py index af719a49af..ffc50f13e1 100644 --- a/litestar/contrib/pydantic/pydantic_schema_plugin.py +++ b/litestar/contrib/pydantic/pydantic_schema_plugin.py @@ -211,7 +211,7 @@ def __init__(self, prefer_alias: bool = False) -> None: @staticmethod def is_plugin_supported_type(value: Any) -> bool: - return isinstance(value, _supported_types) or is_class_and_subclass(value, _supported_types) # type: ignore + return isinstance(value, _supported_types) or is_class_and_subclass(value, _supported_types) # type: ignore[arg-type] @staticmethod def is_undefined_sentinel(value: Any) -> bool: diff --git a/litestar/controller.py b/litestar/controller.py index cd2cc84fad..967454b168 100644 --- a/litestar/controller.py +++ b/litestar/controller.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: + from litestar.connection import Request, WebSocket from litestar.datastructures import CacheControlHeader, ETag from litestar.dto import AbstractDTO from litestar.openapi.spec import SecurityRequirement @@ -61,6 +62,7 @@ class Controller: "owner", "parameters", "path", + "request_class", "response_class", "response_cookies", "response_headers", @@ -70,6 +72,7 @@ class Controller: "tags", "type_encoders", "type_decoders", + "websocket_class", ) after_request: AfterRequestHookHandler | None @@ -127,6 +130,10 @@ class Controller: All route handlers under the controller will have the fragment appended to them. If not set it defaults to ``/``. """ + request_class: type[Request] | None + """A custom subclass of :class:`Request <.connection.Request>` to be used as the default request for all route + handlers under the controller. + """ response_class: type[Response] | None """A custom subclass of :class:`Response <.response.Response>` to be used as the default response for all route handlers under the controller. @@ -150,10 +157,14 @@ class Controller: These types will be added to the signature namespace using their ``__name__`` attribute. """ - type_encoders: TypeEncodersMap | None - """A mapping of types to callables that transform them into types supported for serialization.""" type_decoders: TypeDecodersSequence | None """A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization.""" + type_encoders: TypeEncodersMap | None + """A mapping of types to callables that transform them into types supported for serialization.""" + websocket_class: type[WebSocket] | None + """A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as the default websocket for all route + handlers under the controller. + """ def __init__(self, owner: Router) -> None: """Initialize a controller. diff --git a/litestar/data_extractors.py b/litestar/data_extractors.py index 5a6f6607f0..61993b4552 100644 --- a/litestar/data_extractors.py +++ b/litestar/data_extractors.py @@ -152,7 +152,7 @@ def __call__(self, connection: ASGIConnection[Any, Any, Any, Any]) -> ExtractedR A string keyed dictionary of extracted values. """ extractors = ( - {**self.connection_extractors, **self.request_extractors} # type: ignore + {**self.connection_extractors, **self.request_extractors} # type: ignore[misc] if isinstance(connection, Request) else self.connection_extractors ) @@ -162,7 +162,7 @@ async def extract( self, connection: ASGIConnection[Any, Any, Any, Any], fields: Iterable[str] ) -> ExtractedRequestData: extractors = ( - {**self.connection_extractors, **self.request_extractors} # type: ignore + {**self.connection_extractors, **self.request_extractors} # type: ignore[misc] if isinstance(connection, Request) else self.connection_extractors ) diff --git a/litestar/file_system.py b/litestar/file_system.py index d7655485e1..fcb77c7925 100644 --- a/litestar/file_system.py +++ b/litestar/file_system.py @@ -49,7 +49,7 @@ async def open(self, file: PathType, mode: str, buffering: int = -1) -> AsyncFil mode: Mode, similar to the built ``open``. buffering: Buffer size. """ - return await open_file(file=file, mode=mode, buffering=buffering) # type: ignore + return await open_file(file=file, mode=mode, buffering=buffering) # type: ignore[call-overload, no-any-return] class FileSystemAdapter: diff --git a/litestar/handlers/base.py b/litestar/handlers/base.py index aa02b56a4f..9dbb70e28b 100644 --- a/litestar/handlers/base.py +++ b/litestar/handlers/base.py @@ -90,8 +90,8 @@ def __init__( return_dto: type[AbstractDTO] | None | EmptyType = Empty, signature_namespace: Mapping[str, Any] | None = None, signature_types: Sequence[Any] | None = None, - type_encoders: TypeEncodersMap | None = None, type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, **kwargs: Any, ) -> None: """Initialize ``HTTPRouteHandler``. @@ -115,8 +115,8 @@ def __init__( modelling. signature_types: A sequence of types for use in forward reference resolution during signature modeling. These types will be added to the signature namespace using their ``__name__`` attribute. - type_encoders: A mapping of types to callables that transform them into types supported for serialization. type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. **kwargs: Any additional kwarg - will be set in the opt dictionary. """ self._parsed_fn_signature: ParsedSignature | EmptyType = Empty @@ -149,7 +149,7 @@ def __init__( self.type_encoders = type_encoders self.paths = ( - {normalize_path(p) for p in path} if path and isinstance(path, list) else {normalize_path(path or "/")} # type: ignore + {normalize_path(p) for p in path} if path and isinstance(path, list) else {normalize_path(path or "/")} # type: ignore[arg-type] ) def __call__(self, fn: AsyncAnyCallable) -> Self: @@ -523,7 +523,7 @@ def resolve_return_dto(self) -> type[AbstractDTO] | None: async def authorize_connection(self, connection: ASGIConnection) -> None: """Ensure the connection is authorized by running all the route guards in scope.""" for guard in self.resolve_guards(): - await guard(connection, copy(self)) # type: ignore + await guard(connection, copy(self)) # type: ignore[misc] @staticmethod def _validate_dependency_is_unique(dependencies: dict[str, Provide], key: str, provider: Provide) -> None: diff --git a/litestar/handlers/http_handlers/_utils.py b/litestar/handlers/http_handlers/_utils.py index 2df6717be6..ec95145ab8 100644 --- a/litestar/handlers/http_handlers/_utils.py +++ b/litestar/handlers/http_handlers/_utils.py @@ -92,7 +92,7 @@ def create_generic_asgi_response_handler(after_request: AfterRequestHookHandler """ async def handler(data: ASGIApp, **kwargs: Any) -> ASGIApp: - return await after_request(data) if after_request else data # type: ignore + return await after_request(data) if after_request else data # type: ignore[arg-type, misc, no-any-return] return handler @@ -149,7 +149,7 @@ async def handler( **kwargs: Any, # kwargs is for return dto ) -> ASGIApp: response = await after_request(data) if after_request else data # type:ignore[arg-type,misc] - return response.to_asgi_response( # type: ignore + return response.to_asgi_response( # type: ignore[no-any-return] app=None, background=background, cookies=cookie_list, diff --git a/litestar/handlers/http_handlers/base.py b/litestar/handlers/http_handlers/base.py index d234a24cca..757253e675 100644 --- a/litestar/handlers/http_handlers/base.py +++ b/litestar/handlers/http_handlers/base.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, AnyStr, Mapping, Sequence, TypedDict, cast from litestar._layers.utils import narrow_response_cookies, narrow_response_headers +from litestar.connection import Request from litestar.datastructures.cookie import Cookie from litestar.datastructures.response_header import ResponseHeader from litestar.enums import HttpMethod, MediaType @@ -51,12 +52,12 @@ from litestar.app import Litestar from litestar.background_tasks import BackgroundTask, BackgroundTasks from litestar.config.response_cache import CACHE_FOREVER - from litestar.connection import Request from litestar.datastructures import CacheControlHeader, ETag from litestar.dto import AbstractDTO from litestar.openapi.datastructures import ResponseSpec from litestar.openapi.spec import SecurityRequirement from litestar.types.callable_types import AsyncAnyCallable, OperationIDCreator + from litestar.types.composite_types import TypeDecodersSequence __all__ = ("HTTPRouteHandler", "route") @@ -98,6 +99,7 @@ class HTTPRouteHandler(BaseRouteHandler): "operation_class", "operation_id", "raises", + "request_class", "response_class", "response_cookies", "response_description", @@ -134,6 +136,7 @@ def __init__( middleware: Sequence[Middleware] | None = None, name: str | None = None, opt: Mapping[str, Any] | None = None, + request_class: type[Request] | None = None, response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, @@ -155,6 +158,7 @@ def __init__( security: Sequence[SecurityRequirement] | None = None, summary: str | None = None, tags: Sequence[str] | None = None, + type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, **kwargs: Any, ) -> None: @@ -196,6 +200,8 @@ def __init__( opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's + default request. response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's default response. response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. @@ -224,6 +230,7 @@ def __init__( security: A sequence of dictionaries that contain information about which security scheme can be used on the endpoint. summary: Text used for the route's schema summary section. tags: A sequence of string tags that will be appended to the OpenAPI schema. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization. type_encoders: A mapping of types to callables that transform them into types supported for serialization. **kwargs: Any additional kwarg - will be set in the opt dictionary. """ @@ -244,6 +251,7 @@ def __init__( opt=opt, return_dto=return_dto, signature_namespace=signature_namespace, + type_decoders=type_decoders, type_encoders=type_encoders, **kwargs, ) @@ -257,6 +265,7 @@ def __init__( self.cache_key_builder = cache_key_builder self.etag = etag self.media_type: MediaType | str = media_type or "" + self.request_class = request_class self.response_class = response_class self.response_cookies: Sequence[Cookie] | None = narrow_response_cookies(response_cookies) self.response_headers: Sequence[ResponseHeader] | None = narrow_response_headers(response_headers) @@ -295,6 +304,19 @@ def __call__(self, fn: AnyCallable) -> HTTPRouteHandler: super().__call__(fn) return self + def resolve_request_class(self) -> type[Request]: + """Return the closest custom Request class in the owner graph or the default Request class. + + This method is memoized so the computation occurs only once. + + Returns: + The default :class:`Request <.connection.Request>` class for the route handler. + """ + return next( + (layer.request_class for layer in reversed(self.ownership_layers) if layer.request_class is not None), + Request, + ) + def resolve_response_class(self) -> type[Response]: """Return the closest custom Response class in the owner graph or the default Response class. @@ -304,11 +326,7 @@ def resolve_response_class(self) -> type[Response]: The default :class:`Response <.response.Response>` class for the route handler. """ return next( - ( - layer.response_class - for layer in list(reversed(self.ownership_layers)) - if layer.response_class is not None - ), + (layer.response_class for layer in reversed(self.ownership_layers) if layer.response_class is not None), Response, ) @@ -525,7 +543,7 @@ async def to_response(self, app: Litestar, data: Any, request: Request) -> ASGIA data = return_dto_type(request).data_to_encodable_type(data) response_handler = self.get_response_handler(is_response_type_data=isinstance(data, Response)) - return await response_handler(app=app, data=data, request=request) # type: ignore + return await response_handler(app=app, data=data, request=request) # type: ignore[call-arg] def on_registration(self, app: Litestar) -> None: super().on_registration(app) diff --git a/litestar/handlers/http_handlers/decorators.py b/litestar/handlers/http_handlers/decorators.py index f78168efef..1ae72e559b 100644 --- a/litestar/handlers/http_handlers/decorators.py +++ b/litestar/handlers/http_handlers/decorators.py @@ -6,8 +6,8 @@ from litestar.exceptions import HTTPException, ImproperlyConfiguredException from litestar.openapi.spec import Operation from litestar.response.file import ASGIFileResponse, File +from litestar.types import Empty, TypeDecodersSequence from litestar.types.builtin_types import NoneType -from litestar.types.empty import Empty from litestar.utils import is_class_and_subclass from .base import HTTPRouteHandler @@ -17,6 +17,7 @@ from litestar.background_tasks import BackgroundTask, BackgroundTasks from litestar.config.response_cache import CACHE_FOREVER + from litestar.connection import Request from litestar.datastructures import CacheControlHeader, ETag from litestar.dto import AbstractDTO from litestar.openapi.datastructures import ResponseSpec @@ -70,6 +71,7 @@ def __init__( middleware: Sequence[Middleware] | None = None, name: str | None = None, opt: Mapping[str, Any] | None = None, + request_class: type[Request] | None = None, response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, @@ -91,6 +93,7 @@ def __init__( security: Sequence[SecurityRequirement] | None = None, summary: str | None = None, tags: Sequence[str] | None = None, + type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, **kwargs: Any, ) -> None: @@ -131,6 +134,8 @@ def __init__( name: A string identifying the route handler. opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's + default request. response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's default response. response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. @@ -159,6 +164,8 @@ def __init__( security: A sequence of dictionaries that contain information about which security scheme can be used on the endpoint. summary: Text used for the route's schema summary section. tags: A sequence of string tags that will be appended to the OpenAPI schema. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. type_encoders: A mapping of types to callables that transform them into types supported for serialization. **kwargs: Any additional kwarg - will be set in the opt dictionary. """ @@ -191,6 +198,7 @@ def __init__( opt=opt, path=path, raises=raises, + request_class=request_class, response_class=response_class, response_cookies=response_cookies, response_description=response_description, @@ -203,6 +211,7 @@ def __init__( summary=summary, sync_to_thread=sync_to_thread, tags=tags, + type_decoders=type_decoders, type_encoders=type_encoders, **kwargs, ) @@ -234,6 +243,7 @@ def __init__( middleware: Sequence[Middleware] | None = None, name: str | None = None, opt: Mapping[str, Any] | None = None, + request_class: type[Request] | None = None, response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, @@ -255,6 +265,7 @@ def __init__( security: Sequence[SecurityRequirement] | None = None, summary: str | None = None, tags: Sequence[str] | None = None, + type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, **kwargs: Any, ) -> None: @@ -295,6 +306,8 @@ def __init__( name: A string identifying the route handler. opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's + default request. response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's default response. response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. @@ -323,6 +336,8 @@ def __init__( security: A sequence of dictionaries that contain information about which security scheme can be used on the endpoint. summary: Text used for the route's schema summary section. tags: A sequence of string tags that will be appended to the OpenAPI schema. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. type_encoders: A mapping of types to callables that transform them into types supported for serialization. **kwargs: Any additional kwarg - will be set in the opt dictionary. """ @@ -356,6 +371,7 @@ def __init__( opt=opt, path=path, raises=raises, + request_class=request_class, response_class=response_class, response_cookies=response_cookies, response_description=response_description, @@ -368,6 +384,7 @@ def __init__( summary=summary, sync_to_thread=sync_to_thread, tags=tags, + type_decoders=type_decoders, type_encoders=type_encoders, **kwargs, ) @@ -399,6 +416,7 @@ def __init__( middleware: Sequence[Middleware] | None = None, name: str | None = None, opt: Mapping[str, Any] | None = None, + request_class: type[Request] | None = None, response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, @@ -420,6 +438,7 @@ def __init__( security: Sequence[SecurityRequirement] | None = None, summary: str | None = None, tags: Sequence[str] | None = None, + type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, **kwargs: Any, ) -> None: @@ -464,6 +483,8 @@ def __init__( name: A string identifying the route handler. opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's + default request. response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's default response. response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. @@ -492,6 +513,8 @@ def __init__( security: A sequence of dictionaries that contain information about which security scheme can be used on the endpoint. summary: Text used for the route's schema summary section. tags: A sequence of string tags that will be appended to the OpenAPI schema. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. type_encoders: A mapping of types to callables that transform them into types supported for serialization. **kwargs: Any additional kwarg - will be set in the opt dictionary. """ @@ -525,6 +548,7 @@ def __init__( opt=opt, path=path, raises=raises, + request_class=request_class, response_class=response_class, response_cookies=response_cookies, response_description=response_description, @@ -537,6 +561,7 @@ def __init__( summary=summary, sync_to_thread=sync_to_thread, tags=tags, + type_decoders=type_decoders, type_encoders=type_encoders, **kwargs, ) @@ -581,6 +606,7 @@ def __init__( middleware: Sequence[Middleware] | None = None, name: str | None = None, opt: Mapping[str, Any] | None = None, + request_class: type[Request] | None = None, response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, @@ -602,6 +628,7 @@ def __init__( security: Sequence[SecurityRequirement] | None = None, summary: str | None = None, tags: Sequence[str] | None = None, + type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, **kwargs: Any, ) -> None: @@ -642,6 +669,8 @@ def __init__( name: A string identifying the route handler. opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's + default request. response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's default response. response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. @@ -670,6 +699,8 @@ def __init__( security: A sequence of dictionaries that contain information about which security scheme can be used on the endpoint. summary: Text used for the route's schema summary section. tags: A sequence of string tags that will be appended to the OpenAPI schema. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. type_encoders: A mapping of types to callables that transform them into types supported for serialization. **kwargs: Any additional kwarg - will be set in the opt dictionary. """ @@ -702,6 +733,7 @@ def __init__( opt=opt, path=path, raises=raises, + request_class=request_class, response_class=response_class, response_cookies=response_cookies, response_description=response_description, @@ -714,6 +746,7 @@ def __init__( summary=summary, sync_to_thread=sync_to_thread, tags=tags, + type_decoders=type_decoders, type_encoders=type_encoders, **kwargs, ) @@ -745,6 +778,7 @@ def __init__( middleware: Sequence[Middleware] | None = None, name: str | None = None, opt: Mapping[str, Any] | None = None, + request_class: type[Request] | None = None, response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, @@ -766,6 +800,7 @@ def __init__( security: Sequence[SecurityRequirement] | None = None, summary: str | None = None, tags: Sequence[str] | None = None, + type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, **kwargs: Any, ) -> None: @@ -806,6 +841,8 @@ def __init__( name: A string identifying the route handler. opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's + default request. response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's default response. response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. @@ -834,6 +871,8 @@ def __init__( security: A sequence of dictionaries that contain information about which security scheme can be used on the endpoint. summary: Text used for the route's schema summary section. tags: A sequence of string tags that will be appended to the OpenAPI schema. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. type_encoders: A mapping of types to callables that transform them into types supported for serialization. **kwargs: Any additional kwarg - will be set in the opt dictionary. """ @@ -866,6 +905,7 @@ def __init__( opt=opt, path=path, raises=raises, + request_class=request_class, response_class=response_class, response_cookies=response_cookies, response_description=response_description, @@ -878,6 +918,7 @@ def __init__( summary=summary, sync_to_thread=sync_to_thread, tags=tags, + type_decoders=type_decoders, type_encoders=type_encoders, **kwargs, ) @@ -909,6 +950,7 @@ def __init__( middleware: Sequence[Middleware] | None = None, name: str | None = None, opt: Mapping[str, Any] | None = None, + request_class: type[Request] | None = None, response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, @@ -930,6 +972,7 @@ def __init__( security: Sequence[SecurityRequirement] | None = None, summary: str | None = None, tags: Sequence[str] | None = None, + type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, **kwargs: Any, ) -> None: @@ -970,6 +1013,8 @@ def __init__( name: A string identifying the route handler. opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as route handler's + default request. response_class: A custom subclass of :class:`Response <.response.Response>` to be used as route handler's default response. response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. @@ -998,6 +1043,8 @@ def __init__( security: A sequence of dictionaries that contain information about which security scheme can be used on the endpoint. summary: Text used for the route's schema summary section. tags: A sequence of string tags that will be appended to the OpenAPI schema. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. type_encoders: A mapping of types to callables that transform them into types supported for serialization. **kwargs: Any additional kwarg - will be set in the opt dictionary. """ @@ -1030,6 +1077,7 @@ def __init__( opt=opt, path=path, raises=raises, + request_class=request_class, response_class=response_class, response_cookies=response_cookies, response_description=response_description, @@ -1042,6 +1090,7 @@ def __init__( summary=summary, sync_to_thread=sync_to_thread, tags=tags, + type_decoders=type_decoders, type_encoders=type_encoders, **kwargs, ) diff --git a/litestar/handlers/websocket_handlers/listener.py b/litestar/handlers/websocket_handlers/listener.py index 4d6bc99e81..6d195d79ce 100644 --- a/litestar/handlers/websocket_handlers/listener.py +++ b/litestar/handlers/websocket_handlers/listener.py @@ -47,6 +47,7 @@ from litestar import Router from litestar.dto import AbstractDTO from litestar.types.asgi_types import WebSocketMode + from litestar.types.composite_types import TypeDecodersSequence __all__ = ("WebsocketListener", "WebsocketListenerRouteHandler", "websocket_listener") @@ -61,6 +62,7 @@ class WebsocketListenerRouteHandler(WebsocketRouteHandler): "connection_accept_handler": "Callback to accept a WebSocket connection. By default, calls WebSocket.accept", "on_accept": "Callback invoked after a WebSocket connection has been accepted", "on_disconnect": "Callback invoked after a WebSocket connection has been closed", + "weboscket_class": "WebSocket class", "_connection_lifespan": None, "_handle_receive": None, "_handle_send": None, @@ -85,7 +87,9 @@ def __init__( opt: dict[str, Any] | None = None, return_dto: type[AbstractDTO] | None | EmptyType = Empty, signature_namespace: Mapping[str, Any] | None = None, + type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, **kwargs: Any, ) -> None: ... @@ -109,7 +113,9 @@ def __init__( opt: dict[str, Any] | None = None, return_dto: type[AbstractDTO] | None | EmptyType = Empty, signature_namespace: Mapping[str, Any] | None = None, + type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, **kwargs: Any, ) -> None: ... @@ -133,7 +139,9 @@ def __init__( opt: dict[str, Any] | None = None, return_dto: type[AbstractDTO] | None | EmptyType = Empty, signature_namespace: Mapping[str, Any] | None = None, + type_decoders: TypeDecodersSequence | None = None, type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, **kwargs: Any, ) -> None: """Initialize ``WebsocketRouteHandler`` @@ -166,8 +174,12 @@ def __init__( outbound response data. signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. type_encoders: A mapping of types to callables that transform them into types supported for serialization. **kwargs: Any additional kwarg - will be set in the opt dictionary. + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's + default websocket class. """ if connection_lifespan and any([on_accept, on_disconnect, connection_accept_handler is not WebSocket.accept]): raise ImproperlyConfiguredException( @@ -184,7 +196,9 @@ def __init__( self.connection_accept_handler = connection_accept_handler self.on_accept = ensure_async_callable(on_accept) if on_accept else None self.on_disconnect = ensure_async_callable(on_disconnect) if on_disconnect else None + self.type_decoders = type_decoders self.type_encoders = type_encoders + self.websocket_class = websocket_class listener_dependencies = dict(dependencies or {}) @@ -209,6 +223,9 @@ def __init__( signature_namespace=signature_namespace, dto=dto, return_dto=return_dto, + type_decoders=type_decoders, + type_encoders=type_encoders, + websocket_class=websocket_class, **kwargs, ) @@ -342,10 +359,20 @@ class WebsocketListener(ABC): """ A mapping of names to types for use in forward reference resolution during signature modelling. """ + type_decoders: TypeDecodersSequence | None = None + """ + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. + """ type_encoders: TypeEncodersMap | None = None """ type_encoders: A mapping of types to callables that transform them into types supported for serialization. """ + websocket_class: type[WebSocket] | None = None + """ + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's + default websocket class. + """ def __init__(self, owner: Router) -> None: """Initialize a WebsocketListener instance. @@ -371,7 +398,9 @@ def to_handler(self) -> WebsocketListenerRouteHandler: path=self.path, return_dto=self.return_dto, signature_namespace=self.signature_namespace, + type_decoders=self.type_decoders, type_encoders=self.type_encoders, + websocket_class=self.websocket_class, )(self.on_receive) handler.owner = self._owner return handler diff --git a/litestar/handlers/websocket_handlers/route_handler.py b/litestar/handlers/websocket_handlers/route_handler.py index 850cf59c3c..edb49c3030 100644 --- a/litestar/handlers/websocket_handlers/route_handler.py +++ b/litestar/handlers/websocket_handlers/route_handler.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any, Mapping +from litestar.connection import WebSocket from litestar.exceptions import ImproperlyConfiguredException from litestar.handlers import BaseRouteHandler from litestar.types.builtin_types import NoneType @@ -28,6 +29,7 @@ def __init__( name: str | None = None, opt: dict[str, Any] | None = None, signature_namespace: Mapping[str, Any] | None = None, + websocket_class: type[WebSocket] | None = None, **kwargs: Any, ) -> None: """Initialize ``WebsocketRouteHandler`` @@ -46,7 +48,10 @@ def __init__( signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. type_encoders: A mapping of types to callables that transform them into types supported for serialization. **kwargs: Any additional kwarg - will be set in the opt dictionary. + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's + default websocket class. """ + self.websocket_class = websocket_class super().__init__( path=path, @@ -60,6 +65,19 @@ def __init__( **kwargs, ) + def resolve_websocket_class(self) -> type[WebSocket]: + """Return the closest custom WebSocket class in the owner graph or the default Websocket class. + + This method is memoized so the computation occurs only once. + + Returns: + The default :class:`WebSocket <.connection.WebSocket>` class for the route handler. + """ + return next( + (layer.websocket_class for layer in reversed(self.ownership_layers) if layer.websocket_class is not None), + WebSocket, + ) + def _validate_handler_function(self) -> None: """Validate the route handler function once it's set by inspecting its return annotations.""" super()._validate_handler_function() diff --git a/litestar/logging/config.py b/litestar/logging/config.py index d7733ac39d..d9d376e3f2 100644 --- a/litestar/logging/config.py +++ b/litestar/logging/config.py @@ -30,9 +30,9 @@ try: from structlog.types import BindableLogger, Processor, WrappedLogger except ImportError: - BindableLogger = Any # type: ignore - Processor = Any # type: ignore - WrappedLogger = Any # type: ignore + BindableLogger = Any # type: ignore[assignment, misc] + Processor = Any # type: ignore[misc] + WrappedLogger = Any # type: ignore[misc] default_handlers: dict[str, dict[str, Any]] = { diff --git a/litestar/middleware/logging.py b/litestar/middleware/logging.py index dc827e303e..f5c5ecf9b2 100644 --- a/litestar/middleware/logging.py +++ b/litestar/middleware/logging.py @@ -41,7 +41,7 @@ structlog_installed = True except ImportError: - BindableLogger = object # type: ignore + BindableLogger = object # type: ignore[assignment, misc] structlog_installed = False diff --git a/litestar/middleware/rate_limit.py b/litestar/middleware/rate_limit.py index 8517581ccd..79755ed176 100644 --- a/litestar/middleware/rate_limit.py +++ b/litestar/middleware/rate_limit.py @@ -242,7 +242,7 @@ class RateLimitConfig: def __post_init__(self) -> None: if self.check_throttle_handler: - self.check_throttle_handler = ensure_async_callable(self.check_throttle_handler) # type: ignore + self.check_throttle_handler = ensure_async_callable(self.check_throttle_handler) # type: ignore[arg-type] @property def middleware(self) -> DefineMiddleware: diff --git a/litestar/middleware/session/base.py b/litestar/middleware/session/base.py index b7e2e1586a..bb39fa682b 100644 --- a/litestar/middleware/session/base.py +++ b/litestar/middleware/session/base.py @@ -145,6 +145,17 @@ def deserialize_data(data: Any) -> dict[str, Any]: """ return cast("dict[str, Any]", decode_json(value=data)) + @abstractmethod + def get_session_id(self, connection: ASGIConnection) -> str | None: + """Try to fetch session id from connection ScopeState. If one does not exist, generate one. + + Args: + connection: Originating ASGIConnection containing the scope + + Returns: + Session id str or None if the concept of a session id does not apply. + """ + @abstractmethod async def store_in_message(self, scope_session: ScopeSession, message: Message, connection: ASGIConnection) -> None: """Store the necessary information in the outgoing ``Message`` @@ -241,5 +252,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: connection = ASGIConnection[Any, Any, Any, Any](scope, receive=receive, send=send) scope["session"] = await self.backend.load_from_connection(connection) + connection._connection_state.session_id = self.backend.get_session_id(connection) # pyright: ignore [reportGeneralTypeIssues] await self.app(scope, receive, self.create_send_wrapper(connection)) diff --git a/litestar/middleware/session/client_side.py b/litestar/middleware/session/client_side.py index cce502f64d..f709410478 100644 --- a/litestar/middleware/session/client_side.py +++ b/litestar/middleware/session/client_side.py @@ -206,6 +206,9 @@ async def load_from_connection(self, connection: ASGIConnection) -> dict[str, An return self.load_data(data) return {} + def get_session_id(self, connection: ASGIConnection) -> str | None: + return None + @dataclass class CookieBackendConfig(BaseBackendConfig[ClientSideSessionBackend]): # pyright: ignore diff --git a/litestar/middleware/session/server_side.py b/litestar/middleware/session/server_side.py index cec0011d80..91708ac80d 100644 --- a/litestar/middleware/session/server_side.py +++ b/litestar/middleware/session/server_side.py @@ -77,6 +77,26 @@ async def delete(self, session_id: str, store: Store) -> None: """ await store.delete(session_id) + def get_session_id(self, connection: ASGIConnection) -> str: + """Try to fetch session id from the connection. If one does not exist, generate one. + + If a session ID already exists in the cookies, it is returned. + If there is no ID in the cookies but one in the connection state, then the session exists but has not yet + been returned to the user. + Otherwise, a new session must be created. + + Args: + connection: Originating ASGIConnection containing the scope + Returns: + Session id str or None if the concept of a session id does not apply. + """ + session_id = connection.cookies.get(self.config.key) + if not session_id or session_id == "null": + session_id = connection.get_session_id() + if not session_id: + session_id = self.generate_session_id() + return session_id + def generate_session_id(self) -> str: """Generate a new session-ID, with n=:attr:`session_id_bytes ` random bytes. @@ -104,9 +124,7 @@ async def store_in_message(self, scope_session: ScopeSession, message: Message, scope = connection.scope store = self.config.get_store_from_app(scope["app"]) headers = MutableScopeHeaders.from_message(message) - session_id = connection.cookies.get(self.config.key) - if not session_id or session_id == "null": - session_id = self.generate_session_id() + session_id = self.get_session_id(connection) cookie_params = dict(extract_dataclass_items(self.config, exclude_none=True, include=Cookie.__dict__.keys())) diff --git a/litestar/openapi/config.py b/litestar/openapi/config.py index 8c095c8102..c935693696 100644 --- a/litestar/openapi/config.py +++ b/litestar/openapi/config.py @@ -43,6 +43,8 @@ class OpenAPIConfig: create_examples: bool = field(default=False) """Generate examples using the polyfactory library.""" + random_seed: int = 10 + """The random seed used when creating the examples to ensure deterministic generation of examples.""" openapi_controller: type[OpenAPIController] = field(default_factory=lambda: OpenAPIController) """Controller for generating OpenAPI routes. diff --git a/litestar/openapi/datastructures.py b/litestar/openapi/datastructures.py index 93d0cf4d3f..5796a48d4c 100644 --- a/litestar/openapi/datastructures.py +++ b/litestar/openapi/datastructures.py @@ -6,6 +6,7 @@ from litestar.enums import MediaType if TYPE_CHECKING: + from litestar.openapi.spec import Example from litestar.types import DataContainerType @@ -24,3 +25,5 @@ class ResponseSpec: """A description of the response.""" media_type: MediaType = field(default=MediaType.JSON) """Response media type.""" + examples: list[Example] | None = field(default=None) + """A list of Example models.""" diff --git a/litestar/response/base.py b/litestar/response/base.py index 7f16cbe011..16523fefc3 100644 --- a/litestar/response/base.py +++ b/litestar/response/base.py @@ -1,6 +1,7 @@ from __future__ import annotations import itertools +import re from typing import TYPE_CHECKING, Any, ClassVar, Generic, Iterable, Literal, Mapping, TypeVar, overload from litestar.datastructures.cookie import Cookie @@ -35,6 +36,8 @@ T = TypeVar("T") +MEDIA_TYPE_APPLICATION_JSON_PATTERN = re.compile(r"^application/(?:.+\+)?json") + class ASGIResponse: """A low-level ASGI response class.""" @@ -385,7 +388,9 @@ def render(self, content: Any, media_type: str, enc_hook: Serializer = default_s if media_type == MediaType.MESSAGEPACK: return encode_msgpack(content, enc_hook) - if media_type.startswith("application/json"): + if MEDIA_TYPE_APPLICATION_JSON_PATTERN.match( + media_type, + ): return encode_json(content, enc_hook) raise ImproperlyConfiguredException(f"unsupported media_type {media_type} for content {content!r}") diff --git a/litestar/router.py b/litestar/router.py index d65f39d47b..decb08632f 100644 --- a/litestar/router.py +++ b/litestar/router.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections import defaultdict -from copy import copy +from copy import copy, deepcopy from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast from litestar._layers.utils import narrow_response_cookies, narrow_response_headers @@ -20,6 +20,7 @@ if TYPE_CHECKING: + from litestar.connection import Request, WebSocket from litestar.datastructures import CacheControlHeader, ETag from litestar.dto import AbstractDTO from litestar.openapi.spec import SecurityRequirement @@ -66,6 +67,7 @@ class Router: "parameters", "path", "registered_route_handler_ids", + "request_class", "response_class", "response_cookies", "response_headers", @@ -74,8 +76,9 @@ class Router: "security", "signature_namespace", "tags", - "type_encoders", "type_decoders", + "type_encoders", + "websocket_class", ) def __init__( @@ -95,6 +98,7 @@ def __init__( middleware: Sequence[Middleware] | None = None, opt: Mapping[str, Any] | None = None, parameters: ParametersMap | None = None, + request_class: type[Request] | None = None, response_class: type[Response] | None = None, response_cookies: ResponseCookies | None = None, response_headers: ResponseHeaders | None = None, @@ -104,8 +108,9 @@ def __init__( signature_namespace: Mapping[str, Any] | None = None, signature_types: Sequence[Any] | None = None, tags: Sequence[str] | None = None, - type_encoders: TypeEncodersMap | None = None, type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, ) -> None: """Initialize a ``Router``. @@ -136,6 +141,8 @@ def __init__( paths. path: A path fragment that is prefixed to all route handlers, controllers and other routers associated with the router instance. + request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as the default for + all route handlers, controllers and other routers associated with the router instance. response_class: A custom subclass of :class:`Response <.response.Response>` to be used as the default for all route handlers, controllers and other routers associated with the router instance. response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances. @@ -154,8 +161,10 @@ def __init__( These types will be added to the signature namespace using their ``__name__`` attribute. tags: A sequence of string tags that will be appended to the schema of all route handlers under the application. - type_encoders: A mapping of types to callables that transform them into types supported for serialization. type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as the default for + all route handlers, controllers and other routers associated with the router instance. """ self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore @@ -173,6 +182,7 @@ def __init__( self.owner: Router | None = None self.parameters = dict(parameters or {}) self.path = normalize_path(path) + self.request_class = request_class self.response_class = response_class self.response_cookies = narrow_response_cookies(response_cookies) self.response_headers = narrow_response_headers(response_headers) @@ -186,6 +196,7 @@ def __init__( self.registered_route_handler_ids: set[int] = set() self.type_encoders = dict(type_encoders) if type_encoders is not None else None self.type_decoders = list(type_decoders) if type_decoders is not None else None + self.websocket_class = websocket_class for route_handler in route_handlers or []: self.register(value=route_handler) @@ -309,14 +320,12 @@ def _validate_registration_value(self, value: ControllerRouterHandler) -> Contro return value(owner=self).to_handler() # pyright: ignore if isinstance(value, Router): - if value.owner: - raise ImproperlyConfiguredException(f"Router with path {value.path} has already been registered") - if value is self: raise ImproperlyConfiguredException("Cannot register a router on itself") - value.owner = self - return value + router_copy = deepcopy(value) + router_copy.owner = self + return router_copy if isinstance(value, (ASGIRouteHandler, HTTPRouteHandler, WebsocketRouteHandler)): value.owner = self diff --git a/litestar/routes/http.py b/litestar/routes/http.py index d93fd66e9d..b1f70cb36b 100644 --- a/litestar/routes/http.py +++ b/litestar/routes/http.py @@ -73,8 +73,8 @@ async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None: Returns: None """ - request: Request[Any, Any, Any] = scope["app"].request_class(scope=scope, receive=receive, send=send) route_handler, parameter_model = self.route_handler_map[scope["method"]] + request: Request[Any, Any, Any] = route_handler.resolve_request_class()(scope=scope, receive=receive, send=send) if route_handler.resolve_guards(): await route_handler.authorize_connection(connection=request) diff --git a/litestar/routes/websocket.py b/litestar/routes/websocket.py index 9b309fe107..ebf4959d46 100644 --- a/litestar/routes/websocket.py +++ b/litestar/routes/websocket.py @@ -54,11 +54,14 @@ async def handle(self, scope: WebSocketScope, receive: Receive, send: Send) -> N Returns: None """ - websocket: WebSocket[Any, Any, Any] = scope["app"].websocket_class(scope=scope, receive=receive, send=send) if not self.handler_parameter_model: # pragma: no cover raise ImproperlyConfiguredException("handler parameter model not defined") + websocket: WebSocket[Any, Any, Any] = self.route_handler.resolve_websocket_class()( + scope=scope, receive=receive, send=send + ) + if self.route_handler.resolve_guards(): await self.route_handler.authorize_connection(connection=websocket) diff --git a/litestar/security/jwt/auth.py b/litestar/security/jwt/auth.py index c819ed0b61..2a0f09497f 100644 --- a/litestar/security/jwt/auth.py +++ b/litestar/security/jwt/auth.py @@ -212,7 +212,7 @@ def format_auth_header(self, encoded_token: str) -> str: Returns: The encoded token formatted for the HTTP headers """ - security = self.openapi_components.security_schemes.get(self.openapi_security_scheme_name, None) # type: ignore + security = self.openapi_components.security_schemes.get(self.openapi_security_scheme_name, None) # type: ignore[union-attr] return f"{security.scheme} {encoded_token}" if isinstance(security, SecurityScheme) else encoded_token diff --git a/litestar/security/session_auth/middleware.py b/litestar/security/session_auth/middleware.py index e1dce45691..bb3fce4349 100644 --- a/litestar/security/session_auth/middleware.py +++ b/litestar/security/session_auth/middleware.py @@ -8,7 +8,6 @@ AuthenticationResult, ) from litestar.middleware.exceptions import ExceptionHandlerMiddleware -from litestar.middleware.session.base import SessionMiddleware from litestar.types import Empty, Method, Scopes __all__ = ("MiddlewareWrapper", "SessionAuthMiddleware") @@ -54,14 +53,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: exclude_http_methods=self.config.exclude_http_methods, exclude_opt_key=self.config.exclude_opt_key, scopes=self.config.scopes, - retrieve_user_handler=self.config.retrieve_user_handler, # type: ignore + retrieve_user_handler=self.config.retrieve_user_handler, # type: ignore[arg-type] ) exception_middleware = ExceptionHandlerMiddleware( app=auth_middleware, exception_handlers=litestar_app.exception_handlers or {}, # pyright: ignore debug=None, ) - self.app = SessionMiddleware( + self.app = self.config.session_backend_config.middleware.middleware( app=exception_middleware, backend=self.config.session_backend, ) diff --git a/litestar/serialization/msgspec_hooks.py b/litestar/serialization/msgspec_hooks.py index b26415fb2c..779f2e003e 100644 --- a/litestar/serialization/msgspec_hooks.py +++ b/litestar/serialization/msgspec_hooks.py @@ -163,7 +163,7 @@ def decode_json(value: str | bytes, target_type: type[T], type_decoders: TypeDec ... -def decode_json( # type: ignore +def decode_json( # type: ignore[misc] value: str | bytes, target_type: type[T] | EmptyType = Empty, # pyright: ignore type_decoders: TypeDecodersSequence | None = None, diff --git a/litestar/testing/client/base.py b/litestar/testing/client/base.py index 93f1082a97..3c25be117b 100644 --- a/litestar/testing/client/base.py +++ b/litestar/testing/client/base.py @@ -22,7 +22,6 @@ from httpx._types import CookieTypes from litestar.middleware.session.base import BaseBackendConfig, BaseSessionBackend - from litestar.middleware.session.client_side import ClientSideSessionBackend from litestar.types.asgi_types import HTTPScope, Receive, Scope, Send T = TypeVar("T", bound=ASGIApp) @@ -155,20 +154,16 @@ def portal(self) -> Generator[BlockingPortal, None, None]: ) as portal: yield portal - @staticmethod - def _create_session_cookies(backend: ClientSideSessionBackend, data: dict[str, Any]) -> dict[str, str]: - encoded_data = backend.dump_data(data=data) - return {cookie.key: cast("str", cookie.value) for cookie in backend._create_session_cookies(encoded_data)} - async def _set_session_data(self, data: dict[str, Any]) -> None: mutable_headers = MutableScopeHeaders() + connection = fake_asgi_connection( + app=self.app, + cookies=dict(self.cookies), # type: ignore[arg-type] + ) + session_id = self.session_backend.get_session_id(connection) + connection._connection_state.session_id = session_id # pyright: ignore [reportGeneralTypeIssues] await self.session_backend.store_in_message( - scope_session=data, - message=fake_http_send_message(mutable_headers), - connection=fake_asgi_connection( - app=self.app, - cookies=dict(self.cookies), # type: ignore[arg-type] - ), + scope_session=data, message=fake_http_send_message(mutable_headers), connection=connection ) response = Response(200, request=Request("GET", self.base_url), headers=mutable_headers.headers) diff --git a/litestar/testing/client/sync_client.py b/litestar/testing/client/sync_client.py index b63017b036..df8174fbb5 100644 --- a/litestar/testing/client/sync_client.py +++ b/litestar/testing/client/sync_client.py @@ -511,7 +511,7 @@ def websocket_connect( self, "GET", url, - headers={**dict(headers or {}), **default_headers}, # type: ignore + headers={**dict(headers or {}), **default_headers}, # type: ignore[misc] params=params, cookies=cookies, auth=auth, diff --git a/litestar/types/builtin_types.py b/litestar/types/builtin_types.py index e8555faeb8..335dedd798 100644 --- a/litestar/types/builtin_types.py +++ b/litestar/types/builtin_types.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Type, Union -from typing_extensions import _TypedDictMeta # type: ignore +from typing_extensions import _TypedDictMeta # type: ignore[attr-defined] if TYPE_CHECKING: from typing_extensions import TypeAlias diff --git a/litestar/types/serialization.py b/litestar/types/serialization.py index 50bb5684a8..0f61e10533 100644 --- a/litestar/types/serialization.py +++ b/litestar/types/serialization.py @@ -29,12 +29,12 @@ try: from pydantic import BaseModel except ImportError: - BaseModel = Any # type: ignore + BaseModel = Any # type: ignore[assignment, misc] try: from attrs import AttrsInstance except ImportError: - AttrsInstance = Any # type: ignore + AttrsInstance = Any # type: ignore[assignment, misc] __all__ = ( "LitestarEncodableType", diff --git a/litestar/utils/helpers.py b/litestar/utils/helpers.py index bf88d46d4c..c25fe35f25 100644 --- a/litestar/utils/helpers.py +++ b/litestar/utils/helpers.py @@ -55,7 +55,7 @@ def get_enum_string_value(value: Enum | str) -> str: Returns: A string. """ - return value.value if isinstance(value, Enum) else value # type:ignore + return value.value if isinstance(value, Enum) else value # type: ignore[no-any-return] def unwrap_partial(value: MaybePartial[T]) -> T: diff --git a/litestar/utils/predicates.py b/litestar/utils/predicates.py index b2c4fb5d72..11d5f792ca 100644 --- a/litestar/utils/predicates.py +++ b/litestar/utils/predicates.py @@ -47,7 +47,7 @@ try: import attrs except ImportError: - attrs = Empty # type: ignore + attrs = Empty # type: ignore[assignment] __all__ = ( "is_annotated_type", @@ -148,7 +148,7 @@ def is_generic(annotation: Any) -> bool: Returns: True if the annotation is a subclass of :data:`Generic ` otherwise ``False``. """ - return is_class_and_subclass(annotation, Generic) # type: ignore + return is_class_and_subclass(annotation, Generic) # type: ignore[arg-type] def is_mapping(annotation: Any) -> TypeGuard[Mapping[Any, Any]]: @@ -200,7 +200,7 @@ def is_non_string_sequence(annotation: Any) -> TypeGuard[Sequence[Any]]: try: return not issubclass(origin or annotation, (str, bytes)) and issubclass( origin or annotation, - ( # type: ignore + ( # type: ignore[arg-type] Tuple, List, Set, diff --git a/litestar/utils/scope/state.py b/litestar/utils/scope/state.py index 31f6442e61..bed43940e2 100644 --- a/litestar/utils/scope/state.py +++ b/litestar/utils/scope/state.py @@ -41,6 +41,7 @@ class ScopeState: "msgpack", "parsed_query", "response_compressed", + "session_id", "url", "_compat_ns", ) @@ -62,6 +63,7 @@ def __init__(self) -> None: self.msgpack = Empty self.parsed_query = Empty self.response_compressed = Empty + self.session_id = Empty self.url = Empty self._compat_ns: dict[str, Any] = {} @@ -81,6 +83,7 @@ def __init__(self) -> None: msgpack: Any | EmptyType parsed_query: tuple[tuple[str, str], ...] | EmptyType response_compressed: bool | EmptyType + session_id: str | None | EmptyType url: URL | EmptyType _compat_ns: dict[str, Any] diff --git a/pyproject.toml b/pyproject.toml index ee1e2920be..56244f770a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ maintainers = [ name = "litestar" readme = "README.md" requires-python = ">=3.8,<4.0" -version = "2.6.3" +version = "2.7.0" [project.urls] Blog = "https://blog.litestar.dev" @@ -227,6 +227,7 @@ python_version = "3.8" disallow_any_generics = false disallow_untyped_decorators = true +enable_error_code = "ignore-without-code" implicit_reexport = false show_error_codes = true strict = true diff --git a/tests/e2e/test_router_registration.py b/tests/e2e/test_router_registration.py index a935460425..e7ed74f2f1 100644 --- a/tests/e2e/test_router_registration.py +++ b/tests/e2e/test_router_registration.py @@ -132,9 +132,7 @@ def second_route_handler(self) -> None: def test_register_already_registered_router() -> None: first_router = Router(path="/first", route_handlers=[]) Router(path="/second", route_handlers=[first_router]) - - with pytest.raises(ImproperlyConfiguredException): - Router(path="/third", route_handlers=[first_router]) + Router(path="/third", route_handlers=[first_router]) def test_register_router_on_itself() -> None: diff --git a/tests/e2e/test_routing/test_route_indexing.py b/tests/e2e/test_routing/test_route_indexing.py index 881b49afbe..c09f15f48a 100644 --- a/tests/e2e/test_routing/test_route_indexing.py +++ b/tests/e2e/test_routing/test_route_indexing.py @@ -24,7 +24,7 @@ @pytest.mark.parametrize("decorator", [get, post, patch, put, delete]) def test_indexes_handlers(decorator: Type[HTTPRouteHandler]) -> None: - @decorator("/path-one/{param:str}", name="handler-name") # type: ignore + @decorator("/path-one/{param:str}", name="handler-name") # type: ignore[call-arg] def handler() -> None: return None @@ -59,18 +59,18 @@ async def websocket_handler(socket: Any) -> None: @pytest.mark.parametrize("decorator", [get, post, patch, put, delete]) def test_default_indexes_handlers(decorator: Type[HTTPRouteHandler]) -> None: - @decorator("/handler") # type: ignore + @decorator("/handler") # type: ignore[call-arg] def handler() -> None: pass - @decorator("/named_handler", name="named_handler") # type: ignore + @decorator("/named_handler", name="named_handler") # type: ignore[call-arg] def named_handler() -> None: pass class MyController(Controller): path = "/test" - @decorator() # type: ignore + @decorator() # type: ignore[call-arg] def handler(self) -> None: pass @@ -95,11 +95,11 @@ def handler(self) -> None: @pytest.mark.parametrize("decorator", [get, post, patch, put, delete]) def test_indexes_handlers_with_multiple_paths(decorator: Type[HTTPRouteHandler]) -> None: - @decorator(["/path-one", "/path-one/{param:str}"], name="handler") # type: ignore + @decorator(["/path-one", "/path-one/{param:str}"], name="handler") # type: ignore[call-arg] def handler() -> None: return None - @decorator(["/path-two"], name="handler-two") # type: ignore + @decorator(["/path-two"], name="handler-two") # type: ignore[call-arg] def handler_two() -> None: return None diff --git a/tests/e2e/test_routing/test_route_reverse.py b/tests/e2e/test_routing/test_route_reverse.py index bf9627e805..0a8d914883 100644 --- a/tests/e2e/test_routing/test_route_reverse.py +++ b/tests/e2e/test_routing/test_route_reverse.py @@ -10,26 +10,26 @@ @pytest.mark.parametrize("decorator", [get, post, patch, put, delete]) def test_route_reverse(decorator: Type[HTTPRouteHandler]) -> None: - @decorator("/path-one/{param:str}", name="handler-name") # type: ignore + @decorator("/path-one/{param:str}", name="handler-name") # type: ignore[call-arg] def handler() -> None: return None - @decorator("/path-two", name="handler-no-params") # type: ignore + @decorator("/path-two", name="handler-no-params") # type: ignore[call-arg] def handler_no_params() -> None: return None - @decorator("/multiple/{str_param:str}/params/{int_param:int}/", name="multiple-params-handler-name") # type: ignore + @decorator("/multiple/{str_param:str}/params/{int_param:int}/", name="multiple-params-handler-name") # type: ignore[call-arg] def handler2() -> None: return None @decorator( ["/handler3", "/handler3/{str_param:str}/", "/handler3/{str_param:str}/{int_param:int}/"], name="multiple-default-params", - ) # type: ignore + ) # type: ignore[call-arg] def handler3(str_param: str = "default", int_param: int = 0) -> None: return None - @decorator(["/handler4/int/{int_param:int}", "/handler4/str/{str_param:str}"], name="handler4") # type: ignore + @decorator(["/handler4/int/{int_param:int}", "/handler4/str/{str_param:str}"], name="handler4") # type: ignore[call-arg] def handler4(int_param: int = 1, str_param: str = "str") -> None: return None @@ -62,7 +62,7 @@ def handler4(int_param: int = 1, str_param: str = "str") -> None: "complex_path_param", [("time", time(hour=14), "14:00"), ("float", float(1 / 3), "0.33")], ) -def test_route_reverse_validation_complex_params(complex_path_param) -> None: # type: ignore +def test_route_reverse_validation_complex_params(complex_path_param) -> None: # type: ignore[no-untyped-def] param_type, param_value, param_manual_str = complex_path_param @get(f"/abc/{{param:{param_type}}}", name="handler") diff --git a/tests/examples/test_request_data.py b/tests/examples/test_request_data.py index f86d20a94b..2c03e08bbe 100644 --- a/tests/examples/test_request_data.py +++ b/tests/examples/test_request_data.py @@ -1,3 +1,4 @@ +from docs.examples.request_data.custom_request import app as custom_request_class_app from docs.examples.request_data.msgpack_request import app as msgpack_app from docs.examples.request_data.request_data_1 import app from docs.examples.request_data.request_data_2 import app as app_2 @@ -99,3 +100,9 @@ def test_msgpack_app() -> None: with TestClient(app=msgpack_app) as client: response = client.post("/", content=encode_msgpack(test_data)) assert response.json() == test_data + + +def test_custom_request_app() -> None: + with TestClient(app=custom_request_class_app) as client: + response = client.get("/kitten-name") + assert response.content == b"Whiskers" diff --git a/tests/examples/test_responses/test_json_suffix_responses.py b/tests/examples/test_responses/test_json_suffix_responses.py new file mode 100644 index 0000000000..f853600087 --- /dev/null +++ b/tests/examples/test_responses/test_json_suffix_responses.py @@ -0,0 +1,15 @@ +from docs.examples.responses.json_suffix_responses import app + +from litestar.testing import TestClient + + +def test_json_suffix_responses() -> None: + with TestClient(app=app) as client: + res = client.get("/resources") + assert res.status_code == 418 + assert res.json() == { + "title": "Server thinks it is a teapot", + "type": "Server delusion", + "status": 418, + } + assert res.headers["content-type"] == "application/problem+json" diff --git a/tests/examples/test_websockets.py b/tests/examples/test_websockets.py new file mode 100644 index 0000000000..c2898aa117 --- /dev/null +++ b/tests/examples/test_websockets.py @@ -0,0 +1,12 @@ +from docs.examples.websockets.custom_websocket import app as custom_websocket_class_app + +from litestar.testing.client.sync_client import TestClient + + +def test_custom_websocket_class(): + client = TestClient(app=custom_websocket_class_app) + + with client.websocket_connect("/") as ws: + ws.send({"data": "I should not be in response"}) + data = ws.receive() + assert data["text"] == "Fixed response" diff --git a/tests/unit/test_cli/test_core_commands.py b/tests/unit/test_cli/test_core_commands.py index 46c89d68d1..e2acc93c67 100644 --- a/tests/unit/test_cli/test_core_commands.py +++ b/tests/unit/test_cli/test_core_commands.py @@ -425,7 +425,7 @@ def test_remove_default_schema_routes() -> None: api_config = MagicMock() api_config.openapi_controller.path = "/schema" - results = remove_default_schema_routes(http_routes, api_config) # type: ignore + results = remove_default_schema_routes(http_routes, api_config) # type: ignore[arg-type] assert len(results) == 3 for result in results: words = re.split(r"(^\/[a-z]+)", result.path) @@ -441,7 +441,7 @@ def test_remove_routes_with_patterns() -> None: http_routes.append(http_route) patterns = ("/destroy", "/pizza", "[]") - results = remove_routes_with_patterns(http_routes, patterns) # type: ignore + results = remove_routes_with_patterns(http_routes, patterns) # type: ignore[arg-type] paths = [route.path for route in results] assert len(paths) == 2 for route in ["/", "/foo"]: diff --git a/tests/unit/test_connection/test_request.py b/tests/unit/test_connection/test_request.py index 0bff39624f..b7203904e6 100644 --- a/tests/unit/test_connection/test_request.py +++ b/tests/unit/test_connection/test_request.py @@ -140,7 +140,7 @@ def test_custom_request_class() -> None: class MyRequest(Request): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self.scope["called"] = True # type: ignore + self.scope["called"] = True # type: ignore[typeddict-unknown-key] @get("/", signature_types=[MyRequest]) def handler(request: MyRequest) -> None: @@ -382,7 +382,7 @@ def test_request_state() -> None: def handler(request: Request[Any, Any, Any]) -> dict[Any, Any]: request.state.test = 1 assert request.state.test == 1 - return request.state.dict() # type: ignore + return request.state.dict() # type: ignore[no-any-return] with create_test_client(handler) as client: response = client.get("/") @@ -431,7 +431,7 @@ def post_body() -> Generator[bytes, None, None]: def test_request_send_push_promise() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: # the server is push-enabled - scope["extensions"]["http.response.push"] = {} # type: ignore + scope["extensions"]["http.response.push"] = {} # type: ignore[index] request = Request[Any, Any, Any](scope, receive, send) await request.send_push_promise("/style.css") @@ -490,7 +490,7 @@ def test_request_send_push_promise_without_setting_send() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: # the server is push-enabled - scope["extensions"]["http.response.push"] = {} # type: ignore + scope["extensions"]["http.response.push"] = {} # type: ignore[index] data = "OK" request = Request[Any, Any, Any](scope) diff --git a/tests/unit/test_connection/test_websocket.py b/tests/unit/test_connection/test_websocket.py index 8480bf60c4..3eef092e02 100644 --- a/tests/unit/test_connection/test_websocket.py +++ b/tests/unit/test_connection/test_websocket.py @@ -71,7 +71,7 @@ async def test_custom_request_class() -> None: class MyWebSocket(WebSocket[Any, Any, Any]): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self.scope["called"] = True # type: ignore + self.scope["called"] = True # type: ignore[typeddict-unknown-key] @websocket("/", signature_types=[MyWebSocket]) async def handler(socket: MyWebSocket) -> None: diff --git a/tests/unit/test_contrib/test_htmx/test_htmx_response.py b/tests/unit/test_contrib/test_htmx/test_htmx_response.py index 8a7b4e3bcc..0a0e94a1ac 100644 --- a/tests/unit/test_contrib/test_htmx/test_htmx_response.py +++ b/tests/unit/test_contrib/test_htmx/test_htmx_response.py @@ -187,7 +187,7 @@ def handler() -> TriggerEvent: return TriggerEvent( content="Success!", name="alert", - after="invalid", # type: ignore + after="invalid", # type: ignore[arg-type] params={"warning": "Confirm your choice!"}, ) @@ -361,7 +361,7 @@ def handler() -> HTMXTemplate: context={"request": {"scope": {"path": "nope"}}}, trigger_event="showMessage", params={"alert": "Confirm your Choice."}, - after="begin", # type: ignore + after="begin", # type: ignore[arg-type] ) with create_test_client( diff --git a/tests/unit/test_contrib/test_opentelemetry.py b/tests/unit/test_contrib/test_opentelemetry.py index 6a788f10c0..ef16140b5e 100644 --- a/tests/unit/test_contrib/test_opentelemetry.py +++ b/tests/unit/test_contrib/test_opentelemetry.py @@ -33,7 +33,7 @@ def create_config(**kwargs: Any) -> Tuple[OpenTelemetryConfig, InMemoryMetricRea tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) aggregation_last_value = {Counter: ExplicitBucketHistogramAggregation()} - reader = InMemoryMetricReader(preferred_aggregation=aggregation_last_value) # type: ignore + reader = InMemoryMetricReader(preferred_aggregation=aggregation_last_value) # type: ignore[arg-type] meter_provider = MeterProvider(resource=resource, metric_readers=[reader]) set_meter_provider(meter_provider) @@ -60,9 +60,9 @@ def handler() -> dict: assert reader.get_metrics_data() first_span, second_span, third_span = cast("Tuple[Span, Span, Span]", exporter.get_finished_spans()) - assert dict(first_span.attributes) == {"http.status_code": 200, "type": "http.response.start"} # type: ignore - assert dict(second_span.attributes) == {"type": "http.response.body"} # type: ignore - assert dict(third_span.attributes) == { # type: ignore + assert dict(first_span.attributes) == {"http.status_code": 200, "type": "http.response.start"} # type: ignore[arg-type] + assert dict(second_span.attributes) == {"type": "http.response.body"} # type: ignore[arg-type] + assert dict(third_span.attributes) == { # type: ignore[arg-type] "http.scheme": "http", "http.host": "testserver.local", "net.host.port": 80, @@ -107,11 +107,11 @@ async def handler(socket: "WebSocket") -> None: first_span, second_span, third_span, fourth_span, fifth_span = cast( "Tuple[Span, Span, Span, Span, Span]", exporter.get_finished_spans() ) - assert dict(first_span.attributes) == {"type": "websocket.connect"} # type: ignore - assert dict(second_span.attributes) == {"type": "websocket.accept"} # type: ignore - assert dict(third_span.attributes) == {"http.status_code": 200, "type": "websocket.send"} # type: ignore - assert dict(fourth_span.attributes) == {"type": "websocket.close"} # type: ignore - assert dict(fifth_span.attributes) == { # type: ignore + assert dict(first_span.attributes) == {"type": "websocket.connect"} # type: ignore[arg-type] + assert dict(second_span.attributes) == {"type": "websocket.accept"} # type: ignore[arg-type] + assert dict(third_span.attributes) == {"http.status_code": 200, "type": "websocket.send"} # type: ignore[arg-type] + assert dict(fourth_span.attributes) == {"type": "websocket.close"} # type: ignore[arg-type] + assert dict(fifth_span.attributes) == { # type: ignore[arg-type] "http.scheme": "ws", "http.host": "testserver", "net.host.port": 80, diff --git a/tests/unit/test_contrib/test_piccolo_orm/test_piccolo_orm_dto.py b/tests/unit/test_contrib/test_piccolo_orm/test_piccolo_orm_dto.py index 6df59aaedb..c382ae1ba4 100644 --- a/tests/unit/test_contrib/test_piccolo_orm/test_piccolo_orm_dto.py +++ b/tests/unit/test_contrib/test_piccolo_orm/test_piccolo_orm_dto.py @@ -119,19 +119,19 @@ def test_piccolo_dto_openapi_spec_generation() -> None: post_operation = concert_path.post assert ( - post_operation.request_body.content["application/json"].schema.ref # type: ignore + post_operation.request_body.content["application/json"].schema.ref # type: ignore[union-attr] == "#/components/schemas/CreateConcertConcertRequestBody" ) studio_path_get_operation = studio_path.get assert ( - studio_path_get_operation.responses["200"].content["application/json"].schema.ref # type: ignore + studio_path_get_operation.responses["200"].content["application/json"].schema.ref # type: ignore[index, union-attr] == "#/components/schemas/RetrieveStudioRecordingStudioResponseBody" ) venues_path_get_operation = venues_path.get assert ( - venues_path_get_operation.responses["200"].content["application/json"].schema.items.ref # type: ignore + venues_path_get_operation.responses["200"].content["application/json"].schema.items.ref # type: ignore[index, union-attr] == "#/components/schemas/RetrieveVenuesVenueResponseBody" ) diff --git a/tests/unit/test_data_extractors.py b/tests/unit/test_data_extractors.py index b204707bd1..5ca03bff76 100644 --- a/tests/unit/test_data_extractors.py +++ b/tests/unit/test_data_extractors.py @@ -27,7 +27,7 @@ async def test_connection_data_extractor() -> None: request.scope["path_params"] = {"first": "10", "second": "20", "third": "30"} extractor = ConnectionDataExtractor(parse_body=True, parse_query=True) extracted_data = extractor(request) - assert await extracted_data.get("body") == await request.json() # type: ignore + assert await extracted_data.get("body") == await request.json() # type: ignore[misc] assert extracted_data.get("content_type") == request.content_type assert extracted_data.get("headers") == dict(request.headers) assert extracted_data.get("headers") == dict(request.headers) @@ -48,24 +48,24 @@ def test_parse_query() -> None: assert parsed_extracted_data.get("query") == request.query_params.dict() assert unparsed_extracted_data.get("query") == request.scope["query_string"] # Close to avoid warnings about un-awaited coroutines. - parsed_extracted_data.get("body").close() # type: ignore - unparsed_extracted_data.get("body").close() # type: ignore + parsed_extracted_data.get("body").close() # type: ignore[union-attr] + unparsed_extracted_data.get("body").close() # type: ignore[union-attr] async def test_parse_json_data() -> None: request = factory.post(path="/a/b/c", data={"hello": "world"}) - assert await ConnectionDataExtractor(parse_body=True)(request).get("body") == await request.json() # type: ignore - assert await ConnectionDataExtractor()(request).get("body") == await request.body() # type: ignore + assert await ConnectionDataExtractor(parse_body=True)(request).get("body") == await request.json() # type: ignore[misc] + assert await ConnectionDataExtractor()(request).get("body") == await request.body() # type: ignore[misc] async def test_parse_form_data() -> None: request = factory.post(path="/a/b/c", data={"file": b"123"}, request_media_type=RequestEncodingType.MULTI_PART) - assert await ConnectionDataExtractor(parse_body=True)(request).get("body") == dict(await request.form()) # type: ignore + assert await ConnectionDataExtractor(parse_body=True)(request).get("body") == dict(await request.form()) # type: ignore[misc] async def test_parse_url_encoded() -> None: request = factory.post(path="/a/b/c", data={"key": "123"}, request_media_type=RequestEncodingType.URL_ENCODED) - assert await ConnectionDataExtractor(parse_body=True)(request).get("body") == dict(await request.form()) # type: ignore + assert await ConnectionDataExtractor(parse_body=True)(request).get("body") == dict(await request.form()) # type: ignore[misc] @pytest.mark.parametrize("req", [factory.get(headers={"Special": "123"}), factory.get(headers={"special": "123"})]) @@ -74,7 +74,7 @@ def test_request_extraction_header_obfuscation(req: Request[Any, Any, Any]) -> N extracted_data = extractor(req) assert extracted_data.get("headers") == {"special": "*****"} # Close to avoid warnings about un-awaited coroutines. - extracted_data.get("body").close() # type: ignore + extracted_data.get("body").close() # type: ignore[union-attr] @pytest.mark.parametrize( @@ -89,7 +89,7 @@ def test_request_extraction_cookie_obfuscation(req: Request[Any, Any, Any], key: extracted_data = extractor(req) assert extracted_data.get("cookies") == {"Path": "/", "SameSite": "lax", key: "*****"} # Close to avoid warnings about un-awaited coroutines. - extracted_data.get("body").close() # type: ignore + extracted_data.get("body").close() # type: ignore[union-attr] async def test_response_data_extractor() -> None: @@ -105,7 +105,7 @@ async def send(message: "Any") -> None: await response({}, empty_receive, send) # type: ignore[arg-type] assert len(messages) == 2 - extracted_data = extractor(messages) # type: ignore + extracted_data = extractor(messages) # type: ignore[arg-type] assert extracted_data.get("status_code") == HTTP_200_OK assert extracted_data.get("body") == b'{"hello":"world"}' assert extracted_data.get("headers") == {**headers, "content-length": "17"} diff --git a/tests/unit/test_datastructures/test_headers.py b/tests/unit/test_datastructures/test_headers.py index 355804ac99..dd66409bfe 100644 --- a/tests/unit/test_datastructures/test_headers.py +++ b/tests/unit/test_datastructures/test_headers.py @@ -36,7 +36,7 @@ class TestHeader(Header): def _get_header_value(self) -> str: return "" - def from_header(self, header_value: str) -> "Header": # type: ignore + def from_header(self, header_value: str) -> "Header": # type: ignore[explicit-override, override] return self with pytest.raises(ImproperlyConfiguredException): diff --git a/tests/unit/test_di.py b/tests/unit/test_di.py index cdcccb8dcc..781e805cb1 100644 --- a/tests/unit/test_di.py +++ b/tests/unit/test_di.py @@ -152,7 +152,7 @@ def test_dependency_has_async_callable(dep: Any, exp: bool) -> None: def test_raises_when_dependency_is_not_callable() -> None: with pytest.raises(ImproperlyConfiguredException): - Provide(123) # type: ignore + Provide(123) # type: ignore[arg-type] @pytest.mark.parametrize( diff --git a/tests/unit/test_events.py b/tests/unit/test_events.py index c88611395d..082543b25b 100644 --- a/tests/unit/test_events.py +++ b/tests/unit/test_events.py @@ -95,7 +95,7 @@ def route_handler(request: Request[Any, Any, Any], event_id: int) -> None: async def test_raises_when_decorator_called_without_callable() -> None: with pytest.raises(ImproperlyConfiguredException): - listener("test_even")(True) # type: ignore + listener("test_even")(True) # type: ignore[arg-type] async def test_raises_when_not_initialized() -> None: diff --git a/tests/unit/test_guards.py b/tests/unit/test_guards.py index f630020f51..8ab76e0115 100644 --- a/tests/unit/test_guards.py +++ b/tests/unit/test_guards.py @@ -96,7 +96,7 @@ def http_route_handler() -> None: assert ( len( app.asgi_router.root_route_map_node.children["/http"] - .asgi_handlers["GET"][1] # type: ignore + .asgi_handlers["GET"][1] # type: ignore[arg-type] ._resolved_guards ) == 2 @@ -104,7 +104,7 @@ def http_route_handler() -> None: assert ( len( app.asgi_router.root_route_map_node.children["/router/http"] - .asgi_handlers["GET"][1] # type: ignore + .asgi_handlers["GET"][1] # type: ignore[arg-type] ._resolved_guards ) == 3 diff --git a/tests/unit/test_handlers/test_asgi_handlers/test_validations.py b/tests/unit/test_handlers/test_asgi_handlers/test_validations.py index 55ef2b7bd9..dcebc1eb17 100644 --- a/tests/unit/test_handlers/test_asgi_handlers/test_validations.py +++ b/tests/unit/test_handlers/test_asgi_handlers/test_validations.py @@ -44,4 +44,4 @@ def sync_fn(scope: "Scope", receive: "Receive", send: "Send") -> None: return None with pytest.raises(ImproperlyConfiguredException): - asgi(path="/")(sync_fn).on_registration(Litestar()) # type: ignore + asgi(path="/")(sync_fn).on_registration(Litestar()) # type: ignore[arg-type] diff --git a/tests/unit/test_handlers/test_base_handlers/test_opt.py b/tests/unit/test_handlers/test_base_handlers/test_opt.py index 49a3e3b0dc..5cab773cbe 100644 --- a/tests/unit/test_handlers/test_base_handlers/test_opt.py +++ b/tests/unit/test_handlers/test_base_handlers/test_opt.py @@ -46,7 +46,7 @@ async def socket_handler(socket: "WebSocket") -> None: ) def test_opt_settings(decorator: "RouteHandlerType", handler: Callable) -> None: base_opt = {"base": 1, "kwarg_value": 0} - result = decorator("/", opt=base_opt, kwarg_value=2)(handler) # type: ignore + result = decorator("/", opt=base_opt, kwarg_value=2)(handler) # type: ignore[arg-type, call-arg] assert result.opt == {"base": 1, "kwarg_value": 2} diff --git a/tests/unit/test_handlers/test_http_handlers/test_validations.py b/tests/unit/test_handlers/test_http_handlers/test_validations.py index 9665b93c3e..5395c77295 100644 --- a/tests/unit/test_handlers/test_http_handlers/test_validations.py +++ b/tests/unit/test_handlers/test_http_handlers/test_validations.py @@ -20,11 +20,11 @@ def test_route_handler_validation_http_method() -> None: # doesn't raise for http methods for value in (*list(HttpMethod), *[x.upper() for x in list(HttpMethod)]): - assert route(http_method=value) # type: ignore + assert route(http_method=value) # type: ignore[arg-type, truthy-bool] # raises for invalid values with pytest.raises(ValidationException): - HTTPRouteHandler(http_method="deleze") # type: ignore + HTTPRouteHandler(http_method="deleze") # type: ignore[arg-type] # also when passing an empty list with pytest.raises(ImproperlyConfiguredException): @@ -32,14 +32,14 @@ def test_route_handler_validation_http_method() -> None: # also when passing malformed tokens with pytest.raises(ValidationException): - route(http_method=[HttpMethod.GET, "poft"], status_code=HTTP_200_OK) # type: ignore + route(http_method=[HttpMethod.GET, "poft"], status_code=HTTP_200_OK) # type: ignore[list-item] async def test_function_validation() -> None: with pytest.raises(ImproperlyConfiguredException): @get(path="/") - def method_with_no_annotation(): # type: ignore + def method_with_no_annotation(): # type: ignore[no-untyped-def] pass Litestar(route_handlers=[method_with_no_annotation]) @@ -105,7 +105,7 @@ def test_function_1(socket: WebSocket) -> None: with pytest.raises(ImproperlyConfiguredException): @get("/person") - def test_function_2(self, data: DataclassPerson) -> None: # type: ignore + def test_function_2(self, data: DataclassPerson) -> None: # type: ignore[no-untyped-def] return None Litestar(route_handlers=[test_function_2]) diff --git a/tests/unit/test_handlers/test_websocket_handlers/test_kwarg_handling.py b/tests/unit/test_handlers/test_websocket_handlers/test_kwarg_handling.py index 8a997d8df0..cf45e0d9ec 100644 --- a/tests/unit/test_handlers/test_websocket_handlers/test_kwarg_handling.py +++ b/tests/unit/test_handlers/test_websocket_handlers/test_kwarg_handling.py @@ -29,7 +29,7 @@ async def websocket_handler( client = create_test_client(route_handlers=websocket_handler) # Set cookies on the client to avoid warnings about per-request cookies. - client.cookies = {"cookie": "yum"} # type: ignore + client.cookies = {"cookie": "yum"} # type: ignore[assignment] with client.websocket_connect("/1?qp=1", headers={"some-header": "abc"}) as ws: ws.send_json({"data": "123"}) diff --git a/tests/unit/test_handlers/test_websocket_handlers/test_validations.py b/tests/unit/test_handlers/test_websocket_handlers/test_validations.py index b64b563f1b..a0c043f023 100644 --- a/tests/unit/test_handlers/test_websocket_handlers/test_validations.py +++ b/tests/unit/test_handlers/test_websocket_handlers/test_validations.py @@ -12,7 +12,7 @@ def fn_without_socket_arg(websocket: WebSocket) -> None: pass with pytest.raises(ImproperlyConfiguredException): - websocket(path="/")(fn_without_socket_arg).on_registration(Litestar()) # type: ignore + websocket(path="/")(fn_without_socket_arg).on_registration(Litestar()) # type: ignore[arg-type] def test_raises_for_return_annotation() -> None: @@ -33,7 +33,7 @@ def test_raises_when_no_function() -> None: def test_raises_when_sync_handler_user() -> None: with pytest.raises(ImproperlyConfiguredException): - @websocket(path="/") # type: ignore + @websocket(path="/") # type: ignore[arg-type] def sync_websocket_handler(socket: WebSocket) -> None: ... diff --git a/tests/unit/test_kwargs/test_cookie_params.py b/tests/unit/test_kwargs/test_cookie_params.py index 1d13716d4c..0a23c6eafa 100644 --- a/tests/unit/test_kwargs/test_cookie_params.py +++ b/tests/unit/test_kwargs/test_cookie_params.py @@ -27,12 +27,12 @@ def test_cookie_params(t_type: Type, param_dict: dict, param: ParameterKwarg, ex test_path = "/test" @get(path=test_path) - def test_method(special_cookie: t_type = param) -> None: # type: ignore + def test_method(special_cookie: t_type = param) -> None: # type: ignore[valid-type] if special_cookie: - assert special_cookie in (param_dict.get("special-cookie"), int(param_dict.get("special-cookie"))) # type: ignore + assert special_cookie in (param_dict.get("special-cookie"), int(param_dict.get("special-cookie"))) # type: ignore[arg-type] with create_test_client(test_method) as client: # Set cookies on the client to avoid warnings about per-request cookies. - client.cookies = param_dict # type: ignore + client.cookies = param_dict # type: ignore[assignment] response = client.get(test_path) assert response.status_code == expected_code, response.json() diff --git a/tests/unit/test_kwargs/test_header_params.py b/tests/unit/test_kwargs/test_header_params.py index 1e71d17512..8a654ba779 100644 --- a/tests/unit/test_kwargs/test_header_params.py +++ b/tests/unit/test_kwargs/test_header_params.py @@ -27,9 +27,9 @@ def test_header_params( test_path = "/test" @get(path=test_path) - def test_method(special_header: t_type = param) -> None: # type: ignore + def test_method(special_header: t_type = param) -> None: # type: ignore[valid-type] if special_header: - assert special_header in (param_dict.get("special-header"), int(param_dict.get("special-header"))) # type: ignore + assert special_header in (param_dict.get("special-header"), int(param_dict.get("special-header"))) # type: ignore[arg-type] with create_test_client(test_method) as client: response = client.get(test_path, headers=param_dict) diff --git a/tests/unit/test_kwargs/test_layered_params.py b/tests/unit/test_kwargs/test_layered_params.py index 201ad78277..e3dff0159a 100644 --- a/tests/unit/test_kwargs/test_layered_params.py +++ b/tests/unit/test_kwargs/test_layered_params.py @@ -51,7 +51,7 @@ def my_handler( }, ) as client: # Set cookies on the client to avoid warnings about per-request cookies. - client.cookies = {"app4": "jeronimo"} # type: ignore + client.cookies = {"app4": "jeronimo"} # type: ignore[assignment] query = {"controller1": "99", "controller3": "tuna", "router1": "albatross", "app2": ["x", "y"]} headers = {"router3": "10"} @@ -110,12 +110,12 @@ def my_handler(self) -> dict: query.pop(parameter) # Set cookies on the client to avoid warnings about per-request cookies. - client.cookies = cookies # type: ignore + client.cookies = cookies # type: ignore[assignment] response = client.get("/router/controller/1", params=query, headers=headers) assert response.status_code == HTTP_400_BAD_REQUEST - assert response.json()["detail"].startswith(f"Missing required {param_type} parameter '{parameter}' for url") + assert response.json()["detail"].startswith(f"Missing required {param_type} parameter '{parameter}' for path") def test_layered_parameters_defaults_and_overrides() -> None: diff --git a/tests/unit/test_kwargs/test_multipart_data.py b/tests/unit/test_kwargs/test_multipart_data.py index 25ddd33808..50552c959d 100644 --- a/tests/unit/test_kwargs/test_multipart_data.py +++ b/tests/unit/test_kwargs/test_multipart_data.py @@ -90,7 +90,7 @@ def test_request_body_multi_part(t_type: type) -> None: data = asdict(Form(name="Moishe Zuchmir", age=30, programmer=True, value="100")) @post(path=test_path, signature_namespace={"t_type": t_type}) - def test_method(data: Annotated[t_type, Body(media_type=RequestEncodingType.MULTI_PART)]) -> None: # type: ignore + def test_method(data: Annotated[t_type, Body(media_type=RequestEncodingType.MULTI_PART)]) -> None: # type: ignore[valid-type] assert data with create_test_client(test_method) as client: diff --git a/tests/unit/test_kwargs/test_reserved_kwargs_injection.py b/tests/unit/test_kwargs/test_reserved_kwargs_injection.py index 7bba51cc3d..4fdc12205f 100644 --- a/tests/unit/test_kwargs/test_reserved_kwargs_injection.py +++ b/tests/unit/test_kwargs/test_reserved_kwargs_injection.py @@ -49,10 +49,10 @@ def route_handler(state: ImmutableState) -> str: @pytest.mark.parametrize("state_typing", (State, CustomState)) def test_application_state_injection(state_typing: Type[State]) -> None: @get("/", media_type=MediaType.TEXT) - def route_handler(state: state_typing) -> str: # type: ignore + def route_handler(state: state_typing) -> str: # type: ignore[valid-type] assert state - state.called = True # type: ignore - return cast("str", state.msg) # type: ignore + state.called = True # type: ignore[attr-defined] + return cast("str", state.msg) # type: ignore[attr-defined] with create_test_client(route_handler, state=State({"called": False})) as client: client.app.state.msg = "hello" diff --git a/tests/unit/test_logging/test_logging_config.py b/tests/unit/test_logging/test_logging_config.py index e850305d5a..2e363d45cb 100644 --- a/tests/unit/test_logging/test_logging_config.py +++ b/tests/unit/test_logging/test_logging_config.py @@ -116,7 +116,7 @@ def test_get_picologging_logger() -> None: def test_connection_logger(handlers: Any, listener: Any) -> None: @get("/") def handler(request: Request) -> Dict[str, bool]: - return {"isinstance": isinstance(request.logger.handlers[0], listener)} # type: ignore + return {"isinstance": isinstance(request.logger.handlers[0], listener)} # type: ignore[attr-defined] with create_test_client(route_handlers=[handler], logging_config=LoggingConfig(handlers=handlers)) as client: response = client.get("/") @@ -141,7 +141,7 @@ def test_root_logger(handlers: Any, listener: Any) -> None: logging_config = LoggingConfig(handlers=handlers) get_logger = logging_config.configure() root_logger = get_logger() - assert isinstance(root_logger.handlers[0], listener) # type: ignore + assert isinstance(root_logger.handlers[0], listener) # type: ignore[attr-defined] @pytest.mark.parametrize( @@ -183,4 +183,4 @@ def test_customizing_handler(handlers: Any, listener: Any, monkeypatch: pytest.M logging_config = LoggingConfig(handlers=handlers) get_logger = logging_config.configure() root_logger = get_logger() - assert isinstance(root_logger.handlers[0], listener) # type: ignore + assert isinstance(root_logger.handlers[0], listener) # type: ignore[attr-defined] diff --git a/tests/unit/test_middleware/test_allowed_hosts_middleware.py b/tests/unit/test_middleware/test_allowed_hosts_middleware.py index 7da16afcb5..4e79154c68 100644 --- a/tests/unit/test_middleware/test_allowed_hosts_middleware.py +++ b/tests/unit/test_middleware/test_allowed_hosts_middleware.py @@ -35,12 +35,12 @@ def handler() -> None: assert len(unpacked_middleware) == 4 allowed_hosts_middleware = cast("Any", unpacked_middleware[1]) assert isinstance(allowed_hosts_middleware, AllowedHostsMiddleware) - assert allowed_hosts_middleware.allowed_hosts_regex.pattern == ".*\\.example.com$|moishe.zuchmir.com" # type: ignore + assert allowed_hosts_middleware.allowed_hosts_regex.pattern == ".*\\.example.com$|moishe.zuchmir.com" # type: ignore[union-attr] def test_allowed_hosts_middleware_hosts_regex() -> None: config = AllowedHostsConfig(allowed_hosts=["*.example.com", "moishe.zuchmir.com"]) - middleware = AllowedHostsMiddleware(app=DummyApp(), config=config) # type: ignore + middleware = AllowedHostsMiddleware(app=DummyApp(), config=config) # type: ignore[abstract] assert middleware.allowed_hosts_regex is not None assert middleware.allowed_hosts_regex.pattern == ".*\\.example.com$|moishe.zuchmir.com" @@ -59,7 +59,7 @@ def test_allowed_hosts_middleware_redirect_regex() -> None: config = AllowedHostsConfig( allowed_hosts=["*.example.com", "www.moishe.zuchmir.com", "www.yada.bada.bing.io", "example.com"] ) - middleware = AllowedHostsMiddleware(app=DummyApp(), config=config) # type: ignore + middleware = AllowedHostsMiddleware(app=DummyApp(), config=config) # type: ignore[abstract] assert middleware.redirect_domains is not None assert middleware.redirect_domains.pattern == "moishe.zuchmir.com|yada.bada.bing.io" @@ -75,23 +75,23 @@ def handler() -> dict: config = AllowedHostsConfig(allowed_hosts=["*.example.com", "moishe.zuchmir.com"]) with create_test_client(handler, allowed_hosts=config) as client: - client.base_url = "http://x.example.com" # type: ignore + client.base_url = "http://x.example.com" # type: ignore[assignment] response = client.get("/") assert response.status_code == HTTP_200_OK - client.base_url = "http://x.y.example.com" # type: ignore + client.base_url = "http://x.y.example.com" # type: ignore[assignment] response = client.get("/") assert response.status_code == HTTP_200_OK - client.base_url = "http://moishe.zuchmir.com" # type: ignore + client.base_url = "http://moishe.zuchmir.com" # type: ignore[assignment] response = client.get("/") assert response.status_code == HTTP_200_OK - client.base_url = "http://x.moishe.zuchmir.com" # type: ignore + client.base_url = "http://x.moishe.zuchmir.com" # type: ignore[assignment] response = client.get("/") assert response.status_code == HTTP_400_BAD_REQUEST - client.base_url = "http://x.example.x.com" # type: ignore + client.base_url = "http://x.example.x.com" # type: ignore[assignment] response = client.get("/") assert response.status_code == HTTP_400_BAD_REQUEST @@ -105,7 +105,7 @@ def handler() -> dict: config = AllowedHostsConfig(allowed_hosts=["*", "*.example.com", "moishe.zuchmir.com"]) with create_test_client(handler, allowed_hosts=config) as client: - client.base_url = "http://any.domain.allowed.com" # type: ignore + client.base_url = "http://any.domain.allowed.com" # type: ignore[assignment] response = client.get("/") assert response.status_code == HTTP_200_OK @@ -118,7 +118,7 @@ def handler() -> dict: config = AllowedHostsConfig(allowed_hosts=["www.moishe.zuchmir.com"]) with create_test_client(handler, allowed_hosts=config) as client: - client.base_url = "http://moishe.zuchmir.com" # type: ignore + client.base_url = "http://moishe.zuchmir.com" # type: ignore[assignment] response = client.get("/") assert response.status_code == HTTP_200_OK assert str(response.url) == "http://www.moishe.zuchmir.com/" @@ -132,7 +132,7 @@ def handler() -> dict: config = AllowedHostsConfig(allowed_hosts=["www.moishe.zuchmir.com"], www_redirect=False) with create_test_client(handler, allowed_hosts=config) as client: - client.base_url = "http://moishe.zuchmir.com" # type: ignore + client.base_url = "http://moishe.zuchmir.com" # type: ignore[assignment] response = client.get("/") assert response.status_code == HTTP_400_BAD_REQUEST diff --git a/tests/unit/test_middleware/test_exception_handler_middleware.py b/tests/unit/test_middleware/test_exception_handler_middleware.py index 2c0b1334d8..e698131b6d 100644 --- a/tests/unit/test_middleware/test_exception_handler_middleware.py +++ b/tests/unit/test_middleware/test_exception_handler_middleware.py @@ -123,7 +123,7 @@ def exception_handler(request: Request, exc: Exception) -> Response: return Response(content={"an": "error"}, status_code=HTTP_500_INTERNAL_SERVER_ERROR) app = Litestar(route_handlers=[handler], exception_handlers={Exception: exception_handler}, openapi_config=None) - assert app.asgi_router.root_route_map_node.children["/"].asgi_handlers["GET"][0].exception_handlers == { # type: ignore + assert app.asgi_router.root_route_map_node.children["/"].asgi_handlers["GET"][0].exception_handlers == { # type: ignore[attr-defined] Exception: exception_handler, StarletteHTTPException: _starlette_exception_handler, } diff --git a/tests/unit/test_middleware/test_logging_middleware.py b/tests/unit/test_middleware/test_logging_middleware.py index 761a088c62..6ce6e5dea1 100644 --- a/tests/unit/test_middleware/test_logging_middleware.py +++ b/tests/unit/test_middleware/test_logging_middleware.py @@ -45,10 +45,10 @@ def handler_fn() -> Response: def test_logging_middleware_config_validation() -> None: with pytest.raises(ImproperlyConfiguredException): - LoggingMiddlewareConfig(response_log_fields=None) # type: ignore + LoggingMiddlewareConfig(response_log_fields=None) # type: ignore[arg-type] with pytest.raises(ImproperlyConfiguredException): - LoggingMiddlewareConfig(request_log_fields=None) # type: ignore + LoggingMiddlewareConfig(request_log_fields=None) # type: ignore[arg-type] def test_logging_middleware_regular_logger( @@ -59,7 +59,7 @@ def test_logging_middleware_regular_logger( ) as client, caplog.at_level(INFO): # Set cookies on the client to avoid warnings about per-request cookies. client.app.get_logger = get_logger - client.cookies = {"request-cookie": "abc"} # type: ignore + client.cookies = {"request-cookie": "abc"} # type: ignore[assignment] response = client.get("/", headers={"request-header": "1"}) assert response.status_code == HTTP_200_OK assert len(caplog.messages) == 2 @@ -80,7 +80,7 @@ def test_logging_middleware_struct_logger(handler: HTTPRouteHandler) -> None: logging_config=StructLoggingConfig(), ) as client, capture_logs() as cap_logs: # Set cookies on the client to avoid warnings about per-request cookies. - client.cookies = {"request-cookie": "abc"} # type: ignore + client.cookies = {"request-cookie": "abc"} # type: ignore[assignment] response = client.get("/", headers={"request-header": "1"}) assert response.status_code == HTTP_200_OK assert len(cap_logs) == 2 @@ -126,7 +126,7 @@ def handler2() -> None: route_handlers=[handler, handler2], middleware=[config.middleware] ) as client, caplog.at_level(INFO): # Set cookies on the client to avoid warnings about per-request cookies. - client.cookies = {"request-cookie": "abc"} # type: ignore + client.cookies = {"request-cookie": "abc"} # type: ignore[assignment] client.app.get_logger = get_logger response = client.get("/exclude") @@ -150,7 +150,7 @@ def handler2() -> None: route_handlers=[handler, handler2], middleware=[config.middleware] ) as client, caplog.at_level(INFO): # Set cookies on the client to avoid warnings about per-request cookies. - client.cookies = {"request-cookie": "abc"} # type: ignore + client.cookies = {"request-cookie": "abc"} # type: ignore[assignment] client.app.get_logger = get_logger response = client.get("/exclude") @@ -172,7 +172,7 @@ def test_logging_middleware_compressed_response_body( middleware=[LoggingMiddlewareConfig(include_compressed_body=include).middleware], ) as client, caplog.at_level(INFO): # Set cookies on the client to avoid warnings about per-request cookies. - client.cookies = {"request-cookie": "abc"} # type: ignore + client.cookies = {"request-cookie": "abc"} # type: ignore[assignment] client.app.get_logger = get_logger response = client.get("/", headers={"request-header": "1"}) assert response.status_code == HTTP_200_OK @@ -249,7 +249,7 @@ def test_logging_middleware_log_fields( ) as client, caplog.at_level(INFO): # Set cookies on the client to avoid warnings about per-request cookies. client.app.get_logger = get_logger - client.cookies = {"request-cookie": "abc"} # type: ignore + client.cookies = {"request-cookie": "abc"} # type: ignore[assignment] response = client.get("/", headers={"request-header": "1"}) assert response.status_code == HTTP_200_OK assert len(caplog.messages) == 2 diff --git a/tests/unit/test_middleware/test_middleware_handling.py b/tests/unit/test_middleware/test_middleware_handling.py index 606a3cca02..803c4064e9 100644 --- a/tests/unit/test_middleware/test_middleware_handling.py +++ b/tests/unit/test_middleware/test_middleware_handling.py @@ -34,9 +34,9 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No class BaseMiddlewareRequestLoggingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # type: ignore + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # type: ignore[explicit-override, override] logging.getLogger(__name__).info("%s - %s", request.method, request.url) - return await call_next(request) # type: ignore + return await call_next(request) # type: ignore[arg-type, return-value] class MiddlewareWithArgsAndKwargs(BaseHTTPMiddleware): @@ -45,7 +45,7 @@ def __init__(self, arg: int = 0, *, app: Any, kwarg: str) -> None: self.arg = arg self.kwarg = kwarg - async def dispatch( # type: ignore + async def dispatch( # type: ignore[empty-body, explicit-override, override] self, request: Request, call_next: Callable[[Request], Awaitable[Response]] ) -> Response: ... diff --git a/tests/unit/test_middleware/test_rate_limit_middleware.py b/tests/unit/test_middleware/test_rate_limit_middleware.py index 24ace66c69..e2a30af9a4 100644 --- a/tests/unit/test_middleware/test_rate_limit_middleware.py +++ b/tests/unit/test_middleware/test_rate_limit_middleware.py @@ -120,6 +120,7 @@ def handler() -> None: assert response.status_code == HTTP_200_OK +@travel(datetime.utcnow, tick=False) def test_exclude_patterns() -> None: @get("/excluded") def handler() -> None: @@ -145,6 +146,7 @@ def handler2() -> None: assert response.status_code == HTTP_429_TOO_MANY_REQUESTS +@travel(datetime.utcnow, tick=False) def test_exclude_opt_key() -> None: @get("/excluded", skip_rate_limiting=True) def handler() -> None: @@ -170,6 +172,7 @@ def handler2() -> None: assert response.status_code == HTTP_429_TOO_MANY_REQUESTS +@travel(datetime.utcnow, tick=False) def test_check_throttle_handler() -> None: @get("/path1") def handler1() -> None: @@ -198,6 +201,7 @@ def check_throttle_handler(request: Request[Any, Any, Any]) -> bool: assert response.status_code == HTTP_200_OK +@travel(datetime.utcnow, tick=False) async def test_rate_limiting_works_with_mounted_apps(tmpdir: "Path") -> None: # https://github.com/litestar-org/litestar/issues/781 @get("/not-excluded") diff --git a/tests/unit/test_middleware/test_session/test_middleware.py b/tests/unit/test_middleware/test_session/test_middleware.py index fb60bb15e0..ef0aa89d11 100644 --- a/tests/unit/test_middleware/test_session/test_middleware.py +++ b/tests/unit/test_middleware/test_session/test_middleware.py @@ -1,13 +1,13 @@ -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional, Union from litestar import HttpMethod, Request, Response, get, post, route +from litestar.middleware.session.server_side import ServerSideSessionConfig from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR from litestar.testing import create_test_client from litestar.types import Empty if TYPE_CHECKING: from litestar.middleware.session.base import BaseBackendConfig - from litestar.middleware.session.server_side import ServerSideSessionConfig def test_session_middleware_not_installed_raises() -> None: @@ -37,6 +37,7 @@ def session_handler(request: Request) -> Optional[Dict[str, bool]]: with create_test_client(route_handlers=[session_handler], middleware=[session_backend_config.middleware]) as client: response = client.get("/session") assert response.json() == {"has_session": False} + first_session_id = client.cookies.get("session") client.post("/session") @@ -52,6 +53,57 @@ def session_handler(request: Request) -> Optional[Dict[str, bool]]: response = client.get("/session") assert response.json() == {"has_session": True} + second_session_id = client.cookies.get("session") + assert first_session_id != second_session_id + + +def test_session_id_correctness(session_backend_config: "BaseBackendConfig") -> None: + # Test that `request.get_session_id()` is the same as in the cookies + @route("/session", http_method=[HttpMethod.POST]) + def session_handler(request: Request) -> Optional[Dict[str, Union[str, None]]]: + request.set_session({"foo": "bar"}) + return {"session_id": request.get_session_id()} + + with create_test_client(route_handlers=[session_handler], middleware=[session_backend_config.middleware]) as client: + if isinstance(session_backend_config, ServerSideSessionConfig): + # Generic verification that a session id is set before entering the route handler scope + response = client.post("/session") + request_session_id = response.json()["session_id"] + cookie_session_id = client.cookies.get("session") + assert request_session_id == cookie_session_id + else: + # Client side config does not have a session id in cookies + response = client.post("/session") + assert response.json()["session_id"] is None + assert client.cookies.get("session") is not None + response = client.post("/session") + assert response.json()["session_id"] is None + assert client.cookies.get("session") is not None + + +def test_keep_session_id(session_backend_config: "BaseBackendConfig") -> None: + # Test that session is only created if not already exists + @route("/session", http_method=[HttpMethod.POST]) + def session_handler(request: Request) -> Optional[Dict[str, Union[str, None]]]: + request.set_session({"foo": "bar"}) + return {"session_id": request.get_session_id()} + + with create_test_client(route_handlers=[session_handler], middleware=[session_backend_config.middleware]) as client: + if isinstance(session_backend_config, ServerSideSessionConfig): + # Generic verification that a session id is set before entering the route handler scope + response = client.post("/session") + first_call_id = response.json()["session_id"] + response = client.post("/session") + second_call_id = response.json()["session_id"] + assert first_call_id == second_call_id == client.cookies.get("session") + else: + # Client side config does not have a session id in cookies + response = client.post("/session") + assert response.json()["session_id"] is None + assert client.cookies.get("session") is not None + response = client.post("/session") + assert response.json()["session_id"] is None + assert client.cookies.get("session") is not None def test_set_empty(session_backend_config: "BaseBackendConfig") -> None: diff --git a/tests/unit/test_openapi/test_integration.py b/tests/unit/test_openapi/test_integration.py index 7176344769..30b0a7a1bc 100644 --- a/tests/unit/test_openapi/test_integration.py +++ b/tests/unit/test_openapi/test_integration.py @@ -9,7 +9,7 @@ import yaml from typing_extensions import Annotated -from litestar import Controller, Litestar, get, post +from litestar import Controller, Litestar, delete, get, patch, post from litestar._openapi.plugin import OpenAPIPlugin from litestar.app import DEFAULT_OPENAPI_CONFIG from litestar.enums import MediaType, OpenAPIMediaType, ParamType @@ -371,3 +371,75 @@ def get_handler() -> None: path_item = openapi.paths["/"] assert path_item.get is not None assert path_item.post is not None + + +@pytest.mark.parametrize(("random_seed_one", "random_seed_two", "should_be_equal"), [(10, 10, True), (10, 20, False)]) +def test_seeding(random_seed_one: int, random_seed_two: int, should_be_equal: bool) -> None: + @post("/", sync_to_thread=False) + def post_handler(q: str) -> None: + ... + + @get("/", sync_to_thread=False) + def get_handler(q: str) -> None: + ... + + app = Litestar( + [get_handler, post_handler], openapi_config=OpenAPIConfig("Litestar", "v0.0.1", True, random_seed_one) + ) + openapi_plugin = app.plugins.get(OpenAPIPlugin) + openapi_one = openapi_plugin.provide_openapi() + + app = Litestar( + [get_handler, post_handler], openapi_config=OpenAPIConfig("Litestar", "v0.0.1", True, random_seed_two) + ) + openapi_plugin = app.plugins.get(OpenAPIPlugin) + openapi_two = openapi_plugin.provide_openapi() + + if should_be_equal: + assert openapi_one == openapi_two + else: + assert openapi_one != openapi_two + + +def test_components_schemas_in_alphabetical_order() -> None: + # https://github.com/litestar-org/litestar/issues/3059 + + @dataclass + class A: + ... + + @dataclass + class B: + ... + + @dataclass + class C: + ... + + class TestController(Controller): + @post("/", sync_to_thread=False) + def post_handler(self, data: B) -> None: + ... + + @get("/", sync_to_thread=False) + def get_handler(self) -> A: # type: ignore[empty-body] + ... + + @patch("/", sync_to_thread=False) + def patch_handler(self, data: C) -> A: # type: ignore[empty-body] + ... + + @delete("/", sync_to_thread=False) + def delete_handler(self, data: B) -> None: + ... + + app = Litestar([TestController], signature_types=[A, B, C]) + openapi_plugin = app.plugins.get(OpenAPIPlugin) + openapi = openapi_plugin.provide_openapi() + + expected_keys = [ + "test_components_schemas_in_alphabetical_order.A", + "test_components_schemas_in_alphabetical_order.B", + "test_components_schemas_in_alphabetical_order.C", + ] + assert list(openapi.components.schemas.keys()) == expected_keys diff --git a/tests/unit/test_openapi/test_parameters.py b/tests/unit/test_openapi/test_parameters.py index 6439322883..8940c3d0a3 100644 --- a/tests/unit/test_openapi/test_parameters.py +++ b/tests/unit/test_openapi/test_parameters.py @@ -38,7 +38,7 @@ def create_factory(route: BaseRoute, handler: HTTPRouteHandler) -> ParameterFact def _create_parameters(app: Litestar, path: str) -> List["OpenAPIParameter"]: index = find_index(app.routes, lambda x: x.path_format == path) route = app.routes[index] - route_handler = route.route_handler_map["GET"][0] # type: ignore + route_handler = route.route_handler_map["GET"][0] # type: ignore[union-attr] handler = route_handler.fn assert callable(handler) return create_factory(route, route_handler).create_parameters_for_handler() @@ -167,10 +167,10 @@ def handler(a: ADep, b: BDep, c: float, d: float) -> str: app = Litestar(route_handlers=[handler]) assert isinstance(app.openapi_schema, OpenAPI) - open_api_path_item = app.openapi_schema.paths["/test"] # type: ignore - open_api_parameters = open_api_path_item.get.parameters # type: ignore - assert len(open_api_parameters) == 2 # type: ignore - assert {p.name for p in open_api_parameters} == {"query_param", "other_param"} # type: ignore + open_api_path_item = app.openapi_schema.paths["/test"] # type: ignore[index] + open_api_parameters = open_api_path_item.get.parameters # type: ignore[union-attr] + assert len(open_api_parameters) == 2 # type: ignore[arg-type] + assert {p.name for p in open_api_parameters} == {"query_param", "other_param"} # type: ignore[union-attr] def test_raise_for_multiple_parameters_of_same_name_and_differing_types() -> None: @@ -199,7 +199,7 @@ def handler(dep: Optional[int] = Dependency()) -> None: return None app = Litestar(route_handlers=[handler]) - param_name_set = {p.name for p in cast("OpenAPI", app.openapi_schema).paths["/"].get.parameters} # type: ignore + param_name_set = {p.name for p in cast("OpenAPI", app.openapi_schema).paths["/"].get.parameters} # type: ignore[index, redundant-cast, union-attr] assert "dep" not in param_name_set assert "param" in param_name_set @@ -210,7 +210,7 @@ def handler(dep: Optional[int] = Dependency()) -> None: return None app = Litestar(route_handlers=[handler]) - assert cast("OpenAPI", app.openapi_schema).paths["/"].get.parameters is None # type: ignore + assert cast("OpenAPI", app.openapi_schema).paths["/"].get.parameters is None # type: ignore[index, redundant-cast, union-attr] def test_non_dependency_in_doc_params_if_not_provided() -> None: @@ -219,7 +219,7 @@ def handler(param: Optional[int]) -> None: return None app = Litestar(route_handlers=[handler]) - param_name_set = {p.name for p in cast("OpenAPI", app.openapi_schema).paths["/"].get.parameters} # type: ignore + param_name_set = {p.name for p in cast("OpenAPI", app.openapi_schema).paths["/"].get.parameters} # type: ignore[index, redundant-cast, union-attr] assert "param" in param_name_set @@ -267,48 +267,48 @@ def my_handler( local, app3, controller1, router1, router3, app4, app2, controller3 = tuple(parameters) assert app4.param_in == ParamType.COOKIE - assert app4.schema.type == OpenAPIType.STRING # type: ignore + assert app4.schema.type == OpenAPIType.STRING # type: ignore[union-attr] assert app4.required - assert app4.schema.examples # type: ignore + assert app4.schema.examples # type: ignore[union-attr] assert app2.param_in == ParamType.QUERY - assert app2.schema.type == OpenAPIType.ARRAY # type: ignore + assert app2.schema.type == OpenAPIType.ARRAY # type: ignore[union-attr] assert app2.required - assert app2.schema.examples # type: ignore + assert app2.schema.examples # type: ignore[union-attr] assert app3.param_in == ParamType.QUERY - assert app3.schema.type == OpenAPIType.BOOLEAN # type: ignore + assert app3.schema.type == OpenAPIType.BOOLEAN # type: ignore[union-attr] assert not app3.required - assert app3.schema.examples # type: ignore + assert app3.schema.examples # type: ignore[union-attr] assert router1.param_in == ParamType.QUERY - assert router1.schema.type == OpenAPIType.STRING # type: ignore + assert router1.schema.type == OpenAPIType.STRING # type: ignore[union-attr] assert router1.required - assert router1.schema.pattern == "^[a-zA-Z]$" # type: ignore - assert router1.schema.examples # type: ignore + assert router1.schema.pattern == "^[a-zA-Z]$" # type: ignore[union-attr] + assert router1.schema.examples # type: ignore[union-attr] assert router3.param_in == ParamType.HEADER - assert router3.schema.type == OpenAPIType.NUMBER # type: ignore + assert router3.schema.type == OpenAPIType.NUMBER # type: ignore[union-attr] assert router3.required - assert router3.schema.multiple_of == 5.0 # type: ignore - assert router3.schema.examples # type: ignore + assert router3.schema.multiple_of == 5.0 # type: ignore[union-attr] + assert router3.schema.examples # type: ignore[union-attr] assert controller1.param_in == ParamType.QUERY - assert controller1.schema.type == OpenAPIType.INTEGER # type: ignore + assert controller1.schema.type == OpenAPIType.INTEGER # type: ignore[union-attr] assert controller1.required - assert controller1.schema.exclusive_maximum == 100.0 # type: ignore - assert controller1.schema.examples # type: ignore + assert controller1.schema.exclusive_maximum == 100.0 # type: ignore[union-attr] + assert controller1.schema.examples # type: ignore[union-attr] assert controller3.param_in == ParamType.QUERY - assert controller3.schema.type == OpenAPIType.NUMBER # type: ignore + assert controller3.schema.type == OpenAPIType.NUMBER # type: ignore[union-attr] assert controller3.required - assert controller3.schema.minimum == 5.0 # type: ignore - assert controller3.schema.examples # type: ignore + assert controller3.schema.minimum == 5.0 # type: ignore[union-attr] + assert controller3.schema.examples # type: ignore[union-attr] assert local.param_in == ParamType.PATH - assert local.schema.type == OpenAPIType.INTEGER # type: ignore + assert local.schema.type == OpenAPIType.INTEGER # type: ignore[union-attr] assert local.required - assert local.schema.examples # type: ignore + assert local.schema.examples # type: ignore[union-attr] def test_parameter_examples() -> None: diff --git a/tests/unit/test_openapi/test_path_item.py b/tests/unit/test_openapi/test_path_item.py index e1033edea7..839e829d85 100644 --- a/tests/unit/test_openapi/test_path_item.py +++ b/tests/unit/test_openapi/test_path_item.py @@ -29,7 +29,7 @@ def route(person_controller: type[Controller]) -> HTTPRoute: @pytest.fixture() def routes_with_router(person_controller: type[Controller]) -> tuple[HTTPRoute, HTTPRoute]: - class PersonControllerV2(person_controller): # type: ignore + class PersonControllerV2(person_controller): # type: ignore[misc, valid-type] pass router_v1 = Router(path="/v1", route_handlers=[person_controller]) @@ -106,7 +106,7 @@ async def root(self, *, request: Request[str, str, Any]) -> None: index = find_index(app.routes, lambda x: x.path_format == "/") route_with_multiple_methods = cast("HTTPRoute", app.routes[index]) factory = create_factory(route_with_multiple_methods) - factory.context.openapi_config.operation_id_creator = lambda x: "abc" # type: ignore + factory.context.openapi_config.operation_id_creator = lambda x: "abc" # type: ignore[assignment, misc] schema = create_factory(route_with_multiple_methods).create_path_item() assert schema.get assert schema.get.operation_id diff --git a/tests/unit/test_openapi/test_request_body.py b/tests/unit/test_openapi/test_request_body.py index 45cc96ed05..1a9b072c96 100644 --- a/tests/unit/test_openapi/test_request_body.py +++ b/tests/unit/test_openapi/test_request_body.py @@ -49,7 +49,7 @@ def _factory(route_handler: BaseRouteHandler, data_field: FieldDefinition) -> Re def test_create_request_body(person_controller: Type[Controller], create_request: RequestBodyFactory) -> None: for route in Litestar(route_handlers=[person_controller]).routes: - for route_handler, _ in route.route_handler_map.values(): # type: ignore + for route_handler, _ in route.route_handler_map.values(): # type: ignore[union-attr] handler_fields = route_handler.parsed_fn_signature.parameters if "data" in handler_fields: request_body = create_request(route_handler, handler_fields["data"]) diff --git a/tests/unit/test_openapi/test_responses.py b/tests/unit/test_openapi/test_responses.py index 1788dd1064..148fe28a71 100644 --- a/tests/unit/test_openapi/test_responses.py +++ b/tests/unit/test_openapi/test_responses.py @@ -28,7 +28,7 @@ from litestar.handlers import HTTPRouteHandler from litestar.openapi.config import OpenAPIConfig from litestar.openapi.datastructures import ResponseSpec -from litestar.openapi.spec import OpenAPIHeader, OpenAPIMediaType, Reference, Schema +from litestar.openapi.spec import Example, OpenAPIHeader, OpenAPIMediaType, Reference, Schema from litestar.openapi.spec.enums import OpenAPIType from litestar.response import File, Redirect, Stream, Template from litestar.response.base import T @@ -440,6 +440,28 @@ def handler() -> DataclassPerson: assert responses["400"].description == "Overwritten response" +def test_additional_responses_with_custom_examples(create_factory: CreateFactoryFixture) -> None: + @get(responses={200: ResponseSpec(DataclassPerson, examples=[Example(value={"string": "example", "number": 1})])}) + def handler() -> DataclassPerson: + return DataclassPersonFactory.build() + + factory = create_factory(handler) + responses = factory.create_additional_responses() + status_code, response = next(responses) + assert response.content + assert response.content["application/json"].examples == { + "dataclassperson-example-1": Example( + value={ + "string": "example", + "number": 1, + } + ), + } + + with pytest.raises(StopIteration): + next(responses) + + def test_create_response_for_response_subclass(create_factory: CreateFactoryFixture) -> None: class CustomResponse(Response[T]): pass @@ -505,4 +527,4 @@ def handler() -> File: return File("test.txt") response = create_factory(handler).create_success_response() - assert next(iter(response.content.values())).schema.content_media_type == expected # type: ignore + assert next(iter(response.content.values())).schema.content_media_type == expected # type: ignore[union-attr] diff --git a/tests/unit/test_openapi/test_schema.py b/tests/unit/test_openapi/test_schema.py index 625287e593..1b05ade837 100644 --- a/tests/unit/test_openapi/test_schema.py +++ b/tests/unit/test_openapi/test_schema.py @@ -288,11 +288,11 @@ class Lookup(msgspec.Struct): schema = get_schema_for_field_definition(FieldDefinition.from_kwarg(name="Lookup", annotation=Lookup)) - assert schema.properties["id"].type == OpenAPIType.STRING # type: ignore - assert schema.properties["id"].examples == {"id-example-1": Example(value="example")} # type: ignore - assert schema.properties["id"].description == "description" # type: ignore - assert schema.properties["id"].title == "title" # type: ignore - assert schema.properties["id"].max_length == 16 # type: ignore + assert schema.properties["id"].type == OpenAPIType.STRING # type: ignore[index, union-attr] + assert schema.properties["id"].examples == {"id-example-1": Example(value="example")} # type: ignore[index, union-attr] + assert schema.properties["id"].description == "description" # type: ignore[index] + assert schema.properties["id"].title == "title" # type: ignore[index, union-attr] + assert schema.properties["id"].max_length == 16 # type: ignore[index, union-attr] assert schema.required == ["id"] @@ -312,20 +312,20 @@ class MyDataclass: schema = get_schema_for_field_definition(FieldDefinition.from_kwarg(name="MyDataclass", annotation=MyDataclass)) - assert schema.properties["constrained_int"].exclusive_minimum == 1 # type: ignore - assert schema.properties["constrained_int"].exclusive_maximum == 10 # type: ignore - assert schema.properties["constrained_float"].minimum == 1 # type: ignore - assert schema.properties["constrained_float"].maximum == 10 # type: ignore - assert datetime.utcfromtimestamp(schema.properties["constrained_date"].exclusive_minimum) == datetime.fromordinal( # type: ignore + assert schema.properties["constrained_int"].exclusive_minimum == 1 # type: ignore[index, union-attr] + assert schema.properties["constrained_int"].exclusive_maximum == 10 # type: ignore[index, union-attr] + assert schema.properties["constrained_float"].minimum == 1 # type: ignore[index, union-attr] + assert schema.properties["constrained_float"].maximum == 10 # type: ignore[index, union-attr] + assert datetime.utcfromtimestamp(schema.properties["constrained_date"].exclusive_minimum) == datetime.fromordinal( # type: ignore[arg-type, index, union-attr] historical_date.toordinal() ) - assert datetime.utcfromtimestamp(schema.properties["constrained_date"].exclusive_maximum) == datetime.fromordinal( # type: ignore + assert datetime.utcfromtimestamp(schema.properties["constrained_date"].exclusive_maximum) == datetime.fromordinal( # type: ignore[arg-type, index, union-attr] today.toordinal() ) - assert schema.properties["constrained_lower_case"].description == "must be in lower case" # type: ignore - assert schema.properties["constrained_upper_case"].description == "must be in upper case" # type: ignore - assert schema.properties["constrained_is_ascii"].pattern == "[[:ascii:]]" # type: ignore - assert schema.properties["constrained_is_digit"].pattern == "[[:digit:]]" # type: ignore + assert schema.properties["constrained_lower_case"].description == "must be in lower case" # type: ignore[index] + assert schema.properties["constrained_upper_case"].description == "must be in upper case" # type: ignore[index] + assert schema.properties["constrained_is_ascii"].pattern == "[[:ascii:]]" # type: ignore[index, union-attr] + assert schema.properties["constrained_is_digit"].pattern == "[[:digit:]]" # type: ignore[index, union-attr] def test_literal_enums() -> None: diff --git a/tests/unit/test_openapi/test_tags.py b/tests/unit/test_openapi/test_tags.py index eb21cc263d..91236a4f56 100644 --- a/tests/unit/test_openapi/test_tags.py +++ b/tests/unit/test_openapi/test_tags.py @@ -47,12 +47,12 @@ def openapi_schema(app: Litestar) -> "OpenAPI": def test_openapi_schema_handler_tags(openapi_schema: "OpenAPI") -> None: - assert openapi_schema.paths["/handler"].get.tags == ["handler"] # type: ignore + assert openapi_schema.paths["/handler"].get.tags == ["handler"] # type: ignore[index, union-attr] def test_openapi_schema_controller_tags(openapi_schema: "OpenAPI") -> None: - assert openapi_schema.paths["/controller"].get.tags == ["a", "controller", "handler"] # type: ignore + assert openapi_schema.paths["/controller"].get.tags == ["a", "controller", "handler"] # type: ignore[index, union-attr] def test_openapi_schema_router_tags(openapi_schema: "OpenAPI") -> None: - assert openapi_schema.paths["/router/controller"].get.tags == ["a", "controller", "handler", "router"] # type: ignore + assert openapi_schema.paths["/router/controller"].get.tags == ["a", "controller", "handler", "router"] # type: ignore[index, union-attr] diff --git a/tests/unit/test_pagination.py b/tests/unit/test_pagination.py index 27d37b8fe0..5a0b63500f 100644 --- a/tests/unit/test_pagination.py +++ b/tests/unit/test_pagination.py @@ -80,11 +80,11 @@ async def get_items(self, limit: int, offset: int) -> List[DataclassPerson]: def test_classic_pagination_data_shape(paginator: Any) -> None: @get("/async") async def async_handler(page_size: int, current_page: int) -> ClassicPagination[DataclassPerson]: - return await paginator(page_size=page_size, current_page=current_page) # type: ignore + return await paginator(page_size=page_size, current_page=current_page) # type: ignore[no-any-return] @get("/sync") def sync_handler(page_size: int, current_page: int) -> ClassicPagination[DataclassPerson]: - return paginator(page_size=page_size, current_page=current_page) # type: ignore + return paginator(page_size=page_size, current_page=current_page) # type: ignore[no-any-return] with create_test_client([async_handler, sync_handler]) as client: if isinstance(paginator, TestSyncClassicPaginator): @@ -104,11 +104,11 @@ def sync_handler(page_size: int, current_page: int) -> ClassicPagination[Datacla def test_classic_pagination_openapi_schema(paginator: Any) -> None: @get("/async") async def async_handler(page_size: int, current_page: int) -> ClassicPagination[DataclassPerson]: - return await paginator(page_size=page_size, current_page=current_page) # type: ignore + return await paginator(page_size=page_size, current_page=current_page) # type: ignore[no-any-return] @get("/sync") def sync_handler(page_size: int, current_page: int) -> ClassicPagination[DataclassPerson]: - return paginator(page_size=page_size, current_page=current_page) # type: ignore + return paginator(page_size=page_size, current_page=current_page) # type: ignore[no-any-return] with create_test_client([async_handler, sync_handler], openapi_config=DEFAULT_OPENAPI_CONFIG) as client: schema = client.app.openapi_schema @@ -137,11 +137,11 @@ def sync_handler(page_size: int, current_page: int) -> ClassicPagination[Datacla def test_limit_offset_pagination_data_shape(paginator: Any) -> None: @get("/async") async def async_handler(limit: int, offset: int) -> OffsetPagination[DataclassPerson]: - return await paginator(limit=limit, offset=offset) # type: ignore + return await paginator(limit=limit, offset=offset) # type: ignore[no-any-return] @get("/sync") def sync_handler(limit: int, offset: int) -> OffsetPagination[DataclassPerson]: - return paginator(limit=limit, offset=offset) # type: ignore + return paginator(limit=limit, offset=offset) # type: ignore[no-any-return] with create_test_client([async_handler, sync_handler]) as client: if isinstance(paginator, TestSyncOffsetPaginator): @@ -161,11 +161,11 @@ def sync_handler(limit: int, offset: int) -> OffsetPagination[DataclassPerson]: def test_limit_offset_pagination_openapi_schema(paginator: Any) -> None: @get("/async") async def async_handler(limit: int, offset: int) -> OffsetPagination[DataclassPerson]: - return await paginator(limit=limit, offset=offset) # type: ignore + return await paginator(limit=limit, offset=offset) # type: ignore[no-any-return] @get("/sync") def sync_handler(limit: int, offset: int) -> OffsetPagination[DataclassPerson]: - return paginator(limit=limit, offset=offset) # type: ignore + return paginator(limit=limit, offset=offset) # type: ignore[no-any-return] with create_test_client([async_handler, sync_handler], openapi_config=DEFAULT_OPENAPI_CONFIG) as client: schema = client.app.openapi_schema @@ -218,11 +218,11 @@ async def get_items( def test_cursor_pagination_data_shape(paginator: Any) -> None: @get("/async") async def async_handler(cursor: Optional[str] = None) -> CursorPagination[str, DataclassPerson]: - return await paginator(cursor=cursor, results_per_page=5) # type: ignore + return await paginator(cursor=cursor, results_per_page=5) # type: ignore[no-any-return] @get("/sync") def sync_handler(cursor: Optional[str] = None) -> CursorPagination[str, DataclassPerson]: - return paginator(cursor=cursor, results_per_page=5) # type: ignore + return paginator(cursor=cursor, results_per_page=5) # type: ignore[no-any-return] with create_test_client([async_handler, sync_handler]) as client: if isinstance(paginator, TestSyncCursorPagination): @@ -241,11 +241,11 @@ def sync_handler(cursor: Optional[str] = None) -> CursorPagination[str, Dataclas def test_cursor_pagination_openapi_schema(paginator: Any) -> None: @get("/async") async def async_handler(cursor: Optional[str] = None) -> CursorPagination[str, DataclassPerson]: - return await paginator(cursor=cursor, results_per_page=5) # type: ignore + return await paginator(cursor=cursor, results_per_page=5) # type: ignore[no-any-return] @get("/sync") def sync_handler(cursor: Optional[str] = None) -> CursorPagination[str, DataclassPerson]: - return paginator(cursor=cursor, results_per_page=5) # type: ignore + return paginator(cursor=cursor, results_per_page=5) # type: ignore[no-any-return] with create_test_client([async_handler, sync_handler], openapi_config=DEFAULT_OPENAPI_CONFIG) as client: schema = client.app.openapi_schema diff --git a/tests/unit/test_request_class_resolution.py b/tests/unit/test_request_class_resolution.py new file mode 100644 index 0000000000..081126c625 --- /dev/null +++ b/tests/unit/test_request_class_resolution.py @@ -0,0 +1,65 @@ +from typing import Optional, Type + +import pytest + +from litestar import Controller, HttpMethod, Litestar, Request, Router, get +from litestar.handlers.http_handlers.base import HTTPRouteHandler + +RouterRequest: Type[Request] = type("RouterRequest", (Request,), {}) +ControllerRequest: Type[Request] = type("ControllerRequest", (Request,), {}) +AppRequest: Type[Request] = type("AppRequest", (Request,), {}) +HandlerRequest: Type[Request] = type("HandlerRequest", (Request,), {}) + + +@pytest.mark.parametrize( + "handler_request_class, controller_request_class, router_request_class, app_request_class, has_default_app_class, expected", + ( + (HandlerRequest, ControllerRequest, RouterRequest, AppRequest, True, HandlerRequest), + (None, ControllerRequest, RouterRequest, AppRequest, True, ControllerRequest), + (None, None, RouterRequest, AppRequest, True, RouterRequest), + (None, None, None, AppRequest, True, AppRequest), + (None, None, None, None, True, Request), + (None, None, None, None, False, Request), + ), + ids=( + "Custom class for all layers", + "Custom class for all above handler layer", + "Custom class for all above controller layer", + "Custom class for all above router layer", + "No custom class for layers", + "No default class in app", + ), +) +def test_request_class_resolution_of_layers( + handler_request_class: Optional[Type[Request]], + controller_request_class: Optional[Type[Request]], + router_request_class: Optional[Type[Request]], + app_request_class: Optional[Type[Request]], + has_default_app_class: bool, + expected: Type[Request], +) -> None: + class MyController(Controller): + @get() + def handler(self, request: Request) -> None: + assert type(request) is expected + + if controller_request_class: + MyController.request_class = ControllerRequest + + router = Router(path="/", route_handlers=[MyController]) + + if router_request_class: + router.request_class = router_request_class + + app = Litestar(route_handlers=[router]) + + if app_request_class or not has_default_app_class: + app.request_class = app_request_class # type: ignore[assignment] + + route_handler: HTTPRouteHandler = app.route_handler_method_map["/"][HttpMethod.GET] # type: ignore[assignment] + + if handler_request_class: + route_handler.request_class = handler_request_class + + request_class = route_handler.resolve_request_class() + assert request_class is expected diff --git a/tests/unit/test_response/test_base_response.py b/tests/unit/test_response/test_base_response.py index ec3e1351e8..804da3f7be 100644 --- a/tests/unit/test_response/test_base_response.py +++ b/tests/unit/test_response/test_base_response.py @@ -167,9 +167,10 @@ def handler() -> Response: ({}, MediaType.TEXT, False), ({"abc": "def"}, MediaType.JSON, False), (Empty, MediaType.JSON, True), + ({"key": "value"}, "application/something+json", False), ), ) -def test_render_method(body: Any, media_type: MediaType, should_raise: bool) -> None: +def test_render_method(body: Any, media_type: str, should_raise: bool) -> None: @get("/", media_type=media_type) def handler() -> Any: return body diff --git a/tests/unit/test_response/test_response_cookies.py b/tests/unit/test_response/test_response_cookies.py index be9b8400bf..b1627aa7a8 100644 --- a/tests/unit/test_response/test_response_cookies.py +++ b/tests/unit/test_response/test_response_cookies.py @@ -37,7 +37,7 @@ def test_method(self) -> None: response_cookies=[app_first, app_second], route_handlers=[first_router, second_router], ) - route_handler, _ = app.routes[0].route_handler_map[HttpMethod.GET] # type: ignore + route_handler, _ = app.routes[0].route_handler_map[HttpMethod.GET] # type: ignore[union-attr] response_cookies = {cookie.key: cookie.value for cookie in route_handler.resolve_response_cookies()} assert response_cookies["first"] == local_first.value assert response_cookies["second"] == controller_second.value diff --git a/tests/unit/test_response/test_response_headers.py b/tests/unit/test_response/test_response_headers.py index 769069d030..adffade596 100644 --- a/tests/unit/test_response/test_response_headers.py +++ b/tests/unit/test_response/test_response_headers.py @@ -38,7 +38,7 @@ def test_method(self) -> None: route_handlers=[first_router, second_router], ) - route_handler, _ = app.routes[0].route_handler_map[HttpMethod.GET] # type: ignore + route_handler, _ = app.routes[0].route_handler_map[HttpMethod.GET] # type: ignore[union-attr] resolved_headers = {header.name: header for header in route_handler.resolve_response_headers()} assert resolved_headers["first"].value == local_first.value assert resolved_headers["second"].value == controller_second.value @@ -184,6 +184,6 @@ def my_handler() -> None: app = Litestar(route_handlers=[my_handler]) - route_handler, _ = app.routes[0].route_handler_map[HttpMethod.GET] # type: ignore + route_handler, _ = app.routes[0].route_handler_map[HttpMethod.GET] # type: ignore[union-attr] resolved_headers = {header.name: header for header in route_handler.resolve_response_headers()} assert resolved_headers[header.HEADER_NAME].value == header.to_header() diff --git a/tests/unit/test_response/test_streaming_response.py b/tests/unit/test_response/test_streaming_response.py index 3619d7ac64..5f5a59a3d8 100644 --- a/tests/unit/test_response/test_streaming_response.py +++ b/tests/unit/test_response/test_streaming_response.py @@ -56,7 +56,7 @@ async def stream_indefinitely() -> AsyncIterator[bytes]: response = ASGIStreamingResponse(iterator=stream_indefinitely()) with anyio.move_on_after(1) as cancel_scope: - await response({}, receive_disconnect, send) # type: ignore + await response({}, receive_disconnect, send) # type: ignore[arg-type] assert not cancel_scope.cancel_called, "Content streaming should stop itself." @@ -80,7 +80,7 @@ async def send(message: "Message") -> None: response = ASGIStreamingResponse(iterator=cycle(["1", "2", "3"])) with anyio.move_on_after(1) as cancel_scope: - await response({}, receive_disconnect, send) # type: ignore + await response({}, receive_disconnect, send) # type: ignore[arg-type] assert not cancel_scope.cancel_called, "Content streaming should stop itself." diff --git a/tests/unit/test_response/test_type_decoders.py b/tests/unit/test_response/test_type_decoders.py new file mode 100644 index 0000000000..4d747d8505 --- /dev/null +++ b/tests/unit/test_response/test_type_decoders.py @@ -0,0 +1,108 @@ +from typing import Any, Literal, Type, Union +from unittest import mock + +import pytest + +from litestar import Controller, Litestar, Router, get +from litestar.datastructures.url import URL +from litestar.enums import HttpMethod +from litestar.handlers.http_handlers.base import HTTPRouteHandler +from litestar.handlers.websocket_handlers.listener import ( + WebsocketListener, + WebsocketListenerRouteHandler, + websocket_listener, +) +from litestar.types.composite_types import TypeDecodersSequence + +handler_decoder, router_decoder, controller_decoder, app_decoder = 4 * [(lambda t: t is URL, lambda t, v: URL(v))] + + +@pytest.fixture(scope="module") +def controller() -> Type[Controller]: + class MyController(Controller): + path = "/controller" + type_decoders = [controller_decoder] + + @get("/http", type_decoders=[handler_decoder]) + def http(self) -> Any: + ... + + @websocket_listener("/ws", type_decoders=[handler_decoder]) + async def handler(self, data: str) -> None: + ... + + return MyController + + +@pytest.fixture(scope="module") +def websocket_listener_handler() -> Type[WebsocketListener]: + class WebSocketHandler(WebsocketListener): + path = "/ws-listener" + type_decoders = [handler_decoder] + + def on_receive(self, data: str) -> None: # pyright: ignore [reportIncompatibleMethodOverride] + ... + + return WebSocketHandler + + +@pytest.fixture(scope="module") +def http_handler() -> HTTPRouteHandler: + @get("/http", type_decoders=[handler_decoder]) + def http() -> Any: + ... + + return http + + +@pytest.fixture(scope="module") +def websocket_handler() -> WebsocketListenerRouteHandler: + @websocket_listener("/ws", type_decoders=[handler_decoder]) + async def websocket(data: str) -> None: + ... + + return websocket + + +@pytest.fixture(scope="module") +def router( + controller: Type[Controller], + websocket_listener_handler: Type[WebsocketListenerRouteHandler], + http_handler: Type[HTTPRouteHandler], + websocket_handler: Type[WebsocketListenerRouteHandler], +) -> Router: + return Router( + "/router", + type_decoders=[router_decoder], + route_handlers=[controller, websocket_listener_handler, http_handler, websocket_handler], + ) + + +@pytest.fixture(scope="module") +@mock.patch("litestar.app.Litestar._get_default_plugins", mock.Mock(return_value=[])) +def app(router: Router) -> Litestar: + return Litestar([router], type_decoders=[app_decoder]) + + +@pytest.mark.parametrize( + "path, method, type_decoders", + ( + ("/router/controller/http", HttpMethod.GET, [app_decoder, router_decoder, controller_decoder, handler_decoder]), + ("/router/controller/ws", "websocket", [app_decoder, router_decoder, controller_decoder, handler_decoder]), + ("/router/http", HttpMethod.GET, [app_decoder, router_decoder, handler_decoder]), + ("/router/ws", "websocket", [app_decoder, router_decoder, handler_decoder]), + ("/router/ws-listener", "websocket", [app_decoder, router_decoder, handler_decoder]), + ), + ids=( + "Controller http endpoint type decoders", + "Controller ws endpoint type decoders", + "Router http endpoint type decoders", + "Router ws endpoint type decoders", + "Router ws listener type decoders", + ), +) +def test_resolve_type_decoders( + path: str, method: Union[HttpMethod, Literal["websocket"]], type_decoders: TypeDecodersSequence, app: Litestar +) -> None: + handler = app.route_handler_method_map[path][method] + assert handler.resolve_type_decoders() == type_decoders diff --git a/tests/unit/test_response/test_type_encoders.py b/tests/unit/test_response/test_type_encoders.py index 28964f2b43..2d661e99d9 100644 --- a/tests/unit/test_response/test_type_encoders.py +++ b/tests/unit/test_response/test_type_encoders.py @@ -33,7 +33,7 @@ def handler(self) -> Any: router = Router("/router", type_encoders={router_type: router_encoder}, route_handlers=[MyController]) app = Litestar([router], type_encoders={app_type: app_encoder}) - route_handler = app.routes[0].route_handler_map[HttpMethod.GET][0] # type: ignore + route_handler = app.routes[0].route_handler_map[HttpMethod.GET][0] # type: ignore[union-attr] encoders = route_handler.resolve_type_encoders() assert encoders.get(handler_type) == handler_encoder assert encoders.get(controller_type) == controller_encoder diff --git a/tests/unit/test_response_class_resolution.py b/tests/unit/test_response_class_resolution.py index 9beb6c2ece..0744a6d302 100644 --- a/tests/unit/test_response_class_resolution.py +++ b/tests/unit/test_response_class_resolution.py @@ -1,72 +1,62 @@ -from typing import Optional +from typing import Optional, Type import pytest from litestar import Controller, HttpMethod, Litestar, Response, Router, get -from litestar.types import Empty +from litestar.handlers.http_handlers.base import HTTPRouteHandler -router_response = type("router_response", (Response,), {}) -controller_response = type("controller_response", (Response,), {}) -app_response = type("app_response", (Response,), {}) -handler_response = type("local_response", (Response,), {}) - -test_path = "/test" +RouterResponse: Type[Response] = type("RouterResponse", (Response,), {}) +ControllerResponse: Type[Response] = type("ControllerResponse", (Response,), {}) +AppResponse: Type[Response] = type("AppResponse", (Response,), {}) +HandlerResponse: Type[Response] = type("HandlerResponse", (Response,), {}) @pytest.mark.parametrize( - "layer, expected", - [[0, handler_response], [1, controller_response], [2, router_response], [3, app_response], [None, Response]], + "handler_response_class, controller_response_class, router_response_class, app_response_class, expected", + ( + (HandlerResponse, ControllerResponse, RouterResponse, AppResponse, HandlerResponse), + (None, ControllerResponse, RouterResponse, AppResponse, ControllerResponse), + (None, None, RouterResponse, AppResponse, RouterResponse), + (None, None, None, AppResponse, AppResponse), + (None, None, None, None, Response), + ), + ids=( + "Custom class for all layers", + "Custom class for all above handler layer", + "Custom class for all above controller layer", + "Custom class for all above router layer", + "No custom class for layers", + ), ) -def test_response_class_resolution_of_layers(layer: Optional[int], expected: Response) -> None: +def test_response_class_resolution_of_layers( + handler_response_class: Optional[Type[Response]], + controller_response_class: Optional[Type[Response]], + router_response_class: Optional[Type[Response]], + app_response_class: Optional[Type[Response]], + expected: Type[Response], +) -> None: class MyController(Controller): - path = test_path - - @get( - path="/{path_param:str}", - ) - def test_method(self) -> None: + @get() + def handler(self) -> None: pass - MyController.test_method._resolved_response_class = Empty if layer != 0 else expected # type: ignore - MyController.response_class = None if layer != 1 else expected # type: ignore - router = Router(path="/users", route_handlers=[MyController], response_class=None if layer != 2 else expected) # type: ignore - app = Litestar(route_handlers=[router], response_class=None if layer != 3 else expected) # type: ignore - route_handler, _ = app.routes[0].route_handler_map[HttpMethod.GET] # type: ignore - layer_map = { - 0: route_handler, - 1: MyController, - 2: router, - 3: app, - } - component = layer_map.get(layer) # type: ignore - if component: - component.response_class = expected - assert component.response_class is expected - response_class = route_handler.resolve_response_class() - assert response_class is expected - if component: - component.response_class = None - assert component.response_class is None - + if controller_response_class: + MyController.response_class = ControllerResponse -def test_response_class_resolution_overrides() -> None: - class MyController(Controller): - path = "/path" - response_class = controller_response + router = Router(path="/", route_handlers=[MyController]) - @get("/", response_class=handler_response) - def handler(self) -> None: - return + if router_response_class: + router.response_class = router_response_class - assert MyController.handler.resolve_response_class() is handler_response + app = Litestar(route_handlers=[router]) + if app_response_class: + app.response_class = app_response_class -def test_response_class_resolution_defaults() -> None: - class MyController(Controller): - path = "/path" + route_handler: HTTPRouteHandler = app.route_handler_method_map["/"][HttpMethod.GET] # type: ignore[assignment] - @get("/") - def handler(self) -> None: - return + if handler_response_class: + route_handler.response_class = handler_response_class - assert MyController.handler.resolve_response_class() is Response + response_class = route_handler.resolve_response_class() + assert response_class is expected diff --git a/tests/unit/test_security/test_jwt/test_auth.py b/tests/unit/test_security/test_jwt/test_auth.py index 2b6c22371c..fd2e8cbe4a 100644 --- a/tests/unit/test_security/test_jwt/test_auth.py +++ b/tests/unit/test_security/test_jwt/test_auth.py @@ -175,7 +175,7 @@ async def retrieve_user_handler(token: Token, connection: Any) -> Any: key=auth_cookie, auth_header=auth_header, default_token_expiration=default_token_expiration, - retrieve_user_handler=retrieve_user_handler, # type: ignore + retrieve_user_handler=retrieve_user_handler, # type: ignore[var-annotated] token_secret=token_secret, ) @@ -298,7 +298,7 @@ def west_handler() -> None: def test_jwt_auth_openapi() -> None: - jwt_auth = JWTAuth[Any](token_secret="abc123", retrieve_user_handler=lambda _: None) # type: ignore + jwt_auth = JWTAuth[Any](token_secret="abc123", retrieve_user_handler=lambda _: None) # type: ignore[arg-type, misc] assert jwt_auth.openapi_components.to_schema() == { "schemas": {}, "securitySchemes": { @@ -348,7 +348,7 @@ async def retrieve_user_handler(token: Token, connection: Any) -> Any: jwt_auth = OAuth2PasswordBearerAuth( token_url="/login", token_secret="abc123", - retrieve_user_handler=retrieve_user_handler, # type: ignore + retrieve_user_handler=retrieve_user_handler, # type: ignore[var-annotated] ) @get("/login") diff --git a/tests/unit/test_signature/test_validation.py b/tests/unit/test_signature/test_validation.py index 70c28203fe..1842a8092e 100644 --- a/tests/unit/test_signature/test_validation.py +++ b/tests/unit/test_signature/test_validation.py @@ -33,7 +33,7 @@ def fn(a: int) -> None: def test_create_signature_validation() -> None: @get() - def my_fn(typed: int, untyped) -> None: # type: ignore + def my_fn(typed: int, untyped) -> None: # type: ignore[no-untyped-def] pass with pytest.raises(ImproperlyConfiguredException): @@ -74,7 +74,7 @@ def test(dep: int, param: int, optional_dep: Optional[int] = Dependency()) -> No response = client.get("/?param=thirteen") assert response.json() == { - "detail": "Validation failed for GET http://testserver.local/?param=thirteen", + "detail": "Validation failed for GET /?param=thirteen", "extra": [{"key": "param", "message": "Expected `int`, got `str`", "source": "query"}], "status_code": 400, } @@ -94,7 +94,7 @@ def test(dep: int, param: int, optional_dep: Optional[int] = Dependency()) -> No response = client.get("/?param=thirteen") assert response.json() == { - "detail": "Validation failed for GET http://testserver.local/?param=thirteen", + "detail": "Validation failed for GET /?param=thirteen", "extra": [{"key": "param", "message": "Expected `int`, got `str`", "source": "query"}], "status_code": 400, } diff --git a/tests/unit/test_static_files/test_file_serving_resolution.py b/tests/unit/test_static_files/test_file_serving_resolution.py index 3bc9c70d6a..f204b56d78 100644 --- a/tests/unit/test_static_files/test_file_serving_resolution.py +++ b/tests/unit/test_static_files/test_file_serving_resolution.py @@ -159,7 +159,7 @@ def handler(f: str) -> str: def test_static_substring_of_self(tmpdir: Path, make_config: MakeConfig) -> None: - path = tmpdir.mkdir("static_part").mkdir("static") / "test.txt" # type: ignore + path = tmpdir.mkdir("static_part").mkdir("static") / "test.txt" # type: ignore[arg-type, func-returns-value] path.write_text("content", "utf-8") configs, handlers = make_config(StaticFilesConfig(path="/static", directories=[tmpdir])) @@ -209,7 +209,7 @@ def test_static_files_response_encoding(tmp_path: Path, extension: str, make_con def test_static_files_content_disposition( tmpdir: Path, send_as_attachment: bool, disposition: str, make_config: MakeConfig ) -> None: - path = tmpdir.mkdir("static_part").mkdir("static") / "test.txt" # type: ignore + path = tmpdir.mkdir("static_part").mkdir("static") / "test.txt" # type: ignore[arg-type, func-returns-value] path.write_text("content", "utf-8") configs, handlers = make_config( @@ -223,7 +223,7 @@ def test_static_files_content_disposition( def test_service_from_relative_path_using_string(tmpdir: Path, make_config: MakeConfig) -> None: - sub_dir = Path(tmpdir.mkdir("low")).resolve() # type: ignore + sub_dir = Path(tmpdir.mkdir("low")).resolve() # type: ignore[arg-type, func-returns-value] path = tmpdir / "test.txt" path.write_text("content", "utf-8") @@ -237,7 +237,7 @@ def test_service_from_relative_path_using_string(tmpdir: Path, make_config: Make def test_service_from_relative_path_using_path(tmpdir: Path, make_config: MakeConfig) -> None: - sub_dir = Path(tmpdir.mkdir("low")).resolve() # type: ignore + sub_dir = Path(tmpdir.mkdir("low")).resolve() # type: ignore[arg-type, func-returns-value] path = tmpdir / "test.txt" path.write_text("content", "utf-8") @@ -251,7 +251,7 @@ def test_service_from_relative_path_using_path(tmpdir: Path, make_config: MakeCo def test_service_from_base_path_using_string(tmpdir: Path) -> None: - sub_dir = Path(tmpdir.mkdir("low")).resolve() # type: ignore + sub_dir = Path(tmpdir.mkdir("low")).resolve() # type: ignore[arg-type, func-returns-value] path = tmpdir / "test.txt" path.write_text("content", "utf-8") diff --git a/tests/unit/test_testing/test_request_factory.py b/tests/unit/test_testing/test_request_factory.py index 7869e83b01..2b848ddc8c 100644 --- a/tests/unit/test_testing/test_request_factory.py +++ b/tests/unit/test_testing/test_request_factory.py @@ -69,7 +69,7 @@ async def test_request_factory_create_with_data(data_cls: DataContainerType) -> request = RequestFactory()._create_request_with_data( HttpMethod.POST, "/", - data=data_cls(**person_data), # type: ignore + data=data_cls(**person_data), # type: ignore[operator] ) body = await request.body() assert json.loads(body) == person_data diff --git a/tests/unit/test_utils/test_sync.py b/tests/unit/test_utils/test_sync.py index b2582b9836..6e3be65c94 100644 --- a/tests/unit/test_utils/test_sync.py +++ b/tests/unit/test_utils/test_sync.py @@ -32,10 +32,10 @@ async def my_method(self, value: int) -> None: wrapped_method = ensure_async_callable(instance.my_method) - await wrapped_method(1) # type: ignore + await wrapped_method(1) # type: ignore[unused-coroutine] assert instance.value == 1 - await wrapped_method(value=10) # type: ignore + await wrapped_method(value=10) # type: ignore[unused-coroutine] assert instance.value == 10 @@ -62,10 +62,10 @@ async def my_function(new_value: int) -> None: wrapped_function = ensure_async_callable(my_function) - await wrapped_function(1) # type: ignore + await wrapped_function(1) # type: ignore[unused-coroutine] assert obj["value"] == 1 - await wrapped_function(new_value=10) # type: ignore + await wrapped_function(new_value=10) # type: ignore[unused-coroutine] assert obj["value"] == 10 @@ -98,8 +98,8 @@ async def __call__(self, new_value: int) -> None: wrapped_class = ensure_async_callable(instance) - await wrapped_class(1) # type: ignore + await wrapped_class(1) # type: ignore[unused-coroutine] assert instance.value == 1 - await wrapped_class(new_value=10) # type: ignore + await wrapped_class(new_value=10) # type: ignore[unused-coroutine] assert instance.value == 10 diff --git a/tests/unit/test_websocket_class_resolution.py b/tests/unit/test_websocket_class_resolution.py new file mode 100644 index 0000000000..4ec1967320 --- /dev/null +++ b/tests/unit/test_websocket_class_resolution.py @@ -0,0 +1,115 @@ +from typing import Type, Union + +import pytest + +from litestar import Controller, Litestar, Router, WebSocket +from litestar.handlers.websocket_handlers.listener import WebsocketListener, websocket_listener + +RouterWebSocket: Type[WebSocket] = type("RouterWebSocket", (WebSocket,), {}) +ControllerWebSocket: Type[WebSocket] = type("ControllerWebSocket", (WebSocket,), {}) +AppWebSocket: Type[WebSocket] = type("AppWebSocket", (WebSocket,), {}) +HandlerWebSocket: Type[WebSocket] = type("HandlerWebSocket", (WebSocket,), {}) + + +@pytest.mark.parametrize( + "handler_websocket_class, controller_websocket_class, router_websocket_class, app_websocket_class, has_default_app_class, expected", + ( + (HandlerWebSocket, ControllerWebSocket, RouterWebSocket, AppWebSocket, True, HandlerWebSocket), + (None, ControllerWebSocket, RouterWebSocket, AppWebSocket, True, ControllerWebSocket), + (None, None, RouterWebSocket, AppWebSocket, True, RouterWebSocket), + (None, None, None, AppWebSocket, True, AppWebSocket), + (None, None, None, None, True, WebSocket), + (None, None, None, None, False, WebSocket), + ), + ids=( + "Custom class for all layers", + "Custom class for all above handler layer", + "Custom class for all above controller layer", + "Custom class for all above router layer", + "No custom class for layers", + "No default class in app", + ), +) +def test_websocket_class_resolution_of_layers( + handler_websocket_class: Union[Type[WebSocket], None], + controller_websocket_class: Union[Type[WebSocket], None], + router_websocket_class: Union[Type[WebSocket], None], + app_websocket_class: Union[Type[WebSocket], None], + has_default_app_class: bool, + expected: Type[WebSocket], +) -> None: + class MyController(Controller): + @websocket_listener("/") + def handler(self, data: str) -> None: + return + + if controller_websocket_class: + MyController.websocket_class = ControllerWebSocket + + router = Router(path="/", route_handlers=[MyController]) + + if router_websocket_class: + router.websocket_class = router_websocket_class + + app = Litestar(route_handlers=[router]) + + if app_websocket_class or not has_default_app_class: + app.websocket_class = app_websocket_class # type: ignore[assignment] + + route_handler = app.routes[0].route_handler # type: ignore[union-attr] + + if handler_websocket_class: + route_handler.websocket_class = handler_websocket_class # type: ignore[union-attr] + + websocket_class = route_handler.resolve_websocket_class() # type: ignore[union-attr] + assert websocket_class is expected + + +@pytest.mark.parametrize( + "handler_websocket_class, router_websocket_class, app_websocket_class, has_default_app_class, expected", + ( + (HandlerWebSocket, RouterWebSocket, AppWebSocket, True, HandlerWebSocket), + (None, RouterWebSocket, AppWebSocket, True, RouterWebSocket), + (None, None, AppWebSocket, True, AppWebSocket), + (None, None, None, True, WebSocket), + (None, None, None, False, WebSocket), + ), + ids=( + "Custom class for all layers", + "Custom class for all above handler layer", + "Custom class for all above router layer", + "No custom class for layers", + "No default class in app", + ), +) +def test_listener_websocket_class_resolution_of_layers( + handler_websocket_class: Union[Type[WebSocket], None], + router_websocket_class: Union[Type[WebSocket], None], + app_websocket_class: Union[Type[WebSocket], None], + has_default_app_class: bool, + expected: Type[WebSocket], +) -> None: + class Handler(WebsocketListener): + path = "/" + websocket_class = handler_websocket_class + + def on_receive(self, data: str) -> str: # pyright: ignore + return data + + router = Router(path="/", route_handlers=[Handler]) + + if router_websocket_class: + router.websocket_class = router_websocket_class + + app = Litestar(route_handlers=[router]) + + if app_websocket_class or not has_default_app_class: + app.websocket_class = app_websocket_class # type: ignore[assignment] + + route_handler = app.routes[0].route_handler # type: ignore[union-attr] + + if handler_websocket_class: + route_handler.websocket_class = handler_websocket_class # type: ignore[union-attr] + + websocket_class = route_handler.resolve_websocket_class() # type: ignore[union-attr] + assert websocket_class is expected