-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automated copy feature/OIDC oauth groups #1199
base: master
Are you sure you want to change the base?
Changes from all commits
25d1654
96300d4
d8d0955
5dfdd7b
05c04c8
724b76f
dd8efdd
388cf69
d77cec3
bcbcee5
c5c0519
5ec4fb7
6f30aeb
10c59ec
d3419d5
3692d79
2984f05
b125321
18336df
4645494
7008e94
d0074f5
55cfdc4
ab7dcfa
8520425
7774fc9
7300d99
3869656
39d5217
ab6e17d
f0f9d28
0d72ec7
659bf5a
d791800
f1b8e31
bd28aba
67c318c
436c08a
d546b02
cd810a1
a5ccac6
90667c3
3cb11a7
4a206ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -108,3 +108,6 @@ tests/resources/keys/*.pem | |
.DS_Store | ||
.vscode | ||
.idea | ||
|
||
# snyk | ||
.dccache |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,12 @@ | ||
import time | ||
import base64 | ||
import json | ||
from urllib.parse import urlparse, urlencode, parse_qsl | ||
import jwt | ||
import requests | ||
import flask | ||
from cdislogging import get_logger | ||
from flask_restful import Resource | ||
from urllib.parse import urlparse, urlencode, parse_qsl | ||
|
||
from fence.auth import login_user | ||
from fence.blueprints.login.redirect import validate_redirect | ||
from fence.config import config | ||
|
@@ -20,7 +24,7 @@ def __init__(self, idp_name, client): | |
Args: | ||
idp_name (str): name for the identity provider | ||
client (fence.resources.openid.idp_oauth2.Oauth2ClientBase): | ||
Some instaniation of this base client class or a child class | ||
Some instantiation of this base client class or a child class | ||
""" | ||
self.idp_name = idp_name | ||
self.client = client | ||
|
@@ -92,8 +96,27 @@ def __init__( | |
self.is_mfa_enabled = "multifactor_auth_claim_info" in config[ | ||
"OPENID_CONNECT" | ||
].get(self.idp_name, {}) | ||
|
||
# Config option to explicitly persist refresh tokens | ||
self.persist_refresh_token = False | ||
|
||
self.read_authz_groups_from_tokens = False | ||
|
||
self.app = app | ||
|
||
# This block of code probably need to be made more concise | ||
if "persist_refresh_token" in config["OPENID_CONNECT"].get(self.idp_name, {}): | ||
self.persist_refresh_token = config["OPENID_CONNECT"][self.idp_name][ | ||
"persist_refresh_token" | ||
] | ||
|
||
if "is_authz_groups_sync_enabled" in config["OPENID_CONNECT"].get( | ||
self.idp_name, {} | ||
): | ||
self.read_authz_groups_from_tokens = config["OPENID_CONNECT"][ | ||
self.idp_name | ||
]["is_authz_groups_sync_enabled"] | ||
|
||
def get(self): | ||
# Check if user granted access | ||
if flask.request.args.get("error"): | ||
|
@@ -119,7 +142,11 @@ def get(self): | |
|
||
code = flask.request.args.get("code") | ||
result = self.client.get_auth_info(code) | ||
|
||
refresh_token = result.get("refresh_token") | ||
|
||
username = result.get(self.username_field) | ||
|
||
if not username: | ||
raise UserError( | ||
f"OAuth2 callback error: no '{self.username_field}' in {result}" | ||
|
@@ -129,11 +156,157 @@ def get(self): | |
id_from_idp = result.get(self.id_from_idp_field) | ||
|
||
resp = _login(username, self.idp_name, email=email, id_from_idp=id_from_idp) | ||
self.post_login(user=flask.g.user, token_result=result, id_from_idp=id_from_idp) | ||
|
||
expires = self.extract_exp(refresh_token) | ||
|
||
# if the access token is not a JWT, or does not carry exp, | ||
# default to now + REFRESH_TOKEN_EXPIRES_IN | ||
if expires is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should add a log here so it's clear in an audit trail that we assumed a specific expiration |
||
expires = int(time.time()) + config["REFRESH_TOKEN_EXPIRES_IN"] | ||
|
||
# Store refresh token in db | ||
should_persist_token = ( | ||
self.persist_refresh_token or self.read_authz_groups_from_tokens | ||
) | ||
if should_persist_token: | ||
# Ensure flask.g.user exists to avoid a potential AttributeError | ||
if getattr(flask.g, "user", None): | ||
self.client.store_refresh_token(flask.g.user, refresh_token, expires) | ||
else: | ||
logger.error( | ||
"User information is missing from flask.g; cannot store refresh token." | ||
) | ||
|
||
self.post_login( | ||
user=flask.g.user, | ||
token_result=result, | ||
id_from_idp=id_from_idp, | ||
) | ||
|
||
return resp | ||
|
||
def extract_exp(self, refresh_token): | ||
""" | ||
Extract the expiration time (`exp`) from a refresh token. | ||
|
||
This function attempts to retrieve the expiration time from the provided | ||
refresh token using three methods: | ||
|
||
1. Using PyJWT to decode the token (without signature verification). | ||
2. Introspecting the token (if supported by the identity provider). | ||
3. Manually base64 decoding the token's payload (if it's a JWT). | ||
|
||
**Disclaimer:** This function assumes that the refresh token is valid and | ||
does not perform any JWT validation. For JWTs from an OpenID Connect (OIDC) | ||
provider, validation should be done using the public keys provided by the | ||
identity provider (from the JWKS endpoint) before using this function to | ||
extract the expiration time. Without validation, the token's integrity and | ||
authenticity cannot be guaranteed, which may expose your system to security | ||
risks. Ensure validation is handled prior to calling this function, | ||
especially in any public or production-facing contexts. | ||
|
||
Args: | ||
refresh_token (str): The JWT refresh token from which to extract the expiration. | ||
|
||
Returns: | ||
int or None: The expiration time (`exp`) in seconds since the epoch, | ||
or None if extraction fails. | ||
""" | ||
|
||
# Method 1: PyJWT | ||
try: | ||
# Skipping keys since we're not verifying the signature | ||
decoded_refresh_token = jwt.decode( | ||
refresh_token, | ||
options={ | ||
"verify_aud": False, | ||
"verify_at_hash": False, | ||
"verify_signature": False, | ||
}, | ||
algorithms=["RS256", "HS512"], | ||
) | ||
exp = decoded_refresh_token.get("exp") | ||
|
||
if exp is not None: | ||
return exp | ||
except Exception as e: | ||
logger.info(f"Refresh token expiry: Method (PyJWT) failed: {e}") | ||
|
||
# Method 2: Introspection | ||
try: | ||
introspection_response = self.introspect_token(refresh_token) | ||
exp = introspection_response.get("exp") | ||
|
||
if exp is not None: | ||
return exp | ||
except Exception as e: | ||
logger.info(f"Refresh token expiry: Method Introspection failed: {e}") | ||
|
||
# Method 3: Manual base64 decoding | ||
try: | ||
# Assuming the token is a JWT (header.payload.signature) | ||
payload_encoded = refresh_token.split(".")[1] | ||
# Add necessary padding for base64 decoding | ||
payload_encoded += "=" * (4 - len(payload_encoded) % 4) | ||
payload_decoded = base64.urlsafe_b64decode(payload_encoded) | ||
payload_json = json.loads(payload_decoded) | ||
exp = payload_json.get("exp") | ||
|
||
if exp is not None: | ||
return exp | ||
except Exception as e: | ||
logger.info(f"Method 3 (Manual decoding) failed: {e}") | ||
|
||
# If all methods fail, return None | ||
return None | ||
|
||
def introspect_token(self, token): | ||
"""Introspects an access token to determine its validity and retrieve associated metadata. | ||
|
||
This method sends a POST request to the introspection endpoint specified in the OpenID | ||
discovery document. The request includes the provided token and client credentials, | ||
allowing verification of the token's validity and retrieval of any additional metadata | ||
(e.g., token expiry, scopes, or user information). | ||
|
||
Args: | ||
token (str): The access token to be introspected. | ||
|
||
Returns: | ||
dict or None: A dictionary containing the token's introspection data if the request | ||
is successful and the response status code is 200. If the introspection fails or an | ||
exception occurs, returns None. | ||
|
||
Raises: | ||
Exception: Logs an error message if an error occurs during the introspection process. | ||
""" | ||
try: | ||
introspect_endpoint = self.client.get_value_from_discovery_doc( | ||
"introspection_endpoint", "" | ||
) | ||
|
||
# Headers and payload for the introspection request | ||
headers = {"Content-Type": "application/x-www-form-urlencoded"} | ||
data = { | ||
"token": token, | ||
"client_id": flask.session.get("client_id"), | ||
"client_secret": flask.session.get("client_secret"), | ||
} | ||
|
||
response = requests.post(introspect_endpoint, headers=headers, data=data) | ||
|
||
if response.status_code == 200: | ||
return response.json() | ||
else: | ||
logger.info(f"Error introspecting token: {response.status_code}") | ||
return None | ||
|
||
except Exception as e: | ||
logger.info(f"Error introspecting token: {e}") | ||
return None | ||
|
||
def post_login(self, user=None, token_result=None, **kwargs): | ||
prepare_login_log(self.idp_name) | ||
|
||
metrics.add_login_event( | ||
user_sub=flask.g.user.id, | ||
idp=self.idp_name, | ||
|
@@ -142,6 +315,13 @@ def post_login(self, user=None, token_result=None, **kwargs): | |
client_id=flask.session.get("client_id"), | ||
) | ||
|
||
# this attribute is only applicable to some OAuth clients | ||
# (e.g., not all clients need is_read_authz_groups_from_tokens_enabled) | ||
if self.read_authz_groups_from_tokens: | ||
self.client.update_user_authorization( | ||
user=user, pkey_cache=None, db_session=None, idp_name=self.idp_name | ||
) | ||
|
||
if token_result: | ||
username = token_result.get(self.username_field) | ||
if self.is_mfa_enabled: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,12 +8,13 @@ | |
|
||
from fence.errors import APIError | ||
from fence.config import config | ||
import traceback | ||
|
||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
def get_error_response(error): | ||
def get_error_response(error: Exception): | ||
details, status_code = get_error_details_and_status(error) | ||
support_email = config.get("SUPPORT_EMAIL_FOR_ERRORS") | ||
app_name = config.get("APP_NAME", "Gen3 Data Commons") | ||
|
@@ -27,6 +28,11 @@ def get_error_response(error): | |
) | ||
) | ||
|
||
# TODO: Issue: Error messages are obfuscated, the line below needs be | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we get this fixed before we merge? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
# uncommented when troubleshooting errors. | ||
# Breaks tests if not commented out / removed. We need a fix for this. | ||
# raise error | ||
|
||
# don't include internal details in the public error message | ||
# to do this, only include error messages for known http status codes | ||
# that are less that 500 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do this cleanup before we merge?