Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add cache control for jwks endpoint #429

Merged
merged 10 commits into from
Sep 27, 2023
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0



## [0.15.3] - 2023-09-24
## [0.15.3] - 2023-09-25

- Handle 429 rate limiting from SaaS core instances
- Add `Cache-Control` header for jwks endpoint `/jwt/jwks.json`
- Add `validity_in_secs` to the return value of overridable `get_jwks` recipe function.
- This can be used to control the `Cache-Control` header mentioned above.
- It defaults to `60` or the value set in the cache-control header returned by the core
- This is optional (so you are not required to update your overrides). Returning `None` means that the header won't be set


## [0.15.2] - 2023-09-23

Expand Down
21 changes: 14 additions & 7 deletions supertokens_python/querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def __get_headers_with_api_version(self, path: NormalisedURLPath):

async def send_get_request(
self, path: NormalisedURLPath, params: Union[Dict[str, Any], None] = None
):
) -> Dict[str, Any]:
if params is None:
params = {}

Expand All @@ -149,7 +149,7 @@ async def send_post_request(
path: NormalisedURLPath,
data: Union[Dict[str, Any], None] = None,
test: bool = False,
):
) -> Dict[str, Any]:
if data is None:
data = {}

Expand All @@ -171,7 +171,7 @@ async def f(url: str) -> Response:

async def send_delete_request(
self, path: NormalisedURLPath, params: Union[Dict[str, Any], None] = None
):
) -> Dict[str, Any]:
if params is None:
params = {}

Expand All @@ -187,7 +187,7 @@ async def f(url: str) -> Response:

async def send_put_request(
self, path: NormalisedURLPath, data: Union[Dict[str, Any], None] = None
):
) -> Dict[str, Any]:
if data is None:
data = {}

Expand Down Expand Up @@ -226,7 +226,7 @@ async def __send_request_helper(
http_function: Callable[[str], Awaitable[Response]],
no_of_tries: int,
retry_info_map: Optional[Dict[str, int]] = None,
) -> Any:
) -> Dict[str, Any]:
if no_of_tries == 0:
raise_general_exception("No SuperTokens core available to query")

Expand Down Expand Up @@ -285,10 +285,17 @@ async def __send_request_helper(
+ response.text # type: ignore
)

res: Dict[str, Any] = {"_headers": dict(response.headers)}

if response.headers.get("content-type", "").startswith("text"):
res["_text"] = response.text

try:
return response.json()
res.update(response.json())
except JSONDecodeError:
return response.text
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is changing the semantics of how things were done. Please fix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


return res

except (ConnectionError, NetworkError, ConnectTimeout) as _:
return await self.__send_request_helper(
Expand Down
6 changes: 6 additions & 0 deletions supertokens_python/recipe/jwt/api/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,10 @@ async def jwks_get(
self, api_options: APIOptions, user_context: Dict[str, Any]
) -> JWKSGetResponse:
response = await api_options.recipe_implementation.get_jwks(user_context)

if response.validity_in_secs is not None:
api_options.response.set_header(
"Cache-Control", f"max-age={response.validity_in_secs}, must-revalidate"
)

return JWKSGetResponse(response.keys)
rishabhpoddar marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion supertokens_python/recipe/jwt/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ class CreateJwtResultUnsupportedAlgorithm:


class GetJWKSResult:
rishabhpoddar marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, keys: List[JsonWebKey]):
def __init__(self, keys: List[JsonWebKey], validity_in_secs: Optional[int]):
rishabhpoddar marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, keys: List[JsonWebKey], validity_in_secs: Optional[int]):
def __init__(self, keys: List[JsonWebKey], validity_in_secs: Optional[int] = None):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self.keys = keys
self.validity_in_secs = validity_in_secs


class RecipeInterface(ABC):
Expand Down
21 changes: 20 additions & 1 deletion supertokens_python/recipe/jwt/recipe_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from supertokens_python.normalised_url_path import NormalisedURLPath
from supertokens_python.querier import Querier
import re

if TYPE_CHECKING:
from .utils import JWTConfig
Expand All @@ -32,6 +33,10 @@
from .interfaces import JsonWebKey


# This corresponds to the dynamicSigningKeyOverlapMS in the core
DEFAULT_JWKS_MAX_AGE = 60


class RecipeImplementation(RecipeInterface):
def __init__(self, querier: Querier, config: JWTConfig, app_info: AppInfo):
super().__init__()
Expand Down Expand Up @@ -69,11 +74,25 @@ async def get_jwks(self, user_context: Dict[str, Any]) -> GetJWKSResult:
NormalisedURLPath("/.well-known/jwks.json"), {}
)

validity_in_secs = DEFAULT_JWKS_MAX_AGE
cache_control = response["_headers"].get("Cache-Control")

if cache_control is not None:
pattern = r",?\s*max-age=(\d+)(?:,|$)"
max_age_header = re.match(pattern, cache_control)
if max_age_header is not None:
validity_in_secs = int(max_age_header.group(1))
try:
validity_in_secs = int(validity_in_secs)
except Exception:
validity_in_secs = DEFAULT_JWKS_MAX_AGE

keys: List[JsonWebKey] = []
for key in response["keys"]:
keys.append(
JsonWebKey(
key["kty"], key["kid"], key["n"], key["e"], key["alg"], key["use"]
)
)
return GetJWKSResult(keys)

return GetJWKSResult(keys, validity_in_secs)
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ async def associate_user_to_tenant(
AssociateUserToTenantPhoneNumberAlreadyExistsError,
AssociateUserToTenantThirdPartyUserAlreadyExistsError,
]:
response: Dict[str, Any] = await self.querier.send_post_request(
response = await self.querier.send_post_request(
NormalisedURLPath(
f"{tenant_id or DEFAULT_TENANT_ID}/recipe/multitenancy/tenant/user"
),
Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/recipe/session/recipe_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ async def regenerate_access_token(
) -> Union[RegenerateAccessTokenOkResult, None]:
if new_access_token_payload is None:
new_access_token_payload = {}
response: Dict[str, Any] = await self.querier.send_post_request(
response = await self.querier.send_post_request(
NormalisedURLPath("/recipe/session/regenerate"),
{"accessToken": access_token, "userDataInJWT": new_access_token_payload},
)
Expand Down
4 changes: 0 additions & 4 deletions supertokens_python/recipe/session/session_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ async def create_new_session(
},
)

response.pop("status", None)

return CreateOrRefreshAPIResponse(
CreateOrRefreshAPIResponseSession(
response["session"]["handle"],
Expand Down Expand Up @@ -281,7 +279,6 @@ async def get_session(
NormalisedURLPath("/recipe/session/verify"), data
)
if response["status"] == "OK":
response.pop("status", None)
return GetSessionAPIResponse(
GetSessionAPIResponseSession(
response["session"]["handle"],
Expand Down Expand Up @@ -351,7 +348,6 @@ async def refresh_session(
NormalisedURLPath("/recipe/session/refresh"), data
)
if response["status"] == "OK":
response.pop("status", None)
return CreateOrRefreshAPIResponse(
CreateOrRefreshAPIResponseSession(
response["session"]["handle"],
Expand Down
8 changes: 4 additions & 4 deletions supertokens_python/recipe/userroles/recipe_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def add_role_to_user(
NormalisedURLPath(f"{tenant_id}/recipe/user/role"),
params,
)
if response.get("status") == "OK":
if response["status"] == "OK":
return AddRoleToUserOkResult(
did_user_already_have_role=response["didUserAlreadyHaveRole"]
)
Expand Down Expand Up @@ -93,7 +93,7 @@ async def get_users_that_have_role(
NormalisedURLPath(f"{tenant_id}/recipe/role/users"),
params,
)
if response.get("status") == "OK":
if response["status"] == "OK":
return GetUsersThatHaveRoleOkResult(users=response["users"])
return UnknownRoleError()

Expand All @@ -115,7 +115,7 @@ async def get_permissions_for_role(
response = await self.querier.send_get_request(
NormalisedURLPath("/recipe/role/permissions"), params
)
if response.get("status") == "OK":
if response["status"] == "OK":
return GetPermissionsForRoleOkResult(permissions=response["permissions"])
return UnknownRoleError()

Expand All @@ -126,7 +126,7 @@ async def remove_permissions_from_role(
response = await self.querier.send_post_request(
NormalisedURLPath("/recipe/role/permissions/remove"), params
)
if response.get("status") == "OK":
if response["status"] == "OK":
return RemovePermissionsFromRoleOkResult()
return UnknownRoleError()

Expand Down
42 changes: 40 additions & 2 deletions tests/jwt/test_get_JWKS.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

from _pytest.fixtures import fixture
from fastapi import FastAPI
from typing import Optional, Dict, Any
from pytest import mark
from starlette.requests import Request
from starlette.testclient import TestClient
from supertokens_python import InputAppInfo, SupertokensConfig, init
from supertokens_python.framework.fastapi import get_middleware
from supertokens_python.recipe import jwt
from supertokens_python.recipe.jwt.interfaces import APIInterface
from supertokens_python.recipe.jwt.interfaces import APIInterface, RecipeInterface
from supertokens_python.recipe.session.asyncio import create_new_session
from tests.utils import clean_st, reset, setup_st, start_st

Expand Down Expand Up @@ -83,6 +84,20 @@ async def test_that_default_getJWKS_api_does_not_work_when_disabled(


async def test_that_default_getJWKS_works_fine(driver_config_client: TestClient):
custom_validity: Optional[int] = -1 # -1 means no override

def func_override(oi: RecipeInterface):
oi_get_jwks = oi.get_jwks

async def get_jwks(user_context: Dict[str, Any]):
res = await oi_get_jwks(user_context)
if custom_validity != -1:
res.validity_in_secs = custom_validity
return res

oi.get_jwks = get_jwks
return oi

init(
supertokens_config=SupertokensConfig("http://localhost:3567"),
app_info=InputAppInfo(
Expand All @@ -91,12 +106,35 @@ async def test_that_default_getJWKS_works_fine(driver_config_client: TestClient)
website_domain="supertokens.io",
),
framework="fastapi",
recipe_list=[jwt.init()],
recipe_list=[jwt.init(override=jwt.OverrideConfig(functions=func_override))],
)
start_st()

response = driver_config_client.get(url="/auth/jwt/jwks.json")

# Default:
assert response.status_code == 200
data = response.json()
assert len(data["keys"]) > 0

assert response.headers["cache-control"] == "max-age=60, must-revalidate"

# Override cache control:
custom_validity = 1
response = driver_config_client.get(url="/auth/jwt/jwks.json")

assert response.status_code == 200
data = response.json()
assert len(data["keys"]) > 0

assert response.headers["cache-control"] == "max-age=1, must-revalidate"

# Disable cache control:
custom_validity = None
response = driver_config_client.get(url="/auth/jwt/jwks.json")

assert response.status_code == 200
data = response.json()
assert len(data["keys"]) > 0

assert "cache-control" not in response.headers
42 changes: 42 additions & 0 deletions tests/test_querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import asyncio
import respx
import httpx
import json
from supertokens_python import init, SupertokensConfig
from supertokens_python.querier import Querier, NormalisedURLPath

Expand Down Expand Up @@ -148,3 +149,44 @@ async def call_api(id_: int):
assert call_count2 == 6

assert api.call_count == 12


async def test_querier_text_and_headers():
args = get_st_init_args([session.init()])
args["supertokens_config"] = SupertokensConfig("http://localhost:6789")
init(**args) # type: ignore
start_st()

Querier.api_version = "3.0"
q = Querier.get_instance()

with respx_mock() as mocker:
text = "foo"
mocker.get("http://localhost:6789/text-api").mock(
httpx.Response(200, text=text, headers={"greet": "hello"})
)

res = await q.send_get_request(NormalisedURLPath("/text-api"), {})
assert res == {
"_text": "foo",
"_headers": {
"greet": "hello",
"content-type": "text/plain; charset=utf-8",
"content-length": str(len("foo")),
},
}

body = {"bar": "baz"}
mocker.get("http://localhost:6789/json-api").mock(
httpx.Response(200, json=body, headers={"greet": "hi"})
)

res = await q.send_get_request(NormalisedURLPath("/json-api"), {})
assert res == {
"bar": "baz",
"_headers": {
"greet": "hi",
"content-type": "application/json",
"content-length": str(len(json.dumps(body))),
},
}