Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Periodic tasks on top of the per-second main loop runner #324

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 65 additions & 9 deletions taskiq/cli/scheduler/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from datetime import datetime, timedelta
from logging import basicConfig, getLevelName, getLogger
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple

import pytz
from pycron import is_now
Expand Down Expand Up @@ -100,9 +100,58 @@ def get_task_delay(task: ScheduledTask) -> Optional[int]:
one_min_ahead = (now + timedelta(minutes=1)).replace(second=1, microsecond=0)
if task_time <= one_min_ahead:
return int((task_time - now).total_seconds())
if task.period is not None and int(now.timestamp()) % int(task.period) == 0:
return 0

return None


def is_task_executed_recently(
task: ScheduledTask,
recent_tasks: Dict[str, int],
) -> bool:
"""
Check if the task has been run recently to avoid duplicate executions.

:param task: task to check.
:param recent_tasks: tuple of recent tasks exec.
:return: True if task must be sent.
"""
task_identifier = get_cron_task_identifier(task)
if task_identifier is None:
return False
task_name, task_now_ts = task_identifier
if task_name not in recent_tasks:
return False
recent_task_ts = recent_tasks[task_name]
return recent_task_ts == task_now_ts


def get_cron_task_identifier(
task: ScheduledTask,
dt: Optional[datetime] = None,
) -> Optional[Tuple[str, int]]:
"""
Get the (task_id, timestamp) task identifier.

:param task: task to check.
:return Tuple[str, datetime] | None: (task name, datetime for the task type)
"""
if task.cron is None:
return None
dt = dt or datetime.now(tz=pytz.UTC)
# If user specified cron offset we apply it.
# If it's timedelta, we simply add the delta to current time.
if task.cron_offset and isinstance(task.cron_offset, timedelta):
dt += task.cron_offset
# If timezone was specified as string we convert it timzone
# offset and then apply.
elif task.cron_offset and isinstance(task.cron_offset, str):
dt = dt.astimezone(pytz.timezone(task.cron_offset))
secondless_dt = dt.replace(second=0, microsecond=0)
return (task.task_name, int(secondless_dt.timestamp()))


async def delayed_send(
scheduler: TaskiqScheduler,
source: ScheduleSource,
Expand All @@ -123,7 +172,7 @@ async def delayed_send(
:param scheduler: current scheduler.
:param source: source of the task.
:param task: task to send.
:param delay: how long to wait.
:param delay: task execution delay in seconds.
"""
if delay > 0:
await asyncio.sleep(delay)
Expand All @@ -136,19 +185,22 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
Runs scheduler loop.

This function imports taskiq scheduler
and runs tasks when needed.
and runs tasks to be executed.

:param scheduler: current scheduler.
"""
loop = asyncio.get_event_loop()
running_schedules = set()
recent_schedules: Dict[str, int] = {}
while True:
# We use this method to correctly sleep for one minute.
scheduled_tasks = await get_all_schedules(scheduler)
for source, task_list in scheduled_tasks.items():
for task in task_list:
if is_task_executed_recently(task, recent_schedules):
continue
try:
task_delay = get_task_delay(task)
task_delay_seconds = get_task_delay(task)
except ValueError:
logger.warning(
"Cannot parse cron: %s for task: %s, schedule_id: %s",
Expand All @@ -157,16 +209,20 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
task.schedule_id,
)
continue
if task_delay is not None:
if task_delay_seconds is not None:
send_task = loop.create_task(
delayed_send(scheduler, source, task, task_delay),
delayed_send(scheduler, source, task, task_delay_seconds),
)
task_identifier = get_cron_task_identifier(task)
if isinstance(task_identifier, tuple):
recent_schedules[task_identifier[0]] = task_identifier[1]

running_schedules.add(send_task)
send_task.add_done_callback(running_schedules.discard)
next_minute = datetime.now().replace(second=0, microsecond=0) + timedelta(
minutes=1,
next_second_datetime = datetime.now().replace(microsecond=0) + timedelta(
seconds=1,
)
delay = next_minute - datetime.now()
delay = next_second_datetime - datetime.now()
await asyncio.sleep(delay.total_seconds())


Expand Down
28 changes: 28 additions & 0 deletions taskiq/kicker.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,34 @@ async def schedule_by_time(
await source.add_schedule(scheduled)
return CreatedSchedule(self, source, scheduled)

async def schedule_by_period(
self,
source: "ScheduleSource",
period: Union[float],
*args: _FuncParams.args,
**kwargs: _FuncParams.kwargs,
) -> CreatedSchedule[_ReturnType]:
"""
Function to schedule task to run periodically.

:param source: schedule source.
:param period: period to run the tasks at.
:param args: function's args.
:param kwargs: function's kwargs.
"""
schedule_id = self.broker.id_generator()
message = self._prepare_message(*args, **kwargs)
scheduled = ScheduledTask(
schedule_id=schedule_id,
task_name=message.task_name,
labels=message.labels,
args=message.args,
kwargs=message.kwargs,
period=int(period),
)
await source.add_schedule(scheduled)
return CreatedSchedule(self, source, scheduled)

@classmethod
def _prepare_arg(cls, arg: Any) -> Any:
"""
Expand Down
3 changes: 2 additions & 1 deletion taskiq/schedule_sources/label_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def get_schedules(self) -> List["ScheduledTask"]:
if task.broker != self.broker:
continue
for schedule in task.labels.get("schedule", []):
if "cron" not in schedule and "time" not in schedule:
if all(field not in schedule for field in ("cron", "time", "period")):
continue
labels = schedule.get("labels", {})
labels.update(task.labels)
Expand All @@ -43,6 +43,7 @@ async def get_schedules(self) -> List["ScheduledTask"]:
kwargs=schedule.get("kwargs", {}),
cron=schedule.get("cron"),
time=schedule.get("time"),
period=schedule.get("period"),
cron_offset=schedule.get("cron_offset"),
),
)
Expand Down
1 change: 1 addition & 0 deletions taskiq/scheduler/created_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __str__(self) -> str:
f"id={self.schedule_id}, "
f"time={self.task.time}, "
f"cron={self.task.cron}, "
f"period={self.task.period}, "
f"cron_offset={self.task.cron_offset or 'UTC'}, "
f"task_name={self.task.task_name}, "
f"args={self.task.args}, "
Expand Down
10 changes: 8 additions & 2 deletions taskiq/scheduler/scheduled_task/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from pydantic import BaseModel, Field, root_validator

from taskiq.utils import get_present_object_fields


class ScheduledTask(BaseModel):
"""Abstraction over task schedule."""
Expand All @@ -16,6 +18,7 @@ class ScheduledTask(BaseModel):
cron: Optional[str] = None
cron_offset: Optional[Union[str, timedelta]] = None
time: Optional[datetime] = None
period: Optional[float | int] = None

@root_validator(pre=False) # type: ignore
@classmethod
Expand All @@ -25,6 +28,9 @@ def __check(cls, values: Dict[str, Any]) -> Dict[str, Any]:

:raises ValueError: if cron and time are none.
"""
if values.get("cron") is None and values.get("time") is None:
raise ValueError("Either cron or datetime must be present.")
required_fields = ("cron", "time", "period")
present_fields = get_present_object_fields(values, required_fields)
if not present_fields:
message = f"At least one of {required_fields} must be set."
raise ValueError(message)
return values
10 changes: 8 additions & 2 deletions taskiq/scheduler/scheduled_task/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self

from taskiq.utils import get_present_object_fields


class ScheduledTask(BaseModel):
"""Abstraction over task schedule."""
Expand All @@ -17,6 +19,7 @@ class ScheduledTask(BaseModel):
cron: Optional[str] = None
cron_offset: Optional[Union[str, timedelta]] = None
time: Optional[datetime] = None
period: Optional[Union[float, int]] = None

@model_validator(mode="after")
def __check(self) -> Self:
Expand All @@ -25,6 +28,9 @@ def __check(self) -> Self:

:raises ValueError: if cron and time are none.
"""
if self.cron is None and self.time is None:
raise ValueError("Either cron or datetime must be present.")
required_fields = ("cron", "time", "period")
present_fields = get_present_object_fields(self, required_fields)
if not present_fields:
message = f"At least one of {required_fields} must be set."
raise ValueError(message)
return self
2 changes: 1 addition & 1 deletion taskiq/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def on_ready(self, source: "ScheduleSource", task: ScheduledTask) -> None:
"""
This method is called when task is ready to be enqueued.

It's triggered on proper time depending on `task.cron` or `task.time` attribute.
It's triggered on proper time depending on a task.{cron,time,period} attribute.
:param source: source that triggered this event.
:param task: task to send
"""
Expand Down
47 changes: 46 additions & 1 deletion taskiq/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
import inspect
from typing import Any, Awaitable, Coroutine, TypeVar, Union
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Dict,
List,
Optional,
Sequence,
TypeVar,
Union,
)

_T = TypeVar("_T")

Expand Down Expand Up @@ -35,3 +46,37 @@ def remove_suffix(text: str, suffix: str) -> str:
if text.endswith(suffix):
return text[: -len(suffix)]
return text


def get_present_object_fields(
obj: Any,
fields: Sequence[str],
check_condition: Optional[Callable[[Any, str], bool]] = None,
) -> List[str]:
"""
Check the presence of the fields in the object.

:param obj: Object to check fields in
:param fields: Sequence of fields
:param check_condition: A function to check the value is considered present
:return: present fields.
"""
if not check_condition:
if isinstance(obj, dict):

def check_condition(obj: Dict[str, Any], field: str) -> bool:
return field in obj and obj[field] is not None

else:

def check_condition(obj: Any, field: str) -> bool:
return getattr(obj, field, None) is not None

present_fields = []
for field in fields:
try:
if check_condition(obj, field):
present_fields.append(field)
except AttributeError:
pass
return present_fields
2 changes: 2 additions & 0 deletions tests/schedule_sources/test_label_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
[
pytest.param([{"cron": "* * * * *"}], id="cron"),
pytest.param([{"time": datetime.utcnow()}], id="time"),
pytest.param([{"period": 1.0}], id="period"),
],
)
async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None:
Expand All @@ -33,6 +34,7 @@ def task() -> None:
schedule_id=schedules[0].schedule_id,
cron=schedule_label[0].get("cron"),
time=schedule_label[0].get("time"),
period=schedule_label[0].get("period"),
task_name="test_task",
labels={"schedule": schedule_label},
args=[],
Expand Down