Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[All] add OAuthenticator.modify_auth_state_hook, allow get_user_groups / auth_state_groups_key to be async #751

Merged
merged 4 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 66 additions & 3 deletions oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import uuid
from functools import reduce
from inspect import isawaitable
from urllib.parse import quote, urlencode, urlparse, urlunparse

import jwt
Expand Down Expand Up @@ -361,8 +362,37 @@ class OAuthenticator(Authenticator):

Can be a string key name (use periods for nested keys), or a callable
that accepts the auth state (as a dict) and returns the groups list.
Callables may be async.

Requires `manage_groups` to also be `True`.

.. versionchanged:: 16.4

Added async support.
""",
)

modify_auth_state_hook = Callable(
config=True,
default_value=None,
allow_none=True,
help="""
Callable to modify `auth_state`.

Will be called with the Authenticator instance and the existing auth_state dictionary
and must return the new auth_state dictionary:

```
auth_state = [await] modify_auth_state_hook(authenticator, auth_state)
```

This hook is called _before_ populating group membership,
so can be used to make additional requests to populate additional fields
which may then be consumed by `auth_state_groups_key` to populate groups.

This hook may be async.

.. versionadded: 16.4
""",
)

Expand Down Expand Up @@ -1047,6 +1077,7 @@ async def token_to_user(self, token_info):
def build_auth_state_dict(self, token_info, user_info):
"""
Builds the `auth_state` dict that will be returned by a succesfull `authenticate` method call.
May be async (requires oauthenticator >= 16.4).

Args:
token_info: the dictionary returned by the token request (exchanging the OAuth code for an Access Token)
Expand All @@ -1062,6 +1093,9 @@ def build_auth_state_dict(self, token_info, user_info):
- self.user_auth_state_key: the full user_info response

Called by the :meth:`oauthenticator.OAuthenticator.authenticate`

.. versionchanged:: 16.4
This method be async.
"""

# We know for sure the `access_token` key exists, oterwise we would have errored out already
Expand All @@ -1086,18 +1120,26 @@ def build_auth_state_dict(self, token_info, user_info):
self.user_auth_state_key: user_info,
}

def get_user_groups(self, auth_state: dict):
async def get_user_groups(self, auth_state: dict):
"""
Returns a set of groups the user belongs to based on auth_state_groups_key
and provided auth_state.

- If auth_state_groups_key is a callable, it returns the list of groups directly.
Callable may be async.
- If auth_state_groups_key is a nested dictionary key like
"permissions.groups", this function returns
auth_state["permissions"]["groups"].

.. versionchanged:: 16.4
This method may be async.
The base implementation is now async.
"""
if callable(self.auth_state_groups_key):
return set(self.auth_state_groups_key(auth_state))
groups = self.auth_state_groups_key(auth_state)
if isawaitable(groups):
groups = await groups
return set(groups)
try:
return set(
reduce(dict.get, self.auth_state_groups_key.split("."), auth_state)
Expand Down Expand Up @@ -1126,6 +1168,8 @@ async def update_auth_model(self, auth_model):
if self.manage_groups:
auth_state = auth_model["auth_state"]
user_groups = self.get_user_groups(auth_state)
if isawaitable(user_groups):
user_groups = await user_groups

auth_model["groups"] = sorted(user_groups)

Expand All @@ -1136,6 +1180,18 @@ async def update_auth_model(self, auth_model):
auth_model["admin"] = bool(user_groups & self.admin_groups)
return auth_model

async def _call_modify_auth_state_hook(self, auth_state):
"""Call the modify_auth_state_hook"""
try:
auth_state = self.modify_auth_state_hook(self, auth_state)
if isawaitable(auth_state):
auth_state = await auth_state
except Exception as e:
# let hook errors raise, nothing in auth should suppress errors
self.log.error(f"Error in modify_auth_state_hook: {e}")
raise
return auth_state

async def authenticate(self, handler, data=None, **kwargs):
"""
A JupyterHub Authenticator's authenticate method's job is:
Expand Down Expand Up @@ -1166,11 +1222,16 @@ async def authenticate(self, handler, data=None, **kwargs):
if refresh_token:
token_info["refresh_token"] = refresh_token

auth_state = self.build_auth_state_dict(token_info, user_info)
if isawaitable(auth_state):
auth_state = await auth_state
if self.modify_auth_state_hook is not None:
auth_state = await self._call_modify_auth_state_hook(auth_state)
# build the auth model to be read if authentication goes right
auth_model = {
"name": username,
"admin": True if username in self.admin_users else None,
"auth_state": self.build_auth_state_dict(token_info, user_info),
"auth_state": auth_state,
}

# update the auth_model with info to later authorize the user in
Expand Down Expand Up @@ -1223,6 +1284,8 @@ async def check_allowed(self, username, auth_model):
if self.manage_groups and self.allowed_groups:
auth_state = auth_model["auth_state"]
user_groups = self.get_user_groups(auth_state)
if isawaitable(user_groups):
user_groups = await user_groups
if any(user_groups & self.allowed_groups):
return True

Expand Down
45 changes: 44 additions & 1 deletion oauthenticator/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import partial

import jwt
from pytest import fixture, mark, raises
from pytest import fixture, mark, param, raises
from traitlets.config import Config

from ..generic import GenericOAuthenticator
Expand Down Expand Up @@ -315,6 +315,49 @@ async def test_generic_data(get_authenticator, generic_client):
assert auth_model


def sync_auth_state_hook(authenticator, auth_state):
auth_state["sync"] = True
auth_state["hook_groups"] = ["alpha", "beta", auth_state["oauth_user"]["username"]]
return auth_state


async def async_auth_state_hook(authenticator, auth_state):
auth_state["async"] = True
auth_state["hook_groups"] = [
"alpha",
"beta",
auth_state[authenticator.user_auth_state_key]["username"],
]
return auth_state


@mark.parametrize(
"auth_state_hook",
[param(sync_auth_state_hook, id="sync"), param(async_auth_state_hook, id="async")],
)
async def test_modify_auth_state_hook(
get_authenticator, generic_client, auth_state_hook
):
c = Config()
c.GenericOAuthenticator.allow_all = True
c.OAuthenticator.modify_auth_state_hook = auth_state_hook
c.OAuthenticator.auth_state_groups_key = "hook_groups"
c.OAuthenticator.manage_groups = True

authenticator = get_authenticator(config=c)
assert authenticator.modify_auth_state_hook is auth_state_hook

handled_user_model = user_model("user1")
handler = generic_client.handler_for_user(handled_user_model)
data = {"testing": "data"}
auth_model = await authenticator.authenticate(handler, data)
if auth_state_hook is sync_auth_state_hook:
assert auth_model["auth_state"]["sync"]
else:
assert auth_model["auth_state"]["async"]
assert sorted(auth_model["groups"]) == ["alpha", "beta", "user1"]


@mark.parametrize(
["allowed_scopes", "allowed"], [(["advanced"], False), (["basic"], True)]
)
Expand Down