-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make a function to start a server (#13)
* Moving util functions for launching a server to florist/api/servers/utils.py and changing affected files * Adding launch_local_server function that launches a local server with Redis metrics monitoring
- Loading branch information
Showing
14 changed files
with
310 additions
and
160 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
"""General functions and definitions for monitoring.""" | ||
from pathlib import Path | ||
|
||
|
||
CLIENT_LOG_FOLDER = Path("logs/client/") | ||
SERVER_LOG_FOLDER = Path("logs/server/") | ||
|
||
|
||
def get_client_log_file_path(client_uuid: str) -> Path: | ||
""" | ||
Make the client log file path given its UUID. | ||
Will use the default client log folder defined in this class. | ||
:param client_uuid: (str) the uuid for the client to generate the log file. | ||
:return: (pathlib.Path) The client log file path in the format f"{CLIENT_LOG_FOLDER}/{client_uuid}.out". | ||
""" | ||
CLIENT_LOG_FOLDER.mkdir(parents=True, exist_ok=True) | ||
return CLIENT_LOG_FOLDER / f"{client_uuid}.out" | ||
|
||
|
||
def get_server_log_file_path(server_uuid: str) -> Path: | ||
""" | ||
Make the default server log file path given its UUID. | ||
Will use the default server log folder defined in this class. | ||
:param server_uuid: (str) the uuid for the server to generate the log file. | ||
:return: (Path) The server log file path in the format f"{SERVER_LOG_FOLDER}/{server_uuid}.out". | ||
""" | ||
SERVER_LOG_FOLDER.mkdir(parents=True, exist_ok=True) | ||
return SERVER_LOG_FOLDER / f"{server_uuid}.out" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Implementations for the servers.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""Functions and definitions to launch local servers.""" | ||
import uuid | ||
from functools import partial | ||
from multiprocessing import Process | ||
from typing import Tuple | ||
|
||
from torch import nn | ||
|
||
from florist.api.launchers.local import launch_server | ||
from florist.api.monitoring.logs import get_server_log_file_path | ||
from florist.api.monitoring.metrics import RedisMetricsReporter | ||
from florist.api.servers.utils import get_server | ||
|
||
|
||
def launch_local_server( | ||
model: nn.Module, | ||
n_clients: int, | ||
server_address: str, | ||
n_server_rounds: int, | ||
redis_host: str, | ||
redis_port: str, | ||
) -> Tuple[str, Process]: | ||
""" | ||
Launch a FL server locally. | ||
:param model: (torch.nn.Module) The model to be used by the server. Should match the clients' model. | ||
:param n_clients: (int) The number of clients that will report to this server. | ||
:param server_address: (str) The address the server should start at. | ||
:param n_server_rounds: (int) The number of rounds the training should run for. | ||
:param redis_host: (str) the host name for the Redis instance for metrics reporting. | ||
:param redis_port: (str) the port for the Redis instance for metrics reporting. | ||
:return: (Tuple[str, multiprocessing.Process]) the UUID of the server, which can be used to pull | ||
metrics from Redis, along with its local process object. | ||
""" | ||
server_uuid = str(uuid.uuid4()) | ||
|
||
metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=server_uuid) | ||
server_constructor = partial(get_server, model=model, n_clients=n_clients, metrics_reporter=metrics_reporter) | ||
|
||
log_file_name = str(get_server_log_file_path(server_uuid)) | ||
server_process = launch_server( | ||
server_constructor, | ||
server_address, | ||
n_server_rounds, | ||
log_file_name, | ||
seconds_to_sleep=0, | ||
) | ||
|
||
return server_uuid, server_process |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
"""Utilities functions and definitions for starting a server.""" | ||
from functools import partial | ||
from typing import Callable, Dict, Union | ||
|
||
from fl4health.client_managers.base_sampling_manager import SimpleClientManager | ||
from fl4health.reporting.metrics import MetricsReporter | ||
from fl4health.server.base_server import FlServer | ||
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn | ||
from flwr.common.parameter import ndarrays_to_parameters | ||
from flwr.server.strategy import FedAvg | ||
from torch import nn | ||
|
||
|
||
FitConfigFn = Callable[[int], Dict[str, Union[bool, bytes, float, int, str]]] | ||
|
||
|
||
def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict[str, int]: | ||
""" | ||
Return a dictionary used to configure the server's fit function. | ||
:param batch_size: (int) the size of the batch of samples. | ||
:param local_epochs: (int) the number of local epochs the clients will run. | ||
:param current_server_round: (int) the current server round | ||
:return: (Dict[str, int]) A dictionary to the used at the config for the fit function. | ||
""" | ||
return { | ||
"batch_size": batch_size, | ||
"current_server_round": current_server_round, | ||
"local_epochs": local_epochs, | ||
} | ||
|
||
|
||
def get_server( | ||
model: nn.Module, | ||
fit_config: Callable[[int, int, int], Dict[str, int]] = fit_config, | ||
n_clients: int = 2, | ||
batch_size: int = 8, | ||
local_epochs: int = 1, | ||
metrics_reporter: MetricsReporter = None, | ||
) -> FlServer: | ||
""" | ||
Return a server instance with FedAvg aggregation strategy. | ||
:param model: (torch.nn.Model) the model the server and clients will be using. | ||
:param fit_config: (Callable[[int, int, int], Dict[str, int]]) the function to configure the fit method. | ||
:param n_clients: (int) the number of clients that will participate on training. Optional, default is 2. | ||
:param batch_size: (int) the size of the batch of samples. Optional, default is 8. | ||
:param local_epochs: (int) the number of local epochs the clients will run. Optional, default is 1. | ||
:param metrics_reporter: (fl4health.reporting.metrics.MetricsReporter) An optional metrics reporter instance. | ||
Default is None. | ||
:return: (fl4health.server.base_server.FlServer) An instance of FlServer with FedAvg as strategy. | ||
""" | ||
fit_config_fn: FitConfigFn = partial(fit_config, batch_size, local_epochs) # type: ignore | ||
initial_model_parameters = ndarrays_to_parameters([val.cpu().numpy() for _, val in model.state_dict().items()]) | ||
strategy = FedAvg( | ||
min_fit_clients=n_clients, | ||
min_evaluate_clients=n_clients, | ||
min_available_clients=n_clients, | ||
on_fit_config_fn=fit_config_fn, | ||
on_evaluate_config_fn=fit_config_fn, | ||
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, | ||
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, | ||
initial_parameters=initial_model_parameters, | ||
) | ||
client_manager = SimpleClientManager() | ||
return FlServer(strategy=strategy, client_manager=client_manager, metrics_reporter=metrics_reporter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,45 @@ | ||
import json | ||
import tempfile | ||
from functools import partial | ||
import time | ||
from unittest.mock import ANY | ||
|
||
import redis | ||
|
||
from florist.api import client | ||
from florist.api.launchers.local import launch_server | ||
from florist.tests.utils.api.launch_utils import get_server | ||
from florist.api.clients.mnist import MnistNet | ||
from florist.api.monitoring.logs import get_server_log_file_path | ||
from florist.api.servers.local import launch_local_server | ||
|
||
|
||
def test_train(): | ||
test_server_address = "0.0.0.0:8080" | ||
|
||
with tempfile.TemporaryDirectory() as temp_dir: | ||
server_constructor = partial(get_server, n_clients=1) | ||
server_log_file = f"{temp_dir}/server.out" | ||
server_process = launch_server(server_constructor, test_server_address, 2, server_log_file) | ||
|
||
test_server_address = "0.0.0.0:8080" | ||
test_client = "MNIST" | ||
test_data_path = f"{temp_dir}/data" | ||
test_redis_host = "localhost" | ||
test_redis_port = "6379" | ||
|
||
server_uuid, server_process = launch_local_server( | ||
MnistNet(), | ||
1, | ||
test_server_address, | ||
2, | ||
test_redis_host, | ||
test_redis_port, | ||
) | ||
time.sleep(10) # giving time to start the server | ||
|
||
response = client.start(test_server_address, test_client, test_data_path, test_redis_host, test_redis_port) | ||
json_body = json.loads(response.body.decode()) | ||
|
||
assert json.loads(response.body.decode()) == {"uuid": ANY} | ||
assert json_body == {"uuid": ANY} | ||
|
||
server_process.join() | ||
|
||
with open(server_log_file, "r") as f: | ||
redis_conn = redis.Redis(host=test_redis_host, port=test_redis_port) | ||
assert redis_conn.get(json_body["uuid"]) is not None | ||
assert redis_conn.get(server_uuid) is not None | ||
|
||
with open(get_server_log_file_path(server_uuid), "r") as f: | ||
file_contents = f.read() | ||
assert "FL finished in" in file_contents |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from unittest.mock import ANY, Mock, patch | ||
|
||
from florist.api.clients.mnist import MnistNet | ||
from florist.api.monitoring.logs import get_server_log_file_path | ||
from florist.api.monitoring.metrics import RedisMetricsReporter | ||
from florist.api.servers.local import launch_local_server | ||
from florist.api.servers.utils import get_server | ||
|
||
|
||
@patch("florist.api.servers.local.launch_server") | ||
def test_launch_local_server(mock_launch_server: Mock) -> None: | ||
test_model = MnistNet() | ||
test_n_clients = 2 | ||
test_server_address = "test-server-address" | ||
test_n_server_rounds = 5 | ||
test_redis_host = "test-redis-host" | ||
test_redis_port = "test-redis-port" | ||
test_server_process = "test-server-process" | ||
mock_launch_server.return_value = test_server_process | ||
|
||
server_uuid, server_process = launch_local_server( | ||
test_model, | ||
test_n_clients, | ||
test_server_address, | ||
test_n_server_rounds, | ||
test_redis_host, | ||
test_redis_port, | ||
) | ||
|
||
assert server_uuid is not None | ||
assert server_process == test_server_process | ||
|
||
mock_launch_server.assert_called_once() | ||
call_args = mock_launch_server.call_args_list[0][0] | ||
call_kwargs = mock_launch_server.call_args_list[0][1] | ||
assert call_args == ( | ||
ANY, | ||
test_server_address, | ||
test_n_server_rounds, | ||
str(get_server_log_file_path(server_uuid)), | ||
) | ||
assert call_kwargs == {"seconds_to_sleep": 0} | ||
assert call_args[0].func == get_server | ||
assert call_args[0].keywords == {"model": test_model, "n_clients": test_n_clients, "metrics_reporter": ANY} | ||
|
||
metrics_reporter = call_args[0].keywords["metrics_reporter"] | ||
assert isinstance(metrics_reporter, RedisMetricsReporter) | ||
assert metrics_reporter.host == test_redis_host | ||
assert metrics_reporter.port == test_redis_port | ||
assert metrics_reporter.run_id == server_uuid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.