diff --git a/oauthenticator/generic.py b/oauthenticator/generic.py index 72ab8267..1fb5a301 100644 --- a/oauthenticator/generic.py +++ b/oauthenticator/generic.py @@ -32,28 +32,28 @@ def _login_service_default(self): """, ) - # Initialize value of auth_model_groups_key based on what is in claim_groups_key - @default('auth_model_groups_key') - def _auth_model_groups_key_default(self): + # Initialize value of auth_state_groups_key based on what is in claim_groups_key + @default('auth_state_groups_key') + def _auth_state_groups_key_default(self): if callable(self.claim_groups_key): # Automatically wrap the claim_gorups_key call so it gets what it thinks it should get - return lambda auth_model: self.claim_groups_key( - auth_model["auth_state"][self.user_auth_state_key] + return lambda auth_state: self.claim_groups_key( + auth_state[self.user_auth_state_key] ) else: - return f"auth_state.{self.user_auth_state_key}.{self.claim_groups_key}" + return f"{self.user_auth_state_key}.{self.claim_groups_key}" - # propagate any changes to claim_groups_key to auth_model_groups_key + # propagate any changes to claim_groups_key to auth_state_groups_key @observe("claim_groups_key") def _claim_groups_key_changed(self, change): if callable(change.new): # Automatically wrap the claim_gorups_key call so it gets what it thinks it should get - self.auth_model_groups_key = lambda auth_model: self.claim_groups_key( - auth_model["auth_state"][self.user_auth_state_key] + self.auth_state_groups_key = lambda auth_state: self.claim_groups_key( + auth_state[self.user_auth_state_key] ) else: - self.auth_model_groups_key = ( - f"auth_state.{self.user_auth_state_key}.{self.claim_groups_key}" + self.auth_state_groups_key = ( + f"{self.user_auth_state_key}.{self.claim_groups_key}" ) @default("http_client") diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 94c294d5..750a693d 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -348,16 +348,14 @@ class OAuthenticator(Authenticator): """, ) - auth_model_groups_key = Union( + auth_state_groups_key = Union( [Unicode(), Callable()], config=True, help=""" - Determine groups this user belongs based on contents of the auth model. + Determine groups this user belongs based on contents of auth_state. Can be a string key name (use periods for nested keys), or a callable - that accepts the auth model (as a dict) and returns the groups list. - - TODO: Document what auth_model actually looks like. + that accepts the auth state (as a dict) and returns the groups list. This configures how group membership in the upstream provider is determined for use by `allowed_groups`, `admin_groups`, etc. If `manage_groups` is True, @@ -1044,26 +1042,26 @@ def build_auth_state_dict(self, token_info, user_info): self.user_auth_state_key: user_info, } - def get_user_groups(self, auth_model: dict): + def get_user_groups(self, auth_state: dict): """ - Returns a set of groups the user belongs to based on auth_model_groups_key - and provided auth_model. + Returns a set of groups the user belongs to based on auth_state_groups_key + and provided auth_state. - - If auth_model_groups_key is a callable, it is meant to return the groups + - If auth_state_groups_key is a callable, it is meant to return the groups directly. - - If auth_model_groups_key is a nested dictionary key like + - If auth_state_groups_key is a nested dictionary key like "permissions.groups", this function returns - auth_model["permissions"]["groups"]. + auth_state["permissions"]["groups"]. """ - if callable(self.auth_model_groups_key): - return set(self.auth_model_groups_key(auth_model)) + if callable(self.auth_state_groups_key): + return set(self.auth_state_groups_key(auth_state)) try: return set( - reduce(dict.get, self.auth_model_groups_key.split("."), auth_model) + reduce(dict.get, self.auth_state_groups_key.split("."), auth_state) ) except TypeError: self.log.error( - f"The auth_model_groups_key {self.auth_model_groups_key} does not exist in the auth_model. Available keys are: {auth_model.keys()}" + f"The auth_state_groups_key {self.auth_state_groups_key} does not exist in the auth_model. Available keys are: {auth_state.keys()}" ) return set() @@ -1083,7 +1081,8 @@ async def update_auth_model(self, auth_model): Called by the :meth:`oauthenticator.OAuthenticator.authenticate` """ if self.manage_groups or self.admin_groups: - user_groups = self.get_user_groups(auth_model) + auth_state = auth_model["auth_state"] + user_groups = self.get_user_groups(auth_state) if self.manage_groups: auth_model["groups"] = sorted(user_groups) @@ -1172,7 +1171,8 @@ async def check_allowed(self, username, auth_model): # allow users who are members of allowed_groups if self.allowed_groups: - user_groups = self.get_user_groups(auth_model) + auth_state = auth_model["auth_state"] + user_groups = self.get_user_groups(auth_state) if any(user_groups & self.allowed_groups): return True diff --git a/oauthenticator/tests/test_generic.py b/oauthenticator/tests/test_generic.py index 71f26683..b1abb4a8 100644 --- a/oauthenticator/tests/test_generic.py +++ b/oauthenticator/tests/test_generic.py @@ -335,9 +335,9 @@ async def test_generic_claim_groups_key_nested_strings( assert auth_model["admin"] -async def test_generic_auth_model_groups_key_callable(get_authenticator, generic_client): +async def test_generic_auth_state_groups_key_callable(get_authenticator, generic_client): c = Config() - c.GenericOAuthenticator.auth_model_groups_key = lambda r: r["auth_state"]["oauth_user"]["policies"]["roles"] + c.GenericOAuthenticator.auth_state_groups_key = lambda auth_state: auth_state["oauth_user"]["policies"]["roles"] c.GenericOAuthenticator.allowed_groups = ["super_user"] authenticator = get_authenticator(config=c) @@ -348,11 +348,11 @@ async def test_generic_auth_model_groups_key_callable(get_authenticator, generic assert auth_model -async def test_generic_auth_model_groups_key_nested_strings( +async def test_generic_auth_state_groups_key_nested_strings( get_authenticator, generic_client ): c = Config() - c.GenericOAuthenticator.auth_model_groups_key = "auth_state.oauth_user.permissions.groups" + c.GenericOAuthenticator.auth_state_groups_key = "oauth_user.permissions.groups" c.GenericOAuthenticator.admin_groups = ["super_user"] authenticator = get_authenticator(config=c)