Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalune committed May 25, 2024
1 parent b6ae95f commit 7480168
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 27 deletions.
2 changes: 1 addition & 1 deletion distributaur/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .core import *
from .vast import *
from .vast import *
34 changes: 20 additions & 14 deletions distributaur/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_env_vars(path=".env.default"):
env_vars[key] = value
return env_vars


class Config:
def __init__(self):
self.settings = {}
Expand All @@ -35,25 +36,28 @@ def configure(self, **kwargs):
def get(self, key, default=None):
return self.settings.get(key, default)


config = Config()


def get_redis_values(config):
host = config.get("REDIS_HOST", None)
password = config.get("REDIS_PASSWORD", None)
port = config.get("REDIS_PORT", None)
username = config.get("REDIS_USER", None)

print("host", host)
print("password", password)
print("port", port)
print("username", username)

if None in [host, password, port, username]:
raise ValueError("Missing required Redis configuration values")

redis_url = f"redis://{username}:{password}@{host}:{port}"
return redis_url


def get_redis_connection(config, force_new=False):
"""Retrieve Redis connection from the connection pool."""
global pool
Expand All @@ -67,30 +71,29 @@ def close_redis_connection(client):
"""Close the Redis connection."""
client.close()


def configure(**kwargs):
global app
print('configuring')
print("configuring")
config.configure(**kwargs)
redis_url = get_redis_values(config)
app = Celery(
"distributaur",
broker=redis_url,
backend=redis_url
)
app = Celery("distributaur", broker=redis_url, backend=redis_url)
# Disable task events
app.conf.worker_send_task_events = False
print("Celery configured.")


env_vars = get_env_vars(".env")
print("env_vars")
print(env_vars)
configure(**env_vars)

@app.task(name='call_function_task')

@app.task(name="call_function_task")
def call_function_task(func_name, args_json):
"""
Handle a task by executing the registered function with the provided arguments.
Args:
func_name (str): The name of the registered function to execute.
args_json (str): The JSON string representation of the arguments for the function.
Expand All @@ -107,15 +110,17 @@ def call_function_task(func_name, args_json):
update_function_status(call_function_task.request.id, "completed")
return result


def register_function(func):
"""Decorator to register a function in the dictionary."""
registered_functions[func.__name__] = func
return func


def execute_function(func_name, args):
"""
Execute a task by passing the function name and arguments.
Args:
func_name (str): The name of the registered function to execute.
args (dict): The dictionary of arguments for the function.
Expand All @@ -124,13 +129,14 @@ def execute_function(func_name, args):
print(f"Dispatching task with function: {func_name}, and args: {args_json}")
return call_function_task.delay(func_name, args_json)


def update_function_status(task_id, status):
"""
Update the status of a task in Redis.
Args:
task_id (str): The ID of the task.
status (str): The new status of the task.
"""
redis_client = get_redis_connection(config)
redis_client.set(f"task_status:{task_id}", status)
redis_client.set(f"task_status:{task_id}", status)
29 changes: 25 additions & 4 deletions distributaur/tests/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,19 @@
import time
import pytest

from distributaur.core import execute_function, register_function, registered_functions, close_redis_connection, get_redis_connection, config, configure, registered_functions, update_function_status, get_env_vars
from distributaur.core import (
execute_function,
register_function,
registered_functions,
close_redis_connection,
get_redis_connection,
config,
configure,
registered_functions,
update_function_status,
get_env_vars,
)


@pytest.fixture
def mock_task_function():
Expand All @@ -28,6 +40,7 @@ def test_register_function(mock_task_function):
assert registered_functions[mock_task_function.__name__] == mock_task_function
print("Test passed")


@patch("distributaur.core.call_function_task.delay")
def test_execute_function(mock_delay, mock_task_function):
"""
Expand All @@ -36,12 +49,13 @@ def test_execute_function(mock_delay, mock_task_function):
mock_task_function.__name__ = "mock_task" # Set the __name__ attribute
register_function(mock_task_function)

params = {'arg1': 1, 'arg2': 2}
params = {"arg1": 1, "arg2": 2}
execute_function(mock_task_function.__name__, params)

mock_delay.assert_called_once_with(mock_task_function.__name__, json.dumps(params))
print("Test passed")


@patch("distributaur.core.get_redis_connection")
def test_update_function_status(mock_get_redis_connection):
"""
Expand All @@ -58,25 +72,30 @@ def test_update_function_status(mock_get_redis_connection):
mock_redis_client.set.assert_called_once_with(f"task_status:{task_id}", status)
print("Test passed")


# Add teardown to close Redis connections
def teardown_module(module):
client = get_redis_connection(config)
close_redis_connection(client)


@pytest.fixture
def redis_client():
client = get_redis_connection(config, force_new=True)
yield client
close_redis_connection(client)


def test_redis_connection(redis_client):
assert redis_client.ping()
print("Redis connection test passed")


def test_get_redis_connection(redis_client):
assert redis_client.ping()
print("Redis connection test passed")


def test_register_function():
def example_function(arg1, arg2):
return f"Result: arg1={arg1}, arg2={arg2}"
Expand All @@ -86,6 +105,7 @@ def example_function(arg1, arg2):
assert registered_functions["example_function"] == example_function
print("Task registration test passed")


def test_execute_function():
def example_function(arg1, arg2):
return f"Result: arg1={arg1}, arg2={arg2}"
Expand All @@ -96,6 +116,7 @@ def example_function(arg1, arg2):
assert task.id is not None
print("Task execution test passed")


def test_worker_task_execution():
def example_function(arg1, arg2):
return f"Result: arg1={arg1}, arg2={arg2}"
Expand All @@ -107,7 +128,7 @@ def example_function(arg1, arg2):
"-A",
"distributaur.tests.test_worker",
"worker",
"--loglevel=info"
"--loglevel=info",
]
print("worker_cmd")
print(worker_cmd)
Expand All @@ -126,6 +147,7 @@ def example_function(arg1, arg2):

print("Worker task execution test passed")


def test_task_status_update():
redis_client = get_redis_connection(config)

Expand All @@ -147,4 +169,3 @@ def test_task_status_update():
print("Task status update test passed")
finally:
close_redis_connection(redis_client)

4 changes: 3 additions & 1 deletion distributaur/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
# Disable task events
app.conf.worker_send_task_events = False


# Define and register the example_function
def example_function(arg1, arg2):
return f"Result: arg1={arg1}, arg2={arg2}"

register_function(example_function)

register_function(example_function)
10 changes: 9 additions & 1 deletion distributaur/tests/vast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,21 @@

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../"))

from distributaur.core import configure, config, get_env_vars, close_redis_connection, get_redis_connection
from distributaur.core import (
configure,
config,
get_env_vars,
close_redis_connection,
get_redis_connection,
)
from distributaur.vast import rent_nodes, terminate_nodes, handle_signal, headers

env_vars = get_env_vars(".env")
print("env_vars")
print(env_vars)
configure(**env_vars)


@pytest.fixture(scope="module")
def vast_api_key():
key = os.getenv("VAST_API_KEY") or env_vars.get("VAST_API_KEY")
Expand Down Expand Up @@ -49,6 +56,7 @@ def test_handle_signal():
assert exc_info.value.code == 0
mock_terminate_nodes.assert_called_once_with(nodes)


# Add teardown to close Redis connections
def teardown_module(module):
client = get_redis_connection(config)
Expand Down
15 changes: 9 additions & 6 deletions distributaur/vast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
headers = {}
redis_client = get_redis_connection(config, force_new=True)


def dump_redis_values():
keys = redis_client.keys("*")
for key in keys:
Expand Down Expand Up @@ -304,7 +305,9 @@ def search_offers(max_price):
}
url = (
base_url
+ '?q={"gpu_ram":">=4","rentable":{"eq":true},"dph_total":{"lte":' + str(max_price) + '},"sort_option":{"0":["dph_total","asc"],"1":["total_flops","asc"]}}'
+ '?q={"gpu_ram":">=4","rentable":{"eq":true},"dph_total":{"lte":'
+ str(max_price)
+ '},"sort_option":{"0":["dph_total","asc"],"1":["total_flops","asc"]}}'
)

print("url", url)
Expand All @@ -324,7 +327,7 @@ def search_offers(max_price):
def create_instance(offer_id, image, env):
if env is None:
raise ValueError("env is required")

print("Creating instance with offer_id", offer_id)
print("env is")
print(env)
Expand Down Expand Up @@ -360,25 +363,25 @@ def create_instance(offer_id, image, env):
},
json=json_blob,
)

# check on response
if response.status_code != 200:
print(f"Failed to create instance: {response.text}")
raise Exception(f"Failed to create instance: {response.text}")

return response.json()


def destroy_instance(instance_id):
api_key = config.get("VAST_API_KEY")
headers = {"Authorization": f'Bearer {api_key}'}
headers = {"Authorization": f"Bearer {api_key}"}
url = apiurl(f"/instances/{instance_id}/")
print(f"Terminating instance: {instance_id}")
response = http_del(url, headers=headers, json={})
return response.json()


def rent_nodes(max_price, max_nodes, image, api_key, env=get_env_vars('.env')):
def rent_nodes(max_price, max_nodes, image, api_key, env=get_env_vars(".env")):
api_key = api_key or env.get("VAST_API_KEY")
print("api key")
print(api_key)
Expand Down

0 comments on commit 7480168

Please sign in to comment.