diff --git a/docs/conf.py b/docs/conf.py index 13b43c800b..db77ddf4ff 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -215,6 +215,7 @@ (PY_RE, r".*R"), (PY_OBJ, r"litestar.security.jwt.auth.TokenT"), (PY_CLASS, "ExceptionToProblemDetailMapType"), + (PY_CLASS, "litestar.security.jwt.token.JWTDecodeOptions"), ] # Warnings about missing references to those targets in the specified location will be ignored. diff --git a/docs/examples/security/jwt/custom_decode_payload.py b/docs/examples/security/jwt/custom_decode_payload.py new file mode 100644 index 0000000000..6ae2df58d7 --- /dev/null +++ b/docs/examples/security/jwt/custom_decode_payload.py @@ -0,0 +1,28 @@ +import dataclasses +from typing import Any, List, Optional, Sequence, Union + +from litestar.security.jwt.token import JWTDecodeOptions, Token + + +@dataclasses.dataclass +class CustomToken(Token): + @classmethod + def decode_payload( + cls, + encoded_token: str, + secret: str, + algorithms: List[str], + issuer: Optional[List[str]] = None, + audience: Union[str, Sequence[str], None] = None, + options: Optional[JWTDecodeOptions] = None, + ) -> Any: + payload = super().decode_payload( + encoded_token=encoded_token, + secret=secret, + algorithms=algorithms, + issuer=issuer, + audience=audience, + options=options, + ) + payload["sub"] = payload["sub"].split("@", maxsplit=1)[1] + return payload diff --git a/docs/examples/security/jwt/verify_issuer_audience.py b/docs/examples/security/jwt/verify_issuer_audience.py new file mode 100644 index 0000000000..2b340fc0e9 --- /dev/null +++ b/docs/examples/security/jwt/verify_issuer_audience.py @@ -0,0 +1,32 @@ +import dataclasses +import secrets +from typing import Any, Dict + +from litestar import Litestar, Request, get +from litestar.connection import ASGIConnection +from litestar.security.jwt import JWTAuth, Token + + +@dataclasses.dataclass +class User: + id: str + + +async def retrieve_user_handler(token: Token, connection: ASGIConnection) -> User: + return User(id=token.sub) + + +jwt_auth = JWTAuth[User]( + token_secret=secrets.token_hex(), + retrieve_user_handler=retrieve_user_handler, + accepted_audiences=["https://api.testserver.local"], + accepted_issuers=["https://auth.testserver.local"], +) + + +@get("/") +def handler(request: Request[User, Token, Any]) -> Dict[str, Any]: + return {"id": request.user.id} + + +app = Litestar([handler], middleware=[jwt_auth.middleware]) diff --git a/docs/usage/security/jwt.rst b/docs/usage/security/jwt.rst index c1a96610df..bc0d23e9e3 100644 --- a/docs/usage/security/jwt.rst +++ b/docs/usage/security/jwt.rst @@ -65,3 +65,32 @@ conversions. converting the token. To support more complex conversions, the :meth:`~.security.jwt.Token.encode` and :meth:`~.security.jwt.Token.decode` methods must be overwritten in the subclass. + + +Verifying issuer and audience +----------------------------- + +To verify the JWT ``iss`` (*issuer*) and ``aud`` (*audience*) claim, a list of accepted +issuers or audiences can bet set on the authentication backend. When a JWT is decoded, +the issuer or audience on the token is compared to the list of accepted issuers / +audiences. If the value in the token does not match any value in the respective list, +a :exc:`NotAuthorizedException` will be raised, returning a response with a +``401 Unauthorized`` status. + + +.. literalinclude:: /examples/security/jwt/verify_issuer_audience.py + :caption: Verifying issuer and audience + + +Customizing token validation +---------------------------- + +Token decoding / validation can be further customized by overriding the +:meth:`~.security.jwt.Token.decode_payload` method. It will be called by +:meth:`~.security.jwt.Token.decode` with the encoded token string, and must return a +dictionary representing the decoded payload, which will then used by +:meth:`~.security.jwt.Token.decode` to construct an instance of the token class. + + +.. literalinclude:: /examples/security/jwt/custom_decode_payload.py + :caption: Customizing payload decoding diff --git a/litestar/security/jwt/auth.py b/litestar/security/jwt/auth.py index c1758a9bcd..3ebf9767c5 100644 --- a/litestar/security/jwt/auth.py +++ b/litestar/security/jwt/auth.py @@ -68,6 +68,27 @@ class BaseJWTAuth(Generic[UserType, TokenT], AbstractSecurityConfig[UserType, To """ token_cls: type[Token] = Token """Target type the JWT payload will be converted into""" + accepted_audiences: Sequence[str] | None = None + """Audiences to accept when verifying the token. If given, and the audience in the + token does not match, a 401 response is returned + """ + accepted_issuers: Sequence[str] | None = None + """Issuers to accept when verifying the token. If given, and the issuer in the + token does not match, a 401 response is returned + """ + require_claims: Sequence[str] | None = None + """Require these claims to be present in the JWT payload. If any of those claims + is missing, a 401 response is returned + """ + verify_expiry: bool = True + """Verify that the value of the ``exp`` (*expiration*) claim is in the future""" + verify_not_before: bool = True + """Verify that the value of the ``nbf`` (*not before*) claim is in the past""" + strict_audience: bool = False + """Verify that the value of the ``aud`` (*audience*) claim is a single value, and + not a list of values, and matches ``audience`` exactly. Requires that + ``accepted_audiences`` is a sequence of length 1 + """ @property def openapi_components(self) -> Components: @@ -120,6 +141,12 @@ def middleware(self) -> DefineMiddleware: scopes=self.scopes, token_secret=self.token_secret, token_cls=self.token_cls, + token_issuer=self.accepted_issuers, + token_audience=self.accepted_audiences, + require_claims=self.require_claims, + verify_expiry=self.verify_expiry, + verify_not_before=self.verify_not_before, + strict_audience=self.strict_audience, ) def login( @@ -290,6 +317,27 @@ class JWTAuth(Generic[UserType, TokenT], BaseJWTAuth[UserType, TokenT]): """ token_cls: type[Token] = Token """Target type the JWT payload will be converted into""" + accepted_audiences: Sequence[str] | None = None + """Audiences to accept when verifying the token. If given, and the audience in the + token does not match, a 401 response is returned + """ + accepted_issuers: Sequence[str] | None = None + """Issuers to accept when verifying the token. If given, and the issuer in the + token does not match, a 401 response is returned + """ + require_claims: Sequence[str] | None = None + """Require these claims to be present in the JWT payload. If any of those claims + is missing, a 401 response is returned + """ + verify_expiry: bool = True + """Verify that the value of the ``exp`` (*expiration*) claim is in the future""" + verify_not_before: bool = True + """Verify that the value of the ``nbf`` (*not before*) claim is in the past""" + strict_audience: bool = False + """Verify that the value of the ``aud`` (*audience*) claim is a single value, and + not a list of values, and matches ``audience`` exactly. Requires that + ``accepted_audiences`` is a sequence of length 1 + """ @dataclass @@ -370,6 +418,27 @@ class and adds support for passing JWT tokens ``HttpOnly`` cookies. """ token_cls: type[Token] = Token """Target type the JWT payload will be converted into""" + accepted_audiences: Sequence[str] | None = None + """Audiences to accept when verifying the token. If given, and the audience in the + token does not match, a 401 response is returned + """ + accepted_issuers: Sequence[str] | None = None + """Issuers to accept when verifying the token. If given, and the issuer in the + token does not match, a 401 response is returned + """ + require_claims: Sequence[str] | None = None + """Require these claims to be present in the JWT payload. If any of those claims + is missing, a 401 response is returned + """ + verify_expiry: bool = True + """Verify that the value of the ``exp`` (*expiration*) claim is in the future""" + verify_not_before: bool = True + """Verify that the value of the ``nbf`` (*not before*) claim is in the past""" + strict_audience: bool = False + """Verify that the value of the ``aud`` (*audience*) claim is a single value, and + not a list of values, and matches ``audience`` exactly. Requires that + ``accepted_audiences`` is a sequence of length 1 + """ @property def openapi_components(self) -> Components: @@ -411,6 +480,12 @@ def middleware(self) -> DefineMiddleware: scopes=self.scopes, token_secret=self.token_secret, token_cls=self.token_cls, + token_issuer=self.accepted_issuers, + token_audience=self.accepted_audiences, + require_claims=self.require_claims, + verify_expiry=self.verify_expiry, + verify_not_before=self.verify_not_before, + strict_audience=self.strict_audience, ) def login( @@ -579,6 +654,27 @@ class OAuth2PasswordBearerAuth(Generic[UserType, TokenT], BaseJWTAuth[UserType, """ token_cls: type[Token] = Token """Target type the JWT payload will be converted into""" + accepted_audiences: Sequence[str] | None = None + """Audiences to accept when verifying the token. If given, and the audience in the + token does not match, a 401 response is returned + """ + accepted_issuers: Sequence[str] | None = None + """Issuers to accept when verifying the token. If given, and the issuer in the + token does not match, a 401 response is returned + """ + require_claims: Sequence[str] | None = None + """Require these claims to be present in the JWT payload. If any of those claims + is missing, a 401 response is returned + """ + verify_expiry: bool = True + """Verify that the value of the ``exp`` (*expiration*) claim is in the future""" + verify_not_before: bool = True + """Verify that the value of the ``nbf`` (*not before*) claim is in the past""" + strict_audience: bool = False + """Verify that the value of the ``aud`` (*audience*) claim is a single value, and + not a list of values, and matches ``audience`` exactly. Requires that + ``accepted_audiences`` is a sequence of length 1 + """ @property def middleware(self) -> DefineMiddleware: @@ -600,6 +696,12 @@ def middleware(self) -> DefineMiddleware: scopes=self.scopes, token_secret=self.token_secret, token_cls=self.token_cls, + token_issuer=self.accepted_issuers, + token_audience=self.accepted_audiences, + require_claims=self.require_claims, + verify_expiry=self.verify_expiry, + verify_not_before=self.verify_not_before, + strict_audience=self.strict_audience, ) @property diff --git a/litestar/security/jwt/middleware.py b/litestar/security/jwt/middleware.py index 1acd03d1c7..6426a2158f 100644 --- a/litestar/security/jwt/middleware.py +++ b/litestar/security/jwt/middleware.py @@ -31,6 +31,12 @@ class JWTAuthenticationMiddleware(AbstractAuthenticationMiddleware): "retrieve_user_handler", "token_secret", "token_cls", + "token_audience", + "token_issuer", + "require_claims", + "verify_expiry", + "verify_not_before", + "strict_audience", ) def __init__( @@ -45,6 +51,12 @@ def __init__( scopes: Scopes, token_secret: str, token_cls: type[Token] = Token, + token_audience: Sequence[str] | None = None, + token_issuer: Sequence[str] | None = None, + require_claims: Sequence[str] | None = None, + verify_expiry: bool = True, + verify_not_before: bool = True, + strict_audience: bool = False, ) -> None: """Check incoming requests for an encoded token in the auth header specified, and if present retrieve the user from persistence using the provided function. @@ -62,6 +74,18 @@ def __init__( token_secret: Secret for decoding the JWT. This value should be equivalent to the secret used to encode it. token_cls: Token class used when encoding / decoding JWTs + token_audience: Verify the audience when decoding the token. If the audience + in the token does not match any audience given, raise a + :exc:`NotAuthorizedException` + token_issuer: Verify the issuer when decoding the token. If the issuer in + the token does not match any issuer given, raise a + :exc:`NotAuthorizedException` + require_claims: Require these claims to be present in the JWT payload + verify_expiry: Verify that the value of the ``exp`` (*expiration*) claim is in the future + verify_not_before: Verify that the value of the ``nbf`` (*not before*) claim is in the past + strict_audience: Verify that the value of the ``aud`` (*audience*) claim is a single value, and + not a list of values, and matches ``audience`` exactly. Requires that + ``accepted_audiences`` is a sequence of length 1 """ super().__init__( app=app, @@ -75,6 +99,12 @@ def __init__( self.retrieve_user_handler = retrieve_user_handler self.token_secret = token_secret self.token_cls = token_cls + self.token_audience = token_audience + self.token_issuer = token_issuer + self.require_claims = require_claims + self.verify_expiry = verify_expiry + self.verify_not_before = verify_not_before + self.strict_audience = strict_audience async def authenticate_request(self, connection: ASGIConnection[Any, Any, Any, Any]) -> AuthenticationResult: """Given an HTTP Connection, parse the JWT api key stored in the header and retrieve the user correlating to the @@ -114,6 +144,12 @@ async def authenticate_token( encoded_token=encoded_token, secret=self.token_secret, algorithm=self.algorithm, + audience=self.token_audience, + issuer=self.token_issuer, + require_claims=self.require_claims, + verify_exp=self.verify_expiry, + verify_nbf=self.verify_not_before, + strict_audience=self.strict_audience, ) user = await self.retrieve_user_handler(token, connection) @@ -142,6 +178,12 @@ def __init__( scopes: Scopes, token_secret: str, token_cls: type[Token] = Token, + token_audience: Sequence[str] | None = None, + token_issuer: Sequence[str] | None = None, + require_claims: Sequence[str] | None = None, + verify_expiry: bool = True, + verify_not_before: bool = True, + strict_audience: bool = False, ) -> None: """Check incoming requests for an encoded token in the auth header or cookie name specified, and if present retrieves the user from persistence using the provided function. @@ -160,6 +202,18 @@ def __init__( token_secret: Secret for decoding the JWT. This value should be equivalent to the secret used to encode it. token_cls: Token class used when encoding / decoding JWTs + token_audience: Verify the audience when decoding the token. If the audience + in the token does not match any audience given, raise a + :exc:`NotAuthorizedException` + token_issuer: Verify the issuer when decoding the token. If the issuer in + the token does not match any issuer given, raise a + :exc:`NotAuthorizedException` + require_claims: Require these claims to be present in the JWT payload + verify_expiry: Verify that the value of the ``exp`` (*expiration*) claim is in the future + verify_not_before: Verify that the value of the ``nbf`` (*not before*) claim is in the past + strict_audience: Verify that the value of the ``aud`` (*audience*) claim is a single value, and + not a list of values, and matches ``audience`` exactly. Requires that + ``accepted_audiences`` is a sequence of length 1 """ super().__init__( algorithm=algorithm, @@ -172,6 +226,12 @@ def __init__( scopes=scopes, token_secret=token_secret, token_cls=token_cls, + token_audience=token_audience, + token_issuer=token_issuer, + require_claims=require_claims, + verify_expiry=verify_expiry, + verify_not_before=verify_not_before, + strict_audience=strict_audience, ) self.auth_cookie_key = auth_cookie_key diff --git a/litestar/security/jwt/token.py b/litestar/security/jwt/token.py index 215b1f2362..a7df89d8c4 100644 --- a/litestar/security/jwt/token.py +++ b/litestar/security/jwt/token.py @@ -3,7 +3,7 @@ import dataclasses from dataclasses import asdict, dataclass, field from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, TypedDict import jwt import msgspec @@ -13,8 +13,10 @@ if TYPE_CHECKING: from typing_extensions import Self - -__all__ = ("Token",) +__all__ = ( + "Token", + "JWTDecodeOptions", +) def _normalize_datetime(value: datetime) -> datetime: @@ -32,6 +34,17 @@ def _normalize_datetime(value: datetime) -> datetime: return value.replace(microsecond=0) +class JWTDecodeOptions(TypedDict, total=False): + """``options`` for PyJWTs :func:`jwt.decode`""" + + verify_aud: bool + verify_iss: bool + verify_exp: bool + verify_nbf: bool + strict_aud: bool + require: list[str] + + @dataclass class Token: """JWT Token DTO.""" @@ -72,13 +85,59 @@ def __post_init__(self) -> None: raise ImproperlyConfiguredException("iat must be a current or past time") @classmethod - def decode(cls, encoded_token: str, secret: str, algorithm: str) -> Self: - """Decode a passed in token string and returns a Token instance. + def decode_payload( + cls, + encoded_token: str, + secret: str, + algorithms: list[str], + issuer: list[str] | None = None, + audience: str | Sequence[str] | None = None, + options: JWTDecodeOptions | None = None, + ) -> Any: + """Decode and verify the JWT and return its payload""" + return jwt.decode( + jwt=encoded_token, + key=secret, + algorithms=algorithms, + issuer=issuer, + audience=audience, + options=options, # type: ignore[arg-type] + ) + + @classmethod + def decode( + cls, + encoded_token: str, + secret: str, + algorithm: str, + audience: str | Sequence[str] | None = None, + issuer: str | Sequence[str] | None = None, + require_claims: Sequence[str] | None = None, + verify_exp: bool = True, + verify_nbf: bool = True, + strict_audience: bool = False, + ) -> Self: + """Decode a passed in token string and return a Token instance. Args: encoded_token: A base64 string containing an encoded JWT. - secret: The secret with which the JWT is encoded. It may optionally be an individual JWK or JWS set dict + secret: The secret with which the JWT is encoded. algorithm: The algorithm used to encode the JWT. + audience: Verify the audience when decoding the token. If the audience in + the token does not match any audience given, raise a + :exc:`NotAuthorizedException` + issuer: Verify the issuer when decoding the token. If the issuer in the + token does not match any issuer given, raise a + :exc:`NotAuthorizedException` + require_claims: Verify that the given claims are present in the token + verify_exp: Verify that the value of the ``exp`` (*expiration*) claim is in + the future + verify_nbf: Verify that the value of the ``nbf`` (*not before*) claim is in + the past + strict_audience: Verify that the value of the ``aud`` (*audience*) claim is + a single value, and not a list of values, and matches ``audience`` + exactly. Requires the value passed to the ``audience`` to be a sequence + of length 1 Returns: A decoded Token instance. @@ -86,12 +145,34 @@ def decode(cls, encoded_token: str, secret: str, algorithm: str) -> Self: Raises: NotAuthorizedException: If the token is invalid. """ + + options: JWTDecodeOptions = { + "verify_aud": bool(audience), + "verify_iss": bool(issuer), + } + if require_claims: + options["require"] = list(require_claims) + if verify_exp is False: + options["verify_exp"] = False + if verify_nbf is False: + options["verify_nbf"] = False + if strict_audience: + if audience is None or (not isinstance(audience, str) and len(audience) != 1): + raise ValueError("When using 'strict_audience=True', 'audience' must be a sequence of length 1") + options["strict_aud"] = True + # although not documented, pyjwt requires audience to be a string if + # using the strict_aud option + if not isinstance(audience, str): + audience = audience[0] + try: - payload: dict[str, Any] = jwt.decode( - jwt=encoded_token, - key=secret, + payload = cls.decode_payload( + encoded_token=encoded_token, + secret=secret, algorithms=[algorithm], - options={"verify_aud": False}, + audience=audience, + issuer=list(issuer) if issuer else None, + options=options, ) # msgspec can do these conversions as well, but to keep backwards # compatibility, we do it ourselves, since the datetime parsing works a @@ -105,9 +186,8 @@ def decode(cls, encoded_token: str, secret: str, algorithm: str) -> Self: return msgspec.convert(payload, cls, strict=False) except ( KeyError, - jwt.DecodeError, + jwt.exceptions.InvalidTokenError, ImproperlyConfiguredException, - jwt.exceptions.InvalidAlgorithmError, msgspec.ValidationError, ) as e: raise NotAuthorizedException("Invalid token") from e diff --git a/tests/examples/test_security/test_jwt/test_verify_issuer_audience.py b/tests/examples/test_security/test_jwt/test_verify_issuer_audience.py new file mode 100644 index 0000000000..c6dec15017 --- /dev/null +++ b/tests/examples/test_security/test_jwt/test_verify_issuer_audience.py @@ -0,0 +1,18 @@ +from litestar.testing import TestClient + + +def test_app() -> None: + from docs.examples.security.jwt.verify_issuer_audience import app, jwt_auth + + valid_token = jwt_auth.create_token( + "foo", + token_audience=jwt_auth.accepted_audiences[0], + token_issuer=jwt_auth.accepted_issuers[0], + ) + invalid_token = jwt_auth.create_token("foo") + + with TestClient(app) as client: + response = client.get("/", headers={"Authorization": jwt_auth.format_auth_header(valid_token)}) + assert response.status_code == 200 + response = client.get("/", headers={"Authorization": jwt_auth.format_auth_header(invalid_token)}) + assert response.status_code == 401 diff --git a/tests/unit/test_security/test_jwt/test_auth.py b/tests/unit/test_security/test_jwt/test_auth.py index cdda690b6b..d3aff92829 100644 --- a/tests/unit/test_security/test_jwt/test_auth.py +++ b/tests/unit/test_security/test_jwt/test_auth.py @@ -2,7 +2,7 @@ import secrets import string from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple from uuid import uuid4 import jwt @@ -10,12 +10,13 @@ import pytest from hypothesis import given, settings from hypothesis.strategies import dictionaries, integers, none, one_of, sampled_from, text, timedeltas +from typing_extensions import TypeAlias from litestar import Litestar, Request, Response, get from litestar.security.jwt import JWTAuth, JWTCookieAuth, OAuth2PasswordBearerAuth, Token from litestar.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_401_UNAUTHORIZED from litestar.stores.memory import MemoryStore -from litestar.testing import create_test_client +from litestar.testing import TestClient, create_test_client from tests.models import User, UserFactory if TYPE_CHECKING: @@ -580,3 +581,262 @@ def handler() -> None: with create_test_client(route_handlers=[handler]) as client: response = client.get("/", headers={"Authorization": header}) assert response.status_code == 401 + + +@pytest.mark.parametrize( + "accepted_issuers, signing_issuer, expected_status_code", + [ + (["issuer_a"], "issuer_a", 200), + (["issuer_a", "issuer_b"], "issuer_a", 200), + (["issuer_a", "issuer_b"], "issuer_b", 200), + (["issuer_b"], "issuer_a", 401), + ], +) +@pytest.mark.parametrize("auth_cls", [JWTAuth, JWTCookieAuth, OAuth2PasswordBearerAuth]) +async def test_jwt_auth_verify_issuer( + auth_cls: Any, + accepted_issuers: List[str], + signing_issuer: str, + expected_status_code: int, +) -> None: + async def retrieve_user_handler(token: Token, _: "ASGIConnection") -> Any: + return object() + + token_secret = secrets.token_hex() + + if auth_cls is OAuth2PasswordBearerAuth: + jwt_auth = auth_cls( + token_secret=token_secret, + retrieve_user_handler=retrieve_user_handler, + token_url="http://testserver.local", + accepted_issuers=accepted_issuers, + ) + else: + jwt_auth = auth_cls[Any]( + token_secret=token_secret, + retrieve_user_handler=retrieve_user_handler, + accepted_issuers=accepted_issuers, + ) + + @get("/", middleware=[jwt_auth.middleware]) + def handler() -> None: + return None + + header = jwt_auth.format_auth_header( + jwt_auth.create_token( + identifier="foo", + token_issuer=signing_issuer, + ), + ) + + with create_test_client(route_handlers=[handler]) as client: + response = client.get("/", headers={"Authorization": header}) + assert response.status_code == expected_status_code + + +@pytest.mark.parametrize( + "accepted_audiences, token_audience, expected_status_code", + [ + (["audience_a"], "audience_a", 200), + (["audience_a", "audience_b"], "audience_a", 200), + (["audience_a", "audience_b"], "audience_b", 200), + (["audience_b"], "audience_a", 401), + ], +) +@pytest.mark.parametrize("auth_cls", [JWTAuth, JWTCookieAuth, OAuth2PasswordBearerAuth]) +async def test_jwt_auth_verify_audience( + auth_cls: Any, + accepted_audiences: List[str], + token_audience: str, + expected_status_code: int, +) -> None: + async def retrieve_user_handler(token: Token, _: "ASGIConnection") -> Any: + return object() + + token_secret = secrets.token_hex() + + if auth_cls is OAuth2PasswordBearerAuth: + jwt_auth = auth_cls( + token_secret=token_secret, + retrieve_user_handler=retrieve_user_handler, + token_url="http://testserver.local", + accepted_audiences=accepted_audiences, + ) + else: + jwt_auth = auth_cls[Any]( + token_secret=token_secret, + retrieve_user_handler=retrieve_user_handler, + accepted_audiences=accepted_audiences, + ) + + @get("/", middleware=[jwt_auth.middleware]) + def handler() -> None: + return None + + header = jwt_auth.format_auth_header( + jwt_auth.create_token( + identifier="foo", + token_audience=token_audience, + ), + ) + + with create_test_client(route_handlers=[handler]) as client: + response = client.get("/", headers={"Authorization": header}) + assert response.status_code == expected_status_code + + +CreateJWTApp: TypeAlias = Callable[..., Tuple[JWTAuth, TestClient]] + + +@pytest.fixture() +def create_jwt_app(auth_cls: Any, request: pytest.FixtureRequest) -> CreateJWTApp: + def create(**kwargs: Any) -> Tuple[JWTAuth, TestClient]: + async def retrieve_user_handler(token: Token, _: "ASGIConnection") -> Any: + return object() + + if auth_cls is OAuth2PasswordBearerAuth: + jwt_auth = auth_cls( + token_secret=secrets.token_hex(), + retrieve_user_handler=retrieve_user_handler, + token_url="http://testserver.local", + **kwargs, + ) + else: + jwt_auth = auth_cls[Any]( + token_secret=secrets.token_hex(), retrieve_user_handler=retrieve_user_handler, **kwargs + ) + + @get("/", middleware=[jwt_auth.middleware]) + def handler() -> None: + return None + + client = create_test_client(route_handlers=[handler]).__enter__() + request.addfinalizer(client.__exit__) + + return jwt_auth, client + + return create + + +@pytest.fixture(params=[JWTAuth, JWTCookieAuth, OAuth2PasswordBearerAuth]) +def auth_cls(request: pytest.FixtureRequest) -> Any: + return request.param + + +@pytest.mark.parametrize( + "accepted_audiences, token_audience, expected_status_code", + [ + (["audience_a"], "audience_a", 200), + ("audience_a", "audience_a", 200), + (["audience_a"], ["audience_a", "audience_b"], 401), + (["audience_b"], "audience_a", 401), + ], +) +async def test_jwt_auth_strict_audience( + accepted_audiences: List[str], + token_audience: str, + expected_status_code: int, + create_jwt_app: CreateJWTApp, +) -> None: + jwt_auth, client = create_jwt_app(strict_audience=True, accepted_audiences=accepted_audiences) + + header = jwt_auth.format_auth_header( + jwt_auth.create_token( + identifier="foo", + token_audience=token_audience, + ), + ) + + response = client.get("/", headers={"Authorization": header}) + assert response.status_code == expected_status_code + + +@pytest.mark.parametrize( + "require_claims, token_claims, expected_status_code", + [ + (["aud"], {"token_audience": "foo"}, 200), + (["aud"], {}, 401), + ([], {}, 200), + ], +) +async def test_jwt_auth_require_claims( + require_claims: List[str], + token_claims: Dict[str, str], + expected_status_code: int, + create_jwt_app: CreateJWTApp, +) -> None: + jwt_auth, client = create_jwt_app(require_claims=require_claims) + + header = jwt_auth.format_auth_header( + jwt_auth.create_token( + identifier="foo", + **token_claims, # type: ignore[arg-type] + ), + ) + + response = client.get("/", headers={"Authorization": header}) + assert response.status_code == expected_status_code + + +@pytest.mark.parametrize( + "token_expiration, verify_expiry, expected_status_code", + [ + pytest.param((datetime.now(tz=timezone.utc) + timedelta(days=1)).timestamp(), True, 200, id="valid-verify"), + pytest.param((datetime.now(tz=timezone.utc) + timedelta(days=1)).timestamp(), False, 200, id="valid-no_verify"), + pytest.param( + (datetime.now(tz=timezone.utc) - timedelta(days=1)).timestamp(), False, 200, id="invalid-no_verify" + ), + pytest.param((datetime.now(tz=timezone.utc) - timedelta(days=1)).timestamp(), True, 401, id="invalid-verify"), + ], +) +async def test_jwt_auth_verify_exp( + token_expiration: datetime, + verify_expiry: bool, + expected_status_code: int, + create_jwt_app: CreateJWTApp, +) -> None: + @dataclasses.dataclass + class CustomToken(Token): + def __post_init__(self) -> None: + pass + + jwt_auth, client = create_jwt_app(verify_expiry=verify_expiry, token_cls=CustomToken) + + header = jwt_auth.format_auth_header( + CustomToken( + sub="foo", + exp=token_expiration, + ).encode(jwt_auth.token_secret, jwt_auth.algorithm), + ) + + response = client.get("/", headers={"Authorization": header}) + assert response.status_code == expected_status_code + + +@pytest.mark.parametrize( + "token_nbf, verify_not_before, expected_status_code", + [ + pytest.param((datetime.now(tz=timezone.utc) - timedelta(days=1)).timestamp(), True, 200, id="valid-verify"), + pytest.param((datetime.now(tz=timezone.utc) - timedelta(days=1)).timestamp(), False, 200, id="valid-no_verify"), + pytest.param( + (datetime.now(tz=timezone.utc) + timedelta(days=1)).timestamp(), False, 200, id="invalid-no_verify" + ), + pytest.param((datetime.now(tz=timezone.utc) + timedelta(days=1)).timestamp(), True, 401, id="invalid-verify"), + ], +) +async def test_jwt_auth_verify_nbf( + token_nbf: datetime, + verify_not_before: bool, + expected_status_code: int, + create_jwt_app: CreateJWTApp, +) -> None: + @dataclasses.dataclass() + class CustomToken(Token): + nbf: Optional[float] = None + + jwt_auth, client = create_jwt_app(verify_not_before=verify_not_before, token_cls=CustomToken) + + header = jwt_auth.format_auth_header(jwt_auth.create_token("foo", nbf=token_nbf)) + + response = client.get("/", headers={"Authorization": header}) + assert response.status_code == expected_status_code diff --git a/tests/unit/test_security/test_jwt/test_token.py b/tests/unit/test_security/test_jwt/test_token.py index ac8c5658db..1637ab99f6 100644 --- a/tests/unit/test_security/test_jwt/test_token.py +++ b/tests/unit/test_security/test_jwt/test_token.py @@ -1,8 +1,11 @@ +from __future__ import annotations + +import dataclasses import secrets import sys from dataclasses import asdict from datetime import datetime, timedelta, timezone -from typing import Any, Dict, Optional +from typing import Any, Sequence from uuid import uuid4 import jwt @@ -12,6 +15,7 @@ from litestar.exceptions import ImproperlyConfiguredException, NotAuthorizedException from litestar.security.jwt import Token +from litestar.security.jwt.token import JWTDecodeOptions @pytest.mark.parametrize("algorithm", ["HS256", "HS384", "HS512"]) @@ -23,10 +27,10 @@ @pytest.mark.parametrize("token_extras", [None, {"email": "test@test.com"}]) def test_token( algorithm: str, - token_issuer: Optional[str], - token_audience: Optional[str], - token_unique_jwt_id: Optional[str], - token_extras: Optional[Dict[str, Any]], + token_issuer: str | None, + token_audience: str | None, + token_unique_jwt_id: str | None, + token_extras: dict[str, Any] | None, ) -> None: token_secret = secrets.token_hex() token = Token( @@ -164,3 +168,57 @@ def test_extra_fields() -> None: encoded_token = jwt.encode(payload=raw_token, key=token_secret, algorithm="HS256") token = Token.decode(encoded_token=encoded_token, secret=token_secret, algorithm="HS256") assert token.extras == {} + + +@pytest.mark.parametrize("audience", [None, ["foo", "bar"]]) +def test_strict_aud_with_multiple_audiences_raises(audience: str | list[str]) -> None: + with pytest.raises(ValueError, match="When using 'strict_audience=True'"): + Token.decode( + "", + secret="", + algorithm="HS256", + audience=audience, + strict_audience=True, + ) + + +@pytest.mark.parametrize("audience", ["foo", ["foo", "bar"]]) +def test_strict_aud_with_one_element_sequence(audience: str | list[str]) -> None: + # when validating with strict audience, PyJWT requires that the 'audience' parameter + # is passed as a string - one element lists are not allowed. Since we allow these + # generally, we convert them to a string in this case + secret = secrets.token_hex() + encoded = Token(exp=datetime.now() + timedelta(days=1), sub="foo", aud="foo").encode(secret, "HS256") + Token.decode( + encoded, + secret=secret, + algorithm="HS256", + audience=["foo"], + strict_audience=True, + ) + + +def test_custom_decode_payload() -> None: + @dataclasses.dataclass + class CustomToken(Token): + @classmethod + def decode_payload( + cls, + encoded_token: str, + secret: str, + algorithms: list[str], + issuer: list[str] | None = None, + audience: str | Sequence[str] | None = None, + options: JWTDecodeOptions | None = None, + ) -> Any: + payload = super().decode_payload( + encoded_token=encoded_token, + secret=secret, + algorithms=algorithms, + ) + payload["sub"] = "some-random-value" + return payload + + _secret = secrets.token_hex() + encoded = CustomToken(exp=datetime.now() + timedelta(days=1), sub="foo").encode(_secret, "HS256") + assert CustomToken.decode(encoded, secret=_secret, algorithm="HS256").sub == "some-random-value"