Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add litestar framework #311

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ itsdangerous==2.1.2
Jinja2==3.1.1
jsonschema==3.2.0
lazy-object-proxy==1.7.1
Mako==1.2.0
Mako==1.2.4
Markdown==3.3.6
MarkupSafe==2.1.1
mccabe==0.6.1
Expand All @@ -48,7 +48,7 @@ py==1.11.0
pycodestyle==2.8.0
pycparser==2.21
pycryptodome==3.10.4
pydantic==1.9.0
pydantic>=1.10.0, <2.0.0
PyJWT==2.0.1
pylint==2.12.2
pyparsing==3.0.7
Expand Down Expand Up @@ -84,3 +84,4 @@ uvicorn==0.18.2
Werkzeug==2.0.3
wrapt==1.13.3
zipp==3.7.0
litestar>=2.0.0alpha4, <3.0.0
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
# we use to develop the SDK with otherwise we get
# a bunch of type errors on make dev-install depending
# on changes in these frameworks
"litestar": (
[
"litestar>=2.0.0alpha4, <3.0.0",
]
),
"fastapi": (
[
"respx==0.19.2",
Expand Down
3 changes: 2 additions & 1 deletion supertokens_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from . import supertokens
from .recipe_module import RecipeModule
from supertokens_python.types import SupportedFrameworks

InputAppInfo = supertokens.InputAppInfo
Supertokens = supertokens.Supertokens
Expand All @@ -26,7 +27,7 @@

def init(
app_info: InputAppInfo,
framework: Literal["fastapi", "flask", "django"],
framework: SupportedFrameworks,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Thanks!

supertokens_config: SupertokensConfig,
recipe_list: List[Callable[[supertokens.AppInfo], RecipeModule]],
mode: Union[Literal["asgi", "wsgi"], None] = None,
Expand Down
5 changes: 5 additions & 0 deletions supertokens_python/framework/litestar/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from supertokens_python.framework.litestar import litestar_middleware

get_middleware = litestar_middleware.get_middleware

__all__ = ("get_middleware",)
17 changes: 17 additions & 0 deletions supertokens_python/framework/litestar/framework.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from supertokens_python.framework.types import Framework

if TYPE_CHECKING:
from litestar import Request


class LitestarFramework(Framework):
def wrap_request(self, unwrapped: Request[Any, Any, Any]):
from supertokens_python.framework.litestar.litestar_request import (
LitestarRequest,
)

return LitestarRequest(unwrapped)
51 changes: 51 additions & 0 deletions supertokens_python/framework/litestar/litestar_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from functools import lru_cache
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from litestar.middleware.base import AbstractMiddleware


@lru_cache
def get_middleware() -> type[AbstractMiddleware]:
from supertokens_python import Supertokens
from supertokens_python.exceptions import SuperTokensError
from supertokens_python.framework.litestar.litestar_request import LitestarRequest
from supertokens_python.framework.litestar.litestar_response import LitestarResponse
from supertokens_python.recipe.session import SessionContainer
from supertokens_python.supertokens import manage_session_post_response

from litestar import Response, Request
from litestar.middleware.base import AbstractMiddleware
from litestar.types import Scope, Receive, Send

class Middleware(AbstractMiddleware):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
st = Supertokens.get_instance()
request = Request[Any, Any, Any](scope, receive, send)

try:
result = await st.middleware(
LitestarRequest(request),
LitestarResponse(Response[Any](content=None)),
)
except SuperTokensError as e:
result = await st.handle_supertokens_error(
LitestarRequest(request),
e,
LitestarResponse(Response[Any](content=None)),
)

if isinstance(result, LitestarResponse):
if (
session_container := request.state.get("supertokens")
) and isinstance(session_container, SessionContainer):
manage_session_post_response(session_container, result)

await result.response(scope, receive, send)
return

await self.app(scope, receive, send)
Comment on lines +28 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except SuperTokensError as e isn't working as expected.

I visited http://localhost:8000/test/secure without any session (based on config given by @Spectryx) returned:

{
  "status_code": 500,
  "detail": "UnauthorisedError('Session does not exist. Are you sending the session tokens in the request with the appropriate token transfer method?')"
}

I used the debugger and found that await self.app(...) raises UnauthorizedError from verify_session and that isn't handled here. So I changed it to:

    class Middleware(AbstractMiddleware):
        async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
            st = Supertokens.get_instance()
            request = Request[Any, Any, Any](scope, receive, send)

            try:
                result = await st.middleware(
                    LitestarRequest(request),
                    LitestarResponse(Response[Any](content=None)),
                )
                if isinstance(result, LitestarResponse):
                    if (
                        session_container := request.state.get("supertokens")
                    ) and isinstance(session_container, SessionContainer):
                        manage_session_post_response(session_container, result)

                    await result.response(scope, receive, send)
                    return

                async def send_wrapper(message: "Message") -> None:
                    await send(message)

                await self.app(scope, receive, send_wrapper)
                print("Sent response") # reaches this line
            except SuperTokensError as e:
                result = await st.handle_supertokens_error(
                    LitestarRequest(request),
                    e,
                    LitestarResponse(Response[Any](content=None)),
                )
                if isinstance(result, LitestarResponse):
                    await result.response(scope, receive, send)
                    return
            except Exception as e:
                print(e)
                raise e
                
            print("Middleware ran") # reaches this line 

    return Middleware

The debugger reaches print("Sent response") line despite the error response. This seems different from other frameworks. How does Litestar handle errors? Is https://docs.litestar.dev/2/usage/exceptions.html the only way?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We wrap the chain of middlewares inside an exception handling Middleware. You can customize its behavior by defining exception handler functions mapped to either status codes or exception types.

It's similar to Starlette in this regard.


return Middleware
56 changes: 56 additions & 0 deletions supertokens_python/framework/litestar/litestar_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from supertokens_python.framework.request import BaseRequest

if TYPE_CHECKING:
from litestar import Request
from supertokens_python.recipe.session.interfaces import SessionContainer

try:
from litestar.exceptions import SerializationException
except ImportError:
SerializationException = Exception # type: ignore


class LitestarRequest(BaseRequest):
def __init__(self, request: Request[Any, Any, Any]):
super().__init__()
self.request = request

def get_query_param(self, key: str, default: str | None = None) -> Any:
return self.request.query_params.get(key, default) # pyright: ignore

def get_query_params(self) -> dict[str, list[Any]]:
return self.request.query_params.dict() # pyright: ignore

async def json(self) -> Any:
try:
return await self.request.json()
except SerializationException:
return {}

def method(self) -> str:
return self.request.method

def get_cookie(self, key: str) -> str | None:
return self.request.cookies.get(key)

def get_header(self, key: str) -> str | None:
return self.request.headers.get(key, None)

def get_session(self) -> SessionContainer | None:
return self.request.state.supertokens

def set_session(self, session: SessionContainer):
self.request.state.supertokens = session

def set_session_as_none(self):
self.request.state.supertokens = None

def get_path(self) -> str:
return self.request.url.path

async def form_data(self) -> dict[str, list[Any]]:
return (await self.request.form()).dict()
72 changes: 72 additions & 0 deletions supertokens_python/framework/litestar/litestar_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations
from typing import Any, TYPE_CHECKING, cast
from typing_extensions import Literal
from supertokens_python.framework.response import BaseResponse

if TYPE_CHECKING:
from litestar import Response


class LitestarResponse(BaseResponse):
def __init__(self, response: Response[Any]):
super().__init__({})
self.response = response
self.original = response
self.parser_checked = False
self.response_sent = False
self.status_set = False

def set_html_content(self, content: str):
if not self.response_sent:
body = bytes(content, "utf-8")
self.set_header("Content-Length", str(len(body)))
self.set_header("Content-Type", "text/html")
self.response.body = body
self.response_sent = True

def set_cookie(
self,
key: str,
value: str,
expires: int,
path: str = "/",
domain: str | None = None,
secure: bool = False,
httponly: bool = False,
samesite: str = "lax",
):
self.response.set_cookie(
key=key,
value=value,
expires=expires,
path=path,
domain=domain,
secure=secure,
httponly=httponly,
samesite=cast(Literal["lax", "strict", "none"], samesite),
)

def set_header(self, key: str, value: str):
self.response.set_header(key, value)

def get_header(self, key: str) -> str | None:
return self.response.headers.get(key, None)

def remove_header(self, key: str):
del self.response.headers[key]

def set_status_code(self, status_code: int):
if not self.status_set:
self.response.status_code = status_code
self.status_code = status_code
self.status_set = True

def set_json_content(self, content: dict[str, Any]):
if not self.response_sent:
from litestar.serialization import encode_json

body = encode_json(content)
self.set_header("Content-Type", "application/json; charset=utf-8")
self.set_header("Content-Length", str(len(body)))
self.response.body = body
self.response_sent = True
2 changes: 0 additions & 2 deletions supertokens_python/framework/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from supertokens_python.framework.request import BaseRequest

frameworks = ["fastapi", "flask", "django"]


class FrameworkEnum(Enum):
FASTAPI = 1
Expand Down
7 changes: 4 additions & 3 deletions supertokens_python/recipe/session/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def create_new_session(

if not hasattr(request, "wrapper_used") or not request.wrapper_used:
request = FRAMEWORKS[
SessionRecipe.get_instance().app_info.framework
SessionRecipe.get_instance().app_info.framework # pyright: ignore
].wrap_request(request)

return await SessionRecipe.get_instance().recipe_implementation.create_new_session(
Expand Down Expand Up @@ -251,7 +251,7 @@ async def get_session(
user_context = {}
if not hasattr(request, "wrapper_used") or not request.wrapper_used:
request = FRAMEWORKS[
SessionRecipe.get_instance().app_info.framework
SessionRecipe.get_instance().app_info.framework # pyright: ignore
].wrap_request(request)

session_recipe_impl = SessionRecipe.get_instance().recipe_implementation
Expand All @@ -278,7 +278,7 @@ async def refresh_session(
user_context = {}
if not hasattr(request, "wrapper_used") or not request.wrapper_used:
request = FRAMEWORKS[
SessionRecipe.get_instance().app_info.framework
SessionRecipe.get_instance().app_info.framework # pyright: ignore
].wrap_request(request)

return await SessionRecipe.get_instance().recipe_implementation.refresh_session(
Expand Down Expand Up @@ -331,6 +331,7 @@ async def get_session_information(
) -> Union[SessionInformationResult, None]:
if user_context is None:
user_context = {}

return await SessionRecipe.get_instance().recipe_implementation.get_session_information(
session_handle, user_context
)
Expand Down
44 changes: 44 additions & 0 deletions supertokens_python/recipe/session/framework/litestar/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations
from typing import Any, Callable, Coroutine, TYPE_CHECKING

from supertokens_python.framework.litestar.litestar_request import LitestarRequest
from supertokens_python.recipe.session import SessionRecipe
from supertokens_python.types import MaybeAwaitable

from ...interfaces import SessionContainer, SessionClaimValidator

if TYPE_CHECKING:
from litestar import Request


def verify_session(
anti_csrf_check: bool | None = None,
session_required: bool = True,
override_global_claim_validators: Callable[
[list[SessionClaimValidator], SessionContainer, dict[str, Any]],
MaybeAwaitable[list[SessionClaimValidator]],
]
| None = None,
user_context: None | dict[str, Any] = None,
) -> Callable[..., Coroutine[Any, Any, SessionContainer | None]]:
async def func(request: Request[Any, Any, Any]) -> SessionContainer | None:
litestar_request = LitestarRequest(request)
recipe = SessionRecipe.get_instance()
session = await recipe.verify_session(
litestar_request,
anti_csrf_check,
session_required,
override_global_claim_validators,
user_context or {},
)

if session:
litestar_request.set_session(session)
elif session_required:
raise RuntimeError("Should never come here")
else:
litestar_request.set_session_as_none()

return litestar_request.get_session()

return func
10 changes: 5 additions & 5 deletions supertokens_python/supertokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing_extensions import Literal

from supertokens_python.logger import get_maybe_none_as_str, log_debug_message

from supertokens_python.types import SupportedFrameworks
from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT, USER_DELETE, USERS
from .exceptions import SuperTokensError
from .interfaces import (
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(
app_name: str,
api_domain: str,
website_domain: str,
framework: Literal["fastapi", "flask", "django"],
framework: SupportedFrameworks,
api_gateway_path: str,
api_base_path: str,
website_base_path: str,
Expand All @@ -117,7 +117,7 @@ def __init__(
self.website_base_path = NormalisedURLPath(website_base_path)
if mode is not None:
self.mode = mode
elif framework == "fastapi":
elif framework in ("fastapi", "litestar"):
mode = "asgi"
else:
mode = "wsgi"
Expand Down Expand Up @@ -145,7 +145,7 @@ class Supertokens:
def __init__(
self,
app_info: InputAppInfo,
framework: Literal["fastapi", "flask", "django"],
framework: SupportedFrameworks,
supertokens_config: SupertokensConfig,
recipe_list: List[Callable[[AppInfo], RecipeModule]],
mode: Union[Literal["asgi", "wsgi"], None],
Expand Down Expand Up @@ -199,7 +199,7 @@ def __init__(
@staticmethod
def init(
app_info: InputAppInfo,
framework: Literal["fastapi", "flask", "django"],
framework: SupportedFrameworks,
supertokens_config: SupertokensConfig,
recipe_list: List[Callable[[AppInfo], RecipeModule]],
mode: Union[Literal["asgi", "wsgi"], None],
Expand Down
Loading