diff --git a/python/ppc_model/interface/model_base.py b/python/ppc_model/interface/model_base.py index 050980ae..0e96565f 100644 --- a/python/ppc_model/interface/model_base.py +++ b/python/ppc_model/interface/model_base.py @@ -31,3 +31,7 @@ def load_model(self, file_path): class VerticalModel(ModelBase): mode = "VERTICAL" + + def __init__(self, ctx): + super().__init__(ctx) + self._all_feature_name = [] diff --git a/python/ppc_model/model_result/task_result_handler.py b/python/ppc_model/model_result/task_result_handler.py index 8c5d87de..02697ee2 100644 --- a/python/ppc_model/model_result/task_result_handler.py +++ b/python/ppc_model/model_result/task_result_handler.py @@ -206,6 +206,7 @@ class ModelJobResult: def __init__(self, xgb_job, job_id, components, property_name=DEFAULT_PROPERTY_NAME): self.job_id = job_id self.xgb_job = xgb_job + self.base_context = BaseContext(job_id, ".tmp") self.components = components self.logger = components.logger() self.property_name = property_name @@ -278,6 +279,12 @@ def load_feature_importance_table(self, feature_importance_path, property): self.feature_importance_table = {property: DataItem(name=property, data=feature_importance_table.to_dict(), type=DataType.TABLE).to_dict()} + def load_encrypted_model_data(self): + try: + return self.components.storage_client.get_data(self.base_context.remote_model_enc_file).decode("utf-8") + except: + pass + def load_iteration_metrics(self, iteration_path, property): if not self.xgb_job: return @@ -317,6 +324,7 @@ def __init__(self, task_result_request: TaskResultRequest, components): self.result_list = [] self.predict = False self.xgb_job = False + self.model_data = None if self.task_result_request.task_type == ModelTask.XGB_PREDICTING.name or self.task_result_request.task_type == ModelTask.LR_PREDICTING.name: self.predict = True if self.task_result_request.task_type == ModelTask.XGB_PREDICTING.name or self.task_result_request.task_type == ModelTask.XGB_TRAINING.name: @@ -330,7 +338,12 @@ def get_response(self): merged_result = dict() for result in self.result_list: merged_result.update(result.to_dict()) - response = {"jobPlanetResult": merged_result} + + if self.model_data is None: + response = {"jobPlanetResult": merged_result} + else: + response = {"jobPlanetResult": merged_result, + "modelData": self.model_data} return utils.make_response(PpcErrorCode.SUCCESS.get_code(), PpcErrorCode.SUCCESS.get_msg(), response) def _get_evaluation_result(self): @@ -371,6 +384,7 @@ def _get_evaluation_result(self): # the metrics iteration graph self.model.load_iteration_metrics( utils.METRICS_OVER_ITERATION_FILE, "IterationGraph") + self.model_data = self.model.load_encrypted_model_data() if self.predict: # the train evaluation result diff --git a/python/ppc_model/secure_lgbm/vertical/booster.py b/python/ppc_model/secure_lgbm/vertical/booster.py index a8868dba..076f855c 100644 --- a/python/ppc_model/secure_lgbm/vertical/booster.py +++ b/python/ppc_model/secure_lgbm/vertical/booster.py @@ -33,7 +33,6 @@ def __init__(self, ctx: SecureLGBMContext, dataset: SecureDataset) -> None: self._train_praba = None self._test_weights = None self._test_praba = None - random.seed(ctx.model_params.random_state) np.random.seed(ctx.model_params.random_state) @@ -253,11 +252,12 @@ def merge_model_file(self): lgbm_model['label_column'] = 'y' lgbm_model['participant_agency_list'] = [] for partner_index in range(0, len(self.ctx.participant_id_list)): - agency_info = {'agency': self.ctx.participant_id_list[partner_index]} - agency_info['fields'] = self.ctx._all_feature_name[partner_index] + agency_info = { + 'agency': self.ctx.participant_id_list[partner_index]} + agency_info['fields'] = self._all_feature_name[partner_index] lgbm_model['participant_agency_list'].append(agency_info) - - lgbm_model['model_dict'] = self.ctx.model_params + + lgbm_model['model_dict'] = self.ctx.model_params.get_all_params() model_text = {} with open(self.ctx.feature_bin_file, 'rb') as f: feature_bin_data = f.read() @@ -265,9 +265,10 @@ def merge_model_file(self): model_data = f.read() feature_bin_enc = encrypt_data(self.ctx.key, feature_bin_data) model_data_enc = encrypt_data(self.ctx.key, model_data) - + my_agency_id = self.ctx.components.config_data['AGENCY_ID'] - model_text[my_agency_id] = [cipher_to_base64(feature_bin_enc), cipher_to_base64(model_data_enc)] + model_text[my_agency_id] = [cipher_to_base64( + feature_bin_enc), cipher_to_base64(model_data_enc)] # 发送&接受文件 for partner_index in range(0, len(self.ctx.participant_id_list)): @@ -285,7 +286,8 @@ def merge_model_file(self): model_data_enc = self._receive_byte_data( self.ctx, f'{LGBMMessage.MODEL_DATA.value}_model_data', partner_index) model_text[self.ctx.participant_id_list[partner_index]] = \ - [cipher_to_base64(feature_bin_enc), cipher_to_base64(model_data_enc)] + [cipher_to_base64(feature_bin_enc), + cipher_to_base64(model_data_enc)] lgbm_model['model_text'] = model_text # 上传密文模型 @@ -300,7 +302,8 @@ def split_model_file(self): # 传入模型 my_agency_id = self.ctx.components.config_data['AGENCY_ID'] model_text = self.ctx.model_predict_algorithm['model_text'] - feature_bin_enc, model_data_enc = [base64_to_cipher(i) for i in model_text[my_agency_id]] + feature_bin_enc, model_data_enc = [ + base64_to_cipher(i) for i in model_text[my_agency_id]] # 解密文件 feature_bin_data = decrypt_data(self.ctx.key, feature_bin_enc) diff --git a/python/ppc_model/secure_lr/vertical/booster.py b/python/ppc_model/secure_lr/vertical/booster.py index c80f1a5d..23623787 100644 --- a/python/ppc_model/secure_lr/vertical/booster.py +++ b/python/ppc_model/secure_lr/vertical/booster.py @@ -271,10 +271,10 @@ def merge_model_file(self): lr_model['participant_agency_list'] = [] for partner_index in range(0, len(self.ctx.participant_id_list)): agency_info = {'agency': self.ctx.participant_id_list[partner_index]} - agency_info['fields'] = self.ctx._all_feature_name[partner_index] + agency_info['fields'] = self._all_feature_name[partner_index] lr_model['participant_agency_list'].append(agency_info) - lr_model['model_dict'] = self.ctx.model_params + lr_model['model_dict'] = self.ctx.model_params.get_all_params() model_text = {} with open(self.ctx.model_data_file, 'rb') as f: model_data = f.read()