Skip to content

Commit

Permalink
[wip] Refactor security assoc handling
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavcs committed Aug 12, 2024
1 parent 287b889 commit 5989394
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 55 deletions.
24 changes: 5 additions & 19 deletions lib/galaxy/managers/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from galaxy.managers.context import ProvidesAppContext
from galaxy.model import Group
from galaxy.model.base import transaction
from galaxy.model.db.role import get_roles_by_ids
from galaxy.model.db.user import get_users_by_ids
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.schema.fields import Security
from galaxy.schema.groups import (
Expand Down Expand Up @@ -54,12 +52,11 @@ def create(self, trans: ProvidesAppContext, payload: GroupCreatePayload):

group = model.Group(name=name)
sa_session.add(group)
user_ids = payload.user_ids
users = get_users_by_ids(sa_session, user_ids)
role_ids = payload.role_ids
roles = get_roles_by_ids(sa_session, role_ids)
trans.app.security_agent.set_entity_group_associations(groups=[group], roles=roles, users=users)

with transaction(sa_session):
trans.app.security_agent.set_group_user_and_role_associations(
group, user_ids=payload.user_ids, role_ids=payload.role_ids
)
sa_session.commit()

encoded_id = Security.security.encode_id(group.id)
Expand Down Expand Up @@ -90,18 +87,7 @@ def update(self, trans: ProvidesAppContext, group_id: int, payload: GroupUpdateP
group.name = name
sa_session.add(group)

users = None
if payload.user_ids is not None:
users = get_users_by_ids(sa_session, payload.user_ids)

roles = None
if payload.role_ids is not None:
roles = get_roles_by_ids(sa_session, payload.role_ids)

self._app.security_agent.set_entity_group_associations(
groups=[group], roles=roles, users=users, delete_existing_assocs=False
)

self._app.security_agent.set_group_user_and_roles(group, user_ids=payload.user_ids, role_ids=payload.role_ids)
with transaction(sa_session):
sa_session.commit()

Expand Down
116 changes: 97 additions & 19 deletions lib/galaxy/model/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,31 @@
datetime,
timedelta,
)
from typing import List
from typing import (
List,
Optional,
)

from psycopg2.errors import (
ForeignKeyViolation,
UniqueViolation,
)
from sqlalchemy import (
and_,
delete,
false,
func,
insert,
not_,
or_,
select,
)
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload
from sqlalchemy.sql import text

import galaxy.model
from galaxy.exceptions import RequestParameterInvalidException
from galaxy.model import (
Dataset,
DatasetPermissions,
Expand Down Expand Up @@ -1447,6 +1458,91 @@ def get_showable_folders(
self.get_showable_folders(user, roles, folder, actions_to_check, showable_folders=showable_folders)
return showable_folders

# def set_user_group_and_role_associations(
# self, user_id: int, group_ids: Optional[List[int]] = None, role_ids: Optional[List[int]] = None
# ) -> None:
# """ Set user groups and user roles, replacing current associations."""
# self._set_user_groups(user_id, group_ids or [])
# self._set_user_roles(user_id, role_ids or [])
# self.sa_session.commit()
#
def set_group_user_and_role_associations(
# TODO set group type
self,
group,
*,
user_ids: Optional[List[int]] = None,
role_ids: Optional[List[int]] = None,
) -> None:
"""Set group users and group roles, replacing current associations."""
self._ensure_model_instance_has_id(group)
self._set_group_users(group.id, user_ids or [])
self._set_group_roles(group.id, role_ids or [])

#
# def set_role_user_and_group_associations(
# self, role_id: int, user_ids: Optional[List[int]] = None, group_ids: Optional[List[int]] = None
# ) -> None:
# """ Set role users and role groups, replacing current associations."""
# self._set_group_users(role_id, user_ids or [])
# self._set_group_roles(role_id, grour_ids or [])
# self.sa_session.commit()
#
# def _set_user_groups(self, user, groups):
# delete_stmt = delete(UserGroupAssociation).where(UserGroupAssociation.user_id == user.id)
# insert_values = [{"user_id": user.id, "group_id": group_id} for group_id in groups]
# self._set_associations(UserGroupAssociation, delete_stmt, insert_values)
#

def _ensure_model_instance_has_id(self, model_instance):
# If model_instance is new, it may have not been assigned a database id yet, which is required
# for creating association records. Flush if that's the case.
if model_instance.id is None:
self.sa_session.flush([model_instance])

def _set_group_users(self, group_id, users):
delete_stmt = delete(UserGroupAssociation).where(UserGroupAssociation.group_id == group_id)
insert_values = [{"group_id": group_id, "user_id": user_id} for user_id in users]
self._set_associations(UserGroupAssociation, delete_stmt, insert_values)

# def _set_user_roles(self, user, roles):
# delete_stmt = delete(UserRoleAssociation).where(UserRoleAssociation.user_id == user.id)
# insert_values = [{"user_id": user.id, "role_id": role_id} for role_id in roles]
# self._set_associations(UserRoleAssociation, delete_stmt, insert_values)
#
# def _set_role_users(self, role, users):
# delete_stmt = delete(UserRoleAssociation).where(UserRoleAssociation.role_id == role.id)
# insert_values = [{"role_id": role.id, "user_id": user_id} for user_id in users]
# self._set_associations(UserRoleAssociation, delete_stmt, insert_values)
#
def _set_group_roles(self, group_id, roles):
delete_stmt = delete(GroupRoleAssociation).where(GroupRoleAssociation.group_id == group_id)
insert_values = [{"group_id": group_id, "role_id": role_id} for role_id in roles]
self._set_associations(GroupRoleAssociation, delete_stmt, insert_values)

# def _set_role_groups(self, role, groups):
# delete_stmt = delete(GroupRoleAssociation).where(GroupRoleAssociation.role_id == role.id)
# insert_values = [{"role_id": role.id, "group_id": group_id} for group_id in groups]
# self._set_associations(GroupRoleAssociation, delete_stmt, insert_values)

def _set_associations(self, assoc_model, delete_stmt, insert_values):
# Ensure parent model has a database-assigned id
if assoc_model.id is None:
self.sa_session.flush(assoc_model)
# Delete current associations
self.sa_session.execute(delete_stmt)
# Create new associations
try:
self.sa_session.execute(insert(assoc_model), insert_values)
except IntegrityError as ie:
if isinstance(ie, UniqueViolation):
log.warning("Attempting to add a duplicate %s record(%s)", assoc_model, insert_values)
pass
elif isinstance(ie, ForeignKeyViolation):
raise RequestParameterInvalidException(ie)
else:
raise

def set_entity_user_associations(self, users=None, roles=None, groups=None, delete_existing_assocs=True):
users = users or []
roles = roles or []
Expand All @@ -1468,24 +1564,6 @@ def set_entity_user_associations(self, users=None, roles=None, groups=None, dele
for group in groups:
self.associate_components(user=user, group=group)

def set_entity_group_associations(self, groups=None, users=None, roles=None, delete_existing_assocs=True):
users = users or []
roles = roles or []
groups = groups or []
for group in groups:
if delete_existing_assocs:
flush_needed = False
for a in group.roles + group.users:
self.sa_session.delete(a)
flush_needed = True
if flush_needed:
with transaction(self.sa_session):
self.sa_session.commit()
for role in roles:
self.associate_components(group=group, role=role)
for user in users:
self.associate_components(group=group, user=user)

def set_entity_role_associations(self, roles=None, users=None, groups=None, delete_existing_assocs=True):
users = users or []
roles = roles or []
Expand Down
33 changes: 18 additions & 15 deletions lib/galaxy/webapps/galaxy/controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
util,
web,
)
from galaxy.exceptions import ActionInputError
from galaxy.exceptions import (
ActionInputError,
RequestParameterInvalidException,
)
from galaxy.managers.quotas import QuotaManager
from galaxy.model.base import transaction
from galaxy.model.index_filter_util import (
Expand Down Expand Up @@ -912,21 +915,21 @@ def manage_users_and_roles_for_group(self, trans, payload=None, **kwd):
],
}
else:
in_users = [
trans.sa_session.query(trans.app.model.User).get(trans.security.decode_id(x))
for x in util.listify(payload.get("in_users"))
]
in_roles = [
trans.sa_session.query(trans.app.model.Role).get(trans.security.decode_id(x))
for x in util.listify(payload.get("in_roles"))
]
if None in in_users or None in in_roles:
user_ids = [trans.security.decode_id(id) for id in util.listify(payload.get("in_users"))]
role_ids = [trans.security.decode_id(id) for id in util.listify(payload.get("in_roles"))]
try:
trans.app.security_agent.set_group_user_and_role_associations(
group, user_ids=user_ids, role_ids=role_ids
)
with transaction(trans.sa_session):
trans.sa_session.commit()

trans.sa_session.refresh(group)
return {
"message": f"Group '{group.name}' has been updated with {len(user_ids)} associated users and {len(role_ids)} associated roles."
}
except RequestParameterInvalidException:
return self.message_exception(trans, "One or more invalid user/role id has been provided.")
trans.app.security_agent.set_entity_group_associations(groups=[group], users=in_users, roles=in_roles)
trans.sa_session.refresh(group)
return {
"message": f"Group '{group.name}' has been updated with {len(in_users)} associated users and {len(in_roles)} associated roles."
}

@web.legacy_expose_api
@web.require_admin
Expand Down
5 changes: 3 additions & 2 deletions test/unit/app/managers/test_NotificationManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,9 @@ def _create_test_group(self, name: str, users: List[User], roles: List[Role]):
sa_session = self.trans.sa_session
group = Group(name=name)
sa_session.add(group)
self.trans.app.security_agent.set_entity_group_associations(groups=[group], roles=roles, users=users)
sa_session.flush()
user_ids = [user.id for user in users]
role_ids = [role.id for role in roles]
self.trans.app.security_agent.set_group_user_and_role_associations(group, user_ids=user_ids, role_ids=role_ids)
return group

def _create_test_role(self, name: str, users: List[User], groups: List[Group]):
Expand Down

0 comments on commit 5989394

Please sign in to comment.