From ae06b846efe714810cb932a852021dab282c991a Mon Sep 17 00:00:00 2001 From: Andreas Hellander Date: Thu, 9 Nov 2023 17:41:14 +0100 Subject: [PATCH] Feature/SK-521 | Global model not created if the combiner terminates based on timeout (#478) --- .ci/tests/examples/wait_for.py | 6 +- examples/mnist-keras/bin/build.sh | 2 +- fedn/fedn/common/storage/s3/miniorepo.py | 6 +- fedn/fedn/common/tracer/mongotracer.py | 45 +++- fedn/fedn/network/api/interface.py | 5 +- fedn/fedn/network/controller/control.py | 228 ++++++++++---------- fedn/fedn/network/controller/controlbase.py | 107 ++++++--- 7 files changed, 240 insertions(+), 159 deletions(-) diff --git a/.ci/tests/examples/wait_for.py b/.ci/tests/examples/wait_for.py index dc3345da0..ccd76859d 100644 --- a/.ci/tests/examples/wait_for.py +++ b/.ci/tests/examples/wait_for.py @@ -18,7 +18,7 @@ def _retry(try_func, **func_args): for _ in range(RETRIES): is_success = try_func(**func_args) if is_success: - _eprint('Sucess.') + _eprint('Success.') return True _eprint(f'Sleeping for {SLEEP}.') sleep(SLEEP) @@ -30,7 +30,7 @@ def _test_rounds(n_rounds): client = pymongo.MongoClient( "mongodb://fedn_admin:password@localhost:6534") collection = client['fedn-network']['control']['rounds'] - query = {'reducer.status': 'Success'} + query = {'status': 'Finished'} n = collection.count_documents(query) client.close() _eprint(f'Succeded rounds: {n}.') @@ -60,7 +60,7 @@ def _test_nodes(n_nodes, node_type, reducer_host='localhost', reducer_port='8092 return count == n_nodes except Exception as e: - _eprint(f'Reques exception econuntered: {e}.') + _eprint(f'Request exception enconuntered: {e}.') return False diff --git a/examples/mnist-keras/bin/build.sh b/examples/mnist-keras/bin/build.sh index 18cdb5128..44eda61df 100755 --- a/examples/mnist-keras/bin/build.sh +++ b/examples/mnist-keras/bin/build.sh @@ -5,4 +5,4 @@ set -e client/entrypoint init_seed # Make compute package -tar -czvf package.tgz client \ No newline at end of file +tar -czvf package.tgz client diff --git a/fedn/fedn/common/storage/s3/miniorepo.py b/fedn/fedn/common/storage/s3/miniorepo.py index 9341704e6..154cea7e9 100644 --- a/fedn/fedn/common/storage/s3/miniorepo.py +++ b/fedn/fedn/common/storage/s3/miniorepo.py @@ -62,11 +62,13 @@ def __init__(self, config): self.create_bucket(self.bucket) def create_bucket(self, bucket_name): - """ + """ Create a new bucket. If bucket exists, do nothing. - :param bucket_name: + :param bucket_name: The name of the bucket + :type bucket_name: str """ found = self.client.bucket_exists(bucket_name) + if not found: try: self.client.make_bucket(bucket_name) diff --git a/fedn/fedn/common/tracer/mongotracer.py b/fedn/fedn/common/tracer/mongotracer.py index 0a3e28cdc..aa5c0810b 100644 --- a/fedn/fedn/common/tracer/mongotracer.py +++ b/fedn/fedn/common/tracer/mongotracer.py @@ -52,18 +52,26 @@ def drop_status(self): if self.status: self.status.drop() - def new_session(self, id=None): - """ Create a new session. """ + def create_session(self, id=None): + """ Create a new session. + + :param id: The ID of the created session. + :type id: uuid, str + + """ if not id: id = uuid.uuid4() data = {'session_id': str(id)} self.sessions.insert_one(data) - def new_round(self, id): - """ Create a new session. """ + def create_round(self, round_data): + """ Create a new round. - data = {'round_id': str(id)} - self.rounds.insert_one(data) + :param round_data: Dictionary with round data. + :type round_data: dict + """ + # TODO: Add check if round_id already exists + self.rounds.insert_one(round_data) def set_session_config(self, id, config): self.sessions.update_one({'session_id': str(id)}, { @@ -72,18 +80,35 @@ def set_session_config(self, id, config): def set_round_combiner_data(self, data): """ - :param round_meta: + :param data: The combiner data + :type data: dict """ self.rounds.update_one({'round_id': str(data['round_id'])}, { '$push': {'combiners': data}}, True) - def set_round_data(self, round_data): + def set_round_config(self, round_id, round_config): + """ + + :param round_meta: + """ + self.rounds.update_one({'round_id': round_id}, { + '$set': {'round_config': round_config}}, True) + + def set_round_status(self, round_id, round_status): + """ + + :param round_meta: + """ + self.rounds.update_one({'round_id': round_id}, { + '$set': {'status': round_status}}, True) + + def set_round_data(self, round_id, round_data): """ :param round_meta: """ - self.rounds.update_one({'round_id': str(round_data['round_id'])}, { - '$push': {'reducer': round_data}}, True) + self.rounds.update_one({'round_id': round_id}, { + '$set': {'round_data': round_data}}, True) def update_client_status(self, client_name, status): """ Update client status in statestore. diff --git a/fedn/fedn/network/api/interface.py b/fedn/fedn/network/api/interface.py index 61095e6ec..0821ed176 100644 --- a/fedn/fedn/network/api/interface.py +++ b/fedn/fedn/network/api/interface.py @@ -707,9 +707,8 @@ def get_round(self, round_id): if round_object is None: return jsonify({"success": False, "message": "Round not found."}) payload = { - "round_id": round_object["round_id"], - "reducer": round_object["reducer"], - "combiners": round_object["combiners"], + 'round_id': round_object['round_id'], + 'combiners': round_object['combiners'], } return jsonify(payload) diff --git a/fedn/fedn/network/controller/control.py b/fedn/fedn/network/controller/control.py index a8e32333d..615edb3b5 100644 --- a/fedn/fedn/network/controller/control.py +++ b/fedn/fedn/network/controller/control.py @@ -3,6 +3,9 @@ import time import uuid +from tenacity import (retry, retry_if_exception_type, stop_after_delay, + wait_random) + from fedn.network.combiner.interfaces import CombinerUnavailableError from fedn.network.controller.controlbase import ControlBase from fedn.network.state import ReducerState @@ -48,6 +51,20 @@ def __init__(self, message): super().__init__(self.message) +class CombinersNotDoneException(Exception): + """ Exception class for when model is None """ + + def __init__(self, message): + """ Constructor method. + + :param message: The exception message. + :type message: str + + """ + self.message = message + super().__init__(self.message) + + class Control(ControlBase): """Controller, implementing the overall global training, validation and inference logic. @@ -83,12 +100,10 @@ def session(self, config): return self._state = ReducerState.instructing - - # Must be called to set info in the db config["committed_at"] = datetime.datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) - self.new_session(config) + self.create_session(config) if not self.statestore.get_latest_model(): print( @@ -106,14 +121,13 @@ def session(self, config): # Execute the rounds in this session for round in range(1, int(config["rounds"] + 1)): # Increment the round number - if last_round: current_round = last_round + round else: current_round = round try: - _, round_data = self.round(config, current_round) + _, round_data = self.round(config, str(current_round)) except TypeError as e: print( "Could not unpack data from round: {0}".format(e), @@ -127,30 +141,27 @@ def session(self, config): flush=True, ) - self.tracer.set_round_data(round_data) - # TODO: Report completion of session self._state = ReducerState.idle def round(self, session_config, round_id): - """Execute a single global round. + """ Execute one global round. + + : param session_config: The session config. + : type session_config: dict + : param round_id: The round id. + : type round_id: str - :param session_config: The session config. - :type session_config: dict - :param round_id: The round id. - :type round_id: str(int) """ - round_data = {"round_id": round_id} + self.create_round({'round_id': round_id, 'status': "Pending"}) if len(self.network.get_combiners()) < 1: - print("REDUCER: No combiners connected!", flush=True) - round_data["status"] = "Failed" - return None, round_data + print("CONTROLLER: Round cannot start, no combiners connected!", flush=True) + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - # 1. Assemble round config for this global round, - # and check which combiners are able to participate - # in the round. + # Assemble round config for this global round round_config = copy.deepcopy(session_config) round_config["rounds"] = 1 round_config["round_id"] = round_id @@ -158,94 +169,85 @@ def round(self, session_config, round_id): round_config["model_id"] = self.statestore.get_latest_model() round_config["helper_type"] = self.statestore.get_helper() - combiners = self.get_participating_combiners(round_config) - round_start = self.evaluate_round_start_policy(combiners) + self.set_round_config(round_id, round_config) + + # Get combiners that are able to participate in round, given round_config + participating_combiners = self.get_participating_combiners(round_config) + + # Check if the policy to start the round is met + round_start = self.evaluate_round_start_policy(participating_combiners) if round_start: - print( - "CONTROL: round start policy met, participating combiners {}".format( - combiners - ), - flush=True, - ) + print("CONTROL: round start policy met, {} participating combiners.".format( + len(participating_combiners)), flush=True) else: - print( - "CONTROL: Round start policy not met, skipping round!", - flush=True, - ) - round_data["status"] = "Failed" - return None + print("CONTROL: Round start policy not met, skipping round!", flush=True) + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) + + # Ask participating combiners to coordinate model updates + _ = self.request_model_updates(participating_combiners) + # TODO: Check response + + # Wait until participating combiners have produced an updated global model, + # or round times out. + def do_if_round_times_out(result): + print("CONTROL: Round timed out!", flush=True) + + @retry(wait=wait_random(min=1.0, max=2.0), + stop=stop_after_delay(session_config['round_timeout']), + retry_error_callback=do_if_round_times_out, + retry=retry_if_exception_type(CombinersNotDoneException)) + def combiners_done(): - round_data["round_config"] = round_config - - # 2. Ask participating combiners to coordinate model updates - _ = self.request_model_updates(combiners) - - # Wait until participating combiners have produced an updated global model. - wait = 0.0 - # dict to store combiners that have successfully produced an updated model - updated = {} - # wait until all combiners have produced an updated model or until round timeout - print( - "CONTROL: Fetching round config (ID: {round_id}) from statestore:".format( - round_id=round_id - ), - flush=True, - ) - while len(updated) < len(combiners): round = self.statestore.get_round(round_id) - if round: - print("CONTROL: Round found!", flush=True) - # For each combiner in the round, check if it has produced an updated model (status == 'Success') - for combiner in round["combiners"]: - print(combiner, flush=True) - if combiner["status"] == "Success": - if combiner["name"] not in updated.keys(): - # Add combiner to updated dict - updated[combiner["name"]] = combiner["model_id"] - # Print combiner status - print( - "CONTROL: Combiner {name} status: {status}".format( - name=combiner["name"], status=combiner["status"] - ), - flush=True, - ) - else: - # Print every 10 seconds based on value of wait - if wait % 10 == 0: - print( - "CONTROL: Waiting for round to complete...", flush=True - ) - if wait >= session_config["round_timeout"]: - print("CONTROL: Round timeout! Exiting round...", flush=True) - break - # Update wait time used for timeout - time.sleep(1.0) - wait += 1.0 - - round_valid = self.evaluate_round_validity_policy(updated) + if 'combiners' not in round: + # TODO: use logger + print("CONTROL: Waiting for combiners to update model...", flush=True) + raise CombinersNotDoneException("Combiners have not yet reported.") + + if len(round['combiners']) < len(participating_combiners): + print("CONTROL: Waiting for combiners to update model...", flush=True) + raise CombinersNotDoneException("All combiners have not yet reported.") + + return True + + combiners_done() + + # Due to the distributed nature of the computation, there might be a + # delay before combiners have reported the round data to the db, + # so we need some robustness here. + @retry(wait=wait_random(min=0.1, max=1.0), + retry=retry_if_exception_type(KeyError)) + def check_combiners_done_reporting(): + round = self.statestore.get_round(round_id) + combiners = round['combiners'] + return combiners + + _ = check_combiners_done_reporting() + + round = self.statestore.get_round(round_id) + round_valid = self.evaluate_round_validity_policy(round) if not round_valid: print("REDUCER CONTROL: Round invalid!", flush=True) - round_data["status"] = "Failed" - return None, round_data + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - print("CONTROL: Reducing models from combiners...", flush=True) - # 3. Reduce combiner models into a global model + print("CONTROL: Reducing combiner level models...", flush=True) + # Reduce combiner models into a new global model + round_data = {} try: - model, data = self.reduce(updated) - round_data["reduce"] = data + round = self.statestore.get_round(round_id) + model, data = self.reduce(round['combiners']) + round_data['reduce'] = data print("CONTROL: Done reducing models from combiners!", flush=True) except Exception as e: - print( - "CONTROL: Failed to reduce models from combiners: {}".format( - e - ), - flush=True, - ) - round_data["status"] = "Failed" - return None, round_data + print("CONTROL: Failed to reduce models from combiners: {}".format( + e), flush=True) + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - # 6. Commit the global model to model trail + # Commit the new global model to the model trail if model is not None: print( "CONTROL: Committing global model to model trail...", @@ -271,10 +273,10 @@ def round(self, session_config, round_id): ), flush=True, ) - round_data["status"] = "Failed" - return None, round_data + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - round_data["status"] = "Success" + self.set_round_status(round_id, 'Success') # 4. Trigger participating combiner nodes to execute a validation round for the current model validate = session_config["validate"] @@ -285,9 +287,8 @@ def round(self, session_config, round_id): combiner_config["task"] = "validation" combiner_config["helper_type"] = self.statestore.get_helper() - validating_combiners = self._select_participating_combiners( - combiner_config - ) + validating_combiners = self.get_participating_combiners( + combiner_config) for combiner, combiner_config in validating_combiners: try: @@ -302,13 +303,15 @@ def round(self, session_config, round_id): self._handle_unavailable_combiner(combiner) pass - return model_id, round_data + self.set_round_data(round_id, round_data) + self.set_round_status(round_id, 'Finished') + return model_id, self.statestore.get_round(round_id) def reduce(self, combiners): """Combine updated models from Combiner nodes into one global model. - :param combiners: dict of combiner names (key) and model IDs (value) to reduce - :type combiners: dict + : param combiners: dict of combiner names(key) and model IDs(value) to reduce + : type combiners: dict """ meta = {} @@ -323,7 +326,9 @@ def reduce(self, combiners): print("REDUCER: No combiners to reduce!", flush=True) return model, meta - for name, model_id in combiners.items(): + for combiner in combiners: + name = combiner['name'] + model_id = combiner['model_id'] # TODO: Handle inactive RPC error in get_model and raise specific error print( "REDUCER: Fetching model ({model_id}) from combiner {name}".format( @@ -333,9 +338,9 @@ def reduce(self, combiners): ) try: tic = time.time() - combiner = self.get_combiner(name) - data = combiner.get_model(model_id) - meta["time_fetch_model"] += time.time() - tic + combiner_interface = self.get_combiner(name) + data = combiner_interface.get_model(model_id) + meta['time_fetch_model'] += (time.time() - tic) except Exception as e: print( "REDUCER: Failed to fetch model from combiner {}: {}".format( @@ -367,7 +372,7 @@ def reduce(self, combiners): def infer_instruct(self, config): """Main entrypoint for executing the inference compute plan. - :param config: configuration for the inference round + : param config: configuration for the inference round """ # Check/set instucting state @@ -395,7 +400,7 @@ def infer_instruct(self, config): def inference_round(self, config): """Execute an inference round. - :param config: configuration for the inference round + : param config: configuration for the inference round """ # Init meta @@ -413,7 +418,8 @@ def inference_round(self, config): combiner_config["helper_type"] = self.statestore.get_framework() # Select combiners - validating_combiners = self._select_round_combiners(combiner_config) + validating_combiners = self.get_participating_combiners( + combiner_config) # Test round start policy round_start = self.check_round_start_policy(validating_combiners) diff --git a/fedn/fedn/network/controller/controlbase.py b/fedn/fedn/network/controller/controlbase.py index 077620c14..fab6a2027 100644 --- a/fedn/fedn/network/controller/controlbase.py +++ b/fedn/fedn/network/controller/controlbase.py @@ -196,8 +196,8 @@ def get_compute_package(self, compute_package=""): else: return None - def new_session(self, config): - """Initialize a new session in backend db.""" + def create_session(self, config): + """ Initialize a new session in backend db. """ if "session_id" not in config.keys(): session_id = uuid.uuid4() @@ -205,11 +205,50 @@ def new_session(self, config): else: session_id = config["session_id"] - self.tracer.new_session(id=session_id) + self.tracer.create_session(id=session_id) self.tracer.set_session_config(session_id, config) + def create_round(self, round_data): + """Initialize a new round in backend db. """ + + self.tracer.create_round(round_data) + + def set_round_data(self, round_id, round_data): + """ Set round data. + + :param round_id: The round unique identifier + :type round_id: str + :param round_data: The status + :type status: dict + """ + self.tracer.set_round_data(round_id, round_data) + + def set_round_status(self, round_id, status): + """ Set the round round stats. + + :param round_id: The round unique identifier + :type round_id: str + :param status: The status + :type status: str + """ + self.tracer.set_round_status(round_id, status) + + def set_round_config(self, round_id, round_config): + """ Upate round in backend db. + + :param round_id: The round unique identifier + :type round_id: str + :param round_config: The round configuration + :type round_config: dict + """ + self.tracer.set_round_config(round_id, round_config) + def request_model_updates(self, combiners): - """Call Combiner server RPC to get a model update.""" + """Ask Combiner server to produce a model update. + + :param combiners: A list of combiners + :type combiners: tuple (combiner, comboner_round_config) + """ cl = [] for combiner, combiner_round_config in combiners: response = combiner.submit(combiner_round_config) @@ -217,7 +256,15 @@ def request_model_updates(self, combiners): return cl def commit(self, model_id, model=None, session_id=None): - """Commit a model to the global model trail. The model commited becomes the lastest consensus model.""" + """Commit a model to the global model trail. The model commited becomes the lastest consensus model. + + :param model_id: Unique identifier for the model to commit. + :type model_id: str (uuid) + :param model: The model object to commit + :type model: BytesIO + :param session_id: Unique identifier for the session + :type session_id: str + """ helper = self.get_helper() if model is not None: @@ -289,45 +336,47 @@ def evaluate_round_participation_policy( return False def evaluate_round_start_policy(self, combiners): - """Check if the policy to start a round is met.""" + """Check if the policy to start a round is met. + + :param combiners: A list of combiners + :type combiners: list + :return: True if the round policy is mer, otherwise False + :rtype: bool + """ if len(combiners) > 0: return True else: return False - def evaluate_round_validity_policy(self, combiners): - """Check if the round should be seen as valid. + def evaluate_round_validity_policy(self, round): + """ Check if the round is valid. At the end of the round, before committing a model to the global model trail, we check if the round validity policy has been met. This can involve e.g. asserting that a certain number of combiners have reported in an updated model, or that criteria on model performance have been met. - """ - if combiners.keys() == []: - return False - else: - return True - def _select_participating_combiners(self, compute_plan): - participating_combiners = [] - for combiner in self.network.get_combiners(): + :param round: The round object + :rtype round: dict + :return: True if the policy is met, otherwise False + :rtype: bool + """ + model_ids = [] + for combiner in round['combiners']: try: - combiner_state = combiner.report() - except CombinerUnavailableError: - self._handle_unavailable_combiner(combiner) - combiner_state = None + model_ids.append(combiner['model_id']) + except KeyError: + pass - if combiner_state: - is_participating = self.evaluate_round_participation_policy( - compute_plan, combiner_state - ) - if is_participating: - participating_combiners.append((combiner, compute_plan)) - return participating_combiners + if len(model_ids) == 0: + return False + + return True def state(self): - """ + """ Get the current state of the controller - :return: + :return: The state + :rype: str """ return self._state