Skip to content

Commit

Permalink
Convert claim_groups_key to auth_model_groups_key
Browse files Browse the repository at this point in the history
  • Loading branch information
yuvipanda committed Apr 25, 2024
1 parent 63f6642 commit 86b5bdc
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 26 deletions.
34 changes: 9 additions & 25 deletions oauthenticator/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from jupyterhub.auth import LocalAuthenticator
from jupyterhub.traitlets import Callable
from tornado.httpclient import AsyncHTTPClient
from traitlets import Bool, Dict, Set, Unicode, Union, default
from traitlets import Bool, Dict, Set, Unicode, Union, default, observe

from .oauth2 import OAuthenticator

Expand All @@ -33,6 +33,14 @@ def _login_service_default(self):
""",
)

@observe("claim_groups_key")
def _claim_groups_key_changed(self, change):
if callable(change.new):
# Automatically wrap the claim_gorups_key call so it gets what it thinks it should get
self.auth_model_groups_key = lambda auth_model: self.claim_groups_key(auth_model["auth_state"][self.user_auth_state_key])
else:
self.auth_model_groups_key = f"auth_state.{self.user_auth_state_key}.{self.claim_groups_key}"

@default("http_client")
def _default_http_client(self):
return AsyncHTTPClient(
Expand Down Expand Up @@ -72,30 +80,6 @@ def _default_http_client(self):
""",
)

def get_user_groups(self, user_info):
"""
Returns a set of groups the user belongs to based on claim_groups_key
and provided user_info.
- If claim_groups_key is a callable, it is meant to return the groups
directly.
- If claim_groups_key is a nested dictionary key like
"permissions.groups", this function returns
user_info["permissions"]["groups"].
Note that this method is introduced by GenericOAuthenticator and not
present in the base class.
"""
if callable(self.claim_groups_key):
return set(self.claim_groups_key(user_info))
try:
return set(reduce(dict.get, self.claim_groups_key.split("."), user_info))
except TypeError:
self.log.error(
f"The claim_groups_key {self.claim_groups_key} does not exist in the user token"
)
return set()


class LocalGenericOAuthenticator(LocalAuthenticator, GenericOAuthenticator):
"""A version that mixes in local system user creation"""
40 changes: 39 additions & 1 deletion oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import base64
import json
import os
from functools import reduce
import uuid
from urllib.parse import quote, urlencode, urlparse, urlunparse

Expand Down Expand Up @@ -336,6 +337,22 @@ class OAuthenticator(Authenticator):
""",
)

auth_model_groups_key = Union(
[Unicode(), Callable()],
config=True,
help="""
Determine groups this user belongs based on contents of the auth model.
Can be a string key name (use periods for nested keys), or a callable
that accepts the auth model (as a dict) and returns the groups list.
TODO: Document what auth_model actually looks like.
This configures how group membership in the upstream provider is determined
for use by `allowed_groups`, `admin_groups`, etc. If `manage_groups` is True,
this will also determine users' _JupyterHub_ group membership.
""",
)

authorize_url = Unicode(
config=True,
Expand Down Expand Up @@ -1016,6 +1033,27 @@ def build_auth_state_dict(self, token_info, user_info):
self.user_auth_state_key: user_info,
}

def get_user_groups(self, auth_model: dict):
"""
Returns a set of groups the user belongs to based on claim_groups_key
and provided auth_model.
- If claim_groups_key is a callable, it is meant to return the groups
directly.
- If claim_groups_key is a nested dictionary key like
"permissions.groups", this function returns
auth_model["permissions"]["groups"].
"""
if callable(self.claim_groups_key):
return set(self.auth_model_groups_key(auth_model))
try:
return set(reduce(dict.get, self.auth_model_groups_key.split("."), auth_model))
except TypeError:
self.log.error(
f"The auth_model_groups_key {self.auth_model_groups_key} does not exist in the auth_model. Available keys are: {auth_model.keys()}"
)
return set()

async def update_auth_model(self, auth_model):
"""
Updates and returns the `auth_model` dict.
Expand All @@ -1033,7 +1071,7 @@ async def update_auth_model(self, auth_model):
"""
if self.manage_groups or self.admin_groups:
user_info = auth_model["auth_state"][self.user_auth_state_key]
user_groups = self.get_user_groups(user_info)
user_groups = self.get_user_groups(auth_model)

if self.manage_groups:
auth_model["groups"] = sorted(user_groups)
Expand Down

0 comments on commit 86b5bdc

Please sign in to comment.