Skip to content

Commit

Permalink
feat: add pyjwt requirement
Browse files Browse the repository at this point in the history
  • Loading branch information
mumarkhan999 committed Jun 21, 2023
1 parent 1f1cc25 commit e7937f5
Show file tree
Hide file tree
Showing 14 changed files with 191 additions and 178 deletions.
148 changes: 60 additions & 88 deletions lti_consumer/lti_1p3/key_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
This handles validating messages sent by the tool and generating
access token with LTI scopes.
"""
import codecs
import copy
import time
import json
import math
import time
import sys

import jwt
from Cryptodome.PublicKey import RSA
from jwkest import BadSignature, BadSyntax, WrongNumberOfParts, jwk
from jwkest.jwk import RSAKey, load_jwks_from_url
from jwkest.jws import JWS, NoSuitableSigningKeys
from jwkest.jwt import JWT
from jwt.api_jwk import PyJWK

from . import exceptions

Expand Down Expand Up @@ -47,14 +46,9 @@ def __init__(self, public_key=None, keyset_url=None):
# Import from public key
if public_key:
try:
new_key = RSAKey(use='sig')

# Unescape key before importing it
raw_key = codecs.decode(public_key, 'unicode_escape')

# Import Key and save to internal state
new_key.load_key(RSA.import_key(raw_key))
self.public_key = new_key
algo_obj = jwt.get_algorithm_by_name('RS256')
self.public_key = PyJWK.from_json(algo_obj.to_jwk(public_key))
except ValueError as err:
raise exceptions.InvalidRsaKey() from err

Expand All @@ -69,7 +63,7 @@ def _get_keyset(self, kid=None):

if self.keyset_url:
try:
keys = load_jwks_from_url(self.keyset_url)
keys = jwt.PyJWKClient(self.keyset_url).get_jwk_set()
except Exception as err:
# Broad Exception is required here because jwkest raises
# an Exception object explicitly.
Expand All @@ -78,13 +72,13 @@ def _get_keyset(self, kid=None):
raise exceptions.NoSuitableKeys() from err
keyset.extend(keys)

if self.public_key and kid:
# Fill in key id of stored key.
# This is needed because if the JWS is signed with a
# key with a kid, pyjwkest doesn't match them with
# keys without kid (kid=None) and fails verification
self.public_key.kid = kid

if self.public_key:
if kid:
# Fill in key id of stored key.
# This is needed because if the JWS is signed with a
# key with a kid, pyjwkest doesn't match them with
# keys without kid (kid=None) and fails verification
self.public_key.kid = kid
# Add to keyset
keyset.append(self.public_key)

Expand All @@ -100,32 +94,24 @@ def validate_and_decode(self, token):
iss, sub, exp, aud and jti claims.
"""
try:
# Get KID from JWT header
jwt = JWT().unpack(token)

# Verify message signature
message = JWS().verify_compact(
token,
keys=self._get_keyset(
jwt.headers.get('kid')
)
)

# If message is valid, check expiration from JWT
if 'exp' in message and message['exp'] < time.time():
raise exceptions.TokenSignatureExpired()

# TODO: Validate other JWT claims

# Else returns decoded message
return message

except NoSuitableSigningKeys as err:
raise exceptions.NoSuitableKeys() from err
except (BadSyntax, WrongNumberOfParts) as err:
raise exceptions.MalformedJwtToken() from err
except BadSignature as err:
raise exceptions.BadJwtSignature() from err
key_set = self._get_keyset()
if not key_set:
raise exceptions.NoSuitableKeys()
for i in range(len(key_set)):
try:
message = jwt.decode(
token,
key=key_set[i],
algorithms=['RS256', 'RS512',],
options={'verify_signature': True}
)
return message
except Exception:
if i == len(key_set) - 1:
raise
except Exception as token_error:
exc_info = sys.exc_info()
raise jwt.InvalidTokenError(exc_info[2]) from token_error


class PlatformKeyHandler:
Expand All @@ -144,14 +130,9 @@ def __init__(self, key_pem, kid=None):
if key_pem:
# Import JWK from RSA key
try:
self.key = RSAKey(
# Using the same key ID as client id
# This way we can easily serve multiple public
# keys on teh same endpoint and keep all
# LTI 1.3 blocks working
kid=kid,
key=RSA.import_key(key_pem)
)
algo = jwt.get_algorithm_by_name('RS256')
key_data = algo.prepare_key(rsa_key.export_key('PEM').decode())
self.key = algo.prepare_key(key_pem)
except ValueError as err:
raise exceptions.InvalidRsaKey() from err

Expand All @@ -167,28 +148,26 @@ def encode_and_sign(self, message, expiration=None):
# Set iat and exp if expiration is set
if expiration:
_message.update({
"iat": int(round(time.time())),
"exp": int(round(time.time()) + expiration),
"iat": int(math.floor(time.time())),
"exp": int(math.floor(time.time()) + expiration),
})

# The class instance that sets up the signing operation
# An RS 256 key is required for LTI 1.3
_jws = JWS(_message, alg="RS256", cty="JWT")

# Encode and sign LTI message
return _jws.sign_compact([self.key])
return jwt.encode(_message, self.key, algorithm="RS256")

def get_public_jwk(self):
"""
Export Public JWK
"""
public_keys = jwk.KEYS()
jwk = {"keys": []}

# Only append to keyset if a key exists
if self.key:
public_keys.append(self.key)

return json.loads(public_keys.dump_jwks())
algo_obj = jwt.get_algorithm_by_name('RS256')
public_key = algo_obj.prepare_key(self.key).public_key()
jwk['keys'].append(json.loads(algo_obj.to_jwk(public_key)))
return jwk

def validate_and_decode(self, token, iss=None, aud=None):
"""
Expand All @@ -197,29 +176,22 @@ def validate_and_decode(self, token, iss=None, aud=None):
Validates a token sent by the tool using the platform's RSA Key.
Optionally validate iss and aud claims if provided.
"""
if not self.key:
raise exceptions.RsaKeyNotSet()
try:
# Verify message signature
message = JWS().verify_compact(token, keys=[self.key])

# If message is valid, check expiration from JWT
if 'exp' in message and message['exp'] < time.time():
raise exceptions.TokenSignatureExpired()

# Validate issuer claim (if present)
if iss:
if 'iss' not in message or message['iss'] != iss:
raise exceptions.InvalidClaimValue('The required iss claim is either missing or does '
'not match the expected iss value.')

# Validate audience claim (if present)
if aud:
if 'aud' not in message or aud not in message['aud']:
raise exceptions.InvalidClaimValue('The required aud claim is missing.')

# Else return token contents
message = jwt.decode(
token,
key=self.key.public_key(),
audience=aud,
issuer=iss,
algorithms=['RS256', 'RS512'],
options={
'verify_signature': True,
'verify_aud': True if aud else False
}
)
return message

except NoSuitableSigningKeys as err:
raise exceptions.NoSuitableKeys() from err
except BadSyntax as err:
raise exceptions.MalformedJwtToken() from err
except Exception as token_error:
exc_info = sys.exc_info()
raise jwt.InvalidTokenError(exc_info[2]) from token_error
46 changes: 32 additions & 14 deletions lti_consumer/lti_1p3/tests/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
Unit tests for LTI 1.3 consumer implementation
"""

import json
from unittest.mock import patch
from urllib.parse import parse_qs, urlparse

import ddt
import jwt
import sys
from Cryptodome.PublicKey import RSA
from django.test.testcases import TestCase
from edx_django_utils.cache import get_cache_key, TieredCache
from jwkest.jwk import load_jwks
from jwkest.jws import JWS
from jwt.api_jwk import PyJWKSet

from lti_consumer.data import Lti1p3LaunchData
from lti_consumer.lti_1p3 import exceptions
Expand All @@ -34,7 +34,9 @@
STATE = "ABCD"
# Consider storing a fixed key
RSA_KEY_ID = "1"
RSA_KEY = RSA.generate(2048).export_key('PEM')
RSA_KEY = RSA.generate(2048)
RSA_PRIVATE_KEY = RSA_KEY.export_key('PEM')
RSA_PUBLIC_KEY = RSA_KEY.public_key().export_key('PEM')


# Test classes
Expand All @@ -53,11 +55,11 @@ def setUp(self):
lti_launch_url=LAUNCH_URL,
client_id=CLIENT_ID,
deployment_id=DEPLOYMENT_ID,
rsa_key=RSA_KEY,
rsa_key=RSA_PRIVATE_KEY,
rsa_key_id=RSA_KEY_ID,
redirect_uris=REDIRECT_URIS,
# Use the same key for testing purposes
tool_key=RSA_KEY
tool_key=RSA_PUBLIC_KEY
)

def _setup_lti_launch_data(self):
Expand Down Expand Up @@ -102,9 +104,25 @@ def _decode_token(self, token):
This also tests the public keyset function.
"""
public_keyset = self.lti_consumer.get_public_keyset()
key_set = load_jwks(json.dumps(public_keyset))

return JWS().verify_compact(token, keys=key_set)
keyset = PyJWKSet.from_dict(public_keyset).keys

for i in range(len(keyset)):
try:
message = jwt.decode(
token,
key=keyset[i].key,
algorithms=['RS256', 'RS512'],
options={
'verify_signature': True,
'verify_aud': False
}
)
return message
except Exception as token_error:
if i < len(keyset) - 1:
continue
exc_info = sys.exc_info()
raise jwt.InvalidTokenError(exc_info[2]) from token_error

@ddt.data(
({"client_id": CLIENT_ID, "redirect_uri": LAUNCH_URL, "nonce": STATE, "state": STATE}, True),
Expand Down Expand Up @@ -526,7 +544,7 @@ def test_access_token_invalid_jwt(self):
"scope": "",
}

with self.assertRaises(exceptions.MalformedJwtToken):
with self.assertRaises(jwt.exceptions.InvalidTokenError):
self.lti_consumer.access_token(request_data)

def test_access_token(self):
Expand Down Expand Up @@ -641,11 +659,11 @@ def setUp(self):
lti_launch_url=LAUNCH_URL,
client_id=CLIENT_ID,
deployment_id=DEPLOYMENT_ID,
rsa_key=RSA_KEY,
rsa_key=RSA_PRIVATE_KEY,
rsa_key_id=RSA_KEY_ID,
redirect_uris=REDIRECT_URIS,
# Use the same key for testing purposes
tool_key=RSA_KEY
tool_key=RSA_PUBLIC_KEY
)

self.preflight_response = {}
Expand Down Expand Up @@ -884,11 +902,11 @@ def setUp(self):
lti_launch_url=LAUNCH_URL,
client_id=CLIENT_ID,
deployment_id=DEPLOYMENT_ID,
rsa_key=RSA_KEY,
rsa_key=RSA_PRIVATE_KEY,
rsa_key_id=RSA_KEY_ID,
redirect_uris=REDIRECT_URIS,
# Use the same key for testing purposes
tool_key=RSA_KEY
tool_key=RSA_PUBLIC_KEY
)

self.preflight_response = {}
Expand Down
Loading

0 comments on commit e7937f5

Please sign in to comment.