From 22ba13287874a7a22748b451c7d84f12b2f8ccab Mon Sep 17 00:00:00 2001 From: William Woodruff Date: Tue, 18 Jul 2023 11:52:48 -0400 Subject: [PATCH] api_jwt: add a `strict_aud` option (#902) * api_jwt: add a `strict_aud` option Signed-off-by: William Woodruff * CHANGELOG: record changes Signed-off-by: William Woodruff --------- Signed-off-by: William Woodruff --- CHANGELOG.rst | 2 ++ docs/api.rst | 1 + jwt/api_jwt.py | 22 +++++++++++- tests/test_api_jwt.py | 79 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 103 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 35c7fa95..b5cffbdf 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -16,6 +16,8 @@ Fixed Added ~~~~~ +- Add ``strict_aud`` as an option to ``jwt.decode`` by @woodruffw in `#902 `__ + `v2.7.0 `__ ----------------------------------------------------------------------- diff --git a/docs/api.rst b/docs/api.rst index 919b6af9..8a51097d 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -53,6 +53,7 @@ API Reference * ``verify_exp=verify_signature`` check that ``exp`` (expiration) claim value is in the future * ``verify_iat=verify_signature`` check that ``iat`` (issued at) claim value is an integer * ``verify_nbf=verify_signature`` check that ``nbf`` (not before) claim value is in the past + * ``strict_aud=False`` check that the ``aud`` claim is a single value (not a list), and matches ``audience`` exactly .. warning:: diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 49d1b488..48d739ad 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -251,7 +251,9 @@ def _validate_claims( self._validate_iss(payload, issuer) if options["verify_aud"]: - self._validate_aud(payload, audience) + self._validate_aud( + payload, audience, strict=options.get("strict_aud", False) + ) def _validate_required_claims( self, @@ -307,6 +309,8 @@ def _validate_aud( self, payload: dict[str, Any], audience: str | Iterable[str] | None, + *, + strict: bool = False, ) -> None: if audience is None: if "aud" not in payload or not payload["aud"]: @@ -322,6 +326,22 @@ def _validate_aud( audience_claims = payload["aud"] + # In strict mode, we forbid list matching: the supplied audience + # must be a string, and it must exactly match the audience claim. + if strict: + # Only a single audience is allowed in strict mode. + if not isinstance(audience, str): + raise InvalidAudienceError("Invalid audience (strict)") + + # Only a single audience claim is allowed in strict mode. + if not isinstance(audience_claims, str): + raise InvalidAudienceError("Invalid claim format in token (strict)") + + if audience != audience_claims: + raise InvalidAudienceError("Audience doesn't match (strict)") + + return + if isinstance(audience_claims, str): audience_claims = [audience_claims] if not isinstance(audience_claims, list): diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index 0d534446..82b92994 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -723,3 +723,82 @@ def test_decode_complete_warns_on_unsupported_kwarg(self, jwt, payload): jwt.decode_complete(jwt_message, secret, algorithms=["HS256"], foo="bar") assert len(record) == 1 assert "foo" in str(record[0].message) + + def test_decode_strict_aud_forbids_list_audience(self, jwt, payload): + secret = "secret" + payload["aud"] = "urn:foo" + jwt_message = jwt.encode(payload, secret) + + # Decodes without `strict_aud`. + jwt.decode( + jwt_message, + secret, + audience=["urn:foo", "urn:bar"], + options={"strict_aud": False}, + algorithms=["HS256"], + ) + + # Fails with `strict_aud`. + with pytest.raises(InvalidAudienceError, match=r"Invalid audience \(strict\)"): + jwt.decode( + jwt_message, + secret, + audience=["urn:foo", "urn:bar"], + options={"strict_aud": True}, + algorithms=["HS256"], + ) + + def test_decode_strict_aud_forbids_list_claim(self, jwt, payload): + secret = "secret" + payload["aud"] = ["urn:foo", "urn:bar"] + jwt_message = jwt.encode(payload, secret) + + # Decodes without `strict_aud`. + jwt.decode( + jwt_message, + secret, + audience="urn:foo", + options={"strict_aud": False}, + algorithms=["HS256"], + ) + + # Fails with `strict_aud`. + with pytest.raises( + InvalidAudienceError, match=r"Invalid claim format in token \(strict\)" + ): + jwt.decode( + jwt_message, + secret, + audience="urn:foo", + options={"strict_aud": True}, + algorithms=["HS256"], + ) + + def test_decode_strict_aud_does_not_match(self, jwt, payload): + secret = "secret" + payload["aud"] = "urn:foo" + jwt_message = jwt.encode(payload, secret) + + with pytest.raises( + InvalidAudienceError, match=r"Audience doesn't match \(strict\)" + ): + jwt.decode( + jwt_message, + secret, + audience="urn:bar", + options={"strict_aud": True}, + algorithms=["HS256"], + ) + + def test_decode_strict_ok(self, jwt, payload): + secret = "secret" + payload["aud"] = "urn:foo" + jwt_message = jwt.encode(payload, secret) + + jwt.decode( + jwt_message, + secret, + audience="urn:foo", + options={"strict_aud": True}, + algorithms=["HS256"], + )