diff --git a/fsd_utils/sqs_scheduler/__init__.py b/fsd_utils/sqs_scheduler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fsd_utils/sqs_scheduler/context_aware_executor.py b/fsd_utils/sqs_scheduler/context_aware_executor.py new file mode 100644 index 0000000..0cd360f --- /dev/null +++ b/fsd_utils/sqs_scheduler/context_aware_executor.py @@ -0,0 +1,36 @@ +from concurrent.futures import ThreadPoolExecutor +from contextvars import copy_context + + +class ContextAwareExecutor: + """This Executor copy current flask application context and then inherit the + flask application context for each executor thread then those threads will + have the ability to use flask resources with its own flask context.""" + + def __init__(self, max_workers, thread_name_prefix, flask_app): + """Initialize Threadpool executor and ContextAwareExecutor :max_workers + number of workers for the thread pool :thread_name_prefix prefix of the + thread pool name :flask_app original flask application context.""" + self.executor = ThreadPoolExecutor( + max_workers=max_workers, thread_name_prefix=thread_name_prefix + ) + self.flask_app = flask_app + + def queue_size(self): + """Get queue size of the Thread pool.""" + return self.executor._work_queue.qsize() + + def submit(self, fn, *args, **kwargs): + """Submit executor to the thread pool.""" + ctx = copy_context() + future = self.executor.submit(ctx.run, self.wrap_function(fn), *args, **kwargs) + return future + + def wrap_function(self, fn): + """Wrap the function with copied application context.""" + + def wrapped(*args, **kwargs): + with self.flask_app.app_context(): + return fn(*args, **kwargs) + + return wrapped diff --git a/fsd_utils/sqs_scheduler/scheduler_service.py b/fsd_utils/sqs_scheduler/scheduler_service.py new file mode 100644 index 0000000..e25e94f --- /dev/null +++ b/fsd_utils/sqs_scheduler/scheduler_service.py @@ -0,0 +1,5 @@ +from fsd_utils.sqs_scheduler.task_executer_service import TaskExecutorService + + +def scheduler_executor(task_executor_service: TaskExecutorService): + task_executor_service.process_messages() diff --git a/fsd_utils/sqs_scheduler/task_executer_service.py b/fsd_utils/sqs_scheduler/task_executer_service.py new file mode 100644 index 0000000..77116a0 --- /dev/null +++ b/fsd_utils/sqs_scheduler/task_executer_service.py @@ -0,0 +1,135 @@ +import threading +from abc import abstractmethod +from concurrent.futures import as_completed + +from fsd_utils.services.aws_extended_client import SQSExtendedClient + + +class TaskExecutorService: + def __init__( + self, + flask_app, + executor, + s3_bucket, + sqs_primary_url, + task_executor_max_thread, + sqs_batch_size, + visibility_time, + sqs_wait_time, + endpoint_url_override=None, + aws_access_key_id=None, + aws_secret_access_key=None, + region_name=None, + ): + self.executor = executor + self.sqs_primary_url = sqs_primary_url + self.task_executor_max_thread = task_executor_max_thread + self.sqs_batch_size = sqs_batch_size + self.visibility_time = visibility_time + self.sqs_wait_time = sqs_wait_time + self.logger = flask_app.logger + self.sqs_extended_client = SQSExtendedClient( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=region_name, + endpoint_url=endpoint_url_override, + large_payload_support=s3_bucket, + always_through_s3=True, + delete_payload_from_s3=True, + logger=self.logger, + ) + self.logger.info( + "Created the thread pool executor to process messages in extended SQS queue" + ) + + def process_messages(self): + """ + Scheduler calling this method based on a cron job for every given second then messages will be read + from the SQS queue in AWS and if S3 usage is allowed then it will interact each other to retrieve the messages + """ + current_thread = threading.current_thread() + thread_id = f"[{current_thread.name}:{current_thread.ident}]" + self.logger.debug(f"{thread_id} Triggered schedular to get messages") + + running_threads, read_msg_ids = self._handle_message_receiving_and_processing() + + self._handle_message_delete_processing(running_threads, read_msg_ids) + + self.logger.debug( + f"{thread_id} Message Processing completed and will start again later" + ) + + @abstractmethod + def message_executor(self, message): + """ + Processing the message in a separate worker thread and this will call the GOV notify service to send emails + :param message Json message + override this for implementation + """ + pass + + def _handle_message_receiving_and_processing(self): + """ + Handle message retrieve from the SQS service and get the json from S3 bucket + """ + current_thread = threading.current_thread() + thread_id = f"[{current_thread.name}:{current_thread.ident}]" + running_threads = [] + read_msg_ids = [] + if self.task_executor_max_thread >= self.executor.queue_size(): + sqs_messages = self.sqs_extended_client.receive_messages( + self.sqs_primary_url, + self.sqs_batch_size, + self.visibility_time, + self.sqs_wait_time, + ) + self.logger.debug(f"{thread_id} Message Count [{len(sqs_messages)}]") + if sqs_messages: + for message in sqs_messages: + message_id = message["sqs"]["MessageId"] + self.logger.info(f"{thread_id} Message id [{message_id}]") + read_msg_ids.append(message["sqs"]["MessageId"]) + task = self.executor.submit(self.message_executor, message) + running_threads.append(task) + else: + self.logger.info( + f"{thread_id} Max thread limit reached hence stop reading messages from queue" + ) + + self.logger.debug( + f"{thread_id} Received Message count [{len(read_msg_ids)}] " + f"Created thread count [{len(running_threads)}]" + ) + return running_threads, read_msg_ids + + def _handle_message_delete_processing(self, running_threads, read_msg_ids): + """ + Handling the message delete process from the SQS and S3 bucket if it is completed + :param read_msg_ids All the message ids that taken from SQS + :param running_threads Executing tasks to send emails + """ + current_thread = threading.current_thread() + thread_id = f"[{current_thread.name}:{current_thread.ident}]" + receipt_handles_to_delete = [] + completed_msg_ids = [] + for future in as_completed(running_threads): + try: + msg = future.result() + msg_id = msg["sqs"]["MessageId"] + receipt_handles_to_delete.append(msg["sqs"]) + completed_msg_ids.append(msg_id) + self.logger.debug( + f"{thread_id} Execution completed and deleted from queue: {msg_id}" + ) + except Exception as e: + self.logger.error( + f"{thread_id} An error occurred while processing the message {e}" + ) + dif_msg_ids = [i for i in read_msg_ids if i not in completed_msg_ids] + self.logger.debug( + f"No of messages not processed [{len(dif_msg_ids)}] and msg ids are {dif_msg_ids}" + ) + if receipt_handles_to_delete: + self.sqs_extended_client.delete_messages( + self.sqs_primary_url, receipt_handles_to_delete + ) diff --git a/pyproject.toml b/pyproject.toml index e1388f7..02f5d72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "funding-service-design-utils" -version = "2.0.50" +version = "2.0.51" authors = [ { name="DLUHC", email="FundingServiceDesignTeam@levellingup.gov.uk" }, diff --git a/requirements-dev.txt b/requirements-dev.txt index 021b5d3..114c267 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -11,3 +11,4 @@ Flask-Migrate Flask-SQLAlchemy>=3.0.3 sqlalchemy-utils==0.38.3 beautifulsoup4==4.12.2 +moto[s3,sqs]==5.0.7 diff --git a/tests/test_task_executor_service.py b/tests/test_task_executor_service.py new file mode 100644 index 0000000..c35b916 --- /dev/null +++ b/tests/test_task_executor_service.py @@ -0,0 +1,89 @@ +import unittest +from unittest.mock import MagicMock +from uuid import uuid4 + +import boto3 +from fsd_utils.sqs_scheduler.context_aware_executor import ContextAwareExecutor +from fsd_utils.sqs_scheduler.task_executer_service import TaskExecutorService +from moto import mock_aws + + +class TestTaskExecutorService(unittest.TestCase): + @mock_aws + def test_message_in_mock_environment_processing(self): + """ + This test ensure that when message is there and if no errors occurred while processing the message + then successfully removed it from the queue + """ + self._mock_aws_client() + self._add_data_to_queue() + + self.task_executor.process_messages() + + self._check_is_data_available(0) + + def _mock_aws_client(self): + """ + Mocking aws resources and this will act as real aws environment behaviour + """ + bucket_name = "fsd_msg_s3_bucket" + self.flask_app = MagicMock() + self.executor = ContextAwareExecutor( + max_workers=10, thread_name_prefix="NotifTask", flask_app=self.flask_app + ) + s3_connection = boto3.client( + "s3", + region_name="us-east-1", + aws_access_key_id="test_accesstoken", # pragma: allowlist secret + aws_secret_access_key="secret_key", # pragma: allowlist secret + ) + sqs_connection = boto3.client( + "sqs", + region_name="us-east-1", + aws_access_key_id="test_accesstoken", # pragma: allowlist secret + aws_secret_access_key="secret_key", # pragma: allowlist secret + ) + s3_connection.create_bucket(Bucket=bucket_name) + self.queue_response = sqs_connection.create_queue( + QueueName="notif-queue.fifo", Attributes={"FifoQueue": "true"} + ) + self.task_executor = AnyTaskExecutorService( + flask_app=MagicMock(), + executor=self.executor, + s3_bucket=bucket_name, + sqs_primary_url=self.queue_response["QueueUrl"], + task_executor_max_thread=5, + sqs_batch_size=10, + visibility_time=1, + sqs_wait_time=2, + endpoint_url_override=None, + aws_access_key_id="test_accesstoken", # pragma: allowlist secret + aws_secret_access_key="secret_key", # pragma: allowlist secret + region_name="us-east-1", + ) + self.task_executor.sqs_extended_client.sqs_client = sqs_connection + self.task_executor.sqs_extended_client.s3_client = s3_connection + + def _add_data_to_queue(self): + """ + Adding test data into the queue + """ + for x in range(1): + message_id = self.task_executor.sqs_extended_client.submit_single_message( + queue_url=self.queue_response["QueueUrl"], + message="message", + message_group_id="import_applications_group", + message_deduplication_id=str(uuid4()), # ensures message uniqueness + ) + assert message_id is not None + + def _check_is_data_available(self, count): + response = self.task_executor.sqs_extended_client.receive_messages( + queue_url=self.queue_response["QueueUrl"], max_number=1 + ) + assert len(response) == count + + +class AnyTaskExecutorService(TaskExecutorService): + def message_executor(self, message): + return message