-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
OZ-671: Superset config to use 'authlib' instead of Flask OIDC (#87)
- Loading branch information
Showing
2 changed files
with
73 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,50 @@ | ||
from flask import redirect, request | ||
from flask_appbuilder.security.manager import AUTH_OID | ||
from math import log | ||
from superset.security import SupersetSecurityManager | ||
from flask_oidc import OpenIDConnect | ||
from flask_appbuilder.security.views import AuthOIDView | ||
from flask_login import login_user | ||
from urllib.parse import quote | ||
from flask_appbuilder.views import ModelView, SimpleFormView, expose | ||
import logging | ||
logger = logging.getLogger(__name__) | ||
|
||
class AuthOIDCView(AuthOIDView): | ||
def add_role_if_missing(self, sm, user_id, role_name): | ||
found_role = sm.find_role(role_name) | ||
session = sm.get_session | ||
user = session.query(sm.user_model).get(user_id) | ||
if found_role and found_role not in user.roles: | ||
user.roles += [found_role] | ||
session.commit() | ||
|
||
@expose('/login/', methods=['GET', 'POST']) | ||
def login(self, flag=True): | ||
sm = self.appbuilder.sm | ||
oidc = sm.oid | ||
|
||
|
||
@self.appbuilder.sm.oid.require_login | ||
def handle_login(): | ||
user = sm.auth_user_oid(oidc.user_getfield('email')) | ||
if user is None: | ||
info = oidc.user_getinfo(['preferred_username', 'given_name', 'family_name', 'email','roles']) | ||
user = sm.add_user(info.get('preferred_username'), info.get('given_name'), info.get('family_name'), info.get('email'), sm.find_role('Gamma')) | ||
role_info = oidc.user_getinfo(['roles']) | ||
if role_info is not None: | ||
for role in role_info['roles']: | ||
self.add_role_if_missing(sm, user.id, role) | ||
login_user(user, remember=False) | ||
return redirect(self.appbuilder.get_url_for_index) | ||
|
||
return handle_login() | ||
|
||
@expose('/logout/', methods=['GET', 'POST']) | ||
def logout(self): | ||
|
||
oidc = self.appbuilder.sm.oid | ||
|
||
oidc.logout() | ||
super(AuthOIDCView, self).logout() | ||
from flask_appbuilder.security.views import AuthOAuthView | ||
from flask_appbuilder.baseviews import expose | ||
import time | ||
from flask import ( | ||
redirect, | ||
request | ||
) | ||
|
||
class CustomAuthOAuthView(AuthOAuthView): | ||
|
||
@expose("/logout/") | ||
def logout(self, provider="keycloak", register=None): | ||
provider_obj = self.appbuilder.sm.oauth_remotes[provider] | ||
redirect_url = request.url_root.strip('/') + self.appbuilder.get_url_for_login | ||
|
||
return redirect(oidc.client_secrets.get('issuer') + '/protocol/openid-connect/logout?redirect_uri=' + quote(redirect_url)) | ||
|
||
class OIDCSecurityManager(SupersetSecurityManager): | ||
authoidview = AuthOIDCView | ||
def __init__(self,appbuilder): | ||
super(OIDCSecurityManager, self).__init__(appbuilder) | ||
if self.auth_type == AUTH_OID: | ||
self.oid = OpenIDConnect(self.appbuilder.get_app) | ||
url = ("logout?client_id={}&post_logout_redirect_uri={}".format( | ||
provider_obj.client_id, | ||
redirect_url | ||
)) | ||
|
||
ret = super().logout() | ||
time.sleep(1) | ||
|
||
return redirect("{}{}".format(provider_obj.api_base_url, url)) | ||
|
||
|
||
class CustomSecurityManager(SupersetSecurityManager): | ||
# override the logout function | ||
authoauthview = CustomAuthOAuthView | ||
|
||
def oauth_user_info(self, provider, response=None): | ||
logging.debug("Oauth2 provider: {0}.".format(provider)) | ||
if provider == 'keycloak': | ||
# superset_roles: list[str] = ["Admin", "Alpha", "Gamma", "Public", "granter", "sql_lab"] | ||
me = self.appbuilder.sm.oauth_remotes[provider].get('userinfo').json() | ||
roles = ["public", ] | ||
if "roles" in me: | ||
role_prefix = "superset-" | ||
roles = [r[len(role_prefix):].lower() for r in me.get("roles", []) if r.startswith(role_prefix)] | ||
|
||
return { | ||
"username": me.get("preferred_username", ""), | ||
"first_name": me.get("given_name", ""), | ||
"last_name": me.get("family_name", ""), | ||
"email": me.get("email", ""), | ||
"role_keys": roles, | ||
} | ||
return {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters