Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor common part of lr-context and xgb-context into SecureModelContext #57

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
refactor common part of lr-context and xgb-context into SecureModelCo…
…ntext
cyjseagull committed Oct 16, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit c67a04714e1db3a4983df1abfa4407804708da80
6 changes: 1 addition & 5 deletions python/ppc_common/ppc_utils/utils.py
Original file line number Diff line number Diff line change
@@ -67,11 +67,6 @@
MPC_TRAIN_SET_METRIC_PR_FILE = "mpc_train_metric_pr.svg"
MPC_TRAIN_SET_METRIC_ACCURACY_FILE = "mpc_train_metric_accuracy.svg"
MPC_TRAIN_SET_METRIC_KS_TABLE = "mpc_train_metric_ks.csv"
MPC_EVAL_METRIC_ROC_FILE = "mpc_eval_metric_roc.svg"
MPC_EVAL_METRIC_KS_FILE = "mpc_eval_metric_ks.svg"
MPC_EVAL_METRIC_PR_FILE = "mpc_eval_metric_pr.svg"
MPC_EVAL_METRIC_ACCURACY_FILE = "mpc_eval_metric_accuracy.svg"
MPC_EVAL_METRIC_KS_TABLE = "mpc_eval_metric_ks.csv"
MPC_TRAIN_METRIC_CONFUSION_MATRIX_FILE = "mpc_metric_confusion_matrix.svg"
METRICS_OVER_ITERATION_FILE = "metrics_over_iterations.svg"

@@ -111,6 +106,7 @@ class CryptoType(Enum):
ECDSA = 1
GM = 2


@unique
class HashType(Enum):
BYTES = 1
1 change: 0 additions & 1 deletion python/ppc_model/common/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import unittest

from ppc_model.common.base_context import BaseContext
from ppc_model.common.initializer import Initializer
7 changes: 4 additions & 3 deletions python/ppc_model/datasets/dataset.py
Original file line number Diff line number Diff line change
@@ -190,9 +190,10 @@ def _dataset_fe_selected(self, file_path, feature_name):

def _construct_dataset(self):
if self.algorithm_type == AlgorithmType.Predict.name:
my_fields = [
item["fields"] for item in self.ctx.model_predict_algorithm['participant_agency_list']
if item["agency"] == self.ctx.components.config_data['AGENCY_ID']]
my_fields = []
for item in self.ctx.model_predict_algorithm['participant_agency_list']:
if item["agency"] == self.ctx.components.config_data['AGENCY_ID']:
my_fields = item["fields"]
if 'y' in self.model_data.columns and 'y' not in my_fields:
my_fields = ['y'] + my_fields
if 'id' in self.model_data.columns and 'id' not in my_fields:
14 changes: 7 additions & 7 deletions python/ppc_model/model_result/task_result_handler.py
Original file line number Diff line number Diff line change
@@ -65,10 +65,10 @@ class JobEvaluationResult:
EvaluationType.ACCURACY: utils.MPC_TRAIN_SET_METRIC_ACCURACY_FILE}

DEFAULT_EVAL_EVALUATION_FILES = {
EvaluationType.ROC: utils.MPC_EVAL_METRIC_ROC_FILE,
EvaluationType.PR: utils.MPC_EVAL_METRIC_PR_FILE,
EvaluationType.KS: utils.MPC_EVAL_METRIC_KS_FILE,
EvaluationType.ACCURACY: utils.MPC_EVAL_METRIC_ACCURACY_FILE
EvaluationType.ROC: utils.MPC_TRAIN_METRIC_ROC_FILE,
EvaluationType.PR: utils.MPC_TRAIN_METRIC_PR_FILE,
EvaluationType.KS: utils.MPC_TRAIN_METRIC_KS_FILE,
EvaluationType.ACCURACY: utils.MPC_TRAIN_METRIC_ACCURACY_FILE
}

def __init__(self, property_name, classification_type,
@@ -357,7 +357,7 @@ def _get_evaluation_result(self):
components=self.components)
# load the ks table
self.train_evaluation_result.load_ks_table(
"mpc_train_metric_ks.csv", "TrainKSTable")
utils.MPC_TRAIN_METRIC_KS_TABLE, "TrainKSTable")
self.result_list.append(self.train_evaluation_result)

self.validation_evaluation_result = JobEvaluationResult(
@@ -368,7 +368,7 @@ def _get_evaluation_result(self):
components=self.components)
# load the ks_table
self.validation_evaluation_result.load_ks_table(
"mpc_metric_ks.csv", "KSTable")
utils.MPC_TRAIN_METRIC_KS_TABLE, "KSTable")
self.result_list.append(self.validation_evaluation_result)

self.model = ModelJobResult(self.xgb_job,
@@ -396,7 +396,7 @@ def _get_evaluation_result(self):
components=self.components)
# load ks_table
self.predict_evaluation_result.load_ks_table(
"mpc_eval_metric_ks.csv", "KSTable")
utils.MPC_TRAIN_METRIC_KS_TABLE, "KSTable")
self.result_list.append(self.predict_evaluation_result)

# load model_result
45 changes: 24 additions & 21 deletions python/ppc_model/ppc_model_app.py
Original file line number Diff line number Diff line change
@@ -2,28 +2,29 @@
import sys
sys.path.append("../")

from ppc_model.secure_lgbm.secure_lgbm_training_engine import SecureLGBMTrainingEngine
from ppc_model.secure_lgbm.secure_lgbm_prediction_engine import SecureLGBMPredictionEngine
from ppc_model.secure_lr.secure_lr_training_engine import SecureLRTrainingEngine
from ppc_model.preprocessing.preprocessing_engine import PreprocessingEngine
from ppc_model.network.http.restx import api
from ppc_model.network.http.model_controller import ns2 as log_namespace
from ppc_model.network.http.model_controller import ns as task_namespace
from ppc_model.network.grpc.grpc_server import ModelService
from ppc_model.feature_engineering.feature_engineering_engine import FeatureEngineeringEngine
from ppc_model.common.protocol import ModelTask
from ppc_model.common.global_context import components
from ppc_common.ppc_utils import utils
from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc
from paste.translogger import TransLogger
from flask import Flask, Blueprint
from cheroot.wsgi import Server as WSGIServer
from cheroot.ssl.builtin import BuiltinSSLAdapter
import grpc
from threading import Thread
from concurrent import futures
import os
import multiprocessing
import os
from concurrent import futures
from threading import Thread
import grpc
from cheroot.ssl.builtin import BuiltinSSLAdapter
from cheroot.wsgi import Server as WSGIServer
from flask import Flask, Blueprint
from paste.translogger import TransLogger
from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc
from ppc_common.ppc_utils import utils
from ppc_model.common.global_context import components
from ppc_model.common.protocol import ModelTask
from ppc_model.feature_engineering.feature_engineering_engine import FeatureEngineeringEngine
from ppc_model.network.grpc.grpc_server import ModelService
from ppc_model.network.http.model_controller import ns as task_namespace
from ppc_model.network.http.model_controller import ns2 as log_namespace
from ppc_model.network.http.restx import api
from ppc_model.preprocessing.preprocessing_engine import PreprocessingEngine
from ppc_model.secure_lr.secure_lr_prediction_engine import SecureLRPredictionEngine
from ppc_model.secure_lr.secure_lr_training_engine import SecureLRTrainingEngine
from ppc_model.secure_lgbm.secure_lgbm_prediction_engine import SecureLGBMPredictionEngine
from ppc_model.secure_lgbm.secure_lgbm_training_engine import SecureLGBMTrainingEngine


app = Flask(__name__)
@@ -53,6 +54,8 @@ def register_task_handler():
ModelTask.XGB_PREDICTING, SecureLGBMPredictionEngine.run)
task_manager.register_task_handler(
ModelTask.LR_TRAINING, SecureLRTrainingEngine.run)
task_manager.register_task_handler(
ModelTask.LR_PREDICTING, SecureLRPredictionEngine.run)


def model_serve():
169 changes: 36 additions & 133 deletions python/ppc_model/secure_lgbm/secure_lgbm_context.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@

import os
from enum import Enum
from typing import Any, Dict
from sklearn.base import BaseEstimator

from ppc_common.ppc_utils.utils import AlgorithmType
from ppc_common.ppc_crypto.phe_factory import PheCipherFactory
from ppc_model.common.context import Context
from ppc_model.common.initializer import Initializer
from ppc_model.common.protocol import TaskRole
from ppc_common.ppc_utils import common_func
from ppc_model.common.model_setting import ModelSetting
from ppc_model.secure_model_base.secure_model_context import SecureModel
from ppc_model.secure_model_base.secure_model_context import SecureModelContext


class LGBMModel(BaseEstimator):
class LGBMModel(SecureModel):

def __init__(
self,
@@ -36,7 +37,6 @@ def __init__(
importance_type: str = 'split',
**kwargs
):

self.boosting_type = boosting_type
self.objective = objective
self.num_leaves = num_leaves
@@ -55,60 +55,7 @@ def __init__(
self.random_state = random_state
self.n_jobs = n_jobs
self.importance_type = importance_type
self._other_params: Dict[str, Any] = {}
self.set_params(**kwargs)

def get_params(self, deep: bool = True) -> Dict[str, Any]:
"""Get parameters for this estimator.

Parameters
----------
deep : bool, optional (default=True)
If True, will return the parameters for this estimator and
contained subobjects that are estimators.

Returns
-------
params : dict
Parameter names mapped to their values.
"""
params = super().get_params(deep=deep)
params.update(self._other_params)
return params

def set_model_setting(self, model_setting: ModelSetting) -> "LGBMModel":
# 获取对象的所有属性名
attrs = dir(model_setting)
# 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性)
attrs = [attr for attr in attrs if not attr.startswith('_')]

params = {}
for attr in attrs:
try:
setattr(self, attr, getattr(model_setting, attr))
except Exception as e:
pass
return self

def set_params(self, **params: Any) -> "LGBMModel":
"""Set the parameters of this estimator.

Parameters
----------
**params
Parameter names with their new values.

Returns
-------
self : object
Returns self.
"""
for key, value in params.items():
setattr(self, key, value)
if hasattr(self, f"_{key}"):
setattr(self, f"_{key}", value)
self._other_params[key] = value
return self
super().__init__(**kwargs)


class ModelTaskParams(LGBMModel):
@@ -169,98 +116,54 @@ def _get_params(self):
"""返回LGBMClassifier所有参数"""
return LGBMModel().get_params()

def get_all_params(self):
"""返回SecureLGBMParams所有参数"""
# 获取对象的所有属性名
attrs = dir(self)
# 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性)
attrs = [attr for attr in attrs if not attr.startswith('_')]

params = {}
for attr in attrs:
try:
# 使用getattr来获取属性的值
value = getattr(self, attr)
# 检查value是否可调用(例如,方法或函数),如果是,则不打印其值
if not callable(value):
params[attr] = value
except Exception as e:
pass
return params


class SecureLGBMContext(Context):
class SecureLGBMContext(SecureModelContext):

def __init__(self,
args,
components: Initializer
):

if args['is_label_holder']:
role = TaskRole.ACTIVE_PARTY
else:
role = TaskRole.PASSIVE_PARTY

super().__init__(args['job_id'],
args['task_id'],
components,
role)
super().__init__(args, components)

self.phe = PheCipherFactory.build_phe(
components.homo_algorithm, components.public_key_length)
self.codec = PheCipherFactory.build_codec(components.homo_algorithm)
self.is_label_holder = args['is_label_holder']
self.result_receiver_id_list = args['result_receiver_id_list']
self.participant_id_list = args['participant_id_list']
self.model_predict_algorithm = common_func.get_config_value(
"model_predict_algorithm", None, args, False)
self.algorithm_type = args['algorithm_type']
if 'dataset_id' in args and args['dataset_id'] is not None:
self.dataset_file_path = os.path.join(
self.workspace, args['dataset_id'])
else:
self.dataset_file_path = None

self.model_params = SecureLGBMParams()
model_setting = ModelSetting(args['model_dict'])
self.set_model_params(model_setting)
if model_setting.train_features is not None and len(model_setting.train_features) > 0:
self.model_params.train_feature = model_setting.train_features.split(
',')
if model_setting.categorical is not None and len(model_setting.categorical) > 0:
self.model_params.categorical_feature = model_setting.categorical.split(
',')
self.model_params.n_estimators = model_setting.num_trees
self.model_params.feature_rate = model_setting.colsample_bytree
self.model_params.min_split_gain = model_setting.gamma
self.model_params.random_state = model_setting.seed

self.sync_file_list = {}
if self.algorithm_type == AlgorithmType.Train.name:
self.set_sync_file()

def set_model_params(self, model_setting: ModelSetting):
"""设置lgbm参数"""
self.model_params.set_model_setting(model_setting)
def create_model_param(self):
return SecureLGBMParams()

def get_model_params(self):
"""获取lgbm参数"""
return self.model_params

def set_sync_file(self):
self.sync_file_list['metrics_iteration'] = [self.metrics_iteration_file, self.remote_metrics_iteration_file]
self.sync_file_list['feature_importance'] = [self.feature_importance_file, self.remote_feature_importance_file]
self.sync_file_list['summary_evaluation'] = [self.summary_evaluation_file, self.remote_summary_evaluation_file]
self.sync_file_list['train_ks_table'] = [self.train_metric_ks_table, self.remote_train_metric_ks_table]
self.sync_file_list['train_metric_roc'] = [self.train_metric_roc_file, self.remote_train_metric_roc_file]
self.sync_file_list['train_metric_ks'] = [self.train_metric_ks_file, self.remote_train_metric_ks_file]
self.sync_file_list['train_metric_pr'] = [self.train_metric_pr_file, self.remote_train_metric_pr_file]
self.sync_file_list['train_metric_acc'] = [self.train_metric_acc_file, self.remote_train_metric_acc_file]
self.sync_file_list['test_ks_table'] = [self.test_metric_ks_table, self.remote_test_metric_ks_table]
self.sync_file_list['test_metric_roc'] = [self.test_metric_roc_file, self.remote_test_metric_roc_file]
self.sync_file_list['test_metric_ks'] = [self.test_metric_ks_file, self.remote_test_metric_ks_file]
self.sync_file_list['test_metric_pr'] = [self.test_metric_pr_file, self.remote_test_metric_pr_file]
self.sync_file_list['test_metric_acc'] = [self.test_metric_acc_file, self.remote_test_metric_acc_file]
self.sync_file_list['metrics_iteration'] = [
self.metrics_iteration_file, self.remote_metrics_iteration_file]
self.sync_file_list['feature_importance'] = [
self.feature_importance_file, self.remote_feature_importance_file]
self.sync_file_list['summary_evaluation'] = [
self.summary_evaluation_file, self.remote_summary_evaluation_file]
self.sync_file_list['train_ks_table'] = [
self.train_metric_ks_table, self.remote_train_metric_ks_table]
self.sync_file_list['train_metric_roc'] = [
self.train_metric_roc_file, self.remote_train_metric_roc_file]
self.sync_file_list['train_metric_ks'] = [
self.train_metric_ks_file, self.remote_train_metric_ks_file]
self.sync_file_list['train_metric_pr'] = [
self.train_metric_pr_file, self.remote_train_metric_pr_file]
self.sync_file_list['train_metric_acc'] = [
self.train_metric_acc_file, self.remote_train_metric_acc_file]
self.sync_file_list['test_ks_table'] = [
self.test_metric_ks_table, self.remote_test_metric_ks_table]
self.sync_file_list['test_metric_roc'] = [
self.test_metric_roc_file, self.remote_test_metric_roc_file]
self.sync_file_list['test_metric_ks'] = [
self.test_metric_ks_file, self.remote_test_metric_ks_file]
self.sync_file_list['test_metric_pr'] = [
self.test_metric_pr_file, self.remote_test_metric_pr_file]
self.sync_file_list['test_metric_acc'] = [
self.test_metric_acc_file, self.remote_test_metric_acc_file]


class LGBMMessage(Enum):
FEATURE_NAME = "FEATURE_NAME"
53 changes: 17 additions & 36 deletions python/ppc_model/secure_lgbm/vertical/booster.py
Original file line number Diff line number Diff line change
@@ -13,11 +13,12 @@
from ppc_model.network.stub import PushRequest, PullRequest
from ppc_model.common.model_result import ResultFileHandling
from ppc_model.datasets.feature_binning.feature_binning import FeatureBinning
from ppc_model.secure_model_base.secure_model_booster import SecureModelBooster
from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext, LGBMMessage


# 抽离sgb的公共部分
class VerticalBooster(VerticalModel):
class VerticalBooster(SecureModelBooster):
def __init__(self, ctx: SecureLGBMContext, dataset: SecureDataset) -> None:
super().__init__(ctx)
self.dataset = dataset
@@ -129,7 +130,6 @@ def _get_leaf_mask(self, split_info, instance):
return left_mask, right_mask

def _send_enc_data(self, ctx, key_type, enc_data, partner_index, matrix_data=False):
log = ctx.components.logger()
start_time = time.time()
partner_id = ctx.participant_id_list[partner_index]

@@ -150,12 +150,11 @@ def _send_enc_data(self, ctx, key_type, enc_data, partner_index, matrix_data=Fal
ctx.codec, ctx.phe.public_key, enc_data)
))

log.info(
self.logger.info(
f"task {ctx.task_id}: Sending {key_type} to {partner_id} finished, "
f"data_length: {len(enc_data)}, time_costs: {time.time() - start_time}s")

def _receive_enc_data(self, ctx, key_type, partner_index, matrix_data=False):
log = ctx.components.logger()
start_time = time.time()
partner_id = ctx.participant_id_list[partner_index]

@@ -172,13 +171,12 @@ def _receive_enc_data(self, ctx, key_type, partner_index, matrix_data=False):
public_key, enc_data = PheMessage.unpacking_data(
ctx.codec, byte_data)

log.info(
self.logger.info(
f"task {ctx.task_id}: Received {key_type} from {partner_id} finished, "
f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s")
return public_key, enc_data

def _send_byte_data(self, ctx, key_type, byte_data, partner_index):
log = ctx.components.logger()
start_time = time.time()
partner_id = ctx.participant_id_list[partner_index]

@@ -189,12 +187,11 @@ def _send_byte_data(self, ctx, key_type, byte_data, partner_index):
data=byte_data
))

log.info(
self.logger.info(
f"task {ctx.task_id}: Sending {key_type} to {partner_id} finished, "
f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s")

def _receive_byte_data(self, ctx, key_type, partner_index):
log = ctx.components.logger()
start_time = time.time()
partner_id = ctx.participant_id_list[partner_index]

@@ -204,7 +201,7 @@ def _receive_byte_data(self, ctx, key_type, partner_index):
key=key_type
))

log.info(
self.logger.info(
f"task {ctx.task_id}: Received {key_type} from {partner_id} finished, "
f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s")
return byte_data
@@ -215,49 +212,33 @@ def _split_test_data(ctx, test_X, X_split):
return feat_bin.data_binning(test_X, X_split)[0]

def save_model(self, file_path=None):
log = self.ctx.components.logger()
super().save_model(file_path, "lgbm_model")

def save_model_hook(self, file_path):
# save the feature_bin
if file_path is not None:
self.ctx.feature_bin_file = os.path.join(
file_path, self.ctx.FEATURE_BIN_FILE)
self.ctx.model_data_file = os.path.join(
file_path, self.ctx.MODEL_DATA_FILE)

if self._X_split is not None and not os.path.exists(self.ctx.feature_bin_file):
X_split_dict = {k: v for k, v in zip(
self.dataset.feature_name, self._X_split)}
with open(self.ctx.feature_bin_file, 'w') as f:
json.dump(X_split_dict, f)
ResultFileHandling._upload_file(self.ctx.components.storage_client,
self.ctx.feature_bin_file, self.ctx.remote_feature_bin_file)
log.info(
self.logger.info(
f"task {self.ctx.task_id}: Saved x_split to {self.ctx.feature_bin_file} finished.")

if not os.path.exists(self.ctx.model_data_file):
serial_trees = [self._serial_tree(tree) for tree in self._trees]
with open(self.ctx.model_data_file, 'w') as f:
json.dump(serial_trees, f)
ResultFileHandling._upload_file(self.ctx.components.storage_client,
self.ctx.model_data_file, self.ctx.remote_model_data_file)
log.info(
self.logger.info(
f"task {self.ctx.task_id}: Saved serial_trees to {self.ctx.model_data_file} finished.")

self.merge_model_file()

def merge_model_file(self):

def merge_model_file(self, lgbm_model):
# 加密文件
lgbm_model = {}
lgbm_model['model_type'] = 'xgb_model'
lgbm_model['label_provider'] = self.ctx.participant_id_list[0]
lgbm_model['label_column'] = 'y'
lgbm_model['participant_agency_list'] = []
for partner_index in range(0, len(self.ctx.participant_id_list)):
agency_info = {
'agency': self.ctx.participant_id_list[partner_index]}
agency_info['fields'] = self._all_feature_name[partner_index]
lgbm_model['participant_agency_list'].append(agency_info)

lgbm_model['model_dict'] = self.ctx.model_params.get_all_params()
model_text = {}
with open(self.ctx.feature_bin_file, 'rb') as f:
feature_bin_data = f.read()
@@ -295,10 +276,11 @@ def merge_model_file(self):
json.dump(lgbm_model, f)
ResultFileHandling._upload_file(self.ctx.components.storage_client,
self.ctx.model_enc_file, self.ctx.remote_model_enc_file)
self.ctx.components.logger().info(
self.logger.info(
f"task {self.ctx.task_id}: Saved enc model to {self.ctx.model_enc_file} finished.")

def split_model_file(self):

# 传入模型
my_agency_id = self.ctx.components.config_data['AGENCY_ID']
model_text = self.ctx.model_predict_algorithm['model_text']
@@ -314,7 +296,6 @@ def split_model_file(self):
f.write(model_data)

def load_model(self, file_path=None):
log = self.ctx.components.logger()
if file_path is not None:
self.ctx.feature_bin_file = os.path.join(
file_path, self.ctx.FEATURE_BIN_FILE)
@@ -333,13 +314,13 @@ def load_model(self, file_path=None):
X_split_dict = json.load(f)
feature_name = list(X_split_dict.keys())
x_split = list(X_split_dict.values())
log.info(
self.logger.info(
f"task {self.ctx.task_id}: Load x_split from {self.ctx.feature_bin_file} finished.")
assert len(feature_name) == len(self.dataset.feature_name)

with open(self.ctx.model_data_file, 'r') as f:
serial_trees = json.load(f)
log.info(
self.logger.info(
f"task {self.ctx.task_id}: Load serial_trees from {self.ctx.model_data_file} finished.")

trees = [self._deserial_tree(tree) for tree in serial_trees]
127 changes: 9 additions & 118 deletions python/ppc_model/secure_lr/secure_lr_context.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import os
from enum import Enum
from typing import Any, Dict
from sklearn.base import BaseEstimator

from ppc_common.ppc_utils.utils import AlgorithmType
from ppc_common.ppc_crypto.phe_factory import PheCipherFactory
from ppc_model.common.context import Context
from ppc_model.common.initializer import Initializer
from ppc_model.common.protocol import TaskRole
from ppc_common.ppc_utils import common_func
from ppc_model.common.model_setting import ModelSetting
from ppc_model.secure_model_base.secure_model_context import SecureModel
from ppc_model.secure_model_base.secure_model_context import SecureModelContext


class LRModel(BaseEstimator):
class LRModel(SecureModel):

def __init__(
self,
@@ -29,60 +27,7 @@ def __init__(
self.learning_rate = learning_rate
self.random_state = random_state
self.n_jobs = n_jobs
self._other_params: Dict[str, Any] = {}
self.set_params(**kwargs)

def get_params(self, deep: bool = True) -> Dict[str, Any]:
"""Get parameters for this estimator.
Parameters
----------
deep : bool, optional (default=True)
If True, will return the parameters for this estimator and
contained subobjects that are estimators.
Returns
-------
params : dict
Parameter names mapped to their values.
"""
params = super().get_params(deep=deep)
params.update(self._other_params)
return params

def set_model_setting(self, model_setting: ModelSetting) -> "LRModel":
# 获取对象的所有属性名
attrs = dir(model_setting)
# 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性)
attrs = [attr for attr in attrs if not attr.startswith('_')]

params = {}
for attr in attrs:
try:
setattr(self, attr, getattr(model_setting, attr))
except Exception as e:
pass
return self

def set_params(self, **params: Any) -> "LRModel":
"""Set the parameters of this estimator.
Parameters
----------
**params
Parameter names with their new values.
Returns
-------
self : object
Returns self.
"""
for key, value in params.items():
setattr(self, key, value)
if hasattr(self, f"_{key}"):
setattr(self, f"_{key}", value)
self._other_params[key] = value
return self
super().__init__(**kwargs)


class ModelTaskParams(LRModel):
@@ -124,75 +69,21 @@ def _get_params(self):
"""返回LRClassifier所有参数"""
return LRModel().get_params()

def get_all_params(self):
"""返回SecureLRParams所有参数"""
# 获取对象的所有属性名
attrs = dir(self)
# 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性)
attrs = [attr for attr in attrs if not attr.startswith('_')]

params = {}
for attr in attrs:
try:
# 使用getattr来获取属性的值
value = getattr(self, attr)
# 检查value是否可调用(例如,方法或函数),如果是,则不打印其值
if not callable(value):
params[attr] = value
except Exception as e:
pass
return params


class SecureLRContext(Context):
class SecureLRContext(SecureModelContext):

def __init__(self,
args,
components: Initializer
):

if args['is_label_holder']:
role = TaskRole.ACTIVE_PARTY
else:
role = TaskRole.PASSIVE_PARTY

super().__init__(args['job_id'],
args['task_id'],
components,
role)
super().__init__(args, components)

self.phe = PheCipherFactory.build_phe(
components.homo_algorithm, components.public_key_length)
self.codec = PheCipherFactory.build_codec(components.homo_algorithm)
self.is_label_holder = args['is_label_holder']
self.result_receiver_id_list = args['result_receiver_id_list']
self.participant_id_list = args['participant_id_list']
self.model_predict_algorithm = common_func.get_config_value(
"model_predict_algorithm", None, args, False)
self.algorithm_type = args['algorithm_type']
if 'dataset_id' in args and args['dataset_id'] is not None:
self.dataset_file_path = os.path.join(
self.workspace, args['dataset_id'])
else:
self.dataset_file_path = None

self.model_params = SecureLRParams()
model_setting = ModelSetting(args['model_dict'])
self.set_model_params(model_setting)
if model_setting.train_features is not None and len(model_setting.train_features) > 0:
self.model_params.train_feature = model_setting.train_features.split(
',')
if model_setting.categorical is not None and len(model_setting.categorical) > 0:
self.model_params.categorical_feature = model_setting.categorical.split(
',')
self.model_params.random_state = model_setting.seed
self.sync_file_list = {}
if self.algorithm_type == AlgorithmType.Train.name:
self.set_sync_file()

def set_model_params(self, model_setting: ModelSetting):
"""设置lr参数"""
self.model_params.set_model_setting(model_setting)

def create_model_param(self):
return SecureLRParams()

def get_model_params(self):
"""获取lr参数"""
2 changes: 1 addition & 1 deletion python/ppc_model/secure_lr/secure_lr_prediction_engine.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from ppc_model.secure_lr.vertical import VerticalLRActiveParty, VerticalLRPassiveParty


class SecureLGBMPredictionEngine(TaskEngine):
class SecureLRPredictionEngine(TaskEngine):
task_type = ModelTask.LR_PREDICTING

@staticmethod
100 changes: 46 additions & 54 deletions python/ppc_model/secure_lr/vertical/booster.py
Original file line number Diff line number Diff line change
@@ -16,10 +16,12 @@
from ppc_model.common.model_result import ResultFileHandling
from ppc_model.datasets.feature_binning.feature_binning import FeatureBinning
from ppc_model.secure_lr.secure_lr_context import SecureLRContext, LRMessage

from ppc_model.secure_model_base.secure_model_booster import SecureModelBooster

# 抽离sgb的公共部分
class VerticalBooster(VerticalModel):


class VerticalBooster(SecureModelBooster):
def __init__(self, ctx: SecureLRContext, dataset: SecureDataset) -> None:
super().__init__(ctx)
self.dataset = dataset
@@ -77,8 +79,8 @@ def _get_categorical_idx(feature_name, categorical_feature=[]):

def _init_each_iter(self):

idx = self._get_sample_idx(self._iter_id-1, self.dataset.train_X.shape[0],
size = self.params.batch_size)
idx = self._get_sample_idx(self._iter_id-1, self.dataset.train_X.shape[0],
size=self.params.batch_size)
feature_select = FeatureSelection.feature_selecting(
list(self.dataset.feature_name),
self.params.train_feature, self.params.feature_rate)
@@ -91,11 +93,12 @@ def _send_d_instance_list(self, d):
my_agency_id = self.ctx.components.config_data['AGENCY_ID']

start_time = time.time()
self.log.info(f'task {self.ctx.task_id}: Starting iter-{self._iter_id} '
f'encrypt d in {my_agency_id} party.')
enc_dlist = self.ctx.phe.encrypt_batch_parallel((d_list).astype('object'))
self.log.info(f'task {self.ctx.task_id}: Finished iter-{self._iter_id} '
f'encrypt d time_costs: {time.time() - start_time}.')
self.logger.info(f'task {self.ctx.task_id}: Starting iter-{self._iter_id} '
f'encrypt d in {my_agency_id} party.')
enc_dlist = self.ctx.phe.encrypt_batch_parallel(
(d_list).astype('object'))
self.logger.info(f'task {self.ctx.task_id}: Finished iter-{self._iter_id} '
f'encrypt d time_costs: {time.time() - start_time}.')

for partner_index in range(len(self.ctx.participant_id_list)):
if self.ctx.participant_id_list[partner_index] != my_agency_id:
@@ -105,7 +108,7 @@ def _send_d_instance_list(self, d):
def _receive_d_instance_list(self):

my_agency_id = self.ctx.components.config_data['AGENCY_ID']

public_key_list = []
d_other_list = []
partner_index_list = []
@@ -127,24 +130,29 @@ def _calculate_deriv(self, x_, d, partner_index_list, d_other_list):
# 计算明文*密文 matmul
# deriv_other_i = np.matmul(x.T, d_other_list[i])
deriv_other_i = self.enc_matmul(x.T, d_other_list[i])

# 发送密文,接受密文并解密
self._send_enc_data(self.ctx, f'{LRMessage.ENC_D_HIST.value}_{self._iter_id}',
deriv_other_i, partner_index)
_, enc_deriv_i = self._receive_enc_data(
self.ctx, f'{LRMessage.ENC_D_HIST.value}_{self._iter_id}', partner_index)
deriv_i_rec = np.array(self.ctx.phe.decrypt_batch(enc_deriv_i), dtype='object')
deriv_i = self.recover_d(self.ctx, deriv_i_rec, is_square=True) / x_.shape[0]

deriv_i_rec = np.array(
self.ctx.phe.decrypt_batch(enc_deriv_i), dtype='object')
deriv_i = self.recover_d(
self.ctx, deriv_i_rec, is_square=True) / x_.shape[0]

# 发送明文,接受明文并计算
self._send_byte_data(self.ctx, f'{LRMessage.D_MATMUL.value}_{self._iter_id}',
deriv_i.astype('float').tobytes(), partner_index)
deriv_x_i = np.frombuffer(self._receive_byte_data(
self.ctx, f'{LRMessage.D_MATMUL.value}_{self._iter_id}', partner_index), dtype=np.float)
self.log.info(f'{self.ctx.components.config_data["AGENCY_ID"]}, deriv: {deriv}.')
self.log.info(f'{self.ctx.components.config_data["AGENCY_ID"]}, deriv_x_i: {deriv_x_i}.')
self.ctx, f'{LRMessage.D_MATMUL.value}_{self._iter_id}', partner_index), dtype=np.float)
self.logger.info(
f'{self.ctx.components.config_data["AGENCY_ID"]}, deriv: {deriv}.')
self.logger.info(
f'{self.ctx.components.config_data["AGENCY_ID"]}, deriv_x_i: {deriv_x_i}.')
deriv += deriv_x_i
self.log.info(f'{self.ctx.components.config_data["AGENCY_ID"]}, deriv: {deriv}.')
self.logger.info(
f'{self.ctx.components.config_data["AGENCY_ID"]}, deriv: {deriv}.')
return deriv

def _calculate_deriv1(self, x_, d, partner_index_list, d_other_list):
@@ -159,12 +167,13 @@ def _calculate_deriv1(self, x_, d, partner_index_list, d_other_list):
deriv_other_i, partner_index)
_, enc_deriv_i = self._receive_enc_data(
self.ctx, f'{LRMessage.ENC_D_HIST.value}_{self._iter_id}', partner_index)
deriv_i = np.array(self.ctx.phe.decrypt_batch(enc_deriv_i), dtype='object')
deriv += (self.recover_d(self.ctx, deriv_i, is_square=True) / x_.shape[0])
deriv_i = np.array(self.ctx.phe.decrypt_batch(
enc_deriv_i), dtype='object')
deriv += (self.recover_d(self.ctx, deriv_i,
is_square=True) / x_.shape[0])
return deriv

def _send_enc_data(self, ctx, key_type, enc_data, partner_index, matrix_data=False):
log = ctx.components.logger()
start_time = time.time()
partner_id = ctx.participant_id_list[partner_index]

@@ -185,12 +194,11 @@ def _send_enc_data(self, ctx, key_type, enc_data, partner_index, matrix_data=Fal
ctx.codec, ctx.phe.public_key, enc_data)
))

log.info(
self.logger.info(
f"task {ctx.task_id}: Sending {key_type} to {partner_id} finished, "
f"data_length: {len(enc_data)}, time_costs: {time.time() - start_time}s")

def _receive_enc_data(self, ctx, key_type, partner_index, matrix_data=False):
log = ctx.components.logger()
start_time = time.time()
partner_id = ctx.participant_id_list[partner_index]

@@ -207,13 +215,12 @@ def _receive_enc_data(self, ctx, key_type, partner_index, matrix_data=False):
public_key, enc_data = PheMessage.unpacking_data(
ctx.codec, byte_data)

log.info(
self.logger.info(
f"task {ctx.task_id}: Received {key_type} from {partner_id} finished, "
f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s")
return public_key, enc_data

def _send_byte_data(self, ctx, key_type, byte_data, partner_index):
log = ctx.components.logger()
start_time = time.time()
partner_id = ctx.participant_id_list[partner_index]

@@ -224,12 +231,11 @@ def _send_byte_data(self, ctx, key_type, byte_data, partner_index):
data=byte_data
))

log.info(
self.logger.info(
f"task {ctx.task_id}: Sending {key_type} to {partner_id} finished, "
f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s")

def _receive_byte_data(self, ctx, key_type, partner_index):
log = ctx.components.logger()
start_time = time.time()
partner_id = ctx.participant_id_list[partner_index]

@@ -239,47 +245,32 @@ def _receive_byte_data(self, ctx, key_type, partner_index):
key=key_type
))

log.info(
self.logger.info(
f"task {ctx.task_id}: Received {key_type} from {partner_id} finished, "
f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s")
return byte_data

def save_model(self, file_path=None):
log = self.ctx.components.logger()
if file_path is not None:
self.ctx.model_data_file = os.path.join(
file_path, self.ctx.MODEL_DATA_FILE)
super().save_model(file_path, "lr_model")

def save_model_hook(self, model_file_path):
if not os.path.exists(self.ctx.model_data_file):
serial_weight = list(self._train_weights)
with open(self.ctx.model_data_file, 'w') as f:
json.dump(serial_weight, f)
ResultFileHandling._upload_file(self.ctx.components.storage_client,
self.ctx.model_data_file, self.ctx.remote_model_data_file)
log.info(
self.logger.info(
f"task {self.ctx.task_id}: Saved serial_weight to {self.ctx.model_data_file} finished.")

self.merge_model_file()

def merge_model_file(self):
def merge_model_file(self, lr_model):

# 加密文件
lr_model = {}
lr_model['model_type'] = 'lr_model'
lr_model['label_provider'] = self.ctx.participant_id_list[0]
lr_model['label_column'] = 'y'
lr_model['participant_agency_list'] = []
for partner_index in range(0, len(self.ctx.participant_id_list)):
agency_info = {'agency': self.ctx.participant_id_list[partner_index]}
agency_info['fields'] = self._all_feature_name[partner_index]
lr_model['participant_agency_list'].append(agency_info)

lr_model['model_dict'] = self.ctx.model_params.get_all_params()
model_text = {}
with open(self.ctx.model_data_file, 'rb') as f:
model_data = f.read()
model_data_enc = encrypt_data(self.ctx.key, model_data)

my_agency_id = self.ctx.components.config_data['AGENCY_ID']
model_text[my_agency_id] = cipher_to_base64(model_data_enc)

@@ -293,15 +284,16 @@ def merge_model_file(self):
if self.ctx.participant_id_list[partner_index] != my_agency_id:
model_data_enc = self._receive_byte_data(
self.ctx, f'{LRMessage.MODEL_DATA.value}_model_data', partner_index)
model_text[self.ctx.participant_id_list[partner_index]] = cipher_to_base64(model_data_enc)
model_text[self.ctx.participant_id_list[partner_index]
] = cipher_to_base64(model_data_enc)
lr_model['model_text'] = model_text

# 上传密文模型
with open(self.ctx.model_enc_file, 'w') as f:
json.dump(lr_model, f)
ResultFileHandling._upload_file(self.ctx.components.storage_client,
self.ctx.model_enc_file, self.ctx.remote_model_enc_file)
self.ctx.components.logger().info(
self.logger.info(
f"task {self.ctx.task_id}: Saved enc model to {self.ctx.model_enc_file} finished.")

def split_model_file(self):
@@ -316,7 +308,6 @@ def split_model_file(self):
f.write(model_data)

def load_model(self, file_path=None):
log = self.ctx.components.logger()
if file_path is not None:
self.ctx.model_data_file = os.path.join(
file_path, self.ctx.MODEL_DATA_FILE)
@@ -330,7 +321,7 @@ def load_model(self, file_path=None):
with open(self.ctx.model_data_file, 'r') as f:
serial_weight = json.load(f)
self._train_weights = np.array(serial_weight)
log.info(
self.logger.info(
f"task {self.ctx.task_id}: Load serial_weight from {self.ctx.model_data_file} finished.")

def get_weights(self):
@@ -356,8 +347,9 @@ def rounding_d(d_list: np.ndarray, expand=1000):

@staticmethod
def recover_d(ctx, d_sum_list: np.ndarray, is_square=False, expand=1000):

d_sum_list[d_sum_list > 2**(ctx.phe.key_length-1)] -= 2**(ctx.phe.key_length)

d_sum_list[d_sum_list > 2 **
(ctx.phe.key_length-1)] -= 2**(ctx.phe.key_length)

if is_square:
return (d_sum_list / expand / expand).astype('float')
Empty file.
40 changes: 40 additions & 0 deletions python/ppc_model/secure_model_base/secure_model_booster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
import os
import json
from ppc_model.interface.model_base import VerticalModel
from ppc_model.common.model_result import ResultFileHandling
from abc import abstractmethod


class SecureModelBooster(VerticalModel):
def __init__(self, ctx) -> None:
super().__init__(ctx)
self.logger = self.ctx.components.logger()

def save_model(self, file_path=None, model_type=None):
if file_path is not None:
self.ctx.model_data_file = os.path.join(
file_path, self.ctx.MODEL_DATA_FILE)

self.save_model_hook(file_path)
model = {}
model['model_type'] = model_type
model['label_provider'] = self.ctx.participant_id_list[0]
model['label_column'] = 'y'
model['participant_agency_list'] = []
for partner_index in range(0, len(self.ctx.participant_id_list)):
agency_info = {
'agency': self.ctx.participant_id_list[partner_index]}
agency_info['fields'] = self._all_feature_name[partner_index]
model['participant_agency_list'].append(agency_info)

model['model_dict'] = self.ctx.model_params.get_all_params()
self.merge_model_file(model)

@abstractmethod
def merge_model_file(self, lr_model):
pass

@abstractmethod
def save_model_hook(self, model_file_path):
pass
159 changes: 159 additions & 0 deletions python/ppc_model/secure_model_base/secure_model_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# -*- coding: utf-8 -*-
from abc import abstractmethod
from typing import Any, Dict
import json
import os
from ppc_model.common.context import Context
from ppc_model.common.initializer import Initializer
from ppc_model.common.protocol import TaskRole
from ppc_common.ppc_utils import common_func
from ppc_common.ppc_utils.utils import AlgorithmType
from ppc_model.common.model_setting import ModelSetting

from sklearn.base import BaseEstimator


class SecureModel(BaseEstimator):

def __init__(
self,
**kwargs):
self.train_feature = []
self.categorical_feature = None
self.random_state = None
self._other_params: Dict[str, Any] = {}
self.set_params(**kwargs)

def get_params(self, deep: bool = True) -> Dict[str, Any]:
"""Get parameters for this estimator.
Parameters
----------
deep : bool, optional (default=True)
If True, will return the parameters for this estimator and
contained subobjects that are estimators.
Returns
-------
params : dict
Parameter names mapped to their values.
"""
params = super().get_params(deep=deep)
params.update(self._other_params)
return params

def set_model_setting(self, model_setting: ModelSetting):
# 获取对象的所有属性名
attrs = dir(model_setting)
# 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性)
attrs = [attr for attr in attrs if not attr.startswith('_')]

params = {}
for attr in attrs:
try:
setattr(self, attr, getattr(model_setting, attr))
except Exception as e:
pass
return self

def set_params(self, **params: Any):
"""Set the parameters of this estimator.
Parameters
----------
**params
Parameter names with their new values.
Returns
-------
self : object
Returns self.
"""
for key, value in params.items():
setattr(self, key, value)
if hasattr(self, f"_{key}"):
setattr(self, f"_{key}", value)
self._other_params[key] = value
return self

def get_all_params(self):
"""返回SecureLRParams所有参数"""
# 获取对象的所有属性名
attrs = dir(self)
# 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性)
attrs = [attr for attr in attrs if not attr.startswith('_')]

params = {}
for attr in attrs:
try:
# 使用getattr来获取属性的值
value = getattr(self, attr)
# 检查value是否可调用(例如,方法或函数),如果是,则不打印其值
if not callable(value):
params[attr] = value
except Exception as e:
pass
return params


class SecureModelContext(Context):
def __init__(self,
args,
components: Initializer):

if args['is_label_holder']:
role = TaskRole.ACTIVE_PARTY
else:
role = TaskRole.PASSIVE_PARTY

super().__init__(args['job_id'],
args['task_id'],
components,
role)
self.is_label_holder = args['is_label_holder']
self.result_receiver_id_list = args['result_receiver_id_list']
self.participant_id_list = args['participant_id_list']

model_predict_algorithm_str = common_func.get_config_value(
"model_predict_algorithm", None, args, False)
if model_predict_algorithm_str is not None:
self.model_predict_algorithm = json.loads(
model_predict_algorithm_str)
self.algorithm_type = args['algorithm_type']
self.predict = False
if self.algorithm_type == AlgorithmType.Predict.name:
self.predict = True
# check for the predict task
if self.predict and self.model_predict_algorithm is None:
raise f"Not set model_predict_algorithm for the job: {self.task_id}"

if 'dataset_id' in args and args['dataset_id'] is not None:
self.dataset_file_path = os.path.join(
self.workspace, args['dataset_id'])
else:
self.dataset_file_path = None
self.model_params = self.create_model_param()
self.reset_model_params(ModelSetting(args['model_dict']))
self.sync_file_list = {}
if self.algorithm_type == AlgorithmType.Train.name:
self.set_sync_file()

@abstractmethod
def set_sync_file(self):
pass

@abstractmethod
def create_model_param(self):
pass

def reset_model_params(self, model_setting: ModelSetting):
"""设置lr参数"""
self.model_params.set_model_setting(model_setting)
if model_setting.train_features is not None and len(model_setting.train_features) > 0:
self.model_params.train_feature = model_setting.train_features.split(
',')
if model_setting.categorical is not None and len(model_setting.categorical) > 0:
self.model_params.categorical_feature = model_setting.categorical.split(
',')
if model_setting.seed is not None:
self.model_params.random_state = model_setting.seed