From 5f7f6d9e0f19b14353ec8b3c2a6766fc782bcea1 Mon Sep 17 00:00:00 2001 From: acbaez9 <97056049+acbaez9@users.noreply.github.com> Date: Mon, 24 Jun 2024 14:43:11 -0700 Subject: [PATCH] fixed bug preventing proper task monitoring --- distributaur/distributaur.py | 52 +++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/distributaur/distributaur.py b/distributaur/distributaur.py index a6bae62..8348a96 100644 --- a/distributaur/distributaur.py +++ b/distributaur/distributaur.py @@ -89,7 +89,23 @@ def __init__( self.app.conf.broker_pool_limit = self.settings["BROKER_POOL_LIMIT"] # At exit, close app + + def cleanup_redis(): + patterns = ["celery-task*", "task_status*"] + redis_connection = self.get_redis_connection() + for pattern in patterns: + for key in redis_connection.scan_iter(match=pattern): + redis_connection.delete(key) + redis_connection.close() + print("Redis cleared") + + def cleanup_celery(): + self.app.control.purge() + print("Celery queue cleared") + atexit.register(self.app.close) + atexit.register(cleanup_redis) + atexit.register(cleanup_celery) self.app.task_acks_late = True self.app.worker_prefetch_multiplier = 1 @@ -155,7 +171,6 @@ def get_redis_connection(self, force_new: bool = False) -> Redis: self.pool = ConnectionPool.from_url(redis_url) self.redis_client = Redis(connection_pool=self.pool) atexit.register(self.pool.disconnect) - atexit.register(self.redis_client.close) return self.redis_client @@ -194,7 +209,10 @@ def call_function_task(self, func_name: str, args_json: str) -> any: func = self.registered_functions[func_name] args = json.loads(args_json) result = func(**args) - # self.update_function_status(self.call_function_task.request.id, "success") + + the_id = self.call_function_task.request.id + self.log(f"the id {the_id}") + self.update_function_status(self.call_function_task.request.id, "success") return result except Exception as e: @@ -227,13 +245,9 @@ def execute_function(self, func_name: str, args: dict) -> Celery.AsyncResult: celery.result.AsyncResult: An object representing the asynchronous result of the task. """ args_json = json.dumps(args) - print("obj", self.call_function_task) - print("type", type(self.call_function_task)) - self.log(self.call_function_task) - self.log(self.call_function_task) + async_result = self.call_function_task.delay(func_name, args_json) + return async_result - return self.call_function_task.delay(func_name, args_json) - def update_function_status(self, task_id: str, status: str) -> None: """ Update the status of a function task in Redis. @@ -242,9 +256,8 @@ def update_function_status(self, task_id: str, status: str) -> None: task_id (str): The ID of the task. status (str): The new status to set. """ - # redis_client = self.get_redis_connection() - # redis_client.set(f"task_status:{task_id}", status) - + redis_client = self.get_redis_connection() + redis_client.set(f"task_status:{task_id}", status) def initialize_dataset(self, **kwargs) -> None: """Initialize a Hugging Face repository if it doesn't exist.""" @@ -487,7 +500,9 @@ def search_offers(self, max_price: float) -> List[Dict]: ) raise - def create_instance(self, offer_id: str, image: str, module_name: str, command: str = None) -> Dict: + def create_instance( + self, offer_id: str, image: str, module_name: str, command: str = None + ) -> Dict: """ Create an instance on the Vast.ai platform. @@ -548,7 +563,12 @@ def destroy_instance(self, instance_id: str) -> Dict: return response.json() def rent_nodes( - self, max_price: float, max_nodes: int, image: str, module_name: str, command: str = None + self, + max_price: float, + max_nodes: int, + image: str, + module_name: str, + command: str = None, ) -> List[Dict]: """ Rent nodes on the Vast.ai platform. @@ -586,7 +606,9 @@ def rent_nodes( if len(rented_nodes) >= max_nodes: break try: - instance = self.create_instance(offer["id"], image, module_name, command) + instance = self.create_instance( + offer["id"], image, module_name, command + ) atexit.register(self.destroy_instance, instance["new_contract"]) rented_nodes.append( { @@ -630,7 +652,7 @@ def create_from_config(config_path="config.json", env_path=".env") -> Distributa Create distributaur instance using settings using config that merges config.json and .env files present in distributaur directory. Args: - config_path (str): path to config.json file + config_path (str): path to config.json file env_path (str): path to .env file """ print("**** CREATE_FROM_CONFIG ****")