Skip to content

Commit

Permalink
feat(JWT): Custom JWT payload classes (#3692)
Browse files Browse the repository at this point in the history
* feat(JWT): Custom JWT payload classes
  • Loading branch information
provinzkraut authored Aug 24, 2024
1 parent 23ba256 commit 9cff0a4
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 22 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@
(PY_RE, r"advanced_alchemy\.config.common\.EngineT"),
(PY_RE, r"advanced_alchemy\.config.common\.SessionT"),
(PY_RE, r".*R"),
(PY_OBJ, r"litestar.security.jwt.auth.TokenT"),
]

# Warnings about missing references to those targets in the specified location will be ignored.
Expand Down
38 changes: 38 additions & 0 deletions docs/examples/security/jwt/custom_token_cls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import dataclasses
import secrets
from typing import Any, Dict

from litestar import Litestar, Request, get
from litestar.connection import ASGIConnection
from litestar.security.jwt import JWTAuth, Token


@dataclasses.dataclass
class CustomToken(Token):
token_flag: bool = False


@dataclasses.dataclass
class User:
id: str


async def retrieve_user_handler(token: CustomToken, connection: ASGIConnection) -> User:
return User(id=token.sub)


TOKEN_SECRET = secrets.token_hex()

jwt_auth = JWTAuth[User](
token_secret=TOKEN_SECRET,
retrieve_user_handler=retrieve_user_handler,
token_cls=CustomToken,
)


@get("/")
def handler(request: Request[User, CustomToken, Any]) -> Dict[str, Any]:
return {"id": request.user.id, "token_flag": request.auth.token_flag}


app = Litestar(middleware=[jwt_auth.middleware])
21 changes: 21 additions & 0 deletions docs/usage/security/jwt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,24 @@ OAuth 2.0 Bearer password flows.

.. literalinclude:: /examples/security/jwt/using_oauth2_password_bearer.py
:caption: Using OAUTH2 Bearer Password


Using a custom token class
--------------------------

The token class used can be customized with arbitrary fields, by creating a subclass of
:class:`~.security.jwt.Token`, and specifying it on the backend:

.. literalinclude:: /examples/security/jwt/custom_token_cls.py
:caption: Using a custom token


The token will be converted from JSON into the appropriate type, including basic type
conversions.

.. important::
Complex type conversions, especially those including third libraries such as
Pydantic or attrs, as well as any custom ``type_decoders`` are not available for
converting the token. To support more complex conversions, the
:meth:`~.security.jwt.Token.encode` and :meth:`~.security.jwt.Token.decode` methods
must be overwritten in the subclass.
29 changes: 23 additions & 6 deletions litestar/security/jwt/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from dataclasses import asdict, dataclass, field
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Literal, Sequence, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Literal, Sequence, cast

from typing_extensions import TypeVar

from litestar.datastructures import Cookie
from litestar.enums import MediaType
Expand All @@ -24,9 +26,10 @@


UserType = TypeVar("UserType")
TokenT = TypeVar("TokenT", bound=Token, default=Token)


class BaseJWTAuth(Generic[UserType], AbstractSecurityConfig[UserType, Token]):
class BaseJWTAuth(Generic[UserType, TokenT], AbstractSecurityConfig[UserType, TokenT]):
"""Base class for JWT Auth backends"""

token_secret: str
Expand Down Expand Up @@ -63,6 +66,8 @@ class BaseJWTAuth(Generic[UserType], AbstractSecurityConfig[UserType, Token]):
Must inherit from :class:`JWTAuthenticationMiddleware`
"""
token_cls: type[Token] = Token
"""Target type the JWT payload will be converted into"""

@property
def openapi_components(self) -> Components:
Expand Down Expand Up @@ -114,6 +119,7 @@ def middleware(self) -> DefineMiddleware:
retrieve_user_handler=self.retrieve_user_handler,
scopes=self.scopes,
token_secret=self.token_secret,
token_cls=self.token_cls,
)

def login(
Expand Down Expand Up @@ -179,6 +185,7 @@ def create_token(
token_audience: str | None = None,
token_unique_jwt_id: str | None = None,
token_extras: dict | None = None,
**kwargs: Any,
) -> str:
"""Create a Token instance from the passed in parameters, persists and returns it.
Expand All @@ -189,17 +196,19 @@ def create_token(
token_audience: An optional value for the token ``aud`` field.
token_unique_jwt_id: An optional value for the token ``jti`` field.
token_extras: An optional dictionary to include in the token ``extras`` field.
**kwargs: Additional attributes to set on the token
Returns:
The created token.
"""
token = Token(
token = self.token_cls(
sub=identifier,
exp=(datetime.now(timezone.utc) + (token_expiration or self.default_token_expiration)),
iss=token_issuer,
aud=token_audience,
jti=token_unique_jwt_id,
extras=token_extras or {},
**kwargs,
)
return token.encode(secret=self.token_secret, algorithm=self.algorithm)

Expand All @@ -217,7 +226,7 @@ def format_auth_header(self, encoded_token: str) -> str:


@dataclass
class JWTAuth(Generic[UserType], BaseJWTAuth[UserType]):
class JWTAuth(Generic[UserType, TokenT], BaseJWTAuth[UserType, TokenT]):
"""JWT Authentication Configuration.
This class is the main entry point to the library, and it includes methods to create the middleware, provide login
Expand Down Expand Up @@ -279,10 +288,12 @@ class JWTAuth(Generic[UserType], BaseJWTAuth[UserType]):
Must inherit from :class:`JWTAuthenticationMiddleware`
"""
token_cls: type[Token] = Token
"""Target type the JWT payload will be converted into"""


@dataclass
class JWTCookieAuth(Generic[UserType], BaseJWTAuth[UserType]):
class JWTCookieAuth(Generic[UserType, TokenT], BaseJWTAuth[UserType, TokenT]):
"""JWT Cookie Authentication Configuration.
This class is an alternate entry point to the library, and it includes all the functionality of the :class:`JWTAuth`
Expand Down Expand Up @@ -357,6 +368,8 @@ class and adds support for passing JWT tokens ``HttpOnly`` cookies.
)
"""The authentication middleware class to use. Must inherit from :class:`JWTCookieAuthenticationMiddleware`
"""
token_cls: type[Token] = Token
"""Target type the JWT payload will be converted into"""

@property
def openapi_components(self) -> Components:
Expand Down Expand Up @@ -397,6 +410,7 @@ def middleware(self) -> DefineMiddleware:
retrieve_user_handler=self.retrieve_user_handler,
scopes=self.scopes,
token_secret=self.token_secret,
token_cls=self.token_cls,
)

def login(
Expand Down Expand Up @@ -482,7 +496,7 @@ class OAuth2Login:


@dataclass
class OAuth2PasswordBearerAuth(Generic[UserType], BaseJWTAuth[UserType]):
class OAuth2PasswordBearerAuth(Generic[UserType, TokenT], BaseJWTAuth[UserType, TokenT]):
"""OAUTH2 Schema for Password Bearer Authentication.
This class implements an OAUTH2 authentication flow entry point to the library, and it includes all the
Expand Down Expand Up @@ -563,6 +577,8 @@ class OAuth2PasswordBearerAuth(Generic[UserType], BaseJWTAuth[UserType]):
Must inherit from :class:`JWTCookieAuthenticationMiddleware`
"""
token_cls: type[Token] = Token
"""Target type the JWT payload will be converted into"""

@property
def middleware(self) -> DefineMiddleware:
Expand All @@ -583,6 +599,7 @@ def middleware(self) -> DefineMiddleware:
retrieve_user_handler=self.retrieve_user_handler,
scopes=self.scopes,
token_secret=self.token_secret,
token_cls=self.token_cls,
)

@property
Expand Down
13 changes: 10 additions & 3 deletions litestar/security/jwt/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class JWTAuthenticationMiddleware(AbstractAuthenticationMiddleware):
"auth_header",
"retrieve_user_handler",
"token_secret",
"token_cls",
)

def __init__(
Expand All @@ -43,6 +44,7 @@ def __init__(
retrieve_user_handler: Callable[[Token, ASGIConnection[Any, Any, Any, Any]], Awaitable[Any]],
scopes: Scopes,
token_secret: str,
token_cls: type[Token] = Token,
) -> None:
"""Check incoming requests for an encoded token in the auth header specified, and if present retrieve the user
from persistence using the provided function.
Expand All @@ -57,8 +59,9 @@ def __init__(
retrieve_user_handler: A function that receives a :class:`Token <.security.jwt.Token>` and returns a user,
which can be any arbitrary value.
scopes: ASGI scopes processed by the authentication middleware.
token_secret: Secret for decoding the JWT token. This value should be equivalent to the secret used to
token_secret: Secret for decoding the JWT. This value should be equivalent to the secret used to
encode it.
token_cls: Token class used when encoding / decoding JWTs
"""
super().__init__(
app=app,
Expand All @@ -71,6 +74,7 @@ def __init__(
self.auth_header = auth_header
self.retrieve_user_handler = retrieve_user_handler
self.token_secret = token_secret
self.token_cls = token_cls

async def authenticate_request(self, connection: ASGIConnection[Any, Any, Any, Any]) -> AuthenticationResult:
"""Given an HTTP Connection, parse the JWT api key stored in the header and retrieve the user correlating to the
Expand Down Expand Up @@ -106,7 +110,7 @@ async def authenticate_token(
Returns:
AuthenticationResult
"""
token = Token.decode(
token = self.token_cls.decode(
encoded_token=encoded_token,
secret=self.token_secret,
algorithm=self.algorithm,
Expand Down Expand Up @@ -137,6 +141,7 @@ def __init__(
retrieve_user_handler: Callable[[Token, ASGIConnection[Any, Any, Any, Any]], Awaitable[Any]],
scopes: Scopes,
token_secret: str,
token_cls: type[Token] = Token,
) -> None:
"""Check incoming requests for an encoded token in the auth header or cookie name specified, and if present
retrieves the user from persistence using the provided function.
Expand All @@ -152,8 +157,9 @@ def __init__(
retrieve_user_handler: A function that receives a :class:`Token <.security.jwt.Token>` and returns a user,
which can be any arbitrary value.
scopes: ASGI scopes processed by the authentication middleware.
token_secret: Secret for decoding the JWT token. This value should be equivalent to the secret used to
token_secret: Secret for decoding the JWT. This value should be equivalent to the secret used to
encode it.
token_cls: Token class used when encoding / decoding JWTs
"""
super().__init__(
algorithm=algorithm,
Expand All @@ -165,6 +171,7 @@ def __init__(
retrieve_user_handler=retrieve_user_handler,
scopes=scopes,
token_secret=token_secret,
token_cls=token_cls,
)
self.auth_cookie_key = auth_cookie_key

Expand Down
36 changes: 23 additions & 13 deletions litestar/security/jwt/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import dataclasses
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Dict, Optional

import jwt
import msgspec

from litestar.exceptions import ImproperlyConfiguredException, NotAuthorizedException

Expand Down Expand Up @@ -41,13 +42,13 @@ class Token:
"""Subject - usually a unique identifier of the user or equivalent entity."""
iat: datetime = field(default_factory=lambda: _normalize_datetime(datetime.now(timezone.utc)))
"""Issued at - should always be current now."""
iss: str | None = field(default=None)
iss: Optional[str] = field(default=None) # noqa: UP007
"""Issuer - optional unique identifier for the issuer."""
aud: str | None = field(default=None)
aud: Optional[str] = field(default=None) # noqa: UP007
"""Audience - intended audience."""
jti: str | None = field(default=None)
jti: Optional[str] = field(default=None) # noqa: UP007
"""JWT ID - a unique identifier of the JWT between different issuers."""
extras: dict[str, Any] = field(default_factory=dict)
extras: Dict[str, Any] = field(default_factory=dict) # noqa: UP006
"""Extra fields that were found on the JWT token."""

def __post_init__(self) -> None:
Expand Down Expand Up @@ -86,15 +87,22 @@ def decode(cls, encoded_token: str, secret: str, algorithm: str) -> Self:
NotAuthorizedException: If the token is invalid.
"""
try:
payload = jwt.decode(jwt=encoded_token, key=secret, algorithms=[algorithm], options={"verify_aud": False})
exp = datetime.fromtimestamp(payload.pop("exp"), tz=timezone.utc)
iat = datetime.fromtimestamp(payload.pop("iat"), tz=timezone.utc)
field_names = {f.name for f in dataclasses.fields(Token)}
extra_fields = payload.keys() - field_names
extras = payload.pop("extras", {})
payload: dict[str, Any] = jwt.decode(
jwt=encoded_token,
key=secret,
algorithms=[algorithm],
options={"verify_aud": False},
)
# msgspec can do these conversions as well, but to keep backwards
# compatibility, we do it ourselves, since the datetime parsing works a
# little bit different there
payload["exp"] = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
payload["iat"] = datetime.fromtimestamp(payload["iat"], tz=timezone.utc)
extra_fields = payload.keys() - {f.name for f in dataclasses.fields(cls)}
extras = payload.setdefault("extras", {})
for key in extra_fields:
extras[key] = payload.pop(key)
return cls(exp=exp, iat=iat, **payload, extras=extras)
return msgspec.convert(payload, cls, strict=False)
except (KeyError, jwt.DecodeError, ImproperlyConfiguredException, jwt.exceptions.InvalidAlgorithmError) as e:
raise NotAuthorizedException("Invalid token") from e

Expand All @@ -113,7 +121,9 @@ def encode(self, secret: str, algorithm: str) -> str:
"""
try:
return jwt.encode(
payload={k: v for k, v in asdict(self).items() if v is not None}, key=secret, algorithm=algorithm
payload={k: v for k, v in asdict(self).items() if v is not None},
key=secret,
algorithm=algorithm,
)
except (jwt.DecodeError, NotImplementedError) as e:
raise ImproperlyConfiguredException("Failed to encode token") from e
Loading

0 comments on commit 9cff0a4

Please sign in to comment.