diff --git a/distributaur/__init__.py b/distributaur/__init__.py new file mode 100644 index 0000000..56c4349 --- /dev/null +++ b/distributaur/__init__.py @@ -0,0 +1,3 @@ +from .task_runner import * +from .vast import * +from .utils import * \ No newline at end of file diff --git a/distributaur/batch.py b/distributaur/batch.py deleted file mode 100644 index 1ca8ab6..0000000 --- a/distributaur/batch.py +++ /dev/null @@ -1,197 +0,0 @@ -import json -import os -import signal -import sys -import argparse -import time -from celery import chord, uuid - -sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")) - -from distributaur.vast import ( - rent_nodes, - terminate_nodes, - monitor_job_status, - handle_sigint, - attach_to_existing_job, - dump_redis_values, -) -from distributaur.worker import render_object, notify_completion - - -def render_objects( - job_id, - start_index, - end_index, - start_frame=0, - end_frame=65, - width=1920, - height=1080, - output_dir="./renders", - hdri_path="./backgrounds", - max_price=0.1, - max_nodes=1, - image="arfx/simian-worker:latest", - api_key=None, -): - combinations = [] - # read combinations.json - with open("combinations.json", "r") as file: - combinations = json.load(file) - combinations = combinations["combinations"] - - # make sure end_index is less than the number of combinations - end_index = min(end_index, len(combinations)) - - print(f"Rendering objects from {start_index} to {end_index}") - - tasks = [ - render_object.s( - job_id, - i, - combination, - width, - height, - output_dir, - hdri_path, - start_frame, - end_frame, - ) - for i, combination in enumerate(combinations[start_index:end_index]) - ] - callback = notify_completion.s(job_id) # Pass job_id to completion callback - job = chord(tasks)(callback) - - # Rent nodes using distributed_vast - nodes = rent_nodes(max_price, max_nodes, image, api_key) - - # Set up signal handler for SIGINT - signal.signal(signal.SIGINT, lambda sig, frame: handle_sigint(nodes)) - - # Add delay to wait for workers to start - time.sleep(30) # Adjust this time as needed - - # Monitor the job status - monitor_job_status(job) # Directly pass the job - - # Dump Redis values for debugging - # dump_redis_values() - - # Terminate nodes once the job is complete - terminate_nodes(nodes) - - print("All tasks have been completed!") - return job - - -def main(): - parser = argparse.ArgumentParser( - description="Automate the rendering of objects using Celery." - ) - parser.add_argument( - "--start_index", - type=int, - default=0, - help="Starting index for rendering from the combinations list.", - ) - parser.add_argument( - "--end_index", - type=int, - default=100, - help="Ending index for rendering from the combinations list.", - ) - parser.add_argument( - "--start_frame", - type=int, - default=0, - help="Starting frame number for the animation. Defaults to 0.", - ) - parser.add_argument( - "--end_frame", - type=int, - default=65, - help="Ending frame number for the animation. Defaults to 65.", - ) - parser.add_argument( - "--width", - type=int, - default=1920, - help="Width of the rendering in pixels. Defaults to 1920.", - ) - parser.add_argument( - "--height", - type=int, - default=1080, - help="Height of the rendering in pixels. Defaults to 1080.", - ) - parser.add_argument( - "--output_dir", - type=str, - default="./renders", - help="Directory to save rendered outputs. Defaults to './renders'.", - ) - parser.add_argument( - "--hdri_path", - type=str, - default="./backgrounds", - help="Directory containing HDRI files for rendering. Defaults to './backgrounds'.", - ) - parser.add_argument( - "--max_price", - type=float, - default=0.1, - help="Maximum price per hour for renting nodes. Defaults to 0.1.", - ) - parser.add_argument( - "--max_nodes", - type=int, - default=1, - help="Maximum number of nodes to rent. Defaults to 1.", - ) - parser.add_argument( - "--image", - type=str, - default="arfx/simian-worker:latest", - help="Docker image to use for rendering. Defaults to 'arfx/simian-worker:latest'.", - ) - parser.add_argument( - "--api_key", - type=str, - default=None, - help="API key for renting nodes. Defaults to None.", - ) - # add job_id - parser.add_argument( - "--job_id", - type=str, - default=str(uuid()), - help="Unique job ID for the batch.", - ) - - args = parser.parse_args() - - job_id = args.job_id - # Check if attaching to an existing job - if attach_to_existing_job(job_id): - # Monitor the job status - monitor_job_status() - else: - render_objects( - job_id=job_id, - start_index=args.start_index, - end_index=args.end_index, - start_frame=args.start_frame, - end_frame=args.end_frame, - width=args.width, - height=args.height, - output_dir=args.output_dir, - hdri_path=args.hdri_path, - max_price=args.max_price, - max_nodes=args.max_nodes, - image=args.image, - api_key=args.api_key, - ) - - -if __name__ == "__main__": - main() diff --git a/distributaur/example.py b/distributaur/example.py deleted file mode 100644 index 2bfadfa..0000000 --- a/distributaur/example.py +++ /dev/null @@ -1,5 +0,0 @@ -from distributaur.task_runner import run_task - -@run_task -def run_example_job() -> None: - print("Running example job") \ No newline at end of file diff --git a/distributaur/task_runner.py b/distributaur/task_runner.py index eedb053..0243b3b 100644 --- a/distributaur/task_runner.py +++ b/distributaur/task_runner.py @@ -1,82 +1,64 @@ -import json -import subprocess -import sys -import os -import ssl -import time from celery import Celery -from redis import ConnectionPool, Redis - -ssl._create_default_https_context = ssl._create_unverified_context +import os +import sys +import json -from distributaur.utils import get_redis_values +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) +from distributaur.utils import get_redis_connection, get_redis_values redis_url = get_redis_values() -pool = ConnectionPool.from_url(redis_url) -redis_client = Redis(connection_pool=pool) - -app = Celery("tasks", broker=redis_url, backend=redis_url) - - -def run_task(task_func): - @app.task(name=task_func.__name__, acks_late=True, reject_on_worker_lost=True) - def wrapper(*args, **kwargs): - job_id = kwargs.get("job_id") - task_id = wrapper.request.id - print(f"Starting task {task_id} in job {job_id}") - update_task_status(job_id, task_id, "IN_PROGRESS") - - timeout = 600 # 10 minutes in seconds - task_timeout = 2700 # 45 minutes in seconds - - start_time = time.time() - print(f"Task {task_id} starting.") - - while True: - elapsed_time = time.time() - start_time - if elapsed_time > timeout: - update_task_status(task_id, "TIMEOUT") - print(f"Task {task_id} timed out before starting task") - return - - try: - task_start_time = time.time() - print(f"Task {task_id} executing task function.") - result = task_func(*args, **kwargs) - print(f"Task {task_id} completed task function.") - - elapsed_task_time = time.time() - task_start_time - if elapsed_task_time > task_timeout: - update_task_status(task_id, "TIMEOUT") - print( - f"Task {task_id} timed out after {elapsed_task_time} seconds of execution" - ) - return - - update_task_status(task_id, "COMPLETE") - print(f"Task {task_id} completed successfully") - return result - - except subprocess.TimeoutExpired: - update_task_status(task_id, "TIMEOUT") - print(f"Task {task_id} timed out after {timeout} seconds") - return - - except Exception as e: - update_task_status(job_id, task_id, "FAILED") - print(f"Task {task_id} failed with error: {str(e)}") - return - - return wrapper - - -def update_task_status(job_id, task_id, status): - key = f"celery-task-meta-{task_id}" - value = json.dumps({"status": status}) - redis_client.set(key, value) - print(f"Updated status for task {task_id} in job {job_id} to {status}") - - -if __name__ == "__main__": - print("Starting Celery worker...") - app.start(argv=["celery", "worker", "--loglevel=info"]) +app = Celery( + "distributaur", broker=redis_url, backend=redis_url +) + +registered_functions = {} + +def register_function(func): + """Decorator to register a function in the dictionary.""" + registered_functions[func.__name__] = func + return func + +@app.task +def call_function(func_name, args_json): + """ + Handle a task by executing the registered function with the provided arguments. + + Args: + func_name (str): The name of the registered function to execute. + args_json (str): The JSON string representation of the arguments for the function. + """ + print(f"Received task with function: {func_name}, and args: {args_json}") + if func_name not in registered_functions: + print("registered_functions are", registered_functions) + raise ValueError(f"Function '{func_name}' is not registered.") + + func = registered_functions[func_name] + args = json.loads(args_json) + + print(f"Executing task with function: {func_name}, and args: {args}") + result = func(**args) + update_function_status(call_function.request.id, "completed") + return result + +def execute_function(func_name, args): + """ + Execute a task by passing the function name and arguments. + + Args: + func_name (str): The name of the registered function to execute. + args (dict): The dictionary of arguments for the function. + """ + args_json = json.dumps(args) + print(f"Dispatching task with function: {func_name}, and args: {args_json}") + return call_function.delay(func_name, args_json) + +def update_function_status(task_id, status): + """ + Update the status of a task in Redis. + + Args: + task_id (str): The ID of the task. + status (str): The new status of the task. + """ + redis_client = get_redis_connection() + redis_client.set(f"task_status:{task_id}", status) diff --git a/distributaur/tests/__init__.py b/distributaur/tests/__init__.py index 272f39e..4f9e6eb 100644 --- a/distributaur/tests/__init__.py +++ b/distributaur/tests/__init__.py @@ -1 +1,3 @@ -from .vast import * \ No newline at end of file +from .utils_test import * +from .task_runner_test import * +from .vast_test import * diff --git a/distributaur/tests/task_runner_test.py b/distributaur/tests/task_runner_test.py new file mode 100644 index 0000000..69ae30d --- /dev/null +++ b/distributaur/tests/task_runner_test.py @@ -0,0 +1,65 @@ +import json +import os +import sys +import pytest +from unittest.mock import MagicMock, patch + +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) + +from distributaur.task_runner import execute_function, register_function, update_function_status, registered_functions +from distributaur.utils import close_redis_connection, get_redis_connection + +@pytest.fixture +def mock_task_function(): + """ + Fixture that returns a mock task function. + """ + return MagicMock() + + +def test_register_function(mock_task_function): + """ + Test the register_function function. + """ + mock_task_function.__name__ = "mock_task" # Set the __name__ attribute + decorated_task = register_function(mock_task_function) + + assert callable(decorated_task) + assert mock_task_function.__name__ in registered_functions + assert registered_functions[mock_task_function.__name__] == mock_task_function + print("Test passed") + +@patch("distributaur.task_runner.call_function.delay") +def test_execute_function(mock_delay, mock_task_function): + """ + Test the execute_function function. + """ + mock_task_function.__name__ = "mock_task" # Set the __name__ attribute + register_function(mock_task_function) + + params = {'arg1': 1, 'arg2': 2} + execute_function(mock_task_function.__name__, params) + + mock_delay.assert_called_once_with(mock_task_function.__name__, json.dumps(params)) + print("Test passed") + +@patch("distributaur.task_runner.get_redis_connection") +def test_update_function_status(mock_get_redis_connection): + """ + Test the update_function_status function. + """ + mock_redis_client = MagicMock() + mock_get_redis_connection.return_value = mock_redis_client + + task_id = "task_123" + status = "SUCCESS" + + update_function_status(task_id, status) + + mock_redis_client.set.assert_called_once_with(f"task_status:{task_id}", status) + print("Test passed") + +# Add teardown to close Redis connections +def teardown_module(module): + client = get_redis_connection() + close_redis_connection(client) diff --git a/distributaur/tests/test_worker.py b/distributaur/tests/test_worker.py new file mode 100644 index 0000000..e4114d7 --- /dev/null +++ b/distributaur/tests/test_worker.py @@ -0,0 +1,12 @@ +# /Users/shawwalters/distributoor/example_worker.py + +from distributaur.task_runner import register_function, app + +# Ensure the Celery app is available as `celery` +celery = app + +# Define and register the example_function +def example_function(arg1, arg2): + return f"Result: arg1={arg1}, arg2={arg2}" + +register_function(example_function) diff --git a/distributaur/tests/utils_test.py b/distributaur/tests/utils_test.py new file mode 100644 index 0000000..95b2bde --- /dev/null +++ b/distributaur/tests/utils_test.py @@ -0,0 +1,118 @@ +# /Users/shawwalters/distributoor/distributaur/tests/utils_test.py + +import subprocess +import time +import pytest +from distributaur.task_runner import execute_function, register_function, registered_functions, update_function_status +from distributaur.utils import get_env_vars, get_redis_connection, get_redis_values, close_redis_connection + +@pytest.fixture +def env_file(tmpdir): + env_content = """\ +REDIS_HOST=localhost +REDIS_PORT=6379 +REDIS_USER=user +REDIS_PASSWORD=password\ +""" + env_file = tmpdir.join(".env") + env_file.write(env_content) + return env_file + +@pytest.fixture +def redis_client(): + client = get_redis_connection() + yield client + close_redis_connection(client) + +def test_redis_connection(redis_client): + assert redis_client.ping() + print("Redis connection test passed") + +def test_get_redis_values(redis_client, env_file): + redis_url = get_redis_values(env_file) + assert redis_url == "redis://user:password@localhost:6379" + +def test_get_env_vars(env_file): + env_vars = get_env_vars(env_file) + assert env_vars == { + "REDIS_HOST": "localhost", + "REDIS_PORT": "6379", + "REDIS_USER": "user", + "REDIS_PASSWORD": "password", + } + +def test_get_redis_connection(redis_client): + assert redis_client.ping() + print("Redis connection test passed") + +def test_register_function(): + def example_function(arg1, arg2): + return f"Result: arg1={arg1}, arg2={arg2}" + + register_function(example_function) + assert "example_function" in registered_functions + assert registered_functions["example_function"] == example_function + print("Task registration test passed") + +def test_execute_function(): + def example_function(arg1, arg2): + return f"Result: arg1={arg1}, arg2={arg2}" + + register_function(example_function) + task_params = {"arg1": 10, "arg2": 20} + task = execute_function("example_function", task_params) + assert task.id is not None + print("Task execution test passed") + +def test_worker_task_execution(): + def example_function(arg1, arg2): + return f"Result: arg1={arg1}, arg2={arg2}" + + register_function(example_function) + + worker_cmd = [ + "celery", + "-A", + "distributaur.tests.test_worker", + "worker", + "--loglevel=info", + "--concurrency=1", + "--heartbeat-interval=1", + ] + worker_process = subprocess.Popen(worker_cmd) + + time.sleep(5) + + task_params = {"arg1": 10, "arg2": 20} + task = execute_function("example_function", task_params) + result = task.get(timeout=10) + + assert result == "Result: arg1=10, arg2=20" + + worker_process.terminate() + worker_process.wait() + + print("Worker task execution test passed") + +def test_task_status_update(): + redis_client = get_redis_connection() + + try: + task_status_keys = redis_client.keys("task_status:*") + if task_status_keys: + redis_client.delete(*task_status_keys) + + task_id = "test_task_123" + status = "COMPLETED" + + update_function_status(task_id, status) + + status_from_redis = redis_client.get(f"task_status:{task_id}").decode() + assert status_from_redis == status + + redis_client.delete(f"task_status:{task_id}") + + print("Task status update test passed") + finally: + close_redis_connection(redis_client) + diff --git a/distributaur/tests/vast.py b/distributaur/tests/vast.py deleted file mode 100644 index cb6b02e..0000000 --- a/distributaur/tests/vast.py +++ /dev/null @@ -1,38 +0,0 @@ -import os -import sys -import time -import pytest - -current_dir = os.path.dirname(os.path.abspath(__file__)) -simian_path = os.path.join(current_dir, "../") -sys.path.append(simian_path) - -from distributaur.vast import rent_nodes, terminate_nodes, headers -from distributaur.utils import get_env_vars - - -@pytest.fixture(scope="module") -def vast_api_key(): - env_vars = get_env_vars() - key = os.getenv("VAST_API_KEY") or env_vars.get("VAST_API_KEY") - if not key: - pytest.fail("Vast API key not found.") - return key - - -@pytest.fixture(scope="module") -def rented_nodes(vast_api_key): - headers["Authorization"] = "Bearer " + vast_api_key - - max_price = 0.5 - max_nodes = 1 - image = "arfx/simian-worker:latest" - - nodes = rent_nodes(max_price, max_nodes, image, vast_api_key) - yield nodes - terminate_nodes(nodes) - - -def test_rent_run_terminate(rented_nodes): - assert len(rented_nodes) == 1 - time.sleep(3) # sleep for 3 seconds to simulate runtime diff --git a/distributaur/tests/vast_test.py b/distributaur/tests/vast_test.py new file mode 100644 index 0000000..6b31488 --- /dev/null +++ b/distributaur/tests/vast_test.py @@ -0,0 +1,52 @@ +import os +import sys +import time +import pytest +from unittest.mock import patch +from distributaur.utils import get_env_vars, close_redis_connection, get_redis_connection + +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) + +from distributaur.vast import rent_nodes, terminate_nodes, handle_signal, headers + + +@pytest.fixture(scope="module") +def vast_api_key(): + env_vars = get_env_vars() + key = os.getenv("VAST_API_KEY") or env_vars.get("VAST_API_KEY") + if not key: + pytest.fail("Vast API key not found.") + return key + + +@pytest.fixture(scope="module") +def rented_nodes(vast_api_key): + headers["Authorization"] = "Bearer " + vast_api_key + + max_price = 0.5 + max_nodes = 1 + image = "arfx/simian-worker:latest" + + nodes = rent_nodes(max_price, max_nodes, image, vast_api_key) + yield nodes + terminate_nodes(nodes) + + +def test_rent_run_terminate(rented_nodes): + assert len(rented_nodes) == 1 + time.sleep(3) # sleep for 3 seconds to simulate runtime + + +def test_handle_signal(): + nodes = [{"instance_id": "instance1"}, {"instance_id": "instance2"}] + with patch("distributaur.vast.terminate_nodes") as mock_terminate_nodes: + signal_handler = handle_signal(nodes) + with pytest.raises(SystemExit) as exc_info: + signal_handler(None, None) + assert exc_info.value.code == 0 + mock_terminate_nodes.assert_called_once_with(nodes) + +# Add teardown to close Redis connections +def teardown_module(module): + client = get_redis_connection() + close_redis_connection(client) diff --git a/distributaur/utils.py b/distributaur/utils.py index d44124b..03c8c9f 100644 --- a/distributaur/utils.py +++ b/distributaur/utils.py @@ -1,21 +1,14 @@ -import os -from sys import platform - +# /Users/shawwalters/distributoor/distributaur/utils.py -def get_env_vars(path=".env"): - env_vars = {} - if not os.path.exists(path): - return env_vars - with open(path, "r") as f: - for line in f: - key, value = line.strip().split("=") - env_vars[key] = value - return env_vars +import os +import redis +from redis import ConnectionPool +pool = None def get_redis_values(path=".env"): env_vars = get_env_vars(path) - + host = env_vars.get("REDIS_HOST", os.getenv("REDIS_HOST", "localhost")) password = env_vars.get("REDIS_PASSWORD", os.getenv("REDIS_PASSWORD", None)) port = env_vars.get("REDIS_PORT", os.getenv("REDIS_PORT", 6379)) @@ -26,116 +19,25 @@ def get_redis_values(path=".env"): redis_url = f"redis://{username}:{password}@{host}:{port}" return redis_url +def get_redis_connection(): + """Retrieve Redis connection from the connection pool.""" + global pool + if pool is None: + redis_url = get_redis_values() + pool = ConnectionPool.from_url(redis_url) + return redis.Redis(connection_pool=pool) -def get_blender_path(): - # if we are on macOS, then application_path is /Applications/Blender.app/Contents/MacOS/Blender - if platform.system() == "Darwin": - application_path = "/Applications/Blender.app/Contents/MacOS/Blender" - else: - application_path = "./blender/blender" - if not os.path.exists(application_path): - raise FileNotFoundError(f"Blender not found at {application_path}.") - return application_path - - -def upload_outputs(output_dir): - # determine if s3 or huggingface environment variables are set up - # env_vars = get_env_vars() - # aws_access_key_id = env_vars.get("AWS_ACCESS_KEY_ID") or os.getenv("AWS_ACCESS_KEY_ID") - # huggingface_token = env_vars.get("HF_TOKEN") or os.getenv("HF_TOKEN") - # if aws_access_key_id and huggingface_token: - # print("Warning: Both AWS and Hugging Face credentials are set. Defaulting to Huggingface. Remove credentials to default to AWS.") - # upload_to_huggingface(output_dir, combination) - # elif aws_access_key_id is None and huggingface_token is None: - # raise ValueError("No AWS or Hugging Face credentials found. Please set one.") - # elif aws_access_key_id: - # upload_to_s3(output_dir, combination) - # elif huggingface_token: - upload_to_huggingface(output_dir) - - -# def upload_to_s3(output_dir, combination): -# """ -# Uploads the rendered outputs to an S3 bucket. - -# Args: -# - output_dir (str): The directory where the rendered outputs are saved. -# - bucket_name (str): The name of the S3 bucket. -# - s3_path (str): The path in the S3 bucket where files should be uploaded. - -# Returns: -# - None -# """ -# import boto3 -# from botocore.exceptions import NoCredentialsError, PartialCredentialsError +def close_redis_connection(client): + """Close the Redis connection.""" + client.close() -# env_vars = get_env_vars() -# aws_access_key_id = env_vars.get("AWS_ACCESS_KEY_ID") or os.getenv("AWS_ACCESS_KEY_ID") -# aws_secret_access_key = env_vars.get("AWS_SECRET_ACCESS_KEY") or os.getenv("AWS_SECRET_ACCESS_KEY") - -# s3_client = boto3.client( -# "s3", -# aws_access_key_id, -# aws_secret_access_key -# ) - -# bucket_name = combination.get("bucket_name", os.getenv("AWS_BUCKET_NAME")) or env_vars.get("AWS_BUCKET_NAME") -# s3_path = combination.get("upload_path", os.getenv("AWS_UPLOAD_PATH")) or env_vars.get("AWS_UPLOAD_PATH") - -# for root, dirs, files in os.walk(output_dir): -# for file in files: -# local_path = os.path.join(root, file) -# s3_file_path = os.path.join(s3_path, file) if s3_path else file - -# try: -# s3_client.upload_file(local_path, bucket_name, s3_file_path) -# print(f"Uploaded {local_path} to s3://{bucket_name}/{s3_file_path}") -# except FileNotFoundError: -# print(f"File not found: {local_path}") -# except NoCredentialsError: -# print("AWS credentials not found.") -# except PartialCredentialsError: -# print("Incomplete AWS credentials.") -# except Exception as e: -# print(f"Failed to upload {local_path} to s3://{bucket_name}/{s3_file_path}: {e}") - - -def upload_to_huggingface(output_dir): - """ - Uploads the rendered outputs to a Hugging Face repository. - - Args: - - output_dir (str): The directory where the rendered outputs are saved. - - repo_id (str): The repository ID on Hugging Face. - - Returns: - - None - """ - env_vars = get_env_vars() - hf_token = os.getenv("HF_TOKEN") or env_vars.get("HF_TOKEN") - repo_id = os.getenv("HF_REPO_ID") or env_vars.get("HF_REPO_ID") - repo_path = os.getenv("HF_PATH") or env_vars.get("HF_PATH", "") - from huggingface_hub import HfApi - - api = HfApi(token=hf_token) - - for root, dirs, files in os.walk(output_dir): - for file in files: - local_path = os.path.join(root, file) - path_in_repo = os.path.join(repo_path, file) if repo_path else file - - try: - api.upload_file( - path_or_fileobj=local_path, - path_in_repo=path_in_repo, - repo_id=repo_id, - token=hf_token, - repo_type="dataset", - ) - print( - f"Uploaded {local_path} to Hugging Face repo {repo_id} at {path_in_repo}" - ) - except Exception as e: - print( - f"Failed to upload {local_path} to Hugging Face repo {repo_id} at {path_in_repo}: {e}" - ) +def get_env_vars(path=".env"): + env_vars = {} + if not os.path.exists(path): + return env_vars + with open(path, "r") as f: + for line in f: + key, value = line.strip().split("=") + env_vars[key] = value + print('*** env vars are:', env_vars) + return env_vars diff --git a/distributaur/vast.py b/distributaur/vast.py index 30b3340..9a3e846 100644 --- a/distributaur/vast.py +++ b/distributaur/vast.py @@ -5,18 +5,14 @@ from typing import Dict import time import re -from redis import ConnectionPool, Redis -sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")) -from distributaur.utils import get_env_vars, get_redis_values +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) + +from distributaur.utils import get_env_vars, get_redis_connection server_url_default = "https://console.vast.ai" headers = {} - -redis_url = get_redis_values() -pool = ConnectionPool.from_url(redis_url) -redis_client = Redis(connection_pool=pool) - +redis_client = get_redis_connection() def dump_redis_values(): keys = redis_client.keys("*") @@ -381,7 +377,9 @@ def destroy_instance(instance_id): def rent_nodes(max_price, max_nodes, image, api_key, env=get_env_vars()): - env["VAST_API_KEY"] = api_key or env.get("VAST_API_KEY") + api_key = api_key or env.get("VAST_API_KEY") or os.getenv("VAST_API_KEY") + print("api key") + print(api_key) offers = search_offers(max_price, api_key) rented_nodes = [] for offer in offers: @@ -452,7 +450,13 @@ def attach_to_existing_job(job_id): return False -def handle_sigint(nodes): - print("Received SIGINT. Terminating all running workers...") - terminate_nodes(nodes) - sys.exit(0) +def handle_signal(nodes): + """Handle SIGINT for graceful shutdown.""" + from distributaur.vast import terminate_nodes + + def signal_handler(sig, frame): + print("SIGINT received, shutting down...") + terminate_nodes(nodes) + sys.exit(0) + + return signal_handler diff --git a/example.py b/example.py new file mode 100644 index 0000000..41497ed --- /dev/null +++ b/example.py @@ -0,0 +1,96 @@ +import os +import sys +import subprocess +import signal +import time + +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "./")) + +from distributaur.utils import get_env_vars +from distributaur.vast import attach_to_existing_job, monitor_job_status, rent_nodes, terminate_nodes, handle_signal +from distributaur.task_runner import execute_function, register_function + +def setup_and_run(config): + # Uncomment if you want to use rent_nodes and handle_signal + # nodes = rent_nodes( + # config["max_price"], + # config["max_nodes"], + # config["docker_image"], + # config["api_key"], + # ) + # signal.signal(signal.SIGINT, handle_signal(nodes)) + + tasks = [ + execute_function(config["task_func"], config["task_params"]) + ] + + for task in tasks: + print(f"Task {task.id} dispatched.") + + while not all(task.ready() for task in tasks): + time.sleep(1) + + print("All tasks have been completed!") + # Uncomment if you want to use terminate_nodes + # terminate_nodes(nodes) + +def start_worker(): + worker_cmd = [ + "celery", + "-A", + "example_worker", + "worker", + "--loglevel=info", + "--concurrency=1" + ] + worker_process = subprocess.Popen(worker_cmd) + return worker_process + +@register_function +def run_workload(arg1, arg2): + # Perform your rendering task here + print(f"Rendering object with arg1={arg1} and arg2={arg2}") + # Simulating rendering time + time.sleep(5) + # Return the result or any relevant information + return f"Rendered object with arg1={arg1} and arg2={arg2}" + +register_function(run_workload) + +if __name__ == "__main__": + # Start the worker process + worker_process = start_worker() + + try: + env = get_env_vars() + api_key = env.get("VAST_API_KEY") or os.getenv("VAST_API_KEY") + if not api_key: + raise ValueError("Vast API key not found in environment variables.") + + # Configure your job + config = { + "job_id": "example_job", + "max_price": 0.10, + "max_nodes": 1, + "docker_image": "your-docker-image", + "api_key": api_key, + "task_func": run_workload.__name__, + "task_params": {"arg1": 1, "arg2": "a"} + } + + # Check if the job is already running + if attach_to_existing_job(config["job_id"]): + print("Attaching to an existing job...") + # Monitor job status and handle success/failure conditions + monitor_job_status(config["job_id"]) + else: + # Run the job + setup_and_run(config) + # Monitor job status and handle success/failure conditions + monitor_job_status(config["job_id"]) + + finally: + # Terminate the worker process + worker_process.terminate() + worker_process.wait() + print("Worker process terminated.") diff --git a/example_worker.py b/example_worker.py new file mode 100644 index 0000000..6376508 --- /dev/null +++ b/example_worker.py @@ -0,0 +1,13 @@ +# /Users/shawwalters/distributoor/example_worker.py + +from distributaur.task_runner import register_function, app +import example + +# Ensure the Celery app is available as `celery` +celery = app + +# Define and register the example_function +def example_function(arg1, arg2): + return f"Result: arg1={arg1}, arg2={arg2}" + +register_function(example_function) diff --git a/scripts/kill_redis_connections.sh b/scripts/kill_redis_connections.sh new file mode 100644 index 0000000..d160bbd --- /dev/null +++ b/scripts/kill_redis_connections.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Define the port you want to kill connections for +PORT=17504 + +# Use lsof to find all PIDs for the given port and store them in an array +PIDS=($(lsof -i TCP:$PORT -t)) + +# Check if there are any PIDs to kill +if [ ${#PIDS[@]} -eq 0 ]; then + echo "No processes found using port $PORT." + exit 0 +fi + +# Loop through each PID and kill it +for PID in "${PIDS[@]}"; do + echo "Killing process $PID" + sudo kill -9 $PID +done + +echo "All processes using port $PORT have been killed."