Skip to content

Commit

Permalink
Make the client's "start" endpoint (#10)
Browse files Browse the repository at this point in the history
Making an endpoint to start a client
  • Loading branch information
lotif authored Mar 13, 2024
1 parent 62463ed commit 253e3c0
Show file tree
Hide file tree
Showing 15 changed files with 363 additions and 121 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,4 @@ next-env.d.ts
/florist/tsconfig.json

/metrics/
/logs/
58 changes: 58 additions & 0 deletions florist/api/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
"""FLorist client FastAPI endpoints."""
import uuid
from pathlib import Path

import torch
from fastapi import FastAPI
from fastapi.responses import JSONResponse

from florist.api.clients.common import Clients
from florist.api.launchers.local import launch_client
from florist.api.monitoring.metrics import RedisMetricsReporter


LOG_FOLDER = Path("logs/client/")

app = FastAPI()

Expand All @@ -14,3 +24,51 @@ def connect() -> JSONResponse:
:return: JSON `{"status": "ok"}`
"""
return JSONResponse({"status": "ok"})


@app.get("/api/client/start")
def start(server_address: str, client: str, data_path: str, redis_host: str, redis_port: str) -> JSONResponse:
"""
Start a client.
:param server_address: (str) the address of the server this client should report to.
It should be comprised of the host name and port separated by colon (e.g. "localhost:8080").
:param client: (str) the name of the client. Should be one of the enum values of florist.api.client.Clients.
:param data_path: (str) the path where the training data is located.
:param redis_host: (str) the host name for the Redis instance for metrics reporting.
:param redis_port: (str) the port for the Redis instance for metrics reporting.
:return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the client in the
format below, which can be used to pull metrics from Redis.
{"uuid": <client uuid>}
If not successful, returns the appropriate error code with a JSON with the format below:
{"error": <error message>}
"""
try:
if client not in Clients.list():
return JSONResponse(
content={"error": f"Client '{client}' not supported. Supported clients: {Clients.list()}"},
status_code=400,
)

client_uuid = str(uuid.uuid4())
metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=client_uuid)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

client_class = Clients.class_for_client(Clients[client])
client_obj = client_class(
data_path=Path(data_path),
metrics=[],
device=device,
metrics_reporter=metrics_reporter,
)

LOG_FOLDER.mkdir(parents=True, exist_ok=True)
log_file_name = LOG_FOLDER / f"{client_uuid}.out"

launch_client(client_obj, server_address, str(log_file_name))

return JSONResponse({"uuid": client_uuid})

except Exception as ex:
return JSONResponse({"error": str(ex)}, status_code=500)
1 change: 1 addition & 0 deletions florist/api/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Implementations for the clients."""
36 changes: 36 additions & 0 deletions florist/api/clients/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Common functions and definitions for clients."""
from enum import Enum
from typing import List

from fl4health.clients.basic_client import BasicClient

from florist.api.clients.mnist import MnistClient


class Clients(Enum):
"""Enumeration of supported clients."""

MNIST = "MNIST"

@classmethod
def class_for_client(cls, client: "Clients") -> type[BasicClient]:
"""
Return the class for a given client.
:param client: The client enumeration object.
:return: A subclass of BasicClient corresponding to the given client.
:raises ValueError: if the client is not supported.
"""
if client == Clients.MNIST:
return MnistClient

raise ValueError(f"Client {client.value} not supported.")

@classmethod
def list(cls) -> List[str]:
"""
List all the supported clients.
:return: a list of supported clients.
"""
return [client.value for client in Clients]
82 changes: 82 additions & 0 deletions florist/api/clients/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Implementation of the MNIST client and model."""
from typing import Tuple

import torch
import torch.nn.functional as f
from fl4health.clients.basic_client import BasicClient
from fl4health.utils.dataset import MnistDataset
from fl4health.utils.load_data import load_mnist_data
from flwr.common.typing import Config
from torch import nn
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader


class MnistClient(BasicClient): # type: ignore
"""Implementation of the MNIST client."""

def get_data_loaders(self, config: Config) -> Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]:
"""
Return the data loader for MNIST data.
:param config: (Config) the Config object for this client.
:return: (Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]) a tuple with the train data loader
and validation data loader respectively.
"""
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size=config["batch_size"])
return train_loader, val_loader

def get_model(self, config: Config) -> nn.Module:
"""
Return the model for MNIST data.
:param config: (Config) the Config object for this client.
:return: (torch.nn.Module) An instance of florist.api.clients.mnist.MnistNet.
"""
return MnistNet()

def get_optimizer(self, config: Config) -> Optimizer:
"""
Return the optimizer for MNIST data.
:param config: (Config) the Config object for this client.
:return: (torch.optim.Optimizer) An instance of torch.optim.SGD with learning
rate of 0.001 and momentum of 0.9.
"""
return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)

def get_criterion(self, config: Config) -> _Loss:
"""
Return the loss for MNIST data.
:param config: (Config) the Config object for this client.
:return: (torch.nn.modules.loss._Loss) an instance of torch.nn.CrossEntropyLoss.
"""
return torch.nn.CrossEntropyLoss()


class MnistNet(nn.Module):
"""Implementation of the Mnist model."""

def __init__(self) -> None:
"""Initialize an instance of MnistNet."""
super().__init__()
self.conv1 = nn.Conv2d(1, 8, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(8, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform a forward pass for the given tensor.
:param x: (torch.Tensor) the tensor to perform the forward pass on.
:return: (torch.Tensor) a result tensor after the forward pass.
"""
x = self.pool(f.relu(self.conv1(x)))
x = self.pool(f.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = f.relu(self.fc1(x))
return f.relu(self.fc2(x))
1 change: 1 addition & 0 deletions florist/api/launchers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Launchers for servers and clients."""
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Launcher functions for clients and servers."""
"""Launcher functions for local clients and servers."""
import logging
import sys
import time
Expand Down
1 change: 1 addition & 0 deletions florist/api/monitoring/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Classes and functions for monitoring of clients and servers' execution."""
29 changes: 19 additions & 10 deletions florist/api/monitoring/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,25 @@


class RedisMetricsReporter(MetricsReporter): # type: ignore
"""Save the metrics to a Redis instance while it records them."""
"""
Save the metrics to a Redis instance while it records them.
def __init__(
self,
redis_connection: redis.client.Redis,
run_id: Optional[str] = None,
):
Lazily instantiates a Redis connection when the first metrics are recorded.
"""

def __init__(self, host: str, port: str, run_id: Optional[str] = None):
"""
Init an instance of RedisMetricsReporter.
:param redis_connection: (redis.client.Redis) the connection object to a Redis. Should be the output
of redis.Redis(host=host, port=port)
:param host: (str) The host address where the Redis instance is running.
:param port: (str) The port where the Redis instance is running on the host.
:param run_id: (Optional[str]) the identifier for the run which these metrics are from.
It will be used as the name of the object in Redis. Optional, default is a random UUID.
"""
super().__init__(run_id)
self.redis_connection = redis_connection
self.host = host
self.port = port
self.redis_connection: Optional[redis.Redis] = None

def add_to_metrics(self, data: Dict[str, Any]) -> None:
"""
Expand All @@ -51,7 +53,14 @@ def add_to_metrics_at_round(self, fl_round: int, data: Dict[str, Any]) -> None:
self.dump()

def dump(self) -> None:
"""Dump the current metrics to Redis under the run_id name."""
"""
Dump the current metrics to Redis under the run_id name.
Will instantiate a Redis connection if it's the first time it runs for this instance.
"""
if self.redis_connection is None:
self.redis_connection = redis.Redis(host=self.host, port=self.port)

encoded_metrics = json.dumps(self.metrics, cls=DateTimeEncoder)
log(DEBUG, f"Dumping metrics to redis at key '{self.run_id}': {encoded_metrics}")
self.redis_connection.set(self.run_id, encoded_metrics)
60 changes: 0 additions & 60 deletions florist/tests/api/monitoring/test_metrics.py

This file was deleted.

6 changes: 3 additions & 3 deletions florist/tests/integration/api/launchers/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import torch
from fl4health.server.base_server import FlServer

from florist.api.launchers.launch import launch
from florist.tests.utils.api.fl4health_utils import MnistClient, get_server_fedavg
from florist.tests.utils.api.models import MnistNet
from florist.api.launchers.local import launch
from florist.api.clients.mnist import MnistClient, MnistNet
from florist.tests.utils.api.fl4health_utils import get_server_fedavg


def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict[str, int]:
Expand Down
Loading

0 comments on commit 253e3c0

Please sign in to comment.