diff --git a/api/tacticalrmm/core/tasks.py b/api/tacticalrmm/core/tasks.py index b4530e1b30..ad467f9de2 100644 --- a/api/tacticalrmm/core/tasks.py +++ b/api/tacticalrmm/core/tasks.py @@ -1,8 +1,9 @@ import asyncio import logging from contextlib import suppress -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +import nats from django.conf import settings from django.db.models import Prefetch from django.db.utils import DatabaseError @@ -34,10 +35,13 @@ TaskStatus, TaskSyncStatus, ) +from tacticalrmm.helpers import setup_nats_options +from tacticalrmm.nats_utils import a_nats_cmd from tacticalrmm.utils import redis_lock if TYPE_CHECKING: from django.db.models import QuerySet + from nats.aio.client import Client as NATSClient logger = logging.getLogger("trmm") @@ -147,12 +151,12 @@ def resolve_alerts_task(self) -> str: @app.task(bind=True) -def sync_scheduled_tasks(self) -> None: +def sync_scheduled_tasks(self) -> str: with redis_lock(SYNC_SCHED_TASK_LOCK, self.app.oid) as acquired: if not acquired: return f"{self.app.oid} still running" - actions: list[tuple[str, int, Agent]] = [] # list of tuples + actions: list[tuple[str, int, Agent, Any, str, str]] = [] # list of tuples for agent in _get_agent_qs(): if ( @@ -168,24 +172,60 @@ def sync_scheduled_tasks(self) -> None: isinstance(task.task_result, TaskResult) and task.task_result.sync_status == TaskSyncStatus.INITIAL ): - actions.append(("create", task.id, agent_obj)) + actions.append( + ( + "create", + task.id, + agent_obj, + task.generate_nats_task_payload( + agent=agent_obj, editing=False + ), + agent.agent_id, + agent.hostname, + ) + ) elif ( isinstance(task.task_result, TaskResult) and task.task_result.sync_status == TaskSyncStatus.PENDING_DELETION ): - actions.append(("delete", task.id, agent_obj)) + actions.append( + ( + "delete", + task.id, + agent_obj, + {}, + agent.agent_id, + agent.hostname, + ) + ) elif ( isinstance(task.task_result, TaskResult) and task.task_result.sync_status == TaskSyncStatus.NOT_SYNCED ): - actions.append(("modify", task.id, agent_obj)) - - async def _handle_task_on_agent(actions: tuple[str, int, Agent]) -> None: - # tuple: (0: action, 1: task.id, 2: agent object) + actions.append( + ( + "modify", + task.id, + agent_obj, + task.generate_nats_task_payload( + agent=None, editing=True + ), + agent.agent_id, + agent.hostname, + ) + ) + + async def _handle_task_on_agent( + nc: "NATSClient", actions: tuple[str, int, Agent, Any, str, str] + ) -> None: + # tuple: (0: action, 1: task.id, 2: agent object, 3: nats task payload, 4: agent_id, 5: agent hostname) action = actions[0] task_id = actions[1] agent = actions[2] + payload = action[3] + agent_id = actions[4] + hostname = actions[5] task: "AutomatedTask" = await AutomatedTask.objects.aget(id=task_id) try: @@ -195,19 +235,13 @@ async def _handle_task_on_agent(actions: tuple[str, int, Agent]) -> None: await task_result.asave() if action in ("create", "modify"): - task_args = { - "agent": agent if action == "create" else None, - "editing": action != "create", - } - - payload = task.generate_nats_task_payload(**task_args) logger.debug(payload) nats_data = { "func": "schedtask", "schedtaskpayload": payload, } - r = await agent.nats_cmd(nats_data, timeout=5) + r = await a_nats_cmd(nc=nc, sub=agent_id, data=nats_data, timeout=5) if r != "ok": if action == "create": task_result.sync_status = TaskSyncStatus.INITIAL @@ -215,12 +249,12 @@ async def _handle_task_on_agent(actions: tuple[str, int, Agent]) -> None: task_result.sync_status = TaskSyncStatus.NOT_SYNCED logger.error( - f"Unable to {action} scheduled task {task.name} on {agent.hostname}: {r}" + f"Unable to {action} scheduled task {task.name} on {hostname}: {r}" ) else: task_result.sync_status = TaskSyncStatus.SYNCED logger.info( - f"{agent.hostname} task {task.name} was {'created' if action == 'create' else 'modified'}" + f"{hostname} task {task.name} was {'created' if action == 'create' else 'modified'}" ) await task_result.asave(update_fields=["sync_status"]) @@ -230,7 +264,7 @@ async def _handle_task_on_agent(actions: tuple[str, int, Agent]) -> None: "func": "delschedtask", "schedtaskpayload": {"name": task.win_task_name}, } - r = await agent.nats_cmd(nats_data, timeout=5) + r = await a_nats_cmd(nc=nc, sub=agent_id, data=nats_data, timeout=5) if r != "ok" and "The system cannot find the file specified" not in r: task_result.sync_status = TaskSyncStatus.PENDING_DELETION @@ -239,17 +273,27 @@ async def _handle_task_on_agent(actions: tuple[str, int, Agent]) -> None: await task_result.asave(update_fields=["sync_status"]) logger.error( - f"Unable to {action} scheduled task {task.name} on {agent.hostname}: {r}" + f"Unable to {action} scheduled task {task.name} on {hostname}: {r}" ) else: await task.adelete() - logger.info(f"{agent.hostname} task {task.name} was deleted") + logger.info(f"{hostname} task {task.name} was deleted") + + async def _run() -> str | None: + opts = setup_nats_options() + try: + nc = await nats.connect(**opts) + except Exception as e: + return str(e) + + if tasks := [_handle_task_on_agent(nc, task) for task in actions]: + await asyncio.gather(*tasks) - async def _run() -> None: - tasks = [_handle_task_on_agent(task) for task in actions] - await asyncio.gather(*tasks) + await nc.flush() + await nc.close() asyncio.run(_run()) + return "ok" def _get_failing_data(agents: "QuerySet[Agent]") -> dict[str, bool]: diff --git a/api/tacticalrmm/tacticalrmm/nats_utils.py b/api/tacticalrmm/tacticalrmm/nats_utils.py index 0ccb41e463..d145857cda 100644 --- a/api/tacticalrmm/tacticalrmm/nats_utils.py +++ b/api/tacticalrmm/tacticalrmm/nats_utils.py @@ -3,6 +3,7 @@ import msgpack import nats +from nats.errors import TimeoutError as NatsTimeout from tacticalrmm.exceptions import NatsDown from tacticalrmm.helpers import setup_nats_options @@ -36,3 +37,19 @@ async def abulk_nats_command(*, items: "BULK_NATS_TASKS") -> None: await asyncio.gather(*tasks) await nc.flush() await nc.close() + + +async def a_nats_cmd( + *, nc: "NClient", sub: str, data: NATS_DATA, timeout: int = 10 +) -> str | Any: + try: + msg = await nc.request( + subject=sub, payload=msgpack.dumps(data), timeout=timeout + ) + except NatsTimeout: + return "timeout" + + try: + return msgpack.loads(msg.data) + except Exception as e: + return str(e)