Skip to content

Commit

Permalink
add modify_auth_state_hook
Browse files Browse the repository at this point in the history
for populating additional fields in auth_state
  • Loading branch information
minrk committed Aug 29, 2024
1 parent 6247c6e commit 881bc9d
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 2 deletions.
39 changes: 38 additions & 1 deletion oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,28 @@ class OAuthenticator(Authenticator):
""",
)

modify_auth_state_hook = Callable(
config=True,
default_value=None,
allow_none=True,
help="""
Callable to modify `auth_state`.
Will be called with the Authenticator instance and the existing auth_state dictionary
and must return the new auth_state dictionary:
```
auth_state = [await] modify_auth_state_hook(authenticator, auth_state)
```
This hook is called _before_ populating group membership,
so can be used to make additional requests to populate additional fields
which may then be consumed by `auth_state_groups_key` to populate groups.
This hook may be async.
""",
)

@observe("allowed_groups", "admin_groups", "auth_state_groups_key")
def _requires_manage_groups(self, change):
"""
Expand Down Expand Up @@ -1144,6 +1166,18 @@ async def update_auth_model(self, auth_model):
auth_model["admin"] = bool(user_groups & self.admin_groups)
return auth_model

async def _call_modify_auth_state_hook(self, auth_state):
"""Call the modify_auth_state_hook"""
try:
auth_state = self.modify_auth_state_hook(self, auth_state)
except Exception as e:
# let hook errors raise, nothing in auth should suppress errors
self.log.error(f"Error in modify_auth_state_hook: {e}")
raise
if isawaitable(auth_state):
auth_state = await auth_state
return auth_state

async def authenticate(self, handler, data=None, **kwargs):
"""
A JupyterHub Authenticator's authenticate method's job is:
Expand Down Expand Up @@ -1174,11 +1208,14 @@ async def authenticate(self, handler, data=None, **kwargs):
if refresh_token:
token_info["refresh_token"] = refresh_token

auth_state = self.build_auth_state_dict(token_info, user_info)
if self.modify_auth_state_hook is not None:
auth_state = await self._call_modify_auth_state_hook(auth_state)
# build the auth model to be read if authentication goes right
auth_model = {
"name": username,
"admin": True if username in self.admin_users else None,
"auth_state": self.build_auth_state_dict(token_info, user_info),
"auth_state": auth_state,
}

# update the auth_model with info to later authorize the user in
Expand Down
45 changes: 44 additions & 1 deletion oauthenticator/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import partial

import jwt
from pytest import fixture, mark, raises
from pytest import fixture, mark, param, raises
from traitlets.config import Config

from ..generic import GenericOAuthenticator
Expand Down Expand Up @@ -315,6 +315,49 @@ async def test_generic_data(get_authenticator, generic_client):
assert auth_model


def sync_auth_state_hook(authenticator, auth_state):
auth_state["sync"] = True
auth_state["hook_groups"] = ["alpha", "beta", auth_state["oauth_user"]["username"]]
return auth_state


async def async_auth_state_hook(authenticator, auth_state):
auth_state["async"] = True
auth_state["hook_groups"] = [
"alpha",
"beta",
auth_state[authenticator.user_auth_state_key]["username"],
]
return auth_state


@mark.parametrize(
"auth_state_hook",
[param(sync_auth_state_hook, id="sync"), param(async_auth_state_hook, id="async")],
)
async def test_modify_auth_state_hook(
get_authenticator, generic_client, auth_state_hook
):
c = Config()
c.GenericOAuthenticator.allow_all = True
c.OAuthenticator.modify_auth_state_hook = auth_state_hook
c.OAuthenticator.auth_state_groups_key = "hook_groups"
c.OAuthenticator.manage_groups = True

authenticator = get_authenticator(config=c)
assert authenticator.modify_auth_state_hook is auth_state_hook

handled_user_model = user_model("user1")
handler = generic_client.handler_for_user(handled_user_model)
data = {"testing": "data"}
auth_model = await authenticator.authenticate(handler, data)
if auth_state_hook is sync_auth_state_hook:
assert auth_model["auth_state"]["sync"]
else:
assert auth_model["auth_state"]["async"]
assert sorted(auth_model["groups"]) == ["alpha", "beta", "user1"]


@mark.parametrize(
["allowed_scopes", "allowed"], [(["advanced"], False), (["basic"], True)]
)
Expand Down

0 comments on commit 881bc9d

Please sign in to comment.