From 881bc9d489e1410de20dbf920f248dcf3d717d3f Mon Sep 17 00:00:00 2001 From: Min RK Date: Thu, 29 Aug 2024 11:20:53 +0200 Subject: [PATCH] add modify_auth_state_hook for populating additional fields in auth_state --- oauthenticator/oauth2.py | 39 +++++++++++++++++++++++- oauthenticator/tests/test_generic.py | 45 +++++++++++++++++++++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index d743795e..fa7145dc 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -368,6 +368,28 @@ class OAuthenticator(Authenticator): """, ) + modify_auth_state_hook = Callable( + config=True, + default_value=None, + allow_none=True, + help=""" + Callable to modify `auth_state`. + + Will be called with the Authenticator instance and the existing auth_state dictionary + and must return the new auth_state dictionary: + + ``` + auth_state = [await] modify_auth_state_hook(authenticator, auth_state) + ``` + + This hook is called _before_ populating group membership, + so can be used to make additional requests to populate additional fields + which may then be consumed by `auth_state_groups_key` to populate groups. + + This hook may be async. + """, + ) + @observe("allowed_groups", "admin_groups", "auth_state_groups_key") def _requires_manage_groups(self, change): """ @@ -1144,6 +1166,18 @@ async def update_auth_model(self, auth_model): auth_model["admin"] = bool(user_groups & self.admin_groups) return auth_model + async def _call_modify_auth_state_hook(self, auth_state): + """Call the modify_auth_state_hook""" + try: + auth_state = self.modify_auth_state_hook(self, auth_state) + except Exception as e: + # let hook errors raise, nothing in auth should suppress errors + self.log.error(f"Error in modify_auth_state_hook: {e}") + raise + if isawaitable(auth_state): + auth_state = await auth_state + return auth_state + async def authenticate(self, handler, data=None, **kwargs): """ A JupyterHub Authenticator's authenticate method's job is: @@ -1174,11 +1208,14 @@ async def authenticate(self, handler, data=None, **kwargs): if refresh_token: token_info["refresh_token"] = refresh_token + auth_state = self.build_auth_state_dict(token_info, user_info) + if self.modify_auth_state_hook is not None: + auth_state = await self._call_modify_auth_state_hook(auth_state) # build the auth model to be read if authentication goes right auth_model = { "name": username, "admin": True if username in self.admin_users else None, - "auth_state": self.build_auth_state_dict(token_info, user_info), + "auth_state": auth_state, } # update the auth_model with info to later authorize the user in diff --git a/oauthenticator/tests/test_generic.py b/oauthenticator/tests/test_generic.py index 1f656a47..11da1026 100644 --- a/oauthenticator/tests/test_generic.py +++ b/oauthenticator/tests/test_generic.py @@ -3,7 +3,7 @@ from functools import partial import jwt -from pytest import fixture, mark, raises +from pytest import fixture, mark, param, raises from traitlets.config import Config from ..generic import GenericOAuthenticator @@ -315,6 +315,49 @@ async def test_generic_data(get_authenticator, generic_client): assert auth_model +def sync_auth_state_hook(authenticator, auth_state): + auth_state["sync"] = True + auth_state["hook_groups"] = ["alpha", "beta", auth_state["oauth_user"]["username"]] + return auth_state + + +async def async_auth_state_hook(authenticator, auth_state): + auth_state["async"] = True + auth_state["hook_groups"] = [ + "alpha", + "beta", + auth_state[authenticator.user_auth_state_key]["username"], + ] + return auth_state + + +@mark.parametrize( + "auth_state_hook", + [param(sync_auth_state_hook, id="sync"), param(async_auth_state_hook, id="async")], +) +async def test_modify_auth_state_hook( + get_authenticator, generic_client, auth_state_hook +): + c = Config() + c.GenericOAuthenticator.allow_all = True + c.OAuthenticator.modify_auth_state_hook = auth_state_hook + c.OAuthenticator.auth_state_groups_key = "hook_groups" + c.OAuthenticator.manage_groups = True + + authenticator = get_authenticator(config=c) + assert authenticator.modify_auth_state_hook is auth_state_hook + + handled_user_model = user_model("user1") + handler = generic_client.handler_for_user(handled_user_model) + data = {"testing": "data"} + auth_model = await authenticator.authenticate(handler, data) + if auth_state_hook is sync_auth_state_hook: + assert auth_model["auth_state"]["sync"] + else: + assert auth_model["auth_state"]["async"] + assert sorted(auth_model["groups"]) == ["alpha", "beta", "user1"] + + @mark.parametrize( ["allowed_scopes", "allowed"], [(["advanced"], False), (["basic"], True)] )