From 724c97bcf9c660a2e40302d56bd8e61ad30453f5 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Thu, 12 Dec 2024 17:31:36 +0530 Subject: [PATCH] fix: tests --- .../recipe/oauth2provider/interfaces.py | 18 ++++++----- .../oauth2provider/recipe_implementation.py | 30 ++++++++++++++----- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index bbbb1e59..c903e3bc 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -166,24 +166,28 @@ def __init__( @staticmethod def from_json(json: Dict[str, Any]): return TokenInfo( - access_token=json["access_token"], + access_token=json.get("access_token"), expires_in=json["expires_in"], - id_token=json["id_token"], - refresh_token=json["refresh_token"], + id_token=json.get("id_token"), + refresh_token=json.get("refresh_token"), scope=json["scope"], token_type=json["token_type"], ) def to_json(self) -> Dict[str, Any]: - return { + result = { "status": "OK", - "access_token": self.access_token, "expires_in": self.expires_in, - "id_token": self.id_token, - "refresh_token": self.refresh_token, "scope": self.scope, "token_type": self.token_type, } + if self.access_token is not None: + result["access_token"] = self.access_token + if self.id_token is not None: + result["id_token"] = self.id_token + if self.refresh_token is not None: + result["refresh_token"] = self.refresh_token + return result class LoginInfo: diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index e12d811e..a3252e22 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -633,12 +633,26 @@ async def validate_oauth2_access_token( # Verify token signature using session recipe's JWKS session_recipe = SessionRecipe.get_instance() matching_keys = get_latest_keys(session_recipe.config) - payload = jwt.decode( - token, - matching_keys[0].key, - algorithms=["RS256"], - options={"verify_signature": True, "verify_exp": True}, - ) + err: Optional[Exception] = None + + payload: Dict[str, Any] = {} + + for matching_key in matching_keys: + err = None + try: + payload = jwt.decode( + token, + matching_key.key, + algorithms=["RS256"], + options={"verify_signature": True, "verify_exp": True}, + ) + except Exception as e: + err = e + continue + break + + if err is not None: + raise err if payload.get("stt") != 1: raise Exception("Wrong token type") @@ -845,7 +859,7 @@ async def introspect_token( # If it fails, the token is not active, and we return early if is_access_token: try: - payload = await self.validate_oauth2_access_token( + await self.validate_oauth2_access_token( token=token, requirements=( OAuth2TokenValidationRequirements(scopes=scopes) @@ -855,7 +869,7 @@ async def introspect_token( check_database=False, user_context=user_context, ) - return ActiveTokenResponse(payload=payload) + except Exception: return InactiveTokenResponse()