Skip to content

Commit

Permalink
FS-3535 adding sqs extended client message executor service into util (
Browse files Browse the repository at this point in the history
…#144)

* FS-3535 adding sqs extended client message executor service into util

* Update version to 2.0.51

* FS-3535 adding sqs extended client message executor service into util

* Update context_aware_executor.py

* FS-3535 addressing review comments

---------

Co-authored-by: FSD Github Actions <[email protected]>
  • Loading branch information
nuwan-samarasinghe and FSD Github Actions authored Jun 5, 2024
1 parent da23778 commit 8f770bc
Show file tree
Hide file tree
Showing 7 changed files with 267 additions and 1 deletion.
Empty file.
36 changes: 36 additions & 0 deletions fsd_utils/sqs_scheduler/context_aware_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from concurrent.futures import ThreadPoolExecutor
from contextvars import copy_context


class ContextAwareExecutor:
"""This Executor copy current flask application context and then inherit the
flask application context for each executor thread then those threads will
have the ability to use flask resources with its own flask context."""

def __init__(self, max_workers, thread_name_prefix, flask_app):
"""Initialize Threadpool executor and ContextAwareExecutor :max_workers
number of workers for the thread pool :thread_name_prefix prefix of the
thread pool name :flask_app original flask application context."""
self.executor = ThreadPoolExecutor(
max_workers=max_workers, thread_name_prefix=thread_name_prefix
)
self.flask_app = flask_app

def queue_size(self):
"""Get queue size of the Thread pool."""
return self.executor._work_queue.qsize()

def submit(self, fn, *args, **kwargs):
"""Submit executor to the thread pool."""
ctx = copy_context()
future = self.executor.submit(ctx.run, self.wrap_function(fn), *args, **kwargs)
return future

def wrap_function(self, fn):
"""Wrap the function with copied application context."""

def wrapped(*args, **kwargs):
with self.flask_app.app_context():
return fn(*args, **kwargs)

return wrapped
5 changes: 5 additions & 0 deletions fsd_utils/sqs_scheduler/scheduler_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from fsd_utils.sqs_scheduler.task_executer_service import TaskExecutorService


def scheduler_executor(task_executor_service: TaskExecutorService):
task_executor_service.process_messages()
135 changes: 135 additions & 0 deletions fsd_utils/sqs_scheduler/task_executer_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import threading
from abc import abstractmethod
from concurrent.futures import as_completed

from fsd_utils.services.aws_extended_client import SQSExtendedClient


class TaskExecutorService:
def __init__(
self,
flask_app,
executor,
s3_bucket,
sqs_primary_url,
task_executor_max_thread,
sqs_batch_size,
visibility_time,
sqs_wait_time,
endpoint_url_override=None,
aws_access_key_id=None,
aws_secret_access_key=None,
region_name=None,
):
self.executor = executor
self.sqs_primary_url = sqs_primary_url
self.task_executor_max_thread = task_executor_max_thread
self.sqs_batch_size = sqs_batch_size
self.visibility_time = visibility_time
self.sqs_wait_time = sqs_wait_time
self.logger = flask_app.logger
self.sqs_extended_client = SQSExtendedClient(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=region_name,
endpoint_url=endpoint_url_override,
large_payload_support=s3_bucket,
always_through_s3=True,
delete_payload_from_s3=True,
logger=self.logger,
)
self.logger.info(
"Created the thread pool executor to process messages in extended SQS queue"
)

def process_messages(self):
"""
Scheduler calling this method based on a cron job for every given second then messages will be read
from the SQS queue in AWS and if S3 usage is allowed then it will interact each other to retrieve the messages
"""
current_thread = threading.current_thread()
thread_id = f"[{current_thread.name}:{current_thread.ident}]"
self.logger.debug(f"{thread_id} Triggered schedular to get messages")

running_threads, read_msg_ids = self._handle_message_receiving_and_processing()

self._handle_message_delete_processing(running_threads, read_msg_ids)

self.logger.debug(
f"{thread_id} Message Processing completed and will start again later"
)

@abstractmethod
def message_executor(self, message):
"""
Processing the message in a separate worker thread and this will call the GOV notify service to send emails
:param message Json message
override this for implementation
"""
pass

def _handle_message_receiving_and_processing(self):
"""
Handle message retrieve from the SQS service and get the json from S3 bucket
"""
current_thread = threading.current_thread()
thread_id = f"[{current_thread.name}:{current_thread.ident}]"
running_threads = []
read_msg_ids = []
if self.task_executor_max_thread >= self.executor.queue_size():
sqs_messages = self.sqs_extended_client.receive_messages(
self.sqs_primary_url,
self.sqs_batch_size,
self.visibility_time,
self.sqs_wait_time,
)
self.logger.debug(f"{thread_id} Message Count [{len(sqs_messages)}]")
if sqs_messages:
for message in sqs_messages:
message_id = message["sqs"]["MessageId"]
self.logger.info(f"{thread_id} Message id [{message_id}]")
read_msg_ids.append(message["sqs"]["MessageId"])
task = self.executor.submit(self.message_executor, message)
running_threads.append(task)
else:
self.logger.info(
f"{thread_id} Max thread limit reached hence stop reading messages from queue"
)

self.logger.debug(
f"{thread_id} Received Message count [{len(read_msg_ids)}] "
f"Created thread count [{len(running_threads)}]"
)
return running_threads, read_msg_ids

def _handle_message_delete_processing(self, running_threads, read_msg_ids):
"""
Handling the message delete process from the SQS and S3 bucket if it is completed
:param read_msg_ids All the message ids that taken from SQS
:param running_threads Executing tasks to send emails
"""
current_thread = threading.current_thread()
thread_id = f"[{current_thread.name}:{current_thread.ident}]"
receipt_handles_to_delete = []
completed_msg_ids = []
for future in as_completed(running_threads):
try:
msg = future.result()
msg_id = msg["sqs"]["MessageId"]
receipt_handles_to_delete.append(msg["sqs"])
completed_msg_ids.append(msg_id)
self.logger.debug(
f"{thread_id} Execution completed and deleted from queue: {msg_id}"
)
except Exception as e:
self.logger.error(
f"{thread_id} An error occurred while processing the message {e}"
)
dif_msg_ids = [i for i in read_msg_ids if i not in completed_msg_ids]
self.logger.debug(
f"No of messages not processed [{len(dif_msg_ids)}] and msg ids are {dif_msg_ids}"
)
if receipt_handles_to_delete:
self.sqs_extended_client.delete_messages(
self.sqs_primary_url, receipt_handles_to_delete
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "funding-service-design-utils"

version = "2.0.50"
version = "2.0.51"

authors = [
{ name="DLUHC", email="[email protected]" },
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Flask-Migrate
Flask-SQLAlchemy>=3.0.3
sqlalchemy-utils==0.38.3
beautifulsoup4==4.12.2
moto[s3,sqs]==5.0.7
89 changes: 89 additions & 0 deletions tests/test_task_executor_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import unittest
from unittest.mock import MagicMock
from uuid import uuid4

import boto3
from fsd_utils.sqs_scheduler.context_aware_executor import ContextAwareExecutor
from fsd_utils.sqs_scheduler.task_executer_service import TaskExecutorService
from moto import mock_aws


class TestTaskExecutorService(unittest.TestCase):
@mock_aws
def test_message_in_mock_environment_processing(self):
"""
This test ensure that when message is there and if no errors occurred while processing the message
then successfully removed it from the queue
"""
self._mock_aws_client()
self._add_data_to_queue()

self.task_executor.process_messages()

self._check_is_data_available(0)

def _mock_aws_client(self):
"""
Mocking aws resources and this will act as real aws environment behaviour
"""
bucket_name = "fsd_msg_s3_bucket"
self.flask_app = MagicMock()
self.executor = ContextAwareExecutor(
max_workers=10, thread_name_prefix="NotifTask", flask_app=self.flask_app
)
s3_connection = boto3.client(
"s3",
region_name="us-east-1",
aws_access_key_id="test_accesstoken", # pragma: allowlist secret
aws_secret_access_key="secret_key", # pragma: allowlist secret
)
sqs_connection = boto3.client(
"sqs",
region_name="us-east-1",
aws_access_key_id="test_accesstoken", # pragma: allowlist secret
aws_secret_access_key="secret_key", # pragma: allowlist secret
)
s3_connection.create_bucket(Bucket=bucket_name)
self.queue_response = sqs_connection.create_queue(
QueueName="notif-queue.fifo", Attributes={"FifoQueue": "true"}
)
self.task_executor = AnyTaskExecutorService(
flask_app=MagicMock(),
executor=self.executor,
s3_bucket=bucket_name,
sqs_primary_url=self.queue_response["QueueUrl"],
task_executor_max_thread=5,
sqs_batch_size=10,
visibility_time=1,
sqs_wait_time=2,
endpoint_url_override=None,
aws_access_key_id="test_accesstoken", # pragma: allowlist secret
aws_secret_access_key="secret_key", # pragma: allowlist secret
region_name="us-east-1",
)
self.task_executor.sqs_extended_client.sqs_client = sqs_connection
self.task_executor.sqs_extended_client.s3_client = s3_connection

def _add_data_to_queue(self):
"""
Adding test data into the queue
"""
for x in range(1):
message_id = self.task_executor.sqs_extended_client.submit_single_message(
queue_url=self.queue_response["QueueUrl"],
message="message",
message_group_id="import_applications_group",
message_deduplication_id=str(uuid4()), # ensures message uniqueness
)
assert message_id is not None

def _check_is_data_available(self, count):
response = self.task_executor.sqs_extended_client.receive_messages(
queue_url=self.queue_response["QueueUrl"], max_number=1
)
assert len(response) == count


class AnyTaskExecutorService(TaskExecutorService):
def message_executor(self, message):
return message

0 comments on commit 8f770bc

Please sign in to comment.