-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add litestar framework #311
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from supertokens_python.framework.litestar import litestar_middleware | ||
|
||
get_middleware = litestar_middleware.get_middleware | ||
|
||
__all__ = ("get_middleware",) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any | ||
|
||
from supertokens_python.framework.types import Framework | ||
|
||
if TYPE_CHECKING: | ||
from litestar import Request | ||
|
||
|
||
class LitestarFramework(Framework): | ||
def wrap_request(self, unwrapped: Request[Any, Any, Any]): | ||
from supertokens_python.framework.litestar.litestar_request import ( | ||
LitestarRequest, | ||
) | ||
|
||
return LitestarRequest(unwrapped) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from __future__ import annotations | ||
|
||
from functools import lru_cache | ||
from typing import TYPE_CHECKING, Any | ||
|
||
if TYPE_CHECKING: | ||
from litestar.middleware.base import AbstractMiddleware | ||
|
||
|
||
@lru_cache | ||
def get_middleware() -> type[AbstractMiddleware]: | ||
from supertokens_python import Supertokens | ||
from supertokens_python.exceptions import SuperTokensError | ||
from supertokens_python.framework.litestar.litestar_request import LitestarRequest | ||
from supertokens_python.framework.litestar.litestar_response import LitestarResponse | ||
from supertokens_python.recipe.session import SessionContainer | ||
from supertokens_python.supertokens import manage_session_post_response | ||
|
||
from litestar import Response, Request | ||
from litestar.middleware.base import AbstractMiddleware | ||
from litestar.types import Scope, Receive, Send | ||
|
||
class Middleware(AbstractMiddleware): | ||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | ||
st = Supertokens.get_instance() | ||
request = Request[Any, Any, Any](scope, receive, send) | ||
|
||
try: | ||
result = await st.middleware( | ||
LitestarRequest(request), | ||
LitestarResponse(Response[Any](content=None)), | ||
) | ||
except SuperTokensError as e: | ||
result = await st.handle_supertokens_error( | ||
LitestarRequest(request), | ||
e, | ||
LitestarResponse(Response[Any](content=None)), | ||
) | ||
|
||
if isinstance(result, LitestarResponse): | ||
if ( | ||
session_container := request.state.get("supertokens") | ||
) and isinstance(session_container, SessionContainer): | ||
manage_session_post_response(session_container, result) | ||
|
||
await result.response(scope, receive, send) | ||
return | ||
|
||
await self.app(scope, receive, send) | ||
Comment on lines
+28
to
+49
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I visited http://localhost:8000/test/secure without any session (based on config given by @Spectryx) returned: {
"status_code": 500,
"detail": "UnauthorisedError('Session does not exist. Are you sending the session tokens in the request with the appropriate token transfer method?')"
} I used the debugger and found that class Middleware(AbstractMiddleware):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
st = Supertokens.get_instance()
request = Request[Any, Any, Any](scope, receive, send)
try:
result = await st.middleware(
LitestarRequest(request),
LitestarResponse(Response[Any](content=None)),
)
if isinstance(result, LitestarResponse):
if (
session_container := request.state.get("supertokens")
) and isinstance(session_container, SessionContainer):
manage_session_post_response(session_container, result)
await result.response(scope, receive, send)
return
async def send_wrapper(message: "Message") -> None:
await send(message)
await self.app(scope, receive, send_wrapper)
print("Sent response") # reaches this line
except SuperTokensError as e:
result = await st.handle_supertokens_error(
LitestarRequest(request),
e,
LitestarResponse(Response[Any](content=None)),
)
if isinstance(result, LitestarResponse):
await result.response(scope, receive, send)
return
except Exception as e:
print(e)
raise e
print("Middleware ran") # reaches this line
return Middleware The debugger reaches There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We wrap the chain of middlewares inside an exception handling Middleware. You can customize its behavior by defining exception handler functions mapped to either status codes or exception types. It's similar to Starlette in this regard. |
||
|
||
return Middleware |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any | ||
|
||
from supertokens_python.framework.request import BaseRequest | ||
|
||
if TYPE_CHECKING: | ||
from litestar import Request | ||
from supertokens_python.recipe.session.interfaces import SessionContainer | ||
|
||
try: | ||
from litestar.exceptions import SerializationException | ||
except ImportError: | ||
SerializationException = Exception # type: ignore | ||
|
||
|
||
class LitestarRequest(BaseRequest): | ||
def __init__(self, request: Request[Any, Any, Any]): | ||
super().__init__() | ||
self.request = request | ||
|
||
def get_query_param(self, key: str, default: str | None = None) -> Any: | ||
return self.request.query_params.get(key, default) # pyright: ignore | ||
|
||
def get_query_params(self) -> dict[str, list[Any]]: | ||
return self.request.query_params.dict() # pyright: ignore | ||
|
||
async def json(self) -> Any: | ||
try: | ||
return await self.request.json() | ||
except SerializationException: | ||
return {} | ||
|
||
def method(self) -> str: | ||
return self.request.method | ||
|
||
def get_cookie(self, key: str) -> str | None: | ||
return self.request.cookies.get(key) | ||
|
||
def get_header(self, key: str) -> str | None: | ||
return self.request.headers.get(key, None) | ||
|
||
def get_session(self) -> SessionContainer | None: | ||
return self.request.state.supertokens | ||
|
||
def set_session(self, session: SessionContainer): | ||
self.request.state.supertokens = session | ||
|
||
def set_session_as_none(self): | ||
self.request.state.supertokens = None | ||
|
||
def get_path(self) -> str: | ||
return self.request.url.path | ||
|
||
async def form_data(self) -> dict[str, list[Any]]: | ||
return (await self.request.form()).dict() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from __future__ import annotations | ||
from typing import Any, TYPE_CHECKING, cast | ||
from typing_extensions import Literal | ||
from supertokens_python.framework.response import BaseResponse | ||
|
||
if TYPE_CHECKING: | ||
from litestar import Response | ||
|
||
|
||
class LitestarResponse(BaseResponse): | ||
def __init__(self, response: Response[Any]): | ||
super().__init__({}) | ||
self.response = response | ||
self.original = response | ||
self.parser_checked = False | ||
self.response_sent = False | ||
self.status_set = False | ||
|
||
def set_html_content(self, content: str): | ||
if not self.response_sent: | ||
body = bytes(content, "utf-8") | ||
self.set_header("Content-Length", str(len(body))) | ||
self.set_header("Content-Type", "text/html") | ||
self.response.body = body | ||
self.response_sent = True | ||
|
||
def set_cookie( | ||
self, | ||
key: str, | ||
value: str, | ||
expires: int, | ||
path: str = "/", | ||
domain: str | None = None, | ||
secure: bool = False, | ||
httponly: bool = False, | ||
samesite: str = "lax", | ||
): | ||
self.response.set_cookie( | ||
key=key, | ||
value=value, | ||
expires=expires, | ||
path=path, | ||
domain=domain, | ||
secure=secure, | ||
httponly=httponly, | ||
samesite=cast(Literal["lax", "strict", "none"], samesite), | ||
) | ||
|
||
def set_header(self, key: str, value: str): | ||
self.response.set_header(key, value) | ||
|
||
def get_header(self, key: str) -> str | None: | ||
return self.response.headers.get(key, None) | ||
|
||
def remove_header(self, key: str): | ||
del self.response.headers[key] | ||
|
||
def set_status_code(self, status_code: int): | ||
if not self.status_set: | ||
self.response.status_code = status_code | ||
self.status_code = status_code | ||
self.status_set = True | ||
|
||
def set_json_content(self, content: dict[str, Any]): | ||
if not self.response_sent: | ||
from litestar.serialization import encode_json | ||
|
||
body = encode_json(content) | ||
self.set_header("Content-Type", "application/json; charset=utf-8") | ||
self.set_header("Content-Length", str(len(body))) | ||
self.response.body = body | ||
self.response_sent = True |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from __future__ import annotations | ||
from typing import Any, Callable, Coroutine, TYPE_CHECKING | ||
|
||
from supertokens_python.framework.litestar.litestar_request import LitestarRequest | ||
from supertokens_python.recipe.session import SessionRecipe | ||
from supertokens_python.types import MaybeAwaitable | ||
|
||
from ...interfaces import SessionContainer, SessionClaimValidator | ||
|
||
if TYPE_CHECKING: | ||
from litestar import Request | ||
|
||
|
||
def verify_session( | ||
anti_csrf_check: bool | None = None, | ||
session_required: bool = True, | ||
override_global_claim_validators: Callable[ | ||
[list[SessionClaimValidator], SessionContainer, dict[str, Any]], | ||
MaybeAwaitable[list[SessionClaimValidator]], | ||
] | ||
| None = None, | ||
user_context: None | dict[str, Any] = None, | ||
) -> Callable[..., Coroutine[Any, Any, SessionContainer | None]]: | ||
async def func(request: Request[Any, Any, Any]) -> SessionContainer | None: | ||
litestar_request = LitestarRequest(request) | ||
recipe = SessionRecipe.get_instance() | ||
session = await recipe.verify_session( | ||
litestar_request, | ||
anti_csrf_check, | ||
session_required, | ||
override_global_claim_validators, | ||
user_context or {}, | ||
) | ||
|
||
if session: | ||
litestar_request.set_session(session) | ||
elif session_required: | ||
raise RuntimeError("Should never come here") | ||
else: | ||
litestar_request.set_session_as_none() | ||
|
||
return litestar_request.get_session() | ||
|
||
return func |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. Thanks!