Skip to content

Commit

Permalink
Change start training endpoint to take a job id instead (#22)
Browse files Browse the repository at this point in the history
Change the /api/server/training/start to take a 'job_id' url parameter instead of all the parameters that it required before. Some other required changes:

    Removed ClientInfo helper class which is not needed anymore since that info is now in the DB entity.
    Renamed server_info into server_config and made it dynamically pass in all the config information to the server.
    Created a BasicConfigParser to parse the server config.
  • Loading branch information
lotif authored May 10, 2024
1 parent f8b23b6 commit 5b180e5
Show file tree
Hide file tree
Showing 12 changed files with 369 additions and 673 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ repos:
hooks:
- id: nextjs-lint
name: nextjs-lint
entry: yarn lint-gh-action florist
entry: yarn lint-gh-action florist **/*.tsx
files: "florist/app"
language: system

Expand Down
22 changes: 5 additions & 17 deletions florist/api/db/entities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Definitions for the MongoDB database entities."""

import json
import uuid
from enum import Enum
from typing import Annotated, List, Optional
Expand All @@ -9,9 +8,10 @@

from florist.api.clients.common import Client
from florist.api.servers.common import Model
from florist.api.servers.config_parsers import ConfigParser


JOB_DATABASE_NAME = "job"
JOB_COLLECTION_NAME = "job"
MAX_RECORDS_TO_FETCH = 1000


Expand Down Expand Up @@ -65,24 +65,12 @@ class Job(BaseModel):
status: JobStatus = Field(default=JobStatus.NOT_STARTED)
model: Optional[Annotated[Model, Field(...)]]
server_address: Optional[Annotated[str, Field(...)]]
server_info: Optional[Annotated[str, Field(...)]]
server_config: Optional[Annotated[str, Field(...)]]
config_parser: Optional[Annotated[ConfigParser, Field(...)]]
redis_host: Optional[Annotated[str, Field(...)]]
redis_port: Optional[Annotated[str, Field(...)]]
clients_info: Optional[Annotated[List[ClientInfo], Field(...)]]

@classmethod
def is_valid_server_info(cls, server_info: Optional[str]) -> bool:
"""
Validate if server info is a json string.
:param server_info: (str) the json string with the server info.
:return: True if server_info is None or a valid JSON string, False otherwise.
:raises: (json.JSONDecodeError) if there is an error decoding the server info into json
"""
if server_info is not None:
json.loads(server_info)
return True

class Config:
"""MongoDB config for the Job DB entity."""

Expand All @@ -93,7 +81,7 @@ class Config:
"status": "NOT_STARTED",
"model": "MNIST",
"server_address": "localhost:8080",
"server_info": '{"n_server_rounds": 3, "batch_size": 8}',
"server_config": '{"n_server_rounds": 3, "batch_size": 8}',
"redis_host": "localhost",
"redis_port": "6879",
"clients_info": [
Expand Down
22 changes: 6 additions & 16 deletions florist/api/routes/server/job.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""FastAPI routes for the job."""

from json import JSONDecodeError
from typing import Any, Dict, List

from fastapi import APIRouter, Body, HTTPException, Request, status
from fastapi import APIRouter, Body, Request, status
from fastapi.encoders import jsonable_encoder

from florist.api.db.entities import JOB_DATABASE_NAME, MAX_RECORDS_TO_FETCH, Job, JobStatus
from florist.api.db.entities import JOB_COLLECTION_NAME, MAX_RECORDS_TO_FETCH, Job, JobStatus


router = APIRouter()
Expand All @@ -30,19 +29,10 @@ async def new_job(request: Request, job: Job = Body(...)) -> Dict[str, Any]: #
:return: (Dict[str, Any]) A dictionary with the attributes of the new Job instance as saved in the database.
:raises: (HTTPException) status 400 if job.server_info is not None and cannot be parsed into JSON.
"""
try:
is_valid = Job.is_valid_server_info(job.server_info)
if not is_valid:
msg = f"job.server_info is not valid. job.server_info: {job.server_info}."
raise HTTPException(status_code=400, detail=msg)
except JSONDecodeError as e:
msg = f"job.server_info could not be parsed into JSON. job.server_info: {job.server_info}. Error: {e}"
raise HTTPException(status_code=400, detail=msg) from e

json_job = jsonable_encoder(job)
result = await request.app.database[JOB_DATABASE_NAME].insert_one(json_job)
result = await request.app.database[JOB_COLLECTION_NAME].insert_one(json_job)

created_job = await request.app.database[JOB_DATABASE_NAME].find_one({"_id": result.inserted_id})
created_job = await request.app.database[JOB_COLLECTION_NAME].find_one({"_id": result.inserted_id})
assert isinstance(created_job, dict)

return created_job
Expand All @@ -62,7 +52,7 @@ async def list_jobs_with_status(status: JobStatus, request: Request) -> List[Dic
"""
status = jsonable_encoder(status)

job_db = request.app.database[JOB_DATABASE_NAME]
result = await job_db.find({"status": status}).to_list(MAX_RECORDS_TO_FETCH)
job_collection = request.app.database[JOB_COLLECTION_NAME]
result = await job_collection.find({"status": status}).to_list(MAX_RECORDS_TO_FETCH)
assert isinstance(result, list)
return result
98 changes: 42 additions & 56 deletions florist/api/routes/server/training.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""FastAPI routes for training."""

import logging
from json import JSONDecodeError
from typing import List

import requests
from fastapi import APIRouter, Form
from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
from typing_extensions import Annotated

from florist.api.db.entities import JOB_COLLECTION_NAME, Job
from florist.api.monitoring.metrics import wait_for_metric
from florist.api.servers.common import ClientInfo, ClientInfoParseError, Model
from florist.api.servers.common import Model
from florist.api.servers.config_parsers import ConfigParser
from florist.api.servers.launch import launch_local_server


Expand All @@ -21,41 +23,13 @@


@router.post("/start")
def start(
model: Annotated[str, Form()],
server_address: Annotated[str, Form()],
n_server_rounds: Annotated[int, Form()],
batch_size: Annotated[int, Form()],
local_epochs: Annotated[int, Form()],
redis_host: Annotated[str, Form()],
redis_port: Annotated[str, Form()],
clients_info: Annotated[str, Form()],
) -> JSONResponse:
async def start(job_id: str, request: Request) -> JSONResponse:
"""
Start FL training by starting a FL server and its clients.
Should be called with a POST request and the parameters should be contained in the request's form.
:param model: (str) The name of the model to train. Should be one of the values in the enum
florist.api.servers.common.Model
:param server_address: (str) The address of the FL server to be started. It should be comprised of
the host name and port separated by colon (e.g. "localhost:8080")
:param n_server_rounds: (int) The number of rounds the FL server should run.
:param batch_size: (int) The size of the batch for training.
:param local_epochs: (int) The number of epochs to run by the clients.
:param redis_host: (str) The host name for the Redis instance for metrics reporting.
:param redis_port: (str) The port for the Redis instance for metrics reporting.
:param clients_info: (str) A JSON string containing the client information. It will be parsed by
florist.api.servers.common.ClientInfo and should be in the following format:
[
{
"client": <client name as defined in florist.api.clients.common.Client>,
"client_address": <Florist's client hostname and port, e.g. localhost:8081>,
"data_path": <path where the data is located in the FL client's machine>,
"redis_host": <hostname of the Redis instance the FL client will be reporting to>,
"redis_port": <port of the Redis instance the FL client will be reporting to>,
}
]
Start FL training for a job id by starting a FL server and its clients.
:param job_id: (str) The id of the Job record in the DB which contains the information
necessary to start training.
:param request: (fastapi.Request) the FastAPI request object.
:return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the server and
the clients in the format below. The UUIDs can be used to pull metrics from Redis.
{
Expand All @@ -66,38 +40,50 @@ def start(
{"error": <error message>}
"""
try:
# Parse input data
if model not in Model.list():
error_msg = f"Model '{model}' not supported. Supported models: {Model.list()}"
return JSONResponse(content={"error": error_msg}, status_code=400)
job_collection = request.app.database[JOB_COLLECTION_NAME]
result = await job_collection.find_one({"_id": job_id})
job = Job(**result)

if job.config_parser is None:
job.config_parser = ConfigParser.BASIC

assert job.model is not None, "Missing Job information: model"
assert job.server_config is not None, "Missing Job information: server_config"
assert job.clients_info is not None and len(job.clients_info) > 0, "Missing Job information: clients_info"
assert job.server_address is not None, "Missing Job information: server_address"
assert job.redis_host is not None, "Missing Job information: redis_host"
assert job.redis_port is not None, "Missing Job information: redis_port"

try:
config_parser = ConfigParser.class_for_parser(job.config_parser)
server_config = config_parser.parse(job.server_config)
except JSONDecodeError as err:
raise AssertionError("server_config is not a valid json string.") from err

model_class = Model.class_for_model(Model[model])
clients_info_list = ClientInfo.parse(clients_info)
model_class = Model.class_for_model(job.model)

# Start the server
server_uuid, _ = launch_local_server(
model=model_class(),
n_clients=len(clients_info_list),
server_address=server_address,
n_server_rounds=n_server_rounds,
batch_size=batch_size,
local_epochs=local_epochs,
redis_host=redis_host,
redis_port=redis_port,
n_clients=len(job.clients_info),
server_address=job.server_address,
redis_host=job.redis_host,
redis_port=job.redis_port,
**server_config,
)
wait_for_metric(server_uuid, "fit_start", redis_host, redis_port, logger=LOGGER)
wait_for_metric(server_uuid, "fit_start", job.redis_host, job.redis_port, logger=LOGGER)

# Start the clients
client_uuids: List[str] = []
for client_info in clients_info_list:
for client_info in job.clients_info:
parameters = {
"server_address": server_address,
"server_address": job.server_address,
"client": client_info.client.value,
"data_path": client_info.data_path,
"redis_host": client_info.redis_host,
"redis_port": client_info.redis_port,
}
response = requests.get(url=f"http://{client_info.client_address}/{START_CLIENT_API}", params=parameters)
response = requests.get(url=f"http://{client_info.service_address}/{START_CLIENT_API}", params=parameters)
json_response = response.json()
LOGGER.debug(f"Client response: {json_response}")

Expand All @@ -112,8 +98,8 @@ def start(
# Return the UUIDs
return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids})

except (ValueError, ClientInfoParseError) as ex:
return JSONResponse(content={"error": str(ex)}, status_code=400)
except AssertionError as err:
return JSONResponse(content={"error": str(err)}, status_code=400)

except Exception as ex:
LOGGER.exception(ex)
Expand Down
70 changes: 0 additions & 70 deletions florist/api/servers/common.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,13 @@
"""Common functions and definitions for servers."""

import json
from enum import Enum
from typing import List

from torch import nn

from florist.api.clients.common import Client
from florist.api.models.mnist import MnistNet


class ClientInfo:
"""Define the input information necessary to start a client."""

def __init__(self, client: Client, client_address: str, data_path: str, redis_host: str, redis_port: str):
self.client = client
self.client_address = client_address
self.data_path = data_path
self.redis_host = redis_host
self.redis_port = redis_port

@classmethod
def parse(cls, clients_info: str) -> List["ClientInfo"]:
"""
Parse the client information JSON string into a ClientInfo instance.
:param clients_info: (str) A JSON string containing the client information.
Should be in the following format:
[
{
"client": <client name as defined in florist.api.clients.common.Client>,
"client_address": <Florist's client hostname and port, e.g. localhost:8081>,
"data_path": <path where the data is located in the FL client's machine>,
"redis_host": <hostname of the Redis instance the FL client will be reporting to>,
"redis_port": <port of the Redis instance the FL client will be reporting to>,
}
]
:return: (ClientInfo) an instance of ClientInfo containing the information given.
:raises ClientInfoParseError: If any of the required information is missing or has the
wrong type.
"""
client_info_list: List[ClientInfo] = []

json_clients_info = json.loads(clients_info)
for client_info in json_clients_info:
if "client" not in client_info or not isinstance(client_info["client"], str):
raise ClientInfoParseError("clients_info does not contain key 'client'")
if client_info["client"] not in Client.list():
error_msg = f"Client '{client_info['client']}' not supported. Supported clients: {Client.list()}"
raise ClientInfoParseError(error_msg)
client = Client[client_info["client"]]

if "client_address" not in client_info or not isinstance(client_info["client_address"], str):
raise ClientInfoParseError("clients_info does not contain key 'client_address'")
client_address = client_info["client_address"]

if "data_path" not in client_info or not isinstance(client_info["data_path"], str):
raise ClientInfoParseError("clients_info does not contain key 'data_path'")
data_path = client_info["data_path"]

if "redis_host" not in client_info or not isinstance(client_info["redis_host"], str):
raise ClientInfoParseError("clients_info does not contain key 'redis_host'")
redis_host = client_info["redis_host"]

if "redis_port" not in client_info or not isinstance(client_info["redis_port"], str):
raise ClientInfoParseError("clients_info does not contain key 'redis_port'")
redis_port = client_info["redis_port"]

client_info_list.append(ClientInfo(client, client_address, data_path, redis_host, redis_port))

return client_info_list


class ClientInfoParseError(Exception):
"""Defines errors in parsing client info."""

pass


class Model(Enum):
"""Enumeration of supported models."""

Expand Down
Loading

0 comments on commit 5b180e5

Please sign in to comment.