diff --git a/publishing/tests/test_models.py b/publishing/tests/test_models.py index dd7371d6f..d67445480 100644 --- a/publishing/tests/test_models.py +++ b/publishing/tests/test_models.py @@ -1,3 +1,5 @@ +import threading +from functools import wraps from unittest import mock from unittest.mock import MagicMock from unittest.mock import patch @@ -5,6 +7,7 @@ import factory import freezegun import pytest +from django.db import OperationalError from django_fsm import TransitionNotAllowed from common.tests import factories @@ -469,6 +472,204 @@ def test_next_envelope_id(envelope_storage): assert Envelope.next_envelope_id() == "230002" +@pytest.mark.django_db(transaction=True) +class TestPackagingQueueRaceConditions: + """Tests that concurrent requests to reorder packaged workbaskets don't + result in duplicate or non-consecutive positions.""" + + NUM_THREADS: int = 2 + """The number of threads each test uses.""" + + THREAD_TIMEOUT: int = 2 + """The duration in seconds to wait for a thread to complete before timing + out.""" + + NUM_PWBS: int = 5 + """The number of packaged workbaskets to create for each test.""" + + @pytest.fixture(autouse=True) + def setup(self, settings): + """Initialises a barrier to synchronise threads and creates packaged + workbaskets anew for each test.""" + settings.ENABLE_PACKAGING_NOTIFICATIONS = False + + self.unexpected_exception = None + self.barrier = threading.Barrier(self.NUM_THREADS) + + for _ in range(self.NUM_PWBS): + self._create_packaged_workbasket() + + def _create_packaged_workbasket(self): + """Creates a new packaged workbasket with a unique + create_envelope_task_id.""" + with patch( + "publishing.tasks.create_xml_envelope_file.apply_async", + return_value=MagicMock(id=factory.Faker("uuid4")), + ): + factories.QueuedPackagedWorkBasketFactory() + + def assert_no_unexpected_exception(self): + """Asserts that a thread didn't raise an unexpected exception.""" + assert ( + not self.unexpected_exception + ), f"Unexpected exception raised: {self.unexpected_exception}" + + def assert_expected_positions(self): + """Asserts that positions in the packaging queue are both unique and in + consecutive sequence.""" + positions = list( + PackagedWorkBasket.objects.filter( + processing_state__in=ProcessingState.queued_states(), + ) + .order_by("position") + .values_list("position", flat=True), + ) + + assert len(set(positions)) == len(positions), "Duplicate positions found!" + + assert positions == list( + range(min(positions), max(positions) + 1), + ), "Non-consecutive positions found!" + + def synchronised(func): + """ + Decorator that ensures all threads wait until they can call their target + function in a synchronised fashion. + + Any unexpected exceptions raised during the execution of the decorated + function are stored for the individual test to re-raise. + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + try: + self.barrier.wait() + func(self, *args, **kwargs) + except (TransitionNotAllowed, OperationalError): + pass + except Exception as error: + self.unexpected_exception = error + + return wrapper + + @synchronised + def _begin_processing(self, pwb: PackagedWorkBasket): + """Wrapper method to call `pwb.begin_processing()`.""" + pwb.begin_processing() + + @synchronised + def _abandon(self, pwb: PackagedWorkBasket): + """Wrapper method to call `pwb.abandon()`.""" + pwb.abandon() + + @synchronised + def _create(self): + """Wrapper method to create a new `PackagedWorkbasket` instance.""" + self._create_packaged_workbasket() + + @synchronised + def _promote_to_top_position(self, pwb: PackagedWorkBasket): + """Wrapper method to call `pwb.promote_to_top_position()`.""" + pwb.promote_to_top_position() + + @synchronised + def _promote_position(self, pwb: PackagedWorkBasket): + """Wrapper method to call `pwb.promote_position()`.""" + pwb.promote_position() + + @synchronised + def _demote_position(self, pwb: PackagedWorkBasket): + """Wrapper method to call `pwb.demote_position()`.""" + pwb.demote_position() + + def execute_threads(self, threads: list[threading.Thread]): + """Starts a list of threads and waits for them to complete or + timeout.""" + for thread in threads: + thread.start() + + for thread in threads: + thread.join(timeout=self.THREAD_TIMEOUT) + + def test_begin_processing_and_promote_to_top(self): + """Begins processing the top-most packaged workbasket while promoting to + the top the packaged workbasket in last place.""" + pwbs = PackagedWorkBasket.objects.filter( + processing_state__in=ProcessingState.queued_states(), + ) + + thread1 = threading.Thread( + target=self._begin_processing, + kwargs={"pwb": pwbs[0]}, + ) + thread2 = threading.Thread( + target=self._promote_to_top_position, + kwargs={"pwb": pwbs[4]}, + ) + + self.execute_threads(threads=[thread1, thread2]) + self.assert_no_unexpected_exception() + self.assert_expected_positions() + + def test_promote_and_promote_to_top(self): + """Promotes to the top the last-placed packaged workbasket while + promoting the one above it.""" + pwbs = PackagedWorkBasket.objects.filter( + processing_state__in=ProcessingState.queued_states(), + ) + + thread1 = threading.Thread( + target=self._promote_to_top_position, + kwargs={"pwb": pwbs[4]}, + ) + thread2 = threading.Thread( + target=self._promote_position, + kwargs={"pwb": pwbs[3]}, + ) + + self.execute_threads(threads=[thread1, thread2]) + self.assert_no_unexpected_exception() + self.assert_expected_positions() + + def test_demote_and_promote(self): + """Demotes and promotes the same packaged workbasket.""" + pwbs = PackagedWorkBasket.objects.filter( + processing_state__in=ProcessingState.queued_states(), + ) + + thread1 = threading.Thread( + target=self._demote_position, + kwargs={"pwb": pwbs[2]}, + ) + thread2 = threading.Thread( + target=self._promote_position, + kwargs={"pwb": pwbs[2]}, + ) + + self.execute_threads(threads=[thread1, thread2]) + self.assert_no_unexpected_exception() + self.assert_expected_positions() + + def test_abandon_and_create(self): + """Abandons the last-placed packaged workbasket while creating a new + one.""" + pwbs = PackagedWorkBasket.objects.filter( + processing_state__in=ProcessingState.queued_states(), + ) + + thread1 = threading.Thread( + target=self._abandon, + kwargs={"pwb": pwbs[4]}, + ) + thread2 = threading.Thread( + target=self._create, + ) + + self.execute_threads(threads=[thread1, thread2]) + self.assert_no_unexpected_exception() + self.assert_expected_positions() + + def test_crown_dependencies_publishing_pause_and_unpause(unpause_publishing): """Test that Crown Dependencies publishing operational status can be paused and unpaused."""