From 55c50c34e93ac987f7868d7c38093aa610c4b7db Mon Sep 17 00:00:00 2001 From: Trent Smith <1429913+Bento007@users.noreply.github.com> Date: Tue, 29 Oct 2019 09:34:32 -0700 Subject: [PATCH] Checking for missing keys earlier (#299) * Checking for missing keys earlier This will help catch errors where public keys are not retrievable. This is a partial solution to #283. * Verify request to jwks endpoint Log the positive or negative response * linter * Move get_public_keys into its own function to improve caching --- fusillade/api/oauth.py | 6 ++-- fusillade/utils/security.py | 69 ++++++++++++++++++++++++++++--------- 2 files changed, 55 insertions(+), 20 deletions(-) diff --git a/fusillade/api/oauth.py b/fusillade/api/oauth.py index 72268633..76a713d1 100644 --- a/fusillade/api/oauth.py +++ b/fusillade/api/oauth.py @@ -13,7 +13,7 @@ from fusillade import Config from fusillade.errors import FusilladeHTTPException -from fusillade.utils.security import get_openid_config, get_public_keys +from fusillade.utils.security import get_openid_config, get_public_key def login(): @@ -183,9 +183,9 @@ def cb(): except requests.exceptions.HTTPError: return make_response(res.text, res.status_code, res.headers.items()) token_header = jwt.get_unverified_header(res.json()["id_token"]) - public_keys = get_public_keys(openid_provider) + public_key = get_public_key(openid_provider, token_header["kid"]) tok = jwt.decode(res.json()["id_token"], - key=public_keys[token_header["kid"]], + key=public_key, audience=oauth2_config[openid_provider]["client_id"]) assert tok["email_verified"] if redirect_uri: diff --git a/fusillade/utils/security.py b/fusillade/utils/security.py index 946f084f..4867c1aa 100644 --- a/fusillade/utils/security.py +++ b/fusillade/utils/security.py @@ -26,7 +26,7 @@ @functools.lru_cache(maxsize=32) -def get_openid_config(openid_provider=None): +def get_openid_config(openid_provider=None) -> dict: """ :param openid_provider: the openid provider's domain. @@ -43,7 +43,7 @@ def get_openid_config(openid_provider=None): return res.json() -def get_jwks_uri(openid_provider): +def get_jwks_uri(openid_provider) -> str: if openid_provider.endswith(gserviceaccount_domain): return f"https://www.googleapis.com/service_accounts/v1/jwk/{openid_provider}" else: @@ -51,22 +51,65 @@ def get_jwks_uri(openid_provider): @functools.lru_cache(maxsize=32) -def get_public_keys(openid_provider): +def get_public_keys(issuer: str) -> typing.Dict[str, bytearray]: """ - Fetches the public key from an OIDC Identity provider to verify the JWT. - :param openid_provider: the openid provider's domain. - :return: Public Keys + Fetches the public keys from an OIDC Identity provider to verify the JWT and caching for later use. + :param issuer: the openid provider's domain. + :param kid: the key identifier for verifying the JWT + :return: A Public Keys """ - keys = session.get(get_jwks_uri(openid_provider)).json()["keys"] + resp = session.get(get_jwks_uri(issuer)) + try: + resp.raise_for_status() + except requests.exceptions.HTTPError: + logger.error({"message": f"Get {get_jwks_uri(issuer)} Failed", + "text": resp.text, + "status_code": resp.status_code, + }) + raise FusilladeHTTPException(503, 'Service Unavailable', "Failed to fetched public key from openid provider.") + else: + logger.info({ + "message": f"Get {get_jwks_uri(issuer)} Succeeded", + "response": resp.json(), + "status_code": resp.status_code + }) + return { key["kid"]: rsa.RSAPublicNumbers( e=int.from_bytes(base64.urlsafe_b64decode(key["e"] + "==="), byteorder="big"), n=int.from_bytes(base64.urlsafe_b64decode(key["n"] + "==="), byteorder="big") ).public_key(backend=default_backend()) - for key in keys + for key in resp.json()["keys"] } +def get_public_key(issuer: str, kid: str) -> bytearray: + """ + Fetches the public keys from an OIDC Identity provider to verify the JWT. If the key is not found in the public + key cache, the cache is cleared and a retry is performed. + :param issuer: the openid provider's domain. + :param kid: the key identifier for verifying the JWT + :return: A Public Key + """ + public_keys = get_public_keys(issuer) + try: + return public_keys[kid] + except KeyError: + logger.error({"message": "Failed to fetched public key from openid provider.", + "public_keys": public_keys, + "issuer": issuer, + "kid": kid}) + logger.debug({"message": "Clearing public key cache."}) + get_public_keys.clear_cache() + public_keys = get_public_keys(issuer) + try: + return public_keys[kid] + except KeyError: + raise FusilladeHTTPException(401, + 'Unauthorized', + f"Unable to verify JWT. KID:{kid} does not exists for issuer:{issuer}.") + + def verify_jwt(token: str) -> typing.Optional[typing.Mapping]: """ Verify the JWT from the request. This is function is referenced in fusillade-api.yml @@ -84,15 +127,7 @@ def verify_jwt(token: str) -> typing.Optional[typing.Mapping]: raise FusilladeHTTPException(401, 'Unauthorized', 'Failed to decode token.') issuer = unverified_token['iss'] - public_keys = get_public_keys(issuer) - try: - public_key = public_keys[token_header["kid"]] - except KeyError: - logger.error({"message": "Failed to fetched public key from openid provider.", - "public_keys": public_keys, - "issuer": issuer, - "kid": token_header["kid"]}) - raise FusilladeHTTPException(503, 'Service Unavailable', "Failed to fetched public key from openid provider.") + public_key = get_public_key(issuer, token_header["kid"]) try: verified_tok = jwt.decode(token, key=public_key,