diff --git a/fedn/network/combiner/combiner.py b/fedn/network/combiner/combiner.py index a8abd78ab..48e62466c 100644 --- a/fedn/network/combiner/combiner.py +++ b/fedn/network/combiner/combiner.py @@ -169,12 +169,12 @@ def request_model_update(self, session_id, model_id, config, clients=[]): :type clients: list """ - request, clients = self._send_request_type(fedn.StatusType.MODEL_UPDATE, session_id, model_id, config, clients) + clients = self._send_request_type(fedn.StatusType.MODEL_UPDATE, session_id, model_id, config, clients) if len(clients) < 20: - logger.info("Sent model update request for model {} to clients {}".format(request.model_id, clients)) + logger.info("Sent model update request for model {} to clients {}".format(model_id, clients)) else: - logger.info("Sent model update request for model {} to {} clients".format(request.model_id, len(clients))) + logger.info("Sent model update request for model {} to {} clients".format(model_id, len(clients))) def request_model_validation(self, session_id, model_id, clients=[]): """Ask clients to validate the current global model. @@ -187,12 +187,12 @@ def request_model_validation(self, session_id, model_id, clients=[]): :type clients: list """ - request, clients = self._send_request_type(fedn.StatusType.MODEL_VALIDATION, session_id, model_id, clients) + clients = self._send_request_type(fedn.StatusType.MODEL_VALIDATION, session_id, model_id, clients) if len(clients) < 20: - logger.info("Sent model validation request for model {} to clients {}".format(request.model_id, clients)) + logger.info("Sent model validation request for model {} to clients {}".format(model_id, clients)) else: - logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients))) + logger.info("Sent model validation request for model {} to {} clients".format(model_id, len(clients))) def request_model_inference(self, session_id: str, model_id: str, clients: list = []) -> None: """Ask clients to perform inference on the model. @@ -205,12 +205,12 @@ def request_model_inference(self, session_id: str, model_id: str, clients: list :type clients: list """ - request, clients = self._send_request_type(fedn.StatusType.INFERENCE, session_id, model_id, clients) + clients = self._send_request_type(fedn.StatusType.INFERENCE, session_id, model_id, clients) if len(clients) < 20: - logger.info("Sent model inference request for model {} to clients {}".format(request.model_id, clients)) + logger.info("Sent model inference request for model {} to clients {}".format(model_id, clients)) else: - logger.info("Sent model inference request for model {} to {} clients".format(request.model_id, len(clients))) + logger.info("Sent model inference request for model {} to {} clients".format(model_id, len(clients))) def _send_request_type(self, request_type, session_id, model_id, config=None, clients=[]): """Send a request of a specific type to clients. @@ -223,41 +223,38 @@ def _send_request_type(self, request_type, session_id, model_id, config=None, cl :type config: dict :param clients: the clients to send the request to :type clients: list - :return: the request and the clients - :rtype: tuple + :return: the clients + :rtype: list """ - request = fedn.TaskRequest() - request.model_id = model_id - request.correlation_id = str(uuid.uuid4()) - request.timestamp = str(datetime.now()) - request.type = request_type - request.session_id = session_id - - request.sender.name = self.id - request.sender.role = fedn.COMBINER - - if request_type == fedn.StatusType.MODEL_UPDATE: - request.data = json.dumps(config) - if len(clients) == 0: + if len(clients) == 0: + if request_type == fedn.StatusType.MODEL_UPDATE: clients = self.get_active_trainers() - elif request_type == fedn.StatusType.MODEL_VALIDATION: - if len(clients) == 0: + elif request_type == fedn.StatusType.MODEL_VALIDATION: clients = self.get_active_validators() - elif request_type == fedn.StatusType.INFERENCE: - if len(clients) == 0: + elif request_type == fedn.StatusType.INFERENCE: # TODO: add inference clients type clients = self.get_active_validators() - for client in clients: + request = fedn.TaskRequest() + request.model_id = model_id + request.correlation_id = str(uuid.uuid4()) + request.timestamp = str(datetime.now()) + request.type = request_type + request.session_id = session_id + + request.sender.name = self.id + request.sender.role = fedn.COMBINER request.receiver.name = client request.receiver.role = fedn.WORKER + # Set the request data, not used in validation if request_type == fedn.StatusType.INFERENCE: presigned_url = self.repository.presigned_put_url(self.repository.inference_bucket, f"{client}/{session_id}") # TODO: in inference, request.data should also contain user-defined data/parameters request.data = json.dumps({"presigned_url": presigned_url}) + elif request_type == fedn.StatusType.MODEL_UPDATE: + request.data = json.dumps(config) self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE) - - return request, clients + return clients def get_active_trainers(self): """Get a list of active trainers.