From 6247c6e95140d4a7f256ca075f54e79bf6fa6975 Mon Sep 17 00:00:00 2001 From: Min RK Date: Fri, 23 Aug 2024 09:40:56 +0200 Subject: [PATCH 1/4] allow get_user_groups / auth_state_groups_key to be async --- oauthenticator/oauth2.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index c36fb848..d743795e 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -9,6 +9,7 @@ import os import uuid from functools import reduce +from inspect import isawaitable from urllib.parse import quote, urlencode, urlparse, urlunparse import jwt @@ -361,6 +362,7 @@ class OAuthenticator(Authenticator): Can be a string key name (use periods for nested keys), or a callable that accepts the auth state (as a dict) and returns the groups list. + Callables may be async. Requires `manage_groups` to also be `True`. """, @@ -1086,18 +1088,22 @@ def build_auth_state_dict(self, token_info, user_info): self.user_auth_state_key: user_info, } - def get_user_groups(self, auth_state: dict): + async def get_user_groups(self, auth_state: dict): """ Returns a set of groups the user belongs to based on auth_state_groups_key and provided auth_state. - If auth_state_groups_key is a callable, it returns the list of groups directly. + Callable may be async. - If auth_state_groups_key is a nested dictionary key like "permissions.groups", this function returns auth_state["permissions"]["groups"]. """ if callable(self.auth_state_groups_key): - return set(self.auth_state_groups_key(auth_state)) + groups = self.auth_state_groups_key(auth_state) + if isawaitable(groups): + groups = await groups + return set(groups) try: return set( reduce(dict.get, self.auth_state_groups_key.split("."), auth_state) @@ -1126,6 +1132,8 @@ async def update_auth_model(self, auth_model): if self.manage_groups: auth_state = auth_model["auth_state"] user_groups = self.get_user_groups(auth_state) + if isawaitable(user_groups): + user_groups = await user_groups auth_model["groups"] = sorted(user_groups) @@ -1223,6 +1231,8 @@ async def check_allowed(self, username, auth_model): if self.manage_groups and self.allowed_groups: auth_state = auth_model["auth_state"] user_groups = self.get_user_groups(auth_state) + if isawaitable(user_groups): + user_groups = await user_groups if any(user_groups & self.allowed_groups): return True From 881bc9d489e1410de20dbf920f248dcf3d717d3f Mon Sep 17 00:00:00 2001 From: Min RK Date: Thu, 29 Aug 2024 11:20:53 +0200 Subject: [PATCH 2/4] add modify_auth_state_hook for populating additional fields in auth_state --- oauthenticator/oauth2.py | 39 +++++++++++++++++++++++- oauthenticator/tests/test_generic.py | 45 +++++++++++++++++++++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index d743795e..fa7145dc 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -368,6 +368,28 @@ class OAuthenticator(Authenticator): """, ) + modify_auth_state_hook = Callable( + config=True, + default_value=None, + allow_none=True, + help=""" + Callable to modify `auth_state`. + + Will be called with the Authenticator instance and the existing auth_state dictionary + and must return the new auth_state dictionary: + + ``` + auth_state = [await] modify_auth_state_hook(authenticator, auth_state) + ``` + + This hook is called _before_ populating group membership, + so can be used to make additional requests to populate additional fields + which may then be consumed by `auth_state_groups_key` to populate groups. + + This hook may be async. + """, + ) + @observe("allowed_groups", "admin_groups", "auth_state_groups_key") def _requires_manage_groups(self, change): """ @@ -1144,6 +1166,18 @@ async def update_auth_model(self, auth_model): auth_model["admin"] = bool(user_groups & self.admin_groups) return auth_model + async def _call_modify_auth_state_hook(self, auth_state): + """Call the modify_auth_state_hook""" + try: + auth_state = self.modify_auth_state_hook(self, auth_state) + except Exception as e: + # let hook errors raise, nothing in auth should suppress errors + self.log.error(f"Error in modify_auth_state_hook: {e}") + raise + if isawaitable(auth_state): + auth_state = await auth_state + return auth_state + async def authenticate(self, handler, data=None, **kwargs): """ A JupyterHub Authenticator's authenticate method's job is: @@ -1174,11 +1208,14 @@ async def authenticate(self, handler, data=None, **kwargs): if refresh_token: token_info["refresh_token"] = refresh_token + auth_state = self.build_auth_state_dict(token_info, user_info) + if self.modify_auth_state_hook is not None: + auth_state = await self._call_modify_auth_state_hook(auth_state) # build the auth model to be read if authentication goes right auth_model = { "name": username, "admin": True if username in self.admin_users else None, - "auth_state": self.build_auth_state_dict(token_info, user_info), + "auth_state": auth_state, } # update the auth_model with info to later authorize the user in diff --git a/oauthenticator/tests/test_generic.py b/oauthenticator/tests/test_generic.py index 1f656a47..11da1026 100644 --- a/oauthenticator/tests/test_generic.py +++ b/oauthenticator/tests/test_generic.py @@ -3,7 +3,7 @@ from functools import partial import jwt -from pytest import fixture, mark, raises +from pytest import fixture, mark, param, raises from traitlets.config import Config from ..generic import GenericOAuthenticator @@ -315,6 +315,49 @@ async def test_generic_data(get_authenticator, generic_client): assert auth_model +def sync_auth_state_hook(authenticator, auth_state): + auth_state["sync"] = True + auth_state["hook_groups"] = ["alpha", "beta", auth_state["oauth_user"]["username"]] + return auth_state + + +async def async_auth_state_hook(authenticator, auth_state): + auth_state["async"] = True + auth_state["hook_groups"] = [ + "alpha", + "beta", + auth_state[authenticator.user_auth_state_key]["username"], + ] + return auth_state + + +@mark.parametrize( + "auth_state_hook", + [param(sync_auth_state_hook, id="sync"), param(async_auth_state_hook, id="async")], +) +async def test_modify_auth_state_hook( + get_authenticator, generic_client, auth_state_hook +): + c = Config() + c.GenericOAuthenticator.allow_all = True + c.OAuthenticator.modify_auth_state_hook = auth_state_hook + c.OAuthenticator.auth_state_groups_key = "hook_groups" + c.OAuthenticator.manage_groups = True + + authenticator = get_authenticator(config=c) + assert authenticator.modify_auth_state_hook is auth_state_hook + + handled_user_model = user_model("user1") + handler = generic_client.handler_for_user(handled_user_model) + data = {"testing": "data"} + auth_model = await authenticator.authenticate(handler, data) + if auth_state_hook is sync_auth_state_hook: + assert auth_model["auth_state"]["sync"] + else: + assert auth_model["auth_state"]["async"] + assert sorted(auth_model["groups"]) == ["alpha", "beta", "user1"] + + @mark.parametrize( ["allowed_scopes", "allowed"], [(["advanced"], False), (["basic"], True)] ) From ae48b338d3736d59f2a20205dc98a224a8a59a5c Mon Sep 17 00:00:00 2001 From: Min RK Date: Thu, 29 Aug 2024 11:55:00 +0200 Subject: [PATCH 3/4] allow build_auth_state to be async add versionchanged notes --- oauthenticator/oauth2.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index fa7145dc..d85e70c5 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -365,6 +365,10 @@ class OAuthenticator(Authenticator): Callables may be async. Requires `manage_groups` to also be `True`. + + .. versionchanged:: 16.4 + + Added async support. """, ) @@ -387,6 +391,8 @@ class OAuthenticator(Authenticator): which may then be consumed by `auth_state_groups_key` to populate groups. This hook may be async. + + .. versionadded: 16.4 """, ) @@ -1071,6 +1077,7 @@ async def token_to_user(self, token_info): def build_auth_state_dict(self, token_info, user_info): """ Builds the `auth_state` dict that will be returned by a succesfull `authenticate` method call. + May be async (requires oauthenticator >= 16.4). Args: token_info: the dictionary returned by the token request (exchanging the OAuth code for an Access Token) @@ -1086,6 +1093,9 @@ def build_auth_state_dict(self, token_info, user_info): - self.user_auth_state_key: the full user_info response Called by the :meth:`oauthenticator.OAuthenticator.authenticate` + + .. versionchanged:: 16.4 + This method be async. """ # We know for sure the `access_token` key exists, oterwise we would have errored out already @@ -1120,6 +1130,10 @@ async def get_user_groups(self, auth_state: dict): - If auth_state_groups_key is a nested dictionary key like "permissions.groups", this function returns auth_state["permissions"]["groups"]. + + .. versionchanged:: 16.4 + This method may be async. + The base implementation is now async. """ if callable(self.auth_state_groups_key): groups = self.auth_state_groups_key(auth_state) @@ -1209,6 +1223,8 @@ async def authenticate(self, handler, data=None, **kwargs): token_info["refresh_token"] = refresh_token auth_state = self.build_auth_state_dict(token_info, user_info) + if isawaitable(auth_state): + auth_state = await auth_state if self.modify_auth_state_hook is not None: auth_state = await self._call_modify_auth_state_hook(auth_state) # build the auth model to be read if authentication goes right From 010c2889c8ea3a2477e71ba1c13a3b63b755f346 Mon Sep 17 00:00:00 2001 From: Min RK Date: Thu, 29 Aug 2024 13:02:55 +0200 Subject: [PATCH 4/4] make sure to catch async errors, too --- oauthenticator/oauth2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index d85e70c5..e04ea952 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -1184,12 +1184,12 @@ async def _call_modify_auth_state_hook(self, auth_state): """Call the modify_auth_state_hook""" try: auth_state = self.modify_auth_state_hook(self, auth_state) + if isawaitable(auth_state): + auth_state = await auth_state except Exception as e: # let hook errors raise, nothing in auth should suppress errors self.log.error(f"Error in modify_auth_state_hook: {e}") raise - if isawaitable(auth_state): - auth_state = await auth_state return auth_state async def authenticate(self, handler, data=None, **kwargs):