diff --git a/sqs_workers/memory_sqs.py b/sqs_workers/memory_sqs.py index e1a54ae..bb8bdb0 100644 --- a/sqs_workers/memory_sqs.py +++ b/sqs_workers/memory_sqs.py @@ -107,7 +107,7 @@ class MemoryQueue: name: str = attr.ib() attributes: Dict[str, Dict[str, str]] = attr.ib() messages: List["MemoryMessage"] = attr.ib(factory=list) - in_flight: List["MemoryMessage"] = attr.ib(factory=list) + in_flight: Dict[str, "MemoryMessage"] = attr.ib(factory=dict) def __attrs_post_init__(self): self.attributes["QueueArn"] = self.name @@ -147,7 +147,8 @@ def receive_messages(self, WaitTimeSeconds="0", MaxNumberOfMessages="10", **kwar else: ready_messages.append(message) self.messages[:] = push_back_messages - self.in_flight.extend(ready_messages) + for m in ready_messages: + self.in_flight[m.message_id] = m return ready_messages def delete_messages(self, Entries): @@ -157,17 +158,19 @@ def delete_messages(self, Entries): See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/ services/sqs.html#SQS.Queue.delete_messages """ - message_ids = {entry["Id"] for entry in Entries} + found_entries = [] + not_found_entries = [] - successfully_deleted = set() - - for i, message in enumerate(self.in_flight): - if message.message_id in message_ids: - successfully_deleted.add(message.message_id) - del self.in_flight[i] + for e in Entries: + if e["Id"] in self.in_flight: + found_entries.append(e) + self.in_flight.pop(e["Id"]) + else: + not_found_entries.append(e) return { - "Successful": [{"Id": _id} for _id in successfully_deleted], + "Successful": [{"Id": e["Id"]} for e in found_entries], + "Failed": [{"Id": e["Id"]} for e in not_found_entries], } def change_message_visibility_batch(self, Entries): @@ -175,25 +178,26 @@ def change_message_visibility_batch(self, Entries): Changes message visibility by looking at in-flight messages, setting a new execute_at, and returning it to the pool of processable messages """ - edited = [] - return_to_pool = [] - entries_by_id = {e["Id"]: e for e in Entries} - - for i, m in enumerate(self.in_flight): - if m.message_id in entries_by_id.keys(): - sec = int(entries_by_id[m.message_id]["VisibilityTimeout"]) + found_entries = [] + not_found_entries = [] + + for e in Entries: + if e["Id"] in self.in_flight: + found_entries.append(e) + in_flight_message = self.in_flight[e["Id"]] + sec = int(e["VisibilityTimeout"]) now = datetime.datetime.utcnow() execute_at = now + datetime.timedelta(seconds=sec) - changed = attr.evolve(m, execute_at=execute_at) - changed.attributes["ApproximateReceiveCount"] += 1 - edited.append(changed) - return_to_pool.append(changed) - del self.in_flight[i] - - self.messages.extend(return_to_pool) + updated_message = attr.evolve(in_flight_message, execute_at=execute_at) + updated_message.attributes["ApproximateReceiveCount"] += 1 + self.messages.append(updated_message) + self.in_flight.pop(e["Id"]) + else: + not_found_entries.append(e) return { - "Successful": [{"Id": _id} for _id in edited], + "Successful": [{"Id": e["Id"]} for e in found_entries], + "Failed": [{"Id": e["Id"]} for e in not_found_entries], } def delete(self):