Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed Jun 20, 2024
1 parent 36edfd3 commit c54db41
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 5 deletions.
38 changes: 33 additions & 5 deletions fedn/network/controller/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ def __init__(self, message):
super().__init__(self.message)


class SessionTerminatedException(Exception):
"""Exception class for when session is terminated"""

def __init__(self, message):
"""Constructor method.
:param message: The exception message.
:type message: str
"""
self.message = message
super().__init__(self.message)


class Control(ControlBase):
"""Controller, implementing the overall global training, validation and inference logic.
Expand Down Expand Up @@ -122,6 +136,8 @@ def start_session(self, session_id: str, rounds: int) -> None:
current_round = round

try:
if self.get_session_status(session_id) == "Terminated":
break
_, round_data = self.round(session_config, str(current_round))
except TypeError as e:
logger.error("Failed to execute round: {0}".format(e))
Expand All @@ -130,7 +146,8 @@ def start_session(self, session_id: str, rounds: int) -> None:

session_config["model_id"] = self.statestore.get_latest_model()

self.set_session_status(session_id, "Finished")
if self.get_session_status(session_id) == "Started":
self.set_session_status(session_id, "Finished")
self._state = ReducerState.idle

def session(self, config: RoundConfig) -> None:
Expand Down Expand Up @@ -172,6 +189,8 @@ def session(self, config: RoundConfig) -> None:
current_round = round

try:
if self.get_session_status(config["session_id"]) == "Terminated":
break
_, round_data = self.round(config, str(current_round))
except TypeError as e:
logger.error("Failed to execute round: {0}".format(e))
Expand All @@ -181,7 +200,8 @@ def session(self, config: RoundConfig) -> None:
config["model_id"] = self.statestore.get_latest_model()

# TODO: Report completion of session
self.set_session_status(config["session_id"], "Finished")
if self.get_session_status(config["session_id"]) == "Started":
self.set_session_status(config["session_id"], "Finished")
self._state = ReducerState.idle

def inference_session(self, config: RoundConfig) -> None:
Expand Down Expand Up @@ -227,6 +247,7 @@ def round(self, session_config: RoundConfig, round_id: str):
: type round_id: str
"""
session_id = session_config["session_id"]
self.create_round({"round_id": round_id, "status": "Pending"})

if len(self.network.get_combiners()) < 1:
Expand All @@ -239,7 +260,7 @@ def round(self, session_config: RoundConfig, round_id: str):
round_config["rounds"] = 1
round_config["round_id"] = round_id
round_config["task"] = "training"
round_config["session_id"] = session_config["session_id"]
round_config["session_id"] = session_id

self.set_round_config(round_id, round_config)

Expand All @@ -263,7 +284,11 @@ def round(self, session_config: RoundConfig, round_id: str):
# Wait until participating combiners have produced an updated global model,
# or round times out.
def do_if_round_times_out(result):
logger.warning("Round timed out!")
if isinstance(result.outcome.exception(), SessionTerminatedException):
logger.warning("Session terminated!")
return None, self.statestore.get_round(round_id)
else:
logger.warning("Round timed out!")

@retry(
wait=wait_random(min=1.0, max=2.0),
Expand All @@ -273,6 +298,9 @@ def do_if_round_times_out(result):
)
def combiners_done():
round = self.statestore.get_round(round_id)
session_status = self.get_session_status(session_id)
if session_status == "Terminated":
raise SessionTerminatedException("Session terminated!")
if "combiners" not in round:
logger.info("Waiting for combiners to update model...")
raise CombinersNotDoneException("Combiners have not yet reported.")
Expand All @@ -283,7 +311,7 @@ def combiners_done():

return True

combiners_done()
_ = combiners_done()

# Due to the distributed nature of the computation, there might be a
# delay before combiners have reported the round data to the db,
Expand Down
10 changes: 10 additions & 0 deletions fedn/network/controller/controlbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,16 @@ def set_session_status(self, session_id, status):
"""
self.statestore.set_session_status(session_id, status)

def get_session_status(self, session_id):
"""Get the status of a session.
:param session_id: The session unique identifier
:type session_id: str
:return: The status
:rtype: str
"""
return self.statestore.get_session_status(session_id)

def create_round(self, round_data):
"""Initialize a new round in backend db."""
self.statestore.create_round(round_data)
Expand Down
11 changes: 11 additions & 0 deletions fedn/network/storage/statestore/mongostatestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ def get_session(self, session_id):
"""
return self.sessions.find_one({"session_id": session_id})

def get_session_status(self, session_id):
"""Get the session status.
:param session_id: The session id.
:type session_id: str
:return: The session status.
:rtype: str
"""
session = self.sessions.find_one({"session_id": session_id})
return session["status"]

def set_latest_model(self, model_id, session_id=None):
"""Set the latest model id.
Expand Down

0 comments on commit c54db41

Please sign in to comment.