Skip to content

Commit

Permalink
fix: improved handling of api call failure
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-doist committed Jul 30, 2024
1 parent f1b78a1 commit 0f56ca0
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 41 deletions.
120 changes: 82 additions & 38 deletions sqs_workers/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
Callable,
Dict,
Generator,
Iterable,
List,
Literal,
Optional,
Tuple,
TypeVar,
)

Expand All @@ -27,6 +29,7 @@
from sqs_workers.exceptions import SQSError
from sqs_workers.processors import DEFAULT_CONTEXT_VAR, Processor
from sqs_workers.shutdown_policies import NEVER_SHUTDOWN
from sqs_workers.utils import batcher

DEFAULT_MESSAGE_GROUP_ID = "default"
SEND_BATCH_SIZE = 10
Expand Down Expand Up @@ -85,53 +88,80 @@ def process_queue(self, shutdown_policy=NEVER_SHUTDOWN, wait_second=10):
)
break

def process_batch(self, wait_seconds=0) -> BatchProcessingResult:
def process_batch(self, wait_seconds: int = 0) -> BatchProcessingResult:
"""
Process a batch of messages from the queue (10 messages at most), return
the number of successfully processed messages, and exit
"""
queue = self.get_queue()

if self.batching_policy.batching_enabled:
return self._process_messages_in_batch(queue, wait_seconds)
messages = self.get_raw_messages(wait_seconds, self.batching_policy.batch_size)
success = self.process_messages(messages)
messages_with_success = ((m, success) for m in messages)
else:
messages = self.get_raw_messages(wait_seconds)
success = [self.process_message(message) for message in messages]
messages_with_success = zip(messages, success)

return self._process_messages_individually(queue, wait_seconds)
return self._handle_processed(messages_with_success)

def _process_messages_in_batch(self, queue, wait_seconds):
messages = self.get_raw_messages(wait_seconds, self.batching_policy.batch_size)
result = BatchProcessingResult(self.name)
def _handle_processed(self, messages_with_success: Iterable[Tuple[Any, bool]]):
"""
Handles the results of processing messages.
success = self.process_messages(messages)
For successful messages, we delete the message ID from the queue, which is
equivalent to acknowledging it.
for message in messages:
result.update_with_message(message, success)
if success:
entry = {
"Id": message.message_id,
"ReceiptHandle": message.receipt_handle,
}
queue.delete_messages(Entries=[entry])
else:
timeout = self.backoff_policy.get_visibility_timeout(message)
message.change_visibility(VisibilityTimeout=timeout)
return result
For failed messages, we change the visibility of the message, in order to
keep it un-consumeable for a little while (a form of backoff).
In each case (delete or change-viz), we batch the API calls to AWS in order
to try to avoid getting throttled, with batches of size 10 (the limit). The
config (see sqs_env.py) should also retry in the event of exceptions.
"""
queue = self.get_queue()

def _process_messages_individually(self, queue, wait_seconds):
messages = self.get_raw_messages(wait_seconds)
result = BatchProcessingResult(self.name)

for message in messages:
success = self.process_message(message)
result.update_with_message(message, success)
if success:
entry = {
"Id": message.message_id,
"ReceiptHandle": message.receipt_handle,
}
queue.delete_messages(Entries=[entry])
else:
timeout = self.backoff_policy.get_visibility_timeout(message)
message.change_visibility(VisibilityTimeout=timeout)
for subgroup in batcher(messages_with_success, batch_size=10):
entries_to_ack = []
entries_to_change_viz = []

for m, success in subgroup:
result.update_with_message(m, success)
if success:
entries_to_ack.append(
{
"Id": m.message_id,
"ReceiptHandle": m.receipt_handle,
}
)
else:
entries_to_change_viz.append(
{
"Id": m.message_id,
"ReceiptHandle": m.receipt_handle,
"VisibilityTimeout": self.backoff_policy.get_visibility_timeout(m),
}
)

ack_response = queue.delete_messages(Entries=entries_to_ack)

if ack_response.get("Failed"):
logger.warning(
"Failed to delete processed messages from queue",
extra={"queue": self.name, "failures": ack_response["Failed"]},
)

viz_response = queue.change_message_visibility_batch(
Entries=entries_to_change_viz,
)

if viz_response.get("Failed"):
logger.warning(
"Failed to change visibility of messages which failed to process",
extra={"queue": self.name, "failures": viz_response["Failed"]},
)

return result

def process_message(self, message: Any) -> bool:
Expand All @@ -151,15 +181,16 @@ def process_messages(self, messages: List[Any]) -> bool:
"""
raise NotImplementedError()

def get_raw_messages(self, wait_seconds, max_messages=10):
def get_raw_messages(self, wait_seconds: int, max_messages: int = 10) -> list[Any]:
"""Return raw messages from the queue, addressed by its name"""
queue = self.get_queue()

kwargs = {
"WaitTimeSeconds": wait_seconds,
"MaxNumberOfMessages": max_messages if max_messages <= 10 else 10,
"MessageAttributeNames": ["All"],
"AttributeNames": ["All"],
}
queue = self.get_queue()

if max_messages <= 10:
return queue.receive_messages(**kwargs)
Expand All @@ -180,16 +211,29 @@ def get_raw_messages(self, wait_seconds, max_messages=10):
def drain_queue(self, wait_seconds=0):
"""Delete all messages from the queue without calling purge()."""
queue = self.get_queue()

deleted_count = 0
while True:
messages = self.get_raw_messages(wait_seconds)
if not messages:
break

entries = [
{"Id": msg.message_id, "ReceiptHandle": msg.receipt_handle}
for msg in messages
]
queue.delete_messages(Entries=entries)

ack_response = queue.delete_messages(Entries=entries)

if ack_response.get("Failed"):
logger.warning(
"Failed to delete processed messages from queue",
extra={
"queue": self.name,
"failures": ack_response["Failed"],
},
)

deleted_count += len(messages)
return deleted_count

Expand Down
7 changes: 5 additions & 2 deletions sqs_workers/sqs_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import attr
import boto3
from botocore.config import Config
from typing_extensions import ParamSpec

from sqs_workers import DEFAULT_BACKOFF, RawQueue, codecs, context, processors
Expand Down Expand Up @@ -60,10 +61,12 @@ class SQSEnv:

def __attrs_post_init__(self):
self.context = self.context_maker()
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html
retry_config = Config(retries={"max_attempts": 3, "mode": "standard"})
if not self.sqs_client:
self.sqs_client = self.session.client("sqs")
self.sqs_client = self.session.client("sqs", config=retry_config)
if not self.sqs_resource:
self.sqs_resource = self.session.resource("sqs")
self.sqs_resource = self.session.resource("sqs", config=retry_config)

@overload
def queue(
Expand Down
10 changes: 9 additions & 1 deletion sqs_workers/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import importlib
import logging
from inspect import Signature
from typing import Any
from itertools import islice
from typing import Any, Iterable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -121,3 +122,10 @@ def ensure_string(obj: Any, encoding="utf-8", errors="strict") -> str:
return obj.decode(encoding, errors)
else:
return str(obj)


def batcher(iterable, batch_size) -> Iterable[Iterable[Any]]:
"""Cuts an iterable up into sub-iterables of size batch_size."""
iterator = iter(iterable)
while batch := list(islice(iterator, batch_size)):
yield batch

0 comments on commit 0f56ca0

Please sign in to comment.