From 7a894635de663379c0a23d01f9e4932002a3c0f0 Mon Sep 17 00:00:00 2001 From: James Stott <158563996+jamesstottmoj@users.noreply.github.com> Date: Fri, 20 Dec 2024 15:17:46 +0000 Subject: [PATCH] Replaced python-jose with PyJWT (#1417) * Replaced python-jose with PyJWT * Fixed failing tests * Changed message back to old one * Bumped several dependencies --- controlpanel/api/jwt_auth.py | 6 ++--- controlpanel/jwt.py | 45 +++++++++++++------------------- requirements.dev.txt | 2 +- requirements.txt | 6 ++--- tests/api/test_authentication.py | 27 ++++++++++--------- 5 files changed, 40 insertions(+), 46 deletions(-) diff --git a/controlpanel/api/jwt_auth.py b/controlpanel/api/jwt_auth.py index 4f5676845..031a887d0 100644 --- a/controlpanel/api/jwt_auth.py +++ b/controlpanel/api/jwt_auth.py @@ -7,7 +7,7 @@ # First-party/Local from controlpanel.api.models import User -from controlpanel.jwt import JWT, JWTDecodeError +from controlpanel.jwt import JWT, DecodeError M2M_CLAIM_FLAG = "client-credentials" @@ -69,7 +69,7 @@ def authenticate(self, request): else: try: jwt.validate() - except JWTDecodeError: + except DecodeError: return None return self._get_client(jwt), None @@ -89,7 +89,7 @@ def _get_client(self, jwt): return AuthenticatedServiceClient(jwt.payload) else: raise exceptions.AuthenticationFailed() - except JWTDecodeError: + except DecodeError: raise exceptions.AuthenticationFailed( "Failed to be authenticated due to JWT decoder error!" ) diff --git a/controlpanel/jwt.py b/controlpanel/jwt.py index 6e47c1292..bcf10a598 100644 --- a/controlpanel/jwt.py +++ b/controlpanel/jwt.py @@ -1,19 +1,13 @@ # Third-party -import requests +import jwt import structlog from django.conf import settings -from jose import jwt -from jose.exceptions import JWTError -from requests.exceptions import RequestException +from jwt.exceptions import DecodeError, InvalidTokenError, PyJWKClientError from rest_framework import HTTP_HEADER_ENCODING log = structlog.getLogger(__name__) -class JWTDecodeError(Exception): - pass - - class JWT: def __init__(self, raw_token): self._header = None @@ -25,7 +19,7 @@ def __init__(self, raw_token): "algorithms": [settings.OIDC_RP_SIGN_ALGO], "audience": settings.OIDC_CPANEL_API_AUDIENCE, "options": { - "require_sub": True, + "require": ["sub"], }, } @@ -37,7 +31,7 @@ def header(self): if not self._header: try: self._header = jwt.get_unverified_header(self._raw_token) - except jwt.JWTError: + except (DecodeError, InvalidTokenError): return None return self._header @@ -45,22 +39,19 @@ def header(self): def jwk(self): if not self._jwk and self.header: try: - response = requests.get(self.jwks_url, verify=False) - response.raise_for_status() - except RequestException as error: - raise JWTDecodeError(f"Failed fetching JWK: {error}") + jwks_client = jwt.PyJWKClient(self.jwks_url) + jwk = jwks_client.get_signing_key_from_jwt(self._raw_token) - jwks = response.json() + if jwk.key_id != self.header["kid"]: + raise DecodeError( + f'No JWK with id {self.header["kid"]} found at {self.jwks_url} ' + f"while decoding {self._raw_token}" + ) - for jwk in jwks.get("keys", []): - if jwk["kid"] == self.header["kid"]: - self._jwk = jwk - return self._jwk + self._jwk = jwk.key - raise JWTDecodeError( - f'No JWK with id {self.header["kid"]} found at {self.jwks_url} ' - f"while decoding {self._raw_token}" - ) + except PyJWKClientError as error: + raise DecodeError(f"Failed fetching JWK: {error}") return self._jwk @@ -73,8 +64,8 @@ def payload(self): key=self.jwk, **self.decode_options, ) - except (JWTError, KeyError) as error: - raise JWTDecodeError(f"Failed decoding JWT: {error}") + except (DecodeError, KeyError) as error: + raise DecodeError(f"Failed decoding JWT: {error}") return self._payload def validate(self): @@ -84,8 +75,8 @@ def validate(self): key=self.jwk, **self.decode_options, ) - except (JWTError, KeyError) as error: - raise JWTDecodeError(f"Failed decoding JWT: {error}") + except (DecodeError, KeyError) as error: + raise DecodeError(f"Failed decoding JWT: {error}") @classmethod def from_auth_header(cls, request): diff --git a/requirements.dev.txt b/requirements.dev.txt index 4bba247fb..7904885e5 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -6,5 +6,5 @@ ipython==8.29.0 isort==5.13.2 pandas==2.2.0 pre-commit==3.7.1 -pylint==3.0.3 +pylint==3.3.2 pylint-django==2.5.5 diff --git a/requirements.txt b/requirements.txt index c7ad14afb..d0225670c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,15 +23,15 @@ gunicorn==22.0.0 Jinja2==3.1.4 kubernetes==31.0.0 MarkupSafe==2.1.5 -model-bakery==1.17.0 +model-bakery==1.20.0 moto[all]==5.0.18 mozilla-django-oidc==4.0.1 -psycopg2-binary==2.9.9 +psycopg2-binary==2.9.10 +PyJWT==2.10.1 PyNaCl==1.5.0 pytest==8.0.0 pytest-django==4.9.0 python-dotenv==1.0.1 -python-jose==3.3.0 pyyaml==6.0.2 rules==3.3 sentry-sdk==2.19.2 diff --git a/tests/api/test_authentication.py b/tests/api/test_authentication.py index a3513c1b5..9c9afccf6 100644 --- a/tests/api/test_authentication.py +++ b/tests/api/test_authentication.py @@ -1,10 +1,11 @@ # Standard library -from unittest.mock import patch +from unittest.mock import MagicMock, patch # Third-party +import jwt +import jwt.algorithms import pytest -from jose import JWTError, jwk, jwt -from requests.exceptions import Timeout +from jwt.exceptions import DecodeError, PyJWKClientError # First-party/Local from controlpanel.api.models import User @@ -46,13 +47,13 @@ def audience(settings): @pytest.fixture(autouse=True) def jwks(): - key = jwk.construct(TEST_PUBLIC_KEY, "RS256") - jwk_dict = key.to_dict() - jwk_dict["kid"] = TEST_KID - with patch("controlpanel.jwt.requests") as requests: - response = requests.get.return_value - response.json.return_value = {"keys": [jwk_dict]} - yield requests + with patch("controlpanel.jwt.jwt.PyJWKClient") as client: + client_value = MagicMock() + client_value.get_signing_key_from_jwt.return_value = MagicMock( + key=TEST_PUBLIC_KEY, key_id=TEST_KID + ) + client.return_value = client_value + yield client @pytest.fixture(autouse=True) @@ -120,12 +121,14 @@ def test_token_auth(api_request, auth_header, status): def test_bad_request_for_jwks(api_request, jwks): - jwks.get.side_effect = Timeout("test_bad_request_for_jwks") + jwks.return_value.get_signing_key_from_jwt.side_effect = PyJWKClientError( + "test_bad_request_for_jwks" + ) assert api_request(HTTP_AUTHORIZATION=f"Bearer {token()}").status_code == 403 def test_decode_jwt_error(api_request): with patch("controlpanel.jwt.jwt") as jwt: - jwt.decode.side_effect = JWTError("test_decode_jwt_error") + jwt.decode.side_effect = DecodeError("test_decode_jwt_error") assert api_request(HTTP_AUTHORIZATION=f"Bearer {token()}").status_code == 403