From d839bdac3e43d41f302421c84841e2df435405cc Mon Sep 17 00:00:00 2001 From: Michael Chisholm Date: Thu, 3 Oct 2024 00:56:53 -0400 Subject: [PATCH] fix: repo design harmonizing, misc updates --- src/dioptra/restapi/db/repository/groups.py | 41 +++-- src/dioptra/restapi/db/repository/queues.py | 11 +- src/dioptra/restapi/db/repository/users.py | 161 +++----------------- src/dioptra/restapi/db/repository/utils.py | 61 +++++--- tests/unit/restapi/test_group_repository.py | 4 + tests/unit/restapi/test_user_repository.py | 153 +++++-------------- 6 files changed, 144 insertions(+), 287 deletions(-) diff --git a/src/dioptra/restapi/db/repository/groups.py b/src/dioptra/restapi/db/repository/groups.py index e90f6101b..053ab55a1 100644 --- a/src/dioptra/restapi/db/repository/groups.py +++ b/src/dioptra/restapi/db/repository/groups.py @@ -35,6 +35,7 @@ assert_user_exists, check_user_collision, construct_sql_query_filters, + get_group_id, group_exists, user_exists, ) @@ -193,6 +194,22 @@ def get_by_filters_paged( page_length: int, deletion_policy: DeletionPolicy = DeletionPolicy.NOT_DELETED, ) -> tuple[Sequence[User], int]: + """ + Get some groups according to search criteria. + + Args: + filters: A structure representing search criteria. See + parse_search_text(). + page_start: A row index where the returned page should start + page_length: A row count representing the page length; use <= 0 + for unlimited length + deletion_policy: Whether to look at deleted groups, non-deleted + groups, or all groups + + Returns: + A 2-tuple including a page of Group objects, and a count of the + total number of groups matching the criteria + """ sql_filter = construct_sql_query_filters(filters, self.SEARCHABLE_FIELDS) count_stmt = sa.select(sa.func.count()).select_from(Group) @@ -214,7 +231,9 @@ def get_by_filters_paged( page_stmt = _apply_deletion_policy(page_stmt, deletion_policy) # *must* enforce a sort order for consistent paging page_stmt = page_stmt.order_by(Group.group_id) - page_stmt = page_stmt.offset(page_start).limit(page_length) + page_stmt = page_stmt.offset(page_start) + if page_length > 0: + page_stmt = page_stmt.limit(page_length) groups = self.session.scalars(page_stmt).all() @@ -245,13 +264,13 @@ def num_groups( return num_groups - def num_members(self, group: Group) -> int: + def num_members(self, group: Group | int) -> int: """ Get the number of members in the given group. This is done in a way that's hopefully more efficient than len(group.members). Args: - group: A group + group: A Group object or group_id integer primary key value Returns: A member count @@ -264,10 +283,11 @@ def num_members(self, group: Group) -> int: # len(group.members) might require actually reading all the rows and # translating them into objects. I hope to avoid that. + group_id = get_group_id(group) num_members_stmt = ( sa.select(sa.func.count()) .select_from(GroupMember) - .where(GroupMember.group_id == group.group_id) + .where(GroupMember.group_id == group_id) ) num_members = self.session.scalar(num_members_stmt) @@ -278,13 +298,13 @@ def num_members(self, group: Group) -> int: return num_members - def num_managers(self, group: Group) -> int: + def num_managers(self, group: Group | int) -> int: """ Get the number of managers in the given group. This is done in a way that's hopefully more efficient than len(group.managers). Args: - group: A group + group: A Group object or group_id integer primary key value Returns: A manager count @@ -297,10 +317,11 @@ def num_managers(self, group: Group) -> int: # len(group.members) might require actually reading all the rows and # translating them into objects. I hope to avoid that. + group_id = get_group_id(group) num_members_stmt = ( sa.select(sa.func.count()) .select_from(GroupManager) - .where(GroupManager.group_id == group.group_id) + .where(GroupManager.group_id == group_id) ) num_managers = self.session.scalar(num_members_stmt) @@ -317,8 +338,7 @@ def add_manager( """ Add a user to the managership of the given group. If the user is already a manager, this is a no-op. Permissions are ignored in that - case. To modify permissions for an existing manager, see - UserRepository.set_manager_permissions(). + case. Args: group: A group @@ -408,8 +428,7 @@ def add_member( """ Add a user to the membership of the given group. If the user is already a member, this is a no-op. Permissions are ignored in that - case. To modify permissions for an existing member, see - UserRepository.set_member_permissions(). + case. Args: group: A group diff --git a/src/dioptra/restapi/db/repository/queues.py b/src/dioptra/restapi/db/repository/queues.py index 162db6650..841c22d60 100644 --- a/src/dioptra/restapi/db/repository/queues.py +++ b/src/dioptra/restapi/db/repository/queues.py @@ -182,7 +182,7 @@ def get( Get the latest snapshot of the given queue resource. Args: - resource_ids: An ID or iterable of IDs of queue resource IDs + resource_ids: A single or iterable of queue resource IDs deletion_policy: Whether to look at deleted queues, non-deleted queue, or all queues @@ -290,14 +290,15 @@ def get_by_filters_paged( deletion_policy: DeletionPolicy = DeletionPolicy.NOT_DELETED, ) -> tuple[Sequence[Queue], int]: """ - Get a page of queues according to more complex criteria. + Get some queues according to search criteria. Args: group: Limit queues to those owned by this group; None to not limit the search filters: Search criteria, see parse_search_text() page_start: Zero-based row index where the page should start - page_length: Maximum number of rows in the page + page_length: Maximum number of rows in the page; use <= 0 for + unlimited length sort_by: Sort criterion; must be a key of SORTABLE_FIELDS. None to sort in an implementation-dependent way. descending: Whether to sort in descending order; only applicable @@ -364,7 +365,9 @@ def get_by_filters_paged( sort_criteria = Queue.resource_snapshot_id page_stmt = page_stmt.order_by(sort_criteria) - page_stmt = page_stmt.offset(page_start).limit(page_length) + page_stmt = page_stmt.offset(page_start) + if page_length > 0: + page_stmt = page_stmt.limit(page_length) queues = self.session.scalars(page_stmt).all() diff --git a/src/dioptra/restapi/db/repository/users.py b/src/dioptra/restapi/db/repository/users.py index 399010f1d..1cb58d7ee 100644 --- a/src/dioptra/restapi/db/repository/users.py +++ b/src/dioptra/restapi/db/repository/users.py @@ -35,6 +35,8 @@ assert_user_exists, check_user_collision, construct_sql_query_filters, + get_group_id, + get_user_id, user_exists, ) @@ -229,7 +231,8 @@ def get_by_filters_paged( filters: A structure representing search criteria. See parse_search_text(). page_start: A row index where the returned page should start - page_length: A row count representing the page length + page_length: A row count representing the page length; use <= 0 + for unlimited length deletion_policy: Whether to look at deleted users, non-deleted users, or all users @@ -258,7 +261,9 @@ def get_by_filters_paged( page_stmt = _apply_deletion_policy(page_stmt, deletion_policy) # *must* enforce a sort order for consistent paging page_stmt = page_stmt.order_by(User.user_id) - page_stmt = page_stmt.offset(page_start).limit(page_length) + page_stmt = page_stmt.offset(page_start) + if page_length > 0: + page_stmt = page_stmt.limit(page_length) users = self.session.scalars(page_stmt).all() @@ -289,47 +294,15 @@ def num_users( return num_users - def get_page( - self, - page_start: int, - page_length: int, - deletion_policy: DeletionPolicy = DeletionPolicy.NOT_DELETED, - ) -> Sequence[User]: - """ - Get a "page" of users. The page start is a user index (not a page - index). Use index zero to start at the beginning. Using a page start - that's beyond the last user will result in an empty sequence, not an - error. The number of users returned will not be larger than - page_length (if it is > 0), but might be smaller. - - Args: - page_start: A user start index (use 0 for the first page) - page_length: Page length, in terms of number of users per page. - If <= 0, don't limit page length (get all remaining users) - deletion_policy: Whether to look at deleted users, non-deleted - users, or all users - - Returns: - A sequence of users - """ - stmt = sa.select(User) - stmt = _apply_deletion_policy(stmt, deletion_policy) - - # *must* enforce a sort order for consistent paging - stmt = stmt.order_by(User.user_id).offset(page_start) - - if page_length > 0: - stmt = stmt.limit(page_length) - - return self.session.scalars(stmt).all() - - def get_member_permissions(self, user: User, group: Group) -> GroupMember | None: + def get_member_permissions( + self, user: User | int, group: Group | int + ) -> GroupMember | None: """ Get a user's permissions with respect to the given group. Args: - group: A group - user: A user + group: A Group object or group_id integer primary key value + user: A User object or user_id integer primary key value Returns: A GroupMember object which contains the permissions, or None if @@ -341,17 +314,22 @@ def get_member_permissions(self, user: User, group: Group) -> GroupMember | None assert_group_exists(self.session, group, DeletionPolicy.NOT_DELETED) assert_user_exists(self.session, user, DeletionPolicy.NOT_DELETED) - membership = self.session.get(GroupMember, (user.user_id, group.group_id)) + group_id = get_group_id(group) + user_id = get_user_id(user) + + membership = self.session.get(GroupMember, (user_id, group_id)) return membership - def get_manager_permissions(self, user: User, group: Group) -> GroupManager | None: + def get_manager_permissions( + self, user: User | int, group: Group | int + ) -> GroupManager | None: """ Get a user's group manager permissions with respect to the given group. Args: - group: A group - user: A user + group: A Group object or group_id integer primary key value + user: A User object or user_id integer primary key value Returns: A GroupManager object which contains the permissions, or None if @@ -363,101 +341,12 @@ def get_manager_permissions(self, user: User, group: Group) -> GroupManager | No assert_group_exists(self.session, group, DeletionPolicy.NOT_DELETED) assert_user_exists(self.session, user, DeletionPolicy.NOT_DELETED) - manager = self.session.get(GroupManager, (user.user_id, group.group_id)) - - return manager + group_id = get_group_id(group) + user_id = get_user_id(user) - def set_member_permissions( - self, - user: User, - group: Group, - read: bool | None = None, - write: bool | None = None, - share_read: bool | None = None, - share_write: bool | None = None, - ) -> None: - """ - Set a user's permissions with respect to the given group. Use None as - a permission value if needed, to leave a permission as-is. - - Args: - group: A group - user: A user - read: The read permission to set - write: The write permission to set - share_read: The share_read permission to set - share_write: The share_write permission to set - - Raises: - Exception: If the user or group don't exist, or if the given user - is not a member of the given group - """ + manager = self.session.get(GroupManager, (user_id, group_id)) - # TODO: is this method really necessary? Users could call - # get_member_permissions() to get a GroupMember object and then - # modify permissions themselves. - - assert_group_exists(self.session, group, DeletionPolicy.NOT_DELETED) - assert_user_exists(self.session, user, DeletionPolicy.NOT_DELETED) - - membership = self.session.get(GroupMember, (user.user_id, group.group_id)) - - if membership: - if read is not None: - membership.read = read - if write is not None: - membership.write = write - if share_read is not None: - membership.share_read = share_read - if share_write is not None: - membership.share_write = share_write - - else: - raise Exception( - f"Not a member: user={user.user_id}, group={group.group_id}" - ) - - def set_manager_permissions( - self, - user: User, - group: Group, - owner: bool | None = None, - admin: bool | None = None, - ) -> None: - """ - Set a manager's permissions with respect to the given group. Use None - as a permission value if needed, to leave a permission as-is. - - Args: - group: A group - user: A user - owner: The owner permission to set - admin: The admin permission to set - - Raises: - Exception: if the user or group don't exist, or if the given user - is not a manager of the given group - """ - - # TODO: is this method really necessary? Users could call - # get_manager_permissions() to get a GroupManager object and then - # modify permissions themselves. - - assert_group_exists(self.session, group, DeletionPolicy.NOT_DELETED) - assert_user_exists(self.session, user, DeletionPolicy.NOT_DELETED) - - manager = self.session.get(GroupManager, (user.user_id, group.group_id)) - - if manager: - if owner is not None: - manager.owner = owner - if admin is not None: - manager.admin = admin - - else: - raise Exception( - f"Not a manager: user={user.user_id}, group={group.group_id}" - ) + return manager def _apply_deletion_policy( diff --git a/src/dioptra/restapi/db/repository/utils.py b/src/dioptra/restapi/db/repository/utils.py index 6ef2fd76f..7bc541e03 100644 --- a/src/dioptra/restapi/db/repository/utils.py +++ b/src/dioptra/restapi/db/repository/utils.py @@ -84,6 +84,26 @@ class DeletionPolicy(enum.Enum): DELETED = enum.auto() +def get_user_id(user: User | int) -> int | None: + """ + Helper for APIs which allow a User domain object or user_id integer + primary key value. This normalizes the value to the user_id value, or + None (if a User object was passed with a null .user_id attribute). + + Args: + user: A User object or user_id integer primary key value + + Returns: + A user ID or None + """ + if isinstance(user, int): + user_id = user + else: + user_id = user.user_id + + return user_id + + def get_group_id(group: Group | int) -> int | None: """ Helper for APIs which allow a Group domain object or group_id integer @@ -91,7 +111,7 @@ def get_group_id(group: Group | int) -> int | None: None (if a Group object was passed with a null .group_id attribute). Args: - group: A group object, group_id integer primary key value + group: A Group object or group_id integer primary key value Returns: A group ID or None @@ -143,19 +163,22 @@ def get_resource_id(resource: Resource | ResourceSnapshot | int) -> int | None: return resource_id -def user_exists(session: CompatibleSession[S], user: User) -> ExistenceResult: +def user_exists(session: CompatibleSession[S], user: User | int) -> ExistenceResult: """ Check whether the given user exists in the database, and if so, whether it was deleted or not. Args: session: An SQLAlchemy session - user: A User object + user: A User object or user_id integer primary key value Returns: One of the ExistenceResult enum values """ - if user.user_id is None: + + user_id = get_user_id(user) + + if user_id is None: exists = ExistenceResult.DOES_NOT_EXIST else: # May as well get existence + deletion status in one query. I think @@ -164,7 +187,7 @@ def user_exists(session: CompatibleSession[S], user: User) -> ExistenceResult: stmt = ( sa.select(User.user_id, UserLock.user_lock_type) .outerjoin(UserLock) - .where(User.user_id == user.user_id) + .where(User.user_id == user_id) ) results = session.execute(stmt) # will need to change if a user may have multiple lock types @@ -306,7 +329,7 @@ def snapshot_exists(session: CompatibleSession[S], snapshot: ResourceSnapshot) - def assert_user_exists( - session: CompatibleSession[S], user: User, deletion_policy: DeletionPolicy + session: CompatibleSession[S], user: User | int, deletion_policy: DeletionPolicy ) -> None: """ Check whether the given user exists in the database. This function accepts @@ -321,7 +344,7 @@ def assert_user_exists( Args: session: An SQLAlchemy session - user: A User object + user: A User object or user_id integer primary key value deletion_policy: One of the DeletionPolicy enum values Raises: @@ -329,10 +352,13 @@ def assert_user_exists( """ existence_result = user_exists(session, user) - user_id = "" if user.user_id is None else user.user_id - user_name = "" if user.username is None else user.username + user_id = get_user_id(user) + if isinstance(user, int): + obj_id = str(user_id) + else: + obj_id = f"{user_id}/{user.username}" - _assert_exists(deletion_policy, existence_result, "User", f"{user_id}/{user_name}") + _assert_exists(deletion_policy, existence_result, "User", obj_id) def assert_group_exists( @@ -442,7 +468,7 @@ def assert_snapshot_exists( def assert_user_does_not_exist( - session: CompatibleSession[S], user: User, deletion_policy: DeletionPolicy + session: CompatibleSession[S], user: User | int, deletion_policy: DeletionPolicy ) -> None: """ Check whether the given user exists in the database. This function accepts @@ -458,7 +484,7 @@ def assert_user_does_not_exist( Args: session: An SQLAlchemy session - user: A User object + user: A User object or user_id integer primary key value deletion_policy: One of the DeletionPolicy enum values Raises: @@ -466,12 +492,13 @@ def assert_user_does_not_exist( """ existence_result = user_exists(session, user) - user_id = "" if user.user_id is None else user.user_id - user_name = "" if user.username is None else user.username + user_id = get_user_id(user) + if isinstance(user, int): + obj_id = str(user_id) + else: + obj_id = f"{user_id}/{user.username}" - _assert_does_not_exist( - deletion_policy, existence_result, "User", f"{user_id}/{user_name}" - ) + _assert_does_not_exist(deletion_policy, existence_result, "User", obj_id) def assert_group_does_not_exist( diff --git a/tests/unit/restapi/test_group_repository.py b/tests/unit/restapi/test_group_repository.py index 9e4b82e76..8213e16cc 100644 --- a/tests/unit/restapi/test_group_repository.py +++ b/tests/unit/restapi/test_group_repository.py @@ -242,6 +242,7 @@ def test_group_num_groups(group_repo, account, db): def test_group_num_members(group_repo, user_repo, account, db): assert group_repo.num_members(account.group) == 1 + assert group_repo.num_members(account.group.group_id) == 1 u2 = User("user2", "password2", "user2@example.org") user_repo.create(u2, account.group) @@ -251,6 +252,7 @@ def test_group_num_members(group_repo, user_repo, account, db): db.session.commit() assert group_repo.num_members(account.group) == 2 + assert group_repo.num_members(account.group.group_id) == 2 def test_group_num_members_not_exist(group_repo, account): @@ -263,6 +265,7 @@ def test_group_num_members_not_exist(group_repo, account): def test_group_num_managers(group_repo, user_repo, account, db): assert group_repo.num_managers(account.group) == 1 + assert group_repo.num_managers(account.group.group_id) == 1 u2 = User("user2", "password2", "user2@example.org") user_repo.create(u2, account.group) @@ -272,6 +275,7 @@ def test_group_num_managers(group_repo, user_repo, account, db): db.session.commit() assert group_repo.num_managers(account.group) == 2 + assert group_repo.num_managers(account.group.group_id) == 2 def test_group_num_managers_not_exist(group_repo, account): diff --git a/tests/unit/restapi/test_user_repository.py b/tests/unit/restapi/test_user_repository.py index 16428a7a4..e8ced69dd 100644 --- a/tests/unit/restapi/test_user_repository.py +++ b/tests/unit/restapi/test_user_repository.py @@ -317,20 +317,25 @@ def test_user_get_page(user_repo, account, db): # Unsure sqlalchemy always returns lists... just gather from its sequences # into lists, to make sure we can compare them. - page = list(user_repo.get_page(0, 3)) + page, count = list(user_repo.get_by_filters_paged([], 0, 3)) assert page == users[:3] + assert count == 11 - page = list(user_repo.get_page(1, 4)) + page, count = list(user_repo.get_by_filters_paged([], 1, 4)) assert page == users[1:5] + assert count == 11 - page = list(user_repo.get_page(8, 5)) + page, count = list(user_repo.get_by_filters_paged([], 8, 5)) assert page == users[8:] + assert count == 11 - page = list(user_repo.get_page(20, 20)) + page, count = list(user_repo.get_by_filters_paged([], 20, 20)) assert page == [] + assert count == 11 - page = list(user_repo.get_page(0, 0)) + page, count = list(user_repo.get_by_filters_paged([], 0, 0)) assert page == users + assert count == 11 def test_user_get_page_deleted(user_repo, account, db): @@ -355,14 +360,17 @@ def test_user_get_page_deleted(user_repo, account, db): # Unsure sqlalchemy always returns lists... just gather from its sequences # into lists, to make sure we can compare them. - page = list(user_repo.get_page(0, 3, DeletionPolicy.NOT_DELETED)) + page, count = list(user_repo.get_by_filters_paged([], 0, 3, DeletionPolicy.NOT_DELETED)) assert page == [users[0], users[1], users[3]] + assert count == 8 - page = list(user_repo.get_page(0, 3, DeletionPolicy.DELETED)) + page, count = list(user_repo.get_by_filters_paged([], 0, 3, DeletionPolicy.DELETED)) assert page == [users[2], users[6], users[8]] + assert count == 3 - page = list(user_repo.get_page(0, 3, DeletionPolicy.ANY)) + page, count = list(user_repo.get_by_filters_paged([], 0, 3, DeletionPolicy.ANY)) assert page == users[:3] + assert count == 11 def test_user_get_member_permissions(user_repo, account, db): @@ -378,16 +386,23 @@ def test_user_get_member_permissions(user_repo, account, db): assert not perms.share_read assert perms.share_write + # Get by ID + perms = user_repo.get_member_permissions(u2.user_id, account.group) + assert perms.read + assert not perms.write + assert not perms.share_read + assert perms.share_write + def test_user_get_member_permissions_not_exist(user_repo, account): u2 = User("user2", "password2", "user2@example.org") g2 = Group("group2", u2) with pytest.raises(Exception): - user_repo.get_member_permissions(u2, account.group) + user_repo.get_member_permissions(u2, account.group.group_id) with pytest.raises(Exception): - user_repo.get_member_permissions(account.user, g2) + user_repo.get_member_permissions(account.user.user_id, g2) with pytest.raises(Exception): user_repo.get_member_permissions(u2, g2) @@ -402,6 +417,14 @@ def test_user_get_manager_permissions(user_repo, account): assert mgr.admin assert mgr.owner + # Get by ID + mgr = user_repo.get_manager_permissions(account.user.user_id, account.group) + + assert mgr.user == account.user + assert mgr.group == account.group + assert mgr.admin + assert mgr.owner + def test_user_get_manager_permissions_not_exist(user_repo, account): u2 = User("user2", "password2", "user2@example.org") @@ -411,7 +434,7 @@ def test_user_get_manager_permissions_not_exist(user_repo, account): user_repo.get_manager_permissions(u2, account.group) with pytest.raises(Exception): - user_repo.get_manager_permissions(account.user, g2) + user_repo.get_manager_permissions(account.user.user_id, g2) def test_user_get_manager_permissions_not_manager(user_repo, account, db): @@ -422,111 +445,3 @@ def test_user_get_manager_permissions_not_manager(user_repo, account, db): mgr = user_repo.get_manager_permissions(u2, account.group) assert not mgr - - -def test_user_set_member_permissions(user_repo, account, db): - - user_repo.set_member_permissions( - account.user, - account.group, - read=True, - write=False, - share_read=False, - share_write=True, - ) - db.session.commit() - - membership = user_repo.get_member_permissions(account.user, account.group) - assert membership.read - assert not membership.write - assert not membership.share_read - assert membership.share_write - - # Leave some perms None and ensure they don't change - user_repo.set_member_permissions( - account.user, account.group, read=False, share_read=True - ) - db.session.commit() - - membership = user_repo.get_member_permissions(account.user, account.group) - assert not membership.read - assert not membership.write - assert membership.share_read - assert membership.share_write - - -def test_user_set_member_permissions_membership_not_exist( - user_repo, group_repo, account, db -): - - u2 = User("user2", "password2", "user2@example.org") - g2 = Group("group2", u2) - group_repo.create(g2) - db.session.commit() - - with pytest.raises(Exception): - user_repo.set_member_permissions( - u2, - account.group, - read=False, - write=True, - share_read=True, - share_write=False, - ) - - -def test_user_set_member_permissions_user_group_not_exist(user_repo, account): - - u2 = User("user2", "password2", "user2@example.org") - g2 = Group("group2", u2) - - with pytest.raises(Exception): - user_repo.set_member_permissions(u2, account.group) - - with pytest.raises(Exception): - user_repo.set_member_permissions(account.user, g2) - - -def test_user_set_manager_permissions(user_repo, account, db): - - user_repo.set_manager_permissions( - account.user, account.group, owner=False, admin=True - ) - db.session.commit() - - manager = user_repo.get_manager_permissions(account.user, account.group) - - assert not manager.owner - assert manager.admin - - # Leave some perms None and ensure they don't change - user_repo.set_manager_permissions(account.user, account.group, admin=False) - db.session.commit() - - manager = user_repo.get_manager_permissions(account.user, account.group) - - assert not manager.owner - assert not manager.admin - - -def test_user_set_manager_permissions_managership_not_exist(user_repo, account, db): - - # u2 is a regular member, not a manager - u2 = User("user2", "password2", "user2@example.org") - user_repo.create(u2, account.group) - db.session.commit() - - with pytest.raises(Exception): - user_repo.set_manager_permissions(u2, account.group) - - -def test_user_set_manager_permissions_user_group_not_exist(user_repo, account): - - u2 = User("user2", "password2", "user2@example.org") - g2 = Group("group2", u2) - - with pytest.raises(Exception): - user_repo.set_manager_permissions(u2, account.group) - - with pytest.raises(Exception): - user_repo.set_manager_permissions(account.user, g2)