From fb7fc5110a338dde4ced003a3bf86d9b19c5829b Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sat, 15 Apr 2023 15:33:25 +0200 Subject: [PATCH 1/2] feat: add litestar framework --- dev-requirements.txt | 5 +- setup.py | 5 + supertokens_python/__init__.py | 3 +- .../framework/litestar/__init__.py | 5 + .../framework/litestar/framework.py | 17 + .../framework/litestar/litestar_middleware.py | 51 + .../framework/litestar/litestar_request.py | 56 + .../framework/litestar/litestar_response.py | 72 ++ supertokens_python/framework/types.py | 2 - .../recipe/session/asyncio/__init__.py | 7 +- .../session/framework/litestar/__init__.py | 44 + supertokens_python/supertokens.py | 10 +- supertokens_python/types.py | 3 + supertokens_python/utils.py | 7 +- tests/litestar/__init__.py | 3 + tests/litestar/test_litestar.py | 1013 +++++++++++++++++ 16 files changed, 1288 insertions(+), 15 deletions(-) create mode 100644 supertokens_python/framework/litestar/__init__.py create mode 100644 supertokens_python/framework/litestar/framework.py create mode 100644 supertokens_python/framework/litestar/litestar_middleware.py create mode 100644 supertokens_python/framework/litestar/litestar_request.py create mode 100644 supertokens_python/framework/litestar/litestar_response.py create mode 100644 supertokens_python/recipe/session/framework/litestar/__init__.py create mode 100644 tests/litestar/__init__.py create mode 100644 tests/litestar/test_litestar.py diff --git a/dev-requirements.txt b/dev-requirements.txt index 04017720e..67ce3c5cd 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -30,7 +30,7 @@ itsdangerous==2.1.2 Jinja2==3.1.1 jsonschema==3.2.0 lazy-object-proxy==1.7.1 -Mako==1.2.0 +Mako==1.2.4 Markdown==3.3.6 MarkupSafe==2.1.1 mccabe==0.6.1 @@ -48,7 +48,7 @@ py==1.11.0 pycodestyle==2.8.0 pycparser==2.21 pycryptodome==3.10.4 -pydantic==1.9.0 +pydantic>=1.10.0, <2.0.0 PyJWT==2.0.1 pylint==2.12.2 pyparsing==3.0.7 @@ -84,3 +84,4 @@ uvicorn==0.18.2 Werkzeug==2.0.3 wrapt==1.13.3 zipp==3.7.0 +litestar>=2.0.0alpha4, <3.0.0 diff --git a/setup.py b/setup.py index 8fef020d7..1ac0fb0ac 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,11 @@ # we use to develop the SDK with otherwise we get # a bunch of type errors on make dev-install depending # on changes in these frameworks + "litestar": ( + [ + "litestar>=2.0.0alpha4, <3.0.0", + ] + ), "fastapi": ( [ "respx==0.19.2", diff --git a/supertokens_python/__init__.py b/supertokens_python/__init__.py index 519e7411e..f9afc66bd 100644 --- a/supertokens_python/__init__.py +++ b/supertokens_python/__init__.py @@ -17,6 +17,7 @@ from . import supertokens from .recipe_module import RecipeModule +from supertokens_python.types import SupportedFrameworks InputAppInfo = supertokens.InputAppInfo Supertokens = supertokens.Supertokens @@ -26,7 +27,7 @@ def init( app_info: InputAppInfo, - framework: Literal["fastapi", "flask", "django"], + framework: SupportedFrameworks, supertokens_config: SupertokensConfig, recipe_list: List[Callable[[supertokens.AppInfo], RecipeModule]], mode: Union[Literal["asgi", "wsgi"], None] = None, diff --git a/supertokens_python/framework/litestar/__init__.py b/supertokens_python/framework/litestar/__init__.py new file mode 100644 index 000000000..5b417132a --- /dev/null +++ b/supertokens_python/framework/litestar/__init__.py @@ -0,0 +1,5 @@ +from supertokens_python.framework.litestar import litestar_middleware + +get_middleware = litestar_middleware.get_middleware + +__all__ = ("get_middleware",) diff --git a/supertokens_python/framework/litestar/framework.py b/supertokens_python/framework/litestar/framework.py new file mode 100644 index 000000000..2c2a79967 --- /dev/null +++ b/supertokens_python/framework/litestar/framework.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from supertokens_python.framework.types import Framework + +if TYPE_CHECKING: + from litestar import Request + + +class LitestarFramework(Framework): + def wrap_request(self, unwrapped: Request[Any, Any, Any]): + from supertokens_python.framework.litestar.litestar_request import ( + LitestarRequest, + ) + + return LitestarRequest(unwrapped) diff --git a/supertokens_python/framework/litestar/litestar_middleware.py b/supertokens_python/framework/litestar/litestar_middleware.py new file mode 100644 index 000000000..03828ec71 --- /dev/null +++ b/supertokens_python/framework/litestar/litestar_middleware.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from litestar.middleware.base import AbstractMiddleware + + +@lru_cache +def get_middleware() -> type[AbstractMiddleware]: + from supertokens_python import Supertokens + from supertokens_python.exceptions import SuperTokensError + from supertokens_python.framework.litestar.litestar_request import LitestarRequest + from supertokens_python.framework.litestar.litestar_response import LitestarResponse + from supertokens_python.recipe.session import SessionContainer + from supertokens_python.supertokens import manage_session_post_response + + from litestar import Response, Request + from litestar.middleware.base import AbstractMiddleware + from litestar.types import Scope, Receive, Send + + class Middleware(AbstractMiddleware): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + st = Supertokens.get_instance() + request = Request[Any, Any, Any](scope, receive, send) + + try: + result = await st.middleware( + LitestarRequest(request), + LitestarResponse(Response[Any](content=None)), + ) + except SuperTokensError as e: + result = await st.handle_supertokens_error( + LitestarRequest(request), + e, + LitestarResponse(Response[Any](content=None)), + ) + + if isinstance(result, LitestarResponse): + if ( + session_container := request.state.get("supertokens") + ) and isinstance(session_container, SessionContainer): + manage_session_post_response(session_container, result) + + await result.response(scope, receive, send) + return + + await self.app(scope, receive, send) + + return Middleware diff --git a/supertokens_python/framework/litestar/litestar_request.py b/supertokens_python/framework/litestar/litestar_request.py new file mode 100644 index 000000000..533d6c4ad --- /dev/null +++ b/supertokens_python/framework/litestar/litestar_request.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from supertokens_python.framework.request import BaseRequest + +if TYPE_CHECKING: + from litestar import Request + from supertokens_python.recipe.session.interfaces import SessionContainer + +try: + from litestar.exceptions import SerializationException +except ImportError: + SerializationException = Exception # type: ignore + + +class LitestarRequest(BaseRequest): + def __init__(self, request: Request[Any, Any, Any]): + super().__init__() + self.request = request + + def get_query_param(self, key: str, default: str | None = None) -> Any: + return self.request.query_params.get(key, default) # pyright: ignore + + def get_query_params(self) -> dict[str, list[Any]]: + return self.request.query_params.dict() # pyright: ignore + + async def json(self) -> Any: + try: + return await self.request.json() + except SerializationException: + return {} + + def method(self) -> str: + return self.request.method + + def get_cookie(self, key: str) -> str | None: + return self.request.cookies.get(key) + + def get_header(self, key: str) -> str | None: + return self.request.headers.get(key, None) + + def get_session(self) -> SessionContainer | None: + return self.request.state.supertokens + + def set_session(self, session: SessionContainer): + self.request.state.supertokens = session + + def set_session_as_none(self): + self.request.state.supertokens = None + + def get_path(self) -> str: + return self.request.url.path + + async def form_data(self) -> dict[str, list[Any]]: + return (await self.request.form()).dict() diff --git a/supertokens_python/framework/litestar/litestar_response.py b/supertokens_python/framework/litestar/litestar_response.py new file mode 100644 index 000000000..59378f679 --- /dev/null +++ b/supertokens_python/framework/litestar/litestar_response.py @@ -0,0 +1,72 @@ +from __future__ import annotations +from typing import Any, TYPE_CHECKING, cast +from typing_extensions import Literal +from supertokens_python.framework.response import BaseResponse + +if TYPE_CHECKING: + from litestar import Response + + +class LitestarResponse(BaseResponse): + def __init__(self, response: Response[Any]): + super().__init__({}) + self.response = response + self.original = response + self.parser_checked = False + self.response_sent = False + self.status_set = False + + def set_html_content(self, content: str): + if not self.response_sent: + body = bytes(content, "utf-8") + self.set_header("Content-Length", str(len(body))) + self.set_header("Content-Type", "text/html") + self.response.body = body + self.response_sent = True + + def set_cookie( + self, + key: str, + value: str, + expires: int, + path: str = "/", + domain: str | None = None, + secure: bool = False, + httponly: bool = False, + samesite: str = "lax", + ): + self.response.set_cookie( + key=key, + value=value, + expires=expires, + path=path, + domain=domain, + secure=secure, + httponly=httponly, + samesite=cast(Literal["lax", "strict", "none"], samesite), + ) + + def set_header(self, key: str, value: str): + self.response.set_header(key, value) + + def get_header(self, key: str) -> str | None: + return self.response.headers.get(key, None) + + def remove_header(self, key: str): + del self.response.headers[key] + + def set_status_code(self, status_code: int): + if not self.status_set: + self.response.status_code = status_code + self.status_code = status_code + self.status_set = True + + def set_json_content(self, content: dict[str, Any]): + if not self.response_sent: + from litestar.serialization import encode_json + + body = encode_json(content) + self.set_header("Content-Type", "application/json; charset=utf-8") + self.set_header("Content-Length", str(len(body))) + self.response.body = body + self.response_sent = True diff --git a/supertokens_python/framework/types.py b/supertokens_python/framework/types.py index 0e48bc490..e9d428c8e 100644 --- a/supertokens_python/framework/types.py +++ b/supertokens_python/framework/types.py @@ -18,8 +18,6 @@ from supertokens_python.framework.request import BaseRequest -frameworks = ["fastapi", "flask", "django"] - class FrameworkEnum(Enum): FASTAPI = 1 diff --git a/supertokens_python/recipe/session/asyncio/__init__.py b/supertokens_python/recipe/session/asyncio/__init__.py index f572c57a3..4c2e96f2c 100644 --- a/supertokens_python/recipe/session/asyncio/__init__.py +++ b/supertokens_python/recipe/session/asyncio/__init__.py @@ -65,7 +65,7 @@ async def create_new_session( if not hasattr(request, "wrapper_used") or not request.wrapper_used: request = FRAMEWORKS[ - SessionRecipe.get_instance().app_info.framework + SessionRecipe.get_instance().app_info.framework # pyright: ignore ].wrap_request(request) return await SessionRecipe.get_instance().recipe_implementation.create_new_session( @@ -251,7 +251,7 @@ async def get_session( user_context = {} if not hasattr(request, "wrapper_used") or not request.wrapper_used: request = FRAMEWORKS[ - SessionRecipe.get_instance().app_info.framework + SessionRecipe.get_instance().app_info.framework # pyright: ignore ].wrap_request(request) session_recipe_impl = SessionRecipe.get_instance().recipe_implementation @@ -278,7 +278,7 @@ async def refresh_session( user_context = {} if not hasattr(request, "wrapper_used") or not request.wrapper_used: request = FRAMEWORKS[ - SessionRecipe.get_instance().app_info.framework + SessionRecipe.get_instance().app_info.framework # pyright: ignore ].wrap_request(request) return await SessionRecipe.get_instance().recipe_implementation.refresh_session( @@ -331,6 +331,7 @@ async def get_session_information( ) -> Union[SessionInformationResult, None]: if user_context is None: user_context = {} + return await SessionRecipe.get_instance().recipe_implementation.get_session_information( session_handle, user_context ) diff --git a/supertokens_python/recipe/session/framework/litestar/__init__.py b/supertokens_python/recipe/session/framework/litestar/__init__.py new file mode 100644 index 000000000..5c53e19e5 --- /dev/null +++ b/supertokens_python/recipe/session/framework/litestar/__init__.py @@ -0,0 +1,44 @@ +from __future__ import annotations +from typing import Any, Callable, Coroutine, TYPE_CHECKING + +from supertokens_python.framework.litestar.litestar_request import LitestarRequest +from supertokens_python.recipe.session import SessionRecipe +from supertokens_python.types import MaybeAwaitable + +from ...interfaces import SessionContainer, SessionClaimValidator + +if TYPE_CHECKING: + from litestar import Request + + +def verify_session( + anti_csrf_check: bool | None = None, + session_required: bool = True, + override_global_claim_validators: Callable[ + [list[SessionClaimValidator], SessionContainer, dict[str, Any]], + MaybeAwaitable[list[SessionClaimValidator]], + ] + | None = None, + user_context: None | dict[str, Any] = None, +) -> Callable[..., Coroutine[Any, Any, SessionContainer | None]]: + async def func(request: Request[Any, Any, Any]) -> SessionContainer | None: + litestar_request = LitestarRequest(request) + recipe = SessionRecipe.get_instance() + session = await recipe.verify_session( + litestar_request, + anti_csrf_check, + session_required, + override_global_claim_validators, + user_context or {}, + ) + + if session: + litestar_request.set_session(session) + elif session_required: + raise RuntimeError("Should never come here") + else: + litestar_request.set_session_as_none() + + return litestar_request.get_session() + + return func diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 0a0e88227..b2d1d109a 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -20,7 +20,7 @@ from typing_extensions import Literal from supertokens_python.logger import get_maybe_none_as_str, log_debug_message - +from supertokens_python.types import SupportedFrameworks from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT, USER_DELETE, USERS from .exceptions import SuperTokensError from .interfaces import ( @@ -95,7 +95,7 @@ def __init__( app_name: str, api_domain: str, website_domain: str, - framework: Literal["fastapi", "flask", "django"], + framework: SupportedFrameworks, api_gateway_path: str, api_base_path: str, website_base_path: str, @@ -117,7 +117,7 @@ def __init__( self.website_base_path = NormalisedURLPath(website_base_path) if mode is not None: self.mode = mode - elif framework == "fastapi": + elif framework in ("fastapi", "litestar"): mode = "asgi" else: mode = "wsgi" @@ -145,7 +145,7 @@ class Supertokens: def __init__( self, app_info: InputAppInfo, - framework: Literal["fastapi", "flask", "django"], + framework: SupportedFrameworks, supertokens_config: SupertokensConfig, recipe_list: List[Callable[[AppInfo], RecipeModule]], mode: Union[Literal["asgi", "wsgi"], None], @@ -199,7 +199,7 @@ def __init__( @staticmethod def init( app_info: InputAppInfo, - framework: Literal["fastapi", "flask", "django"], + framework: SupportedFrameworks, supertokens_config: SupertokensConfig, recipe_list: List[Callable[[AppInfo], RecipeModule]], mode: Union[Literal["asgi", "wsgi"], None], diff --git a/supertokens_python/types.py b/supertokens_python/types.py index ed08effd2..6ed093566 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -13,9 +13,12 @@ # under the License. from abc import ABC, abstractmethod from typing import Any, Awaitable, Dict, List, TypeVar, Union +from typing_extensions import Literal _T = TypeVar("_T") +SupportedFrameworks = Literal["fastapi", "flask", "django", "litestar"] + class ThirdPartyInfo: def __init__(self, third_party_user_id: str, third_party_id: str): diff --git a/supertokens_python/utils.py b/supertokens_python/utils.py index 836958f75..c9ae0a2e8 100644 --- a/supertokens_python/utils.py +++ b/supertokens_python/utils.py @@ -41,13 +41,15 @@ from supertokens_python.framework.django.framework import DjangoFramework from supertokens_python.framework.fastapi.framework import FastapiFramework from supertokens_python.framework.flask.framework import FlaskFramework +from supertokens_python.framework.litestar.framework import LitestarFramework from supertokens_python.framework.request import BaseRequest from supertokens_python.framework.response import BaseResponse from supertokens_python.logger import log_debug_message from .constants import ERROR_MESSAGE_KEY, RID_KEY_HEADER from .exceptions import raise_general_exception -from .types import MaybeAwaitable +from .framework.types import Framework +from .types import MaybeAwaitable, SupportedFrameworks _T = TypeVar("_T") @@ -55,10 +57,11 @@ pass -FRAMEWORKS = { +FRAMEWORKS: dict[SupportedFrameworks, Framework] = { "fastapi": FastapiFramework(), "flask": FlaskFramework(), "django": DjangoFramework(), + "litestar": LitestarFramework(), } diff --git a/tests/litestar/__init__.py b/tests/litestar/__init__.py new file mode 100644 index 000000000..e0e2433cc --- /dev/null +++ b/tests/litestar/__init__.py @@ -0,0 +1,3 @@ +import nest_asyncio # type: ignore + +nest_asyncio.apply() # type: ignore diff --git a/tests/litestar/test_litestar.py b/tests/litestar/test_litestar.py new file mode 100644 index 000000000..24500529f --- /dev/null +++ b/tests/litestar/test_litestar.py @@ -0,0 +1,1013 @@ +# pyright: reportUnknownMemberType=false, reportGeneralTypeIssues=false +from __future__ import annotations +import json +from typing import Any, Dict, Union + +from litestar import get, post, Litestar, Request, MediaType +from litestar.di import Provide +from litestar.testing import TestClient +from pytest import fixture, mark, skip + +from supertokens_python import InputAppInfo, SupertokensConfig, init +from supertokens_python.framework import BaseRequest +from supertokens_python.framework.litestar import get_middleware +from supertokens_python.querier import Querier +from supertokens_python.recipe import emailpassword, session +from supertokens_python.recipe import thirdparty +from supertokens_python.recipe.dashboard import DashboardRecipe, InputOverrideConfig +from supertokens_python.recipe.dashboard.interfaces import RecipeInterface +from supertokens_python.recipe.dashboard.utils import DashboardConfig +from supertokens_python.recipe.emailpassword.interfaces import ( + APIInterface as EPAPIInterface, +) +from supertokens_python.recipe.emailpassword.interfaces import APIOptions +from supertokens_python.recipe.passwordless import PasswordlessRecipe, ContactConfig +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.asyncio import ( + create_new_session, + get_session, + refresh_session, +) +from supertokens_python.recipe.session.exceptions import UnauthorisedError +from supertokens_python.recipe.session.framework.litestar import verify_session +from supertokens_python.recipe.session.interfaces import APIInterface +from supertokens_python.recipe.session.interfaces import APIOptions as SessionAPIOptions +from supertokens_python.utils import is_version_gte +from tests.utils import ( + TEST_DRIVER_CONFIG_ACCESS_TOKEN_PATH, + TEST_DRIVER_CONFIG_COOKIE_DOMAIN, + TEST_DRIVER_CONFIG_COOKIE_SAME_SITE, + TEST_DRIVER_CONFIG_REFRESH_TOKEN_PATH, + assert_info_clears_tokens, + clean_st, + extract_all_cookies, + extract_info, + get_st_init_args, + reset, + setup_st, + start_st, + create_users, +) + + +def get_token_transfer_method(*args: Any) -> Any: + return "cookie" + + +def override_dashboard_functions(original_implementation: RecipeInterface): + async def should_allow_access( + request: BaseRequest, __: DashboardConfig, ___: Dict[str, Any] + ): + auth_header = request.get_header("authorization") + return auth_header == "Bearer testapikey" + + original_implementation.should_allow_access = should_allow_access + return original_implementation + + +def setup_function(_): + reset() + clean_st() + setup_st() + + +def teardown_function(_): + reset() + clean_st() + + +@fixture(scope="function") +async def litestar_test_client() -> TestClient[Litestar]: + @get("/login") + async def login(request: Request[Any, Any, Any]) -> dict[str, Any]: + user_id = "userId" + await create_new_session(request, user_id, {}, {}) + return {"userId": user_id} + + @post("/refresh") + async def custom_refresh(request: Request[Any, Any, Any]) -> dict[str, Any]: + await refresh_session(request) + return {} + + @get("/info") + async def info_get(request: Request[Any, Any, Any]) -> dict[str, Any]: + await get_session(request, True) + return {} + + @get("/custom/info") + def custom_info() -> dict[str, Any]: + return {} + + @get("/handle") + async def handle_get(request: Request[Any, Any, Any]) -> dict[str, Any]: + session: Union[None, SessionContainer] = await get_session(request, True) + if session is None: + raise RuntimeError("Should never come here") + return {"s": session.get_handle()} + + @get( + "/handle-session-optional", + dependencies={"session": Provide(verify_session(session_required=False))}, + ) + async def handle_get_optional(session: SessionContainer) -> dict[str, Any]: + + if session is None: + return {"s": "empty session"} + + return {"s": session.get_handle()} + + @post("/logout") + async def custom_logout(request: Request[Any, Any, Any]) -> dict[str, Any]: + session: Union[None, SessionContainer] = await get_session(request, True) + if session is None: + raise RuntimeError("Should never come here") + await session.revoke_session() + return {} + + @post("/create", media_type=MediaType.TEXT) + async def _create(request: Request[Any, Any, Any]) -> str: + await create_new_session(request, "userId", {}, {}) + return "" + + @post("/create-throw") + async def _create_throw(request: Request[Any, Any, Any]) -> None: + await create_new_session(request, "userId", {}, {}) + raise UnauthorisedError("unauthorised") + + app = Litestar( + route_handlers=[ + login, + custom_logout, + custom_refresh, + custom_info, + info_get, + handle_get, + handle_get_optional, + _create, + _create_throw, + ], + middleware=[get_middleware()], + ) + + return TestClient(app) + + +def apis_override_session(param: APIInterface): + param.disable_refresh_post = True + return param + + +@mark.asyncio +async def test_login_refresh(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + override=session.InputOverrideConfig(apis=apis_override_session), + ) + ], + mode="asgi", + ) + start_st() + + with litestar_test_client as client: + response_1 = client.get("/login") + cookies_1 = extract_all_cookies(response_1) + + assert response_1.headers.get("anti-csrf") is not None + assert cookies_1["sAccessToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sRefreshToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sAccessToken"]["path"] == TEST_DRIVER_CONFIG_ACCESS_TOKEN_PATH + assert cookies_1["sRefreshToken"]["path"] == TEST_DRIVER_CONFIG_REFRESH_TOKEN_PATH + assert cookies_1["sAccessToken"]["httponly"] + assert cookies_1["sRefreshToken"]["httponly"] + assert ( + cookies_1["sAccessToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert ( + cookies_1["sRefreshToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + + with litestar_test_client as client: + response_3 = client.post( + url="/refresh", + headers={"anti-csrf": response_1.headers.get("anti-csrf")}, + cookies={ + "sRefreshToken": cookies_1["sRefreshToken"]["value"], + }, + ) + cookies_3 = extract_all_cookies(response_3) + + assert cookies_3["sAccessToken"]["value"] != cookies_1["sAccessToken"]["value"] + assert cookies_3["sRefreshToken"]["value"] != cookies_1["sRefreshToken"]["value"] + assert response_3.headers.get("anti-csrf") is not None + assert cookies_3["sAccessToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_3["sRefreshToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_3["sRefreshToken"]["path"] == TEST_DRIVER_CONFIG_REFRESH_TOKEN_PATH + assert cookies_3["sAccessToken"]["httponly"] + assert cookies_3["sRefreshToken"]["httponly"] + assert ( + cookies_3["sAccessToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert ( + cookies_3["sRefreshToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + + +@mark.asyncio +async def test_login_logout(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ) + ], + mode="asgi", + ) + start_st() + + with litestar_test_client as client: + response_1 = client.get("/login") + cookies_1 = extract_all_cookies(response_1) + + assert response_1.headers.get("anti-csrf") is not None + assert cookies_1["sAccessToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sRefreshToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sAccessToken"]["path"] == TEST_DRIVER_CONFIG_ACCESS_TOKEN_PATH + assert cookies_1["sRefreshToken"]["path"] == TEST_DRIVER_CONFIG_REFRESH_TOKEN_PATH + assert cookies_1["sAccessToken"]["httponly"] + assert cookies_1["sRefreshToken"]["httponly"] + assert ( + cookies_1["sAccessToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert ( + cookies_1["sRefreshToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert cookies_1["sAccessToken"]["secure"] is None + assert cookies_1["sRefreshToken"]["secure"] is None + + with litestar_test_client as client: + response_2 = client.post( + url="/logout", + headers={"anti-csrf": response_1.headers.get("anti-csrf")}, + cookies={ + "sAccessToken": cookies_1["sAccessToken"]["value"], + }, + ) + cookies_2 = extract_all_cookies(response_2) + assert response_2.headers.get("anti-csrf") is None + assert cookies_2["sAccessToken"]["value"] == "" + assert cookies_2["sRefreshToken"]["value"] == "" + assert cookies_2["sAccessToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_2["sRefreshToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_2["sAccessToken"]["path"] == TEST_DRIVER_CONFIG_ACCESS_TOKEN_PATH + assert cookies_2["sAccessToken"]["httponly"] + assert cookies_2["sRefreshToken"]["httponly"] + assert ( + cookies_2["sAccessToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert ( + cookies_2["sRefreshToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert cookies_2["sAccessToken"]["secure"] is None + assert cookies_2["sRefreshToken"]["secure"] is None + + +@mark.asyncio +async def test_login_info(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ) + ], + mode="asgi", + ) + start_st() + + with litestar_test_client as client: + response_1 = client.get("/login") + cookies_1 = extract_all_cookies(response_1) + + assert response_1.headers.get("anti-csrf") is not None + assert cookies_1["sAccessToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sRefreshToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sAccessToken"]["path"] == TEST_DRIVER_CONFIG_ACCESS_TOKEN_PATH + assert cookies_1["sRefreshToken"]["path"] == TEST_DRIVER_CONFIG_REFRESH_TOKEN_PATH + assert cookies_1["sAccessToken"]["httponly"] + assert cookies_1["sRefreshToken"]["httponly"] + assert ( + cookies_1["sAccessToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert ( + cookies_1["sRefreshToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert cookies_1["sAccessToken"]["secure"] is None + assert cookies_1["sRefreshToken"]["secure"] is None + + with litestar_test_client as client: + response_2 = client.get( + url="/info", + headers={"anti-csrf": response_1.headers.get("anti-csrf")}, + cookies={ + "sAccessToken": cookies_1["sAccessToken"]["value"], + }, + ) + cookies_2 = extract_all_cookies(response_2) + assert not cookies_2 + + +@mark.asyncio +async def test_login_handle(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ) + ], + mode="asgi", + ) + start_st() + + with litestar_test_client as client: + response_1 = client.get("/login") + cookies_1 = extract_all_cookies(response_1) + + assert response_1.headers.get("anti-csrf") is not None + assert cookies_1["sAccessToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sRefreshToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sAccessToken"]["path"] == TEST_DRIVER_CONFIG_ACCESS_TOKEN_PATH + assert cookies_1["sRefreshToken"]["path"] == TEST_DRIVER_CONFIG_REFRESH_TOKEN_PATH + assert cookies_1["sAccessToken"]["httponly"] + assert cookies_1["sRefreshToken"]["httponly"] + assert ( + cookies_1["sAccessToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert ( + cookies_1["sRefreshToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert cookies_1["sAccessToken"]["secure"] is None + assert cookies_1["sRefreshToken"]["secure"] is None + + with litestar_test_client as client: + response_2 = client.get( + url="/handle", + headers={"anti-csrf": response_1.headers.get("anti-csrf")}, + cookies={ + "sAccessToken": cookies_1["sAccessToken"]["value"], + }, + ) + result_dict = json.loads(response_2.content) + assert "s" in result_dict + + +@mark.asyncio +async def test_login_refresh_error_handler(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ) + ], + mode="asgi", + ) + start_st() + + with litestar_test_client as client: + response_1 = client.get("/login") + cookies_1 = extract_all_cookies(response_1) + + assert response_1.headers.get("anti-csrf") is not None + assert cookies_1["sAccessToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sRefreshToken"]["domain"] == TEST_DRIVER_CONFIG_COOKIE_DOMAIN + assert cookies_1["sAccessToken"]["path"] == TEST_DRIVER_CONFIG_ACCESS_TOKEN_PATH + assert cookies_1["sRefreshToken"]["path"] == TEST_DRIVER_CONFIG_REFRESH_TOKEN_PATH + assert cookies_1["sAccessToken"]["httponly"] + assert cookies_1["sRefreshToken"]["httponly"] + assert ( + cookies_1["sAccessToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert ( + cookies_1["sRefreshToken"]["samesite"].lower() + == TEST_DRIVER_CONFIG_COOKIE_SAME_SITE + ) + assert cookies_1["sAccessToken"]["secure"] is None + assert cookies_1["sRefreshToken"]["secure"] is None + + with litestar_test_client as client: + response_3 = client.post( + url="/refresh", + headers={"anti-csrf": response_1.headers.get("anti-csrf")}, + cookies={ + # no cookies + }, + ) + assert response_3.status_code == 401 # not authorized because no refresh tokens + + +@mark.asyncio +async def test_custom_response(litestar_test_client: TestClient[Litestar]): + def override_email_password_apis(original_implementation: EPAPIInterface): + original_func = original_implementation.email_exists_get + + async def email_exists_get( + email: str, api_options: APIOptions, user_context: Dict[str, Any] + ): + response_dict = {"custom": True} + api_options.response.set_status_code(203) + api_options.response.set_json_content(response_dict) + return await original_func(email, api_options, user_context) + + original_implementation.email_exists_get = email_exists_get + return original_implementation + + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + emailpassword.init( + override=emailpassword.InputOverrideConfig( + apis=override_email_password_apis + ) + ) + ], + mode="asgi", + ) + start_st() + + with litestar_test_client as client: + response = client.get( + url="/auth/signup/email/exists?email=test@example.com", + ) + + dict_response = json.loads(response.text) + assert response.status_code == 203 + assert dict_response["custom"] + + +@mark.asyncio +async def test_optional_session(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[session.init(get_token_transfer_method=get_token_transfer_method)], + mode="asgi", + ) + start_st() + + with litestar_test_client as client: + response = client.get( + url="handle-session-optional", + ) + + dict_response = json.loads(response.text) + assert response.status_code == 200 + assert dict_response["s"] == "empty session" + + +@mark.asyncio +@mark.parametrize("token_transfer_method", ["cookie", "header"]) +async def test_should_clear_all_response_during_refresh_if_unauthorized( + litestar_test_client: TestClient[Litestar], token_transfer_method: str +): + def override_session_apis(oi: APIInterface): + oi_refresh_post = oi.refresh_post + + async def refresh_post( + api_options: SessionAPIOptions, user_context: Dict[str, Any] + ): + await oi_refresh_post(api_options, user_context) + raise UnauthorisedError("unauthorized", clear_tokens=True) + + oi.refresh_post = refresh_post + return oi + + init( + **get_st_init_args( + [ + session.init( + anti_csrf="VIA_TOKEN", + override=session.InputOverrideConfig(apis=override_session_apis), + ) + ] + ) + ) + start_st() + + with litestar_test_client as client: + res = client.post("/create", headers={"st-auth-mode": token_transfer_method}) + info = extract_info(res) # pyright: ignore + + assert info["accessTokenFromAny"] is not None + assert info["refreshTokenFromAny"] is not None + + headers: Dict[str, Any] = {} + cookies: Dict[str, Any] = {} + + if token_transfer_method == "header": + headers.update({"authorization": f"Bearer {info['refreshTokenFromAny']}"}) + else: + cookies.update( + {"sRefreshToken": info["refreshTokenFromAny"], "sIdRefreshToken": "asdf"} + ) + + if info["antiCsrf"] is not None: + headers.update({"anti-csrf": info["antiCsrf"]}) + + with litestar_test_client as client: + res = client.post("/auth/session/refresh", headers=headers, cookies=cookies) + info = extract_info(res) # pyright: ignore + + assert res.status_code == 401 + assert_info_clears_tokens(info, token_transfer_method) + + +@mark.asyncio +@mark.parametrize("token_transfer_method", ["cookie", "header"]) +async def test_revoking_session_after_create_new_session_with_throwing_unauthorized_error( + litestar_test_client: TestClient[Litestar], token_transfer_method: str +): + init( + **get_st_init_args( + [ + session.init( + anti_csrf="VIA_TOKEN", + ) + ] + ) + ) + start_st() + + with litestar_test_client as client: + res = client.post( + "/create-throw", headers={"st-auth-mode": token_transfer_method} + ) + info = extract_info(res) # pyright: ignore + + assert res.status_code == 401 + assert_info_clears_tokens(info, token_transfer_method) + + +@mark.asyncio +async def test_search_with_email_t(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ), + DashboardRecipe.init( + api_key="testapikey", + override=InputOverrideConfig(functions=override_dashboard_functions), + ), + emailpassword.init(), + ], + mode="asgi", + ) + start_st() + querier = Querier.get_instance(DashboardRecipe.recipe_id) + cdi_version = await querier.get_api_version() + if not cdi_version: + skip() + if not is_version_gte(cdi_version, "2.20"): + skip() + await create_users(emailpassword=True) + query = {"limit": "10", "email": "t"} + with litestar_test_client as client: + res = client.get( + "/auth/dashboard/api/users", + headers={ + "Authorization": "Bearer testapikey", + "Content-Type": "application/json", + }, + params=query, + ) + info = extract_info(res) # pyright: ignore + assert res.status_code == 200 + assert len(info["body"]["users"]) == 5 + + +@mark.asyncio +async def test_search_with_email_multiple_email_entry( + litestar_test_client: TestClient[Litestar], +): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ), + DashboardRecipe.init( + api_key="testapikey", + override=InputOverrideConfig(functions=override_dashboard_functions), + ), + emailpassword.init(), + ], + mode="asgi", + ) + start_st() + querier = Querier.get_instance(DashboardRecipe.recipe_id) + cdi_version = await querier.get_api_version() + if not cdi_version: + skip() + if not is_version_gte(cdi_version, "2.20"): + skip() + await create_users(emailpassword=True) + query = {"limit": "10", "email": "iresh;john"} + with litestar_test_client as client: + res = client.get( + "/auth/dashboard/api/users", + headers={ + "Authorization": "Bearer testapikey", + "Content-Type": "application/json", + }, + params=query, + ) + info = extract_info(res) # pyright: ignore + assert res.status_code == 200 + assert len(info["body"]["users"]) == 1 + + +@mark.asyncio +async def test_search_with_email_iresh(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ), + DashboardRecipe.init( + api_key="testapikey", + override=InputOverrideConfig(functions=override_dashboard_functions), + ), + emailpassword.init(), + ], + mode="asgi", + ) + start_st() + querier = Querier.get_instance(DashboardRecipe.recipe_id) + cdi_version = await querier.get_api_version() + if not cdi_version: + skip() + if not is_version_gte(cdi_version, "2.20"): + skip() + await create_users(emailpassword=True) + query = {"limit": "10", "email": "iresh"} + with litestar_test_client as client: + res = client.get( + "/auth/dashboard/api/users", + headers={ + "Authorization": "Bearer testapikey", + "Content-Type": "application/json", + }, + params=query, + ) + info = extract_info(res) # pyright: ignore + assert res.status_code == 200 + assert len(info["body"]["users"]) == 0 + + +@mark.asyncio +async def test_search_with_phone_plus_one(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ), + DashboardRecipe.init( + api_key="testapikey", + override=InputOverrideConfig(functions=override_dashboard_functions), + ), + PasswordlessRecipe.init( + contact_config=ContactConfig(contact_method="EMAIL"), + flow_type="USER_INPUT_CODE", + ), + ], + mode="asgi", + ) + start_st() + querier = Querier.get_instance(DashboardRecipe.recipe_id) + cdi_version = await querier.get_api_version() + if not cdi_version: + skip() + if not is_version_gte(cdi_version, "2.20"): + skip() + await create_users(passwordless=True) + query = {"limit": "10", "phone": "+1"} + with litestar_test_client as client: + res = client.get( + "/auth/dashboard/api/users", + headers={ + "Authorization": "Bearer testapikey", + "Content-Type": "application/json", + }, + params=query, + ) + info = extract_info(res) # pyright: ignore + assert res.status_code == 200 + assert len(info["body"]["users"]) == 3 + + +@mark.asyncio +async def test_search_with_phone_one_bracket( + litestar_test_client: TestClient[Litestar], +): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ), + DashboardRecipe.init( + api_key="testapikey", + override=InputOverrideConfig(functions=override_dashboard_functions), + ), + PasswordlessRecipe.init( + contact_config=ContactConfig(contact_method="EMAIL"), + flow_type="USER_INPUT_CODE", + ), + ], + mode="asgi", + ) + start_st() + querier = Querier.get_instance(DashboardRecipe.recipe_id) + cdi_version = await querier.get_api_version() + if not cdi_version: + skip() + if not is_version_gte(cdi_version, "2.20"): + skip() + await create_users(passwordless=True) + query = {"limit": "10", "phone": "1("} + with litestar_test_client as client: + res = client.get( + "/auth/dashboard/api/users", + headers={ + "Authorization": "Bearer testapikey", + "Content-Type": "application/json", + }, + params=query, + ) + info = extract_info(res) # pyright: ignore + assert res.status_code == 200 + assert len(info["body"]["users"]) == 0 + + +@mark.asyncio +async def test_search_with_provider_google(litestar_test_client: TestClient[Litestar]): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ), + DashboardRecipe.init( + api_key="testapikey", + override=InputOverrideConfig(functions=override_dashboard_functions), + ), + thirdparty.init( + sign_in_and_up_feature=thirdparty.SignInAndUpFeature( + providers=[ + thirdparty.Apple( + client_id="4398792-io.supertokens.example.service", + client_key_id="7M48Y4RYDL", + client_team_id="YWQCXGJRJL", + client_private_key="-----BEGIN PRIVATE KEY-----\nMIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgu8gXs+XYkqXD6Ala9Sf/iJXzhbwcoG5dMh1OonpdJUmgCgYIKoZIzj0DAQehRANCAASfrvlFbFCYqn3I2zeknYXLwtH30JuOKestDbSfZYxZNMqhF/OzdZFTV0zc5u5s3eN+oCWbnvl0hM+9IW0UlkdA\n-----END PRIVATE KEY-----", + ), + thirdparty.Google( + client_id="467101b197249757c71f", + client_secret="e97051221f4b6426e8fe8d51486396703012f5bd", + ), + thirdparty.Github( + client_id="1060725074195-kmeum4crr01uirfl2op9kd5acmi9jutn.apps.googleusercontent.com", + client_secret="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", + ), + ] + ) + ), + ], + mode="asgi", + ) + start_st() + querier = Querier.get_instance(DashboardRecipe.recipe_id) + cdi_version = await querier.get_api_version() + if not cdi_version: + skip() + if not is_version_gte(cdi_version, "2.20"): + skip() + await create_users(thirdparty=True) + query = {"limit": "10", "provider": "google"} + with litestar_test_client as client: + res = client.get( + "/auth/dashboard/api/users", + headers={ + "Authorization": "Bearer testapikey", + "Content-Type": "application/json", + }, + params=query, + ) + info = extract_info(res) # pyright: ignore + assert res.status_code == 200 + assert len(info["body"]["users"]) == 3 + + +@mark.asyncio +async def test_search_with_provider_google_and_phone_1( + litestar_test_client: TestClient[Litestar], +): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + api_base_path="/auth", + ), + framework="litestar", + recipe_list=[ + session.init( + anti_csrf="VIA_TOKEN", + cookie_domain="supertokens.io", + get_token_transfer_method=get_token_transfer_method, + ), + DashboardRecipe.init( + api_key="testapikey", + override=InputOverrideConfig(functions=override_dashboard_functions), + ), + PasswordlessRecipe.init( + contact_config=ContactConfig(contact_method="EMAIL"), + flow_type="USER_INPUT_CODE", + ), + thirdparty.init( + sign_in_and_up_feature=thirdparty.SignInAndUpFeature( + providers=[ + thirdparty.Apple( + client_id="4398792-io.supertokens.example.service", + client_key_id="7M48Y4RYDL", + client_team_id="YWQCXGJRJL", + client_private_key="-----BEGIN PRIVATE KEY-----\nMIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgu8gXs+XYkqXD6Ala9Sf/iJXzhbwcoG5dMh1OonpdJUmgCgYIKoZIzj0DAQehRANCAASfrvlFbFCYqn3I2zeknYXLwtH30JuOKestDbSfZYxZNMqhF/OzdZFTV0zc5u5s3eN+oCWbnvl0hM+9IW0UlkdA\n-----END PRIVATE KEY-----", + ), + thirdparty.Google( + client_id="467101b197249757c71f", + client_secret="e97051221f4b6426e8fe8d51486396703012f5bd", + ), + thirdparty.Github( + client_id="1060725074195-kmeum4crr01uirfl2op9kd5acmi9jutn.apps.googleusercontent.com", + client_secret="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", + ), + ] + ) + ), + ], + mode="asgi", + ) + start_st() + querier = Querier.get_instance(DashboardRecipe.recipe_id) + cdi_version = await querier.get_api_version() + if not cdi_version: + skip() + if not is_version_gte(cdi_version, "2.20"): + skip() + await create_users(thirdparty=True, passwordless=True) + query = {"limit": "10", "provider": "google", "phone": "1"} + with litestar_test_client as client: + res = client.get( + "/auth/dashboard/api/users", + headers={ + "Authorization": "Bearer testapikey", + "Content-Type": "application/json", + }, + params=query, + ) + info = extract_info(res) # pyright: ignore + assert res.status_code == 200 + assert len(info["body"]["users"]) == 0 From cec2b63ad040c99e64fac402bd5e3eca1a0bf740 Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sat, 22 Apr 2023 07:42:29 +0200 Subject: [PATCH 2/2] updated tests --- tests/litestar/test_litestar.py | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/tests/litestar/test_litestar.py b/tests/litestar/test_litestar.py index 24500529f..b51bfd4b6 100644 --- a/tests/litestar/test_litestar.py +++ b/tests/litestar/test_litestar.py @@ -55,7 +55,7 @@ def get_token_transfer_method(*args: Any) -> Any: def override_dashboard_functions(original_implementation: RecipeInterface): - async def should_allow_access( + def should_allow_access( request: BaseRequest, __: DashboardConfig, ___: Dict[str, Any] ): auth_header = request.get_header("authorization") @@ -77,7 +77,7 @@ def teardown_function(_): @fixture(scope="function") -async def litestar_test_client() -> TestClient[Litestar]: +def litestar_test_client() -> TestClient[Litestar]: @get("/login") async def login(request: Request[Any, Any, Any]) -> dict[str, Any]: user_id = "userId" @@ -109,7 +109,7 @@ async def handle_get(request: Request[Any, Any, Any]) -> dict[str, Any]: "/handle-session-optional", dependencies={"session": Provide(verify_session(session_required=False))}, ) - async def handle_get_optional(session: SessionContainer) -> dict[str, Any]: + def handle_get_optional(session: SessionContainer) -> dict[str, Any]: if session is None: return {"s": "empty session"} @@ -157,8 +157,7 @@ def apis_override_session(param: APIInterface): return param -@mark.asyncio -async def test_login_refresh(litestar_test_client: TestClient[Litestar]): +def test_login_refresh(litestar_test_client: TestClient[Litestar]): init( supertokens_config=SupertokensConfig("http://localhost:3567"), app_info=InputAppInfo( @@ -228,8 +227,7 @@ async def test_login_refresh(litestar_test_client: TestClient[Litestar]): ) -@mark.asyncio -async def test_login_logout(litestar_test_client: TestClient[Litestar]): +def test_login_logout(litestar_test_client: TestClient[Litestar]): init( supertokens_config=SupertokensConfig("http://localhost:3567"), app_info=InputAppInfo( @@ -301,8 +299,7 @@ async def test_login_logout(litestar_test_client: TestClient[Litestar]): assert cookies_2["sRefreshToken"]["secure"] is None -@mark.asyncio -async def test_login_info(litestar_test_client: TestClient[Litestar]): +def test_login_info(litestar_test_client: TestClient[Litestar]): init( supertokens_config=SupertokensConfig("http://localhost:3567"), app_info=InputAppInfo( @@ -357,8 +354,7 @@ async def test_login_info(litestar_test_client: TestClient[Litestar]): assert not cookies_2 -@mark.asyncio -async def test_login_handle(litestar_test_client: TestClient[Litestar]): +def test_login_handle(litestar_test_client: TestClient[Litestar]): init( supertokens_config=SupertokensConfig("http://localhost:3567"), app_info=InputAppInfo( @@ -413,8 +409,7 @@ async def test_login_handle(litestar_test_client: TestClient[Litestar]): assert "s" in result_dict -@mark.asyncio -async def test_login_refresh_error_handler(litestar_test_client: TestClient[Litestar]): +def test_login_refresh_error_handler(litestar_test_client: TestClient[Litestar]): init( supertokens_config=SupertokensConfig("http://localhost:3567"), app_info=InputAppInfo( @@ -468,8 +463,7 @@ async def test_login_refresh_error_handler(litestar_test_client: TestClient[Lite assert response_3.status_code == 401 # not authorized because no refresh tokens -@mark.asyncio -async def test_custom_response(litestar_test_client: TestClient[Litestar]): +def test_custom_response(litestar_test_client: TestClient[Litestar]): def override_email_password_apis(original_implementation: EPAPIInterface): original_func = original_implementation.email_exists_get @@ -514,8 +508,7 @@ async def email_exists_get( assert dict_response["custom"] -@mark.asyncio -async def test_optional_session(litestar_test_client: TestClient[Litestar]): +def test_optional_session(litestar_test_client: TestClient[Litestar]): init( supertokens_config=SupertokensConfig("http://localhost:3567"), app_info=InputAppInfo( @@ -540,9 +533,8 @@ async def test_optional_session(litestar_test_client: TestClient[Litestar]): assert dict_response["s"] == "empty session" -@mark.asyncio @mark.parametrize("token_transfer_method", ["cookie", "header"]) -async def test_should_clear_all_response_during_refresh_if_unauthorized( +def test_should_clear_all_response_during_refresh_if_unauthorized( litestar_test_client: TestClient[Litestar], token_transfer_method: str ): def override_session_apis(oi: APIInterface): @@ -597,9 +589,8 @@ async def refresh_post( assert_info_clears_tokens(info, token_transfer_method) -@mark.asyncio @mark.parametrize("token_transfer_method", ["cookie", "header"]) -async def test_revoking_session_after_create_new_session_with_throwing_unauthorized_error( +def test_revoking_session_after_create_new_session_with_throwing_unauthorized_error( litestar_test_client: TestClient[Litestar], token_transfer_method: str ): init(