Skip to content

Commit

Permalink
Merge branch 'migrate-to-ghcr' of github.com:ministryofjustice/analyt…
Browse files Browse the repository at this point in the history
…ics-platform-control-panel into migrate-to-ghcr
  • Loading branch information
Emterry committed Dec 3, 2024
2 parents 84bccaa + 79a0e1e commit 1493325
Show file tree
Hide file tree
Showing 32 changed files with 1,041 additions and 40 deletions.
41 changes: 40 additions & 1 deletion controlpanel/api/auth0.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class ExtendedAuth0(Auth0):

DEFAULT_GRANT_TYPES = ["authorization_code", "client_credentials"]
DEFAULT_APP_TYPE = "regular_web"
M2M_APP_TYPE = "non_interactive"
M2M_GRANT_TYPES = ["client_credentials"]

DEFAULT_CONNECTION_OPTION = "email"

Expand Down Expand Up @@ -185,6 +187,41 @@ def setup_auth0_client(self, client_name, app_url_name=None, connections=None, a
self._enable_connections_for_new_client(client_id, chosen_connections=new_connections)
return client, group

def setup_m2m_client(self, client_name, scopes):
client, created = self.clients.get_or_create(
{
"name": client_name,
"app_type": "non_interactive",
"grant_types": ExtendedAuth0.M2M_GRANT_TYPES,
}
)
if not created:
return client

try:
body = {
"client_id": client["client_id"],
"scope": scopes,
"audience": settings.OIDC_CPANEL_API_AUDIENCE,
}
self.client_grants.create(body=body)
except exceptions.Auth0Error as error:
# if the client grant already exists, it will raise 409 error, so we can ignore it.
# otherwise, raise the error
if error.status_code != 409:
self.clients.delete(client["client_id"])
raise Auth0Error(error.__str__(), code=error.status_code)

return client

def rotate_m2m_client_secret(self, client_id):
try:
return self.clients.rotate_secret(client_id)
except exceptions.Auth0Error as error:
if error.status_code == 404:
return None
raise Auth0Error(error.__str__(), code=error.status_code)

def add_group_members_by_emails(self, emails, user_options={}, group_id=None, group_name=None):
user_ids = self.users.add_users_by_emails(emails, user_options=user_options)
self.groups.add_group_members(user_ids=user_ids, group_id=group_id, group_name=group_name)
Expand Down Expand Up @@ -417,9 +454,11 @@ def search_first_match(self, resource):

def get_or_create(self, resource):
result = self.search_first_match(resource)
created = False
if result is None:
result = self.create(resource)
return result
created = True
return result, created


class ExtendedClients(ExtendedAPIMethods, Clients):
Expand Down
51 changes: 50 additions & 1 deletion controlpanel/api/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ class App(EntityResource):
AUTHENTICATION_REQUIRED = "AUTHENTICATION_REQUIRED"
AUTH0_PASSWORDLESS = "AUTH0_PASSWORDLESS" # gitleaks:allow
APP_ROLE_ARN = "APP_ROLE_ARN"
API_SCOPES = ["retrieve:app", "customers:app", "add_customers:app"]

def __init__(self, app, github_api_token=None, auth0_instance=None):
super(App, self).__init__()
Expand Down Expand Up @@ -501,13 +502,31 @@ def oidc_provider_statement(self):
)
return json.loads(statement)

@property
def xacct_trust_statement(self):
"""
Builds an assume role statement for a Cloud Platform IAM role
"""
statement = render_to_string(
template_name="assume_roles/cloud_platform_xacct.json",
context={"app_role": self.app.cloud_platform_role_arn},
)
return json.loads(statement)

def create_iam_role(self):
statement = self._get_statement()
assume_role_policy = deepcopy(BASE_ASSUME_ROLE_POLICY)
assume_role_policy["Statement"].append(self.oidc_provider_statement)
assume_role_policy["Statement"].append(statement)
self.aws_role_service.create_role(self.iam_role_name, assume_role_policy)
for env in self.get_deployment_envs():
self._create_secrets(env_name=env)

def _get_statement(self):
if self.app.cloud_platform_role_arn:
return self.xacct_trust_statement

return self.oidc_provider_statement

def grant_bucket_access(self, bucket_arn, access_level, path_arns):
self.aws_role_service.grant_bucket_access(
self.iam_role_name, bucket_arn, access_level, path_arns
Expand Down Expand Up @@ -675,6 +694,36 @@ def create_auth_settings(
)
return client, group

def create_m2m_client(self):
m2m_client = self._get_auth0_instance().setup_m2m_client(
client_name=self.app.auth0_client_name("m2m"),
scopes=self.API_SCOPES,
)
if not self.app.app_conf:
self.app.app_conf = {}

# save the client ID, which we can use to retrieve the client secret
self.app.app_conf["m2m"] = {
"client_id": m2m_client["client_id"],
}
self.app.save()
return m2m_client

def rotate_m2m_client_secret(self):
m2m_client = self._get_auth0_instance().rotate_m2m_client_secret(
client_id=self.app.m2m_client_id
)
if not m2m_client:
self.app.app_conf.pop("m2m", None)
self.app.save()
return m2m_client

def delete_m2m_client(self):
response = self._get_auth0_instance().clients.delete(id=self.app.m2m_client_id)
self.app.app_conf.pop("m2m", None)
self.app.save()
return response

def remove_auth_settings(self, env_name):
try:
secrets_require_remove = [App.AUTH0_CLIENT_ID, App.AUTH0_CLIENT_SECRET]
Expand Down
24 changes: 24 additions & 0 deletions controlpanel/api/migrations/0047_app_cloud_platform_role_arn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Generated by Django 5.1.2 on 2024-11-27 15:46

# Third-party
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("api", "0046_alter_user_options"),
]

operations = [
migrations.AddField(
model_name="app",
name="cloud_platform_role_arn",
field=models.CharField(
default=None,
help_text="The cloud platform arn for the app",
max_length=130,
null=True,
),
),
]
19 changes: 17 additions & 2 deletions controlpanel/api/models/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


class App(TimeStampedModel):

name = models.CharField(max_length=100, blank=False)
description = models.TextField(blank=True)
slug = AutoSlugField(populate_from="_repo_name", slugify_function=s3_slugify)
Expand All @@ -32,6 +33,12 @@ class App(TimeStampedModel):
related_query_name="app",
blank=True,
)
cloud_platform_role_arn = models.CharField(
help_text="The cloud platform arn for the app",
max_length=130,
null=True,
default=None,
)
res_id = models.UUIDField(unique=True, default=uuid.uuid4, editable=False)
is_bedrock_enabled = models.BooleanField(default=False)
is_textract_enabled = models.BooleanField(default=False)
Expand Down Expand Up @@ -91,6 +98,12 @@ def release_name(self):
def iam_role_arn(self):
return cluster.iam_arn(f"role/{self.iam_role_name}")

@property
def m2m_client_id(self):
if self.app_conf is None:
return None
return self.app_conf.get("m2m", {}).get("client_id")

def get_group_id(self, env_name):
return self.get_auth_client(env_name).get("group_id")

Expand Down Expand Up @@ -211,12 +224,14 @@ def delete_customers(self, user_ids, env_name=None, group_id=None):
except auth0.Auth0Error as e:
raise DeleteCustomerError from e

def delete_customer_by_email(self, email, group_id):
def delete_customer_by_email(self, email, group_id=None, env_name=None):
"""
Attempt to find a customer by email and delete them from the group.
If the user is not found, or the user does not belong to the given group, raise
an error.
"""
if not group_id:
group_id = self.get_auth_client(env_name).get("group_id")
auth0_client = auth0.ExtendedAuth0()
try:
user = auth0_client.users.get_users_email_search(
Expand All @@ -232,7 +247,7 @@ def delete_customer_by_email(self, email, group_id):
if group_id == group["_id"]:
return self.delete_customers(user_ids=[user["user_id"]], group_id=group_id)

raise DeleteCustomerError(f"User {email} cannot be found in this application group")
raise DeleteCustomerError(f"User {email} not found for this application and environment")

@property
def status(self):
Expand Down
37 changes: 36 additions & 1 deletion controlpanel/api/pagination.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Third-party
from django.core.paginator import Paginator
from django.core.paginator import InvalidPage, Paginator
from rest_framework import serializers
from rest_framework.pagination import PageNumberPagination, _positive_int
from rest_framework.response import Response
from rest_framework.utils.urls import replace_query_param


class CustomPageNumberPagination(PageNumberPagination):
Expand Down Expand Up @@ -53,3 +56,35 @@ def __init__(self, object_list, per_page, total_count=25, **kwargs):
def count(self):
"""Return the total number of objects, across all pages."""
return self.total_count


class Auth0ApiPagination(Auth0Paginator):

def __init__(self, request, page_number, *args, **kwargs):
self.request = request
super().__init__(*args, **kwargs)
self._page = self.get_page(page_number)

def get_page_url(self, page_number):
url = self.request.build_absolute_uri()
return replace_query_param(url, "page", page_number)

def get_next_link(self):
if not self._page.has_next():
return None
return self.get_page_url(self._page.next_page_number())

def get_previous_link(self):
if not self._page.has_previous():
return None
return self.get_page_url(self._page.previous_page_number())

def get_paginated_response(self):
return Response(
{
"count": self.count,
"next": self.get_next_link(),
"previous": self.get_previous_link(),
"results": self.object_list,
}
)
9 changes: 9 additions & 0 deletions controlpanel/api/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ def has_object_permission(self, request, view, obj):
return hasattr(request.user, "is_client") and request.user.is_client


class AppJwtPermissions(JWTTokenResourcePermissions):

def has_object_permission(self, request, view, obj):
if not super().has_object_permission(request, view, obj):
return False
client_id = request.user.pk.removesuffix("@clients")
return client_id == obj.m2m_client_id


class IsSuperuser(BasePermission):
"""
Only superusers are authorised
Expand Down
2 changes: 2 additions & 0 deletions controlpanel/api/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def is_app_admin(user, obj):
add_perm("api.manage_groups", is_authenticated & is_superuser)
add_perm("api.create_policys3bucket", is_authenticated & is_superuser)
add_perm("api.update_app_settings", is_authenticated & is_app_admin)
add_perm("api.customers_app", is_authenticated & is_app_admin)
add_perm("api.add_customers_app", is_authenticated & is_app_admin)
add_perm("api.update_app_ip_allowlists", is_authenticated & is_app_admin)


Expand Down
5 changes: 5 additions & 0 deletions controlpanel/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ class Meta:
)


class DeleteAppCustomerSerializer(serializers.Serializer):
email = serializers.EmailField(required=True)
env_name = serializers.CharField(max_length=64, required=True)


class ToolDeploymentSerializer(serializers.Serializer):
old_chart_name = serializers.CharField(max_length=64, required=False)
version = serializers.CharField(max_length=64, required=True)
Expand Down
19 changes: 15 additions & 4 deletions controlpanel/api/tasks/handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Third-party
import structlog
from celery import Task as CeleryTask

# First-party/Local
from controlpanel.api.models import Task

log = structlog.getLogger(__name__)


class BaseTaskHandler(CeleryTask):
# can be applied to project settings also
Expand All @@ -16,12 +19,20 @@ class BaseTaskHandler(CeleryTask):
task_obj = None

def complete(self):
if self.task_obj:
self.task_obj.completed = True
self.task_obj.save()
if not self.task_obj:
return log.warn("Task completed, but no object to mark as completed.")

self.task_obj.completed = True
self.task_obj.save()
log.info(f"Task object completed: {self.task_obj.task_id}")

def get_task_obj(self):
return Task.objects.filter(task_id=self.request.id).first()
task_id = self.request.id
log.info(f"Getting task object with ID: {task_id}")
task = Task.objects.filter(task_id=task_id).first()
if not task:
log.warn(f"Task object not found with ID: {task_id}. Continuing...")
return task

def run(self, *args, **kwargs):
self.task_obj = self.get_task_obj()
Expand Down
9 changes: 8 additions & 1 deletion controlpanel/api/tasks/handlers/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,14 @@ class S3BucketRevokeUserAccess(BaseTaskHandler):
name = "revoke_user_s3bucket_access"

def handle(self, bucket_identifier, bucket_user_pk, is_folder):
bucket_user = User.objects.get(pk=bucket_user_pk)
try:
bucket_user = User.objects.get(pk=bucket_user_pk)
except User.DoesNotExist:
# if the user doesnt exist, nothing to revoke, so mark completed
log.warn("User does not exist. Skipping...")
self.complete()
return

if is_folder:
cluster.User(bucket_user).revoke_folder_access(bucket_identifier)
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"Sid": "AllowCloudPlatformCrossAccountIAM",
"Effect": "Allow",
"Action": "sts:AssumeRole",
"Principal": {
"AWS": "{{ app_role }}"
},
"Condition": {}
}
8 changes: 8 additions & 0 deletions controlpanel/api/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def __call__(self, value):
"can only contain alphanumeric, underscores and hyphens)",
)

validate_aws_role_arn = RegexValidator(
regex=r"^arn:aws:iam::[0-9]{12}:role/[a-zA-Z0-9-_]+$",
message=(
"ARN is invalid. Check AWS ARN format "
"(for example, 'arn:aws:iam::123456789012:role/role_name')"
),
)


def validate_github_repository_url(value):
github_base_url = "https://github.com/"
Expand Down
Loading

0 comments on commit 1493325

Please sign in to comment.