diff --git a/fedn/network/clients/client_api.py b/fedn/network/clients/client_api.py index 332331309..108a5448a 100644 --- a/fedn/network/clients/client_api.py +++ b/fedn/network/clients/client_api.py @@ -207,7 +207,6 @@ def init_grpchandler(self, config: GrpcConnectionOptions, client_name: str, toke logger.error("Error: Could not initialize GRPC connection") return False - def send_heartbeats(self, client_name: str, client_id: str, update_frequency: float = 2.0): self.grpc_handler.send_heartbeats(client_name=client_name, client_id=client_id, update_frequency=update_frequency) @@ -220,8 +219,8 @@ def _task_stream_callback(self, request): elif request.type == fedn.StatusType.MODEL_VALIDATION: self.validate(request) - def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) -> BytesIO: - return self.grpc_handler.get_model_from_combiner(id=id, client_name=client_name, timeout=timeout) + def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> BytesIO: + return self.grpc_handler.get_model_from_combiner(id=id, client_name=client_id, timeout=timeout) def send_model_to_combiner(self, model: BytesIO, id: str): return self.grpc_handler.send_model_to_combiner(model, id) @@ -229,7 +228,8 @@ def send_model_to_combiner(self, model: BytesIO, id: str): def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None): return self.grpc_handler.send_status(msg, log_level, type, request, sesssion_id, sender_name) - def send_model_update(self, + def send_model_update( + self, sender_name: str, sender_role: fedn.Role, client_id: str, @@ -237,7 +237,7 @@ def send_model_update(self, model_update_id: str, receiver_name: str, receiver_role: fedn.Role, - meta: dict + meta: dict, ) -> bool: return self.grpc_handler.send_model_update( sender_name=sender_name, @@ -247,17 +247,11 @@ def send_model_update(self, model_update_id=model_update_id, receiver_name=receiver_name, receiver_role=receiver_role, - meta=meta + meta=meta, ) - def send_model_validation(self, - sender_name: str, - receiver_name: str, - receiver_role: fedn.Role, - model_id: str, - metrics: dict, - correlation_id: str, - session_id: str + def send_model_validation( + self, sender_name: str, receiver_name: str, receiver_role: fedn.Role, model_id: str, metrics: dict, correlation_id: str, session_id: str ) -> bool: return self.grpc_handler.send_model_validation(sender_name, receiver_name, receiver_role, model_id, metrics, correlation_id, session_id) diff --git a/fedn/network/clients/client_v2.py b/fedn/network/clients/client_v2.py index ab32f6116..7f5ee93cf 100644 --- a/fedn/network/clients/client_v2.py +++ b/fedn/network/clients/client_v2.py @@ -44,16 +44,17 @@ def to_json(self): class Client: - def __init__(self, - api_url: str, - api_port: int, - client_obj: ClientOptions, - combiner_host: str = None, - combiner_port: int = None, - token: str = None, - package_checksum: str = None, - helper_type: str = None - ): + def __init__( + self, + api_url: str, + api_port: int, + client_obj: ClientOptions, + combiner_host: str = None, + combiner_port: int = None, + token: str = None, + package_checksum: str = None, + helper_type: str = None, + ): self.api_url = api_url self.api_port = api_port self.combiner_host = combiner_host @@ -149,7 +150,6 @@ def on_validation(self, request): logger.info("Received validation request") self._process_validation_request(request) - def _process_training_request(self, request) -> Tuple[str, dict]: """Process a training (model update) request. @@ -164,16 +164,14 @@ def _process_training_request(self, request) -> Tuple[str, dict]: session_id: str = request.session_id self.client_api.send_status( - f"\t Starting processing of training request for model_id {model_id}", - sesssion_id=session_id, - sender_name=self.client_obj.name + f"\t Starting processing of training request for model_id {model_id}", sesssion_id=session_id, sender_name=self.client_obj.name ) try: meta = {} tic = time.time() - model = self.client_api.get_model_from_combiner(id=str(model_id), client_name=self.client_obj.client_id) + model = self.client_api.get_model_from_combiner(id=str(model_id), client_id=self.client_obj.client_id) if model is None: logger.error("Could not retrieve model from combiner. Aborting training request.") @@ -246,7 +244,7 @@ def _process_training_request(self, request) -> Tuple[str, dict]: type=fedn.StatusType.MODEL_UPDATE, request=request, sesssion_id=session_id, - sender_name=self.client_obj.name + sender_name=self.client_obj.name, ) def _process_validation_request(self, request): @@ -266,7 +264,7 @@ def _process_validation_request(self, request): self.client_api.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id, sender_name=self.client_obj.name) try: - model = self.client_api.get_model_from_combiner(id=str(model_id), client_name=self.client_obj.client_id) + model = self.client_api.get_model_from_combiner(id=str(model_id), client_id=self.client_obj.client_id) if model is None: logger.error("Could not retrieve model from combiner. Aborting validation request.") return @@ -318,7 +316,7 @@ def _process_validation_request(self, request): type=fedn.StatusType.MODEL_VALIDATION, request=validation, sesssion_id=request.session_id, - sender_name=self.client_obj.name + sender_name=self.client_obj.name, ) else: self.client_api.send_status( @@ -326,5 +324,5 @@ def _process_validation_request(self, request): log_level=fedn.Status.WARNING, request=request, sesssion_id=request.session_id, - sender_name=self.client_obj.name + sender_name=self.client_obj.name, ) diff --git a/fedn/network/clients/grpc_handler.py b/fedn/network/clients/grpc_handler.py index 4d2b5d569..759161a4c 100644 --- a/fedn/network/clients/grpc_handler.py +++ b/fedn/network/clients/grpc_handler.py @@ -17,6 +17,24 @@ from fedn.common.log_config import logger from fedn.network.combiner.modelservice import upload_request_generator +# Keepalive settings: these help keep the connection open for long-lived clients +KEEPALIVE_TIME_MS = 1 * 1000 # send keepalive ping every 60 seconds +KEEPALIVE_TIMEOUT_MS = 30 * 1000 # wait 20 seconds for keepalive ping ack before considering connection dead +KEEPALIVE_PERMIT_WITHOUT_CALLS = True # allow keepalive pings even when there are no RPCs +MAX_CONNECTION_IDLE_MS = 30000 +MAX_CONNECTION_AGE_GRACE_MS = "INT_MAX" # keep connection open indefinitely +CLIENT_IDLE_TIMEOUT_MS = 30000 + +GRPC_OPTIONS = [ + ("grpc.keepalive_time_ms", KEEPALIVE_TIME_MS), + ("grpc.keepalive_timeout_ms", KEEPALIVE_TIMEOUT_MS), + ("grpc.keepalive_permit_without_calls", KEEPALIVE_PERMIT_WITHOUT_CALLS), + ("grpc.http2.max_pings_without_data", 0), # unlimited pings without data + ("grpc.max_connection_idle_ms", MAX_CONNECTION_IDLE_MS), + ("grpc.max_connection_age_grace_ms", MAX_CONNECTION_AGE_GRACE_MS), + ("grpc.client_idle_timeout_ms", CLIENT_IDLE_TIMEOUT_MS), +] + class GrpcAuth(grpc.AuthMetadataPlugin): def __init__(self, key): @@ -61,11 +79,6 @@ def _init_secure_channel(self, host: str, port: int, token: str): url = f"{host}:{port}" logger.info(f"Connecting (GRPC) to {url}") - # Keepalive settings: these help keep the connection open for long-lived clients - KEEPALIVE_TIME_MS = 60 * 1000 # send keepalive ping every 60 seconds - KEEPALIVE_TIMEOUT_MS = 20 * 1000 # wait 20 seconds for keepalive ping ack before considering connection dead - KEEPALIVE_PERMIT_WITHOUT_CALLS = True # allow keepalive pings even when there are no RPCs - if os.getenv("FEDN_GRPC_ROOT_CERT_PATH"): logger.info("Using root certificate from environment variable for GRPC channel.") with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], "rb") as f: @@ -80,34 +93,48 @@ def _init_secure_channel(self, host: str, port: int, token: str): self.channel = grpc.secure_channel( "{}:{}".format(host, str(port)), grpc.composite_channel_credentials(credentials, auth_creds), - options=[ - ("grpc.keepalive_time_ms", KEEPALIVE_TIME_MS), - ("grpc.keepalive_timeout_ms", KEEPALIVE_TIMEOUT_MS), - ("grpc.keepalive_permit_without_calls", KEEPALIVE_PERMIT_WITHOUT_CALLS), - ("grpc.http2.max_pings_without_data", 0), # unlimited pings without data - ], + options=GRPC_OPTIONS, ) def _init_insecure_channel(self, host: str, port: int): url = f"{host}:{port}" logger.info(f"Connecting (GRPC) to {url}") - self.channel = grpc.insecure_channel(url) + self.channel = grpc.insecure_channel( + url, + options=GRPC_OPTIONS, + ) - def send_heartbeats(self, client_name: str, client_id: str, update_frequency: float = 2.0): + def heartbeat(self, client_name: str, client_id: str): + """Send a heartbeat to the combiner. + + :return: Response from the combiner. + :rtype: fedn.Response + """ heartbeat = fedn.Heartbeat(sender=fedn.Client(name=client_name, role=fedn.WORKER, client_id=client_id)) + try: + logger.info("Sending heartbeat to combiner") + response = self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata) + except grpc.RpcError as e: + raise e + except Exception as e: + logger.error(f"GRPC (SendHeartbeat): An error occurred: {e}") + self._disconnect() + raise e + return response + + def send_heartbeats(self, client_name: str, client_id: str, update_frequency: float = 2.0): send_hearbeat = True while send_hearbeat: try: - logger.info("Sending heartbeat to combiner") - self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata) + response = self.heartbeat(client_name, client_id) except grpc.RpcError as e: return self._handle_grpc_error(e, "SendHeartbeat", lambda: self.send_heartbeats(client_name, client_id, update_frequency)) - except Exception as e: - logger.error(f"GRPC (SendHeartbeat): An error occurred: {e}") - self._disconnect() + if isinstance(response, fedn.Response): + logger.info("Heartbeat successful.") + else: + logger.error("Heartbeat failed.") send_hearbeat = False - time.sleep(update_frequency) def listen_to_task_stream(self, client_name: str, client_id: str, callback: Callable[[Any], None]): @@ -179,7 +206,7 @@ def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=N logger.error(f"GRPC (SendStatus): An error occurred: {e}") self._disconnect() - def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) -> BytesIO: + def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> BytesIO: """Fetch a model from the assigned combiner. Downloads the model update object via a gRPC streaming channel. @@ -191,7 +218,7 @@ def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) data = BytesIO() time_start = time.time() request = fedn.ModelRequest(id=id) - request.sender.name = client_name + request.sender.client_id = client_id request.sender.role = fedn.WORKER try: @@ -211,7 +238,7 @@ def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) return None continue except grpc.RpcError as e: - return self._handle_grpc_error(e, "Download", lambda: self.get_model_from_combiner(id, client_name, timeout)) + return self._handle_grpc_error(e, "Download", lambda: self.get_model_from_combiner(id, client_id, timeout)) except Exception as e: logger.error(f"GRPC (Download): An error occurred: {e}") self._disconnect() diff --git a/fedn/network/combiner/modelservice.py b/fedn/network/combiner/modelservice.py index 89901dabb..8600b8bab 100644 --- a/fedn/network/combiner/modelservice.py +++ b/fedn/network/combiner/modelservice.py @@ -229,7 +229,7 @@ def Download(self, request, context): :return: A model response iterator. :rtype: :class:`fedn.network.grpc.fedn_pb2.ModelResponse` """ - logger.info(f"grpc.ModelService.Download: {request.sender.role}:{request.sender.name} requested model {request.id}") + logger.info(f"grpc.ModelService.Download: {request.sender.role}:{request.sender.client_id} requested model {request.id}") try: status = self.temp_model_storage.get_model_metadata(request.id) if status != fedn.ModelStatus.OK: