Skip to content

Commit

Permalink
more robust waiting for backward pass in roundhandler and controller
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankJonasmoelle committed Dec 20, 2024
1 parent 78096ad commit b8471cb
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 28 deletions.
27 changes: 14 additions & 13 deletions fedn/network/clients/fedn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def update_local_model(self, request):

self.send_status(
f"\t Starting processing of training request for model_id {model_id}",
sesssion_id=request.session_id,
session_id=request.session_id,
sender_name=self.name,
log_level=fedn.LogLevel.INFO,
type=fedn.StatusType.MODEL_UPDATE,
Expand Down Expand Up @@ -263,7 +263,7 @@ def update_local_model(self, request):
log_level=fedn.LogLevel.AUDIT,
type=fedn.StatusType.MODEL_UPDATE,
request=update,
sesssion_id=request.session_id,
session_id=request.session_id,
sender_name=self.name,
)

Expand All @@ -272,7 +272,7 @@ def validate_global_model(self, request):

self.send_status(
f"Processing validate request for model_id {model_id}",
sesssion_id=request.session_id,
session_id=request.session_id,
sender_name=self.name,
log_level=fedn.LogLevel.INFO,
type=fedn.StatusType.MODEL_VALIDATION,
Expand Down Expand Up @@ -303,15 +303,15 @@ def validate_global_model(self, request):
log_level=fedn.LogLevel.AUDIT,
type=fedn.StatusType.MODEL_VALIDATION,
request=validation,
sesssion_id=request.session_id,
session_id=request.session_id,
sender_name=self.name,
)
else:
self.send_status(
"Client {} failed to complete model validation.".format(self.name),
log_level=fedn.LogLevel.WARNING,
request=request,
sesssion_id=request.session_id,
session_id=request.session_id,
sender_name=self.name,
)

Expand Down Expand Up @@ -343,7 +343,7 @@ def forward_embeddings(self, request):
logger.error("No forward callback set")
return

self.send_status(f"\t Starting processing of forward request for model_id {model_id}", sesssion_id=request.session_id, sender_name=self.name)
self.send_status(f"\t Starting processing of forward request for model_id {model_id}", session_id=request.session_id, sender_name=self.name)

logger.info(f"Running forward callback with model ID: {model_id}")
tic = time.time()
Expand All @@ -365,7 +365,7 @@ def forward_embeddings(self, request):
log_level=fedn.LogLevel.AUDIT,
type=fedn.StatusType.MODEL_UPDATE,
request=update,
sesssion_id=request.session_id,
session_id=request.session_id,
sender_name=self.name,
)

Expand All @@ -386,7 +386,7 @@ def backward_gradients(self, request):
logger.error("No backward callback set")
return

self.send_status(f"\t Starting processing of backward request for gradient_id {model_id}", sesssion_id=request.session_id, sender_name=self.name)
self.send_status(f"\t Starting processing of backward request for gradient_id {model_id}", session_id=request.session_id, sender_name=self.name)

logger.info(f"Running backward callback with gradient ID: {model_id}")
tic = time.time()
Expand All @@ -406,11 +406,11 @@ def backward_gradients(self, request):
self.grpc_handler.send_backward_completion(completion)

self.send_status(
"Backward pass completed.",
"Backward pass completed. Status: finished_backward",
log_level=fedn.LogLevel.AUDIT,
type=fedn.StatusType.BACKWARD,
# request=update,
sesssion_id=request.session_id,
session_id=request.session_id,
sender_name=self.name,
)
except Exception as e:
Expand All @@ -422,7 +422,8 @@ def create_backward_completion_message(self, gradient_id: str, meta: dict, reque
receiver_name=request.sender.name,
receiver_role=request.sender.role,
gradient_id=gradient_id,
meta=meta,
session_id=request.session_id,
meta=meta,
)

def create_update_message(self, model_id: str, model_update_id: str, meta: dict, request: fedn.TaskRequest):
Expand Down Expand Up @@ -478,8 +479,8 @@ def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) ->
def send_model_to_combiner(self, model: BytesIO, id: str):
return self.grpc_handler.send_model_to_combiner(model, id)

def send_status(self, msg: str, log_level=fedn.LogLevel.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None):
return self.grpc_handler.send_status(msg, log_level, type, request, sesssion_id, sender_name)
def send_status(self, msg: str, log_level=fedn.LogLevel.INFO, type=None, request=None, session_id: str = None, sender_name: str = None):
return self.grpc_handler.send_status(msg, log_level, type, request, session_id, sender_name)

def send_model_update(self, update: fedn.ModelUpdate) -> bool:
return self.grpc_handler.send_model_update(update)
Expand Down
6 changes: 3 additions & 3 deletions fedn/network/clients/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,19 +361,19 @@ def create_backward_completion_message(
receiver_name: str,
receiver_role: fedn.Role,
gradient_id: str,
# correlation_id: str,
# session_id: str,
session_id: str,
meta: dict,
):
completion = fedn.BackwardCompletion()
completion.sender.name = sender_name
completion.sender.role = fedn.WORKER
completion.sender.role = fedn.CLIENT
completion.sender.client_id = self.metadata[0][1]
completion.receiver.name = receiver_name
completion.receiver.role = receiver_role
completion.gradient_id = gradient_id
completion.timestamp.GetCurrentTime()
completion.meta = json.dumps(meta)
completion.session_id = session_id
return completion

def send_backward_completion(self, update: fedn.BackwardCompletion):
Expand Down
29 changes: 23 additions & 6 deletions fedn/network/combiner/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@
import fedn.network.grpc.fedn_pb2 as fedn
import fedn.network.grpc.fedn_pb2_grpc as rpc
from fedn.common.certificate.certificate import Certificate
from fedn.common.log_config import (logger, set_log_level_from_string,
set_log_stream)
from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream
from fedn.network.combiner.roundhandler import RoundConfig, RoundHandler
from fedn.network.combiner.shared import (client_store, combiner_store,
prediction_store, repository,
statestore, status_store,
validation_store)
from fedn.network.combiner.shared import client_store, combiner_store, prediction_store, repository, statestore, status_store, validation_store
from fedn.network.grpc.server import Server, ServerConfig
from fedn.network.storage.statestore.stores.shared import EntityNotFound

Expand Down Expand Up @@ -855,6 +851,27 @@ def SendBackwardCompletion(self, request, context):
"""
logger.info("Received BackwardCompletion from {}".format(request.sender.name))

########### TODO checking if this works

# Create and send status message for backward completion
status = fedn.Status()
status.timestamp.GetCurrentTime()
status.sender.name = request.sender.name
status.sender.role = request.sender.role
status.sender.client_id = request.sender.client_id
status.status = "finished_backward"
status.type = fedn.StatusType.BACKWARD
status.session_id = request.session_id


logger.info(f"Creating status message with session_id: {request.session_id}")
self._send_status(status)
logger.info("Status message sent to MongoDB")


###########


response = fedn.Response()
response.response = "RECEIVED BackwardCompletion from client {}".format(request.sender.name)
return response
Expand Down
14 changes: 11 additions & 3 deletions fedn/network/combiner/roundhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,17 @@ def _backward_pass(self, config: dict, clients: list):

self.server.request_backward_pass(session_id=config["session_id"], gradient_id=config["model_id"], config=config, clients=clients)

time.sleep(1) # TODO: this is an easy hack for now. There needs to be some waiting time for the backward pass to complete.
# the above mechanism cannot be used, as the backward pass is not returning any model updates (update_handler.waitforit checks for aggregation on the
# queue)
# time.sleep(1)

# Wait for backward completions
start_time = time.time()
while time.time() - start_time < meta["timeout"]:
completion_status = self.server.statestore.check_backward_completion(config["session_id"], meta["nr_required_updates"])
if completion_status:
logger.info("All required clients completed backward pass")
return meta
time.sleep(0.1)
logger.warning("Timeout waiting for backward pass completion")
return meta

def stage_model(self, model_id, timeout_retry=3, retry=2):
Expand Down
27 changes: 24 additions & 3 deletions fedn/network/controller/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import time
import uuid

from tenacity import (retry, retry_if_exception_type, stop_after_delay,
wait_random)
from tenacity import retry, retry_if_exception_type, stop_after_delay, wait_random

from fedn.common.log_config import logger
from fedn.network.combiner.interfaces import CombinerUnavailableError
Expand Down Expand Up @@ -554,9 +553,31 @@ def check_combiners_done_reporting():
participating_combiners = [(combiner, backward_config) for combiner, _ in participating_combiners]
_ = self.request_model_updates(participating_combiners)

time.sleep(1) # TODO: this is an easy hack for now. There needs to be some waiting time for the backward pass to complete.
# time.sleep(1) # TODO: this is an easy hack for now. There needs to be some waiting time for the backward pass to complete.
# the above mechanism cannot be used, as the backward pass is not producing any model updates (unlike the forward pass)

# Add check for backward completion
def check_backward_done():
events = self.statestore.get_events(
status="finished_backward",
type="BACKWARD",
sessionId=session_config["session_id"]
)
return events["count"] >= len(round["combiners"])

# Wait for backward pass completion with timeout
start_time = time.time()
timeout = float(session_config["round_timeout"]) # or get from config
while time.time() - start_time < timeout:
if check_backward_done():
logger.info("CONTROLLER: Backward pass completed.")
break
time.sleep(0.1)
else:
logger.error("Backward pass timed out")
self.set_round_status(round_id, "Failed")
return None, self.statestore.get_round(round_id)

logger.info("CONTROLLER: Backward pass completed.")

# Record round completion
Expand Down
13 changes: 13 additions & 0 deletions fedn/network/storage/statestore/mongostatestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,19 @@ def get_model(self, model_id):
"""
return self.model.find_one({"key": "models", "model": model_id})

def check_backward_completion(self, session_id: str, expected_count: int):
try:
events = self.get_events(
status="finished_backward",
type="BACKWARD",
sessionId=session_id
)
completed = events["count"]
return completed >= expected_count
except Exception as e:
logger.error(f"Error checking backward completion: {e}")
return False

def get_events(self, **kwargs):
"""Get events from the database.
Expand Down

0 comments on commit b8471cb

Please sign in to comment.