diff --git a/fedn/network/controller/control.py b/fedn/network/controller/control.py index c7a6d1c26..96a59fc65 100644 --- a/fedn/network/controller/control.py +++ b/fedn/network/controller/control.py @@ -67,6 +67,20 @@ def __init__(self, message): super().__init__(self.message) +class SessionTerminatedException(Exception): + """Exception class for when session is terminated""" + + def __init__(self, message): + """Constructor method. + + :param message: The exception message. + :type message: str + + """ + self.message = message + super().__init__(self.message) + + class Control(ControlBase): """Controller, implementing the overall global training, validation and inference logic. @@ -122,6 +136,8 @@ def start_session(self, session_id: str, rounds: int) -> None: current_round = round try: + if self.get_session_status(session_id) == "Terminated": + break _, round_data = self.round(session_config, str(current_round)) except TypeError as e: logger.error("Failed to execute round: {0}".format(e)) @@ -130,7 +146,8 @@ def start_session(self, session_id: str, rounds: int) -> None: session_config["model_id"] = self.statestore.get_latest_model() - self.set_session_status(session_id, "Finished") + if self.get_session_status(session_id) == "Started": + self.set_session_status(session_id, "Finished") self._state = ReducerState.idle def session(self, config: RoundConfig) -> None: @@ -172,6 +189,8 @@ def session(self, config: RoundConfig) -> None: current_round = round try: + if self.get_session_status(config["session_id"]) == "Terminated": + break _, round_data = self.round(config, str(current_round)) except TypeError as e: logger.error("Failed to execute round: {0}".format(e)) @@ -181,7 +200,8 @@ def session(self, config: RoundConfig) -> None: config["model_id"] = self.statestore.get_latest_model() # TODO: Report completion of session - self.set_session_status(config["session_id"], "Finished") + if self.get_session_status(config["session_id"]) == "Started": + self.set_session_status(config["session_id"], "Finished") self._state = ReducerState.idle def inference_session(self, config: RoundConfig) -> None: @@ -227,6 +247,7 @@ def round(self, session_config: RoundConfig, round_id: str): : type round_id: str """ + session_id = session_config["session_id"] self.create_round({"round_id": round_id, "status": "Pending"}) if len(self.network.get_combiners()) < 1: @@ -239,7 +260,7 @@ def round(self, session_config: RoundConfig, round_id: str): round_config["rounds"] = 1 round_config["round_id"] = round_id round_config["task"] = "training" - round_config["session_id"] = session_config["session_id"] + round_config["session_id"] = session_id self.set_round_config(round_id, round_config) @@ -263,7 +284,11 @@ def round(self, session_config: RoundConfig, round_id: str): # Wait until participating combiners have produced an updated global model, # or round times out. def do_if_round_times_out(result): - logger.warning("Round timed out!") + if isinstance(result.outcome.exception(), SessionTerminatedException): + logger.warning("Session terminated!") + return None, self.statestore.get_round(round_id) + else: + logger.warning("Round timed out!") @retry( wait=wait_random(min=1.0, max=2.0), @@ -273,6 +298,9 @@ def do_if_round_times_out(result): ) def combiners_done(): round = self.statestore.get_round(round_id) + session_status = self.get_session_status(session_id) + if session_status == "Terminated": + raise SessionTerminatedException("Session terminated!") if "combiners" not in round: logger.info("Waiting for combiners to update model...") raise CombinersNotDoneException("Combiners have not yet reported.") @@ -283,7 +311,7 @@ def combiners_done(): return True - combiners_done() + _ = combiners_done() # Due to the distributed nature of the computation, there might be a # delay before combiners have reported the round data to the db, diff --git a/fedn/network/controller/controlbase.py b/fedn/network/controller/controlbase.py index 141848b78..3ae8e3731 100644 --- a/fedn/network/controller/controlbase.py +++ b/fedn/network/controller/controlbase.py @@ -183,6 +183,16 @@ def set_session_status(self, session_id, status): """ self.statestore.set_session_status(session_id, status) + def get_session_status(self, session_id): + """Get the status of a session. + + :param session_id: The session unique identifier + :type session_id: str + :return: The status + :rtype: str + """ + return self.statestore.get_session_status(session_id) + def create_round(self, round_data): """Initialize a new round in backend db.""" self.statestore.create_round(round_data) diff --git a/fedn/network/storage/statestore/mongostatestore.py b/fedn/network/storage/statestore/mongostatestore.py index 724077984..b42d88f85 100644 --- a/fedn/network/storage/statestore/mongostatestore.py +++ b/fedn/network/storage/statestore/mongostatestore.py @@ -168,6 +168,17 @@ def get_session(self, session_id): """ return self.sessions.find_one({"session_id": session_id}) + def get_session_status(self, session_id): + """Get the session status. + + :param session_id: The session id. + :type session_id: str + :return: The session status. + :rtype: str + """ + session = self.sessions.find_one({"session_id": session_id}) + return session["status"] + def set_latest_model(self, model_id, session_id=None): """Set the latest model id.