Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-923 | Improve continue session #647

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions fedn/network/api/v1/session_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from fedn.network.api.auth import jwt_auth_required
from fedn.network.api.shared import control
from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers, mdb
from fedn.network.combiner.interfaces import CombinerUnavailableError
from fedn.network.state import ReducerState
from fedn.network.storage.statestore.stores.session_store import SessionStore
from fedn.network.storage.statestore.stores.shared import EntityNotFound

Expand Down Expand Up @@ -354,6 +356,18 @@ def post():
return jsonify({"message": "An unexpected error occurred"}), 500


def _get_number_of_available_clients():
result = 0
for combiner in control.network.get_combiners():
try:
nr_active_clients = len(combiner.list_active_clients())
result = result + int(nr_active_clients)
except CombinerUnavailableError:
return 0

return result


@bp.route("/start", methods=["POST"])
@jwt_auth_required(role="admin")
def start_session():
Expand All @@ -367,24 +381,30 @@ def start_session():
data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict()
session_id: str = data.get("session_id")
rounds: int = data.get("rounds", "")
round_timeout: int = data.get("round_timeout", None)

if not session_id or session_id == "":
return jsonify({"message": "Session ID is required"}), 400

if not rounds or rounds == "":
return jsonify({"message": "Rounds is required"}), 400

if not isinstance(rounds, int):
return jsonify({"message": "Rounds must be an integer"}), 400

session = session_store.get(session_id, use_typing=False)

session_config = session["session_config"]
model_id = session_config["model_id"]
min_clients = session_config["clients_required"]

if control.state() == ReducerState.monitoring:
return jsonify({"message": "A session is already running."})

if not rounds or not isinstance(rounds, int):
rounds = session_config["rounds"]
nr_available_clients = _get_number_of_available_clients()

if nr_available_clients < min_clients:
return jsonify({"message": f"Number of available clients is lower than the required minimum of {min_clients}"}), 400

_ = model_store.get(model_id, use_typing=False)

threading.Thread(target=control.start_session, args=(session_id, rounds)).start()
threading.Thread(target=control.start_session, args=(session_id, rounds, round_timeout)).start()

return jsonify({"message": "Session started"}), 200
except Exception:
Expand Down
7 changes: 6 additions & 1 deletion fedn/network/controller/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, statestore):
super().__init__(statestore)
self.name = "DefaultControl"

def start_session(self, session_id: str, rounds: int) -> None:
def start_session(self, session_id: str, rounds: int, round_timeout: int) -> None:
if self._state == ReducerState.instructing:
logger.info("Controller already in INSTRUCTING state. A session is in progress.")
return
Expand All @@ -116,6 +116,9 @@ def start_session(self, session_id: str, rounds: int) -> None:
logger.error("Session not properly configured.")
return

if round_timeout is not None:
session_config["round_timeout"] = round_timeout

self._state = ReducerState.monitoring

last_round = int(self.get_latest_round_id())
Expand Down Expand Up @@ -151,6 +154,8 @@ def start_session(self, session_id: str, rounds: int) -> None:
self.set_session_status(session_id, "Finished")
self._state = ReducerState.idle

self.set_session_config(session_id, session_config)

def session(self, config: RoundConfig) -> None:
"""Execute a new training session. A session consists of one
or several global rounds. All rounds in the same session
Expand Down
10 changes: 10 additions & 0 deletions fedn/network/controller/controlbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,16 @@ def get_session_status(self, session_id):
"""
return self.statestore.get_session_status(session_id)

def set_session_config(self, session_id: str, config: dict):
"""Set the model id for a session.

:param session_id: The session unique identifier
:type session_id: str
:param config: The session config
:type config: dict
"""
self.statestore.set_session_config_v2(session_id, config)

def create_round(self, round_data):
"""Initialize a new round in backend db."""
self.statestore.create_round(round_data)
Expand Down
11 changes: 11 additions & 0 deletions fedn/network/storage/statestore/mongostatestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,17 @@ def set_session_config(self, id: str, config: RoundConfig) -> None:
"""
self.sessions.update_one({"session_id": str(id)}, {"$push": {"session_config": config}}, True)

# Added to accomodate new session config structure
def set_session_config_v2(self, id: str, config: RoundConfig) -> None:
"""Set the session configuration.

:param id: The session id
:type id: str
:param config: Session configuration
:type config: dict
"""
self.sessions.update_one({"session_id": str(id)}, {"$set": {"session_config": config}}, True)

def set_session_status(self, id, status):
"""Set session status.

Expand Down
Loading