Skip to content

Commit

Permalink
Make a function to start a server (#13)
Browse files Browse the repository at this point in the history
* Moving util functions for launching a server to florist/api/servers/utils.py and changing affected files
* Adding launch_local_server function that launches a local server with Redis metrics monitoring
  • Loading branch information
lotif authored Mar 22, 2024
1 parent f5ac81f commit 611fa19
Show file tree
Hide file tree
Showing 14 changed files with 310 additions and 160 deletions.
9 changes: 3 additions & 6 deletions florist/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

from florist.api.clients.common import Clients
from florist.api.launchers.local import launch_client
from florist.api.monitoring.logs import get_client_log_file_path
from florist.api.monitoring.metrics import RedisMetricsReporter


LOG_FOLDER = Path("logs/client/")

app = FastAPI()


Expand Down Expand Up @@ -63,10 +62,8 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red
metrics_reporter=metrics_reporter,
)

LOG_FOLDER.mkdir(parents=True, exist_ok=True)
log_file_name = LOG_FOLDER / f"{client_uuid}.out"

launch_client(client_obj, server_address, str(log_file_name))
log_file_name = str(get_client_log_file_path(client_uuid))
launch_client(client_obj, server_address, log_file_name)

return JSONResponse({"uuid": client_uuid})

Expand Down
32 changes: 32 additions & 0 deletions florist/api/monitoring/logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""General functions and definitions for monitoring."""
from pathlib import Path


CLIENT_LOG_FOLDER = Path("logs/client/")
SERVER_LOG_FOLDER = Path("logs/server/")


def get_client_log_file_path(client_uuid: str) -> Path:
"""
Make the client log file path given its UUID.
Will use the default client log folder defined in this class.
:param client_uuid: (str) the uuid for the client to generate the log file.
:return: (pathlib.Path) The client log file path in the format f"{CLIENT_LOG_FOLDER}/{client_uuid}.out".
"""
CLIENT_LOG_FOLDER.mkdir(parents=True, exist_ok=True)
return CLIENT_LOG_FOLDER / f"{client_uuid}.out"


def get_server_log_file_path(server_uuid: str) -> Path:
"""
Make the default server log file path given its UUID.
Will use the default server log folder defined in this class.
:param server_uuid: (str) the uuid for the server to generate the log file.
:return: (Path) The server log file path in the format f"{SERVER_LOG_FOLDER}/{server_uuid}.out".
"""
SERVER_LOG_FOLDER.mkdir(parents=True, exist_ok=True)
return SERVER_LOG_FOLDER / f"{server_uuid}.out"
1 change: 1 addition & 0 deletions florist/api/servers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Implementations for the servers."""
49 changes: 49 additions & 0 deletions florist/api/servers/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Functions and definitions to launch local servers."""
import uuid
from functools import partial
from multiprocessing import Process
from typing import Tuple

from torch import nn

from florist.api.launchers.local import launch_server
from florist.api.monitoring.logs import get_server_log_file_path
from florist.api.monitoring.metrics import RedisMetricsReporter
from florist.api.servers.utils import get_server


def launch_local_server(
model: nn.Module,
n_clients: int,
server_address: str,
n_server_rounds: int,
redis_host: str,
redis_port: str,
) -> Tuple[str, Process]:
"""
Launch a FL server locally.
:param model: (torch.nn.Module) The model to be used by the server. Should match the clients' model.
:param n_clients: (int) The number of clients that will report to this server.
:param server_address: (str) The address the server should start at.
:param n_server_rounds: (int) The number of rounds the training should run for.
:param redis_host: (str) the host name for the Redis instance for metrics reporting.
:param redis_port: (str) the port for the Redis instance for metrics reporting.
:return: (Tuple[str, multiprocessing.Process]) the UUID of the server, which can be used to pull
metrics from Redis, along with its local process object.
"""
server_uuid = str(uuid.uuid4())

metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=server_uuid)
server_constructor = partial(get_server, model=model, n_clients=n_clients, metrics_reporter=metrics_reporter)

log_file_name = str(get_server_log_file_path(server_uuid))
server_process = launch_server(
server_constructor,
server_address,
n_server_rounds,
log_file_name,
seconds_to_sleep=0,
)

return server_uuid, server_process
66 changes: 66 additions & 0 deletions florist/api/servers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Utilities functions and definitions for starting a server."""
from functools import partial
from typing import Callable, Dict, Union

from fl4health.client_managers.base_sampling_manager import SimpleClientManager
from fl4health.reporting.metrics import MetricsReporter
from fl4health.server.base_server import FlServer
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from flwr.common.parameter import ndarrays_to_parameters
from flwr.server.strategy import FedAvg
from torch import nn


FitConfigFn = Callable[[int], Dict[str, Union[bool, bytes, float, int, str]]]


def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict[str, int]:
"""
Return a dictionary used to configure the server's fit function.
:param batch_size: (int) the size of the batch of samples.
:param local_epochs: (int) the number of local epochs the clients will run.
:param current_server_round: (int) the current server round
:return: (Dict[str, int]) A dictionary to the used at the config for the fit function.
"""
return {
"batch_size": batch_size,
"current_server_round": current_server_round,
"local_epochs": local_epochs,
}


def get_server(
model: nn.Module,
fit_config: Callable[[int, int, int], Dict[str, int]] = fit_config,
n_clients: int = 2,
batch_size: int = 8,
local_epochs: int = 1,
metrics_reporter: MetricsReporter = None,
) -> FlServer:
"""
Return a server instance with FedAvg aggregation strategy.
:param model: (torch.nn.Model) the model the server and clients will be using.
:param fit_config: (Callable[[int, int, int], Dict[str, int]]) the function to configure the fit method.
:param n_clients: (int) the number of clients that will participate on training. Optional, default is 2.
:param batch_size: (int) the size of the batch of samples. Optional, default is 8.
:param local_epochs: (int) the number of local epochs the clients will run. Optional, default is 1.
:param metrics_reporter: (fl4health.reporting.metrics.MetricsReporter) An optional metrics reporter instance.
Default is None.
:return: (fl4health.server.base_server.FlServer) An instance of FlServer with FedAvg as strategy.
"""
fit_config_fn: FitConfigFn = partial(fit_config, batch_size, local_epochs) # type: ignore
initial_model_parameters = ndarrays_to_parameters([val.cpu().numpy() for _, val in model.state_dict().items()])
strategy = FedAvg(
min_fit_clients=n_clients,
min_evaluate_clients=n_clients,
min_available_clients=n_clients,
on_fit_config_fn=fit_config_fn,
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=initial_model_parameters,
)
client_manager = SimpleClientManager()
return FlServer(strategy=strategy, client_manager=client_manager, metrics_reporter=metrics_reporter)
8 changes: 5 additions & 3 deletions florist/tests/integration/api/launchers/test_launch.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
import re
import tempfile
from functools import partial
from pathlib import Path

import torch

from florist.api.launchers.local import launch
from florist.api.clients.mnist import MnistClient
from florist.tests.utils.api.launch_utils import get_server
from florist.api.clients.mnist import MnistClient, MnistNet
from florist.api.servers.utils import get_server


def assert_string_in_file(file_path: str, search_string: str) -> bool:
Expand All @@ -28,10 +29,11 @@ def test_launch() -> None:
os.mkdir(client_data_path)
clients = [MnistClient(client_data_path, [], torch.device("cpu")) for client_data_path in client_data_paths]

server_constructor = partial(get_server, model=MnistNet())
server_path = os.path.join(temp_dir, "server")
client_base_path = f"{temp_dir}/client"
launch(
get_server,
server_constructor,
server_address,
n_server_rounds,
clients,
Expand Down
35 changes: 24 additions & 11 deletions florist/tests/integration/api/test_train.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,45 @@
import json
import tempfile
from functools import partial
import time
from unittest.mock import ANY

import redis

from florist.api import client
from florist.api.launchers.local import launch_server
from florist.tests.utils.api.launch_utils import get_server
from florist.api.clients.mnist import MnistNet
from florist.api.monitoring.logs import get_server_log_file_path
from florist.api.servers.local import launch_local_server


def test_train():
test_server_address = "0.0.0.0:8080"

with tempfile.TemporaryDirectory() as temp_dir:
server_constructor = partial(get_server, n_clients=1)
server_log_file = f"{temp_dir}/server.out"
server_process = launch_server(server_constructor, test_server_address, 2, server_log_file)

test_server_address = "0.0.0.0:8080"
test_client = "MNIST"
test_data_path = f"{temp_dir}/data"
test_redis_host = "localhost"
test_redis_port = "6379"

server_uuid, server_process = launch_local_server(
MnistNet(),
1,
test_server_address,
2,
test_redis_host,
test_redis_port,
)
time.sleep(10) # giving time to start the server

response = client.start(test_server_address, test_client, test_data_path, test_redis_host, test_redis_port)
json_body = json.loads(response.body.decode())

assert json.loads(response.body.decode()) == {"uuid": ANY}
assert json_body == {"uuid": ANY}

server_process.join()

with open(server_log_file, "r") as f:
redis_conn = redis.Redis(host=test_redis_host, port=test_redis_port)
assert redis_conn.get(json_body["uuid"]) is not None
assert redis_conn.get(server_uuid) is not None

with open(get_server_log_file_path(server_uuid), "r") as f:
file_contents = f.read()
assert "FL finished in" in file_contents
50 changes: 50 additions & 0 deletions florist/tests/unit/api/servers/test_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from unittest.mock import ANY, Mock, patch

from florist.api.clients.mnist import MnistNet
from florist.api.monitoring.logs import get_server_log_file_path
from florist.api.monitoring.metrics import RedisMetricsReporter
from florist.api.servers.local import launch_local_server
from florist.api.servers.utils import get_server


@patch("florist.api.servers.local.launch_server")
def test_launch_local_server(mock_launch_server: Mock) -> None:
test_model = MnistNet()
test_n_clients = 2
test_server_address = "test-server-address"
test_n_server_rounds = 5
test_redis_host = "test-redis-host"
test_redis_port = "test-redis-port"
test_server_process = "test-server-process"
mock_launch_server.return_value = test_server_process

server_uuid, server_process = launch_local_server(
test_model,
test_n_clients,
test_server_address,
test_n_server_rounds,
test_redis_host,
test_redis_port,
)

assert server_uuid is not None
assert server_process == test_server_process

mock_launch_server.assert_called_once()
call_args = mock_launch_server.call_args_list[0][0]
call_kwargs = mock_launch_server.call_args_list[0][1]
assert call_args == (
ANY,
test_server_address,
test_n_server_rounds,
str(get_server_log_file_path(server_uuid)),
)
assert call_kwargs == {"seconds_to_sleep": 0}
assert call_args[0].func == get_server
assert call_args[0].keywords == {"model": test_model, "n_clients": test_n_clients, "metrics_reporter": ANY}

metrics_reporter = call_args[0].keywords["metrics_reporter"]
assert isinstance(metrics_reporter, RedisMetricsReporter)
assert metrics_reporter.host == test_redis_host
assert metrics_reporter.port == test_redis_port
assert metrics_reporter.run_id == server_uuid
3 changes: 2 additions & 1 deletion florist/tests/unit/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from florist.api import client
from florist.api.clients.mnist import MnistClient
from florist.api.monitoring.logs import get_client_log_file_path
from florist.api.monitoring.metrics import RedisMetricsReporter


Expand All @@ -30,7 +31,7 @@ def test_start_success(mock_launch_client: Mock) -> None:
json_body = json.loads(response.body.decode())
assert json_body == {"uuid": ANY}

log_file_name = str(client.LOG_FOLDER / f"{json_body['uuid']}.out")
log_file_name = str(get_client_log_file_path(json_body["uuid"]))
mock_launch_client.assert_called_once_with(ANY, test_server_address, log_file_name)

client_obj = mock_launch_client.call_args_list[0][0][0]
Expand Down
Empty file.
Loading

0 comments on commit 611fa19

Please sign in to comment.