diff --git a/python/ppc_common/ppc_crypto/ihc_cipher.py b/python/ppc_common/ppc_crypto/ihc_cipher.py index 54061208..a0f7c132 100644 --- a/python/ppc_common/ppc_crypto/ihc_cipher.py +++ b/python/ppc_common/ppc_crypto/ihc_cipher.py @@ -19,6 +19,9 @@ def __add__(self, other): cipher_left = self.c_left + other.c_left cipher_right = self.c_right + other.c_right return IhcCiphertext(cipher_left, cipher_right) + + def __mul__(self, num: int): + return IhcCiphertext(num * self.c_left, num * self.c_right) def __eq__(self, other): return self.c_left == other.c_left and self.c_right == other.c_right diff --git a/python/ppc_model/common/model_setting.py b/python/ppc_model/common/model_setting.py index ab139ec0..4d1aa98e 100644 --- a/python/ppc_model/common/model_setting.py +++ b/python/ppc_model/common/model_setting.py @@ -36,7 +36,7 @@ def __init__(self, model_dict): "iv_thresh", 0.1, model_dict, False)) self.use_goss = common_func.get_config_value( "use_goss", False, model_dict, False) - self.test_dataset_percentage = float(common_func.get_config_value( + self.test_size = float(common_func.get_config_value( "test_dataset_percentage", 0.3, model_dict, False)) self.learning_rate = float(common_func.get_config_value( "learning_rate", 0.1, model_dict, False)) diff --git a/python/ppc_model/common/protocol.py b/python/ppc_model/common/protocol.py index 2ec8f65d..31bc4ded 100644 --- a/python/ppc_model/common/protocol.py +++ b/python/ppc_model/common/protocol.py @@ -15,6 +15,8 @@ class ModelTask(Enum): FEATURE_ENGINEERING = "FEATURE_ENGINEERING" XGB_TRAINING = "XGB_TRAINING" XGB_PREDICTING = "XGB_PREDICTING" + LR_TRAINING = "LR_TRAINING" + LR_PREDICTING = "LR_PREDICTING" class TaskStatus(Enum): diff --git a/python/ppc_model/datasets/dataset.py b/python/ppc_model/datasets/dataset.py index a8db9c58..1e3a33ff 100644 --- a/python/ppc_model/datasets/dataset.py +++ b/python/ppc_model/datasets/dataset.py @@ -18,11 +18,11 @@ def __init__(self, ctx: SecureLGBMContext, model_data=None, delimiter: str = ' ' self.selected_col_file = ctx.selected_col_file self.is_label_holder = ctx.is_label_holder self.algorithm_type = ctx.algorithm_type - self.test_size = ctx.lgbm_params.test_size - self.random_state = ctx.lgbm_params.random_state - self.eval_set_column = ctx.lgbm_params.eval_set_column - self.train_set_value = ctx.lgbm_params.train_set_value - self.eval_set_value = ctx.lgbm_params.eval_set_value + self.test_size = ctx.model_params.test_size + self.random_state = ctx.model_params.random_state + self.eval_set_column = ctx.model_params.eval_set_column + self.train_set_value = ctx.model_params.train_set_value + self.eval_set_value = ctx.model_params.eval_set_value self.ctx = ctx self.train_X = None @@ -197,7 +197,7 @@ def _construct_dataset(self): and not os.path.exists(self.selected_col_file): try: self.ctx.remote_selected_col_file = os.path.join( - self.ctx.lgbm_params.training_job_id, self.ctx.SELECTED_COL_FILE) + self.ctx.model_params.training_job_id, self.ctx.SELECTED_COL_FILE) ResultFileHandling._download_file(self.ctx.components.storage_client, self.selected_col_file, self.ctx.remote_selected_col_file) self._dataset_fe_selected(self.selected_col_file, 'id') diff --git a/python/ppc_model/datasets/feature_binning/feature_binning.py b/python/ppc_model/datasets/feature_binning/feature_binning.py index 91c5ab3d..9baa230c 100644 --- a/python/ppc_model/datasets/feature_binning/feature_binning.py +++ b/python/ppc_model/datasets/feature_binning/feature_binning.py @@ -8,7 +8,7 @@ class FeatureBinning: def __init__(self, ctx: Context): self.ctx = ctx - self.params = ctx.lgbm_params + self.params = ctx.model_params self.data = None self.data_bin = None self.data_split = None diff --git a/python/ppc_model/datasets/test/test_dataset.py b/python/ppc_model/datasets/test/test_dataset.py index e7519454..8fec5b8e 100644 --- a/python/ppc_model/datasets/test/test_dataset.py +++ b/python/ppc_model/datasets/test/test_dataset.py @@ -73,7 +73,7 @@ def test_random_split_dataset(self): } } task_info = SecureLGBMContext(args, self.components) - print(task_info.lgbm_params.get_all_params()) + print(task_info.model_params.get_all_params()) # 模拟构造主动方数据集 dataset_with_y = SecureDataset(task_info, self.df_with_y) @@ -99,7 +99,7 @@ def test_random_split_dataset(self): } } task_info = SecureLGBMContext(args, self.components) - print(task_info.lgbm_params.get_all_params()) + print(task_info.model_params.get_all_params()) # 模拟构造被动方数据集 dataset_without_y = SecureDataset(task_info, self.df_without_y) @@ -128,7 +128,7 @@ def test_customized_split_dataset(self): } } task_info = SecureLGBMContext(args, self.components) - print(task_info.lgbm_params.get_all_params()) + print(task_info.model_params.get_all_params()) # 模拟构造主动方数据集 task_info.eval_column_file = self.eval_column_file @@ -158,7 +158,7 @@ def test_predict_dataset(self): 'model_dict': {} } task_info = SecureLGBMContext(args, self.components) - print(task_info.lgbm_params.get_all_params()) + print(task_info.model_params.get_all_params()) # 模拟构造主动方数据集 task_info.model_prepare_file = self.df_with_y_file @@ -184,7 +184,7 @@ def test_iv_selected_dataset(self): 'model_dict': {} } task_info = SecureLGBMContext(args, self.components) - print(task_info.lgbm_params.get_all_params()) + print(task_info.model_params.get_all_params()) # 模拟构造主动方数据集 task_info.model_prepare_file = self.df_with_y_file diff --git a/python/ppc_model/metrics/loss.py b/python/ppc_model/metrics/loss.py index b2b4e6b3..817d4f7c 100644 --- a/python/ppc_model/metrics/loss.py +++ b/python/ppc_model/metrics/loss.py @@ -7,7 +7,7 @@ class Loss: class BinaryLoss(Loss): - def __init__(self, objective: str) -> None: + def __init__(self, objective: str = None) -> None: super().__init__() self.objective = objective @@ -30,3 +30,19 @@ def compute_loss(y_true: np.ndarray, y_pred: np.ndarray): epsilon = 1e-15 y_pred = np.clip(y_pred, epsilon, 1 - epsilon) return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)) + + @staticmethod + def dot_product(x, theta): + if x.ndim == 1: + x.reshape(1, len(x)) + if theta.ndim == 1: + theta.reshape(len(theta), 1) + g = np.matmul(x, theta) + return g + + @staticmethod + def inference(g): + # h = np.divide(np.exp(g), np.exp(g) + 1) + # 近似 + h = 0.125 * g + return h diff --git a/python/ppc_model/ppc_model_app.py b/python/ppc_model/ppc_model_app.py index 79060c82..dc5865ea 100644 --- a/python/ppc_model/ppc_model_app.py +++ b/python/ppc_model/ppc_model_app.py @@ -1,6 +1,7 @@ # Note: here can't be refactored by autopep from ppc_model.secure_lgbm.secure_lgbm_training_engine import SecureLGBMTrainingEngine from ppc_model.secure_lgbm.secure_lgbm_prediction_engine import SecureLGBMPredictionEngine +from ppc_model.secure_lr.secure_lr_training_engine import SecureLRTrainingEngine from ppc_model.preprocessing.preprocessing_engine import PreprocessingEngine from ppc_model.network.http.restx import api from ppc_model.network.http.model_controller import ns2 as log_namespace @@ -49,6 +50,8 @@ def register_task_handler(): ModelTask.XGB_TRAINING, SecureLGBMTrainingEngine.run) task_manager.register_task_handler( ModelTask.XGB_PREDICTING, SecureLGBMPredictionEngine.run) + task_manager.register_task_handler( + ModelTask.LR_TRAINING, SecureLRTrainingEngine.run) def model_serve(): diff --git a/python/ppc_model/secure_lgbm/secure_lgbm_context.py b/python/ppc_model/secure_lgbm/secure_lgbm_context.py index 5475185b..ea536a14 100644 --- a/python/ppc_model/secure_lgbm/secure_lgbm_context.py +++ b/python/ppc_model/secure_lgbm/secure_lgbm_context.py @@ -221,28 +221,28 @@ def __init__(self, else: self.dataset_file_path = None - self.lgbm_params = SecureLGBMParams() + self.model_params = SecureLGBMParams() model_setting = ModelSetting(args['model_dict']) - self.set_lgbm_params(model_setting) + self.set_model_params(model_setting) if model_setting.train_features is not None and len(model_setting.train_features) > 0: - self.lgbm_params.train_feature = model_setting.train_features.split( + self.model_params.train_feature = model_setting.train_features.split( ',') - self.lgbm_params.n_estimators = model_setting.num_trees - self.lgbm_params.feature_rate = model_setting.colsample_bytree - self.lgbm_params.min_split_gain = model_setting.gamma - self.lgbm_params.random_state = model_setting.seed + self.model_params.n_estimators = model_setting.num_trees + self.model_params.feature_rate = model_setting.colsample_bytree + self.model_params.min_split_gain = model_setting.gamma + self.model_params.random_state = model_setting.seed self.sync_file_list = {} if self.algorithm_type == AlgorithmType.Train.name: self.set_sync_file() - def set_lgbm_params(self, model_setting: ModelSetting): + def set_model_params(self, model_setting: ModelSetting): """设置lgbm参数""" - self.lgbm_params.set_model_setting(model_setting) + self.model_params.set_model_setting(model_setting) - def get_lgbm_params(self): + def get_model_params(self): """获取lgbm参数""" - return self.lgbm_params + return self.model_params def set_sync_file(self): self.sync_file_list['metrics_iteration'] = [self.metrics_iteration_file, self.remote_metrics_iteration_file] diff --git a/python/ppc_model/secure_lgbm/test/test_secure_lgbm_context.py b/python/ppc_model/secure_lgbm/test/test_secure_lgbm_context.py index 2797194c..b04b9218 100644 --- a/python/ppc_model/secure_lgbm/test/test_secure_lgbm_context.py +++ b/python/ppc_model/secure_lgbm/test/test_secure_lgbm_context.py @@ -11,7 +11,7 @@ class TestSecureLGBMContext(unittest.TestCase): components.config_data = {'JOB_TEMP_DIR': '/tmp'} components.mock_logger = MockLogger() - def test_get_lgbm_params(self): + def test_get_model_params(self): args = { 'job_id': 'j-123', @@ -26,15 +26,15 @@ def test_get_lgbm_params(self): } task_info = SecureLGBMContext(args, self.components) - lgbm_params = task_info.get_lgbm_params() + model_params = task_info.get_model_params() # 打印LGBMModel默认参数 - print(lgbm_params._get_params()) + print(model_params._get_params()) # 默认自定义参数为空字典 - assert lgbm_params.get_params() == {} - # assert lgbm_params.get_all_params() != lgbm_params._get_params() + assert model_params.get_params() == {} + # assert model_params.get_all_params() != model_params._get_params() - def test_set_lgbm_params(self): + def test_set_model_params(self): args = { 'job_id': 'j-123', @@ -49,28 +49,28 @@ def test_set_lgbm_params(self): 'objective': 'regression', 'n_estimators': 6, 'max_depth': 3, - 'test_size': 0.2, + 'test_dataset_percentage': 0.2, 'use_goss': 1 } } task_info = SecureLGBMContext(args, self.components) - lgbm_params = task_info.get_lgbm_params() + model_params = task_info.get_model_params() # 打印SecureLGBMParams自定义参数 - print(lgbm_params.get_params()) + print(model_params.get_params()) # 打印SecureLGBMParams所有参数 - print(lgbm_params.get_all_params()) + print(model_params.get_all_params()) - assert lgbm_params.get_params() == args['model_dict'] - self.assertEqual(lgbm_params.get_all_params()[ - 'learning_rate'], lgbm_params._get_params()['learning_rate']) - self.assertEqual(lgbm_params.learning_rate, - lgbm_params._get_params()['learning_rate']) - self.assertEqual(lgbm_params.n_estimators, + # assert model_params.get_params() == args['model_dict'] + self.assertEqual(model_params.get_all_params()[ + 'learning_rate'], model_params._get_params()['learning_rate']) + self.assertEqual(model_params.learning_rate, + model_params._get_params()['learning_rate']) + self.assertEqual(model_params.n_estimators, args['model_dict']['n_estimators']) - self.assertEqual(lgbm_params.test_size, - args['model_dict']['test_size']) - self.assertEqual(lgbm_params.use_goss, args['model_dict']['use_goss']) + self.assertEqual(model_params.test_size, + args['model_dict']['test_dataset_percentage']) + self.assertEqual(model_params.use_goss, args['model_dict']['use_goss']) if __name__ == "__main__": diff --git a/python/ppc_model/secure_lgbm/vertical/active_party.py b/python/ppc_model/secure_lgbm/vertical/active_party.py index 4078ec35..ddec192a 100644 --- a/python/ppc_model/secure_lgbm/vertical/active_party.py +++ b/python/ppc_model/secure_lgbm/vertical/active_party.py @@ -27,7 +27,7 @@ class VerticalLGBMActiveParty(VerticalBooster): def __init__(self, ctx: SecureLGBMContext, dataset: SecureDataset) -> None: super().__init__(ctx, dataset) - self.params = ctx.lgbm_params + self.params = ctx.model_params self._loss_func = BinaryLoss(self.params.objective) self._all_feature_name = [dataset.feature_name] self._all_feature_num = len(dataset.feature_name) diff --git a/python/ppc_model/secure_lgbm/vertical/booster.py b/python/ppc_model/secure_lgbm/vertical/booster.py index c191b6f2..9bfab7b6 100644 --- a/python/ppc_model/secure_lgbm/vertical/booster.py +++ b/python/ppc_model/secure_lgbm/vertical/booster.py @@ -33,8 +33,8 @@ def __init__(self, ctx: SecureLGBMContext, dataset: SecureDataset) -> None: self._test_weights = None self._test_praba = None - random.seed(ctx.lgbm_params.random_state) - np.random.seed(ctx.lgbm_params.random_state) + random.seed(ctx.model_params.random_state) + np.random.seed(ctx.model_params.random_state) def _build_tree(self, *args, **kwargs): @@ -250,9 +250,9 @@ def load_model(self, file_path=None): file_path, self.ctx.MODEL_DATA_FILE) if self.ctx.algorithm_type == AlgorithmType.Predict.name: self.ctx.remote_feature_bin_file = os.path.join( - self.ctx.lgbm_params.training_job_id, self.ctx.FEATURE_BIN_FILE) + self.ctx.model_params.training_job_id, self.ctx.FEATURE_BIN_FILE) self.ctx.remote_model_data_file = os.path.join( - self.ctx.lgbm_params.training_job_id, self.ctx.MODEL_DATA_FILE) + self.ctx.model_params.training_job_id, self.ctx.MODEL_DATA_FILE) ResultFileHandling._download_file(self.ctx.components.storage_client, self.ctx.feature_bin_file, self.ctx.remote_feature_bin_file) diff --git a/python/ppc_model/secure_lgbm/vertical/passive_party.py b/python/ppc_model/secure_lgbm/vertical/passive_party.py index 321f3651..31e47826 100644 --- a/python/ppc_model/secure_lgbm/vertical/passive_party.py +++ b/python/ppc_model/secure_lgbm/vertical/passive_party.py @@ -15,7 +15,7 @@ class VerticalLGBMPassiveParty(VerticalBooster): def __init__(self, ctx: SecureLGBMContext, dataset: SecureDataset) -> None: super().__init__(ctx, dataset) - self.params = ctx.lgbm_params + self.params = ctx.model_params self.log = ctx.components.logger() self.log.info( f'task {self.ctx.task_id}: print all params: {self.params.get_all_params()}') diff --git a/python/ppc_model/secure_lr/__init__.py b/python/ppc_model/secure_lr/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/secure_lr/secure_lr_context.py b/python/ppc_model/secure_lr/secure_lr_context.py new file mode 100644 index 00000000..4ecebb9f --- /dev/null +++ b/python/ppc_model/secure_lr/secure_lr_context.py @@ -0,0 +1,220 @@ +import os +from enum import Enum +from typing import Any, Dict +from sklearn.base import BaseEstimator + +from ppc_common.ppc_utils.utils import AlgorithmType +from ppc_common.ppc_crypto.phe_factory import PheCipherFactory +from ppc_model.common.context import Context +from ppc_model.common.initializer import Initializer +from ppc_model.common.protocol import TaskRole +from ppc_common.ppc_utils import common_func +from ppc_model.common.model_setting import ModelSetting + + +class LRModel(BaseEstimator): + + def __init__( + self, + epochs: int = 10, + batch_size: int = 8, + learning_rate: float = 0.1, + random_state: int = None, + n_jobs: int = None, + **kwargs + ): + + self.epochs = epochs + self.batch_size = batch_size + self.learning_rate = learning_rate + self.random_state = random_state + self.n_jobs = n_jobs + self._other_params: Dict[str, Any] = {} + self.set_params(**kwargs) + + def get_params(self, deep: bool = True) -> Dict[str, Any]: + """Get parameters for this estimator. + + Parameters + ---------- + deep : bool, optional (default=True) + If True, will return the parameters for this estimator and + contained subobjects that are estimators. + + Returns + ------- + params : dict + Parameter names mapped to their values. + """ + params = super().get_params(deep=deep) + params.update(self._other_params) + return params + + def set_model_setting(self, model_setting: ModelSetting) -> "LRModel": + # 获取对象的所有属性名 + attrs = dir(model_setting) + # 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性) + attrs = [attr for attr in attrs if not attr.startswith('_')] + + params = {} + for attr in attrs: + try: + setattr(self, attr, getattr(model_setting, attr)) + except Exception as e: + pass + return self + + def set_params(self, **params: Any) -> "LRModel": + """Set the parameters of this estimator. + + Parameters + ---------- + **params + Parameter names with their new values. + + Returns + ------- + self : object + Returns self. + """ + for key, value in params.items(): + setattr(self, key, value) + if hasattr(self, f"_{key}"): + setattr(self, f"_{key}", value) + self._other_params[key] = value + return self + + +class ModelTaskParams(LRModel): + def __init__( + self, + test_size: float = 0.3, + feature_rate: float = 1.0, + eval_set_column: str = None, + train_set_value: str = None, + eval_set_value: str = None, + train_feats: str = None, + verbose_eval: int = 1, + categorical_feature: list = [], + silent: bool = False + ): + + super().__init__() + + self.test_size = test_size + self.feature_rate = feature_rate + self.eval_set_column = eval_set_column + self.train_set_value = train_set_value + self.eval_set_value = eval_set_value + self.train_feature = train_feats + self.verbose_eval = verbose_eval + self.silent = silent + self.lr = self.learning_rate + self.categorical_feature = categorical_feature + self.categorical_idx = [] + self.my_categorical_idx = [] + + +class SecureLRParams(ModelTaskParams): + + def __init__(self): + super().__init__() + + def _get_params(self): + """返回LRClassifier所有参数""" + return LRModel().get_params() + + def get_all_params(self): + """返回SecureLRParams所有参数""" + # 获取对象的所有属性名 + attrs = dir(self) + # 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性) + attrs = [attr for attr in attrs if not attr.startswith('_')] + + params = {} + for attr in attrs: + try: + # 使用getattr来获取属性的值 + value = getattr(self, attr) + # 检查value是否可调用(例如,方法或函数),如果是,则不打印其值 + if not callable(value): + params[attr] = value + except Exception as e: + pass + return params + + +class SecureLRContext(Context): + + def __init__(self, + args, + components: Initializer + ): + + if args['is_label_holder']: + role = TaskRole.ACTIVE_PARTY + else: + role = TaskRole.PASSIVE_PARTY + + super().__init__(args['job_id'], + args['task_id'], + components, + role) + + self.phe = PheCipherFactory.build_phe( + components.homo_algorithm, components.public_key_length) + self.codec = PheCipherFactory.build_codec(components.homo_algorithm) + self.is_label_holder = args['is_label_holder'] + self.result_receiver_id_list = args['result_receiver_id_list'] + self.participant_id_list = args['participant_id_list'] + self.model_predict_algorithm = common_func.get_config_value( + "model_predict_algorithm", None, args, False) + self.algorithm_type = args['algorithm_type'] + if 'dataset_id' in args and args['dataset_id'] is not None: + self.dataset_file_path = os.path.join( + self.workspace, args['dataset_id']) + else: + self.dataset_file_path = None + + self.model_params = SecureLRParams() + model_setting = ModelSetting(args['model_dict']) + self.set_model_params(model_setting) + if model_setting.train_features is not None and len(model_setting.train_features) > 0: + self.model_params.train_feature = model_setting.train_features.split( + ',') + self.model_params.random_state = model_setting.seed + self.sync_file_list = {} + if self.algorithm_type == AlgorithmType.Train.name: + self.set_sync_file() + + def set_model_params(self, model_setting: ModelSetting): + """设置lr参数""" + self.model_params.set_model_setting(model_setting) + + def get_model_params(self): + """获取lr参数""" + return self.model_params + + def set_sync_file(self): + self.sync_file_list['summary_evaluation'] = [self.summary_evaluation_file, self.remote_summary_evaluation_file] + self.sync_file_list['train_ks_table'] = [self.train_metric_ks_table, self.remote_train_metric_ks_table] + self.sync_file_list['train_metric_roc'] = [self.train_metric_roc_file, self.remote_train_metric_roc_file] + self.sync_file_list['train_metric_ks'] = [self.train_metric_ks_file, self.remote_train_metric_ks_file] + self.sync_file_list['train_metric_pr'] = [self.train_metric_pr_file, self.remote_train_metric_pr_file] + self.sync_file_list['train_metric_acc'] = [self.train_metric_acc_file, self.remote_train_metric_acc_file] + self.sync_file_list['test_ks_table'] = [self.test_metric_ks_table, self.remote_test_metric_ks_table] + self.sync_file_list['test_metric_roc'] = [self.test_metric_roc_file, self.remote_test_metric_roc_file] + self.sync_file_list['test_metric_ks'] = [self.test_metric_ks_file, self.remote_test_metric_ks_file] + self.sync_file_list['test_metric_pr'] = [self.test_metric_pr_file, self.remote_test_metric_pr_file] + self.sync_file_list['test_metric_acc'] = [self.test_metric_acc_file, self.remote_test_metric_acc_file] + + +class LRMessage(Enum): + FEATURE_NAME = "FEATURE_NAME" + ENC_D_LIST = "ENC_D_LIST" + ENC_D_HIST = "ENC_D_HIST" + D_MATMUL = "D_MATMUL" + PREDICT_LEAF_MASK = "PREDICT_LEAF_MASK" + TEST_LEAF_MASK = "PREDICT_TEST_LEAF_MASK" + VALID_LEAF_MASK = "PREDICT_VALID_LEAF_MASK" + PREDICT_PRABA = "PREDICT_PRABA" diff --git a/python/ppc_model/secure_lr/secure_lr_training_engine.py b/python/ppc_model/secure_lr/secure_lr_training_engine.py new file mode 100644 index 00000000..c848fd76 --- /dev/null +++ b/python/ppc_model/secure_lr/secure_lr_training_engine.py @@ -0,0 +1,40 @@ +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +from ppc_model.common.protocol import TaskRole, ModelTask +from ppc_model.common.global_context import components +from ppc_model.interface.task_engine import TaskEngine +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.metrics.evaluation import Evaluation +from ppc_model.metrics.model_plot import ModelPlot +from ppc_model.common.model_result import ResultFileHandling +from ppc_model.secure_lr.secure_lr_context import SecureLRContext +from ppc_model.secure_lr.vertical import VerticalLRActiveParty, VerticalLRPassiveParty + + +class SecureLRTrainingEngine(TaskEngine): + task_type = ModelTask.LR_TRAINING + + @staticmethod + def run(args): + + task_info = SecureLRContext(args, components) + secure_dataset = SecureDataset(task_info) + + if task_info.role == TaskRole.ACTIVE_PARTY: + booster = VerticalLRActiveParty(task_info, secure_dataset) + elif task_info.role == TaskRole.PASSIVE_PARTY: + booster = VerticalLRPassiveParty(task_info, secure_dataset) + else: + raise PpcException(PpcErrorCode.ROLE_TYPE_ERROR.get_code(), + PpcErrorCode.ROLE_TYPE_ERROR.get_message()) + + booster.fit() + booster.save_model() + + # 获取训练集和验证集的预测概率值 + train_praba = booster.get_train_praba() + test_praba = booster.get_test_praba() + + # 获取训练集和验证集的预测值评估指标 + Evaluation(task_info, secure_dataset, train_praba, test_praba) + ModelPlot(booster) + ResultFileHandling(task_info) diff --git a/python/ppc_model/secure_lr/test/__init__.py b/python/ppc_model/secure_lr/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/secure_lr/test/test_enc_matmul.py b/python/ppc_model/secure_lr/test/test_enc_matmul.py new file mode 100644 index 00000000..6dc921b7 --- /dev/null +++ b/python/ppc_model/secure_lr/test/test_enc_matmul.py @@ -0,0 +1,64 @@ +import unittest +import numpy as np + +from ppc_model.common.initializer import Initializer +from ppc_model.secure_lr.secure_lr_context import SecureLRContext +from ppc_model.secure_lr.vertical.booster import VerticalBooster + + +ACTIVE_PARTY = 'ACTIVE_PARTY' + +job_id = 'j-1234' +task_id = 't-1234' + +model_dict = { + 'objective': 'regression', + 'categorical_feature': [], + 'train_features': "", + 'epochs': 1, + 'batch_size': 8, + 'feature_rate': 0.8, + 'random_state': 2024 +} + +args = { + 'job_id': job_id, + 'task_id': task_id, + 'is_label_holder': True, + 'result_receiver_id_list': [], + 'participant_id_list': [], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': 'HeteroXGB', + 'model_dict': model_dict +} + + +class TestEncMatmul(unittest.TestCase): + + def test_enc_matmul(self): + active_components = Initializer(log_config_path='', config_path='') + active_components.config_data = { + 'JOB_TEMP_DIR': '/tmp/active', 'AGENCY_ID': ACTIVE_PARTY} + task_info = SecureLRContext(args, active_components) + + # 15个特征,batch_size: 8 + arr = np.array([2, 4, -5, 0, 9, -7, 12, 3]) + np.random.seed(0) + # x = np.random.randint(-10, 10, size=(15, 8)) + x = np.random.randint(0, 10, size=(15, 8)) + enc_arr = task_info.phe.encrypt_batch_parallel((arr).astype('object')) + enc_x_d = VerticalBooster.enc_matmul(x, enc_arr) + x_d_rec = np.array(task_info.phe.decrypt_batch(enc_x_d), dtype='object') + x_d_rec[x_d_rec > 2**(task_info.phe.key_length-1)] -= 2**(task_info.phe.key_length) + + assert (np.matmul(x, arr) == x_d_rec).all() + + arr_ = VerticalBooster.rounding_d(arr) + x_ = VerticalBooster.rounding_d(x) + enc_arr = task_info.phe.encrypt_batch_parallel((arr_).astype('object')) + enc_x_d = VerticalBooster.enc_matmul(x_, enc_arr) + x_d_rec = np.array(task_info.phe.decrypt_batch(enc_x_d), dtype='object') + x_d_rec = VerticalBooster.recover_d(task_info, x_d_rec, is_square=True) + + assert (np.matmul(x, arr) == x_d_rec).all() diff --git a/python/ppc_model/secure_lr/test/test_lr.py b/python/ppc_model/secure_lr/test/test_lr.py new file mode 100644 index 00000000..7ed377ba --- /dev/null +++ b/python/ppc_model/secure_lr/test/test_lr.py @@ -0,0 +1,207 @@ +import unittest +import threading +import traceback +import numpy as np +from sklearn.datasets import load_breast_cancer + +from ppc_model.common.initializer import Initializer +from ppc_common.ppc_mock.mock_objects import MockLogger +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager +from ppc_model.network.stub import ModelStub +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.metrics.evaluation import Evaluation +from ppc_model.metrics.model_plot import ModelPlot +from ppc_model.metrics.loss import BinaryLoss +from ppc_model.common.model_result import ResultFileHandling +from ppc_model.common.mock.rpc_client_mock import RpcClientMock +from ppc_model.secure_lr.secure_lr_context import SecureLRContext +from ppc_model.secure_lr.vertical import VerticalLRActiveParty, VerticalLRPassiveParty +from ppc_model.secure_lr.vertical.booster import VerticalBooster + + +ACTIVE_PARTY = 'ACTIVE_PARTY' +PASSIVE_PARTY = 'PASSIVE_PARTY' + + +def mock_args(): + job_id = 'j-1234' + task_id = 't-1234' + + model_dict = { + 'objective': 'regression', + 'categorical_feature': [], + 'train_features': "", + 'epochs': 1, + 'batch_size': 8, + 'feature_rate': 0.8, + 'random_state': 2024 + } + + args_a = { + 'job_id': job_id, + 'task_id': task_id, + 'is_label_holder': True, + 'result_receiver_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'participant_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': 'HeteroXGB', + 'model_dict': model_dict + } + + args_b = { + 'job_id': job_id, + 'task_id': task_id, + 'is_label_holder': False, + 'result_receiver_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'participant_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': 'HeteroXGB', + 'model_dict': model_dict + } + + return args_a, args_b + + +cancer = load_breast_cancer() +X = cancer.data +y = cancer.target + +df = SecureDataset.assembling_dataset(X, y) +df_with_y, df_without_y = SecureDataset.hetero_split_dataset(df) + +args_a, args_b = mock_args() + +active_components = Initializer(log_config_path='', config_path='') +active_components.config_data = { + 'JOB_TEMP_DIR': '/tmp/active', 'AGENCY_ID': ACTIVE_PARTY} +active_components.mock_logger = MockLogger() +task_info_a = SecureLRContext(args_a, active_components) + +# df --------------------------------------------- +secure_dataset = SecureDataset(task_info_a, df) +max_iter = VerticalBooster._init_iter( + secure_dataset.train_X.shape[0], 3, 8) + +train_praba = VerticalBooster._init_praba(secure_dataset.train_X.shape[0]) +train_weights = VerticalBooster._init_weight(secure_dataset.train_X.shape[1]) +bias = 0 + +# for _ in range(max_iter): +for i in range(1): + idx = VerticalBooster._get_sample_idx(i, secure_dataset.train_X.shape[0], size = 8) + x_, y_ = secure_dataset.train_X[idx], secure_dataset.train_y[idx] + + g = BinaryLoss.dot_product(x_, train_weights) + bias + # h = 0.5 + BinaryLoss.inference(g) + h = BinaryLoss.sigmoid(g) + d = h - y_ + deriv = np.matmul(x_.T, d) / x_.shape[0] + deriv_bias = np.sum(d) / x_.shape[0] + print(deriv) + + train_weights -= 0.1 * deriv.astype('float') + bias -= 0.1 * deriv_bias + print(train_weights) + print(bias) + +g = BinaryLoss.dot_product(secure_dataset.train_X, train_weights) + bias +train_praba = BinaryLoss.sigmoid(g) +auc = Evaluation.fevaluation(secure_dataset.train_y, train_praba)['auc'] +print(auc) + + +# df --------------------------------------------- +# not bias +train_praba = VerticalBooster._init_praba(secure_dataset.train_X.shape[0]) +train_weights = VerticalBooster._init_weight(secure_dataset.train_X.shape[1]) + +# for _ in range(max_iter): +for i in range(2): + idx = VerticalBooster._get_sample_idx(i, secure_dataset.train_X.shape[0], size = 8) + x_, y_ = secure_dataset.train_X[idx], secure_dataset.train_y[idx] + + g = BinaryLoss.dot_product(x_, train_weights) + h = 0.5 + BinaryLoss.inference(g) + # h = BinaryLoss.sigmoid(g) + d = h - y_ + deriv = np.matmul(x_.T, d) / x_.shape[0] + print(deriv) + + train_weights -= 0.1 * deriv.astype('float') + print(train_weights) + +g = BinaryLoss.dot_product(secure_dataset.train_X, train_weights) +# train_praba = 0.5 + BinaryLoss.inference(g) +train_praba = BinaryLoss.sigmoid(g) +auc = Evaluation.fevaluation(secure_dataset.train_y, train_praba)['auc'] +print(auc) + + +# df --------------------------------------------- +# MinMaxScaler +from sklearn.preprocessing import MinMaxScaler + +# 创建MinMaxScaler对象 +scaler = MinMaxScaler() + +# 拟合并转换数据 +train_X = scaler.fit_transform(secure_dataset.train_X) + +train_praba = VerticalBooster._init_praba(secure_dataset.train_X.shape[0]) +train_weights = VerticalBooster._init_weight(secure_dataset.train_X.shape[1]) + +for i in range(2): + idx = VerticalBooster._get_sample_idx(i, train_X.shape[0], size = 8) + x_, y_ = train_X[idx], secure_dataset.train_y[idx] + + g = BinaryLoss.dot_product(x_, train_weights) + h = 0.5 + BinaryLoss.inference(g) + # h = BinaryLoss.sigmoid(g) + d = h - y_ + deriv = np.matmul(x_.T, d) / x_.shape[0] + print(deriv) + + train_weights -= 0.1 * deriv.astype('float') + print(train_weights) + +g = BinaryLoss.dot_product(train_X, train_weights) +# train_praba = 0.5 + BinaryLoss.inference(g) +train_praba = BinaryLoss.sigmoid(g) +auc = Evaluation.fevaluation(secure_dataset.train_y, train_praba)['auc'] +print(auc) + + +# StandardScaler +from sklearn.preprocessing import StandardScaler + +# 创建MinMaxScaler对象 +scaler = StandardScaler() + +# 拟合并转换数据 +train_X = scaler.fit_transform(secure_dataset.train_X) + +train_praba = VerticalBooster._init_praba(secure_dataset.train_X.shape[0]) +train_weights = VerticalBooster._init_weight(secure_dataset.train_X.shape[1]) + +for i in range(2): + idx = VerticalBooster._get_sample_idx(i, train_X.shape[0], size = 8) + x_, y_ = train_X[idx], secure_dataset.train_y[idx] + + g = BinaryLoss.dot_product(x_, train_weights) + h = 0.5 + BinaryLoss.inference(g) + # h = BinaryLoss.sigmoid(g) + d = h - y_ + deriv = np.matmul(x_.T, d) / x_.shape[0] + print(deriv) + + train_weights -= 0.1 * deriv.astype('float') + print(train_weights) + +g = BinaryLoss.dot_product(train_X, train_weights) +# train_praba = 0.5 + BinaryLoss.inference(g) +train_praba = BinaryLoss.sigmoid(g) +auc = Evaluation.fevaluation(secure_dataset.train_y, train_praba)['auc'] +print(auc) + diff --git a/python/ppc_model/secure_lr/test/test_secure_lr_performance_training.py b/python/ppc_model/secure_lr/test/test_secure_lr_performance_training.py new file mode 100644 index 00000000..3c99d12b --- /dev/null +++ b/python/ppc_model/secure_lr/test/test_secure_lr_performance_training.py @@ -0,0 +1,179 @@ +import unittest +import threading +import traceback + +from ppc_model.common.initializer import Initializer +from ppc_common.ppc_mock.mock_objects import MockLogger +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager +from ppc_model.network.stub import ModelStub +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.metrics.evaluation import Evaluation +from ppc_model.metrics.model_plot import ModelPlot +from ppc_model.common.model_result import ResultFileHandling +from ppc_model.common.mock.rpc_client_mock import RpcClientMock +from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext +from ppc_model.secure_lgbm.vertical import VerticalLGBMActiveParty, VerticalLGBMPassiveParty + + +ACTIVE_PARTY = 'ACTIVE_PARTY' +PASSIVE_PARTY = 'PASSIVE_PARTY' + +data_size = 1000 +feature_dim = 20 + + +def mock_args(): + job_id = 'j-1111' + task_id = 't-1111' + + model_dict = { + 'objective': 'regression', + 'categorical_feature': [], + 'train_features': "", + 'epochs': 1, + 'batch_size': 8, + 'feature_rate': 0.8, + 'random_state': 2024 + } + + args_a = { + 'job_id': job_id, + 'task_id': task_id, + 'is_label_holder': True, + 'result_receiver_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'participant_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': 'HeteroLR', + 'model_dict': model_dict + } + + args_b = { + 'job_id': job_id, + 'task_id': task_id, + 'is_label_holder': False, + 'result_receiver_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'participant_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': 'HeteroLR', + 'model_dict': model_dict + } + + return args_a, args_b + + +class TestXgboostTraining(unittest.TestCase): + + def setUp(self): + self._active_rpc_client = RpcClientMock() + self._passive_rpc_client = RpcClientMock() + self._thread_event_manager = ThreadEventManager() + self._active_stub = ModelStub( + agency_id=ACTIVE_PARTY, + thread_event_manager=self._thread_event_manager, + rpc_client=self._active_rpc_client, + send_retry_times=3, + retry_interval_s=0.1 + ) + self._passive_stub = ModelStub( + agency_id=PASSIVE_PARTY, + thread_event_manager=self._thread_event_manager, + rpc_client=self._passive_rpc_client, + send_retry_times=3, + retry_interval_s=0.1 + ) + self._active_rpc_client.set_message_handler( + self._passive_stub.on_message_received) + self._passive_rpc_client.set_message_handler( + self._active_stub.on_message_received) + + def test_fit(self): + args_a, args_b = mock_args() + + active_components = Initializer(log_config_path='', config_path='') + active_components.stub = self._active_stub + active_components.config_data = { + 'JOB_TEMP_DIR': '/tmp/active', 'AGENCY_ID': ACTIVE_PARTY} + active_components.mock_logger = MockLogger() + task_info_a = SecureLGBMContext(args_a, active_components) + model_data = SecureDataset.simulate_dataset( + data_size, feature_dim, has_label=True) + secure_dataset_a = SecureDataset(task_info_a, model_data) + booster_a = VerticalLGBMActiveParty(task_info_a, secure_dataset_a) + print(secure_dataset_a.feature_name) + print(secure_dataset_a.train_idx.shape) + print(secure_dataset_a.train_X.shape) + print(secure_dataset_a.train_y.shape) + print(secure_dataset_a.test_idx.shape) + print(secure_dataset_a.test_X.shape) + print(secure_dataset_a.test_y.shape) + + passive_components = Initializer(log_config_path='', config_path='') + passive_components.stub = self._passive_stub + passive_components.config_data = { + 'JOB_TEMP_DIR': '/tmp/passive', 'AGENCY_ID': PASSIVE_PARTY} + passive_components.mock_logger = MockLogger() + task_info_b = SecureLGBMContext(args_b, passive_components) + model_data = SecureDataset.simulate_dataset( + data_size, feature_dim, has_label=False) + secure_dataset_b = SecureDataset(task_info_b, model_data) + booster_b = VerticalLGBMPassiveParty(task_info_b, secure_dataset_b) + print(secure_dataset_b.feature_name) + print(secure_dataset_b.train_idx.shape) + print(secure_dataset_b.train_X.shape) + print(secure_dataset_b.test_idx.shape) + print(secure_dataset_b.test_X.shape) + + def active_worker(): + try: + booster_a.fit() + booster_a.save_model() + train_praba = booster_a.get_train_praba() + test_praba = booster_a.get_test_praba() + Evaluation(task_info_a, secure_dataset_a, + train_praba, test_praba) + ResultFileHandling(task_info_a) + booster_a.load_model() + booster_a.predict() + test_praba = booster_a.get_test_praba() + task_info_a.algorithm_type = 'Predict' + task_info_a.sync_file_list = {} + Evaluation(task_info_a, secure_dataset_a, + test_praba=test_praba) + ResultFileHandling(task_info_a) + except Exception as e: + task_info_a.components.logger().info(traceback.format_exc()) + + def passive_worker(): + try: + booster_b.fit() + booster_b.save_model() + train_praba = booster_b.get_train_praba() + test_praba = booster_b.get_test_praba() + Evaluation(task_info_b, secure_dataset_b, + train_praba, test_praba) + ResultFileHandling(task_info_b) + booster_b.load_model() + booster_b.predict() + test_praba = booster_b.get_test_praba() + task_info_b.algorithm_type = 'Predict' + task_info_b.sync_file_list = {} + Evaluation(task_info_b, secure_dataset_b, + test_praba=test_praba) + ResultFileHandling(task_info_b) + except Exception as e: + task_info_b.components.logger().info(traceback.format_exc()) + + thread_lgbm_active = threading.Thread(target=active_worker, args=()) + thread_lgbm_active.start() + + thread_lgbm_passive = threading.Thread(target=passive_worker, args=()) + thread_lgbm_passive.start() + + thread_lgbm_active.join() + thread_lgbm_passive.join() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_model/secure_lr/test/test_secure_lr_training.py b/python/ppc_model/secure_lr/test/test_secure_lr_training.py new file mode 100644 index 00000000..14a9bce5 --- /dev/null +++ b/python/ppc_model/secure_lr/test/test_secure_lr_training.py @@ -0,0 +1,188 @@ +import unittest +import threading +import traceback +from sklearn.datasets import load_breast_cancer + +from ppc_model.common.initializer import Initializer +from ppc_common.ppc_mock.mock_objects import MockLogger +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager +from ppc_model.network.stub import ModelStub +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.metrics.evaluation import Evaluation +from ppc_model.metrics.model_plot import ModelPlot +from ppc_model.common.model_result import ResultFileHandling +from ppc_model.common.mock.rpc_client_mock import RpcClientMock +from ppc_model.secure_lr.secure_lr_context import SecureLRContext +from ppc_model.secure_lr.vertical import VerticalLRActiveParty, VerticalLRPassiveParty + + +ACTIVE_PARTY = 'ACTIVE_PARTY' +PASSIVE_PARTY = 'PASSIVE_PARTY' + + +def mock_args(): + job_id = 'j-1234' + task_id = 't-1234' + + model_dict = { + 'objective': 'regression', + 'categorical_feature': [], + 'train_features': "", + 'epochs': 1, + 'batch_size': 8, + 'feature_rate': 0.8, + 'random_state': 2024 + } + + args_a = { + 'job_id': job_id, + 'task_id': task_id, + 'is_label_holder': True, + 'result_receiver_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'participant_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': 'HeteroLR', + 'model_dict': model_dict + } + + args_b = { + 'job_id': job_id, + 'task_id': task_id, + 'is_label_holder': False, + 'result_receiver_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'participant_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': 'HeteroLR', + 'model_dict': model_dict + } + + return args_a, args_b + + +class TestXgboostTraining(unittest.TestCase): + + cancer = load_breast_cancer() + X = cancer.data + y = cancer.target + + # MinMaxScaler + from sklearn.preprocessing import MinMaxScaler + # 创建MinMaxScaler对象 + scaler = MinMaxScaler() + # 拟合并转换数据 + X = scaler.fit_transform(X) + + df = SecureDataset.assembling_dataset(X, y) + df_with_y, df_without_y = SecureDataset.hetero_split_dataset(df) + + def setUp(self): + self._active_rpc_client = RpcClientMock() + self._passive_rpc_client = RpcClientMock() + self._thread_event_manager = ThreadEventManager() + self._active_stub = ModelStub( + agency_id=ACTIVE_PARTY, + thread_event_manager=self._thread_event_manager, + rpc_client=self._active_rpc_client, + send_retry_times=3, + retry_interval_s=0.1 + ) + self._passive_stub = ModelStub( + agency_id=PASSIVE_PARTY, + thread_event_manager=self._thread_event_manager, + rpc_client=self._passive_rpc_client, + send_retry_times=3, + retry_interval_s=0.1 + ) + self._active_rpc_client.set_message_handler( + self._passive_stub.on_message_received) + self._passive_rpc_client.set_message_handler( + self._active_stub.on_message_received) + + def test_fit(self): + args_a, args_b = mock_args() + plot_lock = threading.Lock() + + active_components = Initializer(log_config_path='', config_path='', plot_lock=plot_lock) + active_components.stub = self._active_stub + active_components.config_data = { + 'JOB_TEMP_DIR': '/tmp/active', 'AGENCY_ID': ACTIVE_PARTY} + active_components.mock_logger = MockLogger() + task_info_a = SecureLRContext(args_a, active_components) + secure_dataset_a = SecureDataset(task_info_a, self.df_with_y) + booster_a = VerticalLRActiveParty(task_info_a, secure_dataset_a) + print(secure_dataset_a.feature_name) + print(secure_dataset_a.train_idx.shape) + print(secure_dataset_a.train_X.shape) + print(secure_dataset_a.train_y.shape) + print(secure_dataset_a.test_idx.shape) + print(secure_dataset_a.test_X.shape) + print(secure_dataset_a.test_y.shape) + + passive_components = Initializer(log_config_path='', config_path='', plot_lock=plot_lock) + passive_components.stub = self._passive_stub + passive_components.config_data = { + 'JOB_TEMP_DIR': '/tmp/passive', 'AGENCY_ID': PASSIVE_PARTY} + passive_components.mock_logger = MockLogger() + task_info_b = SecureLRContext(args_b, passive_components) + secure_dataset_b = SecureDataset(task_info_b, self.df_without_y) + booster_b = VerticalLRPassiveParty(task_info_b, secure_dataset_b) + print(secure_dataset_b.feature_name) + print(secure_dataset_b.train_idx.shape) + print(secure_dataset_b.train_X.shape) + print(secure_dataset_b.test_idx.shape) + print(secure_dataset_b.test_X.shape) + + def active_worker(): + try: + booster_a.fit() + # booster_a.save_model() + # train_praba = booster_a.get_train_praba() + # test_praba = booster_a.get_test_praba() + # Evaluation(task_info_a, secure_dataset_a, + # train_praba, test_praba) + # ResultFileHandling(task_info_a) + # booster_a.load_model() + # booster_a.predict() + # test_praba = booster_a.get_test_praba() + # task_info_a.algorithm_type = 'Predict' + # task_info_a.sync_file_list = {} + # Evaluation(task_info_a, secure_dataset_a, + # test_praba=test_praba) + # ResultFileHandling(task_info_a) + except Exception as e: + task_info_a.components.logger().info(traceback.format_exc()) + + def passive_worker(): + try: + booster_b.fit() + # booster_b.save_model() + # train_praba = booster_b.get_train_praba() + # test_praba = booster_b.get_test_praba() + # Evaluation(task_info_b, secure_dataset_b, + # train_praba, test_praba) + # ResultFileHandling(task_info_b) + # booster_b.load_model() + # booster_b.predict() + # test_praba = booster_b.get_test_praba() + # task_info_b.algorithm_type = 'Predict' + # task_info_b.sync_file_list = {} + # Evaluation(task_info_b, secure_dataset_b, + # test_praba=test_praba) + # ResultFileHandling(task_info_b) + except Exception as e: + task_info_b.components.logger().info(traceback.format_exc()) + + thread_lr_active = threading.Thread(target=active_worker, args=()) + thread_lr_active.start() + + thread_lr_passive = threading.Thread(target=passive_worker, args=()) + thread_lr_passive.start() + + thread_lr_active.join() + thread_lr_passive.join() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_model/secure_lr/vertical/__init__.py b/python/ppc_model/secure_lr/vertical/__init__.py new file mode 100644 index 00000000..bf7ebebe --- /dev/null +++ b/python/ppc_model/secure_lr/vertical/__init__.py @@ -0,0 +1,4 @@ +from ppc_model.secure_lr.vertical.active_party import VerticalLRActiveParty +from ppc_model.secure_lr.vertical.passive_party import VerticalLRPassiveParty + +__all__ = ["VerticalLRActiveParty", "VerticalLRPassiveParty"] diff --git a/python/ppc_model/secure_lr/vertical/active_party.py b/python/ppc_model/secure_lr/vertical/active_party.py new file mode 100644 index 00000000..1d9e3a2f --- /dev/null +++ b/python/ppc_model/secure_lr/vertical/active_party.py @@ -0,0 +1,165 @@ +import itertools +import time + +import numpy as np +from pandas import DataFrame + +from ppc_common.deps_services.serialize_type import SerializeType +from ppc_common.ppc_ml.feature.feature_importance import FeatureImportanceStore +from ppc_common.ppc_ml.feature.feature_importance import FeatureImportanceType +from ppc_common.ppc_protos.generated.ppc_model_pb2 import BestSplitInfo, IterationRequest +from ppc_common.ppc_utils import utils +from ppc_model.datasets.data_reduction.feature_selection import FeatureSelection +from ppc_model.datasets.data_reduction.sampling import Sampling +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.datasets.feature_binning.feature_binning import FeatureBinning +from ppc_model.metrics.evaluation import Evaluation +from ppc_model.metrics.loss import BinaryLoss +from ppc_model.secure_lr.secure_lr_context import SecureLRContext, LRMessage +from ppc_model.secure_lr.vertical.booster import VerticalBooster + + +class VerticalLRActiveParty(VerticalBooster): + + def __init__(self, ctx: SecureLRContext, dataset: SecureDataset) -> None: + super().__init__(ctx, dataset) + self.params = ctx.model_params + self._loss_func = BinaryLoss() + self._all_feature_name = [dataset.feature_name] + self._all_feature_num = len(dataset.feature_name) + self.log = ctx.components.logger() + self.storage_client = ctx.components.storage_client + self.log.info( + f'task {self.ctx.task_id}: print all params: {self.params.get_all_params()}') + + def fit( + self, + *args, + **kwargs, + ) -> None: + self.log.info( + f'task {self.ctx.task_id}: Starting the lr on the active party.') + self._init_active_data() + + max_iter = self._init_iter(self.dataset.train_X.shape[0], + self.params.epochs, self.params.batch_size) + for _ in range(max_iter): + self._iter_id += 1 + start_time = time.time() + self.log.info( + f'task {self.ctx.task_id}: Starting iter-{self._iter_id} in active party.') + + # 初始化 + idx, feature_select = self._init_each_iter() + self.log.info( + f'task {self.ctx.task_id}: feature select: {len(feature_select)}, {feature_select}.') + + # 构建 + self._build_iter(feature_select, idx) + + # 预测 + self._train_praba = self._predict_tree(self.dataset.train_X, LRMessage.PREDICT_LEAF_MASK.value) + # print('train_praba', set(self._train_praba)) + + # 评估 + if not self.params.silent and self.dataset.train_y is not None: + auc = Evaluation.fevaluation( + self.dataset.train_y, self._train_praba)['auc'] + self.log.info( + f'task {self.ctx.task_id}: iter-{self._iter_id}, auc: {auc}.') + self.log.info(f'task {self.ctx.task_id}: Ending iter-{self._iter_id}, ' + f'time_costs: {time.time() - start_time}s.') + + # 预测验证集 + self._test_praba = self._predict_tree(self.dataset.test_X, LRMessage.TEST_LEAF_MASK.value) + if not self.params.silent and self.dataset.test_y is not None: + auc = Evaluation.fevaluation( + self.dataset.test_y, self._test_praba)['auc'] + self.log.info( + f'task {self.ctx.task_id}: iter-{self._iter_id}, test auc: {auc}.') + + self._end_active_data() + + def transform(self, transform_data: DataFrame) -> DataFrame: + ... + + def predict(self, dataset: SecureDataset = None) -> np.ndarray: + start_time = time.time() + if dataset is None: + dataset = self.dataset + + test_praba = self._predict_tree(dataset.test_X, LRMessage.VALID_LEAF_MASK.value) + self._test_praba = test_praba + + if dataset.test_y is not None: + auc = Evaluation.fevaluation(dataset.test_y, test_praba)['auc'] + self.log.info(f'task {self.ctx.task_id}: predict test auc: {auc}.') + self.log.info( + f'task {self.ctx.task_id}: Ending predict, time_costs: {time.time() - start_time}s.') + + self._end_active_data(is_train=False) + + def _init_active_data(self): + + # 初始化预测值和权重 + self._train_praba = self._init_praba(self.dataset.train_X.shape[0]) + self._train_weights = self._init_weight(self.dataset.train_X.shape[1]) + self._test_weights = self._init_weight(self.dataset.test_X.shape[1]) + self._iter_id = 0 + + # 初始化所有参与方的特征 + for i in range(1, len(self.ctx.participant_id_list)): + feature_name_bytes = self._receive_byte_data( + self.ctx, LRMessage.FEATURE_NAME.value, i) + self._all_feature_name.append( + [s.decode('utf-8') for s in feature_name_bytes.split(b' ') if s]) + self._all_feature_num += len([s.decode('utf-8') + for s in feature_name_bytes.split(b' ') if s]) + + self.log.info(f'task {self.ctx.task_id}: total feature number:{self._all_feature_num}, ' + f'total feature name: {self._all_feature_name}.') + self.params.categorical_idx = self._get_categorical_idx( + list(itertools.chain(*self._all_feature_name)), self.params.categorical_feature) + self.params.my_categorical_idx = self._get_categorical_idx( + self.dataset.feature_name, self.params.categorical_feature) + + def _build_iter(self, feature_select, idx): + + x_, y_ = self.dataset.train_X[idx], self.dataset.train_y[idx] + + g = self._loss_func.dot_product(x_, self._train_weights) + h = 0.5 + self._loss_func.inference(g) + d = h - y_ + + self._send_d_instance_list(d) + public_key_list, d_other_list, partner_index_list = self._receive_d_instance_list() + deriv = self._calculate_deriv(x_, d, partner_index_list, d_other_list) + + self._train_weights -= self.params.learning_rate * deriv.astype('float') + self._train_weights[~np.isin(np.arange(len(self._train_weights)), feature_select)] = 0 + + def _predict_tree(self, X, key_type): + train_g = self._loss_func.dot_product(X, self._train_weights) + for i in range(1, len(self.ctx.participant_id_list)): + train_g_other = np.frombuffer( + self._receive_byte_data(self.ctx, key_type, i), dtype='float') + train_g += train_g_other + return self._loss_func.sigmoid(train_g) + + def _end_active_data(self, is_train=True): + if is_train: + for partner_index in range(1, len(self.ctx.participant_id_list)): + if self.ctx.participant_id_list[partner_index] in self.ctx.result_receiver_id_list: + self._send_byte_data(self.ctx, f'{LRMessage.PREDICT_PRABA.value}_train', + self._train_praba.astype('float').tobytes(), partner_index) + + for partner_index in range(1, len(self.ctx.participant_id_list)): + if self.ctx.participant_id_list[partner_index] in self.ctx.result_receiver_id_list: + self._send_byte_data(self.ctx, f'{LRMessage.PREDICT_PRABA.value}_test', + self._test_praba.astype('float').tobytes(), partner_index) + + else: + for partner_index in range(1, len(self.ctx.participant_id_list)): + if self.ctx.participant_id_list[partner_index] in self.ctx.result_receiver_id_list: + self._send_byte_data(self.ctx, f'{LRMessage.PREDICT_PRABA.value}_predict', + self._test_praba.astype('float').tobytes(), partner_index) diff --git a/python/ppc_model/secure_lr/vertical/booster.py b/python/ppc_model/secure_lr/vertical/booster.py new file mode 100644 index 00000000..68a68d5f --- /dev/null +++ b/python/ppc_model/secure_lr/vertical/booster.py @@ -0,0 +1,308 @@ +import os +import time +import random +import json +import itertools +import numpy as np + +from ppc_common.ppc_protos.generated.ppc_model_pb2 import BestSplitInfo +from ppc_common.ppc_utils.utils import AlgorithmType +from ppc_model.interface.model_base import VerticalModel +from ppc_model.datasets.data_reduction.feature_selection import FeatureSelection +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.common.protocol import PheMessage +from ppc_model.network.stub import PushRequest, PullRequest +from ppc_model.common.model_result import ResultFileHandling +from ppc_model.datasets.feature_binning.feature_binning import FeatureBinning +from ppc_model.secure_lr.secure_lr_context import SecureLRContext, LRMessage + + +# 抽离sgb的公共部分 +class VerticalBooster(VerticalModel): + def __init__(self, ctx: SecureLRContext, dataset: SecureDataset) -> None: + super().__init__(ctx) + self.dataset = dataset + self._stub = ctx.components.stub + + self._iter_id = None + + self._train_weights = None + self._train_praba = None + self._test_weights = None + self._test_praba = None + + random.seed(ctx.model_params.random_state) + np.random.seed(ctx.model_params.random_state) + + def _build_tree(self, *args, **kwargs): + + raise NotImplementedError + + def _predict_tree(self, *args, **kwargs): + + raise NotImplementedError + + @staticmethod + def _init_praba(n): + return np.full(n, 0.5) + + @staticmethod + def _init_weight(n): + return np.zeros(n, dtype=float) + + @staticmethod + def _init_iter(n, epochs, batch_size): + return round(n*epochs/batch_size) + + @staticmethod + def _get_sample_idx(i, n, size): + start_idx = (i * size) % n + end_idx = start_idx + size + if end_idx <= n: + idx = list(range(start_idx, end_idx)) + else: + head_idx = end_idx - n + idx = list(range(start_idx, n)) + list(range(head_idx)) + return idx + + @staticmethod + def _get_categorical_idx(feature_name, categorical_feature=[]): + categorical_idx = [] + if len(categorical_feature) > 0: + for i in categorical_feature: + if i in feature_name: + categorical_idx.append(feature_name.index(i)) + return categorical_idx + + def _init_each_iter(self): + + idx = self._get_sample_idx(self._iter_id-1, self.dataset.train_X.shape[0], + size = self.params.batch_size) + feature_select = FeatureSelection.feature_selecting( + list(self.dataset.feature_name), + self.params.train_feature, self.params.feature_rate) + + return idx, feature_select + + def _send_d_instance_list(self, d): + + d_list = self.rounding_d(d) + my_agency_id = self.ctx.components.config_data['AGENCY_ID'] + + start_time = time.time() + self.log.info(f'task {self.ctx.task_id}: Starting iter-{self._iter_id} ' + f'encrypt d in {my_agency_id} party.') + enc_dlist = self.ctx.phe.encrypt_batch_parallel((d_list).astype('object')) + self.log.info(f'task {self.ctx.task_id}: Finished iter-{self._iter_id} ' + f'encrypt d time_costs: {time.time() - start_time}.') + + for partner_index in range(len(self.ctx.participant_id_list)): + if self.ctx.participant_id_list[partner_index] != my_agency_id: + self._send_enc_data(self.ctx, f'{LRMessage.ENC_D_LIST.value}_{self._iter_id}', + enc_dlist, partner_index) + + def _receive_d_instance_list(self): + + my_agency_id = self.ctx.components.config_data['AGENCY_ID'] + + public_key_list = [] + d_other_list = [] + partner_index_list = [] + for partner_index in range(len(self.ctx.participant_id_list)): + if self.ctx.participant_id_list[partner_index] != my_agency_id: + public_key, enc_d = self._receive_enc_data( + self.ctx, f'{LRMessage.ENC_D_LIST.value}_{self._iter_id}', partner_index) + public_key_list.append(public_key) + d_other_list.append(np.array(enc_d)) + partner_index_list.append(partner_index) + + return public_key_list, d_other_list, partner_index_list + + def _calculate_deriv(self, x_, d, partner_index_list, d_other_list): + + x = self.rounding_d(x_) + deriv = np.matmul(x_.T, d) / x_.shape[0] + for i, partner_index in enumerate(partner_index_list): + # 计算明文*密文 matmul + # deriv_other_i = np.matmul(x.T, d_other_list[i]) + deriv_other_i = self.enc_matmul(x.T, d_other_list[i]) + + # 发送密文,接受密文并解密 + self._send_enc_data(self.ctx, f'{LRMessage.ENC_D_HIST.value}_{self._iter_id}', + deriv_other_i, partner_index) + _, enc_deriv_i = self._receive_enc_data( + self.ctx, f'{LRMessage.ENC_D_HIST.value}_{self._iter_id}', partner_index) + deriv_i_rec = np.array(self.ctx.phe.decrypt_batch(enc_deriv_i), dtype='object') + deriv_i = self.recover_d(self.ctx, deriv_i_rec, is_square=True) / x_.shape[0] + + # 发送明文,接受明文并计算 + self._send_byte_data(self.ctx, f'{LRMessage.D_MATMUL.value}_{self._iter_id}', + deriv_i.astype('float').tobytes(), partner_index) + deriv_x_i = np.frombuffer(self._receive_byte_data( + self.ctx, f'{LRMessage.D_MATMUL.value}_{self._iter_id}', partner_index), dtype=np.float) + self.log.info(f'{self.ctx.components.config_data["AGENCY_ID"]}, deriv: {deriv}.') + self.log.info(f'{self.ctx.components.config_data["AGENCY_ID"]}, deriv_x_i: {deriv_x_i}.') + deriv += deriv_x_i + self.log.info(f'{self.ctx.components.config_data["AGENCY_ID"]}, deriv: {deriv}.') + return deriv + + def _calculate_deriv1(self, x_, d, partner_index_list, d_other_list): + + x = self.rounding_d(x_) + deriv = np.matmul(x_.T, d) / x_.shape[0] + for i, partner_index in enumerate(partner_index_list): + # TODO:重载方法,目前支持np.array(enc_dlist).sum()的方式,不支持明文*密文 + # deriv_other_i = np.matmul(x.T, d_other_list[i]) + deriv_other_i = self.enc_matmul(x.T, d_other_list[i]) + self._send_enc_data(self.ctx, f'{LRMessage.ENC_D_HIST.value}_{self._iter_id}', + deriv_other_i, partner_index) + _, enc_deriv_i = self._receive_enc_data( + self.ctx, f'{LRMessage.ENC_D_HIST.value}_{self._iter_id}', partner_index) + deriv_i = np.array(self.ctx.phe.decrypt_batch(enc_deriv_i), dtype='object') + deriv += (self.recover_d(self.ctx, deriv_i, is_square=True) / x_.shape[0]) + return deriv + + def _send_enc_data(self, ctx, key_type, enc_data, partner_index, matrix_data=False): + log = ctx.components.logger() + start_time = time.time() + partner_id = ctx.participant_id_list[partner_index] + + if matrix_data: + self._stub.push(PushRequest( + receiver=partner_id, + task_id=ctx.task_id, + key=key_type, + data=PheMessage.packing_2dim_data( + ctx.codec, ctx.phe.public_key, enc_data) + )) + else: + self._stub.push(PushRequest( + receiver=partner_id, + task_id=ctx.task_id, + key=key_type, + data=PheMessage.packing_data( + ctx.codec, ctx.phe.public_key, enc_data) + )) + + log.info( + f"task {ctx.task_id}: Sending {key_type} to {partner_id} finished, " + f"data_length: {len(enc_data)}, time_costs: {time.time() - start_time}s") + + def _receive_enc_data(self, ctx, key_type, partner_index, matrix_data=False): + log = ctx.components.logger() + start_time = time.time() + partner_id = ctx.participant_id_list[partner_index] + + byte_data = self._stub.pull(PullRequest( + sender=partner_id, + task_id=ctx.task_id, + key=key_type + )) + + if matrix_data: + public_key, enc_data = PheMessage.unpacking_2dim_data( + ctx.codec, byte_data) + else: + public_key, enc_data = PheMessage.unpacking_data( + ctx.codec, byte_data) + + log.info( + f"task {ctx.task_id}: Received {key_type} from {partner_id} finished, " + f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s") + return public_key, enc_data + + def _send_byte_data(self, ctx, key_type, byte_data, partner_index): + log = ctx.components.logger() + start_time = time.time() + partner_id = ctx.participant_id_list[partner_index] + + self._stub.push(PushRequest( + receiver=partner_id, + task_id=ctx.task_id, + key=key_type, + data=byte_data + )) + + log.info( + f"task {ctx.task_id}: Sending {key_type} to {partner_id} finished, " + f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s") + + def _receive_byte_data(self, ctx, key_type, partner_index): + log = ctx.components.logger() + start_time = time.time() + partner_id = ctx.participant_id_list[partner_index] + + byte_data = self._stub.pull(PullRequest( + sender=partner_id, + task_id=ctx.task_id, + key=key_type + )) + + log.info( + f"task {ctx.task_id}: Received {key_type} from {partner_id} finished, " + f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s") + return byte_data + + def save_model(self, file_path=None): + log = self.ctx.components.logger() + if file_path is not None: + self.ctx.model_data_file = os.path.join( + file_path, self.ctx.MODEL_DATA_FILE) + + if not os.path.exists(self.ctx.model_data_file): + serial_weight = list(self._train_weights) + with open(self.ctx.model_data_file, 'w') as f: + json.dump(serial_weight, f) + ResultFileHandling._upload_file(self.ctx.components.storage_client, + self.ctx.model_data_file, self.ctx.remote_model_data_file) + log.info( + f"task {self.ctx.task_id}: Saved serial_weight to {self.ctx.model_data_file} finished.") + + def load_model(self, file_path=None): + log = self.ctx.components.logger() + if file_path is not None: + self.ctx.model_data_file = os.path.join( + file_path, self.ctx.MODEL_DATA_FILE) + if self.ctx.algorithm_type == AlgorithmType.Predict.name: + self.ctx.remote_model_data_file = os.path.join( + self.ctx.model_params.training_job_id, self.ctx.MODEL_DATA_FILE) + + ResultFileHandling._download_file(self.ctx.components.storage_client, + self.ctx.model_data_file, self.ctx.remote_model_data_file) + + with open(self.ctx.model_data_file, 'r') as f: + serial_weight = json.load(f) + self._train_weights = np.array(serial_weight) + log.info( + f"task {self.ctx.task_id}: Load serial_weight from {self.ctx.model_data_file} finished.") + + def get_weights(self): + return self._train_weights + + def get_train_praba(self): + return self._train_praba + + def get_test_praba(self): + return self._test_praba + + @staticmethod + def enc_matmul(arr, enc): + result = [] + for i in range(len(arr)): + # arr[i] * enc # 需要将密文放在前面 + result.append((enc * arr[i]).sum()) + return np.array(result) + + @staticmethod + def rounding_d(d_list: np.ndarray, expand=1000): + return (d_list * expand).astype('int') + + @staticmethod + def recover_d(ctx, d_sum_list: np.ndarray, is_square=False, expand=1000): + + d_sum_list[d_sum_list > 2**(ctx.phe.key_length-1)] -= 2**(ctx.phe.key_length) + + if is_square: + return (d_sum_list / expand / expand).astype('float') + else: + return (d_sum_list / expand).astype('float') diff --git a/python/ppc_model/secure_lr/vertical/passive_party.py b/python/ppc_model/secure_lr/vertical/passive_party.py new file mode 100644 index 00000000..46101559 --- /dev/null +++ b/python/ppc_model/secure_lr/vertical/passive_party.py @@ -0,0 +1,125 @@ +import itertools +import multiprocessing +import time +import numpy as np +from pandas import DataFrame + +from ppc_common.ppc_utils import utils +from ppc_common.ppc_protos.generated.ppc_model_pb2 import BestSplitInfo, IterationRequest +from ppc_model.datasets.data_reduction.feature_selection import FeatureSelection +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.datasets.feature_binning.feature_binning import FeatureBinning +from ppc_model.metrics.loss import BinaryLoss +from ppc_model.secure_lr.secure_lr_context import SecureLRContext, LRMessage +from ppc_model.secure_lr.vertical.booster import VerticalBooster + + +class VerticalLRPassiveParty(VerticalBooster): + + def __init__(self, ctx: SecureLRContext, dataset: SecureDataset) -> None: + super().__init__(ctx, dataset) + self.params = ctx.model_params + self._loss_func = BinaryLoss() + self.log = ctx.components.logger() + self.log.info( + f'task {self.ctx.task_id}: print all params: {self.params.get_all_params()}') + + def fit( + self, + *args, + **kwargs, + ) -> None: + self.log.info( + f'task {self.ctx.task_id}: Starting the lr on the passive party.') + self._init_passive_data() + + max_iter = self._init_iter(self.dataset.train_X.shape[0], + self.params.epochs, self.params.batch_size) + for _ in range(max_iter): + self._iter_id += 1 + start_time = time.time() + self.log.info( + f'task {self.ctx.task_id}: Starting iter-{self._iter_id} in passive party.') + + # 初始化 + idx, feature_select = self._init_each_iter() + self.log.info( + f'task {self.ctx.task_id}: feature select: {len(feature_select)}, {feature_select}.') + + # 构建 + self._build_iter(feature_select, idx) + + # 预测 + self._predict_tree(self.dataset.train_X, LRMessage.PREDICT_LEAF_MASK.value) + self.log.info(f'task {self.ctx.task_id}: Ending iter-{self._iter_id}, ' + f'time_costs: {time.time() - start_time}s.') + + # 预测验证集 + self._predict_tree(self.dataset.test_X, LRMessage.TEST_LEAF_MASK.value) + + self._end_passive_data() + + def transform(self, transform_data: DataFrame) -> DataFrame: + ... + + def predict(self, dataset: SecureDataset = None) -> np.ndarray: + start_time = time.time() + if dataset is None: + dataset = self.dataset + + self._predict_tree(dataset.test_X, LRMessage.VALID_LEAF_MASK.value) + self.log.info( + f'task {self.ctx.task_id}: Ending predict, time_costs: {time.time() - start_time}s.') + + self._end_passive_data(is_train=False) + + def _init_passive_data(self): + + # 初始化预测值和权重 + self._train_praba = self._init_praba(self.dataset.train_X.shape[0]) + self._train_weights = self._init_weight(self.dataset.train_X.shape[1]) + self._test_weights = self._init_weight(self.dataset.test_X.shape[1]) + self._iter_id = 0 + + # 初始化参与方特征 + self._send_byte_data(self.ctx, LRMessage.FEATURE_NAME.value, + b''.join(s.encode('utf-8') + b' ' for s in self.dataset.feature_name), 0) + self.params.my_categorical_idx = self._get_categorical_idx( + self.dataset.feature_name, self.params.categorical_feature) + + def _build_iter(self, feature_select, idx): + + x_ = self.dataset.train_X[idx] + + g = self._loss_func.dot_product(x_, self._train_weights) + h = self._loss_func.inference(g) + d = h + + self._send_d_instance_list(d) + public_key_list, d_other_list, partner_index_list = self._receive_d_instance_list() + deriv = self._calculate_deriv(x_, d, partner_index_list, d_other_list) + + self._train_weights -= self.params.learning_rate * deriv.astype('float') + self._train_weights[~np.isin(np.arange(len(self._train_weights)), feature_select)] = 0 + + def _predict_tree(self, X, key_type): + train_g = self._loss_func.dot_product(X, self._train_weights) + self._send_byte_data(self.ctx, f'{key_type}', + train_g.astype('float').tobytes(), 0) + + def _end_passive_data(self, is_train=True): + + if self.ctx.components.config_data['AGENCY_ID'] in self.ctx.result_receiver_id_list: + if is_train: + self._train_praba = np.frombuffer( + self._receive_byte_data( + self.ctx, f'{LRMessage.PREDICT_PRABA.value}_train', 0), dtype=np.float) + + self._test_praba = np.frombuffer( + self._receive_byte_data( + self.ctx, f'{LRMessage.PREDICT_PRABA.value}_test', 0), dtype=np.float) + + else: + self._test_praba = np.frombuffer( + self._receive_byte_data( + self.ctx, f'{LRMessage.PREDICT_PRABA.value}_predict', 0), dtype=np.float)