Skip to content

Commit

Permalink
Merge pull request #43 from RaccoonResearch/antbaez/fixing_issues
Browse files Browse the repository at this point in the history
fixed bug preventing proper task monitoring
  • Loading branch information
antbaez9 authored Jun 24, 2024
2 parents 16b9d75 + 5f7f6d9 commit bfae5ea
Showing 1 changed file with 37 additions and 15 deletions.
52 changes: 37 additions & 15 deletions distributaur/distributaur.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,23 @@ def __init__(
self.app.conf.broker_pool_limit = self.settings["BROKER_POOL_LIMIT"]

# At exit, close app

def cleanup_redis():
patterns = ["celery-task*", "task_status*"]
redis_connection = self.get_redis_connection()
for pattern in patterns:
for key in redis_connection.scan_iter(match=pattern):
redis_connection.delete(key)
redis_connection.close()
print("Redis cleared")

def cleanup_celery():
self.app.control.purge()
print("Celery queue cleared")

atexit.register(self.app.close)
atexit.register(cleanup_redis)
atexit.register(cleanup_celery)

self.app.task_acks_late = True
self.app.worker_prefetch_multiplier = 1
Expand Down Expand Up @@ -155,7 +171,6 @@ def get_redis_connection(self, force_new: bool = False) -> Redis:
self.pool = ConnectionPool.from_url(redis_url)
self.redis_client = Redis(connection_pool=self.pool)
atexit.register(self.pool.disconnect)
atexit.register(self.redis_client.close)

return self.redis_client

Expand Down Expand Up @@ -194,7 +209,10 @@ def call_function_task(self, func_name: str, args_json: str) -> any:
func = self.registered_functions[func_name]
args = json.loads(args_json)
result = func(**args)
# self.update_function_status(self.call_function_task.request.id, "success")

the_id = self.call_function_task.request.id
self.log(f"the id {the_id}")
self.update_function_status(self.call_function_task.request.id, "success")

return result
except Exception as e:
Expand Down Expand Up @@ -227,13 +245,9 @@ def execute_function(self, func_name: str, args: dict) -> Celery.AsyncResult:
celery.result.AsyncResult: An object representing the asynchronous result of the task.
"""
args_json = json.dumps(args)
print("obj", self.call_function_task)
print("type", type(self.call_function_task))
self.log(self.call_function_task)
self.log(self.call_function_task)
async_result = self.call_function_task.delay(func_name, args_json)
return async_result

return self.call_function_task.delay(func_name, args_json)

def update_function_status(self, task_id: str, status: str) -> None:
"""
Update the status of a function task in Redis.
Expand All @@ -242,9 +256,8 @@ def update_function_status(self, task_id: str, status: str) -> None:
task_id (str): The ID of the task.
status (str): The new status to set.
"""
# redis_client = self.get_redis_connection()
# redis_client.set(f"task_status:{task_id}", status)

redis_client = self.get_redis_connection()
redis_client.set(f"task_status:{task_id}", status)

def initialize_dataset(self, **kwargs) -> None:
"""Initialize a Hugging Face repository if it doesn't exist."""
Expand Down Expand Up @@ -487,7 +500,9 @@ def search_offers(self, max_price: float) -> List[Dict]:
)
raise

def create_instance(self, offer_id: str, image: str, module_name: str, command: str = None) -> Dict:
def create_instance(
self, offer_id: str, image: str, module_name: str, command: str = None
) -> Dict:
"""
Create an instance on the Vast.ai platform.
Expand Down Expand Up @@ -548,7 +563,12 @@ def destroy_instance(self, instance_id: str) -> Dict:
return response.json()

def rent_nodes(
self, max_price: float, max_nodes: int, image: str, module_name: str, command: str = None
self,
max_price: float,
max_nodes: int,
image: str,
module_name: str,
command: str = None,
) -> List[Dict]:
"""
Rent nodes on the Vast.ai platform.
Expand Down Expand Up @@ -586,7 +606,9 @@ def rent_nodes(
if len(rented_nodes) >= max_nodes:
break
try:
instance = self.create_instance(offer["id"], image, module_name, command)
instance = self.create_instance(
offer["id"], image, module_name, command
)
atexit.register(self.destroy_instance, instance["new_contract"])
rented_nodes.append(
{
Expand Down Expand Up @@ -630,7 +652,7 @@ def create_from_config(config_path="config.json", env_path=".env") -> Distributa
Create distributaur instance using settings using config that merges config.json and .env files present in distributaur directory.
Args:
config_path (str): path to config.json file
config_path (str): path to config.json file
env_path (str): path to .env file
"""
print("**** CREATE_FROM_CONFIG ****")
Expand Down

0 comments on commit bfae5ea

Please sign in to comment.