Skip to content

Commit

Permalink
Merge pull request #1208 from ministryofjustice/ANPL-1704-bugfixes
Browse files Browse the repository at this point in the history
Anpl 1704 bugfixes
  • Loading branch information
michaeljcollinsuk authored Sep 29, 2023
2 parents 7d2c8ec + ea69bf8 commit 4184f45
Show file tree
Hide file tree
Showing 17 changed files with 156 additions and 55 deletions.
13 changes: 11 additions & 2 deletions controlpanel/api/models/access_to_s3bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ def resources(self):
# MUST use signals because cascade deletes do not call delete()
@receiver(models.signals.pre_delete)
def revoke_access(sender, **kwargs):
if issubclass(sender, AccessToS3Bucket):
obj = kwargs["instance"]
"""
Revokes access when the delete is via cascade delete, as these do not call the
instance level delete() method. Checks that the origin is different to the instance,
to ensure that revoke_bucket_access is not called twice when deleting an access
object directly.
"""
if not issubclass(sender, AccessToS3Bucket):
return

obj = kwargs["instance"]
if obj != kwargs["origin"]:
obj.revoke_bucket_access()
4 changes: 4 additions & 0 deletions controlpanel/api/models/policys3bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class Meta:
unique_together = ("policy", "s3bucket")
ordering = ("id",)

def __init__(self, *args, **kwargs):
self.current_user = kwargs.pop("current_user", None)
super().__init__(*args, **kwargs)

def grant_bucket_access(self):
if self.s3bucket.is_folder:
return cluster.RoleGroup(self.policy).grant_folder_access(
Expand Down
1 change: 1 addition & 0 deletions controlpanel/api/tasks/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class AppCreateRole(TaskBase):
ENTITY_CLASS = "App"
QUEUE_NAME = settings.IAM_QUEUE_NAME

@property
def task_name(self):
Expand Down
33 changes: 21 additions & 12 deletions controlpanel/api/tasks/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,31 @@
from celery import Task as CeleryTask

# First-party/Local
from controlpanel.api.models import Task, User
from controlpanel.api.models import Task


class BaseTaskHandler(CeleryTask):
# can be applied to project settings also
# these settings mean that messages are only removed from the queue (acknowledged)
# when returned. if an error occurs, they remain in the queue, and will be resent
# to the worker when the "visibility_timeout" has expired. "visibility_timeout" is
# a setting that is configured in SQS per queue. Currently set to 30secs
acks_late = True
acks_on_failure_or_timeout = False
task_obj = None

def complete(self):
task = Task.objects.filter(task_id=self.request.id).first()
if task:
task.completed = True
task.save()
if self.task_obj:
self.task_obj.completed = True
self.task_obj.save()

def get_task_obj(self):
return Task.objects.filter(task_id=self.request.id).first()

def run(self, *args, **kwargs):
self.task_obj = self.get_task_obj()
if self.task_obj and self.task_obj.completed:
return
self.handle(*args, **kwargs)

def handle(self, *args, **kwargs):
Expand All @@ -29,13 +42,6 @@ class BaseModelTaskHandler(BaseTaskHandler):
model = None
object = None
task_user_pk = None
# can be applied to project settings also
# these settings mean that messages are only removed from the queue (acknowledged)
# when returned. if an error occurs, they remain in the queue, and will be resent
# to the worker when the "visibility_timeout" has expired. "visibility_timeout" is
# a setting that is configured in SQS per queue. Currently set to 30secs
acks_late = True
acks_on_failure_or_timeout = False

def get_object(self, pk):
try:
Expand All @@ -53,6 +59,9 @@ def run(self, obj_pk, task_user_pk, *args, **kwargs):
to look up the user later if required. The `handle` method is then called
with any other args and kwargs sent.
"""
self.task_obj = self.get_task_obj()
if self.task_obj and self.task_obj.completed:
return
self.object = self.get_object(obj_pk)
self.task_user_pk = task_user_pk
self.handle(*args, **kwargs)
1 change: 1 addition & 0 deletions controlpanel/api/tasks/s3bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def _get_args_list(self):
class S3AccessMixin:
ACTION = None
ROLE = None
QUEUE_NAME = settings.IAM_QUEUE_NAME

@property
def task_name(self):
Expand Down
18 changes: 11 additions & 7 deletions controlpanel/celery.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# Standard library
import os
from celery import Celery
import dotenv
from kombu import Queue
from pathlib import Path

# Third-party
import dotenv
import structlog
from celery import Celery
from django.conf import settings
from kombu import Queue

# First-party/Local
from controlpanel.utils import load_app_conf_from_file
from django.conf import settings

dotenv.load_dotenv()

Expand Down Expand Up @@ -43,7 +46,8 @@ def worker_health_check(self):
# ensures worker picks and runs tasks from all queues rather than just default queue
# alternative is to run the worker and pass queue name to -Q flag
app.conf.task_queues = [
Queue(settings.IAM_QUEUE_NAME),
Queue(settings.AUTH_QUEUE_NAME),
Queue(settings.S3_QUEUE_NAME),
Queue(settings.IAM_QUEUE_NAME, routing_key=settings.IAM_QUEUE_NAME),
Queue(settings.AUTH_QUEUE_NAME, routing_key=settings.AUTH_QUEUE_NAME),
Queue(settings.S3_QUEUE_NAME, routing_key=settings.S3_QUEUE_NAME),
]
app.conf.task_default_exchange = "tasks"
2 changes: 1 addition & 1 deletion controlpanel/frontend/jinja2/datasource-detail.html
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ <h1 class="govuk-heading-xl">{{ page_title }}</h1>
</a>
</p>

<section class="cpanel-section">
<section class="cpanel-section track_task">
<h2 class="govuk-heading-m">Users and groups with access</h2>
<table class="govuk-table">
<thead class="govuk-table__head">
Expand Down
2 changes: 2 additions & 0 deletions controlpanel/frontend/jinja2/includes/task-list.html
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
<th class="govuk-table__header">Entity class</th>
<th class="govuk-table__header">Entity ID</th>
<th class="govuk-table__header">Entity description</th>
<th class="govuk-table__header">Task ID</th>
<th class="govuk-table__header">Task description</th>
<th class="govuk-table__header">Create time</th>
<th class="govuk-table__header">
Expand All @@ -19,6 +20,7 @@
<td class="govuk-table__cell">{{ task.entity_class }}</td>
<td class="govuk-table__cell">{{ task.entity_id }}</td>
<td class="govuk-table__cell">{{ task.entity_description }}</td>
<td class="govuk-table__cell">{{ task.task_id }}</td>
<td class="govuk-table__cell">{{ task.task_description }}</td>
<td class="govuk-table__cell">{{ task.created }}</td>
<td class="govuk-table__cell">
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Standard library
import random
from datetime import datetime, timedelta
from pathlib import Path
from sys import exit

# Third-party
from django.core.management.base import BaseCommand
from django.conf import settings
from django.core.management.base import BaseCommand

# First-party/Local
from controlpanel.celery import worker_health_check
Expand All @@ -24,8 +25,10 @@ def add_arguments(self, parser):

def handle(self, *args, **options):
stale_after_secs = options["stale_after_secs"]

worker_health_check.delay().get()
# send task to randomly chosen queue
worker_health_check.apply_async(
queue=random.choice(settings.PRE_DEFINED_QUEUES)
)
# Attempt to read worker health ping file
# NOTE: This may initially fail depending on timing of health task
# execution but that's fine as Kubernetes' `failureThreashold`
Expand Down
8 changes: 3 additions & 5 deletions controlpanel/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,19 +545,17 @@
AUTH_QUEUE_NAME = os.environ.get("AUTH_QUEUE_NAME", "control-panel-auth")

BROKER_URL = os.environ.get("BROKER_URL", "sqs://")
DEFAULT_QUEUE = IAM_QUEUE_NAME
DEFAULT_QUEUE = AUTH_QUEUE_NAME
DEFAULT_BACKOFF_POLICY = {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}
PRE_DEFINED_QUEUES = [IAM_QUEUE_NAME, S3_QUEUE_NAME, AUTH_QUEUE_NAME]
CELERY_DEFAULT_QUEUE = DEFAULT_QUEUE
SQS_REGION = os.environ.get("SQS_REGION", "eu-west-2")

BROKER_TRANSPORT_OPTIONS = {
"polling_interval": 10,
"polling_interval": 1,
"region": SQS_REGION,
"wait_time_seconds": 20,
"wait_time_seconds": 0,
"predefined_queues": {}
}

for queue in PRE_DEFINED_QUEUES:
BROKER_TRANSPORT_OPTIONS['predefined_queues'][queue] = {
'url': f'https://sqs.{SQS_REGION}.amazonaws.com/{AWS_DATA_ACCOUNT_ID}/{queue}',
Expand Down
2 changes: 1 addition & 1 deletion tests/api/fixtures/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def sqs(aws_creds):
with moto.mock_sqs():
sqs = boto3.resource("sqs")
sqs.create_queue(QueueName=settings.DEFAULT_QUEUE)
sqs.create_queue(QueueName=settings.AUTH_QUEUE_NAME)
sqs.create_queue(QueueName=settings.IAM_QUEUE_NAME)
sqs.create_queue(QueueName=settings.S3_QUEUE_NAME)
yield sqs

Expand Down
14 changes: 8 additions & 6 deletions tests/api/models/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def update_aws_secrets_manager():
@pytest.fixture
def app():
app = mommy.make("api.App")
app.repo_url="https://github.com/example.com/repo_name"
app.repo_url = "https://github.com/example.com/repo_name"
auth_settings = dict(
client_id="testing_client_id",
group_id="testing_group_id"
Expand All @@ -46,9 +46,9 @@ def app():
def test_create(sqs, helpers):
repo_url = "https://example.com/foo__bar-baz!bat-1337"
app = App.objects.create(repo_url=repo_url)
iam_messages = helpers.retrieve_messages(sqs, queue_name=settings.DEFAULT_QUEUE)
iam_messages = helpers.retrieve_messages(sqs, queue_name=settings.IAM_QUEUE_NAME)
helpers.validate_task_with_sqs_messages(
iam_messages, App.__name__, app.id, queue_name=settings.DEFAULT_QUEUE
iam_messages, App.__name__, app.id, queue_name=settings.IAM_QUEUE_NAME
)
auth_messages = helpers.retrieve_messages(sqs, queue_name=settings.AUTH_QUEUE_NAME)
helpers.validate_task_with_sqs_messages(
Expand Down Expand Up @@ -195,8 +195,11 @@ def test_app_allowed_ip_ranges():
]
app = mommy.make("api.App") # noqa:F841
for item in ip_allow_lists:
mommy.make("api.AppIPAllowList",
app_id=app.id, ip_allowlist_id=item.id, deployment_env="test"
mommy.make(
"api.AppIPAllowList",
app_id=app.id,
ip_allowlist_id=item.id,
deployment_env="test",
)
app_ip_ranges = app.env_allowed_ip_ranges("test")
assert " " not in app_ip_ranges
Expand All @@ -205,4 +208,3 @@ def test_app_allowed_ip_ranges():
full_app_ip_ranges = app.app_allowed_ip_ranges
assert " " not in full_app_ip_ranges
assert len(full_app_ip_ranges.split(",")) == 4

9 changes: 5 additions & 4 deletions tests/api/models/test_apps3bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# Third-party
import pytest
from django.conf import settings
from django.db.utils import IntegrityError
from model_mommy import mommy

Expand Down Expand Up @@ -45,9 +46,9 @@ def test_aws_permissions(app, bucket, sqs, helpers):
)

apps3bucket.save()
messages = helpers.retrieve_messages(sqs)
messages = helpers.retrieve_messages(sqs, queue_name=settings.IAM_QUEUE_NAME)
helpers.validate_task_with_sqs_messages(
messages, AppS3Bucket.__name__, apps3bucket.id
messages, AppS3Bucket.__name__, apps3bucket.id, settings.IAM_QUEUE_NAME
)


Expand All @@ -65,8 +66,8 @@ def test_delete_revoke_permissions(app, bucket):

apps3bucket.delete()

revoke_bucket_access_task.assert_called_with(
revoke_bucket_access_task.assert_called_once_with(
apps3bucket,
None
)
revoke_bucket_access_task.return_value.create_task.assert_called()
revoke_bucket_access_task.return_value.create_task.assert_called_once()
9 changes: 5 additions & 4 deletions tests/api/models/test_users3bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# Third-party
import pytest
from django.conf import settings
from django.db.utils import IntegrityError
from model_mommy import mommy

Expand Down Expand Up @@ -45,9 +46,9 @@ def test_aws_create_bucket(user, bucket, sqs, helpers):
s3bucket=bucket,
access_level=AccessToS3Bucket.READONLY
)
messages = helpers.retrieve_messages(sqs)
messages = helpers.retrieve_messages(sqs, settings.IAM_QUEUE_NAME)
helpers.validate_task_with_sqs_messages(
messages, UserS3Bucket.__name__, users3bucket.id
messages, UserS3Bucket.__name__, users3bucket.id, settings.IAM_QUEUE_NAME
)


Expand All @@ -74,5 +75,5 @@ def test_delete_revoke_permissions(bucket, users3bucket):
"controlpanel.api.tasks.S3BucketRevokeUserAccess"
) as revoke_user_access_task:
users3bucket.delete()
revoke_user_access_task.assert_called_with(users3bucket, None)
revoke_user_access_task.return_value.create_task.assert_called()
revoke_user_access_task.assert_called_once_with(users3bucket, None)
revoke_user_access_task.return_value.create_task.assert_called_once()
60 changes: 60 additions & 0 deletions tests/api/tasks/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Standard library
from unittest.mock import MagicMock, patch

# Third-party
import pytest

# First-party/Local
from controlpanel.api.models import Task
from controlpanel.api.tasks.handlers.base import BaseModelTaskHandler, BaseTaskHandler


@pytest.mark.parametrize("handler_cls, args", [
(BaseTaskHandler, (None,)),
(BaseModelTaskHandler, (1, 1)),
])
@patch("controlpanel.api.tasks.handlers.base.BaseTaskHandler.handle")
def test_completed_task_handle_not_run(handle, handler_cls, args):
completed_task = MagicMock(spec=Task, completed=True)
base_task_handler = handler_cls()

with patch.object(base_task_handler, "get_task_obj", return_value=completed_task):
base_task_handler.run(*args)

handle.assert_not_called()


@pytest.mark.parametrize("handler_cls, args", [
(BaseTaskHandler, (None,)),
(BaseModelTaskHandler, (1, 1)),
])
@patch("controlpanel.api.tasks.handlers.base.BaseTaskHandler.handle")
@patch(
"controlpanel.api.tasks.handlers.base.BaseModelTaskHandler.get_object",
new=MagicMock,
)
def test_uncompleted_task_handle_is_run(handle, handler_cls, args):
completed_task = MagicMock(spec=Task, completed=False)
base_task_handler = handler_cls()

with patch.object(base_task_handler, "get_task_obj", return_value=completed_task):
base_task_handler.run(*args)

handle.assert_called_once()


@pytest.mark.parametrize("handler_cls, args", [
(BaseTaskHandler, (None,)),
(BaseModelTaskHandler, (1, 1)),
])
@patch("controlpanel.api.tasks.handlers.base.BaseTaskHandler.handle")
@patch(
"controlpanel.api.tasks.handlers.base.BaseModelTaskHandler.get_object",
new=MagicMock,
)
def test_no_task_obj_handle_is_run(handle, handler_cls, args):
base_task_handler = handler_cls()
with patch.object(base_task_handler, "get_task_obj", return_value=None):
base_task_handler.run(*args)

handle.assert_called_once()
Loading

0 comments on commit 4184f45

Please sign in to comment.