Skip to content

Commit

Permalink
fix: remaining impl
Browse files Browse the repository at this point in the history
  • Loading branch information
sattvikc committed Dec 4, 2024
1 parent 79194a4 commit fbee6d6
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def end_session_post(

EndSessionCallable = Callable[
[Dict[str, str], APIOptions, Optional[SessionContainer], bool, Dict[str, Any]],
Awaitable[Union[Dict[str, str], ErrorOAuth2Response, GeneralErrorResponse]],
Awaitable[Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]],
]


Expand Down
271 changes: 269 additions & 2 deletions supertokens_python/recipe/oauth2provider/api/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,275 @@
# License for the specific language governing permissions and limitations
# under the License.

from ..interfaces import APIInterface
from typing import Any, Dict, List, Optional, Union

from supertokens_python.recipe.session import SessionContainer
from supertokens_python.types import GeneralErrorResponse, User

from .utils import (
handle_login_internal_redirects,
handle_logout_internal_redirects,
login_get,
)
from ..interfaces import (
APIInterface,
APIOptions,
ActiveTokenResponse,
ErrorOAuth2Response,
FrontendRedirectResponse,
InactiveTokenResponse,
LoginInfo,
RedirectResponse,
RevokeTokenUsingAuthorizationHeader,
RevokeTokenUsingClientIDAndClientSecret,
TokenInfo,
)


class APIImplementation(APIInterface):
pass
async def login_get(
self,
login_challenge: str,
options: APIOptions,
session: Optional[SessionContainer] = None,
should_try_refresh: bool = False,
user_context: Dict[str, Any] = {},
) -> Union[FrontendRedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]:
response = await login_get(
recipe_implementation=options.recipe_implementation,
login_challenge=login_challenge,
session=session,
should_try_refresh=should_try_refresh,
is_direct_call=True,
user_context=user_context,
)

if isinstance(response, ErrorOAuth2Response):
return response

resp_after_internal_redirects = await handle_login_internal_redirects(
response=response,
cookie=options.request.get_header("cookie") or "",
recipe_implementation=options.recipe_implementation,
session=session,
should_try_refresh=should_try_refresh,
user_context=user_context,
)

if isinstance(resp_after_internal_redirects, ErrorOAuth2Response):
return resp_after_internal_redirects

return FrontendRedirectResponse(
frontend_redirect_to=resp_after_internal_redirects.redirect_to,
cookies=resp_after_internal_redirects.cookies,
)

async def auth_get(
self,
params: Any,
cookie: Optional[str],
session: Optional[SessionContainer],
should_try_refresh: bool,
options: APIOptions,
user_context: Dict[str, Any] = {},
) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]:
response = await options.recipe_implementation.authorization(
params=params,
cookies=cookie,
session=session,
user_context=user_context,
)

if isinstance(response, ErrorOAuth2Response):
return response

return await handle_login_internal_redirects(
response=response,
recipe_implementation=options.recipe_implementation,
cookie=cookie or "",
session=session,
should_try_refresh=should_try_refresh,
user_context=user_context,
)

async def token_post(
self,
authorization_header: Optional[str],
body: Any,
options: APIOptions,
user_context: Dict[str, Any] = {},
) -> Union[TokenInfo, ErrorOAuth2Response, GeneralErrorResponse]:
return await options.recipe_implementation.token_exchange(
authorization_header=authorization_header,
body=body,
user_context=user_context,
)

async def login_info_get(
self,
login_challenge: str,
options: APIOptions,
user_context: Dict[str, Any] = {},
) -> Union[LoginInfo, ErrorOAuth2Response, GeneralErrorResponse]:
login_res = await options.recipe_implementation.get_login_request(
challenge=login_challenge,
user_context=user_context,
)

if isinstance(login_res, ErrorOAuth2Response):
return login_res

client = login_res.client

return LoginInfo(
client_id=client.client_id,
client_name=client.client_name,
tos_uri=client.tos_uri,
policy_uri=client.policy_uri,
logo_uri=client.logo_uri,
client_uri=client.client_uri,
metadata=client.metadata,
)

async def user_info_get(
self,
access_token_payload: Dict[str, Any],
user: User,
scopes: List[str],
tenant_id: str,
options: APIOptions,
user_context: Dict[str, Any] = {},
) -> Union[Dict[str, Any], GeneralErrorResponse]:
return await options.recipe_implementation.build_user_info(
user=user,
access_token_payload=access_token_payload,
scopes=scopes,
tenant_id=tenant_id,
user_context=user_context,
)

async def revoke_token_post(
self,
token: str,
options: APIOptions,
user_context: Dict[str, Any] = {},
authorization_header: Optional[str] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
) -> Union[None, ErrorOAuth2Response, GeneralErrorResponse]:
if authorization_header is not None:
return await options.recipe_implementation.revoke_token(
input=RevokeTokenUsingAuthorizationHeader(
token=token,
authorization_header=authorization_header,
),
user_context=user_context,
)
elif client_id is not None:
if client_secret is None:
raise Exception("client_secret is required")

return await options.recipe_implementation.revoke_token(
input=RevokeTokenUsingClientIDAndClientSecret(
token=token,
client_id=client_id,
client_secret=client_secret,
),
user_context=user_context,
)
else:
raise Exception(
"Either of 'authorization_header' or 'client_id' must be provided"
)

async def introspect_token_post(
self,
token: str,
scopes: Optional[List[str]],
options: APIOptions,
user_context: Dict[str, Any] = {},
) -> Union[ActiveTokenResponse, InactiveTokenResponse, GeneralErrorResponse]:
return await options.recipe_implementation.introspect_token(
token=token,
scopes=scopes,
user_context=user_context,
)

async def end_session_get(
self,
params: Dict[str, str],
options: APIOptions,
session: Optional[SessionContainer] = None,
should_try_refresh: bool = False,
user_context: Dict[str, Any] = {},
) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]:
response = await options.recipe_implementation.end_session(
params=params,
session=session,
should_try_refresh=should_try_refresh,
user_context=user_context,
)

if isinstance(response, ErrorOAuth2Response):
return response

return await handle_logout_internal_redirects(
response=response,
session=session,
recipe_implementation=options.recipe_implementation,
user_context=user_context,
)

async def end_session_post(
self,
params: Dict[str, str],
options: APIOptions,
session: Optional[SessionContainer] = None,
should_try_refresh: bool = False,
user_context: Dict[str, Any] = {},
) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]:
response = await options.recipe_implementation.end_session(
params=params,
session=session,
should_try_refresh=should_try_refresh,
user_context=user_context,
)

if isinstance(response, ErrorOAuth2Response):
return response

return await handle_logout_internal_redirects(
response=response,
session=session,
recipe_implementation=options.recipe_implementation,
user_context=user_context,
)

async def logout_post(
self,
logout_challenge: str,
options: APIOptions,
session: Optional[SessionContainer] = None,
user_context: Dict[str, Any] = {},
) -> Union[FrontendRedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]:
if session is not None:
await session.revoke_session(user_context)

response = await options.recipe_implementation.accept_logout_request(
challenge=logout_challenge,
user_context=user_context,
)

if isinstance(response, ErrorOAuth2Response):
return response

res = await handle_logout_internal_redirects(
response=response,
recipe_implementation=options.recipe_implementation,
user_context=user_context,
)

if isinstance(res, ErrorOAuth2Response):
return res

return FrontendRedirectResponse(frontend_redirect_to=res.redirect_to)
6 changes: 3 additions & 3 deletions supertokens_python/recipe/oauth2provider/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ async def introspect_token(
token: str,
scopes: Optional[List[str]] = None,
user_context: Dict[str, Any] = {},
) -> Dict[str, Any]:
) -> Union[ActiveTokenResponse, InactiveTokenResponse]:
pass

@abstractmethod
Expand Down Expand Up @@ -726,7 +726,7 @@ async def end_session_get(
session: Optional[SessionContainer] = None,
should_try_refresh: bool = False,
user_context: Dict[str, Any] = {},
) -> Union[Dict[str, str], ErrorOAuth2Response, GeneralErrorResponse]:
) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]:
pass

@abstractmethod
Expand All @@ -737,7 +737,7 @@ async def end_session_post(
session: Optional[SessionContainer] = None,
should_try_refresh: bool = False,
user_context: Dict[str, Any] = {},
) -> Union[Dict[str, str], ErrorOAuth2Response, GeneralErrorResponse]:
) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]:
pass

@abstractmethod
Expand Down
19 changes: 12 additions & 7 deletions supertokens_python/recipe/oauth2provider/recipe_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
OAuth2Client,
TokenInfo,
UserInfoBuilderFunction,
ActiveTokenResponse,
InactiveTokenResponse,
)


Expand Down Expand Up @@ -443,11 +445,11 @@ async def token_exchange(
user_context=user_context,
)

if token_info.get("active"):
session_handle = token_info["sessionHandle"]
if isinstance(token_info, ActiveTokenResponse):
session_handle = token_info.payload["sessionHandle"]

client_info = await self.get_oauth2_client(
client_id=token_info["client_id"], user_context=user_context
client_id=token_info.payload["client_id"], user_context=user_context
)

if isinstance(client_info, ErrorOAuth2Response):
Expand All @@ -458,7 +460,7 @@ async def token_exchange(
)

client = client_info.client
user = await get_user(token_info["sub"])
user = await get_user(token_info.payload["sub"])

if not user:
return ErrorOAuth2Response(
Expand Down Expand Up @@ -826,7 +828,7 @@ async def introspect_token(
token: str,
scopes: Optional[List[str]] = None,
user_context: Dict[str, Any] = {},
) -> Dict[str, Any]:
) -> Union[ActiveTokenResponse, InactiveTokenResponse]:
# Determine if the token is an access token by checking if it doesn't start with "st_rt"
is_access_token = not token.startswith("st_rt")

Expand All @@ -845,7 +847,7 @@ async def introspect_token(
user_context=user_context,
)
except Exception:
return {"active": False}
return InactiveTokenResponse()

# For tokens that passed local validation or if it's a refresh token,
# validate the token with the database by calling the core introspection endpoint
Expand All @@ -858,7 +860,10 @@ async def introspect_token(
user_context=user_context,
)

return res
if res.get("active"):
return ActiveTokenResponse(payload=res)
else:
return InactiveTokenResponse()

async def end_session(
self,
Expand Down

0 comments on commit fbee6d6

Please sign in to comment.