Skip to content

Commit

Permalink
Merge branch 'main' into feature/feedback-form
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesstottmoj committed Dec 2, 2024
2 parents 144e611 + 2845adb commit f951bb1
Show file tree
Hide file tree
Showing 24 changed files with 854 additions and 36 deletions.
41 changes: 40 additions & 1 deletion controlpanel/api/auth0.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class ExtendedAuth0(Auth0):

DEFAULT_GRANT_TYPES = ["authorization_code", "client_credentials"]
DEFAULT_APP_TYPE = "regular_web"
M2M_APP_TYPE = "non_interactive"
M2M_GRANT_TYPES = ["client_credentials"]

DEFAULT_CONNECTION_OPTION = "email"

Expand Down Expand Up @@ -185,6 +187,41 @@ def setup_auth0_client(self, client_name, app_url_name=None, connections=None, a
self._enable_connections_for_new_client(client_id, chosen_connections=new_connections)
return client, group

def setup_m2m_client(self, client_name, scopes):
client, created = self.clients.get_or_create(
{
"name": client_name,
"app_type": "non_interactive",
"grant_types": ExtendedAuth0.M2M_GRANT_TYPES,
}
)
if not created:
return client

try:
body = {
"client_id": client["client_id"],
"scope": scopes,
"audience": settings.OIDC_CPANEL_API_AUDIENCE,
}
self.client_grants.create(body=body)
except exceptions.Auth0Error as error:
# if the client grant already exists, it will raise 409 error, so we can ignore it.
# otherwise, raise the error
if error.status_code != 409:
self.clients.delete(client["client_id"])
raise Auth0Error(error.__str__(), code=error.status_code)

return client

def rotate_m2m_client_secret(self, client_id):
try:
return self.clients.rotate_secret(client_id)
except exceptions.Auth0Error as error:
if error.status_code == 404:
return None
raise Auth0Error(error.__str__(), code=error.status_code)

def add_group_members_by_emails(self, emails, user_options={}, group_id=None, group_name=None):
user_ids = self.users.add_users_by_emails(emails, user_options=user_options)
self.groups.add_group_members(user_ids=user_ids, group_id=group_id, group_name=group_name)
Expand Down Expand Up @@ -417,9 +454,11 @@ def search_first_match(self, resource):

def get_or_create(self, resource):
result = self.search_first_match(resource)
created = False
if result is None:
result = self.create(resource)
return result
created = True
return result, created


class ExtendedClients(ExtendedAPIMethods, Clients):
Expand Down
31 changes: 31 additions & 0 deletions controlpanel/api/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ class App(EntityResource):
AUTHENTICATION_REQUIRED = "AUTHENTICATION_REQUIRED"
AUTH0_PASSWORDLESS = "AUTH0_PASSWORDLESS" # gitleaks:allow
APP_ROLE_ARN = "APP_ROLE_ARN"
API_SCOPES = ["retrieve:app", "customers:app", "add_customers:app"]

def __init__(self, app, github_api_token=None, auth0_instance=None):
super(App, self).__init__()
Expand Down Expand Up @@ -693,6 +694,36 @@ def create_auth_settings(
)
return client, group

def create_m2m_client(self):
m2m_client = self._get_auth0_instance().setup_m2m_client(
client_name=self.app.auth0_client_name("m2m"),
scopes=self.API_SCOPES,
)
if not self.app.app_conf:
self.app.app_conf = {}

# save the client ID, which we can use to retrieve the client secret
self.app.app_conf["m2m"] = {
"client_id": m2m_client["client_id"],
}
self.app.save()
return m2m_client

def rotate_m2m_client_secret(self):
m2m_client = self._get_auth0_instance().rotate_m2m_client_secret(
client_id=self.app.m2m_client_id
)
if not m2m_client:
self.app.app_conf.pop("m2m", None)
self.app.save()
return m2m_client

def delete_m2m_client(self):
response = self._get_auth0_instance().clients.delete(id=self.app.m2m_client_id)
self.app.app_conf.pop("m2m", None)
self.app.save()
return response

def remove_auth_settings(self, env_name):
try:
secrets_require_remove = [App.AUTH0_CLIENT_ID, App.AUTH0_CLIENT_SECRET]
Expand Down
12 changes: 10 additions & 2 deletions controlpanel/api/models/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def release_name(self):
def iam_role_arn(self):
return cluster.iam_arn(f"role/{self.iam_role_name}")

@property
def m2m_client_id(self):
if self.app_conf is None:
return None
return self.app_conf.get("m2m", {}).get("client_id")

def get_group_id(self, env_name):
return self.get_auth_client(env_name).get("group_id")

Expand Down Expand Up @@ -218,12 +224,14 @@ def delete_customers(self, user_ids, env_name=None, group_id=None):
except auth0.Auth0Error as e:
raise DeleteCustomerError from e

def delete_customer_by_email(self, email, group_id):
def delete_customer_by_email(self, email, group_id=None, env_name=None):
"""
Attempt to find a customer by email and delete them from the group.
If the user is not found, or the user does not belong to the given group, raise
an error.
"""
if not group_id:
group_id = self.get_auth_client(env_name).get("group_id")
auth0_client = auth0.ExtendedAuth0()
try:
user = auth0_client.users.get_users_email_search(
Expand All @@ -239,7 +247,7 @@ def delete_customer_by_email(self, email, group_id):
if group_id == group["_id"]:
return self.delete_customers(user_ids=[user["user_id"]], group_id=group_id)

raise DeleteCustomerError(f"User {email} cannot be found in this application group")
raise DeleteCustomerError(f"User {email} not found for this application and environment")

@property
def status(self):
Expand Down
37 changes: 36 additions & 1 deletion controlpanel/api/pagination.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Third-party
from django.core.paginator import Paginator
from django.core.paginator import InvalidPage, Paginator
from rest_framework import serializers
from rest_framework.pagination import PageNumberPagination, _positive_int
from rest_framework.response import Response
from rest_framework.utils.urls import replace_query_param


class CustomPageNumberPagination(PageNumberPagination):
Expand Down Expand Up @@ -53,3 +56,35 @@ def __init__(self, object_list, per_page, total_count=25, **kwargs):
def count(self):
"""Return the total number of objects, across all pages."""
return self.total_count


class Auth0ApiPagination(Auth0Paginator):

def __init__(self, request, page_number, *args, **kwargs):
self.request = request
super().__init__(*args, **kwargs)
self._page = self.get_page(page_number)

def get_page_url(self, page_number):
url = self.request.build_absolute_uri()
return replace_query_param(url, "page", page_number)

def get_next_link(self):
if not self._page.has_next():
return None
return self.get_page_url(self._page.next_page_number())

def get_previous_link(self):
if not self._page.has_previous():
return None
return self.get_page_url(self._page.previous_page_number())

def get_paginated_response(self):
return Response(
{
"count": self.count,
"next": self.get_next_link(),
"previous": self.get_previous_link(),
"results": self.object_list,
}
)
9 changes: 9 additions & 0 deletions controlpanel/api/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ def has_object_permission(self, request, view, obj):
return hasattr(request.user, "is_client") and request.user.is_client


class AppJwtPermissions(JWTTokenResourcePermissions):

def has_object_permission(self, request, view, obj):
if not super().has_object_permission(request, view, obj):
return False
client_id = request.user.pk.removesuffix("@clients")
return client_id == obj.m2m_client_id


class IsSuperuser(BasePermission):
"""
Only superusers are authorised
Expand Down
2 changes: 2 additions & 0 deletions controlpanel/api/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def is_app_admin(user, obj):
add_perm("api.manage_groups", is_authenticated & is_superuser)
add_perm("api.create_policys3bucket", is_authenticated & is_superuser)
add_perm("api.update_app_settings", is_authenticated & is_app_admin)
add_perm("api.customers_app", is_authenticated & is_app_admin)
add_perm("api.add_customers_app", is_authenticated & is_app_admin)
add_perm("api.update_app_ip_allowlists", is_authenticated & is_app_admin)


Expand Down
5 changes: 5 additions & 0 deletions controlpanel/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ class Meta:
)


class DeleteAppCustomerSerializer(serializers.Serializer):
email = serializers.EmailField(required=True)
env_name = serializers.CharField(max_length=64, required=True)


class ToolDeploymentSerializer(serializers.Serializer):
old_chart_name = serializers.CharField(max_length=64, required=False)
version = serializers.CharField(max_length=64, required=True)
Expand Down
19 changes: 15 additions & 4 deletions controlpanel/api/tasks/handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Third-party
import structlog
from celery import Task as CeleryTask

# First-party/Local
from controlpanel.api.models import Task

log = structlog.getLogger(__name__)


class BaseTaskHandler(CeleryTask):
# can be applied to project settings also
Expand All @@ -16,12 +19,20 @@ class BaseTaskHandler(CeleryTask):
task_obj = None

def complete(self):
if self.task_obj:
self.task_obj.completed = True
self.task_obj.save()
if not self.task_obj:
return log.warn("Task completed, but no object to mark as completed.")

self.task_obj.completed = True
self.task_obj.save()
log.info(f"Task object completed: {self.task_obj.task_id}")

def get_task_obj(self):
return Task.objects.filter(task_id=self.request.id).first()
task_id = self.request.id
log.info(f"Getting task object with ID: {task_id}")
task = Task.objects.filter(task_id=task_id).first()
if not task:
log.warn(f"Task object not found with ID: {task_id}. Continuing...")
return task

def run(self, *args, **kwargs):
self.task_obj = self.get_task_obj()
Expand Down
98 changes: 95 additions & 3 deletions controlpanel/api/views/apps.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
# Standard library
import re

# Third-party
from django.core.exceptions import ValidationError as DjangoValidationError
from django.core.validators import EmailValidator
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework import mixins, viewsets
from rest_framework import mixins, status, viewsets
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.fields import get_error_detail
from rest_framework.response import Response

# First-party/Local
from controlpanel.api import permissions, serializers
from controlpanel.api.models import App
from controlpanel.api.pagination import Auth0ApiPagination


class AppByNameViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet):
Expand All @@ -13,7 +23,89 @@ class AppByNameViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet):
queryset = App.objects.all()

serializer_class = serializers.AppSerializer
permission_classes = (permissions.AppPermissions | permissions.JWTTokenResourcePermissions,)
permission_classes = (permissions.AppPermissions | permissions.AppJwtPermissions,)
filter_backends = (DjangoFilterBackend,)
http_method_names = ["get"]
lookup_field = "name"

def dispatch(self, request, *args, **kwargs):
return super().dispatch(request, *args, **kwargs)

def get_serializer_class(self, *args, **kwargs):
mapping = {
"customers": serializers.AppCustomerSerializer,
"add_customers": serializers.AppCustomerSerializer,
"delete_customers": serializers.DeleteAppCustomerSerializer,
}
serializer = mapping.get(self.action)
if serializer:
return serializer
return super().get_serializer_class(*args, **kwargs)

@action(detail=True, methods=["get"])
def customers(self, request, *args, **kwargs):
if "env_name" not in request.query_params:
raise ValidationError({"env_name": "This field is required."})

app = self.get_object()
group_id = app.get_group_id(request.query_params.get("env_name", ""))
page_number = request.query_params.get("page", 1)
per_page = request.query_params.get("per_page", 25)
customers = app.customer_paginated(
page=page_number,
group_id=group_id,
per_page=per_page,
)
serializer = self.get_serializer(data=customers["users"], many=True)
serializer.is_valid()
pagination = Auth0ApiPagination(
request,
page_number,
object_list=serializer.validated_data,
total_count=customers["total"],
per_page=per_page,
)
return pagination.get_paginated_response()

@customers.mapping.post
def add_customers(self, request, *args, **kwargs):
if "env_name" not in request.query_params:
raise ValidationError({"env_name": "This field is required."})

serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)

app = self.get_object()

delimiters = re.compile(r"[,; ]+")
emails = delimiters.split(serializer.validated_data["email"])
errors = []
for email in emails:
validator = EmailValidator(message=f"{email} is not a valid email address")
try:
validator(email)
except DjangoValidationError as error:
errors.extend(get_error_detail(error))
if errors:
raise ValidationError(errors)

app.add_customers(emails, env_name=request.query_params.get("env_name", ""))

return Response(serializer.data, status=status.HTTP_201_CREATED)

@customers.mapping.delete
def delete_customers(self, request, *args, **kwargs):
"""
Delete a customer from an environment. Requires the customers email and the env name.
"""
app = self.get_object()
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)

try:
app.delete_customer_by_email(
serializer.validated_data["email"], env_name=serializer.validated_data["env_name"]
)
except app.DeleteCustomerError as error:
raise ValidationError({"email": error.args[0]})

return Response(status=status.HTTP_204_NO_CONTENT)
2 changes: 1 addition & 1 deletion controlpanel/api/views/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class AppViewSet(viewsets.ModelViewSet):

serializer_class = serializers.AppSerializer
filter_backends = (DjangoFilterBackend,)
permission_classes = (permissions.AppPermissions | permissions.JWTTokenResourcePermissions,)
permission_classes = (permissions.AppPermissions | permissions.AppJwtPermissions,)
filterset_fields = ("name", "repo_url", "slug")
lookup_field = "res_id"

Expand Down
Loading

0 comments on commit f951bb1

Please sign in to comment.