diff --git a/.gitignore b/.gitignore index 4a76c3a3e..7e18527a4 100644 --- a/.gitignore +++ b/.gitignore @@ -108,3 +108,6 @@ tests/resources/keys/*.pem .DS_Store .vscode .idea + +# snyk +.dccache \ No newline at end of file diff --git a/fence/__init__.py b/fence/__init__.py index e1aec601d..fdcc9943d 100755 --- a/fence/__init__.py +++ b/fence/__init__.py @@ -470,6 +470,7 @@ def _setup_oidc_clients(app): logger=logger, HTTP_PROXY=config.get("HTTP_PROXY"), idp=settings.get("name") or idp.title(), + arborist=app.arborist, ) clean_idp = idp.lower().replace(" ", "") setattr(app, f"{clean_idp}_client", client) diff --git a/fence/blueprints/login/base.py b/fence/blueprints/login/base.py index 08fcab61d..7cee07fbe 100644 --- a/fence/blueprints/login/base.py +++ b/fence/blueprints/login/base.py @@ -1,8 +1,12 @@ +import time +import base64 +import json +from urllib.parse import urlparse, urlencode, parse_qsl +import jwt +import requests import flask from cdislogging import get_logger from flask_restful import Resource -from urllib.parse import urlparse, urlencode, parse_qsl - from fence.auth import login_user from fence.blueprints.login.redirect import validate_redirect from fence.config import config @@ -20,7 +24,7 @@ def __init__(self, idp_name, client): Args: idp_name (str): name for the identity provider client (fence.resources.openid.idp_oauth2.Oauth2ClientBase): - Some instaniation of this base client class or a child class + Some instantiation of this base client class or a child class """ self.idp_name = idp_name self.client = client @@ -92,8 +96,27 @@ def __init__( self.is_mfa_enabled = "multifactor_auth_claim_info" in config[ "OPENID_CONNECT" ].get(self.idp_name, {}) + + # Config option to explicitly persist refresh tokens + self.persist_refresh_token = False + + self.read_authz_groups_from_tokens = False + self.app = app + # This block of code probably need to be made more concise + if "persist_refresh_token" in config["OPENID_CONNECT"].get(self.idp_name, {}): + self.persist_refresh_token = config["OPENID_CONNECT"][self.idp_name][ + "persist_refresh_token" + ] + + if "is_authz_groups_sync_enabled" in config["OPENID_CONNECT"].get( + self.idp_name, {} + ): + self.read_authz_groups_from_tokens = config["OPENID_CONNECT"][ + self.idp_name + ]["is_authz_groups_sync_enabled"] + def get(self): # Check if user granted access if flask.request.args.get("error"): @@ -119,7 +142,11 @@ def get(self): code = flask.request.args.get("code") result = self.client.get_auth_info(code) + + refresh_token = result.get("refresh_token") + username = result.get(self.username_field) + if not username: raise UserError( f"OAuth2 callback error: no '{self.username_field}' in {result}" @@ -129,11 +156,157 @@ def get(self): id_from_idp = result.get(self.id_from_idp_field) resp = _login(username, self.idp_name, email=email, id_from_idp=id_from_idp) - self.post_login(user=flask.g.user, token_result=result, id_from_idp=id_from_idp) + + expires = self.extract_exp(refresh_token) + + # if the access token is not a JWT, or does not carry exp, + # default to now + REFRESH_TOKEN_EXPIRES_IN + if expires is None: + expires = int(time.time()) + config["REFRESH_TOKEN_EXPIRES_IN"] + + # Store refresh token in db + should_persist_token = ( + self.persist_refresh_token or self.read_authz_groups_from_tokens + ) + if should_persist_token: + # Ensure flask.g.user exists to avoid a potential AttributeError + if getattr(flask.g, "user", None): + self.client.store_refresh_token(flask.g.user, refresh_token, expires) + else: + logger.error( + "User information is missing from flask.g; cannot store refresh token." + ) + + self.post_login( + user=flask.g.user, + token_result=result, + id_from_idp=id_from_idp, + ) + return resp + def extract_exp(self, refresh_token): + """ + Extract the expiration time (`exp`) from a refresh token. + + This function attempts to retrieve the expiration time from the provided + refresh token using three methods: + + 1. Using PyJWT to decode the token (without signature verification). + 2. Introspecting the token (if supported by the identity provider). + 3. Manually base64 decoding the token's payload (if it's a JWT). + + **Disclaimer:** This function assumes that the refresh token is valid and + does not perform any JWT validation. For JWTs from an OpenID Connect (OIDC) + provider, validation should be done using the public keys provided by the + identity provider (from the JWKS endpoint) before using this function to + extract the expiration time. Without validation, the token's integrity and + authenticity cannot be guaranteed, which may expose your system to security + risks. Ensure validation is handled prior to calling this function, + especially in any public or production-facing contexts. + + Args: + refresh_token (str): The JWT refresh token from which to extract the expiration. + + Returns: + int or None: The expiration time (`exp`) in seconds since the epoch, + or None if extraction fails. + """ + + # Method 1: PyJWT + try: + # Skipping keys since we're not verifying the signature + decoded_refresh_token = jwt.decode( + refresh_token, + options={ + "verify_aud": False, + "verify_at_hash": False, + "verify_signature": False, + }, + algorithms=["RS256", "HS512"], + ) + exp = decoded_refresh_token.get("exp") + + if exp is not None: + return exp + except Exception as e: + logger.info(f"Refresh token expiry: Method (PyJWT) failed: {e}") + + # Method 2: Introspection + try: + introspection_response = self.introspect_token(refresh_token) + exp = introspection_response.get("exp") + + if exp is not None: + return exp + except Exception as e: + logger.info(f"Refresh token expiry: Method Introspection failed: {e}") + + # Method 3: Manual base64 decoding + try: + # Assuming the token is a JWT (header.payload.signature) + payload_encoded = refresh_token.split(".")[1] + # Add necessary padding for base64 decoding + payload_encoded += "=" * (4 - len(payload_encoded) % 4) + payload_decoded = base64.urlsafe_b64decode(payload_encoded) + payload_json = json.loads(payload_decoded) + exp = payload_json.get("exp") + + if exp is not None: + return exp + except Exception as e: + logger.info(f"Method 3 (Manual decoding) failed: {e}") + + # If all methods fail, return None + return None + + def introspect_token(self, token): + """Introspects an access token to determine its validity and retrieve associated metadata. + + This method sends a POST request to the introspection endpoint specified in the OpenID + discovery document. The request includes the provided token and client credentials, + allowing verification of the token's validity and retrieval of any additional metadata + (e.g., token expiry, scopes, or user information). + + Args: + token (str): The access token to be introspected. + + Returns: + dict or None: A dictionary containing the token's introspection data if the request + is successful and the response status code is 200. If the introspection fails or an + exception occurs, returns None. + + Raises: + Exception: Logs an error message if an error occurs during the introspection process. + """ + try: + introspect_endpoint = self.client.get_value_from_discovery_doc( + "introspection_endpoint", "" + ) + + # Headers and payload for the introspection request + headers = {"Content-Type": "application/x-www-form-urlencoded"} + data = { + "token": token, + "client_id": flask.session.get("client_id"), + "client_secret": flask.session.get("client_secret"), + } + + response = requests.post(introspect_endpoint, headers=headers, data=data) + + if response.status_code == 200: + return response.json() + else: + logger.info(f"Error introspecting token: {response.status_code}") + return None + + except Exception as e: + logger.info(f"Error introspecting token: {e}") + return None + def post_login(self, user=None, token_result=None, **kwargs): prepare_login_log(self.idp_name) + metrics.add_login_event( user_sub=flask.g.user.id, idp=self.idp_name, @@ -142,6 +315,13 @@ def post_login(self, user=None, token_result=None, **kwargs): client_id=flask.session.get("client_id"), ) + # this attribute is only applicable to some OAuth clients + # (e.g., not all clients need is_read_authz_groups_from_tokens_enabled) + if self.read_authz_groups_from_tokens: + self.client.update_user_authorization( + user=user, pkey_cache=None, db_session=None, idp_name=self.idp_name + ) + if token_result: username = token_result.get(self.username_field) if self.is_mfa_enabled: diff --git a/fence/config-default.yaml b/fence/config-default.yaml index 721994bde..f25cf6f2b 100755 --- a/fence/config-default.yaml +++ b/fence/config-default.yaml @@ -94,6 +94,7 @@ DB_MIGRATION_POSTGRES_LOCK_KEY: 100 # - WARNING: Be careful changing the *_ALLOWED_SCOPES as you can break basic # and optional functionality # ////////////////////////////////////////////////////////////////////////////////////// + OPENID_CONNECT: # any OIDC IDP that does not differ from the generic implementation can be # configured without code changes @@ -115,6 +116,24 @@ OPENID_CONNECT: multifactor_auth_claim_info: # optional, include if you're using arborist to enforce mfa on a per-file level claim: '' # claims field that indicates mfa, either the acr or acm claim. values: [ "" ] # possible values that indicate mfa was used. At least one value configured here is required to be in the token + # When true, it allows refresh tokens to be stored even if is_authz_groups_sync_enabled is set false. + # When false, the system will only store refresh tokens if is_authz_groups_sync_enabled is enabled + persist_refresh_token: false + # is_authz_groups_sync_enabled: A configuration flag that determines whether the application should + # verify and synchronize user group memberships between the identity provider (IdP) + # and the local authorization system (Arborist). When enabled, the refresh token is stored, the system retrieves + # the user's group information from their token issued by the IdP and compares it against + # the groups defined in the local system. Based on the comparison, the user is added to + # or removed from relevant groups in the local system to ensure their group memberships + # remain up-to-date. If this flag is disabled, no group synchronization occurs + is_authz_groups_sync_enabled: true + authz_groups_sync: + # This defines the prefix used to identify authorization groups. + group_prefix: "some_prefix" + # This flag indicates whether the audience (aud) claim in the JWT should be verified during token validation. + verify_aud: true + # This specifies the expected audience (aud) value for the JWT, ensuring that the token is intended for use with the 'fence' service. + audience: fence # These Google values must be obtained from Google's Cloud Console # Follow: https://developers.google.com/identity/protocols/OpenIDConnect # diff --git a/fence/config.py b/fence/config.py index 7fa47c7cd..775296025 100644 --- a/fence/config.py +++ b/fence/config.py @@ -145,6 +145,13 @@ def post_process(self): f"IdP '{idp_id}' is using multifactor_auth_claim_info '{mfa_info['claim']}', which is neither AMR or ACR. Unable to determine if a user used MFA. Fence will continue and assume they have not used MFA." ) + groups_sync_enabled = idp.get("is_authz_groups_sync_enabled", False) + # when is_authz_groups_sync_enabled, then you must provide authz_groups_sync, with group prefix + if groups_sync_enabled and not idp.get("authz_groups_sync"): + error = f"Error: is_authz_groups_sync_enabled is enabled, required values not configured, for idp: {idp_id}" + logger.error(error) + raise Exception(error) + self._validate_parent_child_studies(self._configs["dbGaP"]) @staticmethod diff --git a/fence/error_handler.py b/fence/error_handler.py index 5b3a0cfdb..6ac6f99dc 100644 --- a/fence/error_handler.py +++ b/fence/error_handler.py @@ -8,12 +8,13 @@ from fence.errors import APIError from fence.config import config +import traceback logger = get_logger(__name__) -def get_error_response(error): +def get_error_response(error: Exception): details, status_code = get_error_details_and_status(error) support_email = config.get("SUPPORT_EMAIL_FOR_ERRORS") app_name = config.get("APP_NAME", "Gen3 Data Commons") @@ -27,6 +28,11 @@ def get_error_response(error): ) ) + # TODO: Issue: Error messages are obfuscated, the line below needs be + # uncommented when troubleshooting errors. + # Breaks tests if not commented out / removed. We need a fix for this. + # raise error + # don't include internal details in the public error message # to do this, only include error messages for known http status codes # that are less that 500 diff --git a/fence/job/visa_update_cronjob.py b/fence/job/access_token_updater.py similarity index 74% rename from fence/job/visa_update_cronjob.py rename to fence/job/access_token_updater.py index cac8d9182..6909357b4 100644 --- a/fence/job/visa_update_cronjob.py +++ b/fence/job/access_token_updater.py @@ -3,16 +3,18 @@ import time from cdislogging import get_logger +from flask import current_app from fence.config import config from fence.models import User from fence.resources.openid.ras_oauth2 import RASOauth2Client as RASClient +from fence.resources.openid.idp_oauth2 import Oauth2ClientBase as OIDCClient logger = get_logger(__name__, log_level="debug") -class Visa_Token_Update(object): +class AccessTokenUpdater(object): def __init__( self, chunk_size=None, @@ -20,6 +22,7 @@ def __init__( thread_pool_size=None, buffer_size=None, logger=logger, + arborist=None, ): """ args: @@ -44,17 +47,42 @@ def __init__( self.visa_types = config.get("USERSYNC", {}).get("visa_types", {}) + # Dict on self which contains all clients that need update + self.oidc_clients_requiring_token_refresh = {} + + # keep this as a special case, because RAS will not set group information configuration. # Initialize visa clients: oidc = config.get("OPENID_CONNECT", {}) + + if not isinstance(oidc, dict): + raise TypeError( + "Expected 'OPENID_CONNECT' configuration to be a dictionary." + ) + if "ras" not in oidc: self.logger.error("RAS client not configured") - self.ras_client = None else: - self.ras_client = RASClient( + ras_client = RASClient( oidc["ras"], HTTP_PROXY=config.get("HTTP_PROXY"), logger=logger, ) + self.oidc_clients_requiring_token_refresh["ras"] = ras_client + + self.arborist = arborist + + # Initialise a client for each OIDC client in oidc, which does have is_authz_groups_sync_enabled set to true and add them + # to oidc_clients_requiring_token_refresh + for oidc_name, settings in oidc.items(): + if settings.get("is_authz_groups_sync_enabled", False): + oidc_client = OIDCClient( + settings=settings, + HTTP_PROXY=config.get("HTTP_PROXY"), + logger=logger, + idp=oidc_name, + arborist=arborist, + ) + self.oidc_clients_requiring_token_refresh[oidc_name] = oidc_client async def update_tokens(self, db_session): """ @@ -68,7 +96,8 @@ async def update_tokens(self, db_session): """ start_time = time.time() - self.logger.info("Initializing Visa Update Cronjob . . .") + # Change this line to reflect we are refreshing tokens, not just visas + self.logger.info("Initializing Visa Update and Token refreshing Cronjob . . .") self.logger.info("Total concurrency size: {}".format(self.concurrency)) self.logger.info("Total thread pool size: {}".format(self.thread_pool_size)) self.logger.info("Total buffer size: {}".format(self.buffer_size)) @@ -139,13 +168,12 @@ async def worker(self, name, queue, updater_queue): queue.task_done() async def updater(self, name, updater_queue, db_session): - """ - Update visas in the updater_queue. - Note that only visas which pass validation will be saved. - """ while True: - user = await updater_queue.get() try: + user = await updater_queue.get() + if user is None: # Use None to signal termination + break + client = self._pick_client(user) if client: self.logger.info( @@ -160,30 +188,35 @@ async def updater(self, name, updater_queue, db_session): pkey_cache=self.pkey_cache, db_session=db_session, ) + else: self.logger.debug( f"Updater {name} NOT updating authorization for " f"user {user.username} because no client was found for IdP: {user.identity_provider}" ) + + updater_queue.task_done() + except Exception as exc: self.logger.error( f"Updater {name} could not update authorization " - f"for {user.username}. Error: {exc}. Continuing." + f"for {user.username if user else 'unknown user'}. Error: {exc}. Continuing." ) - pass - - updater_queue.task_done() + # Ensure task is marked done if exception occurs + updater_queue.task_done() def _pick_client(self, user): """ - Pick oidc client according to the identity provider + Select OIDC client based on identity provider. """ - client = None - if ( - user.identity_provider - and getattr(user.identity_provider, "name") == self.ras_client.idp - ): - client = self.ras_client + + client = self.oidc_clients_requiring_token_refresh.get( + getattr(user.identity_provider, "name"), None + ) + if client: + self.logger.info(f"Picked client: {client.idp} for user {user.username}") + else: + self.logger.info(f"No client found for user {user.username}") return client def _pick_client_from_visa(self, visa): diff --git a/fence/resources/openid/cilogon_oauth2.py b/fence/resources/openid/cilogon_oauth2.py index 163663420..dcdd7224f 100644 --- a/fence/resources/openid/cilogon_oauth2.py +++ b/fence/resources/openid/cilogon_oauth2.py @@ -39,7 +39,7 @@ def get_auth_info(self, code): jwks_endpoint = self.get_value_from_discovery_doc( "jwks_uri", "https://cilogon.org/oauth2/certs" ) - claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) + claims, refresh_token = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) if claims.get("sub"): return {"sub": claims["sub"]} diff --git a/fence/resources/openid/cognito_oauth2.py b/fence/resources/openid/cognito_oauth2.py index 73038c87f..7d3924c55 100644 --- a/fence/resources/openid/cognito_oauth2.py +++ b/fence/resources/openid/cognito_oauth2.py @@ -45,7 +45,7 @@ def get_auth_info(self, code): try: token_endpoint = self.get_value_from_discovery_doc("token_endpoint", "") jwks_endpoint = self.get_value_from_discovery_doc("jwks_uri", "") - claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) + claims, refresh_token = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) self.logger.info(f"Received id token from Cognito: {claims}") diff --git a/fence/resources/openid/google_oauth2.py b/fence/resources/openid/google_oauth2.py index b396fe9ca..edae5f64c 100644 --- a/fence/resources/openid/google_oauth2.py +++ b/fence/resources/openid/google_oauth2.py @@ -47,7 +47,7 @@ def get_auth_info(self, code): jwks_endpoint = self.get_value_from_discovery_doc( "jwks_uri", "https://www.googleapis.com/oauth2/v3/certs" ) - claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) + claims, refresh_token = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) if claims.get("email") and claims.get("email_verified"): return {"email": claims["email"], "sub": claims.get("sub")} diff --git a/fence/resources/openid/idp_oauth2.py b/fence/resources/openid/idp_oauth2.py index c2e497085..92181d027 100644 --- a/fence/resources/openid/idp_oauth2.py +++ b/fence/resources/openid/idp_oauth2.py @@ -1,10 +1,14 @@ from authlib.integrations.requests_client import OAuth2Session +from boto3 import client from cached_property import cached_property from flask import current_app from jose import jwt +from jose.exceptions import JWTError, JWTClaimsError import requests import time - +import datetime +import backoff +from fence.utils import DEFAULT_BACKOFF_SETTINGS from fence.errors import AuthError from fence.models import UpstreamRefreshToken @@ -15,7 +19,14 @@ class Oauth2ClientBase(object): """ def __init__( - self, settings, logger, idp, scope=None, discovery_url=None, HTTP_PROXY=None + self, + settings, + logger, + idp, + scope=None, + discovery_url=None, + HTTP_PROXY=None, + arborist=None, ): self.logger = logger self.settings = settings @@ -25,14 +36,17 @@ def __init__( scope=scope or settings.get("scope") or "openid", redirect_uri=settings["redirect_url"], ) + self.discovery_url = ( discovery_url or settings.get("discovery_url") or getattr(self, "DISCOVERY_URL", None) or "" ) - self.idp = idp # display name for use in logs and error messages + # display name for use in logs and error messages + self.idp = idp self.HTTP_PROXY = HTTP_PROXY + self.authz_groups_from_idp = [] if not self.discovery_url and not settings.get("discovery"): self.logger.warning( @@ -40,6 +54,12 @@ def __init__( f"Some calls for this client may fail if they rely on the OIDC Discovery page. Use 'discovery' to configure clients without a discovery page." ) + self.read_authz_groups_from_tokens = self.settings.get( + "is_authz_groups_sync_enabled", False + ) + + self.arborist = arborist + @cached_property def discovery_doc(self): return requests.get(self.discovery_url) @@ -53,6 +73,7 @@ def get_proxies(self): return None def get_token(self, token_endpoint, code): + return self.session.fetch_token( url=token_endpoint, code=code, proxies=self.get_proxies() ) @@ -63,6 +84,7 @@ def get_jwt_keys(self, jwks_uri): Return None if there is an error while retrieving keys from the api """ resp = requests.get(url=jwks_uri, proxies=self.get_proxies()) + if resp.status_code != requests.codes.ok: self.logger.error( "{} ERROR: Can not retrieve jwt keys from IdP's API {}".format( @@ -72,18 +94,101 @@ def get_jwt_keys(self, jwks_uri): return None return resp.json()["keys"] + def get_raw_token_claims(self, token_id): + """Extracts unvalidated claims from a JWT (JSON Web Token). + + This function decodes a JWT and extracts claims without verifying + the token's signature or audience. It is intended for cases where + access to the raw, unvalidated token claims is sufficient. + + Args: + token_id (str): The JWT token from which to extract claims. + + Returns: + dict: A dictionary of token claims if decoding is successful. + + Raises: + JWTError: If there is an error decoding the token without validation. + + Notes: + This function does not perform any validation of the token. It should + only be used in contexts where validation is not critical or is handled + elsewhere in the application. + """ + try: + # Decode without verification + unvalidated_claims = jwt.decode( + token_id, options={"verify_signature": False} + ) + self.logger.info("Raw token claims extracted successfully.") + return unvalidated_claims + except JWTError as e: + self.logger.error(f"Error extracting claims: {e}") + raise JWTError("Unable to decode the token without validation.") + + def decode_and_validate_token(self, token_id, keys, audience, verify_aud=True): + """Decodes and validates a JWT (JSON Web Token) using provided keys and audience. + + This function decodes a JWT and validates its signature and audience claim, + if required. It is typically used for tokens that require validation to + ensure integrity and authenticity. + + Args: + token_id (str): The JWT token to decode. + keys (list): A list of keys to use for decoding the token, usually + provided by the Identity Provider (IdP). + audience (str): The expected audience (`aud`) claim to verify within the token. + verify_aud (bool, optional): Flag to enable or disable audience verification. + Defaults to True. + + Returns: + dict: A dictionary of validated token claims if decoding and validation are successful. + + Raises: + JWTClaimsError: If the token's claims, such as audience, do not match the expected values. + JWTError: If there is an error with the JWT structure or verification. + + Notes: + - This function assumes the token is signed using the RS256 algorithm. + - Audience verification (`aud`) is performed if `verify_aud` is set to True. + """ + try: + validated_claims = jwt.decode( + token_id, + keys, + options={"verify_aud": verify_aud, "verify_at_hash": False}, + algorithms=["RS256"], + audience=audience, + ) + self.logger.info("Token decoded and validated successfully.") + return validated_claims + except JWTClaimsError as e: + self.logger.error(f"Claim error: {e}") + raise JWTClaimsError(f"Invalid audience: {e}") + except JWTError as e: + self.logger.error(f"JWT error: {e}") + raise JWTError(f"JWT error occurred: {e}") + def get_jwt_claims_identity(self, token_endpoint, jwks_endpoint, code): """ Get jwt identity claims """ + token = self.get_token(token_endpoint, code) + keys = self.get_jwt_keys(jwks_endpoint) - return jwt.decode( - token["id_token"], - keys, - options={"verify_aud": False, "verify_at_hash": False}, - algorithms=["RS256"], + refresh_token = token.get("refresh_token", None) + + # validate audience and hash. also ensure that the algorithm is correctly derived from the token. + # hash verification has not been implemented yet + verify_aud = self.settings.get("verify_aud", False) + audience = self.settings.get("audience", self.settings.get("client_id")) + return ( + self.decode_and_validate_token( + token["id_token"], keys, audience, verify_aud + ), + refresh_token, ) def get_value_from_discovery_doc(self, key, default_value): @@ -161,10 +266,28 @@ def get_auth_info(self, code): user OR "error" field with details of the error. """ user_id_field = self.settings.get("user_id_field", "sub") + try: token_endpoint = self.get_value_from_discovery_doc("token_endpoint", "") jwks_endpoint = self.get_value_from_discovery_doc("jwks_uri", "") - claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) + claims, refresh_token = self.get_jwt_claims_identity( + token_endpoint, jwks_endpoint, code + ) + + groups = None + group_prefix = None + + if self.read_authz_groups_from_tokens: + try: + groups = claims.get("groups") + group_prefix = self.settings.get("authz_groups_sync", {}).get( + "group_prefix", "" + ) + except KeyError as e: + self.logger.error( + f"Error: is_authz_groups_sync_enabled is enabled, however groups not found in claims: {e}" + ) + raise Exception(e) if claims.get(user_id_field): if user_id_field == "email" and not claims.get("email_verified"): @@ -172,6 +295,11 @@ def get_auth_info(self, code): return { user_id_field: claims[user_id_field], "mfa": self.has_mfa_claim(claims), + "refresh_token": refresh_token, + "iat": claims.get("iat"), + "exp": claims.get("exp"), + "groups": groups, + "group_prefix": group_prefix, } else: self.logger.exception( @@ -187,14 +315,15 @@ def get_access_token(self, user, token_endpoint, db_session=None): """ Get access_token using a refresh_token and store new refresh in upstream_refresh_token table. """ + # this function is not correct. use self.session.fetch_access_token, + # validate the token for audience and then return the validated token. + # Still store the refresh token. it will be needed for periodic re-fetching of information. refresh_token = None expires = None - # get refresh_token and expiration from db for row in sorted(user.upstream_refresh_tokens, key=lambda row: row.expires): refresh_token = row.refresh_token expires = row.expires - if time.time() > expires: # reset to check for next token refresh_token = None @@ -274,3 +403,150 @@ def store_refresh_token(self, user, refresh_token, expires, db_session=None): current_db_session = db_session.object_session(upstream_refresh_token) current_db_session.add(upstream_refresh_token) db_session.commit() + + def get_groups_from_token(self, decoded_id_token, group_prefix=""): + """Retrieve and format groups from the decoded token.""" + authz_groups_from_idp = decoded_id_token.get("groups", []) + if authz_groups_from_idp: + authz_groups_from_idp = [ + group.removeprefix(group_prefix).lstrip("/") + for group in authz_groups_from_idp + ] + return authz_groups_from_idp + + @backoff.on_exception(backoff.expo, Exception, **DEFAULT_BACKOFF_SETTINGS) + def update_user_authorization(self, user, pkey_cache, db_session=None, **kwargs): + """ + Update the user's authorization by refreshing their access token and synchronizing + their group memberships with Arborist. + + This method refreshes the user's access token using an identity provider (IdP), + retrieves and decodes the token, and optionally synchronizes the user's group + memberships between the IdP and Arborist if the `groups` configuration is enabled. + + Args: + user (User): The user object, which contains details like username and identity provider. + pkey_cache (dict): A cache of public keys used for verifying JWT signatures. + db_session (SQLAlchemy Session, optional): A database session object. If not provided, + it defaults to the scoped session of the current application context. + **kwargs: Additional keyword arguments. + + Raises: + Exception: If there is an issue with retrieving the access token, decoding the token, + or synchronizing the user's groups. + + Workflow: + 1. Retrieves the token endpoint and JWKS URI from the identity provider's discovery document. + 2. Uses the user's refresh token to get a new access token and persists it in the database. + 3. Decodes the ID token using the JWKS (JSON Web Key Set) retrieved from the IdP. + 4. If group synchronization is enabled: + a. Retrieves the list of groups from Arborist. + b. Retrieves the user's groups from the IdP. + c. Adds the user to groups in Arborist that match the groups from the IdP. + d. Removes the user from groups in Arborist that they are no longer part of in the IdP. + + Logging: + - Logs the group membership synchronization activities (adding/removing users from groups). + - Logs any issues encountered while refreshing the token or during group synchronization. + + Warnings: + - If groups are not received from the IdP but group synchronization is enabled, logs a warning. + + """ + db_session = db_session or current_app.scoped_session() + + # Initialize the failure flag for group removal + removal_failed = False + + expires_at = None + + try: + token_endpoint = self.get_value_from_discovery_doc("token_endpoint", "") + + # this get_access_token also persists the refresh token in the db + token = self.get_access_token(user, token_endpoint, db_session) + jwks_endpoint = self.get_value_from_discovery_doc("jwks_uri", "") + keys = self.get_jwt_keys(jwks_endpoint) + expires_at = token["expires_at"] + verify_aud = self.settings.get("verify_aud", False) + audience = self.settings.get("audience", self.settings.get("client_id")) + decoded_token_id = self.decode_and_validate_token( + token_id=token["id_token"], + keys=keys, + audience=audience, + verify_aud=verify_aud, + ) + + except Exception as e: + err_msg = "Could not refresh token" + self.logger.exception("{}: {}".format(err_msg, e)) + raise + if self.read_authz_groups_from_tokens: + group_prefix = self.settings.get("authz_groups_sync", {}).get( + "group_prefix", "" + ) + + # grab all groups defined in arborist + arborist_groups = self.arborist.list_groups().get("groups") + + # groups defined in idp + authz_groups_from_idp = self.get_groups_from_token( + decoded_token_id, group_prefix + ) + + exp = datetime.datetime.fromtimestamp(expires_at, tz=datetime.timezone.utc) + + # if group name is in the list from arborist: + if authz_groups_from_idp: + authz_groups_from_idp = [ + group.removeprefix(group_prefix).lstrip("/") + for group in authz_groups_from_idp + ] + + idp_group_names = set(authz_groups_from_idp) + + # Add user to all matching groups from IDP + for arborist_group in arborist_groups: + if arborist_group["name"] in idp_group_names: + self.logger.info( + f"Adding {user.username} to group: {arborist_group['name']}, sub: {user.id} exp: {exp}" + ) + self.arborist.add_user_to_group( + username=user.username, + group_name=arborist_group["name"], + expires_at=exp, + ) + + # Remove user from groups in Arborist that they are not part of in IDP + for arborist_group in arborist_groups: + if arborist_group["name"] not in idp_group_names: + if user.username in arborist_group.get("users", []): + try: + self.remove_user_from_arborist_group( + user.username, arborist_group["name"] + ) + except Exception as e: + self.logger.error( + f"Failed to remove {user.username} from group {arborist_group['name']}: {e}" + ) + removal_failed = ( + # Set the failure flag if any removal fails + True + ) + + else: + self.logger.warning( + f"is_authz_groups_sync_enabled feature is enabled, but did not receive groups from idp {self.idp} for user: {user.username}" + ) + + # Raise an exception if any group removal failed + if removal_failed: + raise Exception("One or more group removals failed.") + + def remove_user_from_arborist_group(self, username, group_name): + """ + Attempt to remove a user from an Arborist group, catching any errors to allow + processing of remaining groups. Logs errors and re-raises them after all removals are attempted. + """ + self.logger.info(f"Removing {username} from group: {group_name}") + self.arborist.remove_user_from_group(username=username, group_name=group_name) diff --git a/fence/resources/openid/microsoft_oauth2.py b/fence/resources/openid/microsoft_oauth2.py index 916a4a2b1..960bd6e49 100755 --- a/fence/resources/openid/microsoft_oauth2.py +++ b/fence/resources/openid/microsoft_oauth2.py @@ -48,7 +48,7 @@ def get_auth_info(self, code): "jwks_uri", "https://login.microsoftonline.com/organizations/discovery/v2.0/keys", ) - claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) + claims, refresh_token = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) if claims.get("email"): return {"email": claims["email"], "sub": claims.get("sub")} diff --git a/fence/resources/openid/okta_oauth2.py b/fence/resources/openid/okta_oauth2.py index 572031623..b26fb2d68 100644 --- a/fence/resources/openid/okta_oauth2.py +++ b/fence/resources/openid/okta_oauth2.py @@ -37,7 +37,7 @@ def get_auth_info(self, code): "jwks_uri", "", ) - claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) + claims, refresh_token = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) if claims.get("email"): return {"email": claims["email"], "sub": claims.get("sub")} diff --git a/fence/resources/openid/orcid_oauth2.py b/fence/resources/openid/orcid_oauth2.py index ee8711f33..5894a3519 100644 --- a/fence/resources/openid/orcid_oauth2.py +++ b/fence/resources/openid/orcid_oauth2.py @@ -41,7 +41,7 @@ def get_auth_info(self, code): jwks_endpoint = self.get_value_from_discovery_doc( "jwks_uri", "https://orcid.org/oauth/jwks" ) - claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) + claims, refresh_token = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) if claims.get("sub"): return {"orcid": claims["sub"], "sub": claims["sub"]} diff --git a/fence/scripting/fence_create.py b/fence/scripting/fence_create.py index a4b15aff8..9a94e3601 100644 --- a/fence/scripting/fence_create.py +++ b/fence/scripting/fence_create.py @@ -38,7 +38,7 @@ generate_signed_refresh_token, issued_and_expiration_times, ) -from fence.job.visa_update_cronjob import Visa_Token_Update +from fence.job.access_token_updater import AccessTokenUpdater from fence.models import ( Client, GoogleServiceAccount, @@ -1814,12 +1814,19 @@ def access_token_polling_job( thread_pool_size (int): number of Docker container CPU used for jwt verifcation buffer_size (int): max size of queue """ + # Instantiating a new client here because the existing + # client uses authz_provider + arborist = ArboristClient( + arborist_base_url=config["ARBORIST"], + logger=get_logger("user_syncer.arborist_client"), + ) driver = get_SQLAlchemyDriver(db) - job = Visa_Token_Update( + job = AccessTokenUpdater( chunk_size=int(chunk_size) if chunk_size else None, concurrency=int(concurrency) if concurrency else None, thread_pool_size=int(thread_pool_size) if thread_pool_size else None, buffer_size=int(buffer_size) if buffer_size else None, + arborist=arborist, ) with driver.session as db_session: loop = asyncio.get_event_loop() diff --git a/tests/conftest.py b/tests/conftest.py index 9baba01a1..90c81d2fa 100755 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -76,6 +76,7 @@ "cilogon", "generic1", "generic2", + "generic3", ] @@ -396,6 +397,12 @@ def do_patch(urls_to_responses=None): defaults = { "arborist/health": {"GET": ("", 200)}, "arborist/auth/mapping": {"POST": ({}, "200")}, + "arborist/group": { + "GET": ( + {"groups": [{"name": "data_uploaders", "users": ["test_user"]}]}, + 200, + ) + }, } defaults.update(urls_to_responses) urls_to_responses = defaults @@ -479,6 +486,33 @@ def app(kid, rsa_private_key, rsa_public_key): mocker.unmock_functions() +@pytest.fixture +def mock_app(): + return MagicMock() + + +@pytest.fixture +def mock_user(): + return MagicMock() + + +@pytest.fixture +def mock_db_session(): + """Mock the database session.""" + db_session = MagicMock() + return db_session + + +@pytest.fixture +def expired_mock_user(): + """Mock a user object with upstream refresh tokens.""" + user = MagicMock() + user.upstream_refresh_tokens = [ + MagicMock(refresh_token="expired_token", expires=0), # Expired token + ] + return user + + @pytest.fixture(scope="function") def auth_client(request): """ diff --git a/tests/dbgap_sync/test_user_sync.py b/tests/dbgap_sync/test_user_sync.py index f85cc28e5..7cc565c4a 100644 --- a/tests/dbgap_sync/test_user_sync.py +++ b/tests/dbgap_sync/test_user_sync.py @@ -10,7 +10,7 @@ from fence import models from fence.resources.google.access_utils import GoogleUpdateException from fence.config import config -from fence.job.visa_update_cronjob import Visa_Token_Update +from fence.job.access_token_updater import AccessTokenUpdater from fence.utils import DEFAULT_BACKOFF_SETTINGS from tests.dbgap_sync.conftest import ( @@ -998,7 +998,7 @@ def test_user_sync_with_visa_sync_job( # use refresh tokens from users to call access token polling "fence-create update-visa" # and sync authorization from visas - job = Visa_Token_Update() + job = AccessTokenUpdater() job.pkey_cache = { "https://stsstg.nih.gov": { kid: rsa_public_key, diff --git a/tests/job/test_access_token_updater.py b/tests/job/test_access_token_updater.py new file mode 100644 index 000000000..0ba9f6368 --- /dev/null +++ b/tests/job/test_access_token_updater.py @@ -0,0 +1,206 @@ +import pytest +import asyncio +from unittest.mock import AsyncMock, patch, MagicMock +from fence.models import User +from fence.resources.openid.idp_oauth2 import Oauth2ClientBase as OIDCClient +from fence.resources.openid.ras_oauth2 import RASOauth2Client as RASClient +from fence.job.access_token_updater import AccessTokenUpdater + + +@pytest.fixture(scope="session", autouse=True) +def event_loop(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + yield loop + loop.close() + + +@pytest.fixture +def run_async(event_loop): + """Run an async coroutine in the current event loop.""" + + def _run(coro): + return event_loop.run_until_complete(coro) + + return _run + + +@pytest.fixture +def mock_db_session(): + """Fixture to mock the DB session.""" + return MagicMock() + + +@pytest.fixture +def mock_users(): + """Fixture to mock the user list.""" + user1 = MagicMock(spec=User) + user1.username = "testuser1" + user1.identity_provider.name = "ras" + + user2 = MagicMock(spec=User) + user2.username = "testuser2" + user2.identity_provider.name = "test_oidc" + + return [user1, user2] + + +@pytest.fixture +def mock_oidc_clients(): + """Fixture to mock OIDC clients.""" + ras_client = MagicMock(spec=RASClient) + ras_client.idp = "ras" + + oidc_client = MagicMock(spec=OIDCClient) + oidc_client.idp = "test_oidc" + + return [ras_client, oidc_client] + + +@pytest.fixture +def access_token_updater_config(mock_oidc_clients): + """Fixture to instantiate AccessTokenUpdater with mocked OIDC clients.""" + with patch( + "fence.config", + { + "OPENID_CONNECT": { + "ras": {}, + "test_oidc": {"groups": {"read_authz_groups_from_tokens": True}}, + }, + "ENABLE_AUTHZ_GROUPS_FROM_OIDC": True, + }, + ): + updater = AccessTokenUpdater() + + # Ensure this is a dictionary rather than a list + updater.oidc_clients_requiring_token_refresh = { + client.idp: client for client in mock_oidc_clients + } + + return updater + + +def test_get_user_from_db( + run_async, access_token_updater_config, mock_db_session, mock_users +): + """Test the get_user_from_db method.""" + mock_db_session.query().slice().all.return_value = mock_users + + users = run_async( + access_token_updater_config.get_user_from_db(mock_db_session, chunk_idx=0) + ) + assert len(users) == 2 + assert users[0].username == "testuser1" + assert users[1].username == "testuser2" + + +def test_producer(run_async, access_token_updater_config, mock_db_session, mock_users): + """Test the producer method.""" + queue = asyncio.Queue() + mock_db_session.query().slice().all.return_value = mock_users + + # Run producer to add users to queue + run_async(access_token_updater_config.producer(mock_db_session, queue, chunk_idx=0)) + + assert queue.qsize() == len(mock_users) + assert not queue.empty() + + # Dequeue to check correctness + user = run_async(queue.get()) + assert user.username == "testuser1" + + +def test_worker(run_async, access_token_updater_config, mock_users): + """Test the worker method.""" + queue = asyncio.Queue() + updater_queue = asyncio.Queue() + + # Add users to the queue + for user in mock_users: + run_async(queue.put(user)) + + # Run the worker to transfer users from queue to updater_queue + run_async(access_token_updater_config.worker("worker_1", queue, updater_queue)) + + assert updater_queue.qsize() == len(mock_users) + assert queue.empty() + + +async def updater_with_timeout(updater, queue, db_session, timeout=5): + return await asyncio.wait_for(updater(queue, db_session), timeout) + + +def test_updater( + run_async, + access_token_updater_config, + mock_users, + mock_db_session, + mock_oidc_clients, +): + """Test the updater method.""" + updater_queue = asyncio.Queue() + + # Add a user to the updater_queue + run_async(updater_queue.put(mock_users[0])) + + # Mock the client to return a valid update process + mock_oidc_clients[0].update_user_authorization = AsyncMock() + + # Ensure _pick_client returns the correct client + with patch.object( + access_token_updater_config, "_pick_client", return_value=mock_oidc_clients[0] + ): + # Signal the updater to stop after processing + run_async(updater_queue.put(None)) # This should be an awaited call + + # Run the updater to process the user and update authorization + run_async( + access_token_updater_config.updater( + "updater_1", updater_queue, mock_db_session + ) + ) + + # Verify that the OIDC client was called with the correct user + mock_oidc_clients[0].update_user_authorization.assert_called_once_with( + mock_users[0], + pkey_cache=access_token_updater_config.pkey_cache, + db_session=mock_db_session, + ) + + +def test_no_client_found(run_async, access_token_updater_config, mock_users): + """Test that updater does not crash if no client is found.""" + updater_queue = asyncio.Queue() + + # Modify the user to have an unrecognized identity provider + mock_users[0].identity_provider.name = "unknown_provider" + + run_async(updater_queue.put(mock_users[0])) # Ensure this is awaited + run_async(updater_queue.put(None)) # Signal the updater to terminate + + # Mock the client selection to return None + with patch.object(access_token_updater_config, "_pick_client", return_value=None): + # Run the updater and ensure it skips the user with no client + run_async( + access_token_updater_config.updater("updater_1", updater_queue, MagicMock()) + ) + + assert updater_queue.empty() # The user should still be dequeued + + +def test_pick_client( + run_async, access_token_updater_config, mock_users, mock_oidc_clients +): + """Test that the correct OIDC client is selected based on the user's IDP.""" + # Pick the client for a RAS user + client = access_token_updater_config._pick_client(mock_users[0]) + assert client.idp == "ras" + + # Pick the client for a test OIDC user + client = access_token_updater_config._pick_client(mock_users[1]) + assert client.idp == "test_oidc" + + # Ensure no client is returned for a user with no matching IDP + mock_users[0].identity_provider.name = "nonexistent_idp" + client = access_token_updater_config._pick_client(mock_users[0]) + assert client is None diff --git a/tests/login/test_base.py b/tests/login/test_base.py index a32452b2c..bf541f64a 100644 --- a/tests/login/test_base.py +++ b/tests/login/test_base.py @@ -1,6 +1,11 @@ +import pytest + from fence.blueprints.login import DefaultOAuth2Callback +from fence.resources.openid.idp_oauth2 import Oauth2ClientBase, UpstreamRefreshToken from fence.config import config from unittest.mock import MagicMock, patch +from datetime import datetime, timedelta +import time @patch("fence.blueprints.login.base.prepare_login_log") @@ -25,12 +30,14 @@ def test_post_login_set_mfa(app, monkeypatch, mock_authn_user_flask_context): app.arborist = MagicMock() token_result = {"username": "lisasimpson", "mfa": True} callback.post_login(token_result=token_result) + app.arborist.grant_user_policy.assert_called_with( username=token_result["username"], policy_id="mfa_policy" ) token_result = {"username": "homersimpson", "mfa": False} callback.post_login(token_result=token_result) + app.arborist.revoke_user_policy.assert_called_with( username=token_result["username"], policy_id="mfa_policy" ) @@ -54,3 +61,7 @@ def test_post_login_no_mfa_enabled(app, monkeypatch, mock_authn_user_flask_conte token_result = {"username": "lisasimpson"} callback.post_login(token_result=token_result) app.arborist.revoke_user_policy.assert_not_called() + + + + diff --git a/tests/login/test_idp_oauth2.py b/tests/login/test_idp_oauth2.py index 40ae2349a..b5b229af6 100644 --- a/tests/login/test_idp_oauth2.py +++ b/tests/login/test_idp_oauth2.py @@ -1,7 +1,14 @@ import pytest +import datetime +from jose.exceptions import JWTClaimsError +from unittest.mock import ANY +from flask import Flask, g from cdislogging import get_logger +from unittest.mock import MagicMock, Mock, patch +from fence.resources.openid.idp_oauth2 import Oauth2ClientBase, AuthError +from fence.blueprints.login.base import DefaultOAuth2Callback +from fence.config import config -from fence import Oauth2ClientBase MOCK_SETTINGS_ACR = { "client_id": "client", @@ -39,11 +46,6 @@ def test_has_mfa_claim_acr(oauth_client_acr): assert has_mfa -def test_has_mfa_claim_acr(oauth_client_acr): - has_mfa = oauth_client_acr.has_mfa_claim({"acr": "mfa"}) - assert has_mfa - - def test_has_mfa_claim_multiple_acr(oauth_client_acr): has_mfa = oauth_client_acr.has_mfa_claim({"acr": "mfa otp duo"}) assert has_mfa @@ -83,3 +85,355 @@ def test_does_not_has_mfa_claim_amr(oauth_client_amr): def test_does_not_has_mfa_claim_multiple_amr(oauth_client_amr): has_mfa = oauth_client_amr.has_mfa_claim({"amr": ["pwd, trustme"]}) assert not has_mfa + + +# To test the store_refresh_token method of the Oauth2ClientBase class +def test_store_refresh_token(mock_user, mock_app): + """ + Test the `store_refresh_token` method of the `Oauth2ClientBase` class to ensure that + refresh tokens are correctly stored in the database using the `UpstreamRefreshToken` model. + """ + mock_logger = MagicMock() + mock_settings = { + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "redirect_url": "http://localhost/callback", + "discovery_url": "http://localhost/.well-known/openid-configuration", + "groups": {"read_authz_groups_from_tokens": True, "group_prefix": "/"}, + "user_id_field": "sub", + } + + # Ensure oauth_client is correctly instantiated + oauth_client = Oauth2ClientBase( + settings=mock_settings, logger=mock_logger, idp="test_idp" + ) + + refresh_token = "mock_refresh_token" + expires = 1700000000 + + # Patch the UpstreamRefreshToken to prevent actual database interactions + with patch( + "fence.resources.openid.idp_oauth2.UpstreamRefreshToken", autospec=True + ) as MockUpstreamRefreshToken: + # Mock the db_session's object_session method to return a mocked session object + mock_session = MagicMock() + mock_app.arborist.object_session.return_value = mock_session + + # Call the method to test + oauth_client.store_refresh_token( + mock_user, refresh_token, expires, db_session=mock_app.arborist + ) + + # Check if UpstreamRefreshToken was instantiated correctly + MockUpstreamRefreshToken.assert_called_once_with( + user=mock_user, + refresh_token=refresh_token, + expires=expires, + ) + + # Check if the mock session's `add` and `commit` methods were called + mock_app.arborist.object_session.assert_called_once() + mock_session.add.assert_called_once_with(MockUpstreamRefreshToken.return_value) + mock_app.arborist.commit.assert_called_once() + + +# To test if a user is granted access using the get_auth_info method in the Oauth2ClientBase +@patch("fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_jwt_keys") +@patch("fence.resources.openid.idp_oauth2.jwt.decode") +@patch("authlib.integrations.requests_client.OAuth2Session.fetch_token") +@patch( + "fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_value_from_discovery_doc" +) +def test_get_auth_info_granted_access( + mock_get_value_from_discovery_doc, + mock_fetch_token, + mock_jwt_decode, + mock_get_jwt_keys, +): + """ + Test that the `get_auth_info` method correctly retrieves, processes, and decodes + an OAuth2 authentication token, including access, refresh, and ID tokens, while also + handling JWT decoding and discovery document lookups. + + Raises: + AssertionError: If the expected claims or tokens are not present in the returned authentication information. + """ + mock_settings = { + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "redirect_url": "http://localhost/callback", + "discovery_url": "http://localhost/.well-known/openid-configuration", + "is_authz_groups_sync_enabled": True, + "authz_groups_sync:": {"group_prefix": "/"}, + "user_id_field": "sub", + } + + # Mock logger + mock_logger = MagicMock() + + oauth2_client = Oauth2ClientBase( + settings=mock_settings, logger=mock_logger, idp="test_idp" + ) + + # Directly mock the return values for token_endpoint and jwks_uri + mock_get_value_from_discovery_doc.side_effect = lambda key, default=None: ( + "http://localhost/token" if key == "token_endpoint" else "http://localhost/jwks" + ) + + # Setup mock response for fetch_token + mock_fetch_token.return_value = { + "access_token": "mock_access_token", + "id_token": "mock_id_token", + "refresh_token": "mock_refresh_token", + } + + # Setup mock JWT keys response + mock_get_jwt_keys.return_value = [ + {"kty": "RSA", "kid": "1e9gdk7", "use": "sig", "n": "example-key", "e": "AQAB"} + ] + + # Setup mock decoded JWT token + mock_jwt_decode.return_value = { + "sub": "mock_user_id", + "email_verified": True, + "iat": 1609459200, + "exp": 1609462800, + "groups": ["group1", "group2"], + } + + # Log mock setups + print( + f"Mock token endpoint: {mock_get_value_from_discovery_doc('token_endpoint', '')}" + ) + print(f"Mock jwks_uri: {mock_get_value_from_discovery_doc('jwks_uri', '')}") + print(f"Mock fetch_token response: {mock_fetch_token.return_value}") + print(f"Mock JWT decode response: {mock_jwt_decode.return_value}") + + # Call the method + code = "mock_code" + auth_info = oauth2_client.get_auth_info(code) + print(f"Mock auth_info: {auth_info}") + + # Debug: Check if decode was called + print(f"JWT decode call count: {mock_jwt_decode.call_count}") + + # Assertions + assert "sub" in auth_info + assert auth_info["sub"] == "mock_user_id" + assert "refresh_token" in auth_info + assert auth_info["refresh_token"] == "mock_refresh_token" + assert "iat" in auth_info + assert auth_info["iat"] == 1609459200 + assert "exp" in auth_info + assert auth_info["exp"] == 1609462800 + assert "groups" in auth_info + assert auth_info["groups"] == ["group1", "group2"] + + +def test_get_access_token_expired(expired_mock_user, mock_db_session): + """ + Test that attempting to retrieve an access token for a user with an expired refresh token + results in an `AuthError`, the user's token is deleted, and the session is committed. + + + Raises: + AuthError: When the user does not have a valid, non-expired refresh token. + """ + mock_settings = { + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "redirect_url": "http://localhost/callback", + "discovery_url": "http://localhost/.well-known/openid-configuration", + "is_authz_groups_sync_enabled": True, + "authz_groups_sync:": {"group_prefix": "/"}, + "user_id_field": "sub", + } + + # Initialize the Oauth2 client object + oauth2_client = Oauth2ClientBase( + settings=mock_settings, logger=MagicMock(), idp="test_idp" + ) + + # Simulate the token expiration and user not having access + with pytest.raises(AuthError) as excinfo: + print("get_access_token about to be called") + oauth2_client.get_access_token( + expired_mock_user, + token_endpoint="https://token.endpoint", + db_session=mock_db_session, + ) + + print(f"Raised exception message: {excinfo.value}") + + assert "User doesn't have a valid, non-expired refresh token" in str(excinfo.value) + + mock_db_session.delete.assert_called() + mock_db_session.commit.assert_called() + + +@patch("fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_auth_info") +def test_post_login_with_group_prefix(mock_get_auth_info, app): + """ + Test the `post_login` method of the `DefaultOAuth2Callback` class, ensuring that user groups + fetched from an identity provider (IdP) are processed correctly and prefixed before being added + to the user in the Arborist service. + """ + with app.app_context(): + yield + with patch.dict(config, {"ENABLE_AUTHZ_GROUPS_FROM_OIDC": True}, clear=False): + mock_user = MagicMock() + mock_user.username = "test_user" + mock_user.id = "user_id" + g.user = mock_user + + # Set up mock responses for user info and groups from the IdP + mock_get_auth_info.return_value = { + "username": "test_user", + "groups": ["group1", "group2", "covid/group3", "group4", "group5"], + "exp": datetime.datetime.now(tz=datetime.timezone.utc).timestamp(), + "group_prefix": "covid/", + } + + # Mock the Arborist client and its methods + mock_arborist = MagicMock() + mock_arborist.list_groups.return_value = { + "groups": [ + {"name": "group1"}, + {"name": "group2"}, + {"name": "group3"}, + {"name": "reviewers"}, + ] + } + mock_arborist.add_user_to_group = MagicMock() + mock_arborist.remove_user_from_group = MagicMock() + + # Mock the Flask app + app = MagicMock() + app.arborist = mock_arborist + + # Create the callback object with the mock app + callback = DefaultOAuth2Callback( + idp_name="generic3", client=MagicMock(), app=app + ) + + # Mock user and call post_login + mock_user = MagicMock() + mock_user.username = "test_user" + + # Simulate calling post_login + callback.post_login( + user=g.user, + token_result=mock_get_auth_info.return_value, + groups_from_idp=mock_get_auth_info.return_value["groups"], + group_prefix=mock_get_auth_info.return_value["group_prefix"], + expires_at=mock_get_auth_info.return_value["exp"], + username=mock_user.username, + ) + + # Assertions to check if groups were processed with the correct prefix + mock_arborist.add_user_to_group.assert_any_call( + username="test_user", + group_name="group1", + expires_at=datetime.datetime.fromtimestamp( + mock_get_auth_info.return_value["exp"], tz=datetime.timezone.utc + ), + ) + mock_arborist.add_user_to_group.assert_any_call( + username="test_user", + group_name="group2", + expires_at=datetime.datetime.fromtimestamp( + mock_get_auth_info.return_value["exp"], tz=datetime.timezone.utc + ), + ) + mock_arborist.add_user_to_group.assert_any_call( + username="test_user", + group_name="group3", + expires_at=datetime.datetime.fromtimestamp( + mock_get_auth_info.return_value["exp"], tz=datetime.timezone.utc + ), + ) + + # Ensure the mock was called exactly three times (once for each group that was added) + assert mock_arborist.add_user_to_group.call_count == 3 + + +@patch("fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_jwt_keys") +@patch("authlib.integrations.requests_client.OAuth2Session.fetch_token") +@patch("fence.resources.openid.idp_oauth2.jwt.decode") # Mock jwt.decode +def test_jwt_audience_verification_fails( + mock_jwt_decode, mock_fetch_token, mock_get_jwt_keys +): + """ + Test the JWT audience verification failure scenario. + + This test mocks various components used in the OIDC flow to simulate the + process of obtaining a token, fetching JWKS (JSON Web Key Set), and verifying + the JWT token's claims. Specifically, it focuses on the audience verification + step and tests that an invalid audience raises the expected `JWTClaimsError`. + + + Raises: + JWTClaimsError: When the audience in the JWT token is invalid. + """ + # Mock fetch_token to simulate a successful token fetch + mock_fetch_token.return_value = { + "id_token": "mock-id-token", + "access_token": "mock_access_token", + "refresh_token": "mock-refresh-token", + } + + # Mock JWKS response + mock_jwks_response = { + "keys": [ + { + "kty": "RSA", + "kid": "test-key-id", + "use": "sig", + # Simulate RSA public key values + "n": "mock-n-value", + "e": "mock-e-value", + } + ] + } + + mock_get_jwt_keys.return_value = MagicMock() + mock_get_jwt_keys.return_value = mock_jwks_response + + # Mock jwt.decode to raise JWTClaimsError for audience verification failure + mock_jwt_decode.side_effect = JWTClaimsError("Invalid audience") + + # Setup the mock instance of Oauth2ClientBase + client = Oauth2ClientBase( + settings={ + "client_id": "mock-client-id", + "client_secret": "mock-client-secret", + "redirect_url": "mock-redirect-url", + "discovery_url": "http://localhost/discovery", + "audience": "expected-audience", + "verify_aud": True, + }, + logger=MagicMock(), + idp="mock-idp", + ) + + # Invoke the method and expect JWTClaimsError to be raised + with pytest.raises(JWTClaimsError, match="Invalid audience"): + client.get_jwt_claims_identity( + token_endpoint="https://token.endpoint", + jwks_endpoint="https://jwks.uri", + code="auth_code", + ) + + # Verify fetch_token was called correctly + mock_fetch_token.assert_called_once_with( + url="https://token.endpoint", code="auth_code", proxies=None + ) + + # Verify jwt.decode was called with the mock id_token and the mocked JWKS keys + mock_jwt_decode.assert_called_with( + "mock-id-token", # The mock token + mock_jwks_response, # The mocked keys + options={"verify_aud": True, "verify_at_hash": False}, + algorithms=["RS256"], + audience="expected-audience", + ) diff --git a/tests/login/test_login_shib.py b/tests/login/test_login_shib.py index db18aa483..f0335fcb3 100644 --- a/tests/login/test_login_shib.py +++ b/tests/login/test_login_shib.py @@ -1,6 +1,5 @@ from fence.config import config - def test_shib_redirect(client, app): r = client.get("/login/shib?redirect=http://localhost") assert r.status_code == 302 diff --git a/tests/login/test_microsoft_login.py b/tests/login/test_microsoft_login.py index 972b8a07f..a00d75463 100755 --- a/tests/login/test_microsoft_login.py +++ b/tests/login/test_microsoft_login.py @@ -34,9 +34,10 @@ def test_get_auth_info_missing_claim(microsoft_oauth2_client): """ return_value = {"not_email_claim": "user@contoso.com"} expected_value = {"error": "Can't get user's Microsoft email!"} + refresh_token = {} with patch( "fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_jwt_claims_identity", - return_value=return_value, + return_value=(return_value,refresh_token) ): user_id = microsoft_oauth2_client.get_auth_info(code="123") assert user_id == expected_value # nosec diff --git a/tests/ras/test_ras.py b/tests/ras/test_ras.py index f3be7575c..c1439e056 100644 --- a/tests/ras/test_ras.py +++ b/tests/ras/test_ras.py @@ -25,7 +25,7 @@ from tests.utils import add_test_ras_user, TEST_RAS_USERNAME, TEST_RAS_SUB from tests.dbgap_sync.conftest import add_visa_manually -from fence.job.visa_update_cronjob import Visa_Token_Update +from fence.job.access_token_updater import AccessTokenUpdater import tests.utils from tests.conftest import get_subjects_to_passports @@ -713,7 +713,7 @@ def _get_userinfo(*args, **kwargs): mock_userinfo.side_effect = _get_userinfo # test "fence-create update-visa" - job = Visa_Token_Update() + job = AccessTokenUpdater() job.pkey_cache = { "https://stsstg.nih.gov": { kid: rsa_public_key, diff --git a/tests/test-fence-config.yaml b/tests/test-fence-config.yaml index 38ccbd147..bb055b835 100755 --- a/tests/test-fence-config.yaml +++ b/tests/test-fence-config.yaml @@ -141,6 +141,32 @@ OPENID_CONNECT: redirect_url: '{{BASE_URL}}/login/generic2/login' discovery: authorization_endpoint: 'https://generic2/authorization_endpoint' + generic3: + name: 'generic3' # optional; display name for this IDP + client_id: '' + client_secret: '' + redirect_url: '{{BASE_URL}}/login/generic3/login' # replace IDP name + # use `discovery` to configure IDPs that do not expose a discovery + # endpoint. One of `discovery_url` or `discovery` should be configured + discovery_url: 'https://localhost/.well-known/openid-configuration' + # When true, it allows refresh tokens to be stored even if is_authz_groups_sync_enabled is set false. + # When false, the system will only store refresh tokens if is_authz_groups_sync_enabled is enabled + persist_refresh_token: false + # is_authz_groups_sync_enabled: A configuration flag that determines whether the application should + # verify and synchronize user group memberships between the identity provider (IdP) + # and the local authorization system (Arborist). When enabled, the system retrieves + # the user's group information from their token issued by the IdP and compares it against + # the groups defined in the local system. Based on the comparison, the user is added to + # or removed from relevant groups in the local system to ensure their group memberships + # remain up-to-date. If this flag is disabled, no group synchronization occurs + is_authz_groups_sync_enabled: false + authz_groups_sync: + # This defines the prefix used to identify authorization groups. + group_prefix: /covid + # This flag indicates whether the audience (aud) claim in the JWT should be verified during token validation. + verify_aud: false + # This specifies the expected audience (aud) value for the JWT, ensuring that the token is intended for use with the 'fence' service. + audience: fence # these are the *possible* scopes a client can be given, NOT scopes that are # given to all clients. You can be more restrictive during client creation diff --git a/tests/test_metrics.py b/tests/test_metrics.py index be7d6b2ab..d0d47786d 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -519,6 +519,18 @@ def test_login_log_login_endpoint( get_auth_info_value = {"generic1_username": username} elif idp == "generic2": get_auth_info_value = {"sub": username} + elif idp == "generic3": + # get_auth_info_value specific to generic3 + # TODO: Need test when is_authz_groups_sync_enabled == true + get_auth_info_value = { + "username": username, + "sub": username, + "email_verified": True, + "iat": 1609459200, + "exp": 1609462800, + "refresh_token": "mock_refresh_token", + "groups": ["group1", "group2"], + } if idp in ["google", "microsoft", "okta", "synapse", "cognito"]: get_auth_info_value["email"] = username @@ -538,6 +550,7 @@ def test_login_log_login_endpoint( ) path = f"/login/{idp}/{callback_endpoint}" # SEE fence/blueprints/login/fence_login.py L91 response = client.get(path, headers=headers) + print(f"Response: {response.status_code}, Body: {response.data}") assert response.status_code == 200, response user_sub = db_session.query(User).filter(User.username == username).first().id audit_service_requests.post.assert_called_once_with(