diff --git a/staff/models.py b/staff/models.py index d8e3b60e6..2e8a5cb57 100644 --- a/staff/models.py +++ b/staff/models.py @@ -11,7 +11,7 @@ def __str__(self) -> str: return f"{self.employee_no} - {self.first_name} {self.last_name}" -class StaffForecast(models.QuerySet): +class StaffForecastQuerySet(models.QuerySet): pass @@ -24,7 +24,7 @@ class Meta: ) ] - objects = StaffForecast.as_manager() + objects = StaffForecastQuerySet.as_manager() staff = models.ForeignKey(Staff, models.PROTECT, related_name="forecast") year = models.ForeignKey("core.FinancialYear", models.PROTECT) diff --git a/staff/tests.py b/staff/tests.py deleted file mode 100644 index 7ce503c2d..000000000 --- a/staff/tests.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.test import TestCase - -# Create your tests here. diff --git a/staff/tests/test_views.py b/staff/tests/test_views.py new file mode 100644 index 000000000..6577f36c3 --- /dev/null +++ b/staff/tests/test_views.py @@ -0,0 +1,30 @@ +import pytest + +from django.contrib.auth import get_user_model + + +User = get_user_model() + + +@pytest.fixture +def user(db, client): + user = User.objects.create_user( + username="staff.test", + email="staff.test@example.com", + password="password", + ) + user.save() + client.force_login(user) + return user + + +@pytest.mark.parametrize( + "url", + [ + "/staff/edit-payroll/", + "/staff/debug/", + ], +) +def test_only_superuser_can_access(client, user, url): + r = client.get(url) + assert r.status_code == 403 diff --git a/staff/views.py b/staff/views.py index 09e3d03c4..c9abb4189 100644 --- a/staff/views.py +++ b/staff/views.py @@ -1,6 +1,8 @@ +from functools import wraps from django.http import HttpResponse, HttpRequest from django.template.response import TemplateResponse from django.contrib.auth.decorators import user_passes_test +from django.core.exceptions import PermissionDenied from core.models import FinancialYear from costcentre.models import CostCentre @@ -10,17 +12,23 @@ # TODO: Remove once no longer needed. -def _user_is_superuser(user): - return user.is_superuser +def superuser_view(view_func): + @wraps(view_func) + def wrapper(request, *args, **kwargs): + if not request.user.is_superuser: + raise PermissionDenied + return view_func(*args, **kwargs) + return wrapper -@user_passes_test(_user_is_superuser) + +@superuser_view def edit_payroll_page(request: HttpRequest) -> HttpResponse: context = {} return TemplateResponse(request, "staff/page/edit_payroll.html", context) -@user_passes_test(_user_is_superuser) +@superuser_view def staff_debug_page(request: HttpRequest) -> HttpResponse: if request.GET.get("cost_centre"): cost_centre = CostCentre.objects.get(pk=request.GET.get("cost_centre"))