diff --git a/docs/conf.py b/docs/conf.py index 1f1946901a..bcdee8ef27 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -213,6 +213,7 @@ (PY_RE, r"advanced_alchemy\.config.common\.EngineT"), (PY_RE, r"advanced_alchemy\.config.common\.SessionT"), (PY_RE, r".*R"), + (PY_OBJ, r"litestar.security.jwt.auth.TokenT"), ] # Warnings about missing references to those targets in the specified location will be ignored. diff --git a/docs/examples/security/jwt/custom_token_cls.py b/docs/examples/security/jwt/custom_token_cls.py new file mode 100644 index 0000000000..164bb74cb9 --- /dev/null +++ b/docs/examples/security/jwt/custom_token_cls.py @@ -0,0 +1,38 @@ +import dataclasses +import secrets +from typing import Any, Dict + +from litestar import Litestar, Request, get +from litestar.connection import ASGIConnection +from litestar.security.jwt import JWTAuth, Token + + +@dataclasses.dataclass +class CustomToken(Token): + token_flag: bool = False + + +@dataclasses.dataclass +class User: + id: str + + +async def retrieve_user_handler(token: CustomToken, connection: ASGIConnection) -> User: + return User(id=token.sub) + + +TOKEN_SECRET = secrets.token_hex() + +jwt_auth = JWTAuth[User]( + token_secret=TOKEN_SECRET, + retrieve_user_handler=retrieve_user_handler, + token_cls=CustomToken, +) + + +@get("/") +def handler(request: Request[User, CustomToken, Any]) -> Dict[str, Any]: + return {"id": request.user.id, "token_flag": request.auth.token_flag} + + +app = Litestar(middleware=[jwt_auth.middleware]) diff --git a/docs/usage/security/jwt.rst b/docs/usage/security/jwt.rst index e18aef17c3..c1a96610df 100644 --- a/docs/usage/security/jwt.rst +++ b/docs/usage/security/jwt.rst @@ -44,3 +44,24 @@ OAuth 2.0 Bearer password flows. .. literalinclude:: /examples/security/jwt/using_oauth2_password_bearer.py :caption: Using OAUTH2 Bearer Password + + +Using a custom token class +-------------------------- + +The token class used can be customized with arbitrary fields, by creating a subclass of +:class:`~.security.jwt.Token`, and specifying it on the backend: + +.. literalinclude:: /examples/security/jwt/custom_token_cls.py + :caption: Using a custom token + + +The token will be converted from JSON into the appropriate type, including basic type +conversions. + +.. important:: + Complex type conversions, especially those including third libraries such as + Pydantic or attrs, as well as any custom ``type_decoders`` are not available for + converting the token. To support more complex conversions, the + :meth:`~.security.jwt.Token.encode` and :meth:`~.security.jwt.Token.decode` methods + must be overwritten in the subclass. diff --git a/litestar/security/jwt/auth.py b/litestar/security/jwt/auth.py index 2a0f09497f..c1758a9bcd 100644 --- a/litestar/security/jwt/auth.py +++ b/litestar/security/jwt/auth.py @@ -2,7 +2,9 @@ from dataclasses import asdict, dataclass, field from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Literal, Sequence, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Literal, Sequence, cast + +from typing_extensions import TypeVar from litestar.datastructures import Cookie from litestar.enums import MediaType @@ -24,9 +26,10 @@ UserType = TypeVar("UserType") +TokenT = TypeVar("TokenT", bound=Token, default=Token) -class BaseJWTAuth(Generic[UserType], AbstractSecurityConfig[UserType, Token]): +class BaseJWTAuth(Generic[UserType, TokenT], AbstractSecurityConfig[UserType, TokenT]): """Base class for JWT Auth backends""" token_secret: str @@ -63,6 +66,8 @@ class BaseJWTAuth(Generic[UserType], AbstractSecurityConfig[UserType, Token]): Must inherit from :class:`JWTAuthenticationMiddleware` """ + token_cls: type[Token] = Token + """Target type the JWT payload will be converted into""" @property def openapi_components(self) -> Components: @@ -114,6 +119,7 @@ def middleware(self) -> DefineMiddleware: retrieve_user_handler=self.retrieve_user_handler, scopes=self.scopes, token_secret=self.token_secret, + token_cls=self.token_cls, ) def login( @@ -179,6 +185,7 @@ def create_token( token_audience: str | None = None, token_unique_jwt_id: str | None = None, token_extras: dict | None = None, + **kwargs: Any, ) -> str: """Create a Token instance from the passed in parameters, persists and returns it. @@ -189,17 +196,19 @@ def create_token( token_audience: An optional value for the token ``aud`` field. token_unique_jwt_id: An optional value for the token ``jti`` field. token_extras: An optional dictionary to include in the token ``extras`` field. + **kwargs: Additional attributes to set on the token Returns: The created token. """ - token = Token( + token = self.token_cls( sub=identifier, exp=(datetime.now(timezone.utc) + (token_expiration or self.default_token_expiration)), iss=token_issuer, aud=token_audience, jti=token_unique_jwt_id, extras=token_extras or {}, + **kwargs, ) return token.encode(secret=self.token_secret, algorithm=self.algorithm) @@ -217,7 +226,7 @@ def format_auth_header(self, encoded_token: str) -> str: @dataclass -class JWTAuth(Generic[UserType], BaseJWTAuth[UserType]): +class JWTAuth(Generic[UserType, TokenT], BaseJWTAuth[UserType, TokenT]): """JWT Authentication Configuration. This class is the main entry point to the library, and it includes methods to create the middleware, provide login @@ -279,10 +288,12 @@ class JWTAuth(Generic[UserType], BaseJWTAuth[UserType]): Must inherit from :class:`JWTAuthenticationMiddleware` """ + token_cls: type[Token] = Token + """Target type the JWT payload will be converted into""" @dataclass -class JWTCookieAuth(Generic[UserType], BaseJWTAuth[UserType]): +class JWTCookieAuth(Generic[UserType, TokenT], BaseJWTAuth[UserType, TokenT]): """JWT Cookie Authentication Configuration. This class is an alternate entry point to the library, and it includes all the functionality of the :class:`JWTAuth` @@ -357,6 +368,8 @@ class and adds support for passing JWT tokens ``HttpOnly`` cookies. ) """The authentication middleware class to use. Must inherit from :class:`JWTCookieAuthenticationMiddleware` """ + token_cls: type[Token] = Token + """Target type the JWT payload will be converted into""" @property def openapi_components(self) -> Components: @@ -397,6 +410,7 @@ def middleware(self) -> DefineMiddleware: retrieve_user_handler=self.retrieve_user_handler, scopes=self.scopes, token_secret=self.token_secret, + token_cls=self.token_cls, ) def login( @@ -482,7 +496,7 @@ class OAuth2Login: @dataclass -class OAuth2PasswordBearerAuth(Generic[UserType], BaseJWTAuth[UserType]): +class OAuth2PasswordBearerAuth(Generic[UserType, TokenT], BaseJWTAuth[UserType, TokenT]): """OAUTH2 Schema for Password Bearer Authentication. This class implements an OAUTH2 authentication flow entry point to the library, and it includes all the @@ -563,6 +577,8 @@ class OAuth2PasswordBearerAuth(Generic[UserType], BaseJWTAuth[UserType]): Must inherit from :class:`JWTCookieAuthenticationMiddleware` """ + token_cls: type[Token] = Token + """Target type the JWT payload will be converted into""" @property def middleware(self) -> DefineMiddleware: @@ -583,6 +599,7 @@ def middleware(self) -> DefineMiddleware: retrieve_user_handler=self.retrieve_user_handler, scopes=self.scopes, token_secret=self.token_secret, + token_cls=self.token_cls, ) @property diff --git a/litestar/security/jwt/middleware.py b/litestar/security/jwt/middleware.py index 84326da7a3..1acd03d1c7 100644 --- a/litestar/security/jwt/middleware.py +++ b/litestar/security/jwt/middleware.py @@ -30,6 +30,7 @@ class JWTAuthenticationMiddleware(AbstractAuthenticationMiddleware): "auth_header", "retrieve_user_handler", "token_secret", + "token_cls", ) def __init__( @@ -43,6 +44,7 @@ def __init__( retrieve_user_handler: Callable[[Token, ASGIConnection[Any, Any, Any, Any]], Awaitable[Any]], scopes: Scopes, token_secret: str, + token_cls: type[Token] = Token, ) -> None: """Check incoming requests for an encoded token in the auth header specified, and if present retrieve the user from persistence using the provided function. @@ -57,8 +59,9 @@ def __init__( retrieve_user_handler: A function that receives a :class:`Token <.security.jwt.Token>` and returns a user, which can be any arbitrary value. scopes: ASGI scopes processed by the authentication middleware. - token_secret: Secret for decoding the JWT token. This value should be equivalent to the secret used to + token_secret: Secret for decoding the JWT. This value should be equivalent to the secret used to encode it. + token_cls: Token class used when encoding / decoding JWTs """ super().__init__( app=app, @@ -71,6 +74,7 @@ def __init__( self.auth_header = auth_header self.retrieve_user_handler = retrieve_user_handler self.token_secret = token_secret + self.token_cls = token_cls async def authenticate_request(self, connection: ASGIConnection[Any, Any, Any, Any]) -> AuthenticationResult: """Given an HTTP Connection, parse the JWT api key stored in the header and retrieve the user correlating to the @@ -106,7 +110,7 @@ async def authenticate_token( Returns: AuthenticationResult """ - token = Token.decode( + token = self.token_cls.decode( encoded_token=encoded_token, secret=self.token_secret, algorithm=self.algorithm, @@ -137,6 +141,7 @@ def __init__( retrieve_user_handler: Callable[[Token, ASGIConnection[Any, Any, Any, Any]], Awaitable[Any]], scopes: Scopes, token_secret: str, + token_cls: type[Token] = Token, ) -> None: """Check incoming requests for an encoded token in the auth header or cookie name specified, and if present retrieves the user from persistence using the provided function. @@ -152,8 +157,9 @@ def __init__( retrieve_user_handler: A function that receives a :class:`Token <.security.jwt.Token>` and returns a user, which can be any arbitrary value. scopes: ASGI scopes processed by the authentication middleware. - token_secret: Secret for decoding the JWT token. This value should be equivalent to the secret used to + token_secret: Secret for decoding the JWT. This value should be equivalent to the secret used to encode it. + token_cls: Token class used when encoding / decoding JWTs """ super().__init__( algorithm=algorithm, @@ -165,6 +171,7 @@ def __init__( retrieve_user_handler=retrieve_user_handler, scopes=scopes, token_secret=token_secret, + token_cls=token_cls, ) self.auth_cookie_key = auth_cookie_key diff --git a/litestar/security/jwt/token.py b/litestar/security/jwt/token.py index 8187b1a73d..46237229e0 100644 --- a/litestar/security/jwt/token.py +++ b/litestar/security/jwt/token.py @@ -3,9 +3,10 @@ import dataclasses from dataclasses import asdict, dataclass, field from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Dict, Optional import jwt +import msgspec from litestar.exceptions import ImproperlyConfiguredException, NotAuthorizedException @@ -41,13 +42,13 @@ class Token: """Subject - usually a unique identifier of the user or equivalent entity.""" iat: datetime = field(default_factory=lambda: _normalize_datetime(datetime.now(timezone.utc))) """Issued at - should always be current now.""" - iss: str | None = field(default=None) + iss: Optional[str] = field(default=None) # noqa: UP007 """Issuer - optional unique identifier for the issuer.""" - aud: str | None = field(default=None) + aud: Optional[str] = field(default=None) # noqa: UP007 """Audience - intended audience.""" - jti: str | None = field(default=None) + jti: Optional[str] = field(default=None) # noqa: UP007 """JWT ID - a unique identifier of the JWT between different issuers.""" - extras: dict[str, Any] = field(default_factory=dict) + extras: Dict[str, Any] = field(default_factory=dict) # noqa: UP006 """Extra fields that were found on the JWT token.""" def __post_init__(self) -> None: @@ -86,15 +87,22 @@ def decode(cls, encoded_token: str, secret: str, algorithm: str) -> Self: NotAuthorizedException: If the token is invalid. """ try: - payload = jwt.decode(jwt=encoded_token, key=secret, algorithms=[algorithm], options={"verify_aud": False}) - exp = datetime.fromtimestamp(payload.pop("exp"), tz=timezone.utc) - iat = datetime.fromtimestamp(payload.pop("iat"), tz=timezone.utc) - field_names = {f.name for f in dataclasses.fields(Token)} - extra_fields = payload.keys() - field_names - extras = payload.pop("extras", {}) + payload: dict[str, Any] = jwt.decode( + jwt=encoded_token, + key=secret, + algorithms=[algorithm], + options={"verify_aud": False}, + ) + # msgspec can do these conversions as well, but to keep backwards + # compatibility, we do it ourselves, since the datetime parsing works a + # little bit different there + payload["exp"] = datetime.fromtimestamp(payload["exp"], tz=timezone.utc) + payload["iat"] = datetime.fromtimestamp(payload["iat"], tz=timezone.utc) + extra_fields = payload.keys() - {f.name for f in dataclasses.fields(cls)} + extras = payload.setdefault("extras", {}) for key in extra_fields: extras[key] = payload.pop(key) - return cls(exp=exp, iat=iat, **payload, extras=extras) + return msgspec.convert(payload, cls, strict=False) except (KeyError, jwt.DecodeError, ImproperlyConfiguredException, jwt.exceptions.InvalidAlgorithmError) as e: raise NotAuthorizedException("Invalid token") from e @@ -113,7 +121,9 @@ def encode(self, secret: str, algorithm: str) -> str: """ try: return jwt.encode( - payload={k: v for k, v in asdict(self).items() if v is not None}, key=secret, algorithm=algorithm + payload={k: v for k, v in asdict(self).items() if v is not None}, + key=secret, + algorithm=algorithm, ) except (jwt.DecodeError, NotImplementedError) as e: raise ImproperlyConfiguredException("Failed to encode token") from e diff --git a/tests/unit/test_security/test_jwt/test_auth.py b/tests/unit/test_security/test_jwt/test_auth.py index fd2e8cbe4a..e314350483 100644 --- a/tests/unit/test_security/test_jwt/test_auth.py +++ b/tests/unit/test_security/test_jwt/test_auth.py @@ -1,3 +1,5 @@ +import dataclasses +import secrets import string from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Dict, Optional @@ -128,6 +130,56 @@ def login_handler() -> Response["User"]: assert response.status_code == HTTP_401_UNAUTHORIZED +@pytest.mark.parametrize("auth_cls", [JWTAuth, JWTCookieAuth, OAuth2PasswordBearerAuth]) +async def test_jwt_auth_custom_token_cls(auth_cls: Any) -> None: + @dataclasses.dataclass + class CustomToken(Token): + random_field: int = 1 + + async def retrieve_user_handler(token: CustomToken, _: "ASGIConnection") -> Any: + return object() + + token_secret = secrets.token_hex() + + if auth_cls is OAuth2PasswordBearerAuth: + jwt_auth = auth_cls( + token_secret=token_secret, + retrieve_user_handler=retrieve_user_handler, + token_cls=CustomToken, + token_url="http://testserver.local", + ) + else: + jwt_auth = auth_cls[Any]( + token_secret=token_secret, + retrieve_user_handler=retrieve_user_handler, + token_cls=CustomToken, + ) + + @get("/", middleware=[jwt_auth.middleware]) + def handler(request: Request[Any, CustomToken, Any]) -> Dict[str, Any]: + return { + "is_token_cls": isinstance(request.auth, CustomToken), + "token": dataclasses.asdict(request.auth), + } + + header = jwt_auth.format_auth_header( + jwt_auth.create_token( + "foo", + token_extras={"foo": "bar"}, + # pass a string here as value to ensure things get converted properly + random_field="2", + ), + ) + + with create_test_client(route_handlers=[handler]) as client: + response = client.get("/", headers={"Authorization": header}) + assert response.status_code == 200 + response_data = response.json() + assert response_data["is_token_cls"] is True + assert response_data["token"]["extras"] == {"foo": "bar"} + assert response_data["token"]["random_field"] == 2 + + @given( algorithm=sampled_from( [