diff --git a/fedn/network/clients/fedn_client.py b/fedn/network/clients/fedn_client.py index 9abbc095f..307564a8a 100644 --- a/fedn/network/clients/fedn_client.py +++ b/fedn/network/clients/fedn_client.py @@ -235,7 +235,7 @@ def update_local_model(self, request): self.send_status( f"\t Starting processing of training request for model_id {model_id}", - sesssion_id=request.session_id, + session_id=request.session_id, sender_name=self.name, log_level=fedn.LogLevel.INFO, type=fedn.StatusType.MODEL_UPDATE, @@ -263,7 +263,7 @@ def update_local_model(self, request): log_level=fedn.LogLevel.AUDIT, type=fedn.StatusType.MODEL_UPDATE, request=update, - sesssion_id=request.session_id, + session_id=request.session_id, sender_name=self.name, ) @@ -272,7 +272,7 @@ def validate_global_model(self, request): self.send_status( f"Processing validate request for model_id {model_id}", - sesssion_id=request.session_id, + session_id=request.session_id, sender_name=self.name, log_level=fedn.LogLevel.INFO, type=fedn.StatusType.MODEL_VALIDATION, @@ -303,7 +303,7 @@ def validate_global_model(self, request): log_level=fedn.LogLevel.AUDIT, type=fedn.StatusType.MODEL_VALIDATION, request=validation, - sesssion_id=request.session_id, + session_id=request.session_id, sender_name=self.name, ) else: @@ -311,7 +311,7 @@ def validate_global_model(self, request): "Client {} failed to complete model validation.".format(self.name), log_level=fedn.LogLevel.WARNING, request=request, - sesssion_id=request.session_id, + session_id=request.session_id, sender_name=self.name, ) @@ -343,7 +343,7 @@ def forward_embeddings(self, request): logger.error("No forward callback set") return - self.send_status(f"\t Starting processing of forward request for model_id {model_id}", sesssion_id=request.session_id, sender_name=self.name) + self.send_status(f"\t Starting processing of forward request for model_id {model_id}", session_id=request.session_id, sender_name=self.name) logger.info(f"Running forward callback with model ID: {model_id}") tic = time.time() @@ -365,7 +365,7 @@ def forward_embeddings(self, request): log_level=fedn.LogLevel.AUDIT, type=fedn.StatusType.MODEL_UPDATE, request=update, - sesssion_id=request.session_id, + session_id=request.session_id, sender_name=self.name, ) @@ -386,7 +386,7 @@ def backward_gradients(self, request): logger.error("No backward callback set") return - self.send_status(f"\t Starting processing of backward request for gradient_id {model_id}", sesssion_id=request.session_id, sender_name=self.name) + self.send_status(f"\t Starting processing of backward request for gradient_id {model_id}", session_id=request.session_id, sender_name=self.name) logger.info(f"Running backward callback with gradient ID: {model_id}") tic = time.time() @@ -406,11 +406,11 @@ def backward_gradients(self, request): self.grpc_handler.send_backward_completion(completion) self.send_status( - "Backward pass completed.", + "Backward pass completed. Status: finished_backward", log_level=fedn.LogLevel.AUDIT, type=fedn.StatusType.BACKWARD, # request=update, - sesssion_id=request.session_id, + session_id=request.session_id, sender_name=self.name, ) except Exception as e: @@ -422,7 +422,8 @@ def create_backward_completion_message(self, gradient_id: str, meta: dict, reque receiver_name=request.sender.name, receiver_role=request.sender.role, gradient_id=gradient_id, - meta=meta, + session_id=request.session_id, + meta=meta, ) def create_update_message(self, model_id: str, model_update_id: str, meta: dict, request: fedn.TaskRequest): @@ -478,8 +479,8 @@ def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> def send_model_to_combiner(self, model: BytesIO, id: str): return self.grpc_handler.send_model_to_combiner(model, id) - def send_status(self, msg: str, log_level=fedn.LogLevel.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None): - return self.grpc_handler.send_status(msg, log_level, type, request, sesssion_id, sender_name) + def send_status(self, msg: str, log_level=fedn.LogLevel.INFO, type=None, request=None, session_id: str = None, sender_name: str = None): + return self.grpc_handler.send_status(msg, log_level, type, request, session_id, sender_name) def send_model_update(self, update: fedn.ModelUpdate) -> bool: return self.grpc_handler.send_model_update(update) diff --git a/fedn/network/clients/grpc_handler.py b/fedn/network/clients/grpc_handler.py index 5572e7ad8..e4036b805 100644 --- a/fedn/network/clients/grpc_handler.py +++ b/fedn/network/clients/grpc_handler.py @@ -361,19 +361,19 @@ def create_backward_completion_message( receiver_name: str, receiver_role: fedn.Role, gradient_id: str, - # correlation_id: str, - # session_id: str, + session_id: str, meta: dict, ): completion = fedn.BackwardCompletion() completion.sender.name = sender_name - completion.sender.role = fedn.WORKER + completion.sender.role = fedn.CLIENT completion.sender.client_id = self.metadata[0][1] completion.receiver.name = receiver_name completion.receiver.role = receiver_role completion.gradient_id = gradient_id completion.timestamp.GetCurrentTime() completion.meta = json.dumps(meta) + completion.session_id = session_id return completion def send_backward_completion(self, update: fedn.BackwardCompletion): diff --git a/fedn/network/combiner/combiner.py b/fedn/network/combiner/combiner.py index 2e911e841..14b0d6cf6 100644 --- a/fedn/network/combiner/combiner.py +++ b/fedn/network/combiner/combiner.py @@ -14,13 +14,9 @@ import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc from fedn.common.certificate.certificate import Certificate -from fedn.common.log_config import (logger, set_log_level_from_string, - set_log_stream) +from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream from fedn.network.combiner.roundhandler import RoundConfig, RoundHandler -from fedn.network.combiner.shared import (client_store, combiner_store, - prediction_store, repository, - statestore, status_store, - validation_store) +from fedn.network.combiner.shared import client_store, combiner_store, prediction_store, repository, statestore, status_store, validation_store from fedn.network.grpc.server import Server, ServerConfig from fedn.network.storage.statestore.stores.shared import EntityNotFound @@ -855,6 +851,27 @@ def SendBackwardCompletion(self, request, context): """ logger.info("Received BackwardCompletion from {}".format(request.sender.name)) + ########### TODO checking if this works + + # Create and send status message for backward completion + status = fedn.Status() + status.timestamp.GetCurrentTime() + status.sender.name = request.sender.name + status.sender.role = request.sender.role + status.sender.client_id = request.sender.client_id + status.status = "finished_backward" + status.type = fedn.StatusType.BACKWARD + status.session_id = request.session_id + + + logger.info(f"Creating status message with session_id: {request.session_id}") + self._send_status(status) + logger.info("Status message sent to MongoDB") + + + ########### + + response = fedn.Response() response.response = "RECEIVED BackwardCompletion from client {}".format(request.sender.name) return response diff --git a/fedn/network/combiner/roundhandler.py b/fedn/network/combiner/roundhandler.py index f967b80b7..6a66fc9aa 100644 --- a/fedn/network/combiner/roundhandler.py +++ b/fedn/network/combiner/roundhandler.py @@ -282,9 +282,17 @@ def _backward_pass(self, config: dict, clients: list): self.server.request_backward_pass(session_id=config["session_id"], gradient_id=config["model_id"], config=config, clients=clients) - time.sleep(1) # TODO: this is an easy hack for now. There needs to be some waiting time for the backward pass to complete. - # the above mechanism cannot be used, as the backward pass is not returning any model updates (update_handler.waitforit checks for aggregation on the - # queue) + # time.sleep(1) + + # Wait for backward completions + start_time = time.time() + while time.time() - start_time < meta["timeout"]: + completion_status = self.server.statestore.check_backward_completion(config["session_id"], meta["nr_required_updates"]) + if completion_status: + logger.info("All required clients completed backward pass") + return meta + time.sleep(0.1) + logger.warning("Timeout waiting for backward pass completion") return meta def stage_model(self, model_id, timeout_retry=3, retry=2): diff --git a/fedn/network/controller/control.py b/fedn/network/controller/control.py index de65d3363..d0d3f9f60 100644 --- a/fedn/network/controller/control.py +++ b/fedn/network/controller/control.py @@ -3,8 +3,7 @@ import time import uuid -from tenacity import (retry, retry_if_exception_type, stop_after_delay, - wait_random) +from tenacity import retry, retry_if_exception_type, stop_after_delay, wait_random from fedn.common.log_config import logger from fedn.network.combiner.interfaces import CombinerUnavailableError @@ -554,9 +553,31 @@ def check_combiners_done_reporting(): participating_combiners = [(combiner, backward_config) for combiner, _ in participating_combiners] _ = self.request_model_updates(participating_combiners) - time.sleep(1) # TODO: this is an easy hack for now. There needs to be some waiting time for the backward pass to complete. + # time.sleep(1) # TODO: this is an easy hack for now. There needs to be some waiting time for the backward pass to complete. # the above mechanism cannot be used, as the backward pass is not producing any model updates (unlike the forward pass) + # Add check for backward completion + def check_backward_done(): + events = self.statestore.get_events( + status="finished_backward", + type="BACKWARD", + sessionId=session_config["session_id"] + ) + return events["count"] >= len(round["combiners"]) + + # Wait for backward pass completion with timeout + start_time = time.time() + timeout = float(session_config["round_timeout"]) # or get from config + while time.time() - start_time < timeout: + if check_backward_done(): + logger.info("CONTROLLER: Backward pass completed.") + break + time.sleep(0.1) + else: + logger.error("Backward pass timed out") + self.set_round_status(round_id, "Failed") + return None, self.statestore.get_round(round_id) + logger.info("CONTROLLER: Backward pass completed.") # Record round completion diff --git a/fedn/network/storage/statestore/mongostatestore.py b/fedn/network/storage/statestore/mongostatestore.py index 316cd4965..962beab6c 100644 --- a/fedn/network/storage/statestore/mongostatestore.py +++ b/fedn/network/storage/statestore/mongostatestore.py @@ -573,6 +573,19 @@ def get_model(self, model_id): """ return self.model.find_one({"key": "models", "model": model_id}) + def check_backward_completion(self, session_id: str, expected_count: int): + try: + events = self.get_events( + status="finished_backward", + type="BACKWARD", + sessionId=session_id + ) + completed = events["count"] + return completed >= expected_count + except Exception as e: + logger.error(f"Error checking backward completion: {e}") + return False + def get_events(self, **kwargs): """Get events from the database.