Skip to content

Commit

Permalink
jobs, example, tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalune committed May 19, 2024
1 parent c06ae0f commit 2ceaee1
Show file tree
Hide file tree
Showing 15 changed files with 486 additions and 456 deletions.
3 changes: 3 additions & 0 deletions distributaur/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .task_runner import *
from .vast import *
from .utils import *
197 changes: 0 additions & 197 deletions distributaur/batch.py

This file was deleted.

5 changes: 0 additions & 5 deletions distributaur/example.py

This file was deleted.

138 changes: 60 additions & 78 deletions distributaur/task_runner.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,64 @@
import json
import subprocess
import sys
import os
import ssl
import time
from celery import Celery
from redis import ConnectionPool, Redis

ssl._create_default_https_context = ssl._create_unverified_context
import os
import sys
import json

from distributaur.utils import get_redis_values
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../"))
from distributaur.utils import get_redis_connection, get_redis_values

redis_url = get_redis_values()
pool = ConnectionPool.from_url(redis_url)
redis_client = Redis(connection_pool=pool)

app = Celery("tasks", broker=redis_url, backend=redis_url)


def run_task(task_func):
@app.task(name=task_func.__name__, acks_late=True, reject_on_worker_lost=True)
def wrapper(*args, **kwargs):
job_id = kwargs.get("job_id")
task_id = wrapper.request.id
print(f"Starting task {task_id} in job {job_id}")
update_task_status(job_id, task_id, "IN_PROGRESS")

timeout = 600 # 10 minutes in seconds
task_timeout = 2700 # 45 minutes in seconds

start_time = time.time()
print(f"Task {task_id} starting.")

while True:
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
update_task_status(task_id, "TIMEOUT")
print(f"Task {task_id} timed out before starting task")
return

try:
task_start_time = time.time()
print(f"Task {task_id} executing task function.")
result = task_func(*args, **kwargs)
print(f"Task {task_id} completed task function.")

elapsed_task_time = time.time() - task_start_time
if elapsed_task_time > task_timeout:
update_task_status(task_id, "TIMEOUT")
print(
f"Task {task_id} timed out after {elapsed_task_time} seconds of execution"
)
return

update_task_status(task_id, "COMPLETE")
print(f"Task {task_id} completed successfully")
return result

except subprocess.TimeoutExpired:
update_task_status(task_id, "TIMEOUT")
print(f"Task {task_id} timed out after {timeout} seconds")
return

except Exception as e:
update_task_status(job_id, task_id, "FAILED")
print(f"Task {task_id} failed with error: {str(e)}")
return

return wrapper


def update_task_status(job_id, task_id, status):
key = f"celery-task-meta-{task_id}"
value = json.dumps({"status": status})
redis_client.set(key, value)
print(f"Updated status for task {task_id} in job {job_id} to {status}")


if __name__ == "__main__":
print("Starting Celery worker...")
app.start(argv=["celery", "worker", "--loglevel=info"])
app = Celery(
"distributaur", broker=redis_url, backend=redis_url
)

registered_functions = {}

def register_function(func):
"""Decorator to register a function in the dictionary."""
registered_functions[func.__name__] = func
return func

@app.task
def call_function(func_name, args_json):
"""
Handle a task by executing the registered function with the provided arguments.
Args:
func_name (str): The name of the registered function to execute.
args_json (str): The JSON string representation of the arguments for the function.
"""
print(f"Received task with function: {func_name}, and args: {args_json}")
if func_name not in registered_functions:
print("registered_functions are", registered_functions)
raise ValueError(f"Function '{func_name}' is not registered.")

func = registered_functions[func_name]
args = json.loads(args_json)

print(f"Executing task with function: {func_name}, and args: {args}")
result = func(**args)
update_function_status(call_function.request.id, "completed")
return result

def execute_function(func_name, args):
"""
Execute a task by passing the function name and arguments.
Args:
func_name (str): The name of the registered function to execute.
args (dict): The dictionary of arguments for the function.
"""
args_json = json.dumps(args)
print(f"Dispatching task with function: {func_name}, and args: {args_json}")
return call_function.delay(func_name, args_json)

def update_function_status(task_id, status):
"""
Update the status of a task in Redis.
Args:
task_id (str): The ID of the task.
status (str): The new status of the task.
"""
redis_client = get_redis_connection()
redis_client.set(f"task_status:{task_id}", status)
4 changes: 3 additions & 1 deletion distributaur/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .vast import *
from .utils_test import *
from .task_runner_test import *
from .vast_test import *
Loading

0 comments on commit 2ceaee1

Please sign in to comment.