From 7480168ef4b3fce5bfb686f9b161921122fe6519 Mon Sep 17 00:00:00 2001 From: moon Date: Fri, 24 May 2024 18:11:00 -0700 Subject: [PATCH] black --- distributaur/__init__.py | 2 +- distributaur/core.py | 34 ++++++++++++++++++------------- distributaur/tests/core_test.py | 29 ++++++++++++++++++++++---- distributaur/tests/test_worker.py | 4 +++- distributaur/tests/vast_test.py | 10 ++++++++- distributaur/vast.py | 15 ++++++++------ 6 files changed, 67 insertions(+), 27 deletions(-) diff --git a/distributaur/__init__.py b/distributaur/__init__.py index 7ea1148..dd18643 100644 --- a/distributaur/__init__.py +++ b/distributaur/__init__.py @@ -1,2 +1,2 @@ from .core import * -from .vast import * \ No newline at end of file +from .vast import * diff --git a/distributaur/core.py b/distributaur/core.py index 3ea43cf..fe7d35a 100644 --- a/distributaur/core.py +++ b/distributaur/core.py @@ -24,6 +24,7 @@ def get_env_vars(path=".env.default"): env_vars[key] = value return env_vars + class Config: def __init__(self): self.settings = {} @@ -35,25 +36,28 @@ def configure(self, **kwargs): def get(self, key, default=None): return self.settings.get(key, default) + config = Config() + def get_redis_values(config): host = config.get("REDIS_HOST", None) password = config.get("REDIS_PASSWORD", None) port = config.get("REDIS_PORT", None) username = config.get("REDIS_USER", None) - + print("host", host) print("password", password) print("port", port) print("username", username) - + if None in [host, password, port, username]: raise ValueError("Missing required Redis configuration values") - + redis_url = f"redis://{username}:{password}@{host}:{port}" return redis_url + def get_redis_connection(config, force_new=False): """Retrieve Redis connection from the connection pool.""" global pool @@ -67,30 +71,29 @@ def close_redis_connection(client): """Close the Redis connection.""" client.close() + def configure(**kwargs): global app - print('configuring') + print("configuring") config.configure(**kwargs) redis_url = get_redis_values(config) - app = Celery( - "distributaur", - broker=redis_url, - backend=redis_url - ) + app = Celery("distributaur", broker=redis_url, backend=redis_url) # Disable task events app.conf.worker_send_task_events = False print("Celery configured.") + env_vars = get_env_vars(".env") print("env_vars") print(env_vars) configure(**env_vars) -@app.task(name='call_function_task') + +@app.task(name="call_function_task") def call_function_task(func_name, args_json): """ Handle a task by executing the registered function with the provided arguments. - + Args: func_name (str): The name of the registered function to execute. args_json (str): The JSON string representation of the arguments for the function. @@ -107,15 +110,17 @@ def call_function_task(func_name, args_json): update_function_status(call_function_task.request.id, "completed") return result + def register_function(func): """Decorator to register a function in the dictionary.""" registered_functions[func.__name__] = func return func + def execute_function(func_name, args): """ Execute a task by passing the function name and arguments. - + Args: func_name (str): The name of the registered function to execute. args (dict): The dictionary of arguments for the function. @@ -124,13 +129,14 @@ def execute_function(func_name, args): print(f"Dispatching task with function: {func_name}, and args: {args_json}") return call_function_task.delay(func_name, args_json) + def update_function_status(task_id, status): """ Update the status of a task in Redis. - + Args: task_id (str): The ID of the task. status (str): The new status of the task. """ redis_client = get_redis_connection(config) - redis_client.set(f"task_status:{task_id}", status) \ No newline at end of file + redis_client.set(f"task_status:{task_id}", status) diff --git a/distributaur/tests/core_test.py b/distributaur/tests/core_test.py index 75b842b..99f4672 100644 --- a/distributaur/tests/core_test.py +++ b/distributaur/tests/core_test.py @@ -6,7 +6,19 @@ import time import pytest -from distributaur.core import execute_function, register_function, registered_functions, close_redis_connection, get_redis_connection, config, configure, registered_functions, update_function_status, get_env_vars +from distributaur.core import ( + execute_function, + register_function, + registered_functions, + close_redis_connection, + get_redis_connection, + config, + configure, + registered_functions, + update_function_status, + get_env_vars, +) + @pytest.fixture def mock_task_function(): @@ -28,6 +40,7 @@ def test_register_function(mock_task_function): assert registered_functions[mock_task_function.__name__] == mock_task_function print("Test passed") + @patch("distributaur.core.call_function_task.delay") def test_execute_function(mock_delay, mock_task_function): """ @@ -36,12 +49,13 @@ def test_execute_function(mock_delay, mock_task_function): mock_task_function.__name__ = "mock_task" # Set the __name__ attribute register_function(mock_task_function) - params = {'arg1': 1, 'arg2': 2} + params = {"arg1": 1, "arg2": 2} execute_function(mock_task_function.__name__, params) mock_delay.assert_called_once_with(mock_task_function.__name__, json.dumps(params)) print("Test passed") + @patch("distributaur.core.get_redis_connection") def test_update_function_status(mock_get_redis_connection): """ @@ -58,25 +72,30 @@ def test_update_function_status(mock_get_redis_connection): mock_redis_client.set.assert_called_once_with(f"task_status:{task_id}", status) print("Test passed") + # Add teardown to close Redis connections def teardown_module(module): client = get_redis_connection(config) close_redis_connection(client) + @pytest.fixture def redis_client(): client = get_redis_connection(config, force_new=True) yield client close_redis_connection(client) + def test_redis_connection(redis_client): assert redis_client.ping() print("Redis connection test passed") + def test_get_redis_connection(redis_client): assert redis_client.ping() print("Redis connection test passed") + def test_register_function(): def example_function(arg1, arg2): return f"Result: arg1={arg1}, arg2={arg2}" @@ -86,6 +105,7 @@ def example_function(arg1, arg2): assert registered_functions["example_function"] == example_function print("Task registration test passed") + def test_execute_function(): def example_function(arg1, arg2): return f"Result: arg1={arg1}, arg2={arg2}" @@ -96,6 +116,7 @@ def example_function(arg1, arg2): assert task.id is not None print("Task execution test passed") + def test_worker_task_execution(): def example_function(arg1, arg2): return f"Result: arg1={arg1}, arg2={arg2}" @@ -107,7 +128,7 @@ def example_function(arg1, arg2): "-A", "distributaur.tests.test_worker", "worker", - "--loglevel=info" + "--loglevel=info", ] print("worker_cmd") print(worker_cmd) @@ -126,6 +147,7 @@ def example_function(arg1, arg2): print("Worker task execution test passed") + def test_task_status_update(): redis_client = get_redis_connection(config) @@ -147,4 +169,3 @@ def test_task_status_update(): print("Task status update test passed") finally: close_redis_connection(redis_client) - diff --git a/distributaur/tests/test_worker.py b/distributaur/tests/test_worker.py index 0b1dcfc..40167e4 100644 --- a/distributaur/tests/test_worker.py +++ b/distributaur/tests/test_worker.py @@ -8,8 +8,10 @@ # Disable task events app.conf.worker_send_task_events = False + # Define and register the example_function def example_function(arg1, arg2): return f"Result: arg1={arg1}, arg2={arg2}" -register_function(example_function) \ No newline at end of file + +register_function(example_function) diff --git a/distributaur/tests/vast_test.py b/distributaur/tests/vast_test.py index 8b45994..fe74c15 100644 --- a/distributaur/tests/vast_test.py +++ b/distributaur/tests/vast_test.py @@ -6,7 +6,13 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) -from distributaur.core import configure, config, get_env_vars, close_redis_connection, get_redis_connection +from distributaur.core import ( + configure, + config, + get_env_vars, + close_redis_connection, + get_redis_connection, +) from distributaur.vast import rent_nodes, terminate_nodes, handle_signal, headers env_vars = get_env_vars(".env") @@ -14,6 +20,7 @@ print(env_vars) configure(**env_vars) + @pytest.fixture(scope="module") def vast_api_key(): key = os.getenv("VAST_API_KEY") or env_vars.get("VAST_API_KEY") @@ -49,6 +56,7 @@ def test_handle_signal(): assert exc_info.value.code == 0 mock_terminate_nodes.assert_called_once_with(nodes) + # Add teardown to close Redis connections def teardown_module(module): client = get_redis_connection(config) diff --git a/distributaur/vast.py b/distributaur/vast.py index abdfe9e..1da3f86 100644 --- a/distributaur/vast.py +++ b/distributaur/vast.py @@ -14,6 +14,7 @@ headers = {} redis_client = get_redis_connection(config, force_new=True) + def dump_redis_values(): keys = redis_client.keys("*") for key in keys: @@ -304,7 +305,9 @@ def search_offers(max_price): } url = ( base_url - + '?q={"gpu_ram":">=4","rentable":{"eq":true},"dph_total":{"lte":' + str(max_price) + '},"sort_option":{"0":["dph_total","asc"],"1":["total_flops","asc"]}}' + + '?q={"gpu_ram":">=4","rentable":{"eq":true},"dph_total":{"lte":' + + str(max_price) + + '},"sort_option":{"0":["dph_total","asc"],"1":["total_flops","asc"]}}' ) print("url", url) @@ -324,7 +327,7 @@ def search_offers(max_price): def create_instance(offer_id, image, env): if env is None: raise ValueError("env is required") - + print("Creating instance with offer_id", offer_id) print("env is") print(env) @@ -360,25 +363,25 @@ def create_instance(offer_id, image, env): }, json=json_blob, ) - + # check on response if response.status_code != 200: print(f"Failed to create instance: {response.text}") raise Exception(f"Failed to create instance: {response.text}") - + return response.json() def destroy_instance(instance_id): api_key = config.get("VAST_API_KEY") - headers = {"Authorization": f'Bearer {api_key}'} + headers = {"Authorization": f"Bearer {api_key}"} url = apiurl(f"/instances/{instance_id}/") print(f"Terminating instance: {instance_id}") response = http_del(url, headers=headers, json={}) return response.json() -def rent_nodes(max_price, max_nodes, image, api_key, env=get_env_vars('.env')): +def rent_nodes(max_price, max_nodes, image, api_key, env=get_env_vars(".env")): api_key = api_key or env.get("VAST_API_KEY") print("api key") print(api_key)