diff --git a/lib/galaxy/managers/groups.py b/lib/galaxy/managers/groups.py index 8edb50218203..cfd1acc94a65 100644 --- a/lib/galaxy/managers/groups.py +++ b/lib/galaxy/managers/groups.py @@ -13,8 +13,6 @@ from galaxy.managers.context import ProvidesAppContext from galaxy.model import Group from galaxy.model.base import transaction -from galaxy.model.db.role import get_roles_by_ids -from galaxy.model.db.user import get_users_by_ids from galaxy.model.scoped_session import galaxy_scoped_session from galaxy.schema.fields import Security from galaxy.schema.groups import ( @@ -54,12 +52,11 @@ def create(self, trans: ProvidesAppContext, payload: GroupCreatePayload): group = model.Group(name=name) sa_session.add(group) - user_ids = payload.user_ids - users = get_users_by_ids(sa_session, user_ids) - role_ids = payload.role_ids - roles = get_roles_by_ids(sa_session, role_ids) - trans.app.security_agent.set_entity_group_associations(groups=[group], roles=roles, users=users) + with transaction(sa_session): + trans.app.security_agent.set_group_user_and_role_associations( + group, user_ids=payload.user_ids, role_ids=payload.role_ids + ) sa_session.commit() encoded_id = Security.security.encode_id(group.id) @@ -90,18 +87,7 @@ def update(self, trans: ProvidesAppContext, group_id: int, payload: GroupUpdateP group.name = name sa_session.add(group) - users = None - if payload.user_ids is not None: - users = get_users_by_ids(sa_session, payload.user_ids) - - roles = None - if payload.role_ids is not None: - roles = get_roles_by_ids(sa_session, payload.role_ids) - - self._app.security_agent.set_entity_group_associations( - groups=[group], roles=roles, users=users, delete_existing_assocs=False - ) - + self._app.security_agent.set_group_user_and_roles(group, user_ids=payload.user_ids, role_ids=payload.role_ids) with transaction(sa_session): sa_session.commit() diff --git a/lib/galaxy/model/security.py b/lib/galaxy/model/security.py index 09b425dcd8eb..f7cc4f31120c 100644 --- a/lib/galaxy/model/security.py +++ b/lib/galaxy/model/security.py @@ -4,20 +4,31 @@ datetime, timedelta, ) -from typing import List +from typing import ( + List, + Optional, +) +from psycopg2.errors import ( + ForeignKeyViolation, + UniqueViolation, +) from sqlalchemy import ( and_, + delete, false, func, + insert, not_, or_, select, ) +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import joinedload from sqlalchemy.sql import text import galaxy.model +from galaxy.exceptions import RequestParameterInvalidException from galaxy.model import ( Dataset, DatasetPermissions, @@ -1447,6 +1458,91 @@ def get_showable_folders( self.get_showable_folders(user, roles, folder, actions_to_check, showable_folders=showable_folders) return showable_folders + # def set_user_group_and_role_associations( + # self, user_id: int, group_ids: Optional[List[int]] = None, role_ids: Optional[List[int]] = None + # ) -> None: + # """ Set user groups and user roles, replacing current associations.""" + # self._set_user_groups(user_id, group_ids or []) + # self._set_user_roles(user_id, role_ids or []) + # self.sa_session.commit() + # + def set_group_user_and_role_associations( + # TODO set group type + self, + group, + *, + user_ids: Optional[List[int]] = None, + role_ids: Optional[List[int]] = None, + ) -> None: + """Set group users and group roles, replacing current associations.""" + self._ensure_model_instance_has_id(group) + self._set_group_users(group.id, user_ids or []) + self._set_group_roles(group.id, role_ids or []) + + # + # def set_role_user_and_group_associations( + # self, role_id: int, user_ids: Optional[List[int]] = None, group_ids: Optional[List[int]] = None + # ) -> None: + # """ Set role users and role groups, replacing current associations.""" + # self._set_group_users(role_id, user_ids or []) + # self._set_group_roles(role_id, grour_ids or []) + # self.sa_session.commit() + # + # def _set_user_groups(self, user, groups): + # delete_stmt = delete(UserGroupAssociation).where(UserGroupAssociation.user_id == user.id) + # insert_values = [{"user_id": user.id, "group_id": group_id} for group_id in groups] + # self._set_associations(UserGroupAssociation, delete_stmt, insert_values) + # + + def _ensure_model_instance_has_id(self, model_instance): + # If model_instance is new, it may have not been assigned a database id yet, which is required + # for creating association records. Flush if that's the case. + if model_instance.id is None: + self.sa_session.flush([model_instance]) + + def _set_group_users(self, group_id, users): + delete_stmt = delete(UserGroupAssociation).where(UserGroupAssociation.group_id == group_id) + insert_values = [{"group_id": group_id, "user_id": user_id} for user_id in users] + self._set_associations(UserGroupAssociation, delete_stmt, insert_values) + + # def _set_user_roles(self, user, roles): + # delete_stmt = delete(UserRoleAssociation).where(UserRoleAssociation.user_id == user.id) + # insert_values = [{"user_id": user.id, "role_id": role_id} for role_id in roles] + # self._set_associations(UserRoleAssociation, delete_stmt, insert_values) + # + # def _set_role_users(self, role, users): + # delete_stmt = delete(UserRoleAssociation).where(UserRoleAssociation.role_id == role.id) + # insert_values = [{"role_id": role.id, "user_id": user_id} for user_id in users] + # self._set_associations(UserRoleAssociation, delete_stmt, insert_values) + # + def _set_group_roles(self, group_id, roles): + delete_stmt = delete(GroupRoleAssociation).where(GroupRoleAssociation.group_id == group_id) + insert_values = [{"group_id": group_id, "role_id": role_id} for role_id in roles] + self._set_associations(GroupRoleAssociation, delete_stmt, insert_values) + + # def _set_role_groups(self, role, groups): + # delete_stmt = delete(GroupRoleAssociation).where(GroupRoleAssociation.role_id == role.id) + # insert_values = [{"role_id": role.id, "group_id": group_id} for group_id in groups] + # self._set_associations(GroupRoleAssociation, delete_stmt, insert_values) + + def _set_associations(self, assoc_model, delete_stmt, insert_values): + # Ensure parent model has a database-assigned id + if assoc_model.id is None: + self.sa_session.flush(assoc_model) + # Delete current associations + self.sa_session.execute(delete_stmt) + # Create new associations + try: + self.sa_session.execute(insert(assoc_model), insert_values) + except IntegrityError as ie: + if isinstance(ie, UniqueViolation): + log.warning("Attempting to add a duplicate %s record(%s)", assoc_model, insert_values) + pass + elif isinstance(ie, ForeignKeyViolation): + raise RequestParameterInvalidException(ie) + else: + raise + def set_entity_user_associations(self, users=None, roles=None, groups=None, delete_existing_assocs=True): users = users or [] roles = roles or [] @@ -1468,24 +1564,6 @@ def set_entity_user_associations(self, users=None, roles=None, groups=None, dele for group in groups: self.associate_components(user=user, group=group) - def set_entity_group_associations(self, groups=None, users=None, roles=None, delete_existing_assocs=True): - users = users or [] - roles = roles or [] - groups = groups or [] - for group in groups: - if delete_existing_assocs: - flush_needed = False - for a in group.roles + group.users: - self.sa_session.delete(a) - flush_needed = True - if flush_needed: - with transaction(self.sa_session): - self.sa_session.commit() - for role in roles: - self.associate_components(group=group, role=role) - for user in users: - self.associate_components(group=group, user=user) - def set_entity_role_associations(self, roles=None, users=None, groups=None, delete_existing_assocs=True): users = users or [] roles = roles or [] diff --git a/lib/galaxy/webapps/galaxy/controllers/admin.py b/lib/galaxy/webapps/galaxy/controllers/admin.py index 7e788bb1cc0f..43cdfcbf1550 100644 --- a/lib/galaxy/webapps/galaxy/controllers/admin.py +++ b/lib/galaxy/webapps/galaxy/controllers/admin.py @@ -13,7 +13,10 @@ util, web, ) -from galaxy.exceptions import ActionInputError +from galaxy.exceptions import ( + ActionInputError, + RequestParameterInvalidException, +) from galaxy.managers.quotas import QuotaManager from galaxy.model.base import transaction from galaxy.model.index_filter_util import ( @@ -912,21 +915,21 @@ def manage_users_and_roles_for_group(self, trans, payload=None, **kwd): ], } else: - in_users = [ - trans.sa_session.query(trans.app.model.User).get(trans.security.decode_id(x)) - for x in util.listify(payload.get("in_users")) - ] - in_roles = [ - trans.sa_session.query(trans.app.model.Role).get(trans.security.decode_id(x)) - for x in util.listify(payload.get("in_roles")) - ] - if None in in_users or None in in_roles: + user_ids = [trans.security.decode_id(id) for id in util.listify(payload.get("in_users"))] + role_ids = [trans.security.decode_id(id) for id in util.listify(payload.get("in_roles"))] + try: + trans.app.security_agent.set_group_user_and_role_associations( + group, user_ids=user_ids, role_ids=role_ids + ) + with transaction(trans.sa_session): + trans.sa_session.commit() + + trans.sa_session.refresh(group) + return { + "message": f"Group '{group.name}' has been updated with {len(user_ids)} associated users and {len(role_ids)} associated roles." + } + except RequestParameterInvalidException: return self.message_exception(trans, "One or more invalid user/role id has been provided.") - trans.app.security_agent.set_entity_group_associations(groups=[group], users=in_users, roles=in_roles) - trans.sa_session.refresh(group) - return { - "message": f"Group '{group.name}' has been updated with {len(in_users)} associated users and {len(in_roles)} associated roles." - } @web.legacy_expose_api @web.require_admin diff --git a/test/unit/app/managers/test_NotificationManager.py b/test/unit/app/managers/test_NotificationManager.py index 6e0c36397c95..76e934cc9e6f 100644 --- a/test/unit/app/managers/test_NotificationManager.py +++ b/test/unit/app/managers/test_NotificationManager.py @@ -524,8 +524,9 @@ def _create_test_group(self, name: str, users: List[User], roles: List[Role]): sa_session = self.trans.sa_session group = Group(name=name) sa_session.add(group) - self.trans.app.security_agent.set_entity_group_associations(groups=[group], roles=roles, users=users) - sa_session.flush() + user_ids = [user.id for user in users] + role_ids = [role.id for role in roles] + self.trans.app.security_agent.set_group_user_and_role_associations(group, user_ids=user_ids, role_ids=role_ids) return group def _create_test_role(self, name: str, users: List[User], groups: List[Group]):