diff --git a/fedn/network/clients/client_api.py b/fedn/network/clients/client_api.py index 332331309..108a5448a 100644 --- a/fedn/network/clients/client_api.py +++ b/fedn/network/clients/client_api.py @@ -207,7 +207,6 @@ def init_grpchandler(self, config: GrpcConnectionOptions, client_name: str, toke logger.error("Error: Could not initialize GRPC connection") return False - def send_heartbeats(self, client_name: str, client_id: str, update_frequency: float = 2.0): self.grpc_handler.send_heartbeats(client_name=client_name, client_id=client_id, update_frequency=update_frequency) @@ -220,8 +219,8 @@ def _task_stream_callback(self, request): elif request.type == fedn.StatusType.MODEL_VALIDATION: self.validate(request) - def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) -> BytesIO: - return self.grpc_handler.get_model_from_combiner(id=id, client_name=client_name, timeout=timeout) + def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> BytesIO: + return self.grpc_handler.get_model_from_combiner(id=id, client_name=client_id, timeout=timeout) def send_model_to_combiner(self, model: BytesIO, id: str): return self.grpc_handler.send_model_to_combiner(model, id) @@ -229,7 +228,8 @@ def send_model_to_combiner(self, model: BytesIO, id: str): def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None): return self.grpc_handler.send_status(msg, log_level, type, request, sesssion_id, sender_name) - def send_model_update(self, + def send_model_update( + self, sender_name: str, sender_role: fedn.Role, client_id: str, @@ -237,7 +237,7 @@ def send_model_update(self, model_update_id: str, receiver_name: str, receiver_role: fedn.Role, - meta: dict + meta: dict, ) -> bool: return self.grpc_handler.send_model_update( sender_name=sender_name, @@ -247,17 +247,11 @@ def send_model_update(self, model_update_id=model_update_id, receiver_name=receiver_name, receiver_role=receiver_role, - meta=meta + meta=meta, ) - def send_model_validation(self, - sender_name: str, - receiver_name: str, - receiver_role: fedn.Role, - model_id: str, - metrics: dict, - correlation_id: str, - session_id: str + def send_model_validation( + self, sender_name: str, receiver_name: str, receiver_role: fedn.Role, model_id: str, metrics: dict, correlation_id: str, session_id: str ) -> bool: return self.grpc_handler.send_model_validation(sender_name, receiver_name, receiver_role, model_id, metrics, correlation_id, session_id) diff --git a/fedn/network/clients/client_v2.py b/fedn/network/clients/client_v2.py index ab32f6116..7f5ee93cf 100644 --- a/fedn/network/clients/client_v2.py +++ b/fedn/network/clients/client_v2.py @@ -44,16 +44,17 @@ def to_json(self): class Client: - def __init__(self, - api_url: str, - api_port: int, - client_obj: ClientOptions, - combiner_host: str = None, - combiner_port: int = None, - token: str = None, - package_checksum: str = None, - helper_type: str = None - ): + def __init__( + self, + api_url: str, + api_port: int, + client_obj: ClientOptions, + combiner_host: str = None, + combiner_port: int = None, + token: str = None, + package_checksum: str = None, + helper_type: str = None, + ): self.api_url = api_url self.api_port = api_port self.combiner_host = combiner_host @@ -149,7 +150,6 @@ def on_validation(self, request): logger.info("Received validation request") self._process_validation_request(request) - def _process_training_request(self, request) -> Tuple[str, dict]: """Process a training (model update) request. @@ -164,16 +164,14 @@ def _process_training_request(self, request) -> Tuple[str, dict]: session_id: str = request.session_id self.client_api.send_status( - f"\t Starting processing of training request for model_id {model_id}", - sesssion_id=session_id, - sender_name=self.client_obj.name + f"\t Starting processing of training request for model_id {model_id}", sesssion_id=session_id, sender_name=self.client_obj.name ) try: meta = {} tic = time.time() - model = self.client_api.get_model_from_combiner(id=str(model_id), client_name=self.client_obj.client_id) + model = self.client_api.get_model_from_combiner(id=str(model_id), client_id=self.client_obj.client_id) if model is None: logger.error("Could not retrieve model from combiner. Aborting training request.") @@ -246,7 +244,7 @@ def _process_training_request(self, request) -> Tuple[str, dict]: type=fedn.StatusType.MODEL_UPDATE, request=request, sesssion_id=session_id, - sender_name=self.client_obj.name + sender_name=self.client_obj.name, ) def _process_validation_request(self, request): @@ -266,7 +264,7 @@ def _process_validation_request(self, request): self.client_api.send_status(f"Processing {cmd} request for model_id {model_id}", sesssion_id=session_id, sender_name=self.client_obj.name) try: - model = self.client_api.get_model_from_combiner(id=str(model_id), client_name=self.client_obj.client_id) + model = self.client_api.get_model_from_combiner(id=str(model_id), client_id=self.client_obj.client_id) if model is None: logger.error("Could not retrieve model from combiner. Aborting validation request.") return @@ -318,7 +316,7 @@ def _process_validation_request(self, request): type=fedn.StatusType.MODEL_VALIDATION, request=validation, sesssion_id=request.session_id, - sender_name=self.client_obj.name + sender_name=self.client_obj.name, ) else: self.client_api.send_status( @@ -326,5 +324,5 @@ def _process_validation_request(self, request): log_level=fedn.Status.WARNING, request=request, sesssion_id=request.session_id, - sender_name=self.client_obj.name + sender_name=self.client_obj.name, ) diff --git a/fedn/network/clients/grpc_handler.py b/fedn/network/clients/grpc_handler.py index 4b327edba..9edf40233 100644 --- a/fedn/network/clients/grpc_handler.py +++ b/fedn/network/clients/grpc_handler.py @@ -206,7 +206,7 @@ def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=N logger.error(f"GRPC (SendStatus): An error occurred: {e}") self._disconnect() - def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) -> BytesIO: + def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> BytesIO: """Fetch a model from the assigned combiner. Downloads the model update object via a gRPC streaming channel. @@ -218,7 +218,7 @@ def get_model_from_combiner(self, id: str, client_name: str, timeout: int = 20) data = BytesIO() time_start = time.time() request = fedn.ModelRequest(id=id) - request.sender.client_id = client_name + request.sender.client_id = client_id request.sender.role = fedn.WORKER try: