Skip to content

Commit

Permalink
Switching background tasks for threads on client and server updates (#…
Browse files Browse the repository at this point in the history
…122)

While I was testing the app this week, I realized FastAPI's Background Tasks, which we use for client and server metric updates, are not really suitable for long running tasks as they are executed serially instead of in parallel.

Here I am switching from that to plain python threads, and also making a few other changes:

    Adding a database connection in the listeners as FastAPI's db instance cannot be shared between threads
    Changing the listener functions to be async
    Converting the sync database functions to async in the Job
    Getting rid of the sync database connection, we don't need it anymore
    For metrics reporter, avoid saving the metrics if it's exactly the same as what's already stored. This will eliminate an issue with updates that are too frequent into Redis with no new information, which makes the app do a lot of unnecessary work.
    Fixing the metrics in the UI progress section for the updated metric names
  • Loading branch information
lotif authored Dec 9, 2024
1 parent a6d0cf3 commit 968df3b
Show file tree
Hide file tree
Showing 13 changed files with 320 additions and 246 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ docker start redis-florist-client
To start the client back-end service:

```shell
uvicorn florist.api.client:app --reload --port 8001
python -m uvicorn florist.api.client:app --reload --port 8001
```

The service will be available at `http://localhost:8001`.
Expand Down
4 changes: 4 additions & 0 deletions florist/api/db/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Database configuration parameters."""

MONGODB_URI = "mongodb://localhost:27017/"
DATABASE_NAME = "florist-server"
39 changes: 15 additions & 24 deletions florist/api/db/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from fastapi.encoders import jsonable_encoder
from motor.motor_asyncio import AsyncIOMotorDatabase
from pydantic import BaseModel, Field
from pymongo.database import Database
from pymongo.results import UpdateResult

from florist.api.clients.common import Client
Expand Down Expand Up @@ -164,47 +163,38 @@ async def set_status(self, status: JobStatus, database: AsyncIOMotorDatabase[Any
update_result = await job_collection.update_one({"_id": self.id}, {"$set": {"status": status.value}})
assert_updated_successfully(update_result)

def set_status_sync(self, status: JobStatus, database: Database[Dict[str, Any]]) -> None:
"""
Sync function to save the status in the database under the current job's id.
:param status: (JobStatus) the status to be saved in the database.
:param database: (pymongo.database.Database) The database where the job collection is stored.
"""
job_collection = database[JOB_COLLECTION_NAME]
self.status = status
update_result = job_collection.update_one({"_id": self.id}, {"$set": {"status": status.value}})
assert_updated_successfully(update_result)

def set_server_metrics(
async def set_server_metrics(
self,
server_metrics: Dict[str, Any],
database: Database[Dict[str, Any]],
database: AsyncIOMotorDatabase[Any],
) -> None:
"""
Sync function to save the server's metrics in the database under the current job's id.
Save the server's metrics in the database under the current job's id.
:param server_metrics: (Dict[str, Any]) the server metrics to be saved.
:param database: (pymongo.database.Database) The database where the job collection is stored.
:param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored.
"""
job_collection = database[JOB_COLLECTION_NAME]

self.server_metrics = json.dumps(server_metrics)
update_result = job_collection.update_one({"_id": self.id}, {"$set": {"server_metrics": self.server_metrics}})
update_result = await job_collection.update_one(
{"_id": self.id},
{"$set": {"server_metrics": self.server_metrics}},
)
assert_updated_successfully(update_result)

def set_client_metrics(
async def set_client_metrics(
self,
client_uuid: str,
client_metrics: Dict[str, Any],
database: Database[Dict[str, Any]],
database: AsyncIOMotorDatabase[Any],
) -> None:
"""
Sync function to save a clients' metrics in the database under the current job's id.
Save a client's metrics in the database under the current job's id.
:param client_uuid: (str) the client's uuid whose produced the metrics.
:param client_metrics: (Dict[str, Any]) the client's metrics to be saved.
:param database: (pymongo.database.Database) The database where the job collection is stored.
:param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored.
"""
assert (
self.clients_info is not None and client_uuid in [c.uuid for c in self.clients_info]
Expand All @@ -215,8 +205,9 @@ def set_client_metrics(
for i in range(len(self.clients_info)):
if client_uuid == self.clients_info[i].uuid:
self.clients_info[i].metrics = json.dumps(client_metrics)
update_result = job_collection.update_one(
{"_id": self.id}, {"$set": {f"clients_info.{i}.metrics": self.clients_info[i].metrics}}
update_result = await job_collection.update_one(
{"_id": self.id},
{"$set": {f"clients_info.{i}.metrics": self.clients_info[i].metrics}},
)
assert_updated_successfully(update_result)

Expand Down
11 changes: 11 additions & 0 deletions florist/api/monitoring/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ def dump(self) -> None:
assert self.run_id is not None, "Run ID is None, ensure reporter is initialized prior to dumping metrics."

encoded_metrics = json.dumps(self.metrics, cls=DateTimeEncoder)

previous_metrics_blob = self.redis_connection.get(self.run_id)
if previous_metrics_blob is not None and isinstance(previous_metrics_blob, bytes):
previous_metrics = json.loads(previous_metrics_blob)
current_metrics = json.loads(encoded_metrics)
if current_metrics == previous_metrics:
log(
DEBUG, f"Skipping dumping: previous metrics are the same as current metrics at key '{self.run_id}'"
)
return

log(DEBUG, f"Dumping metrics to redis at key '{self.run_id}': {encoded_metrics}")
self.redis_connection.set(self.run_id, encoded_metrics)
log(DEBUG, f"Notifying redis channel '{self.run_id}'")
Expand Down
92 changes: 53 additions & 39 deletions florist/api/routes/server/training.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""FastAPI routes for training."""

import asyncio
import logging
from json import JSONDecodeError
from typing import Any, Dict, List
from threading import Thread
from typing import Any, List

import requests
from fastapi import APIRouter, BackgroundTasks, Request
from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
from pymongo.database import Database
from motor.motor_asyncio import AsyncIOMotorClient

from florist.api.db.config import DATABASE_NAME, MONGODB_URI
from florist.api.db.entities import ClientInfo, Job, JobStatus
from florist.api.monitoring.metrics import get_from_redis, get_subscriber, wait_for_metric
from florist.api.servers.common import Model
Expand All @@ -25,15 +28,13 @@


@router.post("/start")
async def start(job_id: str, request: Request, background_tasks: BackgroundTasks) -> JSONResponse:
async def start(job_id: str, request: Request) -> JSONResponse:
"""
Start FL training for a job id by starting a FL server and its clients.
:param job_id: (str) The id of the Job record in the DB which contains the information
necessary to start training.
:param request: (fastapi.Request) the FastAPI request object.
:param background_tasks: (BackgroundTasks) A BackgroundTasks instance to launch the training listener,
which will update the progress of the training job.
:return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the server and
the clients in the format below. The UUIDs can be used to pull metrics from Redis.
{
Expand Down Expand Up @@ -105,11 +106,13 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks

await job.set_uuids(server_uuid, client_uuids, request.app.database)

# Start the server training listener as a background task to update
# Start the server training listener and client training listeners as threads to update
# the job's metrics and status once the training is done
background_tasks.add_task(server_training_listener, job, request.app.synchronous_database)
server_listener_thread = Thread(target=asyncio.run, args=(server_training_listener(job),))
server_listener_thread.start()
for client_info in job.clients_info:
background_tasks.add_task(client_training_listener, job, client_info, request.app.synchronous_database)
client_listener_thread = Thread(target=asyncio.run, args=(client_training_listener(job, client_info),))
client_listener_thread.start()

# Return the UUIDs
return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids})
Expand All @@ -126,30 +129,31 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks
return JSONResponse({"error": str(ex)}, status_code=500)


def client_training_listener(job: Job, client_info: ClientInfo, database: Database[Dict[str, Any]]) -> None:
async def client_training_listener(job: Job, client_info: ClientInfo) -> None:
"""
Listen to the Redis' channel that reports updates on the training process of a FL client.
Keeps consuming updates to the channel until it finds `shutdown` in the client metrics.
:param job: (Job) The job that has this client's metrics.
:param client_info: (ClientInfo) The ClientInfo with the client_uuid to listen to.
:param database: (pymongo.database.Database) An instance of the database to save the information
into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async
because of limitations with FastAPI's BackgroundTasks.
"""
LOGGER.info(f"Starting listener for client messages from job {job.id} at channel {client_info.uuid}")

assert client_info.uuid is not None, "client_info.uuid is None."

db_client: AsyncIOMotorClient[Any] = AsyncIOMotorClient(MONGODB_URI)
database = db_client[DATABASE_NAME]

# check if training has already finished before start listening
client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port)
LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}")
LOGGER.debug(f"Client listener: Current metrics for client {client_info.uuid}: {client_metrics}")
if client_metrics is not None:
LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}")
job.set_client_metrics(client_info.uuid, client_metrics, database)
LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.")
LOGGER.info(f"Client listener: Updating client metrics for client {client_info.uuid} on job {job.id}")
await job.set_client_metrics(client_info.uuid, client_metrics, database)
LOGGER.info(f"Client listener: Client metrics for client {client_info.uuid} on {job.id} have been updated.")
if "shutdown" in client_metrics:
db_client.close()
return

subscriber = get_subscriber(client_info.uuid, client_info.redis_host, client_info.redis_port)
Expand All @@ -158,16 +162,22 @@ def client_training_listener(job: Job, client_info: ClientInfo, database: Databa
if message["type"] == "message":
# The contents of the message do not matter, we just use it to get notified
client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port)
LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}")
LOGGER.debug(f"Client listener: Current metrics for client {client_info.uuid}: {client_metrics}")

if client_metrics is not None:
LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}")
job.set_client_metrics(client_info.uuid, client_metrics, database)
LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.")
LOGGER.info(f"Client listener: Updating client metrics for client {client_info.uuid} on job {job.id}")
await job.set_client_metrics(client_info.uuid, client_metrics, database)
LOGGER.info(
f"Client listener: Client metrics for client {client_info.uuid} on {job.id} have been updated."
)
if "shutdown" in client_metrics:
db_client.close()
return

db_client.close()


def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> None:
async def server_training_listener(job: Job) -> None:
"""
Listen to the Redis' channel that reports updates on the training process of a FL server.
Expand All @@ -176,27 +186,28 @@ def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> No
to the job in the database.
:param job: (Job) The job with the server_uuid to listen to.
:param database: (pymongo.database.Database) An instance of the database to save the information
into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async
because of limitations with FastAPI's BackgroundTasks.
"""
LOGGER.info(f"Starting listener for server messages from job {job.id} at channel {job.server_uuid}")

assert job.server_uuid is not None, "job.server_uuid is None."
assert job.redis_host is not None, "job.redis_host is None."
assert job.redis_port is not None, "job.redis_port is None."

db_client: AsyncIOMotorClient[Any] = AsyncIOMotorClient(MONGODB_URI)
database = db_client[DATABASE_NAME]

# check if training has already finished before start listening
server_metrics = get_from_redis(job.server_uuid, job.redis_host, job.redis_port)
LOGGER.debug(f"Listener: Current metrics for job {job.id}: {server_metrics}")
LOGGER.debug(f"Server listener: Current metrics for job {job.id}: {server_metrics}")
if server_metrics is not None:
LOGGER.info(f"Listener: Updating server metrics for job {job.id}")
job.set_server_metrics(server_metrics, database)
LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.")
LOGGER.info(f"Server listener: Updating server metrics for job {job.id}")
await job.set_server_metrics(server_metrics, database)
LOGGER.info(f"Server listener: Server metrics for {job.id} have been updated.")
if "fit_end" in server_metrics:
LOGGER.info(f"Listener: Training finished for job {job.id}")
job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database)
LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.")
LOGGER.info(f"Server listener: Training finished for job {job.id}")
await job.set_status(JobStatus.FINISHED_SUCCESSFULLY, database)
LOGGER.info(f"Server listener: Job {job.id} status have been set to {job.status.value}.")
db_client.close()
return

subscriber = get_subscriber(job.server_uuid, job.redis_host, job.redis_port)
Expand All @@ -205,14 +216,17 @@ def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> No
if message["type"] == "message":
# The contents of the message do not matter, we just use it to get notified
server_metrics = get_from_redis(job.server_uuid, job.redis_host, job.redis_port)
LOGGER.debug(f"Listener: Message received for job {job.id}. Metrics: {server_metrics}")
LOGGER.debug(f"Server listener: Message received for job {job.id}. Metrics: {server_metrics}")

if server_metrics is not None:
LOGGER.info(f"Listener: Updating server metrics for job {job.id}")
job.set_server_metrics(server_metrics, database)
LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.")
LOGGER.info(f"Server listener: Updating server metrics for job {job.id}")
await job.set_server_metrics(server_metrics, database)
LOGGER.info(f"Server listener: Server metrics for {job.id} have been updated.")
if "fit_end" in server_metrics:
LOGGER.info(f"Listener: Training finished for job {job.id}")
job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database)
LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.")
LOGGER.info(f"Server listener: Training finished for job {job.id}")
await job.set_status(JobStatus.FINISHED_SUCCESSFULLY, database)
LOGGER.info(f"Server listener: Job {job.id} status have been set to {job.status.value}.")
db_client.close()
return

db_client.close()
10 changes: 1 addition & 9 deletions florist/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,26 @@
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient

from florist.api.clients.common import Client
from florist.api.db.config import DATABASE_NAME, MONGODB_URI
from florist.api.routes.server.job import router as job_router
from florist.api.routes.server.status import router as status_router
from florist.api.routes.server.training import router as training_router
from florist.api.servers.common import Model


MONGODB_URI = "mongodb://localhost:27017/"
DATABASE_NAME = "florist-server"


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[Any, Any]:
"""Set up function for app startup and shutdown."""
# Set up mongodb
app.db_client = AsyncIOMotorClient(MONGODB_URI) # type: ignore[attr-defined]
app.database = app.db_client[DATABASE_NAME] # type: ignore[attr-defined]
# Setting up a synchronous database connection for background tasks
app.synchronous_db_client = MongoClient(MONGODB_URI) # type: ignore[attr-defined]
app.synchronous_database = app.synchronous_db_client[DATABASE_NAME] # type: ignore[attr-defined]

yield

# Shut down mongodb
app.db_client.close() # type: ignore[attr-defined]
app.synchronous_db_client.close() # type: ignore[attr-defined]


app = FastAPI(lifespan=lifespan)
Expand Down
Loading

0 comments on commit 968df3b

Please sign in to comment.