Skip to content

Commit

Permalink
Merge pull request #1 from RaccoonResearch/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
lalalune authored May 25, 2024
2 parents 0040b31 + a20279b commit b6ae95f
Show file tree
Hide file tree
Showing 14 changed files with 277 additions and 305 deletions.
5 changes: 5 additions & 0 deletions .env.default
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_USER=
REDIS_PASSWORD=
VAST_API_KEY=
8 changes: 0 additions & 8 deletions .env.example

This file was deleted.

5 changes: 2 additions & 3 deletions distributaur/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .task_runner import *
from .vast import *
from .utils import *
from .core import *
from .vast import *
136 changes: 136 additions & 0 deletions distributaur/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# distributaur/task_runner.py
from celery import Celery
import os
import sys
import json
import os
import redis
from redis import ConnectionPool

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../"))

app = None
registered_functions = {}
pool = None


def get_env_vars(path=".env.default"):
print("get_env_vars")
env_vars = {}
if os.path.exists(path):
with open(path, "r") as f:
for line in f:
key, value = line.strip().split("=")
env_vars[key] = value
return env_vars

class Config:
def __init__(self):
self.settings = {}
self.settings.update(get_env_vars())

def configure(self, **kwargs):
self.settings.update(kwargs)

def get(self, key, default=None):
return self.settings.get(key, default)

config = Config()

def get_redis_values(config):
host = config.get("REDIS_HOST", None)
password = config.get("REDIS_PASSWORD", None)
port = config.get("REDIS_PORT", None)
username = config.get("REDIS_USER", None)

print("host", host)
print("password", password)
print("port", port)
print("username", username)

if None in [host, password, port, username]:
raise ValueError("Missing required Redis configuration values")

redis_url = f"redis://{username}:{password}@{host}:{port}"
return redis_url

def get_redis_connection(config, force_new=False):
"""Retrieve Redis connection from the connection pool."""
global pool
if pool is None or force_new:
redis_url = get_redis_values(config)
pool = ConnectionPool.from_url(redis_url)
return redis.Redis(connection_pool=pool)


def close_redis_connection(client):
"""Close the Redis connection."""
client.close()

def configure(**kwargs):
global app
print('configuring')
config.configure(**kwargs)
redis_url = get_redis_values(config)
app = Celery(
"distributaur",
broker=redis_url,
backend=redis_url
)
# Disable task events
app.conf.worker_send_task_events = False
print("Celery configured.")

env_vars = get_env_vars(".env")
print("env_vars")
print(env_vars)
configure(**env_vars)

@app.task(name='call_function_task')
def call_function_task(func_name, args_json):
"""
Handle a task by executing the registered function with the provided arguments.
Args:
func_name (str): The name of the registered function to execute.
args_json (str): The JSON string representation of the arguments for the function.
"""
print(f"Received task with function: {func_name}, and args: {args_json}")
if func_name not in registered_functions:
print("registered_functions are", registered_functions)
raise ValueError(f"Function '{func_name}' is not registered.")

func = registered_functions[func_name]
args = json.loads(args_json)
print(f"Executing task with function: {func_name}, and args: {args}")
result = func(**args)
update_function_status(call_function_task.request.id, "completed")
return result

def register_function(func):
"""Decorator to register a function in the dictionary."""
registered_functions[func.__name__] = func
return func

def execute_function(func_name, args):
"""
Execute a task by passing the function name and arguments.
Args:
func_name (str): The name of the registered function to execute.
args (dict): The dictionary of arguments for the function.
"""
args_json = json.dumps(args)
print(f"Dispatching task with function: {func_name}, and args: {args_json}")
return call_function_task.delay(func_name, args_json)

def update_function_status(task_id, status):
"""
Update the status of a task in Redis.
Args:
task_id (str): The ID of the task.
status (str): The new status of the task.
"""
redis_client = get_redis_connection(config)
redis_client.set(f"task_status:{task_id}", status)
64 changes: 0 additions & 64 deletions distributaur/task_runner.py

This file was deleted.

3 changes: 1 addition & 2 deletions distributaur/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .utils_test import *
from .task_runner_test import *
from .core_test import *
from .vast_test import *
Original file line number Diff line number Diff line change
@@ -1,46 +1,78 @@
# /Users/shawwalters/distributoor/distributaur/tests/utils_test.py
import json
import pytest
from unittest.mock import MagicMock, patch

import subprocess
import time
import pytest
from distributaur.task_runner import execute_function, register_function, registered_functions, update_function_status
from distributaur.utils import get_env_vars, get_redis_connection, get_redis_values, close_redis_connection

from distributaur.core import execute_function, register_function, registered_functions, close_redis_connection, get_redis_connection, config, configure, registered_functions, update_function_status, get_env_vars

@pytest.fixture
def env_file(tmpdir):
env_content = """\
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_USER=user
REDIS_PASSWORD=password\
"""
env_file = tmpdir.join(".env")
env_file.write(env_content)
return env_file
def mock_task_function():
"""
Fixture that returns a mock task function.
"""
return MagicMock()


def test_register_function(mock_task_function):
"""
Test the register_function function.
"""
mock_task_function.__name__ = "mock_task" # Set the __name__ attribute
decorated_task = register_function(mock_task_function)

assert callable(decorated_task)
assert mock_task_function.__name__ in registered_functions
assert registered_functions[mock_task_function.__name__] == mock_task_function
print("Test passed")

@patch("distributaur.core.call_function_task.delay")
def test_execute_function(mock_delay, mock_task_function):
"""
Test the execute_function function.
"""
mock_task_function.__name__ = "mock_task" # Set the __name__ attribute
register_function(mock_task_function)

params = {'arg1': 1, 'arg2': 2}
execute_function(mock_task_function.__name__, params)

mock_delay.assert_called_once_with(mock_task_function.__name__, json.dumps(params))
print("Test passed")

@patch("distributaur.core.get_redis_connection")
def test_update_function_status(mock_get_redis_connection):
"""
Test the update_function_status function.
"""
mock_redis_client = MagicMock()
mock_get_redis_connection.return_value = mock_redis_client

task_id = "task_123"
status = "SUCCESS"

update_function_status(task_id, status)

mock_redis_client.set.assert_called_once_with(f"task_status:{task_id}", status)
print("Test passed")

# Add teardown to close Redis connections
def teardown_module(module):
client = get_redis_connection(config)
close_redis_connection(client)

@pytest.fixture
def redis_client():
client = get_redis_connection()
client = get_redis_connection(config, force_new=True)
yield client
close_redis_connection(client)

def test_redis_connection(redis_client):
assert redis_client.ping()
print("Redis connection test passed")

def test_get_redis_values(redis_client, env_file):
redis_url = get_redis_values(env_file)
assert redis_url == "redis://user:password@localhost:6379"

def test_get_env_vars(env_file):
env_vars = get_env_vars(env_file)
assert env_vars == {
"REDIS_HOST": "localhost",
"REDIS_PORT": "6379",
"REDIS_USER": "user",
"REDIS_PASSWORD": "password",
}

def test_get_redis_connection(redis_client):
assert redis_client.ping()
print("Redis connection test passed")
Expand Down Expand Up @@ -75,17 +107,17 @@ def example_function(arg1, arg2):
"-A",
"distributaur.tests.test_worker",
"worker",
"--loglevel=info",
"--concurrency=1",
"--heartbeat-interval=1",
"--loglevel=info"
]
print("worker_cmd")
print(worker_cmd)
worker_process = subprocess.Popen(worker_cmd)

time.sleep(5)

task_params = {"arg1": 10, "arg2": 20}
task = execute_function("example_function", task_params)
result = task.get(timeout=10)
result = task.get(timeout=3)

assert result == "Result: arg1=10, arg2=20"

Expand All @@ -95,7 +127,7 @@ def example_function(arg1, arg2):
print("Worker task execution test passed")

def test_task_status_update():
redis_client = get_redis_connection()
redis_client = get_redis_connection(config)

try:
task_status_keys = redis_client.keys("task_status:*")
Expand Down
Loading

0 comments on commit b6ae95f

Please sign in to comment.