Skip to content

Commit

Permalink
refactor common part of lr-context and xgb-context into SecureModelCo…
Browse files Browse the repository at this point in the history
…ntext
  • Loading branch information
cyjseagull committed Oct 16, 2024
1 parent f33a82a commit 936488a
Show file tree
Hide file tree
Showing 13 changed files with 334 additions and 365 deletions.
6 changes: 1 addition & 5 deletions python/ppc_common/ppc_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,6 @@
MPC_TRAIN_SET_METRIC_PR_FILE = "mpc_train_metric_pr.svg"
MPC_TRAIN_SET_METRIC_ACCURACY_FILE = "mpc_train_metric_accuracy.svg"
MPC_TRAIN_SET_METRIC_KS_TABLE = "mpc_train_metric_ks.csv"
MPC_EVAL_METRIC_ROC_FILE = "mpc_eval_metric_roc.svg"
MPC_EVAL_METRIC_KS_FILE = "mpc_eval_metric_ks.svg"
MPC_EVAL_METRIC_PR_FILE = "mpc_eval_metric_pr.svg"
MPC_EVAL_METRIC_ACCURACY_FILE = "mpc_eval_metric_accuracy.svg"
MPC_EVAL_METRIC_KS_TABLE = "mpc_eval_metric_ks.csv"
MPC_TRAIN_METRIC_CONFUSION_MATRIX_FILE = "mpc_metric_confusion_matrix.svg"
METRICS_OVER_ITERATION_FILE = "metrics_over_iterations.svg"

Expand Down Expand Up @@ -111,6 +106,7 @@ class CryptoType(Enum):
ECDSA = 1
GM = 2


@unique
class HashType(Enum):
BYTES = 1
Expand Down
1 change: 0 additions & 1 deletion python/ppc_model/common/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import unittest

from ppc_model.common.base_context import BaseContext
from ppc_model.common.initializer import Initializer
Expand Down
7 changes: 4 additions & 3 deletions python/ppc_model/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,10 @@ def _dataset_fe_selected(self, file_path, feature_name):

def _construct_dataset(self):
if self.algorithm_type == AlgorithmType.Predict.name:
my_fields = [
item["fields"] for item in self.ctx.model_predict_algorithm['participant_agency_list']
if item["agency"] == self.ctx.components.config_data['AGENCY_ID']]
my_fields = []
for item in self.ctx.model_predict_algorithm['participant_agency_list']:
if item["agency"] == self.ctx.components.config_data['AGENCY_ID']:
my_fields = item["fields"]
if 'y' in self.model_data.columns and 'y' not in my_fields:
my_fields = ['y'] + my_fields
if 'id' in self.model_data.columns and 'id' not in my_fields:
Expand Down
14 changes: 7 additions & 7 deletions python/ppc_model/model_result/task_result_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ class JobEvaluationResult:
EvaluationType.ACCURACY: utils.MPC_TRAIN_SET_METRIC_ACCURACY_FILE}

DEFAULT_EVAL_EVALUATION_FILES = {
EvaluationType.ROC: utils.MPC_EVAL_METRIC_ROC_FILE,
EvaluationType.PR: utils.MPC_EVAL_METRIC_PR_FILE,
EvaluationType.KS: utils.MPC_EVAL_METRIC_KS_FILE,
EvaluationType.ACCURACY: utils.MPC_EVAL_METRIC_ACCURACY_FILE
EvaluationType.ROC: utils.MPC_TRAIN_METRIC_ROC_FILE,
EvaluationType.PR: utils.MPC_TRAIN_METRIC_PR_FILE,
EvaluationType.KS: utils.MPC_TRAIN_METRIC_KS_FILE,
EvaluationType.ACCURACY: utils.MPC_TRAIN_METRIC_ACCURACY_FILE
}

def __init__(self, property_name, classification_type,
Expand Down Expand Up @@ -357,7 +357,7 @@ def _get_evaluation_result(self):
components=self.components)
# load the ks table
self.train_evaluation_result.load_ks_table(
"mpc_train_metric_ks.csv", "TrainKSTable")
utils.MPC_TRAIN_METRIC_KS_TABLE, "TrainKSTable")
self.result_list.append(self.train_evaluation_result)

self.validation_evaluation_result = JobEvaluationResult(
Expand All @@ -368,7 +368,7 @@ def _get_evaluation_result(self):
components=self.components)
# load the ks_table
self.validation_evaluation_result.load_ks_table(
"mpc_metric_ks.csv", "KSTable")
utils.MPC_TRAIN_METRIC_KS_TABLE, "KSTable")
self.result_list.append(self.validation_evaluation_result)

self.model = ModelJobResult(self.xgb_job,
Expand Down Expand Up @@ -396,7 +396,7 @@ def _get_evaluation_result(self):
components=self.components)
# load ks_table
self.predict_evaluation_result.load_ks_table(
"mpc_eval_metric_ks.csv", "KSTable")
utils.MPC_TRAIN_METRIC_KS_TABLE, "KSTable")
self.result_list.append(self.predict_evaluation_result)

# load model_result
Expand Down
45 changes: 24 additions & 21 deletions python/ppc_model/ppc_model_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,29 @@
import sys
sys.path.append("../")

from ppc_model.secure_lgbm.secure_lgbm_training_engine import SecureLGBMTrainingEngine
from ppc_model.secure_lgbm.secure_lgbm_prediction_engine import SecureLGBMPredictionEngine
from ppc_model.secure_lr.secure_lr_training_engine import SecureLRTrainingEngine
from ppc_model.preprocessing.preprocessing_engine import PreprocessingEngine
from ppc_model.network.http.restx import api
from ppc_model.network.http.model_controller import ns2 as log_namespace
from ppc_model.network.http.model_controller import ns as task_namespace
from ppc_model.network.grpc.grpc_server import ModelService
from ppc_model.feature_engineering.feature_engineering_engine import FeatureEngineeringEngine
from ppc_model.common.protocol import ModelTask
from ppc_model.common.global_context import components
from ppc_common.ppc_utils import utils
from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc
from paste.translogger import TransLogger
from flask import Flask, Blueprint
from cheroot.wsgi import Server as WSGIServer
from cheroot.ssl.builtin import BuiltinSSLAdapter
import grpc
from threading import Thread
from concurrent import futures
import os
import multiprocessing
import os
from concurrent import futures
from threading import Thread
import grpc
from cheroot.ssl.builtin import BuiltinSSLAdapter
from cheroot.wsgi import Server as WSGIServer
from flask import Flask, Blueprint
from paste.translogger import TransLogger
from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc
from ppc_common.ppc_utils import utils
from ppc_model.common.global_context import components
from ppc_model.common.protocol import ModelTask
from ppc_model.feature_engineering.feature_engineering_engine import FeatureEngineeringEngine
from ppc_model.network.grpc.grpc_server import ModelService
from ppc_model.network.http.model_controller import ns as task_namespace
from ppc_model.network.http.model_controller import ns2 as log_namespace
from ppc_model.network.http.restx import api
from ppc_model.preprocessing.preprocessing_engine import PreprocessingEngine
from ppc_model.secure_lr.secure_lr_prediction_engine import SecureLRPredictionEngine
from ppc_model.secure_lr.secure_lr_training_engine import SecureLRTrainingEngine
from ppc_model.secure_lgbm.secure_lgbm_prediction_engine import SecureLGBMPredictionEngine
from ppc_model.secure_lgbm.secure_lgbm_training_engine import SecureLGBMTrainingEngine


app = Flask(__name__)
Expand Down Expand Up @@ -53,6 +54,8 @@ def register_task_handler():
ModelTask.XGB_PREDICTING, SecureLGBMPredictionEngine.run)
task_manager.register_task_handler(
ModelTask.LR_TRAINING, SecureLRTrainingEngine.run)
task_manager.register_task_handler(
ModelTask.LR_PREDICTING, SecureLRPredictionEngine.run)


def model_serve():
Expand Down
169 changes: 36 additions & 133 deletions python/ppc_model/secure_lgbm/secure_lgbm_context.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@

import os
from enum import Enum
from typing import Any, Dict
from sklearn.base import BaseEstimator

from ppc_common.ppc_utils.utils import AlgorithmType
from ppc_common.ppc_crypto.phe_factory import PheCipherFactory
from ppc_model.common.context import Context
from ppc_model.common.initializer import Initializer
from ppc_model.common.protocol import TaskRole
from ppc_common.ppc_utils import common_func
from ppc_model.common.model_setting import ModelSetting
from ppc_model.secure_model_base.secure_model_context import SecureModel
from ppc_model.secure_model_base.secure_model_context import SecureModelContext


class LGBMModel(BaseEstimator):
class LGBMModel(SecureModel):

def __init__(
self,
Expand All @@ -36,7 +37,6 @@ def __init__(
importance_type: str = 'split',
**kwargs
):

self.boosting_type = boosting_type
self.objective = objective
self.num_leaves = num_leaves
Expand All @@ -55,60 +55,7 @@ def __init__(
self.random_state = random_state
self.n_jobs = n_jobs
self.importance_type = importance_type
self._other_params: Dict[str, Any] = {}
self.set_params(**kwargs)

def get_params(self, deep: bool = True) -> Dict[str, Any]:
"""Get parameters for this estimator.
Parameters
----------
deep : bool, optional (default=True)
If True, will return the parameters for this estimator and
contained subobjects that are estimators.
Returns
-------
params : dict
Parameter names mapped to their values.
"""
params = super().get_params(deep=deep)
params.update(self._other_params)
return params

def set_model_setting(self, model_setting: ModelSetting) -> "LGBMModel":
# 获取对象的所有属性名
attrs = dir(model_setting)
# 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性)
attrs = [attr for attr in attrs if not attr.startswith('_')]

params = {}
for attr in attrs:
try:
setattr(self, attr, getattr(model_setting, attr))
except Exception as e:
pass
return self

def set_params(self, **params: Any) -> "LGBMModel":
"""Set the parameters of this estimator.
Parameters
----------
**params
Parameter names with their new values.
Returns
-------
self : object
Returns self.
"""
for key, value in params.items():
setattr(self, key, value)
if hasattr(self, f"_{key}"):
setattr(self, f"_{key}", value)
self._other_params[key] = value
return self
super().__init__(**kwargs)


class ModelTaskParams(LGBMModel):
Expand Down Expand Up @@ -169,98 +116,54 @@ def _get_params(self):
"""返回LGBMClassifier所有参数"""
return LGBMModel().get_params()

def get_all_params(self):
"""返回SecureLGBMParams所有参数"""
# 获取对象的所有属性名
attrs = dir(self)
# 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性)
attrs = [attr for attr in attrs if not attr.startswith('_')]

params = {}
for attr in attrs:
try:
# 使用getattr来获取属性的值
value = getattr(self, attr)
# 检查value是否可调用(例如,方法或函数),如果是,则不打印其值
if not callable(value):
params[attr] = value
except Exception as e:
pass
return params


class SecureLGBMContext(Context):
class SecureLGBMContext(SecureModelContext):

def __init__(self,
args,
components: Initializer
):

if args['is_label_holder']:
role = TaskRole.ACTIVE_PARTY
else:
role = TaskRole.PASSIVE_PARTY

super().__init__(args['job_id'],
args['task_id'],
components,
role)
super().__init__(args, components)

self.phe = PheCipherFactory.build_phe(
components.homo_algorithm, components.public_key_length)
self.codec = PheCipherFactory.build_codec(components.homo_algorithm)
self.is_label_holder = args['is_label_holder']
self.result_receiver_id_list = args['result_receiver_id_list']
self.participant_id_list = args['participant_id_list']
self.model_predict_algorithm = common_func.get_config_value(
"model_predict_algorithm", None, args, False)
self.algorithm_type = args['algorithm_type']
if 'dataset_id' in args and args['dataset_id'] is not None:
self.dataset_file_path = os.path.join(
self.workspace, args['dataset_id'])
else:
self.dataset_file_path = None

self.model_params = SecureLGBMParams()
model_setting = ModelSetting(args['model_dict'])
self.set_model_params(model_setting)
if model_setting.train_features is not None and len(model_setting.train_features) > 0:
self.model_params.train_feature = model_setting.train_features.split(
',')
if model_setting.categorical is not None and len(model_setting.categorical) > 0:
self.model_params.categorical_feature = model_setting.categorical.split(
',')
self.model_params.n_estimators = model_setting.num_trees
self.model_params.feature_rate = model_setting.colsample_bytree
self.model_params.min_split_gain = model_setting.gamma
self.model_params.random_state = model_setting.seed

self.sync_file_list = {}
if self.algorithm_type == AlgorithmType.Train.name:
self.set_sync_file()

def set_model_params(self, model_setting: ModelSetting):
"""设置lgbm参数"""
self.model_params.set_model_setting(model_setting)
def create_model_param(self):
return SecureLGBMParams()

def get_model_params(self):
"""获取lgbm参数"""
return self.model_params

def set_sync_file(self):
self.sync_file_list['metrics_iteration'] = [self.metrics_iteration_file, self.remote_metrics_iteration_file]
self.sync_file_list['feature_importance'] = [self.feature_importance_file, self.remote_feature_importance_file]
self.sync_file_list['summary_evaluation'] = [self.summary_evaluation_file, self.remote_summary_evaluation_file]
self.sync_file_list['train_ks_table'] = [self.train_metric_ks_table, self.remote_train_metric_ks_table]
self.sync_file_list['train_metric_roc'] = [self.train_metric_roc_file, self.remote_train_metric_roc_file]
self.sync_file_list['train_metric_ks'] = [self.train_metric_ks_file, self.remote_train_metric_ks_file]
self.sync_file_list['train_metric_pr'] = [self.train_metric_pr_file, self.remote_train_metric_pr_file]
self.sync_file_list['train_metric_acc'] = [self.train_metric_acc_file, self.remote_train_metric_acc_file]
self.sync_file_list['test_ks_table'] = [self.test_metric_ks_table, self.remote_test_metric_ks_table]
self.sync_file_list['test_metric_roc'] = [self.test_metric_roc_file, self.remote_test_metric_roc_file]
self.sync_file_list['test_metric_ks'] = [self.test_metric_ks_file, self.remote_test_metric_ks_file]
self.sync_file_list['test_metric_pr'] = [self.test_metric_pr_file, self.remote_test_metric_pr_file]
self.sync_file_list['test_metric_acc'] = [self.test_metric_acc_file, self.remote_test_metric_acc_file]
self.sync_file_list['metrics_iteration'] = [
self.metrics_iteration_file, self.remote_metrics_iteration_file]
self.sync_file_list['feature_importance'] = [
self.feature_importance_file, self.remote_feature_importance_file]
self.sync_file_list['summary_evaluation'] = [
self.summary_evaluation_file, self.remote_summary_evaluation_file]
self.sync_file_list['train_ks_table'] = [
self.train_metric_ks_table, self.remote_train_metric_ks_table]
self.sync_file_list['train_metric_roc'] = [
self.train_metric_roc_file, self.remote_train_metric_roc_file]
self.sync_file_list['train_metric_ks'] = [
self.train_metric_ks_file, self.remote_train_metric_ks_file]
self.sync_file_list['train_metric_pr'] = [
self.train_metric_pr_file, self.remote_train_metric_pr_file]
self.sync_file_list['train_metric_acc'] = [
self.train_metric_acc_file, self.remote_train_metric_acc_file]
self.sync_file_list['test_ks_table'] = [
self.test_metric_ks_table, self.remote_test_metric_ks_table]
self.sync_file_list['test_metric_roc'] = [
self.test_metric_roc_file, self.remote_test_metric_roc_file]
self.sync_file_list['test_metric_ks'] = [
self.test_metric_ks_file, self.remote_test_metric_ks_file]
self.sync_file_list['test_metric_pr'] = [
self.test_metric_pr_file, self.remote_test_metric_pr_file]
self.sync_file_list['test_metric_acc'] = [
self.test_metric_acc_file, self.remote_test_metric_acc_file]


class LGBMMessage(Enum):
FEATURE_NAME = "FEATURE_NAME"
Expand Down
Loading

0 comments on commit 936488a

Please sign in to comment.