diff --git a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/files/jupyterhub/04-auth.py b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/files/jupyterhub/04-auth.py index 082268a107..bc6fb6a721 100644 --- a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/files/jupyterhub/04-auth.py +++ b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/files/jupyterhub/04-auth.py @@ -1,8 +1,10 @@ import json import os +import time import urllib from functools import reduce +from jupyterhub import scopes from jupyterhub.traitlets import Callable from oauthenticator.generic import GenericOAuthenticator from traitlets import Bool, Unicode, Union @@ -28,24 +30,72 @@ class KeyCloakOAuthenticator(GenericOAuthenticator): reset_managed_roles_on_startup = Bool(True) async def update_auth_model(self, auth_model): + """Updates and returns the auth_model dict. + This function is called every time a user authenticates with JupyterHub, as in + every time a user login to Nebari. + + It will fetch the roles and their corresponding scopes from keycloak + and return updated auth model which will updates roles/scopes for the + user. When a user's roles/scopes are updated, they take in-affect only + after they log in to Nebari. + """ + start = time.time() + self.log.info("Updating user auth model") auth_model = await super().update_auth_model(auth_model) + user_id = auth_model["auth_state"]["oauth_user"]["sub"] + token = await self._get_token() + + jupyterhub_client_id = await self._get_jupyterhub_client_id(token=token) user_info = auth_model["auth_state"][self.user_auth_state_key] - user_roles = self._get_user_roles(user_info) - auth_model["roles"] = [{"name": role_name} for role_name in user_roles] + user_roles_from_claims = self._get_user_roles(user_info=user_info) + keycloak_api_call_start = time.time() + user_roles = await self._get_client_roles_for_user( + user_id=user_id, client_id=jupyterhub_client_id, token=token + ) + user_roles_rich = await self._get_roles_with_attributes( + roles=user_roles, client_id=jupyterhub_client_id, token=token + ) + keycloak_api_call_time_taken = time.time() - keycloak_api_call_start + user_roles_rich_names = {role["name"] for role in user_roles_rich} + user_roles_non_jhub_client = [ + {"name": role} + for role in user_roles_from_claims + if role in (user_roles_from_claims - user_roles_rich_names) + ] + auth_model["roles"] = [ + { + "name": role["name"], + "description": role.get("description"), + "scopes": self._get_scope_from_role(role), + } + for role in [*user_roles_rich, *user_roles_non_jhub_client] + ] # note: because the roles check is comprehensive, we need to re-add the admin and user roles if auth_model["admin"]: auth_model["roles"].append({"name": "admin"}) - if self.check_allowed(auth_model["name"], auth_model): + if await self.check_allowed(auth_model["name"], auth_model): auth_model["roles"].append({"name": "user"}) + execution_time = time.time() - start + self.log.info( + f"Auth model update complete, time taken: {execution_time}s " + f"time taken for keycloak api call: {keycloak_api_call_time_taken}s " + f"delta between full execution and keycloak call: {execution_time - keycloak_api_call_time_taken}s" + ) return auth_model - async def load_managed_roles(self): - if not self.manage_roles: - raise ValueError( - "Managed roles can only be loaded when `manage_roles` is True" - ) - token = await self._get_token() + async def _get_jupyterhub_client_roles(self, jupyterhub_client_id, token): + """Get roles for the client named 'jupyterhub'.""" + # Includes roles like "jupyterhub_admin", "jupyterhub_developer", "dask_gateway_developer" + + client_roles = await self._fetch_api( + endpoint=f"clients/{jupyterhub_client_id}/roles", token=token + ) + client_roles_rich = await self._get_roles_with_attributes( + client_roles, client_id=jupyterhub_client_id, token=token + ) + return client_roles_rich + async def _get_jupyterhub_client_id(self, token): # Get the clients list to find the "id" of "jupyterhub" client. clients_data = await self._fetch_api(endpoint="clients/", token=token) jupyterhub_clients = [ @@ -53,16 +103,28 @@ async def load_managed_roles(self): ] assert len(jupyterhub_clients) == 1 jupyterhub_client_id = jupyterhub_clients[0]["id"] + return jupyterhub_client_id - # Includes roles like "jupyterhub_admin", "jupyterhub_developer", "dask_gateway_developer" - client_roles = await self._fetch_api( - endpoint=f"clients/{jupyterhub_client_id}/roles", token=token + async def load_managed_roles(self): + self.log.info("Loading managed roles") + if not self.manage_roles: + raise ValueError( + "Managed roles can only be loaded when `manage_roles` is True" + ) + token = await self._get_token() + jupyterhub_client_id = await self._get_jupyterhub_client_id(token=token) + client_roles_rich = await self._get_jupyterhub_client_roles( + jupyterhub_client_id=jupyterhub_client_id, token=token ) # Includes roles like "default-roles-nebari", "offline_access", "uma_authorization" realm_roles = await self._fetch_api(endpoint="roles", token=token) roles = { - role["name"]: {"name": role["name"], "description": role["description"]} - for role in [*realm_roles, *client_roles] + role["name"]: { + "name": role["name"], + "description": role["description"], + "scopes": self._get_scope_from_role(role), + } + for role in [*realm_roles, *client_roles_rich] } # we could use either `name` (e.g. "developer") or `path` ("/developer"); # since the default claim key returns `path`, it seems preferable. @@ -76,7 +138,7 @@ async def load_managed_roles(self): # fetch role assignments to users users = await self._fetch_api(f"roles/{role_name}/users", token=token) role["users"] = [user["username"] for user in users] - for client_role in client_roles: + for client_role in client_roles_rich: role_name = client_role["name"] role = roles[role_name] # fetch role assignments to groups @@ -92,6 +154,49 @@ async def load_managed_roles(self): return list(roles.values()) + def _get_scope_from_role(self, role): + """Return scopes from role if the component is jupyterhub""" + role_scopes = role.get("attributes", {}).get("scopes", []) + component = role.get("attributes", {}).get("component") + # Attributes are returned as a single-element array, unless `##` delimiter is used in Keycloak + # See this: https://stackoverflow.com/questions/68954733/keycloak-client-role-attribute-array + if component == ["jupyterhub"] and role_scopes: + return self.validate_scopes(role_scopes[0].split(",")) + else: + return [] + + def validate_scopes(self, role_scopes): + """Validate role scopes to sanity check user provided scopes from keycloak""" + self.log.info(f"Validating role scopes: {role_scopes}") + try: + # This is not a public function, but there isn't any alternative + # method to verify scopes, and we do need to do this sanity check + # as a invalid scopes could cause hub pod to fail + scopes._check_scopes_exist(role_scopes) + return role_scopes + except scopes.ScopeNotFound as e: + self.log.error(f"Invalid scopes, skipping: {role_scopes} ({e})") + return [] + + async def _get_roles_with_attributes(self, roles: dict, client_id: str, token: str): + """This fetches all roles by id to fetch there attributes.""" + roles_rich = [] + for role in roles: + # If this takes too much time, which isn't the case right now, we can + # also do multi-threaded requests + role_rich = await self._fetch_api( + endpoint=f"roles-by-id/{role['id']}?client={client_id}", token=token + ) + roles_rich.append(role_rich) + return roles_rich + + async def _get_client_roles_for_user(self, user_id, client_id, token): + user_roles = await self._fetch_api( + endpoint=f"users/{user_id}/role-mappings/clients/{client_id}/composite", + token=token, + ) + return user_roles + def _get_user_roles(self, user_info): if callable(self.claim_roles_key): return set(self.claim_roles_key(user_info)) diff --git a/tests/tests_deployment/conftest.py b/tests/tests_deployment/conftest.py new file mode 100644 index 0000000000..fa71302823 --- /dev/null +++ b/tests/tests_deployment/conftest.py @@ -0,0 +1,11 @@ +import pytest + +from tests.tests_deployment.keycloak_utils import delete_client_keycloak_test_roles + + +@pytest.fixture() +def cleanup_keycloak_roles(): + # setup + yield + # teardown + delete_client_keycloak_test_roles(client_name="jupyterhub") diff --git a/tests/tests_deployment/keycloak_utils.py b/tests/tests_deployment/keycloak_utils.py new file mode 100644 index 0000000000..6e6f6c21e6 --- /dev/null +++ b/tests/tests_deployment/keycloak_utils.py @@ -0,0 +1,96 @@ +import os +import pathlib + +from _nebari.config import read_configuration +from _nebari.keycloak import get_keycloak_admin_from_config +from nebari.plugins import nebari_plugin_manager + + +def get_keycloak_client_details_by_name(client_name, keycloak_admin=None): + if not keycloak_admin: + keycloak_admin = get_keycloak_admin() + clients = keycloak_admin.get_clients() + for client in clients: + if client["clientId"] == client_name: + return client + + +def get_keycloak_user_details_by_name(username, keycloak_admin=None): + if not keycloak_admin: + keycloak_admin = get_keycloak_admin() + users = keycloak_admin.get_users() + for user in users: + if user["username"] == username: + return user + + +def get_keycloak_role_details_by_name(roles, role_name): + for role in roles: + if role["name"] == role_name: + return role + + +def get_keycloak_admin(): + config_schema = nebari_plugin_manager.config_schema + config_filepath = os.environ.get("NEBARI_CONFIG_PATH", "nebari-config.yaml") + assert pathlib.Path(config_filepath).exists() + config = read_configuration(config_filepath, config_schema) + return get_keycloak_admin_from_config(config) + + +def create_keycloak_client_role( + client_id: str, role_name: str, scopes: str, component: str +): + keycloak_admin = get_keycloak_admin() + keycloak_admin.create_client_role( + client_id, + payload={ + "name": role_name, + "description": f"{role_name} description", + "attributes": {"scopes": [scopes], "component": [component]}, + }, + ) + client_roles = keycloak_admin.get_client_roles(client_id=client_id) + return get_keycloak_role_details_by_name(client_roles, role_name) + + +def assign_keycloak_client_role_to_user(username: str, client_name: str, role: dict): + """Given a keycloak role and client name, assign that to the user""" + keycloak_admin = get_keycloak_admin() + user_details = get_keycloak_user_details_by_name( + username=username, keycloak_admin=keycloak_admin + ) + client_details = get_keycloak_client_details_by_name( + client_name=client_name, keycloak_admin=keycloak_admin + ) + keycloak_admin.assign_client_role( + user_id=user_details["id"], client_id=client_details["id"], roles=[role] + ) + + +def create_keycloak_role(client_name: str, role_name: str, scopes: str, component: str): + """Create a role keycloak role for the given client with scopes and + component set in attributes + """ + keycloak_admin = get_keycloak_admin() + client_details = get_keycloak_client_details_by_name( + client_name=client_name, keycloak_admin=keycloak_admin + ) + return create_keycloak_client_role( + client_details["id"], role_name=role_name, scopes=scopes, component=component + ) + + +def delete_client_keycloak_test_roles(client_name): + keycloak_admin = get_keycloak_admin() + client_details = get_keycloak_client_details_by_name( + client_name=client_name, keycloak_admin=keycloak_admin + ) + client_roles = keycloak_admin.get_client_roles(client_id=client_details["id"]) + for role in client_roles: + if not role["name"].startswith("test"): + continue + keycloak_admin.delete_client_role( + client_role_id=client_details["id"], + role_name=role["name"], + ) diff --git a/tests/tests_deployment/test_jupyterhub_api.py b/tests/tests_deployment/test_jupyterhub_api.py index faa6e82f53..5f02c72dd4 100644 --- a/tests/tests_deployment/test_jupyterhub_api.py +++ b/tests/tests_deployment/test_jupyterhub_api.py @@ -1,7 +1,11 @@ import pytest from tests.tests_deployment import constants -from tests.tests_deployment.utils import get_jupyterhub_session +from tests.tests_deployment.keycloak_utils import ( + assign_keycloak_client_role_to_user, + create_keycloak_role, +) +from tests.tests_deployment.utils import create_jupyterhub_token, get_jupyterhub_session @pytest.mark.filterwarnings("ignore::urllib3.exceptions.InsecureRequestWarning") @@ -30,6 +34,54 @@ def test_jupyterhub_loads_roles_from_keycloak(): } +@pytest.mark.parametrize( + "component,scopes,expected_scopes_difference", + ( + [ + "jupyterhub", + "read:users:shares,read:groups:shares,users:shares", + {"read:groups:shares", "users:shares", "read:users:shares"}, + ], + ["invalid-component", "read:users:shares,read:groups:shares,users:shares", {}], + ["invalid-component", "admin:invalid-scope", {}], + ), +) +@pytest.mark.filterwarnings("ignore::urllib3.exceptions.InsecureRequestWarning") +@pytest.mark.filterwarnings( + "ignore:.*auto_refresh_token is deprecated:DeprecationWarning" +) +def test_keycloak_roles_attributes_parsed_as_jhub_scopes( + component, scopes, expected_scopes_difference, cleanup_keycloak_roles +): + # check token scopes before role creation and assignment + token_response_before = create_jupyterhub_token( + note="before-role-creation-and-assignment" + ) + token_scopes_before = set(token_response_before.json()["scopes"]) + # create keycloak role with jupyterhub scopes in attributes + role = create_keycloak_role( + client_name="jupyterhub", + # Note: we're clearing this role after every test case, and we're clearing + # it by name, so it must start with test- to be deleted afterward + role_name="test-custom-role", + scopes=scopes, + component=component, + ) + assert role + # assign created role to the user + assign_keycloak_client_role_to_user( + constants.KEYCLOAK_USERNAME, client_name="jupyterhub", role=role + ) + token_response_after = create_jupyterhub_token( + note="after-role-creation-and-assignment" + ) + token_scopes_after = set(token_response_after.json()["scopes"]) + # verify new scopes added/removed + expected_scopes_difference = token_scopes_after - token_scopes_before + # Comparing token scopes for the user before and after role assignment + assert expected_scopes_difference == expected_scopes_difference + + @pytest.mark.filterwarnings("ignore::urllib3.exceptions.InsecureRequestWarning") def test_jupyterhub_loads_groups_from_keycloak(): session = get_jupyterhub_session() diff --git a/tests/tests_deployment/utils.py b/tests/tests_deployment/utils.py index f37523d920..b0965dd1ae 100644 --- a/tests/tests_deployment/utils.py +++ b/tests/tests_deployment/utils.py @@ -26,21 +26,24 @@ def get_jupyterhub_session(): return session -def get_jupyterhub_token(note="jupyterhub-tests-deployment"): +def create_jupyterhub_token(note): session = get_jupyterhub_session() xsrf_token = session.cookies.get("_xsrf") headers = {"Referer": f"https://{constants.NEBARI_HOSTNAME}/hub/token"} if xsrf_token: headers["X-XSRFToken"] = xsrf_token data = {"note": note, "expires_in": None} - r = session.post( + return session.post( f"https://{constants.NEBARI_HOSTNAME}/hub/api/users/{constants.KEYCLOAK_USERNAME}/tokens", headers=headers, json=data, verify=False, ) - return r.json()["token"] + +def get_jupyterhub_token(note="jupyterhub-tests-deployment"): + response = create_jupyterhub_token(note=note) + return response.json()["token"] def monkeypatch_ssl_context():