Skip to content

Commit

Permalink
make task monitor async
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Jan 5, 2025
1 parent 39f0a0d commit 704f942
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 16 deletions.
34 changes: 24 additions & 10 deletions agixt/TaskMonitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import time
import asyncio
import logging
from DB import get_session, TaskItem, User
from Globals import getenv
Expand Down Expand Up @@ -37,8 +37,13 @@ def impersonate_user(user_id: str):
class TaskMonitor:
def __init__(self):
self.running = False
self.tasks = []

def get_all_pending_tasks(self) -> list:
def is_running(self):
"""Check if the monitor is running and tasks are healthy"""
return self.running and any(not task.done() for task in self.tasks)

async def get_all_pending_tasks(self) -> list:
"""Get all pending tasks for all users"""
session = get_session()
now = datetime.now()
Expand All @@ -56,13 +61,13 @@ def get_all_pending_tasks(self) -> list:
finally:
session.close()

def process_tasks(self):
async def process_tasks(self):
"""Process all pending tasks across users"""
while self.running:
try:
session = get_session()
try:
pending_tasks = self.get_all_pending_tasks()
pending_tasks = await self.get_all_pending_tasks()
for pending_task in pending_tasks:
# Create task manager with impersonated user context
logging.info(
Expand All @@ -89,7 +94,7 @@ def process_tasks(self):
)
try:
# Execute single task
task_manager.execute_pending_tasks()
await task_manager.execute_pending_tasks()
except Exception as e:
logger.error(
f"Error processing task {pending_task.id}: {str(e)}"
Expand All @@ -101,18 +106,27 @@ def process_tasks(self):
session.close()

# Wait before next check
time.sleep(60)
await asyncio.sleep(60)
except Exception as e:
logger.error(f"Error in task processing loop: {str(e)}")
time.sleep(60)
await asyncio.sleep(60)

def start(self):
async def start(self):
"""Start the task monitoring service"""
self.running = True
logger.info("Starting task monitor service...")
self.process_tasks()
task = asyncio.create_task(self.process_tasks())
self.tasks.append(task)

def stop(self):
async def stop(self):
"""Stop the task monitoring service"""
self.running = False
for task in self.tasks:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self.tasks.clear()
logger.info("Task monitor service stopped.")
17 changes: 11 additions & 6 deletions agixt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import logging
import signal
import asyncio
import mimetypes
from pathlib import Path
from fastapi import FastAPI, HTTPException
Expand All @@ -22,7 +23,6 @@
from Workspaces import WorkspaceManager
from typing import Optional
from TaskMonitor import TaskMonitor
import threading

os.environ["TOKENIZERS_PARALLELISM"] = "false"

Expand All @@ -41,7 +41,9 @@
@asynccontextmanager
async def lifespan(app: FastAPI):
workspace_manager.start_file_watcher()
tasks = threading.Thread(target=task_monitor.start).start()
# Start the task monitor asynchronously
await task_monitor.start()

NGROK_TOKEN = getenv("NGROK_TOKEN")
if NGROK_TOKEN:
from pyngrok import ngrok
Expand All @@ -62,9 +64,7 @@ async def lifespan(app: FastAPI):
finally:
# Shutdown
workspace_manager.stop_file_watcher()
task_monitor.stop()
if tasks:
tasks.kill()
await task_monitor.stop() # Make sure to await the stop
if NGROK_TOKEN:
try:
ngrok.kill()
Expand All @@ -73,8 +73,13 @@ async def lifespan(app: FastAPI):


# Register signal handlers for unexpected shutdowns
def signal_handler(signum, frame):
async def cleanup():
workspace_manager.stop_file_watcher()
await task_monitor.stop()


def signal_handler(signum, frame):
asyncio.run(cleanup())
sys.exit(0)


Expand Down

0 comments on commit 704f942

Please sign in to comment.