diff --git a/controlpanel/jwt.py b/controlpanel/jwt.py index 04cdb301..fc846abb 100644 --- a/controlpanel/jwt.py +++ b/controlpanel/jwt.py @@ -40,7 +40,12 @@ def jwk(self): if not self._jwk and self.header: try: jwks_client = jwt.PyJWKClient(self.jwks_url) - self._jwk = jwks_client.get_signing_key_from_jwt(self._raw_token).key + jwk = jwks_client.get_signing_key_from_jwt(self._raw_token) + + if jwk.key_id != self.header["kid"]: + raise DecodeError("Key ID mismatch") + + self._jwk = jwk.key except PyJWKClientError as error: raise DecodeError(f"Failed fetching JWK: {error}") diff --git a/tests/api/test_authentication.py b/tests/api/test_authentication.py index a3513c1b..9c9afccf 100644 --- a/tests/api/test_authentication.py +++ b/tests/api/test_authentication.py @@ -1,10 +1,11 @@ # Standard library -from unittest.mock import patch +from unittest.mock import MagicMock, patch # Third-party +import jwt +import jwt.algorithms import pytest -from jose import JWTError, jwk, jwt -from requests.exceptions import Timeout +from jwt.exceptions import DecodeError, PyJWKClientError # First-party/Local from controlpanel.api.models import User @@ -46,13 +47,13 @@ def audience(settings): @pytest.fixture(autouse=True) def jwks(): - key = jwk.construct(TEST_PUBLIC_KEY, "RS256") - jwk_dict = key.to_dict() - jwk_dict["kid"] = TEST_KID - with patch("controlpanel.jwt.requests") as requests: - response = requests.get.return_value - response.json.return_value = {"keys": [jwk_dict]} - yield requests + with patch("controlpanel.jwt.jwt.PyJWKClient") as client: + client_value = MagicMock() + client_value.get_signing_key_from_jwt.return_value = MagicMock( + key=TEST_PUBLIC_KEY, key_id=TEST_KID + ) + client.return_value = client_value + yield client @pytest.fixture(autouse=True) @@ -120,12 +121,14 @@ def test_token_auth(api_request, auth_header, status): def test_bad_request_for_jwks(api_request, jwks): - jwks.get.side_effect = Timeout("test_bad_request_for_jwks") + jwks.return_value.get_signing_key_from_jwt.side_effect = PyJWKClientError( + "test_bad_request_for_jwks" + ) assert api_request(HTTP_AUTHORIZATION=f"Bearer {token()}").status_code == 403 def test_decode_jwt_error(api_request): with patch("controlpanel.jwt.jwt") as jwt: - jwt.decode.side_effect = JWTError("test_decode_jwt_error") + jwt.decode.side_effect = DecodeError("test_decode_jwt_error") assert api_request(HTTP_AUTHORIZATION=f"Bearer {token()}").status_code == 403