Skip to content

Commit

Permalink
Feature/SK-923 | Improve continue session (#647)
Browse files Browse the repository at this point in the history
  • Loading branch information
niklastheman authored Jul 3, 2024
1 parent 2485fb4 commit 153dd91
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 8 deletions.
34 changes: 27 additions & 7 deletions fedn/network/api/v1/session_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from fedn.network.api.auth import jwt_auth_required
from fedn.network.api.shared import control
from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb
from fedn.network.combiner.interfaces import CombinerUnavailableError
from fedn.network.state import ReducerState
from fedn.network.storage.statestore.stores.session_store import SessionStore
from fedn.network.storage.statestore.stores.shared import EntityNotFound

Expand Down Expand Up @@ -354,6 +356,18 @@ def post():
return jsonify({"message": "An unexpected error occurred"}), 500


def _get_number_of_available_clients():
result = 0
for combiner in control.network.get_combiners():
try:
nr_active_clients = len(combiner.list_active_clients())
result = result + int(nr_active_clients)
except CombinerUnavailableError:
return 0

return result


@bp.route("/start", methods=["POST"])
@jwt_auth_required(role="admin")
def start_session():
Expand All @@ -367,24 +381,30 @@ def start_session():
data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict()
session_id: str = data.get("session_id")
rounds: int = data.get("rounds", "")
round_timeout: int = data.get("round_timeout", None)

if not session_id or session_id == "":
return jsonify({"message": "Session ID is required"}), 400

if not rounds or rounds == "":
return jsonify({"message": "Rounds is required"}), 400

if not isinstance(rounds, int):
return jsonify({"message": "Rounds must be an integer"}), 400

session = session_store.get(session_id, use_typing=False)

session_config = session["session_config"]
model_id = session_config["model_id"]
min_clients = session_config["clients_required"]

if control.state() == ReducerState.monitoring:
return jsonify({"message": "A session is already running."})

if not rounds or not isinstance(rounds, int):
rounds = session_config["rounds"]
nr_available_clients = _get_number_of_available_clients()

if nr_available_clients < min_clients:
return jsonify({"message": f"Number of available clients is lower than the required minimum of {min_clients}"}), 400

_ = model_store.get(model_id, use_typing=False)

threading.Thread(target=control.start_session, args=(session_id, rounds)).start()
threading.Thread(target=control.start_session, args=(session_id, rounds, round_timeout)).start()

return jsonify({"message": "Session started"}), 200
except Exception:
Expand Down
7 changes: 6 additions & 1 deletion fedn/network/controller/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, statestore):
super().__init__(statestore)
self.name = "DefaultControl"

def start_session(self, session_id: str, rounds: int) -> None:
def start_session(self, session_id: str, rounds: int, round_timeout: int) -> None:
if self._state == ReducerState.instructing:
logger.info("Controller already in INSTRUCTING state. A session is in progress.")
return
Expand All @@ -116,6 +116,9 @@ def start_session(self, session_id: str, rounds: int) -> None:
logger.error("Session not properly configured.")
return

if round_timeout is not None:
session_config["round_timeout"] = round_timeout

self._state = ReducerState.monitoring

last_round = int(self.get_latest_round_id())
Expand Down Expand Up @@ -151,6 +154,8 @@ def start_session(self, session_id: str, rounds: int) -> None:
self.set_session_status(session_id, "Finished")
self._state = ReducerState.idle

self.set_session_config(session_id, session_config)

def session(self, config: RoundConfig) -> None:
"""Execute a new training session. A session consists of one
or several global rounds. All rounds in the same session
Expand Down
10 changes: 10 additions & 0 deletions fedn/network/controller/controlbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,16 @@ def get_session_status(self, session_id):
"""
return self.statestore.get_session_status(session_id)

def set_session_config(self, session_id: str, config: dict):
"""Set the model id for a session.
:param session_id: The session unique identifier
:type session_id: str
:param config: The session config
:type config: dict
"""
self.statestore.set_session_config_v2(session_id, config)

def create_round(self, round_data):
"""Initialize a new round in backend db."""
self.statestore.create_round(round_data)
Expand Down
11 changes: 11 additions & 0 deletions fedn/network/storage/statestore/mongostatestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,17 @@ def set_session_config(self, id: str, config: RoundConfig) -> None:
"""
self.sessions.update_one({"session_id": str(id)}, {"$push": {"session_config": config}}, True)

# Added to accomodate new session config structure
def set_session_config_v2(self, id: str, config: RoundConfig) -> None:
"""Set the session configuration.
:param id: The session id
:type id: str
:param config: Session configuration
:type config: dict
"""
self.sessions.update_one({"session_id": str(id)}, {"$set": {"session_config": config}}, True)

def set_session_status(self, id, status):
"""Set session status.
Expand Down

0 comments on commit 153dd91

Please sign in to comment.