Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve groupby operation in GET /auth/group/ #1745

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 35 additions & 21 deletions fractal_server/app/routes/auth/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from fastapi import Depends
from fastapi import HTTPException
from fastapi import status
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import join
from sqlmodel import col
from sqlmodel import select

Expand All @@ -32,30 +34,42 @@ async def get_list_user_groups(
db: AsyncSession = Depends(get_async_db),
) -> list[UserGroupRead]:
"""
FIXME docstring
FIXME docstring x
"""

# Get all groups
stm_all_groups = select(UserGroup)
res = await db.execute(stm_all_groups)
groups = res.scalars().all()

if user_ids is True:
# Get all user/group links
stm_all_links = select(LinkUserGroup)
res = await db.execute(stm_all_links)
links = res.scalars().all()

# FIXME GROUPS: this must be optimized
for ind, group in enumerate(groups):
groups[ind] = dict(
group.model_dump(),
user_ids=[
link.user_id for link in links if link.group_id == group.id
],
)
if user_ids is False:
# Get all groups, sorted by `id`
stm_groups = select(UserGroup).order_by("id")
res = await db.execute(stm_groups)
groups = res.scalars().all()
return groups

else:

SEPARATOR = ","

return groups
stm = (
select(
UserGroup,
func.aggregate_strings(LinkUserGroup.user_id, SEPARATOR),
)
.select_from(
join(
LinkUserGroup,
UserGroup,
LinkUserGroup.group_id == UserGroup.id,
)
)
.group_by(LinkUserGroup.group_id)
.order_by(UserGroup.id)
)
res = await db.execute(stm)
enriched_groups = []
for row in res.all():
group, user_ids_string = row[:]
user_ids = [int(_id) for _id in user_ids_string.split(SEPARATOR)]
enriched_groups.append(dict(group.model_dump(), user_ids=user_ids))
return enriched_groups


@router_group.get(
Expand Down
21 changes: 15 additions & 6 deletions tests/no_version/test_auth_groups_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ async def test_user_group_crud(registered_superuser_client):
# Preliminary: register two new users
credentials_user_A = dict(email="[email protected]", password="12345")
credentials_user_B = dict(email="[email protected]", password="12345")
credentials_user_C = dict(email="[email protected]", password="12345")
res = await registered_superuser_client.post(
f"{PREFIX}/register/", json=credentials_user_A
)
Expand All @@ -100,31 +101,38 @@ async def test_user_group_crud(registered_superuser_client):
)
assert res.status_code == 201
user_B_id = res.json()["id"]
res = await registered_superuser_client.post(
f"{PREFIX}/register/", json=credentials_user_C
)
assert res.status_code == 201
user_C_id = res.json()["id"]

# Create group 1 with user A
# Create group 1
res = await registered_superuser_client.post(
f"{PREFIX}/group/", json=dict(name="group 1")
)
assert res.status_code == 201
group_1_id = res.json()["id"]

# Create group 2 with users A and B
# Create group 2
res = await registered_superuser_client.post(
f"{PREFIX}/group/",
json=dict(name="group 2"),
)
assert res.status_code == 201
group_2_id = res.json()["id"]

# Add user A to group 1
# Add users A and B and C to group 1
res = await registered_superuser_client.patch(
f"{PREFIX}/group/{group_1_id}/",
json=dict(new_user_ids=[user_A_id, user_B_id]),
json=dict(new_user_ids=[user_A_id, user_B_id, user_C_id]),
)
assert res.status_code == 200

# Add user B to group 2
res = await registered_superuser_client.patch(
f"{PREFIX}/group/{group_2_id}/", json=dict(new_user_ids=[user_B_id])
f"{PREFIX}/group/{group_2_id}/",
json=dict(new_user_ids=[user_B_id]),
)
assert res.status_code == 200

Expand All @@ -133,11 +141,12 @@ async def test_user_group_crud(registered_superuser_client):
f"{PREFIX}/group/?user_ids=true"
)
assert res.status_code == 200

groups_data = res.json()
assert len(groups_data) == 2
for group in groups_data:
if group["name"] == "group 1":
assert set(group["user_ids"]) == {user_A_id, user_B_id}
assert set(group["user_ids"]) == {user_A_id, user_B_id, user_C_id}
elif group["name"] == "group 2":
assert group["user_ids"] == [user_B_id]
else:
Expand Down
Loading