diff --git a/florist/api/db/entities.py b/florist/api/db/entities.py index e361537..518890c 100644 --- a/florist/api/db/entities.py +++ b/florist/api/db/entities.py @@ -1,22 +1,76 @@ """Definitions for the MongoDB database entities.""" +import json import uuid -from typing import Annotated, Optional +from enum import Enum +from typing import Annotated, List, Optional from pydantic import BaseModel, Field +from florist.api.clients.common import Client from florist.api.servers.common import Model JOB_DATABASE_NAME = "job" +class JobStatus(Enum): + """Enumeration of all possible statuses of a Job.""" + + NOT_STARTED = "NOT_STARTED" + IN_PROGRESS = "IN_PROGRESS" + FINISHED_WITH_ERROR = "FINISHED_WITH_ERROR" + FINISHED_SUCCESSFULLY = "FINISHED_SUCCESSFULLY" + + +class ClientInfo(BaseModel): + """Define the information of an FL client.""" + + id: str = Field(default_factory=uuid.uuid4, alias="_id") + client: Client = Field(...) + service_address: str = Field(...) + data_path: str = Field(...) + redis_host: str = Field(...) + redis_port: str = Field(...) + + class Config: + """MongoDB config for the ClientInfo DB entity.""" + + allow_population_by_field_name = True + schema_extra = { + "example": { + "client": "MNIST", + "service_address": "locahost:8081", + "data_path": "path/to/data", + "redis_host": "localhost", + "redis_port": "6880", + }, + } + + class Job(BaseModel): """Define the Job DB entity.""" id: str = Field(default_factory=uuid.uuid4, alias="_id") + status: JobStatus = Field(default=JobStatus.NOT_STARTED) model: Optional[Annotated[Model, Field(...)]] + server_address: Optional[Annotated[str, Field(...)]] + server_info: Optional[Annotated[str, Field(...)]] redis_host: Optional[Annotated[str, Field(...)]] redis_port: Optional[Annotated[str, Field(...)]] + clients_info: Optional[Annotated[List[ClientInfo], Field(...)]] + + @classmethod + def is_valid_server_info(cls, server_info: Optional[str]) -> bool: + """ + Validate if server info is a json string. + + :param server_info: (str) the json string with the server info. + :return: True if server_info is None or a valid JSON string, False otherwise. + :raises: (json.JSONDecodeError) if there is an error decoding the server info into json + """ + if server_info is not None: + json.loads(server_info) + return True class Config: """MongoDB config for the Job DB entity.""" @@ -25,8 +79,20 @@ class Config: schema_extra = { "example": { "_id": "066de609-b04a-4b30-b46c-32537c7f1f6e", + "status": "NOT_STARTED", "model": "MNIST", - "redis_host": "locahost", + "server_address": "localhost:8080", + "server_info": '{"n_server_rounds": 3, "batch_size": 8}', + "redis_host": "localhost", "redis_port": "6879", + "client_info": [ + { + "client": "MNIST", + "service_address": "locahost:8081", + "data_path": "path/to/data", + "redis_host": "localhost", + "redis_port": "6880", + }, + ], }, } diff --git a/florist/api/routes/server/job.py b/florist/api/routes/server/job.py index 6f43e1a..cc87e37 100644 --- a/florist/api/routes/server/job.py +++ b/florist/api/routes/server/job.py @@ -1,7 +1,8 @@ """FastAPI routes for the job.""" +from json import JSONDecodeError from typing import Any, Dict -from fastapi import APIRouter, Body, Request, status +from fastapi import APIRouter, Body, HTTPException, Request, status from fastapi.encoders import jsonable_encoder from florist.api.db.entities import JOB_DATABASE_NAME, Job @@ -26,7 +27,17 @@ async def new_job(request: Request, job: Job = Body(...)) -> Dict[str, Any]: # :param request: (fastapi.Request) the FastAPI request object. :param job: (Job) The Job instance to be saved in the database. :return: (Dict[str, Any]) A dictionary with the attributes of the new Job instance as saved in the database. + :raises: (HTTPException) status 400 if job.server_info is not None and cannot be parsed into JSON. """ + try: + is_valid = Job.is_valid_server_info(job.server_info) + if not is_valid: + msg = f"job.server_info is not valid. job.server_info: {job.server_info}." + raise HTTPException(status_code=400, detail=msg) + except JSONDecodeError as e: + msg = f"job.server_info could not be parsed into JSON. job.server_info: {job.server_info}. Error: {e}" + raise HTTPException(status_code=400, detail=msg) from e + json_job = jsonable_encoder(job) result = await request.app.database[JOB_DATABASE_NAME].insert_one(json_job) diff --git a/florist/tests/integration/api/routes/server/test_job.py b/florist/tests/integration/api/routes/server/test_job.py index 8bea9d1..739c43e 100644 --- a/florist/tests/integration/api/routes/server/test_job.py +++ b/florist/tests/integration/api/routes/server/test_job.py @@ -1,6 +1,10 @@ from unittest.mock import ANY +from pytest import raises -from florist.api.db.entities import Job +from fastapi import HTTPException + +from florist.api.clients.common import Client +from florist.api.db.entities import ClientInfo, Job, JobStatus from florist.api.routes.server.job import new_job from florist.api.servers.common import Model from florist.tests.integration.api.utils import mock_request @@ -12,18 +16,77 @@ async def test_new_job(mock_request) -> None: assert result == { "_id": ANY, + "status": JobStatus.NOT_STARTED.value, "model": None, + "server_address": None, + "server_info": None, "redis_host": None, "redis_port": None, + "clients_info": None, } assert isinstance(result["_id"], str) - test_job = Job(id="test-id", model=Model.MNIST, redis_host="test-redis-host", redis_port="test-redis-port") + test_job = Job( + id="test-id", + status=JobStatus.IN_PROGRESS, + model=Model.MNIST, + server_address="test-server-address", + server_info="{\"test-server-info\": 123}", + redis_host="test-redis-host", + redis_port="test-redis-port", + clients_info=[ + ClientInfo( + client=Client.MNIST, + service_address="test-addr-1", + data_path="test/data/path-1", + redis_host="test-redis-host-1", + redis_port="test-redis-port-1", + ), + ClientInfo( + client=Client.MNIST, + service_address="test-addr-2", + data_path="test/data/path-2", + redis_host="test-redis-host-2", + redis_port="test-redis-port-2", + ), + ] + ) result = await new_job(mock_request, test_job) assert result == { "_id": test_job.id, + "status": test_job.status.value, "model": test_job.model.value, + "server_address": "test-server-address", + "server_info": "{\"test-server-info\": 123}", "redis_host": test_job.redis_host, "redis_port": test_job.redis_port, + "clients_info": [ + { + "_id": ANY, + "client": test_job.clients_info[0].client.value, + "service_address": test_job.clients_info[0].service_address, + "data_path": test_job.clients_info[0].data_path, + "redis_host": test_job.clients_info[0].redis_host, + "redis_port": test_job.clients_info[0].redis_port, + }, { + "_id": ANY, + "client": test_job.clients_info[1].client.value, + "service_address": test_job.clients_info[1].service_address, + "data_path": test_job.clients_info[1].data_path, + "redis_host": test_job.clients_info[1].redis_host, + "redis_port": test_job.clients_info[1].redis_port, + }, + ], } + assert isinstance(result["clients_info"][0]["_id"], str) + assert isinstance(result["clients_info"][1]["_id"], str) + + +async def test_new_job_fail_bad_server_info(mock_request) -> None: + test_job = Job(server_info="not json") + with raises(HTTPException) as exception_info: + await new_job(mock_request, test_job) + + assert exception_info.value.status_code == 400 + assert "job.server_info could not be parsed into JSON" in exception_info.value.detail