diff --git a/controlpanel/api/serializers.py b/controlpanel/api/serializers.py index fecdd182..4f101302 100644 --- a/controlpanel/api/serializers.py +++ b/controlpanel/api/serializers.py @@ -6,6 +6,7 @@ # Third-party from django.conf import settings from django.core.exceptions import ValidationError +from django.core.validators import EmailValidator from rest_framework import serializers # First-party/Local @@ -286,6 +287,51 @@ class Meta: ) +class AppCustomersQueryParamsSerializer(serializers.Serializer): + env_name = serializers.CharField(max_length=64, required=True) + page = serializers.IntegerField(min_value=1, required=False, default=1) + per_page = serializers.IntegerField(min_value=1, required=False, default=25) + + def __init__(self, *args, **kwargs): + self.app = kwargs.pop("app") + super().__init__(*args, **kwargs) + + def validate_env_name(self, env_name): + if not self.app.get_group_id(env_name): + raise serializers.ValidationError(f"{env_name} is invalid for this app.") + return env_name + + +class AddAppCustomersSerializer(serializers.Serializer): + emails = serializers.CharField(max_length=None, required=True) + env_name = serializers.CharField(max_length=64, required=True) + + def __init__(self, *args, **kwargs): + self.app = kwargs.pop("app") + super().__init__(*args, **kwargs) + + def validate_emails(self, emails): + delimiters = re.compile(r"[,; ]+") # split by comma, semicolon, and space + emails = delimiters.split(emails) + errors = [] + validator = EmailValidator() + for email in emails: + try: + validator(email) + except ValidationError: + errors.append(email) + if errors: + raise serializers.ValidationError( + f"Request contains invalid emails: {', '.join(errors)}" + ) + return emails + + def validate_env_name(self, env_name): + if not self.app.get_group_id(env_name): + raise serializers.ValidationError(f"{env_name} is invalid for this app.") + return env_name + + class DeleteAppCustomerSerializer(serializers.Serializer): email = serializers.EmailField(required=True) env_name = serializers.CharField(max_length=64, required=True) diff --git a/controlpanel/api/views/apps.py b/controlpanel/api/views/apps.py index 683cde0e..4fc10085 100644 --- a/controlpanel/api/views/apps.py +++ b/controlpanel/api/views/apps.py @@ -2,13 +2,10 @@ import re # Third-party -from django.core.exceptions import ValidationError as DjangoValidationError -from django.core.validators import EmailValidator from django_filters.rest_framework import DjangoFilterBackend from rest_framework import mixins, status, viewsets from rest_framework.decorators import action from rest_framework.exceptions import ValidationError -from rest_framework.fields import get_error_detail from rest_framework.response import Response # First-party/Local @@ -33,7 +30,7 @@ def dispatch(self, request, *args, **kwargs): def get_serializer_class(self, *args, **kwargs): mapping = { "customers": serializers.AppCustomerSerializer, - "add_customers": serializers.AppCustomerSerializer, + "add_customers": serializers.AddAppCustomersSerializer, "delete_customers": serializers.DeleteAppCustomerSerializer, } serializer = mapping.get(self.action) @@ -43,54 +40,48 @@ def get_serializer_class(self, *args, **kwargs): @action(detail=True, methods=["get"]) def customers(self, request, *args, **kwargs): - if "env_name" not in request.query_params: - raise ValidationError({"env_name": "This field is required."}) - app = self.get_object() - group_id = app.get_group_id(request.query_params.get("env_name", "")) - page_number = request.query_params.get("page", 1) - per_page = request.query_params.get("per_page", 25) + serializer = serializers.AppCustomersQueryParamsSerializer( + data=request.query_params, app=app + ) + serializer.is_valid(raise_exception=True) + validated_params = serializer.validated_data + + group_id = app.get_group_id(validated_params["env_name"]) customers = app.customer_paginated( - page=page_number, + page=validated_params["page"], group_id=group_id, - per_page=per_page, + per_page=validated_params["per_page"], ) - serializer = self.get_serializer(data=customers["users"], many=True) - serializer.is_valid() - pagination = Auth0ApiPagination( - request, - page_number, - object_list=serializer.validated_data, + customers_serializer = self.get_serializer(data=customers["users"], many=True) + customers_serializer.is_valid(raise_exception=True) + + return Auth0ApiPagination( + request=request, + page_number=validated_params["page"], + object_list=customers_serializer.validated_data, total_count=customers["total"], - per_page=per_page, - ) - return pagination.get_paginated_response() + per_page=validated_params["per_page"], + ).get_paginated_response() @customers.mapping.post def add_customers(self, request, *args, **kwargs): - if "env_name" not in request.query_params: - raise ValidationError({"env_name": "This field is required."}) - - serializer = self.get_serializer(data=request.data) + app = self.get_object() + serializer = self.get_serializer(data=request.data, app=app) serializer.is_valid(raise_exception=True) - app = self.get_object() + try: + app.add_customers( + serializer.validated_data["emails"], env_name=serializer.validated_data["env_name"] + ) + except app.AddCustomerError: + raise ValidationError( + "An error occurred trying to add customers, check that the environment exists." + ) - delimiters = re.compile(r"[,; ]+") - emails = delimiters.split(serializer.validated_data["email"]) - errors = [] - for email in emails: - validator = EmailValidator(message=f"{email} is not a valid email address") - try: - validator(email) - except DjangoValidationError as error: - errors.extend(get_error_detail(error)) - if errors: - raise ValidationError(errors) - - app.add_customers(emails, env_name=request.query_params.get("env_name", "")) - - return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response( + {"message": "Successfully added customers."}, status=status.HTTP_201_CREATED + ) @customers.mapping.delete def delete_customers(self, request, *args, **kwargs): diff --git a/tests/api/permissions/test_app_permissions.py b/tests/api/permissions/test_app_permissions.py index e5d2e08f..d4938a38 100644 --- a/tests/api/permissions/test_app_permissions.py +++ b/tests/api/permissions/test_app_permissions.py @@ -15,6 +15,7 @@ # First-party/Local from controlpanel.api.jwt_auth import AuthenticatedServiceClient +from controlpanel.api.models import App from controlpanel.api.permissions import AppJwtPermissions @@ -42,7 +43,16 @@ def users(users, authenticated_client, invalid_client_sub, invalid_client_scope) @pytest.fixture(autouse=True) def app(users): - app = baker.make("api.App", name="Test App 1", app_conf={"m2m": {"client_id": "abc123"}}) + app = baker.make( + "api.App", + name="Test App 1", + app_conf={ + "m2m": {"client_id": "abc123"}, + App.KEY_WORD_FOR_AUTH_SETTINGS: { + "test": {"group_id": "test_group_id"}, + }, + }, + ) user = users["app_admin"] baker.make("api.UserApp", user=user, app=app, is_admin=True) @@ -125,12 +135,11 @@ def app_by_name_customers(client, app, *args): def app_by_name_add_customers(client, app, *args): - data = {"email": "example@email.com"} + data = {"emails": "example@email.com", "env_name": "test"} with patch("controlpanel.api.models.App.add_customers"): return client.post( reverse("apps-by-name-customers", kwargs={"name": app.name}), data, - query_params={"env_name": "test"}, ) diff --git a/tests/api/views/test_app.py b/tests/api/views/test_app.py index 93d1132e..928d7405 100644 --- a/tests/api/views/test_app.py +++ b/tests/api/views/test_app.py @@ -21,6 +21,12 @@ def app(): return baker.make( "api.App", repo_url="https://github.com/ministryofjustice/example.git", + app_conf={ + App.KEY_WORD_FOR_AUTH_SETTINGS: { + "dev": {"group_id": "dev_group_id"}, + "prod": {"group_id": "prod_group_id"}, + } + }, ) @@ -114,7 +120,7 @@ def customer(): } -@pytest.mark.parametrize("env_name", ["dev", "prod"]) +@pytest.mark.parametrize("env_name", ["dev", "prod"], ids=["dev", "prod"]) def test_app_by_name_get_customers(client, app, customer, env_name): with patch("controlpanel.api.models.App.customer_paginated") as customer_paginated: customer_paginated.return_value = {"total": 1, "users": [customer]} @@ -135,17 +141,43 @@ def test_app_by_name_get_customers(client, app, customer, env_name): assert response.data["results"] == [{field: customer[field] for field in expected_fields}] +@pytest.mark.parametrize("env_name", ["", "foo"]) +def test_app_by_name_get_customers_invalid(client, app, env_name): + with patch("controlpanel.api.models.App.customer_paginated") as customer_paginated: + + response = client.get( + reverse("apps-by-name-customers", kwargs={"name": app.name}), + query_params={"env_name": env_name}, + ) + customer_paginated.assert_not_called() + assert response.status_code == status.HTTP_400_BAD_REQUEST + + @pytest.mark.parametrize("env_name", ["dev", "prod"]) def test_app_by_name_add_customers(client, app, env_name): emails = ["test1@example.com", "test2@example.com"] - data = {"email": ", ".join(emails)} + data = {"emails": ", ".join(emails), "env_name": env_name} with patch("controlpanel.api.models.App.add_customers") as add_customers: url = reverse("apps-by-name-customers", kwargs={"name": app.name}) - response = client.post(f"{url}?env_name={env_name}", data=data) + response = client.post(url, data=data) - add_customers.assert_called_once_with(emails, env_name=env_name) assert response.status_code == status.HTTP_201_CREATED + add_customers.assert_called_once_with(emails, env_name=env_name) + + +@pytest.mark.parametrize("env_name", ["", "foo"]) +def test_app_by_name_add_customers_invalid(client, app, env_name): + emails = ["test1@example.com", "test2@example.com"] + data = {"emails": ", ".join(emails)} + + with patch("controlpanel.api.models.App.add_customers") as add_customers: + add_customers.side_effect = app.AddCustomerError + url = reverse("apps-by-name-customers", kwargs={"name": app.name}) + response = client.post(url, data=data) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + add_customers.assert_not_called() def test_app_detail_by_name(client, app):