diff --git a/taskiq/cli/scheduler/run.py b/taskiq/cli/scheduler/run.py index 9155324..c8842e1 100644 --- a/taskiq/cli/scheduler/run.py +++ b/taskiq/cli/scheduler/run.py @@ -2,7 +2,7 @@ import sys from datetime import datetime, timedelta from logging import basicConfig, getLevelName, getLogger -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import pytz from pycron import is_now @@ -100,9 +100,58 @@ def get_task_delay(task: ScheduledTask) -> Optional[int]: one_min_ahead = (now + timedelta(minutes=1)).replace(second=1, microsecond=0) if task_time <= one_min_ahead: return int((task_time - now).total_seconds()) + if task.period is not None and int(now.timestamp()) % int(task.period) == 0: + return 0 + return None +def is_task_executed_recently( + task: ScheduledTask, + recent_tasks: Dict[str, int], +) -> bool: + """ + Check if the task has been run recently to avoid duplicate executions. + + :param task: task to check. + :param recent_tasks: tuple of recent tasks exec. + :return: True if task must be sent. + """ + task_identifier = get_cron_task_identifier(task) + if task_identifier is None: + return False + task_name, task_now_ts = task_identifier + if task_name not in recent_tasks: + return False + recent_task_ts = recent_tasks[task_name] + return recent_task_ts == task_now_ts + + +def get_cron_task_identifier( + task: ScheduledTask, + dt: Optional[datetime] = None, +) -> Optional[Tuple[str, int]]: + """ + Get the (task_id, timestamp) task identifier. + + :param task: task to check. + :return Tuple[str, datetime] | None: (task name, datetime for the task type) + """ + if task.cron is None: + return None + dt = dt or datetime.now(tz=pytz.UTC) + # If user specified cron offset we apply it. + # If it's timedelta, we simply add the delta to current time. + if task.cron_offset and isinstance(task.cron_offset, timedelta): + dt += task.cron_offset + # If timezone was specified as string we convert it timzone + # offset and then apply. + elif task.cron_offset and isinstance(task.cron_offset, str): + dt = dt.astimezone(pytz.timezone(task.cron_offset)) + secondless_dt = dt.replace(second=0, microsecond=0) + return (task.task_name, int(secondless_dt.timestamp())) + + async def delayed_send( scheduler: TaskiqScheduler, source: ScheduleSource, @@ -123,7 +172,7 @@ async def delayed_send( :param scheduler: current scheduler. :param source: source of the task. :param task: task to send. - :param delay: how long to wait. + :param delay: task execution delay in seconds. """ if delay > 0: await asyncio.sleep(delay) @@ -136,19 +185,22 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None: Runs scheduler loop. This function imports taskiq scheduler - and runs tasks when needed. + and runs tasks to be executed. :param scheduler: current scheduler. """ loop = asyncio.get_event_loop() running_schedules = set() + recent_schedules: Dict[str, int] = {} while True: # We use this method to correctly sleep for one minute. scheduled_tasks = await get_all_schedules(scheduler) for source, task_list in scheduled_tasks.items(): for task in task_list: + if is_task_executed_recently(task, recent_schedules): + continue try: - task_delay = get_task_delay(task) + task_delay_seconds = get_task_delay(task) except ValueError: logger.warning( "Cannot parse cron: %s for task: %s, schedule_id: %s", @@ -157,16 +209,20 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None: task.schedule_id, ) continue - if task_delay is not None: + if task_delay_seconds is not None: send_task = loop.create_task( - delayed_send(scheduler, source, task, task_delay), + delayed_send(scheduler, source, task, task_delay_seconds), ) + task_identifier = get_cron_task_identifier(task) + if isinstance(task_identifier, tuple): + recent_schedules[task_identifier[0]] = task_identifier[1] + running_schedules.add(send_task) send_task.add_done_callback(running_schedules.discard) - next_minute = datetime.now().replace(second=0, microsecond=0) + timedelta( - minutes=1, + next_second_datetime = datetime.now().replace(microsecond=0) + timedelta( + seconds=1, ) - delay = next_minute - datetime.now() + delay = next_second_datetime - datetime.now() await asyncio.sleep(delay.total_seconds()) diff --git a/taskiq/kicker.py b/taskiq/kicker.py index 5a3b975..5bc0de6 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -214,6 +214,34 @@ async def schedule_by_time( await source.add_schedule(scheduled) return CreatedSchedule(self, source, scheduled) + async def schedule_by_period( + self, + source: "ScheduleSource", + period: Union[float], + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> CreatedSchedule[_ReturnType]: + """ + Function to schedule task to run periodically. + + :param source: schedule source. + :param period: period to run the tasks at. + :param args: function's args. + :param kwargs: function's kwargs. + """ + schedule_id = self.broker.id_generator() + message = self._prepare_message(*args, **kwargs) + scheduled = ScheduledTask( + schedule_id=schedule_id, + task_name=message.task_name, + labels=message.labels, + args=message.args, + kwargs=message.kwargs, + period=int(period), + ) + await source.add_schedule(scheduled) + return CreatedSchedule(self, source, scheduled) + @classmethod def _prepare_arg(cls, arg: Any) -> Any: """ diff --git a/taskiq/schedule_sources/label_based.py b/taskiq/schedule_sources/label_based.py index e9116fd..5f27610 100644 --- a/taskiq/schedule_sources/label_based.py +++ b/taskiq/schedule_sources/label_based.py @@ -31,7 +31,7 @@ async def get_schedules(self) -> List["ScheduledTask"]: if task.broker != self.broker: continue for schedule in task.labels.get("schedule", []): - if "cron" not in schedule and "time" not in schedule: + if all(field not in schedule for field in ("cron", "time", "period")): continue labels = schedule.get("labels", {}) labels.update(task.labels) @@ -43,6 +43,7 @@ async def get_schedules(self) -> List["ScheduledTask"]: kwargs=schedule.get("kwargs", {}), cron=schedule.get("cron"), time=schedule.get("time"), + period=schedule.get("period"), cron_offset=schedule.get("cron_offset"), ), ) diff --git a/taskiq/scheduler/created_schedule.py b/taskiq/scheduler/created_schedule.py index 8e87083..e3dadc0 100644 --- a/taskiq/scheduler/created_schedule.py +++ b/taskiq/scheduler/created_schedule.py @@ -52,6 +52,7 @@ def __str__(self) -> str: f"id={self.schedule_id}, " f"time={self.task.time}, " f"cron={self.task.cron}, " + f"period={self.task.period}, " f"cron_offset={self.task.cron_offset or 'UTC'}, " f"task_name={self.task.task_name}, " f"args={self.task.args}, " diff --git a/taskiq/scheduler/scheduled_task/v1.py b/taskiq/scheduler/scheduled_task/v1.py index 5209f61..fccb76f 100644 --- a/taskiq/scheduler/scheduled_task/v1.py +++ b/taskiq/scheduler/scheduled_task/v1.py @@ -4,6 +4,8 @@ from pydantic import BaseModel, Field, root_validator +from taskiq.utils import get_present_object_fields + class ScheduledTask(BaseModel): """Abstraction over task schedule.""" @@ -16,6 +18,7 @@ class ScheduledTask(BaseModel): cron: Optional[str] = None cron_offset: Optional[Union[str, timedelta]] = None time: Optional[datetime] = None + period: Optional[float | int] = None @root_validator(pre=False) # type: ignore @classmethod @@ -25,6 +28,9 @@ def __check(cls, values: Dict[str, Any]) -> Dict[str, Any]: :raises ValueError: if cron and time are none. """ - if values.get("cron") is None and values.get("time") is None: - raise ValueError("Either cron or datetime must be present.") + required_fields = ("cron", "time", "period") + present_fields = get_present_object_fields(values, required_fields) + if not present_fields: + message = f"At least one of {required_fields} must be set." + raise ValueError(message) return values diff --git a/taskiq/scheduler/scheduled_task/v2.py b/taskiq/scheduler/scheduled_task/v2.py index 332dce5..def3cb0 100644 --- a/taskiq/scheduler/scheduled_task/v2.py +++ b/taskiq/scheduler/scheduled_task/v2.py @@ -5,6 +5,8 @@ from pydantic import BaseModel, Field, model_validator from typing_extensions import Self +from taskiq.utils import get_present_object_fields + class ScheduledTask(BaseModel): """Abstraction over task schedule.""" @@ -17,6 +19,7 @@ class ScheduledTask(BaseModel): cron: Optional[str] = None cron_offset: Optional[Union[str, timedelta]] = None time: Optional[datetime] = None + period: Optional[Union[float, int]] = None @model_validator(mode="after") def __check(self) -> Self: @@ -25,6 +28,9 @@ def __check(self) -> Self: :raises ValueError: if cron and time are none. """ - if self.cron is None and self.time is None: - raise ValueError("Either cron or datetime must be present.") + required_fields = ("cron", "time", "period") + present_fields = get_present_object_fields(self, required_fields) + if not present_fields: + message = f"At least one of {required_fields} must be set." + raise ValueError(message) return self diff --git a/taskiq/scheduler/scheduler.py b/taskiq/scheduler/scheduler.py index 04f0887..e94a388 100644 --- a/taskiq/scheduler/scheduler.py +++ b/taskiq/scheduler/scheduler.py @@ -37,7 +37,7 @@ async def on_ready(self, source: "ScheduleSource", task: ScheduledTask) -> None: """ This method is called when task is ready to be enqueued. - It's triggered on proper time depending on `task.cron` or `task.time` attribute. + It's triggered on proper time depending on a task.{cron,time,period} attribute. :param source: source that triggered this event. :param task: task to send """ diff --git a/taskiq/utils.py b/taskiq/utils.py index 9600cdd..aad04e6 100644 --- a/taskiq/utils.py +++ b/taskiq/utils.py @@ -1,5 +1,16 @@ import inspect -from typing import Any, Awaitable, Coroutine, TypeVar, Union +from typing import ( + Any, + Awaitable, + Callable, + Coroutine, + Dict, + List, + Optional, + Sequence, + TypeVar, + Union, +) _T = TypeVar("_T") @@ -35,3 +46,37 @@ def remove_suffix(text: str, suffix: str) -> str: if text.endswith(suffix): return text[: -len(suffix)] return text + + +def get_present_object_fields( + obj: Any, + fields: Sequence[str], + check_condition: Optional[Callable[[Any, str], bool]] = None, +) -> List[str]: + """ + Check the presence of the fields in the object. + + :param obj: Object to check fields in + :param fields: Sequence of fields + :param check_condition: A function to check the value is considered present + :return: present fields. + """ + if not check_condition: + if isinstance(obj, dict): + + def check_condition(obj: Dict[str, Any], field: str) -> bool: + return field in obj and obj[field] is not None + + else: + + def check_condition(obj: Any, field: str) -> bool: + return getattr(obj, field, None) is not None + + present_fields = [] + for field in fields: + try: + if check_condition(obj, field): + present_fields.append(field) + except AttributeError: + pass + return present_fields diff --git a/tests/schedule_sources/test_label_based.py b/tests/schedule_sources/test_label_based.py index 9e68391..2454730 100644 --- a/tests/schedule_sources/test_label_based.py +++ b/tests/schedule_sources/test_label_based.py @@ -14,6 +14,7 @@ [ pytest.param([{"cron": "* * * * *"}], id="cron"), pytest.param([{"time": datetime.utcnow()}], id="time"), + pytest.param([{"period": 1.0}], id="period"), ], ) async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None: @@ -33,6 +34,7 @@ def task() -> None: schedule_id=schedules[0].schedule_id, cron=schedule_label[0].get("cron"), time=schedule_label[0].get("time"), + period=schedule_label[0].get("period"), task_name="test_task", labels={"schedule": schedule_label}, args=[],