Skip to content

Commit

Permalink
Use auth_state, not auth_model
Browse files Browse the repository at this point in the history
auth_model is currently not documented nor exposed to customizable
code (without inheriting from a class and modifying it). So instead
of documenting auth_model and trying to keep that stable, we rely instead
on auth_model being populated.
  • Loading branch information
yuvipanda committed May 2, 2024
1 parent 8eb9304 commit b337015
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 32 deletions.
22 changes: 11 additions & 11 deletions oauthenticator/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,28 @@ def _login_service_default(self):
""",
)

# Initialize value of auth_model_groups_key based on what is in claim_groups_key
@default('auth_model_groups_key')
def _auth_model_groups_key_default(self):
# Initialize value of auth_state_groups_key based on what is in claim_groups_key
@default('auth_state_groups_key')
def _auth_state_groups_key_default(self):
if callable(self.claim_groups_key):
# Automatically wrap the claim_gorups_key call so it gets what it thinks it should get
return lambda auth_model: self.claim_groups_key(
auth_model["auth_state"][self.user_auth_state_key]
return lambda auth_state: self.claim_groups_key(
auth_state[self.user_auth_state_key]
)
else:
return f"auth_state.{self.user_auth_state_key}.{self.claim_groups_key}"
return f"{self.user_auth_state_key}.{self.claim_groups_key}"

# propagate any changes to claim_groups_key to auth_model_groups_key
# propagate any changes to claim_groups_key to auth_state_groups_key
@observe("claim_groups_key")
def _claim_groups_key_changed(self, change):
if callable(change.new):
# Automatically wrap the claim_gorups_key call so it gets what it thinks it should get
self.auth_model_groups_key = lambda auth_model: self.claim_groups_key(
auth_model["auth_state"][self.user_auth_state_key]
self.auth_state_groups_key = lambda auth_state: self.claim_groups_key(
auth_state[self.user_auth_state_key]
)
else:
self.auth_model_groups_key = (
f"auth_state.{self.user_auth_state_key}.{self.claim_groups_key}"
self.auth_state_groups_key = (
f"{self.user_auth_state_key}.{self.claim_groups_key}"
)

@default("http_client")
Expand Down
34 changes: 17 additions & 17 deletions oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,16 +348,14 @@ class OAuthenticator(Authenticator):
""",
)

auth_model_groups_key = Union(
auth_state_groups_key = Union(
[Unicode(), Callable()],
config=True,
help="""
Determine groups this user belongs based on contents of the auth model.
Determine groups this user belongs based on contents of auth_state.
Can be a string key name (use periods for nested keys), or a callable
that accepts the auth model (as a dict) and returns the groups list.
TODO: Document what auth_model actually looks like.
that accepts the auth state (as a dict) and returns the groups list.
This configures how group membership in the upstream provider is determined
for use by `allowed_groups`, `admin_groups`, etc. If `manage_groups` is True,
Expand Down Expand Up @@ -1044,26 +1042,26 @@ def build_auth_state_dict(self, token_info, user_info):
self.user_auth_state_key: user_info,
}

def get_user_groups(self, auth_model: dict):
def get_user_groups(self, auth_state: dict):
"""
Returns a set of groups the user belongs to based on auth_model_groups_key
and provided auth_model.
Returns a set of groups the user belongs to based on auth_state_groups_key
and provided auth_state.
- If auth_model_groups_key is a callable, it is meant to return the groups
- If auth_state_groups_key is a callable, it is meant to return the groups
directly.
- If auth_model_groups_key is a nested dictionary key like
- If auth_state_groups_key is a nested dictionary key like
"permissions.groups", this function returns
auth_model["permissions"]["groups"].
auth_state["permissions"]["groups"].
"""
if callable(self.auth_model_groups_key):
return set(self.auth_model_groups_key(auth_model))
if callable(self.auth_state_groups_key):
return set(self.auth_state_groups_key(auth_state))
try:
return set(
reduce(dict.get, self.auth_model_groups_key.split("."), auth_model)
reduce(dict.get, self.auth_state_groups_key.split("."), auth_state)
)
except TypeError:
self.log.error(
f"The auth_model_groups_key {self.auth_model_groups_key} does not exist in the auth_model. Available keys are: {auth_model.keys()}"
f"The auth_state_groups_key {self.auth_state_groups_key} does not exist in the auth_model. Available keys are: {auth_state.keys()}"
)
return set()

Expand All @@ -1083,7 +1081,8 @@ async def update_auth_model(self, auth_model):
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
"""
if self.manage_groups or self.admin_groups:
user_groups = self.get_user_groups(auth_model)
auth_state = auth_model["auth_state"]
user_groups = self.get_user_groups(auth_state)

if self.manage_groups:
auth_model["groups"] = sorted(user_groups)
Expand Down Expand Up @@ -1172,7 +1171,8 @@ async def check_allowed(self, username, auth_model):

# allow users who are members of allowed_groups
if self.allowed_groups:
user_groups = self.get_user_groups(auth_model)
auth_state = auth_model["auth_state"]
user_groups = self.get_user_groups(auth_state)
if any(user_groups & self.allowed_groups):
return True

Expand Down
8 changes: 4 additions & 4 deletions oauthenticator/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,9 @@ async def test_generic_claim_groups_key_nested_strings(
assert auth_model["admin"]


async def test_generic_auth_model_groups_key_callable(get_authenticator, generic_client):
async def test_generic_auth_state_groups_key_callable(get_authenticator, generic_client):
c = Config()
c.GenericOAuthenticator.auth_model_groups_key = lambda r: r["auth_state"]["oauth_user"]["policies"]["roles"]
c.GenericOAuthenticator.auth_state_groups_key = lambda auth_state: auth_state["oauth_user"]["policies"]["roles"]
c.GenericOAuthenticator.allowed_groups = ["super_user"]
authenticator = get_authenticator(config=c)

Expand All @@ -348,11 +348,11 @@ async def test_generic_auth_model_groups_key_callable(get_authenticator, generic
assert auth_model


async def test_generic_auth_model_groups_key_nested_strings(
async def test_generic_auth_state_groups_key_nested_strings(
get_authenticator, generic_client
):
c = Config()
c.GenericOAuthenticator.auth_model_groups_key = "auth_state.oauth_user.permissions.groups"
c.GenericOAuthenticator.auth_state_groups_key = "oauth_user.permissions.groups"
c.GenericOAuthenticator.admin_groups = ["super_user"]
authenticator = get_authenticator(config=c)

Expand Down

0 comments on commit b337015

Please sign in to comment.