Skip to content

Commit

Permalink
feat(JWT): Customised token verification (#3695)
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut authored Aug 27, 2024
1 parent 8cdc43d commit 44819d0
Show file tree
Hide file tree
Showing 10 changed files with 687 additions and 19 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
(PY_RE, r".*R"),
(PY_OBJ, r"litestar.security.jwt.auth.TokenT"),
(PY_CLASS, "ExceptionToProblemDetailMapType"),
(PY_CLASS, "litestar.security.jwt.token.JWTDecodeOptions"),
]

# Warnings about missing references to those targets in the specified location will be ignored.
Expand Down
28 changes: 28 additions & 0 deletions docs/examples/security/jwt/custom_decode_payload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import dataclasses
from typing import Any, List, Optional, Sequence, Union

from litestar.security.jwt.token import JWTDecodeOptions, Token


@dataclasses.dataclass
class CustomToken(Token):
@classmethod
def decode_payload(
cls,
encoded_token: str,
secret: str,
algorithms: List[str],
issuer: Optional[List[str]] = None,
audience: Union[str, Sequence[str], None] = None,
options: Optional[JWTDecodeOptions] = None,
) -> Any:
payload = super().decode_payload(
encoded_token=encoded_token,
secret=secret,
algorithms=algorithms,
issuer=issuer,
audience=audience,
options=options,
)
payload["sub"] = payload["sub"].split("@", maxsplit=1)[1]
return payload
32 changes: 32 additions & 0 deletions docs/examples/security/jwt/verify_issuer_audience.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import dataclasses
import secrets
from typing import Any, Dict

from litestar import Litestar, Request, get
from litestar.connection import ASGIConnection
from litestar.security.jwt import JWTAuth, Token


@dataclasses.dataclass
class User:
id: str


async def retrieve_user_handler(token: Token, connection: ASGIConnection) -> User:
return User(id=token.sub)


jwt_auth = JWTAuth[User](
token_secret=secrets.token_hex(),
retrieve_user_handler=retrieve_user_handler,
accepted_audiences=["https://api.testserver.local"],
accepted_issuers=["https://auth.testserver.local"],
)


@get("/")
def handler(request: Request[User, Token, Any]) -> Dict[str, Any]:
return {"id": request.user.id}


app = Litestar([handler], middleware=[jwt_auth.middleware])
29 changes: 29 additions & 0 deletions docs/usage/security/jwt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,32 @@ conversions.
converting the token. To support more complex conversions, the
:meth:`~.security.jwt.Token.encode` and :meth:`~.security.jwt.Token.decode` methods
must be overwritten in the subclass.


Verifying issuer and audience
-----------------------------

To verify the JWT ``iss`` (*issuer*) and ``aud`` (*audience*) claim, a list of accepted
issuers or audiences can bet set on the authentication backend. When a JWT is decoded,
the issuer or audience on the token is compared to the list of accepted issuers /
audiences. If the value in the token does not match any value in the respective list,
a :exc:`NotAuthorizedException` will be raised, returning a response with a
``401 Unauthorized`` status.


.. literalinclude:: /examples/security/jwt/verify_issuer_audience.py
:caption: Verifying issuer and audience


Customizing token validation
----------------------------

Token decoding / validation can be further customized by overriding the
:meth:`~.security.jwt.Token.decode_payload` method. It will be called by
:meth:`~.security.jwt.Token.decode` with the encoded token string, and must return a
dictionary representing the decoded payload, which will then used by
:meth:`~.security.jwt.Token.decode` to construct an instance of the token class.


.. literalinclude:: /examples/security/jwt/custom_decode_payload.py
:caption: Customizing payload decoding
102 changes: 102 additions & 0 deletions litestar/security/jwt/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,27 @@ class BaseJWTAuth(Generic[UserType, TokenT], AbstractSecurityConfig[UserType, To
"""
token_cls: type[Token] = Token
"""Target type the JWT payload will be converted into"""
accepted_audiences: Sequence[str] | None = None
"""Audiences to accept when verifying the token. If given, and the audience in the
token does not match, a 401 response is returned
"""
accepted_issuers: Sequence[str] | None = None
"""Issuers to accept when verifying the token. If given, and the issuer in the
token does not match, a 401 response is returned
"""
require_claims: Sequence[str] | None = None
"""Require these claims to be present in the JWT payload. If any of those claims
is missing, a 401 response is returned
"""
verify_expiry: bool = True
"""Verify that the value of the ``exp`` (*expiration*) claim is in the future"""
verify_not_before: bool = True
"""Verify that the value of the ``nbf`` (*not before*) claim is in the past"""
strict_audience: bool = False
"""Verify that the value of the ``aud`` (*audience*) claim is a single value, and
not a list of values, and matches ``audience`` exactly. Requires that
``accepted_audiences`` is a sequence of length 1
"""

@property
def openapi_components(self) -> Components:
Expand Down Expand Up @@ -120,6 +141,12 @@ def middleware(self) -> DefineMiddleware:
scopes=self.scopes,
token_secret=self.token_secret,
token_cls=self.token_cls,
token_issuer=self.accepted_issuers,
token_audience=self.accepted_audiences,
require_claims=self.require_claims,
verify_expiry=self.verify_expiry,
verify_not_before=self.verify_not_before,
strict_audience=self.strict_audience,
)

def login(
Expand Down Expand Up @@ -290,6 +317,27 @@ class JWTAuth(Generic[UserType, TokenT], BaseJWTAuth[UserType, TokenT]):
"""
token_cls: type[Token] = Token
"""Target type the JWT payload will be converted into"""
accepted_audiences: Sequence[str] | None = None
"""Audiences to accept when verifying the token. If given, and the audience in the
token does not match, a 401 response is returned
"""
accepted_issuers: Sequence[str] | None = None
"""Issuers to accept when verifying the token. If given, and the issuer in the
token does not match, a 401 response is returned
"""
require_claims: Sequence[str] | None = None
"""Require these claims to be present in the JWT payload. If any of those claims
is missing, a 401 response is returned
"""
verify_expiry: bool = True
"""Verify that the value of the ``exp`` (*expiration*) claim is in the future"""
verify_not_before: bool = True
"""Verify that the value of the ``nbf`` (*not before*) claim is in the past"""
strict_audience: bool = False
"""Verify that the value of the ``aud`` (*audience*) claim is a single value, and
not a list of values, and matches ``audience`` exactly. Requires that
``accepted_audiences`` is a sequence of length 1
"""


@dataclass
Expand Down Expand Up @@ -370,6 +418,27 @@ class and adds support for passing JWT tokens ``HttpOnly`` cookies.
"""
token_cls: type[Token] = Token
"""Target type the JWT payload will be converted into"""
accepted_audiences: Sequence[str] | None = None
"""Audiences to accept when verifying the token. If given, and the audience in the
token does not match, a 401 response is returned
"""
accepted_issuers: Sequence[str] | None = None
"""Issuers to accept when verifying the token. If given, and the issuer in the
token does not match, a 401 response is returned
"""
require_claims: Sequence[str] | None = None
"""Require these claims to be present in the JWT payload. If any of those claims
is missing, a 401 response is returned
"""
verify_expiry: bool = True
"""Verify that the value of the ``exp`` (*expiration*) claim is in the future"""
verify_not_before: bool = True
"""Verify that the value of the ``nbf`` (*not before*) claim is in the past"""
strict_audience: bool = False
"""Verify that the value of the ``aud`` (*audience*) claim is a single value, and
not a list of values, and matches ``audience`` exactly. Requires that
``accepted_audiences`` is a sequence of length 1
"""

@property
def openapi_components(self) -> Components:
Expand Down Expand Up @@ -411,6 +480,12 @@ def middleware(self) -> DefineMiddleware:
scopes=self.scopes,
token_secret=self.token_secret,
token_cls=self.token_cls,
token_issuer=self.accepted_issuers,
token_audience=self.accepted_audiences,
require_claims=self.require_claims,
verify_expiry=self.verify_expiry,
verify_not_before=self.verify_not_before,
strict_audience=self.strict_audience,
)

def login(
Expand Down Expand Up @@ -579,6 +654,27 @@ class OAuth2PasswordBearerAuth(Generic[UserType, TokenT], BaseJWTAuth[UserType,
"""
token_cls: type[Token] = Token
"""Target type the JWT payload will be converted into"""
accepted_audiences: Sequence[str] | None = None
"""Audiences to accept when verifying the token. If given, and the audience in the
token does not match, a 401 response is returned
"""
accepted_issuers: Sequence[str] | None = None
"""Issuers to accept when verifying the token. If given, and the issuer in the
token does not match, a 401 response is returned
"""
require_claims: Sequence[str] | None = None
"""Require these claims to be present in the JWT payload. If any of those claims
is missing, a 401 response is returned
"""
verify_expiry: bool = True
"""Verify that the value of the ``exp`` (*expiration*) claim is in the future"""
verify_not_before: bool = True
"""Verify that the value of the ``nbf`` (*not before*) claim is in the past"""
strict_audience: bool = False
"""Verify that the value of the ``aud`` (*audience*) claim is a single value, and
not a list of values, and matches ``audience`` exactly. Requires that
``accepted_audiences`` is a sequence of length 1
"""

@property
def middleware(self) -> DefineMiddleware:
Expand All @@ -600,6 +696,12 @@ def middleware(self) -> DefineMiddleware:
scopes=self.scopes,
token_secret=self.token_secret,
token_cls=self.token_cls,
token_issuer=self.accepted_issuers,
token_audience=self.accepted_audiences,
require_claims=self.require_claims,
verify_expiry=self.verify_expiry,
verify_not_before=self.verify_not_before,
strict_audience=self.strict_audience,
)

@property
Expand Down
60 changes: 60 additions & 0 deletions litestar/security/jwt/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ class JWTAuthenticationMiddleware(AbstractAuthenticationMiddleware):
"retrieve_user_handler",
"token_secret",
"token_cls",
"token_audience",
"token_issuer",
"require_claims",
"verify_expiry",
"verify_not_before",
"strict_audience",
)

def __init__(
Expand All @@ -45,6 +51,12 @@ def __init__(
scopes: Scopes,
token_secret: str,
token_cls: type[Token] = Token,
token_audience: Sequence[str] | None = None,
token_issuer: Sequence[str] | None = None,
require_claims: Sequence[str] | None = None,
verify_expiry: bool = True,
verify_not_before: bool = True,
strict_audience: bool = False,
) -> None:
"""Check incoming requests for an encoded token in the auth header specified, and if present retrieve the user
from persistence using the provided function.
Expand All @@ -62,6 +74,18 @@ def __init__(
token_secret: Secret for decoding the JWT. This value should be equivalent to the secret used to
encode it.
token_cls: Token class used when encoding / decoding JWTs
token_audience: Verify the audience when decoding the token. If the audience
in the token does not match any audience given, raise a
:exc:`NotAuthorizedException`
token_issuer: Verify the issuer when decoding the token. If the issuer in
the token does not match any issuer given, raise a
:exc:`NotAuthorizedException`
require_claims: Require these claims to be present in the JWT payload
verify_expiry: Verify that the value of the ``exp`` (*expiration*) claim is in the future
verify_not_before: Verify that the value of the ``nbf`` (*not before*) claim is in the past
strict_audience: Verify that the value of the ``aud`` (*audience*) claim is a single value, and
not a list of values, and matches ``audience`` exactly. Requires that
``accepted_audiences`` is a sequence of length 1
"""
super().__init__(
app=app,
Expand All @@ -75,6 +99,12 @@ def __init__(
self.retrieve_user_handler = retrieve_user_handler
self.token_secret = token_secret
self.token_cls = token_cls
self.token_audience = token_audience
self.token_issuer = token_issuer
self.require_claims = require_claims
self.verify_expiry = verify_expiry
self.verify_not_before = verify_not_before
self.strict_audience = strict_audience

async def authenticate_request(self, connection: ASGIConnection[Any, Any, Any, Any]) -> AuthenticationResult:
"""Given an HTTP Connection, parse the JWT api key stored in the header and retrieve the user correlating to the
Expand Down Expand Up @@ -114,6 +144,12 @@ async def authenticate_token(
encoded_token=encoded_token,
secret=self.token_secret,
algorithm=self.algorithm,
audience=self.token_audience,
issuer=self.token_issuer,
require_claims=self.require_claims,
verify_exp=self.verify_expiry,
verify_nbf=self.verify_not_before,
strict_audience=self.strict_audience,
)

user = await self.retrieve_user_handler(token, connection)
Expand Down Expand Up @@ -142,6 +178,12 @@ def __init__(
scopes: Scopes,
token_secret: str,
token_cls: type[Token] = Token,
token_audience: Sequence[str] | None = None,
token_issuer: Sequence[str] | None = None,
require_claims: Sequence[str] | None = None,
verify_expiry: bool = True,
verify_not_before: bool = True,
strict_audience: bool = False,
) -> None:
"""Check incoming requests for an encoded token in the auth header or cookie name specified, and if present
retrieves the user from persistence using the provided function.
Expand All @@ -160,6 +202,18 @@ def __init__(
token_secret: Secret for decoding the JWT. This value should be equivalent to the secret used to
encode it.
token_cls: Token class used when encoding / decoding JWTs
token_audience: Verify the audience when decoding the token. If the audience
in the token does not match any audience given, raise a
:exc:`NotAuthorizedException`
token_issuer: Verify the issuer when decoding the token. If the issuer in
the token does not match any issuer given, raise a
:exc:`NotAuthorizedException`
require_claims: Require these claims to be present in the JWT payload
verify_expiry: Verify that the value of the ``exp`` (*expiration*) claim is in the future
verify_not_before: Verify that the value of the ``nbf`` (*not before*) claim is in the past
strict_audience: Verify that the value of the ``aud`` (*audience*) claim is a single value, and
not a list of values, and matches ``audience`` exactly. Requires that
``accepted_audiences`` is a sequence of length 1
"""
super().__init__(
algorithm=algorithm,
Expand All @@ -172,6 +226,12 @@ def __init__(
scopes=scopes,
token_secret=token_secret,
token_cls=token_cls,
token_audience=token_audience,
token_issuer=token_issuer,
require_claims=require_claims,
verify_expiry=verify_expiry,
verify_not_before=verify_not_before,
strict_audience=strict_audience,
)
self.auth_cookie_key = auth_cookie_key

Expand Down
Loading

0 comments on commit 44819d0

Please sign in to comment.