diff --git a/fedn/network/clients/client_v2.py b/fedn/network/clients/client_v2.py index 6d1f52fb4..43edc9b79 100644 --- a/fedn/network/clients/client_v2.py +++ b/fedn/network/clients/client_v2.py @@ -10,7 +10,7 @@ from fedn.common.log_config import logger from fedn.network.clients.fedn_client import ConnectToApiResult, FednClient, GrpcConnectionOptions from fedn.network.combiner.modelservice import get_tmp_path -from fedn.utils.helpers.helpers import get_helper +from fedn.utils.helpers.helpers import get_helper, save_metadata def get_url(api_url: str, api_port: int) -> str: @@ -132,15 +132,15 @@ def set_helper(self, response: GrpcConnectionOptions = None): # Priority: helper_type from constructor > helper_type from response > default helper_type self.helper = get_helper(helper_type_to_use) - def on_train(self, in_model): - out_model, meta = self._process_training_request(in_model) + def on_train(self, in_model, client_settings): + out_model, meta = self._process_training_request(in_model, client_settings) return out_model, meta def on_validation(self, in_model): metrics = self._process_validation_request(in_model) return metrics - def _process_training_request(self, in_model: BytesIO) -> Tuple[BytesIO, dict]: + def _process_training_request(self, in_model: BytesIO, client_settings: dict) -> Tuple[BytesIO, dict]: """Process a training (model update) request. :param in_model: The model to be updated. @@ -156,6 +156,8 @@ def _process_training_request(self, in_model: BytesIO) -> Tuple[BytesIO, dict]: with open(inpath, "wb") as fh: fh.write(in_model.getbuffer()) + save_metadata(metadata=client_settings, filename=inpath) + outpath = self.helper.get_tmp_path() tic = time.time() diff --git a/fedn/network/clients/fedn_client.py b/fedn/network/clients/fedn_client.py index 828758131..3f7124cb2 100644 --- a/fedn/network/clients/fedn_client.py +++ b/fedn/network/clients/fedn_client.py @@ -225,8 +225,9 @@ def update_local_model(self, request): self.send_status(f"\t Starting processing of training request for model_id {model_id}", sesssion_id=request.session_id, sender_name=self.name) logger.info(f"Running train callback with model ID: {model_id}") + client_settings = json.loads(request.data).get("client_settings", {}) tic = time.time() - out_model, meta = self.train_callback(in_model) + out_model, meta = self.train_callback(in_model, client_settings) meta["processing_time"] = time.time() - tic tic = time.time()