Skip to content

Commit

Permalink
Take into consideration child organizations for dashboard admins
Browse files Browse the repository at this point in the history
  • Loading branch information
marcospri committed Aug 6, 2024
1 parent bcdf643 commit c05a73c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 20 deletions.
28 changes: 23 additions & 5 deletions lms/services/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

class DashboardService:
def __init__(
self, request, assignment_service, course_service, organization_service
self,
request,
assignment_service,
course_service,
organization_service: OrganizationService,
):
self._db = request.db

Expand Down Expand Up @@ -69,11 +73,19 @@ def get_request_course(self, request):

def get_organizations_by_admin_email(self, email: str) -> list[Organization]:
"""Get a list of organizations where the user with email `email` is an admin in."""
return self._db.scalars(
select(Organization)
.join(DashboardAdmin)
organization_ids = []

for org_id in self._db.scalars(
select(DashboardAdmin.organization_id)
.where(DashboardAdmin.email == email)
.distinct()
).all():
organization_ids.extend(
self._organization_service.get_hierarchy_ids(org_id)
)

return self._db.scalars(
select(Organization).where(Organization.id.in_(organization_ids))
).all()

def add_dashboard_admin(
Expand Down Expand Up @@ -104,7 +116,13 @@ def get_request_admin_organizations(self, request) -> list[Organization]:
if not organization:
raise HTTPNotFound()

return [organization]
return self._db.scalars(
select(Organization).where(
Organization.id.in_(
self._organization_service.get_hierarchy_ids(organization.id)
)
)
).all()

return self.get_organizations_by_admin_email(
request.lti_user.email if request.lti_user else request.identity.userid
Expand Down
38 changes: 23 additions & 15 deletions tests/unit/lms/services/dashboard_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,14 @@


class TestDashboardService:
def test_get_request_assignment_404(
self,
pyramid_request,
assignment_service,
svc,
):
def test_get_request_assignment_404(self, pyramid_request, assignment_service, svc):
pyramid_request.matchdict["assignment_id"] = sentinel.id
assignment_service.get_by_id.return_value = None

with pytest.raises(HTTPNotFound):
svc.get_request_assignment(pyramid_request)

def test_get_request_assignment_403(
self,
pyramid_request,
assignment_service,
svc,
):
def test_get_request_assignment_403(self, pyramid_request, assignment_service, svc):
pyramid_request.matchdict["assignment_id"] = sentinel.id
assignment_service.is_member.return_value = False

Expand Down Expand Up @@ -143,13 +133,23 @@ def test_delete_dashboard_admin(self, svc, db_session, organization):

assert not db_session.query(DashboardAdmin).filter_by(id=admin.id).first()

def test_get_organizations_by_admin_email(self, svc, db_session, organization):
def test_get_organizations_by_admin_email(
self, svc, db_session, organization, organization_service
):
child_organization = factories.Organization(parent=organization)
admin = factories.DashboardAdmin(
organization=organization, email="[email protected]", created_by="creator"
)
db_session.flush()
organization_service.get_hierarchy_ids.return_value = [
organization.id,
child_organization.id,
]

assert svc.get_organizations_by_admin_email(admin.email) == [organization]
assert set(svc.get_organizations_by_admin_email(admin.email)) == {
organization,
child_organization,
}

def test_get_request_admin_organizations_for_non_staff(self, pyramid_request, svc):
pyramid_request.params = {"public_id": sentinel.public_id}
Expand Down Expand Up @@ -178,6 +178,7 @@ def test_get_request_admin_organizations(
):
pyramid_config.testing_securitypolicy(permissive=True)
organization_service.get_by_public_id.return_value = organization
organization_service.get_hierarchy_ids.return_value = [organization.id]
pyramid_request.params = {"public_id": sentinel.public_id}

assert svc.get_request_admin_organizations(pyramid_request) == [organization]
Expand All @@ -191,6 +192,7 @@ def test_get_request_admin_organizations_for_staff(
pyramid_config.testing_securitypolicy(permissive=True)
pyramid_request.params = {"public_id": sentinel.id}
organization_service.get_by_public_id.return_value = organization
organization_service.get_hierarchy_ids.return_value = [organization.id]

assert svc.get_request_admin_organizations(pyramid_request) == [organization]

Expand All @@ -202,7 +204,13 @@ def svc(
pyramid_request, assignment_service, course_service, organization_service
)

@pytest.fixture()
@pytest.fixture(autouse=True)
def organization_service(self, organization_service, organization):
organization_service.get_by_public_id.return_value = organization
organization_service.get_hierarchy_ids.return_value = []
return organization_service

@pytest.fixture
def get_request_admin_organizations(self, svc):
with patch.object(
svc, "get_request_admin_organizations"
Expand Down

0 comments on commit c05a73c

Please sign in to comment.