Skip to content

Commit

Permalink
Adding SQList as client storage solution (#159)
Browse files Browse the repository at this point in the history
In this PR I am introducing SQLite as a solution for the FLorist client to store sensitive information about the FL client. List of changes:

    Making the florist/api/db/client_entities.py which will contain both the SQLite entity definitions and the methods to connect and access the DB.
    Renaming florist/api/db/entities.py to florist/api/db/server_entities.py
    Removing pid and log_file_path from the ClientInfo MongoDB entity
    When the FL client starts, instead of returning the pid and log_file_path, it will now save those to the database
    On both client's stop and get_log functions, change them to receive the client UUID instead of the pid or log_file_path respectivelly
    Changing the set_pids function to set_server_pid as it now only saves the server's PID.
  • Loading branch information
lotif authored Feb 6, 2025
1 parent 1d2fc87 commit 27e0971
Show file tree
Hide file tree
Showing 15 changed files with 647 additions and 249 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,6 @@ next-env.d.ts
/logs/
/.ruff_cache/
/.swc/
/florist/api/client.db
/florist/tests/integration/api/client.db
/florist/tests/unit/api/client.db
65 changes: 42 additions & 23 deletions florist/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import logging
import os
import signal
import uuid
from pathlib import Path
from uuid import uuid4

import torch
from fastapi import FastAPI
from fastapi.responses import JSONResponse

from florist.api.clients.common import Client
from florist.api.db.client_entities import ClientDAO
from florist.api.launchers.local import launch_client
from florist.api.monitoring.logs import get_client_log_file_path
from florist.api.monitoring.metrics import RedisMetricsReporter, get_from_redis
Expand Down Expand Up @@ -43,12 +44,10 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red
:param data_path: (str) the path where the training data is located.
:param redis_host: (str) the host name for the Redis instance for metrics reporting.
:param redis_port: (str) the port for the Redis instance for metrics reporting.
:return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID, the PID and the log
file path for the client in the format below:
:return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the client in the
format below, which can be used to pull metrics from Redis.
{
"uuid": (str) The client's uuid, which can be used to pull metrics from Redis,
"log_file_path": (str) The local path of the log file for this client,
"pid": (str) The PID of the client process
}
If not successful, returns the appropriate error code with a JSON with the format below:
{
Expand All @@ -60,7 +59,7 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red
error_msg = f"Client '{client}' not supported. Supported clients: {Client.list()}"
return JSONResponse(content={"error": error_msg}, status_code=400)

client_uuid = str(uuid.uuid4())
client_uuid = str(uuid4())
metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=client_uuid)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Expand All @@ -76,9 +75,13 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red
log_file_path = str(get_client_log_file_path(client_uuid))
client_process = launch_client(client_obj, server_address, log_file_path)

return JSONResponse({"uuid": client_uuid, "log_file_path": log_file_path, "pid": str(client_process.pid)})
db_entity = ClientDAO(uuid=client_uuid, log_file_path=log_file_path, pid=client_process.pid)
db_entity.save()

return JSONResponse({"uuid": client_uuid})

except Exception as ex:
LOGGER.exception(ex)
return JSONResponse({"error": str(ex)}, status_code=500)


Expand Down Expand Up @@ -108,35 +111,51 @@ def check_status(client_uuid: str, redis_host: str, redis_port: str) -> JSONResp
return JSONResponse({"error": str(ex)}, status_code=500)


@app.get("/api/client/get_log")
def get_log(log_file_path: str) -> JSONResponse:
@app.get("/api/client/get_log/{uuid}")
def get_log(uuid: str) -> JSONResponse:
"""
Return the contents of the log file under the given path.
Return the contents of the logs for the given client uuid.
:param log_file_path: (str) the path of the logt file.
:param uuid: (str) the uuid of the client.
:return: (JSONResponse) Returns the contents of the file as a string.
:return: (JSONResponse) If successful, returns the contents of the file as a string.
If not successful, returns the appropriate error code with a JSON with the format below:
{"error": <error message>}
"""
with open(log_file_path, "r") as f:
content = f.read()
return JSONResponse(content)
try:
client = ClientDAO.find(uuid)

assert client.log_file_path, "Client log file path is None or empty"

with open(client.log_file_path, "r") as f:
content = f.read()
return JSONResponse(content)

# TODO verify the safety of this call
@app.get("/api/client/stop/{pid}")
def stop(pid: str) -> JSONResponse:
except AssertionError as err:
return JSONResponse(content={"error": str(err)}, status_code=400)
except Exception as ex:
LOGGER.exception(ex)
return JSONResponse({"error": str(ex)}, status_code=500)


@app.get("/api/client/stop/{uuid}")
def stop(uuid: str) -> JSONResponse:
"""
Kills the client process with given PID.
Stop the client with given UUID.
:param pid: (str) the PID of the client to be killed.
:param uuid: (str) the UUID of the client to be stopped.
:return: (JSONResponse) If successful, returns 200. If not successful, returns the appropriate
error code with a JSON with the format below:
{"error": <error message>}
"""
try:
assert pid, "PID is empty or None."
os.kill(int(pid), signal.SIGTERM)
LOGGER.info(f"Killed process with PID {pid}")
assert uuid, "UUID is empty or None."
client = ClientDAO.find(uuid)
assert client.pid, "PID is empty or None."

os.kill(client.pid, signal.SIGTERM)
LOGGER.info(f"Stopped client with UUID {uuid} ({client.pid})")

return JSONResponse(content={"status": "success"})
except AssertionError as err:
return JSONResponse(content={"error": str(err)}, status_code=400)
Expand Down
169 changes: 169 additions & 0 deletions florist/api/db/client_entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""Definitions for the SQLIte database entities (client database)."""

import json
import sqlite3
from abc import ABC, abstractmethod
from typing import Optional

from typing_extensions import Self

from florist.api.db.config import SQLITE_DB_PATH


class EntityDAO(ABC):
"""Base Data Access Object (DAO) for SQLite entities."""

table_name = "Entity"
db_path = SQLITE_DB_PATH

@abstractmethod
def __init__(self, uuid: str):
"""
Initialize an Entity.
Abstract method to be implemented by the child classes.
:param uuid: the UUID of the entity
"""
self.uuid = uuid

@classmethod
def get_connection(cls) -> sqlite3.Connection:
"""
Return the SQLite connection object.
Will create the table of the entity in the DB if it doesn't exist.
:return: (sqlite3.Connection) The SQLite connection object
"""
sqlite_db = sqlite3.connect(cls.db_path)
sqlite_db.execute(f"CREATE TABLE IF NOT EXISTS {cls.table_name} (uuid TEXT, data TEXT)")
sqlite_db.commit()
return sqlite_db

@classmethod
def find(cls, uuid: str) -> Self:
"""
Find the entity in the database with the given UUID.
:param uuid: (str) the UUID of the entity.
:return: (Self) an instance of the entity.
:raises ValueError: if no such entity exists in the database with given UUID.
"""
sqlite_db = cls.get_connection()
results = sqlite_db.execute(f"SELECT * FROM {cls.table_name} WHERE uuid=? LIMIT 1", (uuid,))
for result in results:
return cls.from_json(result[1])

raise ValueError(f"Client with uuid '{uuid}' not found.")

@classmethod
def exists(cls, uuid: str) -> bool:
"""
Check if an entity with the given UUID exists in the database.
:param uuid: (str) the UUID of the entity.
:return: (bool) True if the entity exists, False otherwise.
"""
sqlite_db = cls.get_connection()
results = sqlite_db.execute(f"SELECT EXISTS(SELECT 1 FROM {cls.table_name} WHERE uuid=? LIMIT 1);", (uuid,))
for result in results:
return bool(result[0])

return False

def save(self) -> None:
"""
Save the current entity to the database.
Will insert a new record if an entity with self.uuid doesn't yet exist in the database,
will update the database entity at self.uuid otherwise.
"""
sqlite_db = self.__class__.get_connection()
if self.__class__.exists(self.uuid):
sqlite_db.execute(
f"UPDATE {self.__class__.table_name} SET data=? WHERE uuid=?", (self.to_json(), self.uuid)
)
else:
sqlite_db.execute(
f"INSERT INTO {self.__class__.table_name} (uuid, data) VALUES(?, ?)", (self.uuid, self.to_json())
)
sqlite_db.commit()

def __eq__(self, other: object) -> bool:
"""
Check if two instances of this entity have the same values for the same attributes.
:param other: (object) the other instance to check against.
:return: (bool) True if they are equal, False otherwise.
"""
if not isinstance(other, self.__class__):
return False
return self.to_json() == other.to_json()

@classmethod
@abstractmethod
def from_json(cls, json_data: str) -> Self:
"""
Convert from a JSON string to an instance of the entity.
Abstract method, to be implemented by the child classes.
:param json_data: (str) the entity data as a JSON string.
:return: (Self) and instance of the entity populated with the JSON data.
"""
pass

@abstractmethod
def to_json(self) -> str:
"""
Convert the entity data into a JSON string.
Abstract method, to be implemented by the child classes.
:return: (str) the entity data as a JSON string.
"""
pass


class ClientDAO(EntityDAO):
"""Data Access Object (DAO) for the Client SQLite entity."""

table_name = "Client"

def __init__(self, uuid: str, log_file_path: Optional[str] = None, pid: Optional[int] = None):
"""
Initialize a Client entity.
:param uuid: (str) the UUID of the client.
:param log_file_path: the path in the filesystem where the client's log can be located.
:param pid: the PID of the client's process.
"""
super().__init__(uuid=uuid)
self.log_file_path = log_file_path
self.pid = pid

@classmethod
def from_json(cls, json_data: str) -> Self:
"""
Convert from a JSON string into an instance of Client.
:param json_data: the client's data as a JSON string.
:return: (Self) and instancxe of ClientDAO populated with the JSON data.
"""
data = json.loads(json_data)
return cls(data["uuid"], data["log_file_path"], data["pid"])

def to_json(self) -> str:
"""
Convert the client data into a JSON string.
:return: (str) the client data as a JSON string.
"""
return json.dumps(
{
"uuid": self.uuid,
"log_file_path": self.log_file_path,
"pid": self.pid,
}
)
2 changes: 2 additions & 0 deletions florist/api/db/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@

MONGODB_URI = "mongodb://localhost:27017/"
DATABASE_NAME = "florist-server"

SQLITE_DB_PATH = "florist/api/client.db"
25 changes: 3 additions & 22 deletions florist/api/db/entities.py → florist/api/db/server_entities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Definitions for the MongoDB database entities."""
"""Definitions for the MongoDB database entities (server database)."""

import json
import uuid
Expand Down Expand Up @@ -48,8 +48,6 @@ class ClientInfo(BaseModel):
redis_port: str = Field(...)
uuid: Optional[Annotated[str, Field(...)]]
metrics: Optional[Annotated[str, Field(...)]]
log_file_path: Optional[Annotated[str, Field(...)]]
pid: Optional[Annotated[str, Field(...)]]

class Config:
"""MongoDB config for the ClientInfo DB entity."""
Expand All @@ -64,8 +62,6 @@ class Config:
"redis_port": "6380",
"uuid": "0c316680-1375-4e07-84c3-a732a2e6d03f",
"metrics": '{"host_type": "client", "initialized": "2024-03-25 11:20:56.819569", "rounds": {"1": {"fit_start": "2024-03-25 11:20:56.827081"}}}',
"log_file_path": "/Users/foo/client/logfile.log",
"pid": "123",
},
}

Expand Down Expand Up @@ -256,32 +252,19 @@ async def set_client_log_file_path(
)
assert_updated_successfully(update_result)

async def set_pids(self, server_pid: str, client_pids: List[str], database: AsyncIOMotorDatabase[Any]) -> None:
async def set_server_pid(self, server_pid: str, database: AsyncIOMotorDatabase[Any]) -> None:
"""
Save the server and clients' PIDs in the database under the current job's id.
Save the server PID in the database under the current job's id.
:param server_pid: [str] the server PID to be saved in the database.
:param client_pids: List[str] the list of client PIDs to be saved in the database.
:param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored.
"""
assert self.clients_info is not None and len(self.clients_info) == len(client_pids), (
"self.clients_info and client_pids must have the same length "
f"({'None' if self.clients_info is None else len(self.clients_info)}!={len(client_pids)})."
)

job_collection = database[JOB_COLLECTION_NAME]

self.server_pid = server_pid
update_result = await job_collection.update_one({"_id": self.id}, {"$set": {"server_pid": server_pid}})
assert_updated_successfully(update_result)

for i in range(len(client_pids)):
self.clients_info[i].pid = client_pids[i]
update_result = await job_collection.update_one(
{"_id": self.id}, {"$set": {f"clients_info.{i}.pid": client_pids[i]}}
)
assert_updated_successfully(update_result)

async def set_error_message(self, error_message: str, database: AsyncIOMotorDatabase[Any]) -> None:
"""
Save an error message in the database under the current job's id.
Expand Down Expand Up @@ -319,8 +302,6 @@ class Config:
"redis_host": "localhost",
"redis_port": "6380",
"uuid": "0c316680-1375-4e07-84c3-a732a2e6d03f",
"metrics": '{"host_type": "client", "initialized": "2024-03-25 11:20:56.819569", "rounds": {"1": {"fit_start": "2024-03-25 11:20:56.827081"}}}',
"pid": "123",
},
],
"error_message": "Some plain text error message.",
Expand Down
Loading

0 comments on commit 27e0971

Please sign in to comment.