diff --git a/okta/jwt.py b/okta/jwt.py index b30e5547..8dd87197 100644 --- a/okta/jwt.py +++ b/okta/jwt.py @@ -1,11 +1,12 @@ import json -from Cryptodome.PublicKey import RSA -from ast import literal_eval -import jose.jwk as jwk -import jose.jwt as jwt +import os import time import uuid -import os + +from ast import literal_eval +from Cryptodome.PublicKey import RSA +from jwcrypto.jwk import JWK, InvalidJWKType +from jwt import encode as jwt_encode class JWT(): @@ -63,32 +64,36 @@ def get_PEM_JWK(private_key): # if string repr, convert to dict object if isinstance(private_key, str): private_key = literal_eval(private_key) - # Create JWK using dict obj - my_jwk = jwk.construct(private_key, JWT.HASH_ALGORITHM) + # remove whitespace from key vaules + private_key = {k: ''.join(private_key[k].split()) for k in private_key} + # ensure private_key is JSON formatted + try: + json.loads(private_key) + except TypeError: + private_key = json.dumps(private_key) + try: + my_jwk = JWK.from_json(private_key) + except InvalidJWKType: + raise ValueError( + "JWK given is of the wrong type") else: # it's a PEM # check for filepath or explicit private key if isinstance(private_key, (str, bytes, os.PathLike)) and os.path.exists(private_key): - # open file if exists and import key + # open file if exists and read pem_file = open(private_key, 'r') - my_pem = RSA.import_key(pem_file.read()) + private_key = pem_file.read() pem_file.close() - else: - # convert given string to bytes and import key - private_key_bytes = bytes(private_key, 'ascii') - my_pem = RSA.import_key(private_key_bytes) - - if not my_pem: - # return error if import failed - return (None, ValueError( - "RSA Private Key given is of the wrong type")) - - if my_jwk: # was JWK provided - # get PEM using JWK - pem_bytes = my_jwk.to_pem(JWT.PEM_FORMAT) - my_pem = RSA.import_key(pem_bytes) - else: # was pem provided - # get JWK using PEM - my_jwk = jwk.construct(my_pem.export_key(), JWT.HASH_ALGORITHM) + # remove leading whitespaces from each line + my_pem = '\n'.join([line.strip() for line in private_key.splitlines()]) + my_pem = bytes(my_pem, 'ascii') + try: + my_jwk = JWK.from_pem(my_pem) + except ValueError: + raise ValueError( + "RSA Private Key given is of the wrong type") + + my_pem = my_jwk.export_to_pem(private_key=True, password=None) + my_pem = RSA.import_key(my_pem) return (my_pem, my_jwk) @@ -108,7 +113,7 @@ def create_token(org_url, client_id, private_key, kid=None): str: Generated JWT """ # Generate PEM and JWK - my_pem, my_jwk = JWT.get_PEM_JWK(private_key) + my_pem, _ = JWT.get_PEM_JWK(private_key) # Get current time and expiry time for token issued_time = int(time.time()) expiry_time = issued_time + JWT.ONE_HOUR @@ -142,5 +147,5 @@ def create_token(org_url, client_id, private_key, kid=None): if "kid" in headers: del headers["kid"] - token = jwt.encode(claims, my_jwk.to_dict(), JWT.HASH_ALGORITHM, headers=headers) + token = jwt_encode(claims, my_pem.export_key(), JWT.HASH_ALGORITHM, headers) return token diff --git a/requirements.txt b/requirements.txt index 157e1168..11e7b771 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,8 @@ pyyaml xmltodict yarl pycryptodomex -python-jose[cryptography] +jwcrypto +pyjwt aenum pydash flake8 diff --git a/setup.py b/setup.py index 02d01b7d..c0d4d221 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,8 @@ def get_version(): "xmltodict", "yarl", "pycryptodomex", - "python-jose", + "jwcrypto", + "pyjwt", "aenum==3.1.11", "pydash" ] diff --git a/tests/mocks.py b/tests/mocks.py index 32789744..7ae59b24 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -416,3 +416,6 @@ def mock_next_link(self_url: URL): KLElmMvzocvFaWKvup_a3vPaBi6y4K5kBiq60o-IDMGQ''', "kid": "5ashWt3LP1zkYwMGbfMsVizRfx52QTyky4GTHd9MykE" } + +SAMPLE_INVALID_JWK = {'foo':'bar'} +SAMPLE_INVALID_RSA = 'foobar' diff --git a/tests/unit/test_jwt.py b/tests/unit/test_jwt.py index 12d27ad1..f405e253 100644 --- a/tests/unit/test_jwt.py +++ b/tests/unit/test_jwt.py @@ -7,20 +7,20 @@ def test_private_key_with_kid_in_private_key(mocker): - mocked_encode = mocker.patch('jose.jwt.encode') + mocked_encode = mocker.patch('okta.jwt.jwt_encode') JWT.create_token("test.com", "test-client-id", mocks.SAMPLE_JWK_WITH_KID) expected_kid = mocks.SAMPLE_JWK_WITH_KID["kid"] - _, kwargs = mocked_encode.call_args + args = mocked_encode.call_args.args mocked_encode.assert_called_once() - assert "kid" in kwargs["headers"] - assert kwargs["headers"]["kid"] == expected_kid + assert "kid" in args[-1] + assert args[-1]["kid"] == expected_kid def test_private_key_with_kid_in_config(mocker): - mocked_encode = mocker.patch('jose.jwt.encode') + mocked_encode = mocker.patch('okta.jwt.jwt_encode') expected_kid = "test-kid" JWT.create_token("test.com", "test-client-id", mocks.SAMPLE_JWK, kid=expected_kid) - _, kwargs = mocked_encode.call_args + args = mocked_encode.call_args.args mocked_encode.assert_called_once() - assert "kid" in kwargs["headers"] - assert kwargs["headers"]["kid"] == expected_kid + assert "kid" in args[-1] + assert args[-1]["kid"] == expected_kid diff --git a/tests/unit/test_oauth.py b/tests/unit/test_oauth.py index af7a2d0f..2e608c0d 100644 --- a/tests/unit/test_oauth.py +++ b/tests/unit/test_oauth.py @@ -14,7 +14,7 @@ def test_private_key_PEM_JWK_dict(jwk_input): generated_pem, generated_jwk = JWT.get_PEM_JWK(jwk_input) assert generated_pem is not None and generated_jwk is not None - assert not generated_jwk.is_public() + assert generated_jwk.has_private def test_private_key_PEM_JWK_file(fs): @@ -24,11 +24,18 @@ def test_private_key_PEM_JWK_file(fs): generated_pem, generated_jwk = JWT.get_PEM_JWK(file_path) assert generated_pem is not None and generated_jwk is not None - assert not generated_jwk.is_public() + assert generated_jwk.has_private def test_private_key_PEM_JWK_explicit_string(): generated_pem, generated_jwk = JWT.get_PEM_JWK(mocks.SAMPLE_RSA) assert generated_pem is not None and generated_jwk is not None - assert not generated_jwk.is_public() + assert generated_jwk.has_private + + +@pytest.mark.parametrize("private_key", + [mocks.SAMPLE_INVALID_JWK, str(mocks.SAMPLE_INVALID_JWK), mocks.SAMPLE_INVALID_RSA]) +def test_invalid_private_key_PEM_JWK(private_key): + with pytest.raises(ValueError): + generated_pem, generated_jwk = JWT.get_PEM_JWK(private_key)