Skip to content

Commit

Permalink
fix: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sattvikc committed Dec 13, 2024
1 parent 213b9fe commit 92a7a0b
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 24 deletions.
4 changes: 3 additions & 1 deletion supertokens_python/framework/flask/flask_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def set_json_content(self, content: Dict[str, Any]):
self.response_sent = True

def redirect(self, url: str) -> BaseResponse:
self.response.headers.set("Location", url)
self.set_header("Location", url)
self.set_status_code(302)
self.response.data = b""
self.response_sent = True
return self
4 changes: 2 additions & 2 deletions supertokens_python/recipe/oauth2provider/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ async def auth_get(
return None

original_url = api_options.request.get_original_url()
split_url = original_url.split("?")
params = dict(parse_qsl(split_url[1])) if len(split_url) > 1 else {}
split_url = original_url.split("?", 1)
params = dict(parse_qsl(split_url[1], True)) if len(split_url) > 1 else {}

session = None
should_try_refresh = False
Expand Down
10 changes: 6 additions & 4 deletions supertokens_python/recipe/oauth2provider/api/end_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ async def end_session_get(
return None

orig_url = api_options.request.get_original_url()
split_url = orig_url.split("?")
params = dict(urllib.parse.parse_qsl(split_url[1]))
split_url = orig_url.split("?", 1)
params = (
dict(urllib.parse.parse_qsl(split_url[1], True)) if len(split_url) > 1 else {}
)

return await end_session_common(
params, api_implementation.end_session_get, api_options, user_context
Expand Down Expand Up @@ -110,9 +112,9 @@ async def end_session_common(
)

if isinstance(response, RedirectResponse):
options.response.redirect(response.redirect_to)
return options.response.redirect(response.redirect_to)
elif isinstance(response, ErrorOAuth2Response):
send_non_200_response(
return send_non_200_response(
{
"error": response.error,
"error_description": response.error_description,
Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/recipe/oauth2provider/api/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def token_post(

authorization_header = api_options.request.get_header("authorization")

body = await api_options.request.json()
body = await api_options.request.get_json_or_form_data()

response = await api_implementation.token_post(
authorization_header=authorization_header,
Expand Down
10 changes: 8 additions & 2 deletions supertokens_python/recipe/oauth2provider/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,6 @@ def to_json(self) -> Dict[str, Any]:
result: Dict[str, Any] = {
"frontendRedirectTo": self.frontend_redirect_to,
}
if self.cookies is not None:
result["cookies"] = self.cookies
return result


Expand Down Expand Up @@ -337,6 +335,14 @@ def __init__(
self.scopes = scopes
self.audience = audience

@staticmethod
def from_json(json: Dict[str, Any]):
return OAuth2TokenValidationRequirements(
client_id=json.get("clientId"),
scopes=json.get("scopes"),
audience=json.get("audience"),
)


class FrontendRedirectionURLTypeLogin:
def __init__(
Expand Down
33 changes: 24 additions & 9 deletions supertokens_python/recipe/oauth2provider/recipe_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,9 @@ async def authorization(

payloads = None

if not params.get("client_id") or not isinstance(params.get("client_id"), str):
if params.get("client_id") is None or not isinstance(
params.get("client_id"), str
):
return ErrorOAuth2Response(
status_code=400,
error="invalid_request",
Expand Down Expand Up @@ -644,7 +646,11 @@ async def validate_oauth2_access_token(
token,
matching_key.key,
algorithms=["RS256"],
options={"verify_signature": True, "verify_exp": True},
options={
"verify_signature": True,
"verify_exp": True,
"verify_aud": False,
},
)
except Exception as e:
err = e
Expand Down Expand Up @@ -908,15 +914,24 @@ async def end_session(
# CASE 3: end_session request with a logout_verifier (after accepting the logout request)
# - Redirects to the post_logout_redirect_uri or the default logout fallback page.

request_body: Dict[str, Any] = {}

if params.get("client_id") is not None:
request_body["clientId"] = params.get("client_id")
if params.get("id_token_hint") is not None:
request_body["idTokenHint"] = params.get("id_token_hint")
if params.get("post_logout_redirect_uri") is not None:
request_body["postLogoutRedirectUri"] = params.get(
"post_logout_redirect_uri"
)
if params.get("state") is not None:
request_body["state"] = params.get("state")
if params.get("logout_verifier") is not None:
request_body["logoutVerifier"] = params.get("logout_verifier")

resp = await self.querier.send_get_request(
NormalisedURLPath("/recipe/oauth/sessions/logout"),
{
"clientId": params.get("client_id"),
"idTokenHint": params.get("id_token_hint"),
"postLogoutRedirectUri": params.get("post_logout_redirect_uri"),
"state": params.get("state"),
"logoutVerifier": params.get("logout_verifier"),
},
request_body,
user_context=user_context,
)

Expand Down
21 changes: 16 additions & 5 deletions tests/test-server/oauth2provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from flask import Flask, request, jsonify
from supertokens_python.recipe.oauth2provider.interfaces import CreateOAuth2ClientInput
from supertokens_python.recipe.oauth2provider.interfaces import (
CreateOAuth2ClientInput,
OAuth2TokenValidationRequirements,
UpdateOAuth2ClientInput,
)
import supertokens_python.recipe.oauth2provider.syncio as OAuth2Provider


Expand Down Expand Up @@ -38,7 +42,8 @@ def update_oauth2_client_api(): # type: ignore
print("OAuth2Provider:updateOAuth2Client", request.json)

response = OAuth2Provider.update_oauth2_client(
params=request.json["input"], user_context=request.json.get("userContext")
params=UpdateOAuth2ClientInput.from_json(request.json.get("input", {})),
user_context=request.json.get("userContext"),
)
return jsonify(response.to_json())

Expand All @@ -60,11 +65,17 @@ def validate_oauth2_access_token_api(): # type: ignore

response = OAuth2Provider.validate_oauth2_access_token(
token=request.json["token"],
requirements=request.json["requirements"],
check_database=request.json["checkDatabase"],
requirements=(
OAuth2TokenValidationRequirements.from_json(
request.json["requirements"]
)
if "requirements" in request.json
else None
),
check_database=request.json.get("checkDatabase"),
user_context=request.json.get("userContext"),
)
return jsonify(response)
return jsonify({**response, "status": "OK"})

@app.route("/test/oauth2provider/validateoauth2refreshtoken", methods=["POST"]) # type: ignore
def validate_oauth2_refresh_token_api(): # type: ignore
Expand Down

0 comments on commit 92a7a0b

Please sign in to comment.