diff --git a/python/ppc_model/interface/model_base.py b/python/ppc_model/interface/model_base.py index 050980ae..0e96565f 100644 --- a/python/ppc_model/interface/model_base.py +++ b/python/ppc_model/interface/model_base.py @@ -31,3 +31,7 @@ def load_model(self, file_path): class VerticalModel(ModelBase): mode = "VERTICAL" + + def __init__(self, ctx): + super().__init__(ctx) + self._all_feature_name = [] diff --git a/python/ppc_model/secure_lgbm/vertical/booster.py b/python/ppc_model/secure_lgbm/vertical/booster.py index a8868dba..076f855c 100644 --- a/python/ppc_model/secure_lgbm/vertical/booster.py +++ b/python/ppc_model/secure_lgbm/vertical/booster.py @@ -33,7 +33,6 @@ def __init__(self, ctx: SecureLGBMContext, dataset: SecureDataset) -> None: self._train_praba = None self._test_weights = None self._test_praba = None - random.seed(ctx.model_params.random_state) np.random.seed(ctx.model_params.random_state) @@ -253,11 +252,12 @@ def merge_model_file(self): lgbm_model['label_column'] = 'y' lgbm_model['participant_agency_list'] = [] for partner_index in range(0, len(self.ctx.participant_id_list)): - agency_info = {'agency': self.ctx.participant_id_list[partner_index]} - agency_info['fields'] = self.ctx._all_feature_name[partner_index] + agency_info = { + 'agency': self.ctx.participant_id_list[partner_index]} + agency_info['fields'] = self._all_feature_name[partner_index] lgbm_model['participant_agency_list'].append(agency_info) - - lgbm_model['model_dict'] = self.ctx.model_params + + lgbm_model['model_dict'] = self.ctx.model_params.get_all_params() model_text = {} with open(self.ctx.feature_bin_file, 'rb') as f: feature_bin_data = f.read() @@ -265,9 +265,10 @@ def merge_model_file(self): model_data = f.read() feature_bin_enc = encrypt_data(self.ctx.key, feature_bin_data) model_data_enc = encrypt_data(self.ctx.key, model_data) - + my_agency_id = self.ctx.components.config_data['AGENCY_ID'] - model_text[my_agency_id] = [cipher_to_base64(feature_bin_enc), cipher_to_base64(model_data_enc)] + model_text[my_agency_id] = [cipher_to_base64( + feature_bin_enc), cipher_to_base64(model_data_enc)] # 发送&接受文件 for partner_index in range(0, len(self.ctx.participant_id_list)): @@ -285,7 +286,8 @@ def merge_model_file(self): model_data_enc = self._receive_byte_data( self.ctx, f'{LGBMMessage.MODEL_DATA.value}_model_data', partner_index) model_text[self.ctx.participant_id_list[partner_index]] = \ - [cipher_to_base64(feature_bin_enc), cipher_to_base64(model_data_enc)] + [cipher_to_base64(feature_bin_enc), + cipher_to_base64(model_data_enc)] lgbm_model['model_text'] = model_text # 上传密文模型 @@ -300,7 +302,8 @@ def split_model_file(self): # 传入模型 my_agency_id = self.ctx.components.config_data['AGENCY_ID'] model_text = self.ctx.model_predict_algorithm['model_text'] - feature_bin_enc, model_data_enc = [base64_to_cipher(i) for i in model_text[my_agency_id]] + feature_bin_enc, model_data_enc = [ + base64_to_cipher(i) for i in model_text[my_agency_id]] # 解密文件 feature_bin_data = decrypt_data(self.ctx.key, feature_bin_enc) diff --git a/python/ppc_model/secure_lr/vertical/booster.py b/python/ppc_model/secure_lr/vertical/booster.py index c80f1a5d..23623787 100644 --- a/python/ppc_model/secure_lr/vertical/booster.py +++ b/python/ppc_model/secure_lr/vertical/booster.py @@ -271,10 +271,10 @@ def merge_model_file(self): lr_model['participant_agency_list'] = [] for partner_index in range(0, len(self.ctx.participant_id_list)): agency_info = {'agency': self.ctx.participant_id_list[partner_index]} - agency_info['fields'] = self.ctx._all_feature_name[partner_index] + agency_info['fields'] = self._all_feature_name[partner_index] lr_model['participant_agency_list'].append(agency_info) - lr_model['model_dict'] = self.ctx.model_params + lr_model['model_dict'] = self.ctx.model_params.get_all_params() model_text = {} with open(self.ctx.model_data_file, 'rb') as f: model_data = f.read()