Skip to content

Commit

Permalink
Checking for missing keys earlier (#299)
Browse files Browse the repository at this point in the history
* Checking for missing keys earlier

This will help catch errors where public keys are not retrievable. This is a partial solution to #283.

* Verify request to jwks endpoint

Log the positive or negative response

* linter

* Move get_public_keys into its own function to improve caching
  • Loading branch information
Bento007 authored Oct 29, 2019
1 parent 58a72d4 commit 55c50c3
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 20 deletions.
6 changes: 3 additions & 3 deletions fusillade/api/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from fusillade import Config
from fusillade.errors import FusilladeHTTPException
from fusillade.utils.security import get_openid_config, get_public_keys
from fusillade.utils.security import get_openid_config, get_public_key


def login():
Expand Down Expand Up @@ -183,9 +183,9 @@ def cb():
except requests.exceptions.HTTPError:
return make_response(res.text, res.status_code, res.headers.items())
token_header = jwt.get_unverified_header(res.json()["id_token"])
public_keys = get_public_keys(openid_provider)
public_key = get_public_key(openid_provider, token_header["kid"])
tok = jwt.decode(res.json()["id_token"],
key=public_keys[token_header["kid"]],
key=public_key,
audience=oauth2_config[openid_provider]["client_id"])
assert tok["email_verified"]
if redirect_uri:
Expand Down
69 changes: 52 additions & 17 deletions fusillade/utils/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


@functools.lru_cache(maxsize=32)
def get_openid_config(openid_provider=None):
def get_openid_config(openid_provider=None) -> dict:
"""
:param openid_provider: the openid provider's domain.
Expand All @@ -43,30 +43,73 @@ def get_openid_config(openid_provider=None):
return res.json()


def get_jwks_uri(openid_provider):
def get_jwks_uri(openid_provider) -> str:
if openid_provider.endswith(gserviceaccount_domain):
return f"https://www.googleapis.com/service_accounts/v1/jwk/{openid_provider}"
else:
return get_openid_config(openid_provider)["jwks_uri"]


@functools.lru_cache(maxsize=32)
def get_public_keys(openid_provider):
def get_public_keys(issuer: str) -> typing.Dict[str, bytearray]:
"""
Fetches the public key from an OIDC Identity provider to verify the JWT.
:param openid_provider: the openid provider's domain.
:return: Public Keys
Fetches the public keys from an OIDC Identity provider to verify the JWT and caching for later use.
:param issuer: the openid provider's domain.
:param kid: the key identifier for verifying the JWT
:return: A Public Keys
"""
keys = session.get(get_jwks_uri(openid_provider)).json()["keys"]
resp = session.get(get_jwks_uri(issuer))
try:
resp.raise_for_status()
except requests.exceptions.HTTPError:
logger.error({"message": f"Get {get_jwks_uri(issuer)} Failed",
"text": resp.text,
"status_code": resp.status_code,
})
raise FusilladeHTTPException(503, 'Service Unavailable', "Failed to fetched public key from openid provider.")
else:
logger.info({
"message": f"Get {get_jwks_uri(issuer)} Succeeded",
"response": resp.json(),
"status_code": resp.status_code
})

return {
key["kid"]: rsa.RSAPublicNumbers(
e=int.from_bytes(base64.urlsafe_b64decode(key["e"] + "==="), byteorder="big"),
n=int.from_bytes(base64.urlsafe_b64decode(key["n"] + "==="), byteorder="big")
).public_key(backend=default_backend())
for key in keys
for key in resp.json()["keys"]
}


def get_public_key(issuer: str, kid: str) -> bytearray:
"""
Fetches the public keys from an OIDC Identity provider to verify the JWT. If the key is not found in the public
key cache, the cache is cleared and a retry is performed.
:param issuer: the openid provider's domain.
:param kid: the key identifier for verifying the JWT
:return: A Public Key
"""
public_keys = get_public_keys(issuer)
try:
return public_keys[kid]
except KeyError:
logger.error({"message": "Failed to fetched public key from openid provider.",
"public_keys": public_keys,
"issuer": issuer,
"kid": kid})
logger.debug({"message": "Clearing public key cache."})
get_public_keys.clear_cache()
public_keys = get_public_keys(issuer)
try:
return public_keys[kid]
except KeyError:
raise FusilladeHTTPException(401,
'Unauthorized',
f"Unable to verify JWT. KID:{kid} does not exists for issuer:{issuer}.")


def verify_jwt(token: str) -> typing.Optional[typing.Mapping]:
"""
Verify the JWT from the request. This is function is referenced in fusillade-api.yml
Expand All @@ -84,15 +127,7 @@ def verify_jwt(token: str) -> typing.Optional[typing.Mapping]:
raise FusilladeHTTPException(401, 'Unauthorized', 'Failed to decode token.')

issuer = unverified_token['iss']
public_keys = get_public_keys(issuer)
try:
public_key = public_keys[token_header["kid"]]
except KeyError:
logger.error({"message": "Failed to fetched public key from openid provider.",
"public_keys": public_keys,
"issuer": issuer,
"kid": token_header["kid"]})
raise FusilladeHTTPException(503, 'Service Unavailable', "Failed to fetched public key from openid provider.")
public_key = get_public_key(issuer, token_header["kid"])
try:
verified_tok = jwt.decode(token,
key=public_key,
Expand Down

0 comments on commit 55c50c3

Please sign in to comment.