diff --git a/sharded_queue/__init__.py b/sharded_queue/__init__.py index 1dbe040..6ad1393 100644 --- a/sharded_queue/__init__.py +++ b/sharded_queue/__init__.py @@ -102,10 +102,23 @@ async def register( if recurrent: if_not_exists = True + tube = Tube(RecurrentHandler, Route()) + length = await self.storage.length(tube.pipe) + messages = await self.storage.range(tube.pipe, length) + pipe_messages = RecurrentHandler.transform( pipe_messages, recurrent, self.serializer ) + recurrent_tuples = [ + (request.pipe, request.msg) for (_, request) in pipe_messages + ] + + for msg in reversed(messages): + request = self.serializer.deserialize(RecurrentRequest, msg) + if (request.pipe, request.msg) in recurrent_tuples: + await self.storage.remove(tube.pipe, msg) + if defer: if_not_exists = True pipe_messages = DeferredHandler.transform( diff --git a/tests/test_recurrent.py b/tests/test_recurrent.py index b9832c6..734aac7 100644 --- a/tests/test_recurrent.py +++ b/tests/test_recurrent.py @@ -1,11 +1,12 @@ from asyncio import sleep -from datetime import timedelta +from datetime import datetime, timedelta from typing import NamedTuple from pytest import mark -from sharded_queue import (DeferredHandler, Handler, Queue, RecurrentHandler, - Route, Tube, Worker) +from sharded_queue import (DeferredHandler, DeferredRequest, Handler, Queue, + RecurrentHandler, RecurrentRequest, Route, Tube, + Worker) from sharded_queue.drivers import RuntimeLock, RuntimeStorage from sharded_queue.protocols import Lock, Storage @@ -41,7 +42,7 @@ async def stats() -> tuple[int, int, int]: ) await queue.register( - ValidateAccess, CompanyRequest(1), recurrent=timedelta(milliseconds=10) + ValidateAccess, CompanyRequest(1), recurrent=timedelta(seconds=1) ) assert await stats() == (0, 1, 0), 'recurrent pipe contains request' @@ -58,7 +59,7 @@ async def stats() -> tuple[int, int, int]: await Worker(lock, queue).loop(1, handler=RecurrentHandler) assert await stats() == (1, 1, 0), 'no deffered duplicates' - await sleep(0.01) + await sleep(1) await lock.release(recurrent_pipe) await Worker(lock, queue).loop(1, handler=RecurrentHandler) @@ -71,7 +72,20 @@ async def stats() -> tuple[int, int, int]: await worker.loop(1, handler=RecurrentHandler) assert await stats() == (1, 1, 1), 'deferred added' - await sleep(0.01) + await sleep(1) await worker.loop(1, handler=DeferredHandler) assert await stats() == (0, 1, 1), 'no validation duplicate' + + [recurrent] = await queue.storage.range(recurrent_pipe, 1) + request = queue.serializer.deserialize(RecurrentRequest, recurrent) + assert request.interval == 1 + await queue.register( + ValidateAccess, CompanyRequest(1), recurrent=timedelta(seconds=2) + ) + + assert await queue.storage.length(recurrent_pipe) == 1 + + [recurrent] = await queue.storage.range(recurrent_pipe, 1) + request = queue.serializer.deserialize(RecurrentRequest, recurrent) + assert request.interval == 2