From 936488abd4c0cb5bf6dfbadc7361f6f67d0a8276 Mon Sep 17 00:00:00 2001 From: cyjseagull Date: Wed, 16 Oct 2024 17:28:30 +0800 Subject: [PATCH] refactor common part of lr-context and xgb-context into SecureModelContext --- python/ppc_common/ppc_utils/utils.py | 6 +- python/ppc_model/common/context.py | 1 - python/ppc_model/datasets/dataset.py | 7 +- .../model_result/task_result_handler.py | 14 +- python/ppc_model/ppc_model_app.py | 45 ++--- .../secure_lgbm/secure_lgbm_context.py | 169 ++++-------------- .../ppc_model/secure_lgbm/vertical/booster.py | 39 +--- .../ppc_model/secure_lr/secure_lr_context.py | 127 +------------ .../secure_lr/secure_lr_prediction_engine.py | 2 +- .../ppc_model/secure_lr/vertical/booster.py | 82 ++++----- .../ppc_model/secure_model_base/__init__.py | 0 .../secure_model_base/secure_model_booster.py | 48 +++++ .../secure_model_base/secure_model_context.py | 159 ++++++++++++++++ 13 files changed, 334 insertions(+), 365 deletions(-) create mode 100644 python/ppc_model/secure_model_base/__init__.py create mode 100644 python/ppc_model/secure_model_base/secure_model_booster.py create mode 100644 python/ppc_model/secure_model_base/secure_model_context.py diff --git a/python/ppc_common/ppc_utils/utils.py b/python/ppc_common/ppc_utils/utils.py index 51605969..6450a4a4 100644 --- a/python/ppc_common/ppc_utils/utils.py +++ b/python/ppc_common/ppc_utils/utils.py @@ -67,11 +67,6 @@ MPC_TRAIN_SET_METRIC_PR_FILE = "mpc_train_metric_pr.svg" MPC_TRAIN_SET_METRIC_ACCURACY_FILE = "mpc_train_metric_accuracy.svg" MPC_TRAIN_SET_METRIC_KS_TABLE = "mpc_train_metric_ks.csv" -MPC_EVAL_METRIC_ROC_FILE = "mpc_eval_metric_roc.svg" -MPC_EVAL_METRIC_KS_FILE = "mpc_eval_metric_ks.svg" -MPC_EVAL_METRIC_PR_FILE = "mpc_eval_metric_pr.svg" -MPC_EVAL_METRIC_ACCURACY_FILE = "mpc_eval_metric_accuracy.svg" -MPC_EVAL_METRIC_KS_TABLE = "mpc_eval_metric_ks.csv" MPC_TRAIN_METRIC_CONFUSION_MATRIX_FILE = "mpc_metric_confusion_matrix.svg" METRICS_OVER_ITERATION_FILE = "metrics_over_iterations.svg" @@ -111,6 +106,7 @@ class CryptoType(Enum): ECDSA = 1 GM = 2 + @unique class HashType(Enum): BYTES = 1 diff --git a/python/ppc_model/common/context.py b/python/ppc_model/common/context.py index 02614e47..551b9f83 100644 --- a/python/ppc_model/common/context.py +++ b/python/ppc_model/common/context.py @@ -1,4 +1,3 @@ -import unittest from ppc_model.common.base_context import BaseContext from ppc_model.common.initializer import Initializer diff --git a/python/ppc_model/datasets/dataset.py b/python/ppc_model/datasets/dataset.py index ab3329a5..451ee44a 100644 --- a/python/ppc_model/datasets/dataset.py +++ b/python/ppc_model/datasets/dataset.py @@ -190,9 +190,10 @@ def _dataset_fe_selected(self, file_path, feature_name): def _construct_dataset(self): if self.algorithm_type == AlgorithmType.Predict.name: - my_fields = [ - item["fields"] for item in self.ctx.model_predict_algorithm['participant_agency_list'] - if item["agency"] == self.ctx.components.config_data['AGENCY_ID']] + my_fields = [] + for item in self.ctx.model_predict_algorithm['participant_agency_list']: + if item["agency"] == self.ctx.components.config_data['AGENCY_ID']: + my_fields = item["fields"] if 'y' in self.model_data.columns and 'y' not in my_fields: my_fields = ['y'] + my_fields if 'id' in self.model_data.columns and 'id' not in my_fields: diff --git a/python/ppc_model/model_result/task_result_handler.py b/python/ppc_model/model_result/task_result_handler.py index 02697ee2..d5ef6ca6 100644 --- a/python/ppc_model/model_result/task_result_handler.py +++ b/python/ppc_model/model_result/task_result_handler.py @@ -65,10 +65,10 @@ class JobEvaluationResult: EvaluationType.ACCURACY: utils.MPC_TRAIN_SET_METRIC_ACCURACY_FILE} DEFAULT_EVAL_EVALUATION_FILES = { - EvaluationType.ROC: utils.MPC_EVAL_METRIC_ROC_FILE, - EvaluationType.PR: utils.MPC_EVAL_METRIC_PR_FILE, - EvaluationType.KS: utils.MPC_EVAL_METRIC_KS_FILE, - EvaluationType.ACCURACY: utils.MPC_EVAL_METRIC_ACCURACY_FILE + EvaluationType.ROC: utils.MPC_TRAIN_METRIC_ROC_FILE, + EvaluationType.PR: utils.MPC_TRAIN_METRIC_PR_FILE, + EvaluationType.KS: utils.MPC_TRAIN_METRIC_KS_FILE, + EvaluationType.ACCURACY: utils.MPC_TRAIN_METRIC_ACCURACY_FILE } def __init__(self, property_name, classification_type, @@ -357,7 +357,7 @@ def _get_evaluation_result(self): components=self.components) # load the ks table self.train_evaluation_result.load_ks_table( - "mpc_train_metric_ks.csv", "TrainKSTable") + utils.MPC_TRAIN_METRIC_KS_TABLE, "TrainKSTable") self.result_list.append(self.train_evaluation_result) self.validation_evaluation_result = JobEvaluationResult( @@ -368,7 +368,7 @@ def _get_evaluation_result(self): components=self.components) # load the ks_table self.validation_evaluation_result.load_ks_table( - "mpc_metric_ks.csv", "KSTable") + utils.MPC_TRAIN_METRIC_KS_TABLE, "KSTable") self.result_list.append(self.validation_evaluation_result) self.model = ModelJobResult(self.xgb_job, @@ -396,7 +396,7 @@ def _get_evaluation_result(self): components=self.components) # load ks_table self.predict_evaluation_result.load_ks_table( - "mpc_eval_metric_ks.csv", "KSTable") + utils.MPC_TRAIN_METRIC_KS_TABLE, "KSTable") self.result_list.append(self.predict_evaluation_result) # load model_result diff --git a/python/ppc_model/ppc_model_app.py b/python/ppc_model/ppc_model_app.py index 63d77156..6c40bbc3 100644 --- a/python/ppc_model/ppc_model_app.py +++ b/python/ppc_model/ppc_model_app.py @@ -2,28 +2,29 @@ import sys sys.path.append("../") -from ppc_model.secure_lgbm.secure_lgbm_training_engine import SecureLGBMTrainingEngine -from ppc_model.secure_lgbm.secure_lgbm_prediction_engine import SecureLGBMPredictionEngine -from ppc_model.secure_lr.secure_lr_training_engine import SecureLRTrainingEngine -from ppc_model.preprocessing.preprocessing_engine import PreprocessingEngine -from ppc_model.network.http.restx import api -from ppc_model.network.http.model_controller import ns2 as log_namespace -from ppc_model.network.http.model_controller import ns as task_namespace -from ppc_model.network.grpc.grpc_server import ModelService -from ppc_model.feature_engineering.feature_engineering_engine import FeatureEngineeringEngine -from ppc_model.common.protocol import ModelTask -from ppc_model.common.global_context import components -from ppc_common.ppc_utils import utils -from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc -from paste.translogger import TransLogger -from flask import Flask, Blueprint -from cheroot.wsgi import Server as WSGIServer -from cheroot.ssl.builtin import BuiltinSSLAdapter -import grpc -from threading import Thread -from concurrent import futures -import os import multiprocessing +import os +from concurrent import futures +from threading import Thread +import grpc +from cheroot.ssl.builtin import BuiltinSSLAdapter +from cheroot.wsgi import Server as WSGIServer +from flask import Flask, Blueprint +from paste.translogger import TransLogger +from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc +from ppc_common.ppc_utils import utils +from ppc_model.common.global_context import components +from ppc_model.common.protocol import ModelTask +from ppc_model.feature_engineering.feature_engineering_engine import FeatureEngineeringEngine +from ppc_model.network.grpc.grpc_server import ModelService +from ppc_model.network.http.model_controller import ns as task_namespace +from ppc_model.network.http.model_controller import ns2 as log_namespace +from ppc_model.network.http.restx import api +from ppc_model.preprocessing.preprocessing_engine import PreprocessingEngine +from ppc_model.secure_lr.secure_lr_prediction_engine import SecureLRPredictionEngine +from ppc_model.secure_lr.secure_lr_training_engine import SecureLRTrainingEngine +from ppc_model.secure_lgbm.secure_lgbm_prediction_engine import SecureLGBMPredictionEngine +from ppc_model.secure_lgbm.secure_lgbm_training_engine import SecureLGBMTrainingEngine app = Flask(__name__) @@ -53,6 +54,8 @@ def register_task_handler(): ModelTask.XGB_PREDICTING, SecureLGBMPredictionEngine.run) task_manager.register_task_handler( ModelTask.LR_TRAINING, SecureLRTrainingEngine.run) + task_manager.register_task_handler( + ModelTask.LR_PREDICTING, SecureLRPredictionEngine.run) def model_serve(): diff --git a/python/ppc_model/secure_lgbm/secure_lgbm_context.py b/python/ppc_model/secure_lgbm/secure_lgbm_context.py index 4871c3a6..804c5b97 100644 --- a/python/ppc_model/secure_lgbm/secure_lgbm_context.py +++ b/python/ppc_model/secure_lgbm/secure_lgbm_context.py @@ -1,18 +1,19 @@ + import os from enum import Enum -from typing import Any, Dict from sklearn.base import BaseEstimator from ppc_common.ppc_utils.utils import AlgorithmType from ppc_common.ppc_crypto.phe_factory import PheCipherFactory -from ppc_model.common.context import Context from ppc_model.common.initializer import Initializer from ppc_model.common.protocol import TaskRole from ppc_common.ppc_utils import common_func from ppc_model.common.model_setting import ModelSetting +from ppc_model.secure_model_base.secure_model_context import SecureModel +from ppc_model.secure_model_base.secure_model_context import SecureModelContext -class LGBMModel(BaseEstimator): +class LGBMModel(SecureModel): def __init__( self, @@ -36,7 +37,6 @@ def __init__( importance_type: str = 'split', **kwargs ): - self.boosting_type = boosting_type self.objective = objective self.num_leaves = num_leaves @@ -55,60 +55,7 @@ def __init__( self.random_state = random_state self.n_jobs = n_jobs self.importance_type = importance_type - self._other_params: Dict[str, Any] = {} - self.set_params(**kwargs) - - def get_params(self, deep: bool = True) -> Dict[str, Any]: - """Get parameters for this estimator. - - Parameters - ---------- - deep : bool, optional (default=True) - If True, will return the parameters for this estimator and - contained subobjects that are estimators. - - Returns - ------- - params : dict - Parameter names mapped to their values. - """ - params = super().get_params(deep=deep) - params.update(self._other_params) - return params - - def set_model_setting(self, model_setting: ModelSetting) -> "LGBMModel": - # 获取对象的所有属性名 - attrs = dir(model_setting) - # 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性) - attrs = [attr for attr in attrs if not attr.startswith('_')] - - params = {} - for attr in attrs: - try: - setattr(self, attr, getattr(model_setting, attr)) - except Exception as e: - pass - return self - - def set_params(self, **params: Any) -> "LGBMModel": - """Set the parameters of this estimator. - - Parameters - ---------- - **params - Parameter names with their new values. - - Returns - ------- - self : object - Returns self. - """ - for key, value in params.items(): - setattr(self, key, value) - if hasattr(self, f"_{key}"): - setattr(self, f"_{key}", value) - self._other_params[key] = value - return self + super().__init__(**kwargs) class ModelTaskParams(LGBMModel): @@ -169,98 +116,54 @@ def _get_params(self): """返回LGBMClassifier所有参数""" return LGBMModel().get_params() - def get_all_params(self): - """返回SecureLGBMParams所有参数""" - # 获取对象的所有属性名 - attrs = dir(self) - # 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性) - attrs = [attr for attr in attrs if not attr.startswith('_')] - - params = {} - for attr in attrs: - try: - # 使用getattr来获取属性的值 - value = getattr(self, attr) - # 检查value是否可调用(例如,方法或函数),如果是,则不打印其值 - if not callable(value): - params[attr] = value - except Exception as e: - pass - return params - -class SecureLGBMContext(Context): +class SecureLGBMContext(SecureModelContext): def __init__(self, args, components: Initializer ): - - if args['is_label_holder']: - role = TaskRole.ACTIVE_PARTY - else: - role = TaskRole.PASSIVE_PARTY - - super().__init__(args['job_id'], - args['task_id'], - components, - role) + super().__init__(args, components) self.phe = PheCipherFactory.build_phe( components.homo_algorithm, components.public_key_length) self.codec = PheCipherFactory.build_codec(components.homo_algorithm) - self.is_label_holder = args['is_label_holder'] - self.result_receiver_id_list = args['result_receiver_id_list'] - self.participant_id_list = args['participant_id_list'] - self.model_predict_algorithm = common_func.get_config_value( - "model_predict_algorithm", None, args, False) - self.algorithm_type = args['algorithm_type'] - if 'dataset_id' in args and args['dataset_id'] is not None: - self.dataset_file_path = os.path.join( - self.workspace, args['dataset_id']) - else: - self.dataset_file_path = None - - self.model_params = SecureLGBMParams() - model_setting = ModelSetting(args['model_dict']) - self.set_model_params(model_setting) - if model_setting.train_features is not None and len(model_setting.train_features) > 0: - self.model_params.train_feature = model_setting.train_features.split( - ',') - if model_setting.categorical is not None and len(model_setting.categorical) > 0: - self.model_params.categorical_feature = model_setting.categorical.split( - ',') - self.model_params.n_estimators = model_setting.num_trees - self.model_params.feature_rate = model_setting.colsample_bytree - self.model_params.min_split_gain = model_setting.gamma - self.model_params.random_state = model_setting.seed - self.sync_file_list = {} - if self.algorithm_type == AlgorithmType.Train.name: - self.set_sync_file() - - def set_model_params(self, model_setting: ModelSetting): - """设置lgbm参数""" - self.model_params.set_model_setting(model_setting) + def create_model_param(self): + return SecureLGBMParams() def get_model_params(self): """获取lgbm参数""" return self.model_params def set_sync_file(self): - self.sync_file_list['metrics_iteration'] = [self.metrics_iteration_file, self.remote_metrics_iteration_file] - self.sync_file_list['feature_importance'] = [self.feature_importance_file, self.remote_feature_importance_file] - self.sync_file_list['summary_evaluation'] = [self.summary_evaluation_file, self.remote_summary_evaluation_file] - self.sync_file_list['train_ks_table'] = [self.train_metric_ks_table, self.remote_train_metric_ks_table] - self.sync_file_list['train_metric_roc'] = [self.train_metric_roc_file, self.remote_train_metric_roc_file] - self.sync_file_list['train_metric_ks'] = [self.train_metric_ks_file, self.remote_train_metric_ks_file] - self.sync_file_list['train_metric_pr'] = [self.train_metric_pr_file, self.remote_train_metric_pr_file] - self.sync_file_list['train_metric_acc'] = [self.train_metric_acc_file, self.remote_train_metric_acc_file] - self.sync_file_list['test_ks_table'] = [self.test_metric_ks_table, self.remote_test_metric_ks_table] - self.sync_file_list['test_metric_roc'] = [self.test_metric_roc_file, self.remote_test_metric_roc_file] - self.sync_file_list['test_metric_ks'] = [self.test_metric_ks_file, self.remote_test_metric_ks_file] - self.sync_file_list['test_metric_pr'] = [self.test_metric_pr_file, self.remote_test_metric_pr_file] - self.sync_file_list['test_metric_acc'] = [self.test_metric_acc_file, self.remote_test_metric_acc_file] + self.sync_file_list['metrics_iteration'] = [ + self.metrics_iteration_file, self.remote_metrics_iteration_file] + self.sync_file_list['feature_importance'] = [ + self.feature_importance_file, self.remote_feature_importance_file] + self.sync_file_list['summary_evaluation'] = [ + self.summary_evaluation_file, self.remote_summary_evaluation_file] + self.sync_file_list['train_ks_table'] = [ + self.train_metric_ks_table, self.remote_train_metric_ks_table] + self.sync_file_list['train_metric_roc'] = [ + self.train_metric_roc_file, self.remote_train_metric_roc_file] + self.sync_file_list['train_metric_ks'] = [ + self.train_metric_ks_file, self.remote_train_metric_ks_file] + self.sync_file_list['train_metric_pr'] = [ + self.train_metric_pr_file, self.remote_train_metric_pr_file] + self.sync_file_list['train_metric_acc'] = [ + self.train_metric_acc_file, self.remote_train_metric_acc_file] + self.sync_file_list['test_ks_table'] = [ + self.test_metric_ks_table, self.remote_test_metric_ks_table] + self.sync_file_list['test_metric_roc'] = [ + self.test_metric_roc_file, self.remote_test_metric_roc_file] + self.sync_file_list['test_metric_ks'] = [ + self.test_metric_ks_file, self.remote_test_metric_ks_file] + self.sync_file_list['test_metric_pr'] = [ + self.test_metric_pr_file, self.remote_test_metric_pr_file] + self.sync_file_list['test_metric_acc'] = [ + self.test_metric_acc_file, self.remote_test_metric_acc_file] + class LGBMMessage(Enum): FEATURE_NAME = "FEATURE_NAME" diff --git a/python/ppc_model/secure_lgbm/vertical/booster.py b/python/ppc_model/secure_lgbm/vertical/booster.py index 4d0c184b..27000262 100644 --- a/python/ppc_model/secure_lgbm/vertical/booster.py +++ b/python/ppc_model/secure_lgbm/vertical/booster.py @@ -13,11 +13,12 @@ from ppc_model.network.stub import PushRequest, PullRequest from ppc_model.common.model_result import ResultFileHandling from ppc_model.datasets.feature_binning.feature_binning import FeatureBinning +from ppc_model.secure_model_base.secure_model_booster import SecureModelBooster from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext, LGBMMessage # 抽离sgb的公共部分 -class VerticalBooster(VerticalModel): +class VerticalBooster(SecureModelBooster): def __init__(self, ctx: SecureLGBMContext, dataset: SecureDataset) -> None: super().__init__(ctx) self.dataset = dataset @@ -33,6 +34,7 @@ def __init__(self, ctx: SecureLGBMContext, dataset: SecureDataset) -> None: self._train_praba = None self._test_weights = None self._test_praba = None + self.logger = ctx.components.logger() random.seed(ctx.model_params.random_state) np.random.seed(ctx.model_params.random_state) @@ -215,13 +217,13 @@ def _split_test_data(ctx, test_X, X_split): return feat_bin.data_binning(test_X, X_split)[0] def save_model(self, file_path=None): - log = self.ctx.components.logger() + super().save_model(file_path, "lgbm_model") + + def save_model_hook(self, file_path): + # save the feature_bin if file_path is not None: self.ctx.feature_bin_file = os.path.join( file_path, self.ctx.FEATURE_BIN_FILE) - self.ctx.model_data_file = os.path.join( - file_path, self.ctx.MODEL_DATA_FILE) - if self._X_split is not None and not os.path.exists(self.ctx.feature_bin_file): X_split_dict = {k: v for k, v in zip( self.dataset.feature_name, self._X_split)} @@ -232,32 +234,8 @@ def save_model(self, file_path=None): log.info( f"task {self.ctx.task_id}: Saved x_split to {self.ctx.feature_bin_file} finished.") - if not os.path.exists(self.ctx.model_data_file): - serial_trees = [self._serial_tree(tree) for tree in self._trees] - with open(self.ctx.model_data_file, 'w') as f: - json.dump(serial_trees, f) - ResultFileHandling._upload_file(self.ctx.components.storage_client, - self.ctx.model_data_file, self.ctx.remote_model_data_file) - log.info( - f"task {self.ctx.task_id}: Saved serial_trees to {self.ctx.model_data_file} finished.") - - self.merge_model_file() - - def merge_model_file(self): - + def merge_model_file(self, lgbm_model): # 加密文件 - lgbm_model = {} - lgbm_model['model_type'] = 'xgb_model' - lgbm_model['label_provider'] = self.ctx.participant_id_list[0] - lgbm_model['label_column'] = 'y' - lgbm_model['participant_agency_list'] = [] - for partner_index in range(0, len(self.ctx.participant_id_list)): - agency_info = { - 'agency': self.ctx.participant_id_list[partner_index]} - agency_info['fields'] = self._all_feature_name[partner_index] - lgbm_model['participant_agency_list'].append(agency_info) - - lgbm_model['model_dict'] = self.ctx.model_params.get_all_params() model_text = {} with open(self.ctx.feature_bin_file, 'rb') as f: feature_bin_data = f.read() @@ -299,6 +277,7 @@ def merge_model_file(self): f"task {self.ctx.task_id}: Saved enc model to {self.ctx.model_enc_file} finished.") def split_model_file(self): + # 传入模型 my_agency_id = self.ctx.components.config_data['AGENCY_ID'] model_text = self.ctx.model_predict_algorithm['model_text'] diff --git a/python/ppc_model/secure_lr/secure_lr_context.py b/python/ppc_model/secure_lr/secure_lr_context.py index 7ecbe338..d4e74a01 100644 --- a/python/ppc_model/secure_lr/secure_lr_context.py +++ b/python/ppc_model/secure_lr/secure_lr_context.py @@ -1,18 +1,16 @@ import os from enum import Enum -from typing import Any, Dict -from sklearn.base import BaseEstimator - from ppc_common.ppc_utils.utils import AlgorithmType from ppc_common.ppc_crypto.phe_factory import PheCipherFactory -from ppc_model.common.context import Context from ppc_model.common.initializer import Initializer from ppc_model.common.protocol import TaskRole from ppc_common.ppc_utils import common_func from ppc_model.common.model_setting import ModelSetting +from ppc_model.secure_model_base.secure_model_context import SecureModel +from ppc_model.secure_model_base.secure_model_context import SecureModelContext -class LRModel(BaseEstimator): +class LRModel(SecureModel): def __init__( self, @@ -29,60 +27,7 @@ def __init__( self.learning_rate = learning_rate self.random_state = random_state self.n_jobs = n_jobs - self._other_params: Dict[str, Any] = {} - self.set_params(**kwargs) - - def get_params(self, deep: bool = True) -> Dict[str, Any]: - """Get parameters for this estimator. - - Parameters - ---------- - deep : bool, optional (default=True) - If True, will return the parameters for this estimator and - contained subobjects that are estimators. - - Returns - ------- - params : dict - Parameter names mapped to their values. - """ - params = super().get_params(deep=deep) - params.update(self._other_params) - return params - - def set_model_setting(self, model_setting: ModelSetting) -> "LRModel": - # 获取对象的所有属性名 - attrs = dir(model_setting) - # 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性) - attrs = [attr for attr in attrs if not attr.startswith('_')] - - params = {} - for attr in attrs: - try: - setattr(self, attr, getattr(model_setting, attr)) - except Exception as e: - pass - return self - - def set_params(self, **params: Any) -> "LRModel": - """Set the parameters of this estimator. - - Parameters - ---------- - **params - Parameter names with their new values. - - Returns - ------- - self : object - Returns self. - """ - for key, value in params.items(): - setattr(self, key, value) - if hasattr(self, f"_{key}"): - setattr(self, f"_{key}", value) - self._other_params[key] = value - return self + super().__init__(**kwargs) class ModelTaskParams(LRModel): @@ -124,75 +69,21 @@ def _get_params(self): """返回LRClassifier所有参数""" return LRModel().get_params() - def get_all_params(self): - """返回SecureLRParams所有参数""" - # 获取对象的所有属性名 - attrs = dir(self) - # 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性) - attrs = [attr for attr in attrs if not attr.startswith('_')] - - params = {} - for attr in attrs: - try: - # 使用getattr来获取属性的值 - value = getattr(self, attr) - # 检查value是否可调用(例如,方法或函数),如果是,则不打印其值 - if not callable(value): - params[attr] = value - except Exception as e: - pass - return params - -class SecureLRContext(Context): +class SecureLRContext(SecureModelContext): def __init__(self, args, components: Initializer ): - - if args['is_label_holder']: - role = TaskRole.ACTIVE_PARTY - else: - role = TaskRole.PASSIVE_PARTY - - super().__init__(args['job_id'], - args['task_id'], - components, - role) + super().__init__(args, components) self.phe = PheCipherFactory.build_phe( components.homo_algorithm, components.public_key_length) self.codec = PheCipherFactory.build_codec(components.homo_algorithm) - self.is_label_holder = args['is_label_holder'] - self.result_receiver_id_list = args['result_receiver_id_list'] - self.participant_id_list = args['participant_id_list'] - self.model_predict_algorithm = common_func.get_config_value( - "model_predict_algorithm", None, args, False) - self.algorithm_type = args['algorithm_type'] - if 'dataset_id' in args and args['dataset_id'] is not None: - self.dataset_file_path = os.path.join( - self.workspace, args['dataset_id']) - else: - self.dataset_file_path = None - - self.model_params = SecureLRParams() - model_setting = ModelSetting(args['model_dict']) - self.set_model_params(model_setting) - if model_setting.train_features is not None and len(model_setting.train_features) > 0: - self.model_params.train_feature = model_setting.train_features.split( - ',') - if model_setting.categorical is not None and len(model_setting.categorical) > 0: - self.model_params.categorical_feature = model_setting.categorical.split( - ',') - self.model_params.random_state = model_setting.seed - self.sync_file_list = {} - if self.algorithm_type == AlgorithmType.Train.name: - self.set_sync_file() - - def set_model_params(self, model_setting: ModelSetting): - """设置lr参数""" - self.model_params.set_model_setting(model_setting) + + def create_model_param(self): + return SecureLRParams() def get_model_params(self): """获取lr参数""" diff --git a/python/ppc_model/secure_lr/secure_lr_prediction_engine.py b/python/ppc_model/secure_lr/secure_lr_prediction_engine.py index 218264b6..92776269 100644 --- a/python/ppc_model/secure_lr/secure_lr_prediction_engine.py +++ b/python/ppc_model/secure_lr/secure_lr_prediction_engine.py @@ -9,7 +9,7 @@ from ppc_model.secure_lr.vertical import VerticalLRActiveParty, VerticalLRPassiveParty -class SecureLGBMPredictionEngine(TaskEngine): +class SecureLRPredictionEngine(TaskEngine): task_type = ModelTask.LR_PREDICTING @staticmethod diff --git a/python/ppc_model/secure_lr/vertical/booster.py b/python/ppc_model/secure_lr/vertical/booster.py index 9ab285b2..8ed8a068 100644 --- a/python/ppc_model/secure_lr/vertical/booster.py +++ b/python/ppc_model/secure_lr/vertical/booster.py @@ -16,10 +16,12 @@ from ppc_model.common.model_result import ResultFileHandling from ppc_model.datasets.feature_binning.feature_binning import FeatureBinning from ppc_model.secure_lr.secure_lr_context import SecureLRContext, LRMessage - +from ppc_model.secure_model_base.secure_model_booster import SecureModelBooster # 抽离sgb的公共部分 -class VerticalBooster(VerticalModel): + + +class VerticalBooster(SecureModelBooster): def __init__(self, ctx: SecureLRContext, dataset: SecureDataset) -> None: super().__init__(ctx) self.dataset = dataset @@ -77,8 +79,8 @@ def _get_categorical_idx(feature_name, categorical_feature=[]): def _init_each_iter(self): - idx = self._get_sample_idx(self._iter_id-1, self.dataset.train_X.shape[0], - size = self.params.batch_size) + idx = self._get_sample_idx(self._iter_id-1, self.dataset.train_X.shape[0], + size=self.params.batch_size) feature_select = FeatureSelection.feature_selecting( list(self.dataset.feature_name), self.params.train_feature, self.params.feature_rate) @@ -93,7 +95,8 @@ def _send_d_instance_list(self, d): start_time = time.time() self.log.info(f'task {self.ctx.task_id}: Starting iter-{self._iter_id} ' f'encrypt d in {my_agency_id} party.') - enc_dlist = self.ctx.phe.encrypt_batch_parallel((d_list).astype('object')) + enc_dlist = self.ctx.phe.encrypt_batch_parallel( + (d_list).astype('object')) self.log.info(f'task {self.ctx.task_id}: Finished iter-{self._iter_id} ' f'encrypt d time_costs: {time.time() - start_time}.') @@ -105,7 +108,7 @@ def _send_d_instance_list(self, d): def _receive_d_instance_list(self): my_agency_id = self.ctx.components.config_data['AGENCY_ID'] - + public_key_list = [] d_other_list = [] partner_index_list = [] @@ -127,24 +130,29 @@ def _calculate_deriv(self, x_, d, partner_index_list, d_other_list): # 计算明文*密文 matmul # deriv_other_i = np.matmul(x.T, d_other_list[i]) deriv_other_i = self.enc_matmul(x.T, d_other_list[i]) - + # 发送密文,接受密文并解密 self._send_enc_data(self.ctx, f'{LRMessage.ENC_D_HIST.value}_{self._iter_id}', deriv_other_i, partner_index) _, enc_deriv_i = self._receive_enc_data( self.ctx, f'{LRMessage.ENC_D_HIST.value}_{self._iter_id}', partner_index) - deriv_i_rec = np.array(self.ctx.phe.decrypt_batch(enc_deriv_i), dtype='object') - deriv_i = self.recover_d(self.ctx, deriv_i_rec, is_square=True) / x_.shape[0] - + deriv_i_rec = np.array( + self.ctx.phe.decrypt_batch(enc_deriv_i), dtype='object') + deriv_i = self.recover_d( + self.ctx, deriv_i_rec, is_square=True) / x_.shape[0] + # 发送明文,接受明文并计算 self._send_byte_data(self.ctx, f'{LRMessage.D_MATMUL.value}_{self._iter_id}', deriv_i.astype('float').tobytes(), partner_index) deriv_x_i = np.frombuffer(self._receive_byte_data( - self.ctx, f'{LRMessage.D_MATMUL.value}_{self._iter_id}', partner_index), dtype=np.float) - self.log.info(f'{self.ctx.components.config_data["AGENCY_ID"]}, deriv: {deriv}.') - self.log.info(f'{self.ctx.components.config_data["AGENCY_ID"]}, deriv_x_i: {deriv_x_i}.') + self.ctx, f'{LRMessage.D_MATMUL.value}_{self._iter_id}', partner_index), dtype=np.float) + self.log.info( + f'{self.ctx.components.config_data["AGENCY_ID"]}, deriv: {deriv}.') + self.log.info( + f'{self.ctx.components.config_data["AGENCY_ID"]}, deriv_x_i: {deriv_x_i}.') deriv += deriv_x_i - self.log.info(f'{self.ctx.components.config_data["AGENCY_ID"]}, deriv: {deriv}.') + self.log.info( + f'{self.ctx.components.config_data["AGENCY_ID"]}, deriv: {deriv}.') return deriv def _calculate_deriv1(self, x_, d, partner_index_list, d_other_list): @@ -159,8 +167,10 @@ def _calculate_deriv1(self, x_, d, partner_index_list, d_other_list): deriv_other_i, partner_index) _, enc_deriv_i = self._receive_enc_data( self.ctx, f'{LRMessage.ENC_D_HIST.value}_{self._iter_id}', partner_index) - deriv_i = np.array(self.ctx.phe.decrypt_batch(enc_deriv_i), dtype='object') - deriv += (self.recover_d(self.ctx, deriv_i, is_square=True) / x_.shape[0]) + deriv_i = np.array(self.ctx.phe.decrypt_batch( + enc_deriv_i), dtype='object') + deriv += (self.recover_d(self.ctx, deriv_i, + is_square=True) / x_.shape[0]) return deriv def _send_enc_data(self, ctx, key_type, enc_data, partner_index, matrix_data=False): @@ -245,41 +255,19 @@ def _receive_byte_data(self, ctx, key_type, partner_index): return byte_data def save_model(self, file_path=None): - log = self.ctx.components.logger() - if file_path is not None: - self.ctx.model_data_file = os.path.join( - file_path, self.ctx.MODEL_DATA_FILE) - - if not os.path.exists(self.ctx.model_data_file): - serial_weight = list(self._train_weights) - with open(self.ctx.model_data_file, 'w') as f: - json.dump(serial_weight, f) - ResultFileHandling._upload_file(self.ctx.components.storage_client, - self.ctx.model_data_file, self.ctx.remote_model_data_file) - log.info( - f"task {self.ctx.task_id}: Saved serial_weight to {self.ctx.model_data_file} finished.") + super().save_model(file_path, "lr_model") - self.merge_model_file() + def save_model_hook(self, model_file_path): + pass - def merge_model_file(self): + def merge_model_file(self, lr_model): # 加密文件 - lr_model = {} - lr_model['model_type'] = 'lr_model' - lr_model['label_provider'] = self.ctx.participant_id_list[0] - lr_model['label_column'] = 'y' - lr_model['participant_agency_list'] = [] - for partner_index in range(0, len(self.ctx.participant_id_list)): - agency_info = {'agency': self.ctx.participant_id_list[partner_index]} - agency_info['fields'] = self._all_feature_name[partner_index] - lr_model['participant_agency_list'].append(agency_info) - - lr_model['model_dict'] = self.ctx.model_params.get_all_params() model_text = {} with open(self.ctx.model_data_file, 'rb') as f: model_data = f.read() model_data_enc = encrypt_data(self.ctx.key, model_data) - + my_agency_id = self.ctx.components.config_data['AGENCY_ID'] model_text[my_agency_id] = cipher_to_base64(model_data_enc) @@ -293,7 +281,8 @@ def merge_model_file(self): if self.ctx.participant_id_list[partner_index] != my_agency_id: model_data_enc = self._receive_byte_data( self.ctx, f'{LRMessage.MODEL_DATA.value}_model_data', partner_index) - model_text[self.ctx.participant_id_list[partner_index]] = cipher_to_base64(model_data_enc) + model_text[self.ctx.participant_id_list[partner_index] + ] = cipher_to_base64(model_data_enc) lr_model['model_text'] = model_text # 上传密文模型 @@ -356,8 +345,9 @@ def rounding_d(d_list: np.ndarray, expand=1000): @staticmethod def recover_d(ctx, d_sum_list: np.ndarray, is_square=False, expand=1000): - - d_sum_list[d_sum_list > 2**(ctx.phe.key_length-1)] -= 2**(ctx.phe.key_length) + + d_sum_list[d_sum_list > 2 ** + (ctx.phe.key_length-1)] -= 2**(ctx.phe.key_length) if is_square: return (d_sum_list / expand / expand).astype('float') diff --git a/python/ppc_model/secure_model_base/__init__.py b/python/ppc_model/secure_model_base/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/secure_model_base/secure_model_booster.py b/python/ppc_model/secure_model_base/secure_model_booster.py new file mode 100644 index 00000000..abcfec94 --- /dev/null +++ b/python/ppc_model/secure_model_base/secure_model_booster.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +import os +import json +from ppc_model.interface.model_base import VerticalModel +from ppc_model.common.model_result import ResultFileHandling +from abc import abstractmethod + + +class SecureModelBooster(VerticalModel): + def __init__(self, ctx) -> None: + super().__init__(ctx) + + def save_model(self, file_path=None, model_type=None): + log = self.ctx.components.logger() + if file_path is not None: + self.ctx.model_data_file = os.path.join( + file_path, self.ctx.MODEL_DATA_FILE) + + self.save_model_hook() + if not os.path.exists(self.ctx.model_data_file): + serial_weight = list(self._train_weights) + with open(self.ctx.model_data_file, 'w') as f: + json.dump(serial_weight, f) + ResultFileHandling._upload_file(self.ctx.components.storage_client, + self.ctx.model_data_file, self.ctx.remote_model_data_file) + log.info( + f"task {self.ctx.task_id}: Saved serial_weight to {self.ctx.model_data_file} finished.") + model = {} + model['model_type'] = model_type + model['label_provider'] = self.ctx.participant_id_list[0] + model['label_column'] = 'y' + model['participant_agency_list'] = [] + for partner_index in range(0, len(self.ctx.participant_id_list)): + agency_info = { + 'agency': self.ctx.participant_id_list[partner_index]} + agency_info['fields'] = self._all_feature_name[partner_index] + model['participant_agency_list'].append(agency_info) + + model['model_dict'] = self.ctx.model_params.get_all_params() + self.merge_model_file(model) + + @abstractmethod + def merge_model_file(self, lr_model): + pass + + @abstractmethod + def save_model_hook(self, model_file_path): + pass diff --git a/python/ppc_model/secure_model_base/secure_model_context.py b/python/ppc_model/secure_model_base/secure_model_context.py new file mode 100644 index 00000000..8292d0cb --- /dev/null +++ b/python/ppc_model/secure_model_base/secure_model_context.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- +from abc import abstractmethod +from typing import Any, Dict +import json +import os +from ppc_model.common.context import Context +from ppc_model.common.initializer import Initializer +from ppc_model.common.protocol import TaskRole +from ppc_common.ppc_utils import common_func +from ppc_common.ppc_utils.utils import AlgorithmType +from ppc_model.common.model_setting import ModelSetting + +from sklearn.base import BaseEstimator + + +class SecureModel(BaseEstimator): + + def __init__( + self, + **kwargs): + self.train_feature = [] + self.categorical_feature = None + self.random_state = None + self._other_params: Dict[str, Any] = {} + self.set_params(**kwargs) + + def get_params(self, deep: bool = True) -> Dict[str, Any]: + """Get parameters for this estimator. + + Parameters + ---------- + deep : bool, optional (default=True) + If True, will return the parameters for this estimator and + contained subobjects that are estimators. + + Returns + ------- + params : dict + Parameter names mapped to their values. + """ + params = super().get_params(deep=deep) + params.update(self._other_params) + return params + + def set_model_setting(self, model_setting: ModelSetting): + # 获取对象的所有属性名 + attrs = dir(model_setting) + # 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性) + attrs = [attr for attr in attrs if not attr.startswith('_')] + + params = {} + for attr in attrs: + try: + setattr(self, attr, getattr(model_setting, attr)) + except Exception as e: + pass + return self + + def set_params(self, **params: Any): + """Set the parameters of this estimator. + + Parameters + ---------- + **params + Parameter names with their new values. + + Returns + ------- + self : object + Returns self. + """ + for key, value in params.items(): + setattr(self, key, value) + if hasattr(self, f"_{key}"): + setattr(self, f"_{key}", value) + self._other_params[key] = value + return self + + def get_all_params(self): + """返回SecureLRParams所有参数""" + # 获取对象的所有属性名 + attrs = dir(self) + # 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性) + attrs = [attr for attr in attrs if not attr.startswith('_')] + + params = {} + for attr in attrs: + try: + # 使用getattr来获取属性的值 + value = getattr(self, attr) + # 检查value是否可调用(例如,方法或函数),如果是,则不打印其值 + if not callable(value): + params[attr] = value + except Exception as e: + pass + return params + + +class SecureModelContext(Context): + def __init__(self, + args, + components: Initializer): + + if args['is_label_holder']: + role = TaskRole.ACTIVE_PARTY + else: + role = TaskRole.PASSIVE_PARTY + + super().__init__(args['job_id'], + args['task_id'], + components, + role) + self.is_label_holder = args['is_label_holder'] + self.result_receiver_id_list = args['result_receiver_id_list'] + self.participant_id_list = args['participant_id_list'] + + model_predict_algorithm_str = common_func.get_config_value( + "model_predict_algorithm", None, args, False) + if model_predict_algorithm_str is not None: + self.model_predict_algorithm = json.loads( + model_predict_algorithm_str) + self.algorithm_type = args['algorithm_type'] + self.predict = False + if self.algorithm_type == AlgorithmType.Predict.name: + self.predict = True + # check for the predict task + if self.predict and self.model_predict_algorithm is None: + raise f"Not set model_predict_algorithm for the job: {self.task_id}" + + if 'dataset_id' in args and args['dataset_id'] is not None: + self.dataset_file_path = os.path.join( + self.workspace, args['dataset_id']) + else: + self.dataset_file_path = None + self.model_params = self.create_model_param() + self.reset_model_params(ModelSetting(args['model_dict'])) + self.sync_file_list = {} + if self.algorithm_type == AlgorithmType.Train.name: + self.set_sync_file() + + @abstractmethod + def set_sync_file(self): + pass + + @abstractmethod + def create_model_param(self): + pass + + def reset_model_params(self, model_setting: ModelSetting): + """设置lr参数""" + self.model_params.set_model_setting(model_setting) + if model_setting.train_features is not None and len(model_setting.train_features) > 0: + self.model_params.train_feature = model_setting.train_features.split( + ',') + if model_setting.categorical is not None and len(model_setting.categorical) > 0: + self.model_params.categorical_feature = model_setting.categorical.split( + ',') + if model_setting.seed is not None: + self.model_params.random_state = model_setting.seed