Skip to content

Commit

Permalink
api_jwt: add a strict_aud option (#902)
Browse files Browse the repository at this point in the history
* api_jwt: add a `strict_aud` option

Signed-off-by: William Woodruff <[email protected]>

* CHANGELOG: record changes

Signed-off-by: William Woodruff <[email protected]>

---------

Signed-off-by: William Woodruff <[email protected]>
  • Loading branch information
woodruffw authored Jul 18, 2023
1 parent 6db5df7 commit 22ba132
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Fixed
Added
~~~~~

- Add ``strict_aud`` as an option to ``jwt.decode`` by @woodruffw in `#902 <https://github.com/jpadilla/pyjwt/pull/902>`__

`v2.7.0 <https://github.com/jpadilla/pyjwt/compare/2.6.0...2.7.0>`__
-----------------------------------------------------------------------

Expand Down
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ API Reference
* ``verify_exp=verify_signature`` check that ``exp`` (expiration) claim value is in the future
* ``verify_iat=verify_signature`` check that ``iat`` (issued at) claim value is an integer
* ``verify_nbf=verify_signature`` check that ``nbf`` (not before) claim value is in the past
* ``strict_aud=False`` check that the ``aud`` claim is a single value (not a list), and matches ``audience`` exactly

.. warning::

Expand Down
22 changes: 21 additions & 1 deletion jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,9 @@ def _validate_claims(
self._validate_iss(payload, issuer)

if options["verify_aud"]:
self._validate_aud(payload, audience)
self._validate_aud(
payload, audience, strict=options.get("strict_aud", False)
)

def _validate_required_claims(
self,
Expand Down Expand Up @@ -307,6 +309,8 @@ def _validate_aud(
self,
payload: dict[str, Any],
audience: str | Iterable[str] | None,
*,
strict: bool = False,
) -> None:
if audience is None:
if "aud" not in payload or not payload["aud"]:
Expand All @@ -322,6 +326,22 @@ def _validate_aud(

audience_claims = payload["aud"]

# In strict mode, we forbid list matching: the supplied audience
# must be a string, and it must exactly match the audience claim.
if strict:
# Only a single audience is allowed in strict mode.
if not isinstance(audience, str):
raise InvalidAudienceError("Invalid audience (strict)")

# Only a single audience claim is allowed in strict mode.
if not isinstance(audience_claims, str):
raise InvalidAudienceError("Invalid claim format in token (strict)")

if audience != audience_claims:
raise InvalidAudienceError("Audience doesn't match (strict)")

return

if isinstance(audience_claims, str):
audience_claims = [audience_claims]
if not isinstance(audience_claims, list):
Expand Down
79 changes: 79 additions & 0 deletions tests/test_api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,82 @@ def test_decode_complete_warns_on_unsupported_kwarg(self, jwt, payload):
jwt.decode_complete(jwt_message, secret, algorithms=["HS256"], foo="bar")
assert len(record) == 1
assert "foo" in str(record[0].message)

def test_decode_strict_aud_forbids_list_audience(self, jwt, payload):
secret = "secret"
payload["aud"] = "urn:foo"
jwt_message = jwt.encode(payload, secret)

# Decodes without `strict_aud`.
jwt.decode(
jwt_message,
secret,
audience=["urn:foo", "urn:bar"],
options={"strict_aud": False},
algorithms=["HS256"],
)

# Fails with `strict_aud`.
with pytest.raises(InvalidAudienceError, match=r"Invalid audience \(strict\)"):
jwt.decode(
jwt_message,
secret,
audience=["urn:foo", "urn:bar"],
options={"strict_aud": True},
algorithms=["HS256"],
)

def test_decode_strict_aud_forbids_list_claim(self, jwt, payload):
secret = "secret"
payload["aud"] = ["urn:foo", "urn:bar"]
jwt_message = jwt.encode(payload, secret)

# Decodes without `strict_aud`.
jwt.decode(
jwt_message,
secret,
audience="urn:foo",
options={"strict_aud": False},
algorithms=["HS256"],
)

# Fails with `strict_aud`.
with pytest.raises(
InvalidAudienceError, match=r"Invalid claim format in token \(strict\)"
):
jwt.decode(
jwt_message,
secret,
audience="urn:foo",
options={"strict_aud": True},
algorithms=["HS256"],
)

def test_decode_strict_aud_does_not_match(self, jwt, payload):
secret = "secret"
payload["aud"] = "urn:foo"
jwt_message = jwt.encode(payload, secret)

with pytest.raises(
InvalidAudienceError, match=r"Audience doesn't match \(strict\)"
):
jwt.decode(
jwt_message,
secret,
audience="urn:bar",
options={"strict_aud": True},
algorithms=["HS256"],
)

def test_decode_strict_ok(self, jwt, payload):
secret = "secret"
payload["aud"] = "urn:foo"
jwt_message = jwt.encode(payload, secret)

jwt.decode(
jwt_message,
secret,
audience="urn:foo",
options={"strict_aud": True},
algorithms=["HS256"],
)

0 comments on commit 22ba132

Please sign in to comment.