Skip to content

Commit

Permalink
Bugfix/customers api (#1403)
Browse files Browse the repository at this point in the history
* Update add_customers endpoint

* Refactor validation when getting app customers

* Raise validation error if invalid env is provided

* Catch invalid env name when adding customers
  • Loading branch information
michaeljcollinsuk authored Dec 6, 2024
1 parent c010466 commit bb318f1
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 48 deletions.
46 changes: 46 additions & 0 deletions controlpanel/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Third-party
from django.conf import settings
from django.core.exceptions import ValidationError
from django.core.validators import EmailValidator
from rest_framework import serializers

# First-party/Local
Expand Down Expand Up @@ -286,6 +287,51 @@ class Meta:
)


class AppCustomersQueryParamsSerializer(serializers.Serializer):
env_name = serializers.CharField(max_length=64, required=True)
page = serializers.IntegerField(min_value=1, required=False, default=1)
per_page = serializers.IntegerField(min_value=1, required=False, default=25)

def __init__(self, *args, **kwargs):
self.app = kwargs.pop("app")
super().__init__(*args, **kwargs)

def validate_env_name(self, env_name):
if not self.app.get_group_id(env_name):
raise serializers.ValidationError(f"{env_name} is invalid for this app.")
return env_name


class AddAppCustomersSerializer(serializers.Serializer):
emails = serializers.CharField(max_length=None, required=True)
env_name = serializers.CharField(max_length=64, required=True)

def __init__(self, *args, **kwargs):
self.app = kwargs.pop("app")
super().__init__(*args, **kwargs)

def validate_emails(self, emails):
delimiters = re.compile(r"[,; ]+") # split by comma, semicolon, and space
emails = delimiters.split(emails)
errors = []
validator = EmailValidator()
for email in emails:
try:
validator(email)
except ValidationError:
errors.append(email)
if errors:
raise serializers.ValidationError(
f"Request contains invalid emails: {', '.join(errors)}"
)
return emails

def validate_env_name(self, env_name):
if not self.app.get_group_id(env_name):
raise serializers.ValidationError(f"{env_name} is invalid for this app.")
return env_name


class DeleteAppCustomerSerializer(serializers.Serializer):
email = serializers.EmailField(required=True)
env_name = serializers.CharField(max_length=64, required=True)
Expand Down
73 changes: 32 additions & 41 deletions controlpanel/api/views/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
import re

# Third-party
from django.core.exceptions import ValidationError as DjangoValidationError
from django.core.validators import EmailValidator
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework import mixins, status, viewsets
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.fields import get_error_detail
from rest_framework.response import Response

# First-party/Local
Expand All @@ -33,7 +30,7 @@ def dispatch(self, request, *args, **kwargs):
def get_serializer_class(self, *args, **kwargs):
mapping = {
"customers": serializers.AppCustomerSerializer,
"add_customers": serializers.AppCustomerSerializer,
"add_customers": serializers.AddAppCustomersSerializer,
"delete_customers": serializers.DeleteAppCustomerSerializer,
}
serializer = mapping.get(self.action)
Expand All @@ -43,54 +40,48 @@ def get_serializer_class(self, *args, **kwargs):

@action(detail=True, methods=["get"])
def customers(self, request, *args, **kwargs):
if "env_name" not in request.query_params:
raise ValidationError({"env_name": "This field is required."})

app = self.get_object()
group_id = app.get_group_id(request.query_params.get("env_name", ""))
page_number = request.query_params.get("page", 1)
per_page = request.query_params.get("per_page", 25)
serializer = serializers.AppCustomersQueryParamsSerializer(
data=request.query_params, app=app
)
serializer.is_valid(raise_exception=True)
validated_params = serializer.validated_data

group_id = app.get_group_id(validated_params["env_name"])
customers = app.customer_paginated(
page=page_number,
page=validated_params["page"],
group_id=group_id,
per_page=per_page,
per_page=validated_params["per_page"],
)
serializer = self.get_serializer(data=customers["users"], many=True)
serializer.is_valid()
pagination = Auth0ApiPagination(
request,
page_number,
object_list=serializer.validated_data,
customers_serializer = self.get_serializer(data=customers["users"], many=True)
customers_serializer.is_valid(raise_exception=True)

return Auth0ApiPagination(
request=request,
page_number=validated_params["page"],
object_list=customers_serializer.validated_data,
total_count=customers["total"],
per_page=per_page,
)
return pagination.get_paginated_response()
per_page=validated_params["per_page"],
).get_paginated_response()

@customers.mapping.post
def add_customers(self, request, *args, **kwargs):
if "env_name" not in request.query_params:
raise ValidationError({"env_name": "This field is required."})

serializer = self.get_serializer(data=request.data)
app = self.get_object()
serializer = self.get_serializer(data=request.data, app=app)
serializer.is_valid(raise_exception=True)

app = self.get_object()
try:
app.add_customers(
serializer.validated_data["emails"], env_name=serializer.validated_data["env_name"]
)
except app.AddCustomerError:
raise ValidationError(
"An error occurred trying to add customers, check that the environment exists."
)

delimiters = re.compile(r"[,; ]+")
emails = delimiters.split(serializer.validated_data["email"])
errors = []
for email in emails:
validator = EmailValidator(message=f"{email} is not a valid email address")
try:
validator(email)
except DjangoValidationError as error:
errors.extend(get_error_detail(error))
if errors:
raise ValidationError(errors)

app.add_customers(emails, env_name=request.query_params.get("env_name", ""))

return Response(serializer.data, status=status.HTTP_201_CREATED)
return Response(
{"message": "Successfully added customers."}, status=status.HTTP_201_CREATED
)

@customers.mapping.delete
def delete_customers(self, request, *args, **kwargs):
Expand Down
15 changes: 12 additions & 3 deletions tests/api/permissions/test_app_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

# First-party/Local
from controlpanel.api.jwt_auth import AuthenticatedServiceClient
from controlpanel.api.models import App
from controlpanel.api.permissions import AppJwtPermissions


Expand Down Expand Up @@ -42,7 +43,16 @@ def users(users, authenticated_client, invalid_client_sub, invalid_client_scope)

@pytest.fixture(autouse=True)
def app(users):
app = baker.make("api.App", name="Test App 1", app_conf={"m2m": {"client_id": "abc123"}})
app = baker.make(
"api.App",
name="Test App 1",
app_conf={
"m2m": {"client_id": "abc123"},
App.KEY_WORD_FOR_AUTH_SETTINGS: {
"test": {"group_id": "test_group_id"},
},
},
)
user = users["app_admin"]
baker.make("api.UserApp", user=user, app=app, is_admin=True)

Expand Down Expand Up @@ -125,12 +135,11 @@ def app_by_name_customers(client, app, *args):


def app_by_name_add_customers(client, app, *args):
data = {"email": "[email protected]"}
data = {"emails": "[email protected]", "env_name": "test"}
with patch("controlpanel.api.models.App.add_customers"):
return client.post(
reverse("apps-by-name-customers", kwargs={"name": app.name}),
data,
query_params={"env_name": "test"},
)


Expand Down
40 changes: 36 additions & 4 deletions tests/api/views/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ def app():
return baker.make(
"api.App",
repo_url="https://github.com/ministryofjustice/example.git",
app_conf={
App.KEY_WORD_FOR_AUTH_SETTINGS: {
"dev": {"group_id": "dev_group_id"},
"prod": {"group_id": "prod_group_id"},
}
},
)


Expand Down Expand Up @@ -114,7 +120,7 @@ def customer():
}


@pytest.mark.parametrize("env_name", ["dev", "prod"])
@pytest.mark.parametrize("env_name", ["dev", "prod"], ids=["dev", "prod"])
def test_app_by_name_get_customers(client, app, customer, env_name):
with patch("controlpanel.api.models.App.customer_paginated") as customer_paginated:
customer_paginated.return_value = {"total": 1, "users": [customer]}
Expand All @@ -135,17 +141,43 @@ def test_app_by_name_get_customers(client, app, customer, env_name):
assert response.data["results"] == [{field: customer[field] for field in expected_fields}]


@pytest.mark.parametrize("env_name", ["", "foo"])
def test_app_by_name_get_customers_invalid(client, app, env_name):
with patch("controlpanel.api.models.App.customer_paginated") as customer_paginated:

response = client.get(
reverse("apps-by-name-customers", kwargs={"name": app.name}),
query_params={"env_name": env_name},
)
customer_paginated.assert_not_called()
assert response.status_code == status.HTTP_400_BAD_REQUEST


@pytest.mark.parametrize("env_name", ["dev", "prod"])
def test_app_by_name_add_customers(client, app, env_name):
emails = ["[email protected]", "[email protected]"]
data = {"email": ", ".join(emails)}
data = {"emails": ", ".join(emails), "env_name": env_name}

with patch("controlpanel.api.models.App.add_customers") as add_customers:
url = reverse("apps-by-name-customers", kwargs={"name": app.name})
response = client.post(f"{url}?env_name={env_name}", data=data)
response = client.post(url, data=data)

add_customers.assert_called_once_with(emails, env_name=env_name)
assert response.status_code == status.HTTP_201_CREATED
add_customers.assert_called_once_with(emails, env_name=env_name)


@pytest.mark.parametrize("env_name", ["", "foo"])
def test_app_by_name_add_customers_invalid(client, app, env_name):
emails = ["[email protected]", "[email protected]"]
data = {"emails": ", ".join(emails)}

with patch("controlpanel.api.models.App.add_customers") as add_customers:
add_customers.side_effect = app.AddCustomerError
url = reverse("apps-by-name-customers", kwargs={"name": app.name})
response = client.post(url, data=data)

assert response.status_code == status.HTTP_400_BAD_REQUEST
add_customers.assert_not_called()


def test_app_detail_by_name(client, app):
Expand Down

0 comments on commit bb318f1

Please sign in to comment.