Skip to content

Commit

Permalink
Merge branch 'develop' into 2284-keda-conda-store-worker-hpa
Browse files Browse the repository at this point in the history
  • Loading branch information
pt247 authored May 28, 2024
2 parents bcddb7b + 363cb0d commit 40ee18f
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import os
import time
import urllib
from functools import reduce

from jupyterhub import scopes
from jupyterhub.traitlets import Callable
from oauthenticator.generic import GenericOAuthenticator
from traitlets import Bool, Unicode, Union
Expand All @@ -28,41 +30,101 @@ class KeyCloakOAuthenticator(GenericOAuthenticator):
reset_managed_roles_on_startup = Bool(True)

async def update_auth_model(self, auth_model):
"""Updates and returns the auth_model dict.
This function is called every time a user authenticates with JupyterHub, as in
every time a user login to Nebari.
It will fetch the roles and their corresponding scopes from keycloak
and return updated auth model which will updates roles/scopes for the
user. When a user's roles/scopes are updated, they take in-affect only
after they log in to Nebari.
"""
start = time.time()
self.log.info("Updating user auth model")
auth_model = await super().update_auth_model(auth_model)
user_id = auth_model["auth_state"]["oauth_user"]["sub"]
token = await self._get_token()

jupyterhub_client_id = await self._get_jupyterhub_client_id(token=token)
user_info = auth_model["auth_state"][self.user_auth_state_key]
user_roles = self._get_user_roles(user_info)
auth_model["roles"] = [{"name": role_name} for role_name in user_roles]
user_roles_from_claims = self._get_user_roles(user_info=user_info)
keycloak_api_call_start = time.time()
user_roles = await self._get_client_roles_for_user(
user_id=user_id, client_id=jupyterhub_client_id, token=token
)
user_roles_rich = await self._get_roles_with_attributes(
roles=user_roles, client_id=jupyterhub_client_id, token=token
)
keycloak_api_call_time_taken = time.time() - keycloak_api_call_start
user_roles_rich_names = {role["name"] for role in user_roles_rich}
user_roles_non_jhub_client = [
{"name": role}
for role in user_roles_from_claims
if role in (user_roles_from_claims - user_roles_rich_names)
]
auth_model["roles"] = [
{
"name": role["name"],
"description": role.get("description"),
"scopes": self._get_scope_from_role(role),
}
for role in [*user_roles_rich, *user_roles_non_jhub_client]
]
# note: because the roles check is comprehensive, we need to re-add the admin and user roles
if auth_model["admin"]:
auth_model["roles"].append({"name": "admin"})
if self.check_allowed(auth_model["name"], auth_model):
if await self.check_allowed(auth_model["name"], auth_model):
auth_model["roles"].append({"name": "user"})
execution_time = time.time() - start
self.log.info(
f"Auth model update complete, time taken: {execution_time}s "
f"time taken for keycloak api call: {keycloak_api_call_time_taken}s "
f"delta between full execution and keycloak call: {execution_time - keycloak_api_call_time_taken}s"
)
return auth_model

async def load_managed_roles(self):
if not self.manage_roles:
raise ValueError(
"Managed roles can only be loaded when `manage_roles` is True"
)
token = await self._get_token()
async def _get_jupyterhub_client_roles(self, jupyterhub_client_id, token):
"""Get roles for the client named 'jupyterhub'."""
# Includes roles like "jupyterhub_admin", "jupyterhub_developer", "dask_gateway_developer"

client_roles = await self._fetch_api(
endpoint=f"clients/{jupyterhub_client_id}/roles", token=token
)
client_roles_rich = await self._get_roles_with_attributes(
client_roles, client_id=jupyterhub_client_id, token=token
)
return client_roles_rich

async def _get_jupyterhub_client_id(self, token):
# Get the clients list to find the "id" of "jupyterhub" client.
clients_data = await self._fetch_api(endpoint="clients/", token=token)
jupyterhub_clients = [
client for client in clients_data if client["clientId"] == "jupyterhub"
]
assert len(jupyterhub_clients) == 1
jupyterhub_client_id = jupyterhub_clients[0]["id"]
return jupyterhub_client_id

# Includes roles like "jupyterhub_admin", "jupyterhub_developer", "dask_gateway_developer"
client_roles = await self._fetch_api(
endpoint=f"clients/{jupyterhub_client_id}/roles", token=token
async def load_managed_roles(self):
self.log.info("Loading managed roles")
if not self.manage_roles:
raise ValueError(
"Managed roles can only be loaded when `manage_roles` is True"
)
token = await self._get_token()
jupyterhub_client_id = await self._get_jupyterhub_client_id(token=token)
client_roles_rich = await self._get_jupyterhub_client_roles(
jupyterhub_client_id=jupyterhub_client_id, token=token
)
# Includes roles like "default-roles-nebari", "offline_access", "uma_authorization"
realm_roles = await self._fetch_api(endpoint="roles", token=token)
roles = {
role["name"]: {"name": role["name"], "description": role["description"]}
for role in [*realm_roles, *client_roles]
role["name"]: {
"name": role["name"],
"description": role["description"],
"scopes": self._get_scope_from_role(role),
}
for role in [*realm_roles, *client_roles_rich]
}
# we could use either `name` (e.g. "developer") or `path` ("/developer");
# since the default claim key returns `path`, it seems preferable.
Expand All @@ -76,7 +138,7 @@ async def load_managed_roles(self):
# fetch role assignments to users
users = await self._fetch_api(f"roles/{role_name}/users", token=token)
role["users"] = [user["username"] for user in users]
for client_role in client_roles:
for client_role in client_roles_rich:
role_name = client_role["name"]
role = roles[role_name]
# fetch role assignments to groups
Expand All @@ -92,6 +154,49 @@ async def load_managed_roles(self):

return list(roles.values())

def _get_scope_from_role(self, role):
"""Return scopes from role if the component is jupyterhub"""
role_scopes = role.get("attributes", {}).get("scopes", [])
component = role.get("attributes", {}).get("component")
# Attributes are returned as a single-element array, unless `##` delimiter is used in Keycloak
# See this: https://stackoverflow.com/questions/68954733/keycloak-client-role-attribute-array
if component == ["jupyterhub"] and role_scopes:
return self.validate_scopes(role_scopes[0].split(","))
else:
return []

def validate_scopes(self, role_scopes):
"""Validate role scopes to sanity check user provided scopes from keycloak"""
self.log.info(f"Validating role scopes: {role_scopes}")
try:
# This is not a public function, but there isn't any alternative
# method to verify scopes, and we do need to do this sanity check
# as a invalid scopes could cause hub pod to fail
scopes._check_scopes_exist(role_scopes)
return role_scopes
except scopes.ScopeNotFound as e:
self.log.error(f"Invalid scopes, skipping: {role_scopes} ({e})")
return []

async def _get_roles_with_attributes(self, roles: dict, client_id: str, token: str):
"""This fetches all roles by id to fetch there attributes."""
roles_rich = []
for role in roles:
# If this takes too much time, which isn't the case right now, we can
# also do multi-threaded requests
role_rich = await self._fetch_api(
endpoint=f"roles-by-id/{role['id']}?client={client_id}", token=token
)
roles_rich.append(role_rich)
return roles_rich

async def _get_client_roles_for_user(self, user_id, client_id, token):
user_roles = await self._fetch_api(
endpoint=f"users/{user_id}/role-mappings/clients/{client_id}/composite",
token=token,
)
return user_roles

def _get_user_roles(self, user_info):
if callable(self.claim_roles_key):
return set(self.claim_roles_key(user_info))
Expand Down
11 changes: 11 additions & 0 deletions tests/tests_deployment/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest

from tests.tests_deployment.keycloak_utils import delete_client_keycloak_test_roles


@pytest.fixture()
def cleanup_keycloak_roles():
# setup
yield
# teardown
delete_client_keycloak_test_roles(client_name="jupyterhub")
96 changes: 96 additions & 0 deletions tests/tests_deployment/keycloak_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import os
import pathlib

from _nebari.config import read_configuration
from _nebari.keycloak import get_keycloak_admin_from_config
from nebari.plugins import nebari_plugin_manager


def get_keycloak_client_details_by_name(client_name, keycloak_admin=None):
if not keycloak_admin:
keycloak_admin = get_keycloak_admin()
clients = keycloak_admin.get_clients()
for client in clients:
if client["clientId"] == client_name:
return client


def get_keycloak_user_details_by_name(username, keycloak_admin=None):
if not keycloak_admin:
keycloak_admin = get_keycloak_admin()
users = keycloak_admin.get_users()
for user in users:
if user["username"] == username:
return user


def get_keycloak_role_details_by_name(roles, role_name):
for role in roles:
if role["name"] == role_name:
return role


def get_keycloak_admin():
config_schema = nebari_plugin_manager.config_schema
config_filepath = os.environ.get("NEBARI_CONFIG_PATH", "nebari-config.yaml")
assert pathlib.Path(config_filepath).exists()
config = read_configuration(config_filepath, config_schema)
return get_keycloak_admin_from_config(config)


def create_keycloak_client_role(
client_id: str, role_name: str, scopes: str, component: str
):
keycloak_admin = get_keycloak_admin()
keycloak_admin.create_client_role(
client_id,
payload={
"name": role_name,
"description": f"{role_name} description",
"attributes": {"scopes": [scopes], "component": [component]},
},
)
client_roles = keycloak_admin.get_client_roles(client_id=client_id)
return get_keycloak_role_details_by_name(client_roles, role_name)


def assign_keycloak_client_role_to_user(username: str, client_name: str, role: dict):
"""Given a keycloak role and client name, assign that to the user"""
keycloak_admin = get_keycloak_admin()
user_details = get_keycloak_user_details_by_name(
username=username, keycloak_admin=keycloak_admin
)
client_details = get_keycloak_client_details_by_name(
client_name=client_name, keycloak_admin=keycloak_admin
)
keycloak_admin.assign_client_role(
user_id=user_details["id"], client_id=client_details["id"], roles=[role]
)


def create_keycloak_role(client_name: str, role_name: str, scopes: str, component: str):
"""Create a role keycloak role for the given client with scopes and
component set in attributes
"""
keycloak_admin = get_keycloak_admin()
client_details = get_keycloak_client_details_by_name(
client_name=client_name, keycloak_admin=keycloak_admin
)
return create_keycloak_client_role(
client_details["id"], role_name=role_name, scopes=scopes, component=component
)


def delete_client_keycloak_test_roles(client_name):
keycloak_admin = get_keycloak_admin()
client_details = get_keycloak_client_details_by_name(
client_name=client_name, keycloak_admin=keycloak_admin
)
client_roles = keycloak_admin.get_client_roles(client_id=client_details["id"])
for role in client_roles:
if not role["name"].startswith("test"):
continue
keycloak_admin.delete_client_role(
client_role_id=client_details["id"],
role_name=role["name"],
)
54 changes: 53 additions & 1 deletion tests/tests_deployment/test_jupyterhub_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import pytest

from tests.tests_deployment import constants
from tests.tests_deployment.utils import get_jupyterhub_session
from tests.tests_deployment.keycloak_utils import (
assign_keycloak_client_role_to_user,
create_keycloak_role,
)
from tests.tests_deployment.utils import create_jupyterhub_token, get_jupyterhub_session


@pytest.mark.filterwarnings("ignore::urllib3.exceptions.InsecureRequestWarning")
Expand Down Expand Up @@ -30,6 +34,54 @@ def test_jupyterhub_loads_roles_from_keycloak():
}


@pytest.mark.parametrize(
"component,scopes,expected_scopes_difference",
(
[
"jupyterhub",
"read:users:shares,read:groups:shares,users:shares",
{"read:groups:shares", "users:shares", "read:users:shares"},
],
["invalid-component", "read:users:shares,read:groups:shares,users:shares", {}],
["invalid-component", "admin:invalid-scope", {}],
),
)
@pytest.mark.filterwarnings("ignore::urllib3.exceptions.InsecureRequestWarning")
@pytest.mark.filterwarnings(
"ignore:.*auto_refresh_token is deprecated:DeprecationWarning"
)
def test_keycloak_roles_attributes_parsed_as_jhub_scopes(
component, scopes, expected_scopes_difference, cleanup_keycloak_roles
):
# check token scopes before role creation and assignment
token_response_before = create_jupyterhub_token(
note="before-role-creation-and-assignment"
)
token_scopes_before = set(token_response_before.json()["scopes"])
# create keycloak role with jupyterhub scopes in attributes
role = create_keycloak_role(
client_name="jupyterhub",
# Note: we're clearing this role after every test case, and we're clearing
# it by name, so it must start with test- to be deleted afterward
role_name="test-custom-role",
scopes=scopes,
component=component,
)
assert role
# assign created role to the user
assign_keycloak_client_role_to_user(
constants.KEYCLOAK_USERNAME, client_name="jupyterhub", role=role
)
token_response_after = create_jupyterhub_token(
note="after-role-creation-and-assignment"
)
token_scopes_after = set(token_response_after.json()["scopes"])
# verify new scopes added/removed
expected_scopes_difference = token_scopes_after - token_scopes_before
# Comparing token scopes for the user before and after role assignment
assert expected_scopes_difference == expected_scopes_difference


@pytest.mark.filterwarnings("ignore::urllib3.exceptions.InsecureRequestWarning")
def test_jupyterhub_loads_groups_from_keycloak():
session = get_jupyterhub_session()
Expand Down
9 changes: 6 additions & 3 deletions tests/tests_deployment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,24 @@ def get_jupyterhub_session():
return session


def get_jupyterhub_token(note="jupyterhub-tests-deployment"):
def create_jupyterhub_token(note):
session = get_jupyterhub_session()
xsrf_token = session.cookies.get("_xsrf")
headers = {"Referer": f"https://{constants.NEBARI_HOSTNAME}/hub/token"}
if xsrf_token:
headers["X-XSRFToken"] = xsrf_token
data = {"note": note, "expires_in": None}
r = session.post(
return session.post(
f"https://{constants.NEBARI_HOSTNAME}/hub/api/users/{constants.KEYCLOAK_USERNAME}/tokens",
headers=headers,
json=data,
verify=False,
)

return r.json()["token"]

def get_jupyterhub_token(note="jupyterhub-tests-deployment"):
response = create_jupyterhub_token(note=note)
return response.json()["token"]


def monkeypatch_ssl_context():
Expand Down

0 comments on commit 40ee18f

Please sign in to comment.