Skip to content

Commit

Permalink
fixes to async rework
Browse files Browse the repository at this point in the history
  • Loading branch information
wh1te909 committed Nov 15, 2023
1 parent 7377906 commit 597240d
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 24 deletions.
92 changes: 68 additions & 24 deletions api/tacticalrmm/core/tasks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import logging
from contextlib import suppress
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import nats
from django.conf import settings
from django.db.models import Prefetch
from django.db.utils import DatabaseError
Expand Down Expand Up @@ -34,10 +35,13 @@
TaskStatus,
TaskSyncStatus,
)
from tacticalrmm.helpers import setup_nats_options
from tacticalrmm.nats_utils import a_nats_cmd
from tacticalrmm.utils import redis_lock

if TYPE_CHECKING:
from django.db.models import QuerySet
from nats.aio.client import Client as NATSClient

logger = logging.getLogger("trmm")

Expand Down Expand Up @@ -147,12 +151,12 @@ def resolve_alerts_task(self) -> str:


@app.task(bind=True)
def sync_scheduled_tasks(self) -> None:
def sync_scheduled_tasks(self) -> str:
with redis_lock(SYNC_SCHED_TASK_LOCK, self.app.oid) as acquired:
if not acquired:
return f"{self.app.oid} still running"

actions: list[tuple[str, int, Agent]] = [] # list of tuples
actions: list[tuple[str, int, Agent, Any, str, str]] = [] # list of tuples

for agent in _get_agent_qs():
if (
Expand All @@ -168,24 +172,60 @@ def sync_scheduled_tasks(self) -> None:
isinstance(task.task_result, TaskResult)
and task.task_result.sync_status == TaskSyncStatus.INITIAL
):
actions.append(("create", task.id, agent_obj))
actions.append(
(
"create",
task.id,
agent_obj,
task.generate_nats_task_payload(
agent=agent_obj, editing=False
),
agent.agent_id,
agent.hostname,
)
)
elif (
isinstance(task.task_result, TaskResult)
and task.task_result.sync_status
== TaskSyncStatus.PENDING_DELETION
):
actions.append(("delete", task.id, agent_obj))
actions.append(
(
"delete",
task.id,
agent_obj,
{},
agent.agent_id,
agent.hostname,
)
)
elif (
isinstance(task.task_result, TaskResult)
and task.task_result.sync_status == TaskSyncStatus.NOT_SYNCED
):
actions.append(("modify", task.id, agent_obj))

async def _handle_task_on_agent(actions: tuple[str, int, Agent]) -> None:
# tuple: (0: action, 1: task.id, 2: agent object)
actions.append(
(
"modify",
task.id,
agent_obj,
task.generate_nats_task_payload(
agent=None, editing=True
),
agent.agent_id,
agent.hostname,
)
)

async def _handle_task_on_agent(
nc: "NATSClient", actions: tuple[str, int, Agent, Any, str, str]
) -> None:
# tuple: (0: action, 1: task.id, 2: agent object, 3: nats task payload, 4: agent_id, 5: agent hostname)
action = actions[0]
task_id = actions[1]
agent = actions[2]
payload = action[3]
agent_id = actions[4]
hostname = actions[5]

task: "AutomatedTask" = await AutomatedTask.objects.aget(id=task_id)
try:
Expand All @@ -195,32 +235,26 @@ async def _handle_task_on_agent(actions: tuple[str, int, Agent]) -> None:
await task_result.asave()

if action in ("create", "modify"):
task_args = {
"agent": agent if action == "create" else None,
"editing": action != "create",
}

payload = task.generate_nats_task_payload(**task_args)
logger.debug(payload)
nats_data = {
"func": "schedtask",
"schedtaskpayload": payload,
}

r = await agent.nats_cmd(nats_data, timeout=5)
r = await a_nats_cmd(nc=nc, sub=agent_id, data=nats_data, timeout=5)
if r != "ok":
if action == "create":
task_result.sync_status = TaskSyncStatus.INITIAL
else:
task_result.sync_status = TaskSyncStatus.NOT_SYNCED

logger.error(
f"Unable to {action} scheduled task {task.name} on {agent.hostname}: {r}"
f"Unable to {action} scheduled task {task.name} on {hostname}: {r}"
)
else:
task_result.sync_status = TaskSyncStatus.SYNCED
logger.info(
f"{agent.hostname} task {task.name} was {'created' if action == 'create' else 'modified'}"
f"{hostname} task {task.name} was {'created' if action == 'create' else 'modified'}"
)

await task_result.asave(update_fields=["sync_status"])
Expand All @@ -230,7 +264,7 @@ async def _handle_task_on_agent(actions: tuple[str, int, Agent]) -> None:
"func": "delschedtask",
"schedtaskpayload": {"name": task.win_task_name},
}
r = await agent.nats_cmd(nats_data, timeout=5)
r = await a_nats_cmd(nc=nc, sub=agent_id, data=nats_data, timeout=5)

if r != "ok" and "The system cannot find the file specified" not in r:
task_result.sync_status = TaskSyncStatus.PENDING_DELETION
Expand All @@ -239,17 +273,27 @@ async def _handle_task_on_agent(actions: tuple[str, int, Agent]) -> None:
await task_result.asave(update_fields=["sync_status"])

logger.error(
f"Unable to {action} scheduled task {task.name} on {agent.hostname}: {r}"
f"Unable to {action} scheduled task {task.name} on {hostname}: {r}"
)
else:
await task.adelete()
logger.info(f"{agent.hostname} task {task.name} was deleted")
logger.info(f"{hostname} task {task.name} was deleted")

async def _run() -> str | None:
opts = setup_nats_options()
try:
nc = await nats.connect(**opts)
except Exception as e:
return str(e)

if tasks := [_handle_task_on_agent(nc, task) for task in actions]:
await asyncio.gather(*tasks)

async def _run() -> None:
tasks = [_handle_task_on_agent(task) for task in actions]
await asyncio.gather(*tasks)
await nc.flush()
await nc.close()

asyncio.run(_run())
return "ok"


def _get_failing_data(agents: "QuerySet[Agent]") -> dict[str, bool]:
Expand Down
17 changes: 17 additions & 0 deletions api/tacticalrmm/tacticalrmm/nats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import msgpack
import nats
from nats.errors import TimeoutError as NatsTimeout

from tacticalrmm.exceptions import NatsDown
from tacticalrmm.helpers import setup_nats_options
Expand Down Expand Up @@ -36,3 +37,19 @@ async def abulk_nats_command(*, items: "BULK_NATS_TASKS") -> None:
await asyncio.gather(*tasks)
await nc.flush()
await nc.close()


async def a_nats_cmd(
*, nc: "NClient", sub: str, data: NATS_DATA, timeout: int = 10
) -> str | Any:
try:
msg = await nc.request(
subject=sub, payload=msgpack.dumps(data), timeout=timeout
)
except NatsTimeout:
return "timeout"

try:
return msgpack.loads(msg.data)
except Exception as e:
return str(e)

0 comments on commit 597240d

Please sign in to comment.