diff --git a/.gitignore b/.gitignore index 8ae22060..7d2c7f53 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,4 @@ next-env.d.ts /florist/tsconfig.json /metrics/ +/logs/ diff --git a/florist/api/client.py b/florist/api/client.py index 6973b36d..8a145feb 100644 --- a/florist/api/client.py +++ b/florist/api/client.py @@ -1,7 +1,17 @@ """FLorist client FastAPI endpoints.""" +import uuid +from pathlib import Path + +import torch from fastapi import FastAPI from fastapi.responses import JSONResponse +from florist.api.clients.common import Clients +from florist.api.launchers.local import launch_client +from florist.api.monitoring.metrics import RedisMetricsReporter + + +LOG_FOLDER = Path("logs/client/") app = FastAPI() @@ -14,3 +24,51 @@ def connect() -> JSONResponse: :return: JSON `{"status": "ok"}` """ return JSONResponse({"status": "ok"}) + + +@app.get("/api/client/start") +def start(server_address: str, client: str, data_path: str, redis_host: str, redis_port: str) -> JSONResponse: + """ + Start a client. + + :param server_address: (str) the address of the server this client should report to. + It should be comprised of the host name and port separated by colon (e.g. "localhost:8080"). + :param client: (str) the name of the client. Should be one of the enum values of florist.api.client.Clients. + :param data_path: (str) the path where the training data is located. + :param redis_host: (str) the host name for the Redis instance for metrics reporting. + :param redis_port: (str) the port for the Redis instance for metrics reporting. + :return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the client in the + format below, which can be used to pull metrics from Redis. + {"uuid": } + If not successful, returns the appropriate error code with a JSON with the format below: + {"error": } + """ + try: + if client not in Clients.list(): + return JSONResponse( + content={"error": f"Client '{client}' not supported. Supported clients: {Clients.list()}"}, + status_code=400, + ) + + client_uuid = str(uuid.uuid4()) + metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=client_uuid) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + client_class = Clients.class_for_client(Clients[client]) + client_obj = client_class( + data_path=Path(data_path), + metrics=[], + device=device, + metrics_reporter=metrics_reporter, + ) + + LOG_FOLDER.mkdir(parents=True, exist_ok=True) + log_file_name = LOG_FOLDER / f"{client_uuid}.out" + + launch_client(client_obj, server_address, str(log_file_name)) + + return JSONResponse({"uuid": client_uuid}) + + except Exception as ex: + return JSONResponse({"error": str(ex)}, status_code=500) diff --git a/florist/api/clients/__init__.py b/florist/api/clients/__init__.py new file mode 100644 index 00000000..51fda2a4 --- /dev/null +++ b/florist/api/clients/__init__.py @@ -0,0 +1 @@ +"""Implementations for the clients.""" diff --git a/florist/api/clients/common.py b/florist/api/clients/common.py new file mode 100644 index 00000000..121f06ad --- /dev/null +++ b/florist/api/clients/common.py @@ -0,0 +1,36 @@ +"""Common functions and definitions for clients.""" +from enum import Enum +from typing import List + +from fl4health.clients.basic_client import BasicClient + +from florist.api.clients.mnist import MnistClient + + +class Clients(Enum): + """Enumeration of supported clients.""" + + MNIST = "MNIST" + + @classmethod + def class_for_client(cls, client: "Clients") -> type[BasicClient]: + """ + Return the class for a given client. + + :param client: The client enumeration object. + :return: A subclass of BasicClient corresponding to the given client. + :raises ValueError: if the client is not supported. + """ + if client == Clients.MNIST: + return MnistClient + + raise ValueError(f"Client {client.value} not supported.") + + @classmethod + def list(cls) -> List[str]: + """ + List all the supported clients. + + :return: a list of supported clients. + """ + return [client.value for client in Clients] diff --git a/florist/api/clients/mnist.py b/florist/api/clients/mnist.py new file mode 100644 index 00000000..57464694 --- /dev/null +++ b/florist/api/clients/mnist.py @@ -0,0 +1,82 @@ +"""Implementation of the MNIST client and model.""" +from typing import Tuple + +import torch +import torch.nn.functional as f +from fl4health.clients.basic_client import BasicClient +from fl4health.utils.dataset import MnistDataset +from fl4health.utils.load_data import load_mnist_data +from flwr.common.typing import Config +from torch import nn +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer +from torch.utils.data import DataLoader + + +class MnistClient(BasicClient): # type: ignore + """Implementation of the MNIST client.""" + + def get_data_loaders(self, config: Config) -> Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]: + """ + Return the data loader for MNIST data. + + :param config: (Config) the Config object for this client. + :return: (Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]) a tuple with the train data loader + and validation data loader respectively. + """ + train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size=config["batch_size"]) + return train_loader, val_loader + + def get_model(self, config: Config) -> nn.Module: + """ + Return the model for MNIST data. + + :param config: (Config) the Config object for this client. + :return: (torch.nn.Module) An instance of florist.api.clients.mnist.MnistNet. + """ + return MnistNet() + + def get_optimizer(self, config: Config) -> Optimizer: + """ + Return the optimizer for MNIST data. + + :param config: (Config) the Config object for this client. + :return: (torch.optim.Optimizer) An instance of torch.optim.SGD with learning + rate of 0.001 and momentum of 0.9. + """ + return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) + + def get_criterion(self, config: Config) -> _Loss: + """ + Return the loss for MNIST data. + + :param config: (Config) the Config object for this client. + :return: (torch.nn.modules.loss._Loss) an instance of torch.nn.CrossEntropyLoss. + """ + return torch.nn.CrossEntropyLoss() + + +class MnistNet(nn.Module): + """Implementation of the Mnist model.""" + + def __init__(self) -> None: + """Initialize an instance of MnistNet.""" + super().__init__() + self.conv1 = nn.Conv2d(1, 8, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(8, 16, 5) + self.fc1 = nn.Linear(16 * 4 * 4, 120) + self.fc2 = nn.Linear(120, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Perform a forward pass for the given tensor. + + :param x: (torch.Tensor) the tensor to perform the forward pass on. + :return: (torch.Tensor) a result tensor after the forward pass. + """ + x = self.pool(f.relu(self.conv1(x))) + x = self.pool(f.relu(self.conv2(x))) + x = x.view(-1, 16 * 4 * 4) + x = f.relu(self.fc1(x)) + return f.relu(self.fc2(x)) diff --git a/florist/api/launchers/__init__.py b/florist/api/launchers/__init__.py new file mode 100644 index 00000000..f08007bc --- /dev/null +++ b/florist/api/launchers/__init__.py @@ -0,0 +1 @@ +"""Launchers for servers and clients.""" diff --git a/florist/api/launchers/launch.py b/florist/api/launchers/local.py similarity index 99% rename from florist/api/launchers/launch.py rename to florist/api/launchers/local.py index 04e8e81a..1681f5e1 100644 --- a/florist/api/launchers/launch.py +++ b/florist/api/launchers/local.py @@ -1,4 +1,4 @@ -"""Launcher functions for clients and servers.""" +"""Launcher functions for local clients and servers.""" import logging import sys import time diff --git a/florist/api/monitoring/__init__.py b/florist/api/monitoring/__init__.py new file mode 100644 index 00000000..4ed23d32 --- /dev/null +++ b/florist/api/monitoring/__init__.py @@ -0,0 +1 @@ +"""Classes and functions for monitoring of clients and servers' execution.""" diff --git a/florist/api/monitoring/metrics.py b/florist/api/monitoring/metrics.py index 8bc9519a..ed8461fc 100644 --- a/florist/api/monitoring/metrics.py +++ b/florist/api/monitoring/metrics.py @@ -9,23 +9,25 @@ class RedisMetricsReporter(MetricsReporter): # type: ignore - """Save the metrics to a Redis instance while it records them.""" + """ + Save the metrics to a Redis instance while it records them. - def __init__( - self, - redis_connection: redis.client.Redis, - run_id: Optional[str] = None, - ): + Lazily instantiates a Redis connection when the first metrics are recorded. + """ + + def __init__(self, host: str, port: str, run_id: Optional[str] = None): """ Init an instance of RedisMetricsReporter. - :param redis_connection: (redis.client.Redis) the connection object to a Redis. Should be the output - of redis.Redis(host=host, port=port) + :param host: (str) The host address where the Redis instance is running. + :param port: (str) The port where the Redis instance is running on the host. :param run_id: (Optional[str]) the identifier for the run which these metrics are from. It will be used as the name of the object in Redis. Optional, default is a random UUID. """ super().__init__(run_id) - self.redis_connection = redis_connection + self.host = host + self.port = port + self.redis_connection: Optional[redis.Redis] = None def add_to_metrics(self, data: Dict[str, Any]) -> None: """ @@ -51,7 +53,14 @@ def add_to_metrics_at_round(self, fl_round: int, data: Dict[str, Any]) -> None: self.dump() def dump(self) -> None: - """Dump the current metrics to Redis under the run_id name.""" + """ + Dump the current metrics to Redis under the run_id name. + + Will instantiate a Redis connection if it's the first time it runs for this instance. + """ + if self.redis_connection is None: + self.redis_connection = redis.Redis(host=self.host, port=self.port) + encoded_metrics = json.dumps(self.metrics, cls=DateTimeEncoder) log(DEBUG, f"Dumping metrics to redis at key '{self.run_id}': {encoded_metrics}") self.redis_connection.set(self.run_id, encoded_metrics) diff --git a/florist/tests/api/monitoring/test_metrics.py b/florist/tests/api/monitoring/test_metrics.py deleted file mode 100644 index 6fc50a99..00000000 --- a/florist/tests/api/monitoring/test_metrics.py +++ /dev/null @@ -1,60 +0,0 @@ -import datetime -import json -from unittest.mock import Mock - -from fl4health.reporting.metrics import DateTimeEncoder -from freezegun import freeze_time - -from florist.api.monitoring.metrics import RedisMetricsReporter - - -@freeze_time("2012-12-11 10:09:08") -def test_add_to_metrics() -> None: - mock_redis_connection = Mock() - test_run_id = "123" - test_data = {"test": "data", "date": datetime.datetime.now()} - - redis_metric_reporter = RedisMetricsReporter(mock_redis_connection, test_run_id) - redis_metric_reporter.add_to_metrics(test_data) - - mock_redis_connection.set.assert_called_once_with(test_run_id, json.dumps(test_data, cls=DateTimeEncoder)) - - -@freeze_time("2012-12-11 10:09:08") -def test_add_to_metrics_at_round() -> None: - mock_redis_connection = Mock() - test_run_id = "123" - test_data = {"test": "data", "date": datetime.datetime.now()} - test_round = 2 - - redis_metric_reporter = RedisMetricsReporter(mock_redis_connection, test_run_id) - redis_metric_reporter.add_to_metrics_at_round(test_round, test_data) - - expected_data = { - "rounds": { - str(test_round): test_data, - } - } - mock_redis_connection.set.assert_called_once_with(test_run_id, json.dumps(expected_data, cls=DateTimeEncoder)) - - -@freeze_time("2012-12-11 10:09:08") -def test_dump() -> None: - mock_redis_connection = Mock() - test_run_id = "123" - test_data = {"test": "data", "date": datetime.datetime.now()} - test_round = 2 - - redis_metric_reporter = RedisMetricsReporter(mock_redis_connection, test_run_id) - redis_metric_reporter.add_to_metrics(test_data) - redis_metric_reporter.add_to_metrics_at_round(test_round, test_data) - redis_metric_reporter.dump() - - expected_data = { - **test_data, - "rounds": { - str(test_round): test_data, - }, - } - assert mock_redis_connection.set.call_args_list[2][0][0] == test_run_id - assert mock_redis_connection.set.call_args_list[2][0][1] == json.dumps(expected_data, cls=DateTimeEncoder) diff --git a/florist/tests/integration/api/launchers/test_launch.py b/florist/tests/integration/api/launchers/test_launch.py index be99748c..fe8ca0a3 100644 --- a/florist/tests/integration/api/launchers/test_launch.py +++ b/florist/tests/integration/api/launchers/test_launch.py @@ -8,9 +8,9 @@ import torch from fl4health.server.base_server import FlServer -from florist.api.launchers.launch import launch -from florist.tests.utils.api.fl4health_utils import MnistClient, get_server_fedavg -from florist.tests.utils.api.models import MnistNet +from florist.api.launchers.local import launch +from florist.api.clients.mnist import MnistClient, MnistNet +from florist.tests.utils.api.fl4health_utils import get_server_fedavg def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict[str, int]: diff --git a/florist/tests/unit/api/monitoring/test_metrics.py b/florist/tests/unit/api/monitoring/test_metrics.py new file mode 100644 index 00000000..518d20c5 --- /dev/null +++ b/florist/tests/unit/api/monitoring/test_metrics.py @@ -0,0 +1,96 @@ +import datetime +import json +from unittest.mock import Mock, patch + +from fl4health.reporting.metrics import DateTimeEncoder +from freezegun import freeze_time + +from florist.api.monitoring.metrics import RedisMetricsReporter + + +@freeze_time("2012-12-11 10:09:08") +@patch("florist.api.monitoring.metrics.redis.Redis") +def test_add_to_metrics(mock_redis: Mock) -> None: + mock_redis_connection = Mock() + mock_redis.return_value = mock_redis_connection + + test_host = "test host" + test_port = "test port" + test_run_id = "123" + test_data = {"test": "data", "date": datetime.datetime.now()} + + redis_metric_reporter = RedisMetricsReporter(test_host, test_port, test_run_id) + redis_metric_reporter.add_to_metrics(test_data) + + mock_redis.assert_called_once_with(host=test_host, port=test_port) + mock_redis_connection.set.assert_called_once_with(test_run_id, json.dumps(test_data, cls=DateTimeEncoder)) + + +@freeze_time("2012-12-11 10:09:08") +@patch("florist.api.monitoring.metrics.redis.Redis") +def test_add_to_metrics_at_round(mock_redis: Mock) -> None: + mock_redis_connection = Mock() + mock_redis.return_value = mock_redis_connection + + test_host = "test host" + test_port = "test port" + test_run_id = "123" + test_data = {"test": "data", "date": datetime.datetime.now()} + test_round = 2 + + redis_metric_reporter = RedisMetricsReporter(test_host, test_port, test_run_id) + redis_metric_reporter.add_to_metrics_at_round(test_round, test_data) + + mock_redis.assert_called_once_with(host=test_host, port=test_port) + expected_data = { + "rounds": { + str(test_round): test_data, + } + } + mock_redis_connection.set.assert_called_once_with(test_run_id, json.dumps(expected_data, cls=DateTimeEncoder)) + + +@freeze_time("2012-12-11 10:09:08") +@patch("florist.api.monitoring.metrics.redis.Redis") +def test_dump_without_existing_connection(mock_redis: Mock) -> None: + mock_redis_connection = Mock() + mock_redis.return_value = mock_redis_connection + + test_host = "test host" + test_port = "test port" + test_run_id = "123" + test_data = {"test": "data", "date": datetime.datetime.now()} + test_round = 2 + + redis_metric_reporter = RedisMetricsReporter(test_host, test_port, test_run_id) + redis_metric_reporter.add_to_metrics(test_data) + redis_metric_reporter.add_to_metrics_at_round(test_round, test_data) + redis_metric_reporter.dump() + + mock_redis.assert_called_once_with(host=test_host, port=test_port) + expected_data = { + **test_data, + "rounds": { + str(test_round): test_data, + }, + } + assert mock_redis_connection.set.call_args_list[2][0][0] == test_run_id + assert mock_redis_connection.set.call_args_list[2][0][1] == json.dumps(expected_data, cls=DateTimeEncoder) + + +@freeze_time("2012-12-11 10:09:08") +@patch("florist.api.monitoring.metrics.redis.Redis") +def test_dump_with_existing_connection(mock_redis: Mock) -> None: + mock_redis_connection = Mock() + + test_run_id = "123" + test_data = {"test": "data", "date": datetime.datetime.now()} + + redis_metric_reporter = RedisMetricsReporter("test host", "test port", test_run_id) + redis_metric_reporter.redis_connection = mock_redis_connection + redis_metric_reporter.metrics = test_data + redis_metric_reporter.dump() + + mock_redis.assert_not_called() + assert mock_redis_connection.set.call_args_list[0][0][0] == test_run_id + assert mock_redis_connection.set.call_args_list[0][0][1] == json.dumps(test_data, cls=DateTimeEncoder) diff --git a/florist/tests/unit/api/test_client.py b/florist/tests/unit/api/test_client.py index 5b0c3cc0..91298e47 100644 --- a/florist/tests/unit/api/test_client.py +++ b/florist/tests/unit/api/test_client.py @@ -1,7 +1,10 @@ """Tests for FLorist's client FastAPI endpoints.""" import json +from unittest.mock import ANY, Mock, patch from florist.api import client +from florist.api.clients.mnist import MnistClient +from florist.api.monitoring.metrics import RedisMetricsReporter def test_connect() -> None: @@ -9,4 +12,63 @@ def test_connect() -> None: response = client.connect() assert response.status_code == 200 - assert response.body.decode() == json.dumps({"status": "ok"}, separators=(",", ":")) + json_body = json.loads(response.body.decode()) + assert json_body == {"status": "ok"} + + +@patch("florist.api.client.launch_client") +def test_start_success(mock_launch_client: Mock) -> None: + test_server_address = "test-server-address" + test_client = "MNIST" + test_data_path = "test/data/path" + test_redis_host = "test-redis-host" + test_redis_port = "test-redis-port" + + response = client.start(test_server_address, test_client, test_data_path, test_redis_host, test_redis_port) + + assert response.status_code == 200 + json_body = json.loads(response.body.decode()) + assert json_body == {"uuid": ANY} + + log_file_name = str(client.LOG_FOLDER / f"{json_body['uuid']}.out") + mock_launch_client.assert_called_once_with(ANY, test_server_address, log_file_name) + + client_obj = mock_launch_client.call_args_list[0][0][0] + assert isinstance(client_obj, MnistClient) + assert str(client_obj.data_path) == test_data_path + + metrics_reporter = client_obj.metrics_reporter + assert isinstance(metrics_reporter, RedisMetricsReporter) + assert metrics_reporter.host == test_redis_host + assert metrics_reporter.port == test_redis_port + assert metrics_reporter.run_id == json_body["uuid"] + + +def test_start_fail_unsupported_client() -> None: + test_server_address = "test-server-address" + test_client = "WRONG" + test_data_path = "test/data/path" + test_redis_host = "test-redis-host" + test_redis_port = "test-redis-port" + + response = client.start(test_server_address, test_client, test_data_path, test_redis_host, test_redis_port) + + assert response.status_code == 400 + json_body = json.loads(response.body.decode()) + assert json_body == {"error": ANY} + assert f"Client '{test_client}' not supported" in json_body["error"] + + +@patch("florist.api.client.launch_client", side_effect=Exception("test exception")) +def test_start_fail_exception(mock_launch_client: Mock) -> None: + test_server_address = "test-server-address" + test_client = "MNIST" + test_data_path = "test/data/path" + test_redis_host = "test-redis-host" + test_redis_port = "test-redis-port" + + response = client.start(test_server_address, test_client, test_data_path, test_redis_host, test_redis_port) + + assert response.status_code == 500 + json_body = json.loads(response.body.decode()) + assert json_body == {"error": "test exception"} diff --git a/florist/tests/utils/api/fl4health_utils.py b/florist/tests/utils/api/fl4health_utils.py index fac9cc91..9b0e6bd2 100644 --- a/florist/tests/utils/api/fl4health_utils.py +++ b/florist/tests/utils/api/fl4health_utils.py @@ -1,35 +1,11 @@ from typing import Callable, List, Tuple -import torch from fl4health.client_managers.base_sampling_manager import SimpleClientManager -from fl4health.clients.basic_client import BasicClient from fl4health.server.base_server import FlServer -from fl4health.utils.load_data import load_mnist_data from flwr.common.parameter import ndarrays_to_parameters -from flwr.common.typing import Config, Metrics, Parameters +from flwr.common.typing import Metrics, Parameters from flwr.server.strategy import FedAvg from torch import nn -from torch.nn.modules.loss import _Loss -from torch.optim import Optimizer -from torch.utils.data import DataLoader - -from florist.tests.utils.api.models import MnistNet - - -class MnistClient(BasicClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: - train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size=config["batch_size"]) - return train_loader, val_loader - - def get_model(self, config: Config) -> nn.Module: - return MnistNet() - - def get_optimizer(self, config: Config) -> Optimizer: - opt = torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) - return opt - - def get_criterion(self, config: Config) -> _Loss: - return torch.nn.CrossEntropyLoss() def metric_aggregation( diff --git a/florist/tests/utils/api/models.py b/florist/tests/utils/api/models.py deleted file mode 100644 index 76274893..00000000 --- a/florist/tests/utils/api/models.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import nn - - -class MnistNet(nn.Module): - def __init__(self) -> None: - super().__init__() - self.conv1 = nn.Conv2d(1, 8, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(8, 16, 5) - self.fc1 = nn.Linear(16 * 4 * 4, 120) - self.fc2 = nn.Linear(120, 10) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = x.view(-1, 16 * 4 * 4) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - return x