diff --git a/controlpanel/api/rules.py b/controlpanel/api/rules.py index 9a62c1469..765cd23ec 100644 --- a/controlpanel/api/rules.py +++ b/controlpanel/api/rules.py @@ -133,7 +133,7 @@ def is_self(user, other): return user == other -add_perm('api.list_user', is_authenticated) +add_perm('api.list_user', is_authenticated & is_superuser) add_perm('api.create_user', is_authenticated & is_superuser) add_perm('api.retrieve_user', is_authenticated & is_self) add_perm('api.update_user', is_authenticated & is_self) diff --git a/tests/api/filters/test_user_filter.py b/tests/api/filters/test_user_filter.py index ce48a6226..bd834c067 100644 --- a/tests/api/filters/test_user_filter.py +++ b/tests/api/filters/test_user_filter.py @@ -1,29 +1,23 @@ -from model_mommy import mommy +from rest_framework import status from rest_framework.reverse import reverse -from rest_framework.status import HTTP_200_OK -from rest_framework.test import APITestCase -class UserFilterTest(APITestCase): +def user_list(client): + return client.get(reverse('user-list')) - def setUp(self): - self.superuser = mommy.make( - "api.User", is_superuser=True) - self.normal_user = mommy.make( - "api.User", is_superuser=False) - def test_superuser_sees_everything(self): - self.client.force_login(self.superuser) +def test_superuser_sees_everything(client, users): + client.force_login(users['superuser']) + response = user_list(client) + assert response.status_code == status.HTTP_200_OK - response = self.client.get(reverse("user-list")) - user_ids = [user["auth0_id"] for user in response.data["results"]] - self.assertEqual(len(user_ids), 2) - self.assertIn(self.superuser.auth0_id, user_ids) - self.assertIn(self.normal_user.auth0_id, user_ids) + all_user_ids = [user.auth0_id for key, user in users.items()] + returned_user_ids = [user["auth0_id"] for user in response.data["results"]] - def test_normal_user_sees_everything(self): - self.client.force_login(self.normal_user) + assert set(returned_user_ids) == set(all_user_ids) - response = self.client.get(reverse("user-list")) - self.assertEqual(HTTP_200_OK, response.status_code) - self.assertEqual(len(response.data["results"]), 2) + +def test_normal_user_sees_nothing(client, users): + client.force_login(users['normal_user']) + response = user_list(client) + assert response.status_code == status.HTTP_403_FORBIDDEN diff --git a/tests/api/permissions/test_user_permissions.py b/tests/api/permissions/test_user_permissions.py index d81c2bd39..fa04f1529 100644 --- a/tests/api/permissions/test_user_permissions.py +++ b/tests/api/permissions/test_user_permissions.py @@ -73,7 +73,7 @@ def user_update_self(client, users): (user_create, "superuser", status.HTTP_201_CREATED), (user_update, "superuser", status.HTTP_200_OK), - (user_list, "normal_user", status.HTTP_200_OK), + (user_list, "normal_user", status.HTTP_403_FORBIDDEN), (user_detail, "normal_user", status.HTTP_403_FORBIDDEN), (user_own_detail, "normal_user", status.HTTP_200_OK), (user_delete, "normal_user", status.HTTP_403_FORBIDDEN), diff --git a/tests/api/test_authentication.py b/tests/api/test_authentication.py index 37de5131c..1e946dcf7 100644 --- a/tests/api/test_authentication.py +++ b/tests/api/test_authentication.py @@ -66,7 +66,7 @@ def make_request(**headers): if value is not None: filtered[header] = value return client.get( - "/api/cpanel/v1/users", + "/api/cpanel/v1/apps", follow=True, **filtered, ) diff --git a/tests/frontend/views/test_user.py b/tests/frontend/views/test_user.py index 4805cf919..512d8f313 100644 --- a/tests/frontend/views/test_user.py +++ b/tests/frontend/views/test_user.py @@ -43,7 +43,7 @@ def reset_mfa(client, users, *args): 'view,user,expected_status', [ (list, 'superuser', status.HTTP_200_OK), - (list, 'normal_user', status.HTTP_200_OK), + (list, 'normal_user', status.HTTP_403_FORBIDDEN), (delete, 'superuser', status.HTTP_302_FOUND), (delete, 'normal_user', status.HTTP_403_FORBIDDEN), @@ -72,7 +72,6 @@ def test_permission(client, users, view, user, expected_status): 'view,user,expected_count', [ (list, 'superuser', 3), - (list, 'normal_user', 3), ], ) def test_list(client, users, view, user, expected_count):