diff --git a/agixt/TaskMonitor.py b/agixt/TaskMonitor.py index 50bc71d4af75..5b0514a420ea 100644 --- a/agixt/TaskMonitor.py +++ b/agixt/TaskMonitor.py @@ -1,4 +1,4 @@ -import time +import asyncio import logging from DB import get_session, TaskItem, User from Globals import getenv @@ -37,8 +37,13 @@ def impersonate_user(user_id: str): class TaskMonitor: def __init__(self): self.running = False + self.tasks = [] - def get_all_pending_tasks(self) -> list: + def is_running(self): + """Check if the monitor is running and tasks are healthy""" + return self.running and any(not task.done() for task in self.tasks) + + async def get_all_pending_tasks(self) -> list: """Get all pending tasks for all users""" session = get_session() now = datetime.now() @@ -56,13 +61,13 @@ def get_all_pending_tasks(self) -> list: finally: session.close() - def process_tasks(self): + async def process_tasks(self): """Process all pending tasks across users""" while self.running: try: session = get_session() try: - pending_tasks = self.get_all_pending_tasks() + pending_tasks = await self.get_all_pending_tasks() for pending_task in pending_tasks: # Create task manager with impersonated user context logging.info( @@ -89,7 +94,7 @@ def process_tasks(self): ) try: # Execute single task - task_manager.execute_pending_tasks() + await task_manager.execute_pending_tasks() except Exception as e: logger.error( f"Error processing task {pending_task.id}: {str(e)}" @@ -101,18 +106,27 @@ def process_tasks(self): session.close() # Wait before next check - time.sleep(60) + await asyncio.sleep(60) except Exception as e: logger.error(f"Error in task processing loop: {str(e)}") - time.sleep(60) + await asyncio.sleep(60) - def start(self): + async def start(self): """Start the task monitoring service""" self.running = True logger.info("Starting task monitor service...") - self.process_tasks() + task = asyncio.create_task(self.process_tasks()) + self.tasks.append(task) - def stop(self): + async def stop(self): """Stop the task monitoring service""" self.running = False + for task in self.tasks: + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + self.tasks.clear() logger.info("Task monitor service stopped.") diff --git a/agixt/app.py b/agixt/app.py index 54c3afcc6942..1cbac92d1704 100644 --- a/agixt/app.py +++ b/agixt/app.py @@ -3,6 +3,7 @@ import sys import logging import signal +import asyncio import mimetypes from pathlib import Path from fastapi import FastAPI, HTTPException @@ -22,7 +23,6 @@ from Workspaces import WorkspaceManager from typing import Optional from TaskMonitor import TaskMonitor -import threading os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -41,7 +41,9 @@ @asynccontextmanager async def lifespan(app: FastAPI): workspace_manager.start_file_watcher() - tasks = threading.Thread(target=task_monitor.start).start() + # Start the task monitor asynchronously + await task_monitor.start() + NGROK_TOKEN = getenv("NGROK_TOKEN") if NGROK_TOKEN: from pyngrok import ngrok @@ -62,9 +64,7 @@ async def lifespan(app: FastAPI): finally: # Shutdown workspace_manager.stop_file_watcher() - task_monitor.stop() - if tasks: - tasks.kill() + await task_monitor.stop() # Make sure to await the stop if NGROK_TOKEN: try: ngrok.kill() @@ -73,8 +73,13 @@ async def lifespan(app: FastAPI): # Register signal handlers for unexpected shutdowns -def signal_handler(signum, frame): +async def cleanup(): workspace_manager.stop_file_watcher() + await task_monitor.stop() + + +def signal_handler(signum, frame): + asyncio.run(cleanup()) sys.exit(0)