Skip to content

Commit

Permalink
Dev
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalune committed May 24, 2024
1 parent 0040b31 commit dd9c8d1
Show file tree
Hide file tree
Showing 12 changed files with 135 additions and 147 deletions.
5 changes: 5 additions & 0 deletions .env.default
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_USER=
REDIS_PASSWORD=
VAST_API_KEY=
8 changes: 0 additions & 8 deletions .env.example

This file was deleted.

17 changes: 17 additions & 0 deletions distributaur/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# distributaur/config.py

from distributaur.utils import get_env_vars


class Config:
def __init__(self):
self.settings = {}
self.settings.update(get_env_vars())

def configure(self, **kwargs):
self.settings.update(kwargs)

def get(self, key, default=None):
return self.settings.get(key, default)

config = Config()
73 changes: 43 additions & 30 deletions distributaur/task_runner.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,57 @@
from celery import Celery
# distributaur/task_runner.py

from celery import Celery, Task
import os
import sys
import json

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../"))
from distributaur.utils import get_redis_connection, get_redis_values
from distributaur.config import config

redis_url = get_redis_values()
app = Celery(
"distributaur", broker=redis_url, backend=redis_url
)

app = None
registered_functions = {}

def register_function(func):
"""Decorator to register a function in the dictionary."""
registered_functions[func.__name__] = func
return func
class CallFunctionTask(Task):
name = 'call_function'

@app.task
def call_function(func_name, args_json):
"""
Handle a task by executing the registered function with the provided arguments.
def run(self, func_name, args_json):
"""
Handle a task by executing the registered function with the provided arguments.
Args:
func_name (str): The name of the registered function to execute.
args_json (str): The JSON string representation of the arguments for the function.
"""
print(f"Received task with function: {func_name}, and args: {args_json}")
if func_name not in registered_functions:
print("registered_functions are", registered_functions)
raise ValueError(f"Function '{func_name}' is not registered.")
Args:
func_name (str): The name of the registered function to execute.
args_json (str): The JSON string representation of the arguments for the function.
"""
print(f"Received task with function: {func_name}, and args: {args_json}")
if func_name not in registered_functions:
print("registered_functions are", registered_functions)
raise ValueError(f"Function '{func_name}' is not registered.")

func = registered_functions[func_name]
args = json.loads(args_json)

print(f"Executing task with function: {func_name}, and args: {args}")
result = func(**args)
update_function_status(self.request.id, "completed")
return result

func = registered_functions[func_name]
args = json.loads(args_json)
def configure(**kwargs):
global app
config.configure(**kwargs)
redis_url = get_redis_values(config)
app = Celery(
"distributaur", broker=redis_url, backend=redis_url
)

# Register call_function as a task
app.tasks.register(CallFunctionTask())
print("Celery configured.")

print(f"Executing task with function: {func_name}, and args: {args}")
result = func(**args)
update_function_status(call_function.request.id, "completed")
return result
def register_function(func):
"""Decorator to register a function in the dictionary."""
registered_functions[func.__name__] = func
return func

def execute_function(func_name, args):
"""
Expand All @@ -50,7 +63,7 @@ def execute_function(func_name, args):
"""
args_json = json.dumps(args)
print(f"Dispatching task with function: {func_name}, and args: {args_json}")
return call_function.delay(func_name, args_json)
return app.send_task('call_function', args=[func_name, args_json])

def update_function_status(task_id, status):
"""
Expand All @@ -60,5 +73,5 @@ def update_function_status(task_id, status):
task_id (str): The ID of the task.
status (str): The new status of the task.
"""
redis_client = get_redis_connection()
redis_client = get_redis_connection(config)
redis_client.set(f"task_status:{task_id}", status)
7 changes: 2 additions & 5 deletions distributaur/tests/task_runner_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import json
import os
import sys
import pytest
from unittest.mock import MagicMock, patch

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../"))

from distributaur.task_runner import execute_function, register_function, update_function_status, registered_functions
from distributaur.utils import close_redis_connection, get_redis_connection
from distributaur.config import config

@pytest.fixture
def mock_task_function():
Expand Down Expand Up @@ -61,5 +58,5 @@ def test_update_function_status(mock_get_redis_connection):

# Add teardown to close Redis connections
def teardown_module(module):
client = get_redis_connection()
client = get_redis_connection(config)
close_redis_connection(client)
6 changes: 4 additions & 2 deletions distributaur/tests/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# /Users/shawwalters/distributoor/example_worker.py
from distributaur.task_runner import configure, register_function, app
from distributaur.utils import get_env_vars

from distributaur.task_runner import register_function, app
env_vars = get_env_vars()
configure(**env_vars)

# Ensure the Celery app is available as `celery`
celery = app
Expand Down
15 changes: 8 additions & 7 deletions distributaur/tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import subprocess
import time
import pytest
from distributaur.task_runner import execute_function, register_function, registered_functions, update_function_status
from distributaur.config import config
from distributaur.task_runner import configure, execute_function, register_function, registered_functions, update_function_status
from distributaur.utils import get_env_vars, get_redis_connection, get_redis_values, close_redis_connection
from distributaur.config import config

@pytest.fixture
def env_file(tmpdir):
Expand All @@ -18,20 +20,19 @@ def env_file(tmpdir):
env_file.write(env_content)
return env_file

env_vars = get_env_vars()
configure(**env_vars)

@pytest.fixture
def redis_client():
client = get_redis_connection()
client = get_redis_connection(config)
yield client
close_redis_connection(client)

def test_redis_connection(redis_client):
assert redis_client.ping()
print("Redis connection test passed")

def test_get_redis_values(redis_client, env_file):
redis_url = get_redis_values(env_file)
assert redis_url == "redis://user:password@localhost:6379"

def test_get_env_vars(env_file):
env_vars = get_env_vars(env_file)
assert env_vars == {
Expand Down Expand Up @@ -95,7 +96,7 @@ def example_function(arg1, arg2):
print("Worker task execution test passed")

def test_task_status_update():
redis_client = get_redis_connection()
redis_client = get_redis_connection(config)

try:
task_status_keys = redis_client.keys("task_status:*")
Expand Down
6 changes: 5 additions & 1 deletion distributaur/tests/vast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
import time
import pytest
from unittest.mock import patch
from distributaur.task_runner import configure
from distributaur.utils import get_env_vars, close_redis_connection, get_redis_connection

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../"))

from distributaur.vast import rent_nodes, terminate_nodes, handle_signal, headers
from distributaur.config import config

env_vars = get_env_vars()
configure(**env_vars)

@pytest.fixture(scope="module")
def vast_api_key():
Expand Down Expand Up @@ -48,5 +52,5 @@ def test_handle_signal():

# Add teardown to close Redis connections
def teardown_module(module):
client = get_redis_connection()
client = get_redis_connection(config)
close_redis_connection(client)
39 changes: 14 additions & 25 deletions distributaur/utils.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,38 @@
# /Users/shawwalters/distributoor/distributaur/utils.py

import os
import redis
from redis import ConnectionPool

pool = None

def get_redis_values(path=".env"):
env_vars = get_env_vars(path)

host = env_vars.get("REDIS_HOST", os.getenv("REDIS_HOST", "localhost"))
password = env_vars.get("REDIS_PASSWORD", os.getenv("REDIS_PASSWORD", None))
port = env_vars.get("REDIS_PORT", os.getenv("REDIS_PORT", 6379))
username = env_vars.get("REDIS_USER", os.getenv("REDIS_USER", None))
def get_redis_values(config):
print("config is", config)
host = config.get("REDIS_HOST", "localhost")
password = config.get("REDIS_PASSWORD", None)
port = config.get("REDIS_PORT", 6379)
username = config.get("REDIS_USER", None)
if password is None:
redis_url = f"redis://{host}:{port}"
else:
redis_url = f"redis://{username}:{password}@{host}:{port}"
return redis_url

def get_redis_connection():
def get_redis_connection(config):
"""Retrieve Redis connection from the connection pool."""
global pool
if pool is None:
redis_url = get_redis_values()
redis_url = get_redis_values(config)
pool = ConnectionPool.from_url(redis_url)
return redis.Redis(connection_pool=pool)

def close_redis_connection(client):
"""Close the Redis connection."""
client.close()

def get_env_vars(path=".env"):
# combine env vars from .env file and system environment

def get_env_vars(path=".env.default"):
env_vars = {}

for key, value in os.environ.items():
env_vars[key] = value

if not os.path.exists(path):
return env_vars
with open(path, "r") as f:
for line in f:
key, value = line.strip().split("=")
env_vars[key] = value
print('*** env vars are:', env_vars)
if os.path.exists(path):
with open(path, "r") as f:
for line in f:
key, value = line.strip().split("=")
env_vars[key] = value
return env_vars
25 changes: 8 additions & 17 deletions distributaur/vast.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../"))

from distributaur.utils import get_env_vars, get_redis_connection
from distributaur.config import config

server_url_default = "https://console.vast.ai"
headers = {}
redis_client = get_redis_connection()
redis_client = get_redis_connection(config)

def dump_redis_values():
keys = redis_client.keys("*")
Expand Down Expand Up @@ -292,7 +293,8 @@ def parse_env(envs):
import urllib.parse


def search_offers(max_price, api_key):
def search_offers(max_price):
api_key = config.get("VAST_API_KEY")
base_url = "https://console.vast.ai/api/v0/bundles/"
headers = {
"Accept": "application/json",
Expand All @@ -301,7 +303,7 @@ def search_offers(max_price, api_key):
}
url = (
base_url
+ '?q={"gpu_ram":">=4","rentable":{"eq":true},"dph_total":{"lte":0.1480339514817041},"sort_option":{"0":["dph_total","asc"],"1":["total_flops","asc"]}}'
+ '?q={"gpu_ram":">=4","rentable":{"eq":true},"dph_total":{"lte":' + max_price + '},"sort_option":{"0":["dph_total","asc"],"1":["total_flops","asc"]}}'
)

print("url", url)
Expand All @@ -319,23 +321,12 @@ def search_offers(max_price, api_key):


def create_instance(offer_id, image, env):
# check that the env is a dictionary and has the vars REDIS_HOST, REDIS_PORT, REDIS_USER, REDIS_PASSWORD, HF_TOKEN, HF_REPO_ID, HF_PATH, VAST_API_KEY
if env is None:
raise ValueError("env is required")

if not isinstance(env, dict):
raise ValueError("env must be a dictionary")

if not all(
k in env for k in ["REDIS_HOST", "REDIS_PORT", "REDIS_USER", "REDIS_PASSWORD"]
):
# warn about missing redis env vars
print("Warning: Missing Redis environment variables")

if not all(k in env for k in ["HF_TOKEN", "HF_REPO_ID", "HF_PATH"]):
# warn about missing huggingface env vars
print("Warning: Missing Hugging Face environment variables")

if not "VAST_API_KEY" in env:
# warn about missing vast api key
print("Warning: Missing Vast API key")
Expand All @@ -345,7 +336,7 @@ def create_instance(offer_id, image, env):
"image": image,
"env": "",
"disk": 16, # Set a non-zero value for disk
"onstart": f"export PATH=$PATH:/ && cd ../ && REDIS_HOST={env['REDIS_HOST']} REDIS_PORT={env['REDIS_PORT']} REDIS_USER={env['REDIS_USER']} REDIS_PASSWORD={env['REDIS_PASSWORD']} HF_TOKEN={env['HF_TOKEN']} HF_REPO_ID={env['HF_REPO_ID']} HF_PATH={env['HF_PATH']} VAST_API_KEY={env['VAST_API_KEY']} celery -A simian.worker worker --loglevel=info",
"onstart": f"export PATH=$PATH:/ && cd ../ && REDIS_HOST={config.get('REDIS_HOST')} REDIS_PORT={config.get('REDIS_PORT')} REDIS_USER={config.get('REDIS_USER')} REDIS_PASSWORD={config.get('REDIS_PASSWORD')} HF_TOKEN={config.get('HF_TOKEN')} HF_REPO_ID={config.get('HF_REPO_ID')} HF_PATH={config.get('HF_PATH')} VAST_API_KEY={config.get('VAST_API_KEY')} celery -A simian.worker worker --loglevel=info",
"runtype": "ssh ssh_proxy",
"image_login": None,
"python_utf8": False,
Expand All @@ -368,8 +359,8 @@ def create_instance(offer_id, image, env):


def destroy_instance(instance_id):
env = get_env_vars()
headers = {"Authorization": f'Bearer {env["VAST_API_KEY"]}'}
api_key = config.get("VAST_API_KEY")
headers = {"Authorization": f'Bearer {api_key}'}
url = apiurl(f"/instances/{instance_id}/")
print(f"Terminating instance: {instance_id}")
response = http_del(url, headers=headers, json={})
Expand Down
Loading

0 comments on commit dd9c8d1

Please sign in to comment.