Skip to content

Commit

Permalink
Replaced python-jose with PyJWT (#1417)
Browse files Browse the repository at this point in the history
* Replaced python-jose with PyJWT

* Fixed failing tests

* Changed message back to old one

* Bumped several dependencies
  • Loading branch information
jamesstottmoj authored Dec 20, 2024
1 parent 16aa4cd commit 7a89463
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 46 deletions.
6 changes: 3 additions & 3 deletions controlpanel/api/jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# First-party/Local
from controlpanel.api.models import User
from controlpanel.jwt import JWT, JWTDecodeError
from controlpanel.jwt import JWT, DecodeError

M2M_CLAIM_FLAG = "client-credentials"

Expand Down Expand Up @@ -69,7 +69,7 @@ def authenticate(self, request):
else:
try:
jwt.validate()
except JWTDecodeError:
except DecodeError:
return None

return self._get_client(jwt), None
Expand All @@ -89,7 +89,7 @@ def _get_client(self, jwt):
return AuthenticatedServiceClient(jwt.payload)
else:
raise exceptions.AuthenticationFailed()
except JWTDecodeError:
except DecodeError:
raise exceptions.AuthenticationFailed(
"Failed to be authenticated due to JWT decoder error!"
)
Expand Down
45 changes: 18 additions & 27 deletions controlpanel/jwt.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
# Third-party
import requests
import jwt
import structlog
from django.conf import settings
from jose import jwt
from jose.exceptions import JWTError
from requests.exceptions import RequestException
from jwt.exceptions import DecodeError, InvalidTokenError, PyJWKClientError
from rest_framework import HTTP_HEADER_ENCODING

log = structlog.getLogger(__name__)


class JWTDecodeError(Exception):
pass


class JWT:
def __init__(self, raw_token):
self._header = None
Expand All @@ -25,7 +19,7 @@ def __init__(self, raw_token):
"algorithms": [settings.OIDC_RP_SIGN_ALGO],
"audience": settings.OIDC_CPANEL_API_AUDIENCE,
"options": {
"require_sub": True,
"require": ["sub"],
},
}

Expand All @@ -37,30 +31,27 @@ def header(self):
if not self._header:
try:
self._header = jwt.get_unverified_header(self._raw_token)
except jwt.JWTError:
except (DecodeError, InvalidTokenError):
return None
return self._header

@property
def jwk(self):
if not self._jwk and self.header:
try:
response = requests.get(self.jwks_url, verify=False)
response.raise_for_status()
except RequestException as error:
raise JWTDecodeError(f"Failed fetching JWK: {error}")
jwks_client = jwt.PyJWKClient(self.jwks_url)
jwk = jwks_client.get_signing_key_from_jwt(self._raw_token)

jwks = response.json()
if jwk.key_id != self.header["kid"]:
raise DecodeError(
f'No JWK with id {self.header["kid"]} found at {self.jwks_url} '
f"while decoding {self._raw_token}"
)

for jwk in jwks.get("keys", []):
if jwk["kid"] == self.header["kid"]:
self._jwk = jwk
return self._jwk
self._jwk = jwk.key

raise JWTDecodeError(
f'No JWK with id {self.header["kid"]} found at {self.jwks_url} '
f"while decoding {self._raw_token}"
)
except PyJWKClientError as error:
raise DecodeError(f"Failed fetching JWK: {error}")

return self._jwk

Expand All @@ -73,8 +64,8 @@ def payload(self):
key=self.jwk,
**self.decode_options,
)
except (JWTError, KeyError) as error:
raise JWTDecodeError(f"Failed decoding JWT: {error}")
except (DecodeError, KeyError) as error:
raise DecodeError(f"Failed decoding JWT: {error}")
return self._payload

def validate(self):
Expand All @@ -84,8 +75,8 @@ def validate(self):
key=self.jwk,
**self.decode_options,
)
except (JWTError, KeyError) as error:
raise JWTDecodeError(f"Failed decoding JWT: {error}")
except (DecodeError, KeyError) as error:
raise DecodeError(f"Failed decoding JWT: {error}")

@classmethod
def from_auth_header(cls, request):
Expand Down
2 changes: 1 addition & 1 deletion requirements.dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ ipython==8.29.0
isort==5.13.2
pandas==2.2.0
pre-commit==3.7.1
pylint==3.0.3
pylint==3.3.2
pylint-django==2.5.5
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ gunicorn==22.0.0
Jinja2==3.1.4
kubernetes==31.0.0
MarkupSafe==2.1.5
model-bakery==1.17.0
model-bakery==1.20.0
moto[all]==5.0.18
mozilla-django-oidc==4.0.1
psycopg2-binary==2.9.9
psycopg2-binary==2.9.10
PyJWT==2.10.1
PyNaCl==1.5.0
pytest==8.0.0
pytest-django==4.9.0
python-dotenv==1.0.1
python-jose==3.3.0
pyyaml==6.0.2
rules==3.3
sentry-sdk==2.19.2
Expand Down
27 changes: 15 additions & 12 deletions tests/api/test_authentication.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Standard library
from unittest.mock import patch
from unittest.mock import MagicMock, patch

# Third-party
import jwt
import jwt.algorithms
import pytest
from jose import JWTError, jwk, jwt
from requests.exceptions import Timeout
from jwt.exceptions import DecodeError, PyJWKClientError

# First-party/Local
from controlpanel.api.models import User
Expand Down Expand Up @@ -46,13 +47,13 @@ def audience(settings):

@pytest.fixture(autouse=True)
def jwks():
key = jwk.construct(TEST_PUBLIC_KEY, "RS256")
jwk_dict = key.to_dict()
jwk_dict["kid"] = TEST_KID
with patch("controlpanel.jwt.requests") as requests:
response = requests.get.return_value
response.json.return_value = {"keys": [jwk_dict]}
yield requests
with patch("controlpanel.jwt.jwt.PyJWKClient") as client:
client_value = MagicMock()
client_value.get_signing_key_from_jwt.return_value = MagicMock(
key=TEST_PUBLIC_KEY, key_id=TEST_KID
)
client.return_value = client_value
yield client


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -120,12 +121,14 @@ def test_token_auth(api_request, auth_header, status):


def test_bad_request_for_jwks(api_request, jwks):
jwks.get.side_effect = Timeout("test_bad_request_for_jwks")
jwks.return_value.get_signing_key_from_jwt.side_effect = PyJWKClientError(
"test_bad_request_for_jwks"
)
assert api_request(HTTP_AUTHORIZATION=f"Bearer {token()}").status_code == 403


def test_decode_jwt_error(api_request):
with patch("controlpanel.jwt.jwt") as jwt:
jwt.decode.side_effect = JWTError("test_decode_jwt_error")
jwt.decode.side_effect = DecodeError("test_decode_jwt_error")

assert api_request(HTTP_AUTHORIZATION=f"Bearer {token()}").status_code == 403

0 comments on commit 7a89463

Please sign in to comment.