diff --git a/requirements.txt b/requirements.txt index 74d47ca..667ebff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ pydantic-settings>=2.0.0 +python-dateutil==2.8.2 diff --git a/sharded_queue/__init__.py b/sharded_queue/__init__.py index 059c072..074405c 100644 --- a/sharded_queue/__init__.py +++ b/sharded_queue/__init__.py @@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Generic, NamedTuple, Optional, Self, TypeVar, get_type_hints) +from dateutil.rrule import rrule, rrulestr from sharded_queue.drivers import JsonTupleSerializer from sharded_queue.protocols import Lock, Serializer, Storage from sharded_queue.settings import WorkerSettings @@ -292,7 +293,13 @@ class DeferredRequest(NamedTuple): msg: list @classmethod - def calculate_timestamp(cls, delta: float | int | timedelta) -> float: + def calculate_timestamp( + cls, + delta: float | int | str | timedelta, + ) -> float: + if isinstance(delta, str): + return rrulestr(delta).after(datetime.now()).timestamp() + now: datetime = datetime.now() if isinstance(delta, timedelta): now = now + delta @@ -364,12 +371,18 @@ def transform( class RecurrentRequest(NamedTuple): - interval: float + interval: float | str pipe: str msg: list @classmethod - def get_interval(cls, interval: int | float | timedelta) -> float: + def get_interval( + cls, + interval: int | float | timedelta | rrule + ) -> float | str: + if isinstance(interval, rrule): + return str(interval) + if isinstance(interval, timedelta): return float(int(interval.total_seconds())) @@ -410,10 +423,10 @@ async def handle(self, *requests: RecurrentRequest) -> None: def transform( cls, pipe_messages: list[tuple[str, T]], - recurrent: float | int | timedelta, + recurrent: float | int | timedelta | rrule, serializer: Serializer, ) -> list[tuple[str, RecurrentRequest]]: - interval: float = RecurrentRequest.get_interval(recurrent) + interval: float | str = RecurrentRequest.get_interval(recurrent) return [ ( Tube(RecurrentHandler, Route()).pipe,