Skip to content

Commit

Permalink
refactor modelSetting
Browse files Browse the repository at this point in the history
  • Loading branch information
cyjseagull committed Oct 15, 2024
1 parent 2070004 commit df53c4d
Show file tree
Hide file tree
Showing 12 changed files with 125 additions and 68 deletions.
77 changes: 56 additions & 21 deletions python/ppc_model/common/model_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ppc_common.ppc_utils import common_func


class ModelSetting:
class PreprocessingSetting:
def __init__(self, model_dict):
self.use_psi = common_func.get_config_value(
"use_psi", False, model_dict, False)
Expand All @@ -28,26 +28,56 @@ def __init__(self, model_dict):
"psi_select_bins", 4, model_dict, False))
self.corr_select = float(common_func.get_config_value(
"corr_select", 0, model_dict, False))
self.use_goss = common_func.get_config_value(
"use_goss", False, model_dict, False)


class FeatureEngineeringEngineSetting:
def __init__(self, model_dict):
self.use_iv = common_func.get_config_value(
"use_iv", False, model_dict, False)
self.group_num = int(common_func.get_config_value(
"group_num", 4, model_dict, False))
self.iv_thresh = float(common_func.get_config_value(
"iv_thresh", 0.1, model_dict, False))
self.use_goss = common_func.get_config_value(
"use_goss", False, model_dict, False)
self.test_size = float(common_func.get_config_value(
"test_dataset_percentage", 0.3, model_dict, False))


class CommmonModelSetting:
def __init__(self, model_dict):
self.learning_rate = float(common_func.get_config_value(
"learning_rate", 0.1, model_dict, False))

self.eval_set_column = common_func.get_config_value(
"eval_set_column", "", model_dict, False)
self.train_set_value = common_func.get_config_value(
"train_set_value", "", model_dict, False)
self.eval_set_value = common_func.get_config_value(
"eval_set_value", "", model_dict, False)
self.verbose_eval = int(common_func.get_config_value(
"verbose_eval", 1, model_dict, False))
self.silent = common_func.get_config_value(
"silent", False, model_dict, False)
self.train_features = common_func.get_config_value(
"train_features", "", model_dict, False)
if len(self.random_state) > 0:
self.random_state = int(common_func.get_config_value(
"random_state", 0, model_dict, False))
self.n_jobs = int(common_func.get_config_value(
"n_jobs", 0, model_dict, False))


class SecureLGBMSetting(CommmonModelSetting):
def __init__(self, model_dict):
super().__init__(model_dict)
self.test_size = float(common_func.get_config_value(
"test_dataset_percentage", 0.3, model_dict, False))
self.num_trees = int(common_func.get_config_value(
"num_trees", 6, model_dict, False))
self.max_depth = int(common_func.get_config_value(
"max_depth", 3, model_dict, False))
self.max_bin = int(common_func.get_config_value(
"max_bin", 4, model_dict, False))
self.silent = common_func.get_config_value(
"silent", False, model_dict, False)

self.subsample = float(common_func.get_config_value(
"subsample", 1, model_dict, False))
self.colsample_bytree = float(common_func.get_config_value(
Expand All @@ -70,21 +100,26 @@ def __init__(self, model_dict):
"early_stopping_rounds", 5, model_dict, False))
self.eval_metric = common_func.get_config_value(
"eval_metric", "auc", model_dict, False)
self.verbose_eval = int(common_func.get_config_value(
"verbose_eval", 1, model_dict, False))
self.eval_set_column = common_func.get_config_value(
"eval_set_column", "", model_dict, False)
self.train_set_value = common_func.get_config_value(
"train_set_value", "", model_dict, False)
self.eval_set_value = common_func.get_config_value(
"eval_set_value", "", model_dict, False)
self.train_features = common_func.get_config_value(
"train_features", "", model_dict, False)
self.epochs = int(common_func.get_config_value(
"epochs", 3, model_dict, False))
self.batch_size = int(common_func.get_config_value(
"batch_size", 16, model_dict, False))
self.threads = int(common_func.get_config_value(
"threads", 8, model_dict, False))
self.one_hot = common_func.get_config_value(
"one_hot", 0, model_dict, False)


class SecureLRSetting(CommmonModelSetting):
def __init__(self, model_dict):
super().__init__(model_dict)
self.feature_rate = float(common_func.get_config_value(
"feature_rate", 1.0, model_dict, False))
self.batch_size = int(common_func.get_config_value(
"batch_size", 16, model_dict, False))
self.epochs = int(common_func.get_config_value(
"epochs", 3, model_dict, False))


class ModelSetting(PreprocessingSetting, FeatureEngineeringEngineSetting, SecureLGBMSetting, SecureLRSetting):
def __init__(self, model_dict):
super(PreprocessingSetting, self).__init__(model_dict)
super(FeatureEngineeringEngineSetting, self).__init__(model_dict)
super(SecureLGBMSetting, self).__init__(model_dict)
super(SecureLRSetting, self).__init__(model_dict)
4 changes: 2 additions & 2 deletions python/ppc_model/conf/application-sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ SSL_CRT: "./ssl.crt"
SSL_KEY: "./ssl.key"


PEM_PATH: "/data/app/ppcs-model4ef/wedpr-model-node/ppc_model_service/server.pem"
SHARE_PATH: "/data/app/ppcs-model4ef/wedpr-model-node/ppc_model_service/dataset_share/"
PEM_PATH: "/data/app/wedpr-model/wedpr-model-node/ppc_model_service/server.pem"
SHARE_PATH: "/data/app/wedpr-model/wedpr-model-node/ppc_model_service/dataset_share/"

DB_TYPE: "mysql"
SQLALCHEMY_DATABASE_URI: "mysql://[*user_ppcsmodeladm]:[*pass_ppcsmodeladm]@[@4346-TDSQL_VIP]:[@4346-TDSQL_PORT]/ppcsmodeladm?autocommit=true&charset=utf8mb4"
Expand Down
4 changes: 2 additions & 2 deletions python/ppc_model/conf/logging.conf
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ keys=fileHandler,consoleHandler,accessHandler

[handler_accessHandler]
class=handlers.TimedRotatingFileHandler
args=('/data/app/logs/ppcs-model4ef/appmonitor.log', 'D', 1, 30, 'utf-8')
args=('logs/appmonitor.log', 'D', 1, 30, 'utf-8')
level=INFO
formatter=simpleFormatter

[handler_fileHandler]
class=handlers.TimedRotatingFileHandler
args=('/data/app/logs/ppcs-model4ef/ppcs-model4ef-node.log', 'D', 1, 30, 'utf-8')
args=('logs/wedpr-model.log', 'D', 1, 30, 'utf-8')
level=INFO
formatter=simpleFormatter

Expand Down
5 changes: 3 additions & 2 deletions python/ppc_model/ppc_model_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Note: here can't be refactored by autopep
import sys
sys.path.append("../")

from ppc_model.secure_lgbm.secure_lgbm_training_engine import SecureLGBMTrainingEngine
from ppc_model.secure_lgbm.secure_lgbm_prediction_engine import SecureLGBMPredictionEngine
from ppc_model.secure_lr.secure_lr_training_engine import SecureLRTrainingEngine
Expand All @@ -21,8 +24,6 @@
from concurrent import futures
import os
import multiprocessing
import sys
sys.path.append("../")


app = Flask(__name__)
Expand Down
3 changes: 3 additions & 0 deletions python/ppc_model/secure_lgbm/secure_lgbm_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ def __init__(self,
if model_setting.train_features is not None and len(model_setting.train_features) > 0:
self.model_params.train_feature = model_setting.train_features.split(
',')
if model_setting.categorical is not None and len(model_setting.categorical) > 0:
self.model_params.categorical_feature = model_setting.categorical.split(
',')
self.model_params.n_estimators = model_setting.num_trees
self.model_params.feature_rate = model_setting.colsample_bytree
self.model_params.min_split_gain = model_setting.gamma
Expand Down
36 changes: 25 additions & 11 deletions python/ppc_model/secure_lr/secure_lr_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def __init__(self,
if model_setting.train_features is not None and len(model_setting.train_features) > 0:
self.model_params.train_feature = model_setting.train_features.split(
',')
if model_setting.categorical is not None and len(model_setting.categorical) > 0:
self.model_params.categorical_feature = model_setting.categorical.split(
',')
self.model_params.random_state = model_setting.seed
self.sync_file_list = {}
if self.algorithm_type == AlgorithmType.Train.name:
Expand All @@ -196,17 +199,28 @@ def get_model_params(self):
return self.model_params

def set_sync_file(self):
self.sync_file_list['summary_evaluation'] = [self.summary_evaluation_file, self.remote_summary_evaluation_file]
self.sync_file_list['train_ks_table'] = [self.train_metric_ks_table, self.remote_train_metric_ks_table]
self.sync_file_list['train_metric_roc'] = [self.train_metric_roc_file, self.remote_train_metric_roc_file]
self.sync_file_list['train_metric_ks'] = [self.train_metric_ks_file, self.remote_train_metric_ks_file]
self.sync_file_list['train_metric_pr'] = [self.train_metric_pr_file, self.remote_train_metric_pr_file]
self.sync_file_list['train_metric_acc'] = [self.train_metric_acc_file, self.remote_train_metric_acc_file]
self.sync_file_list['test_ks_table'] = [self.test_metric_ks_table, self.remote_test_metric_ks_table]
self.sync_file_list['test_metric_roc'] = [self.test_metric_roc_file, self.remote_test_metric_roc_file]
self.sync_file_list['test_metric_ks'] = [self.test_metric_ks_file, self.remote_test_metric_ks_file]
self.sync_file_list['test_metric_pr'] = [self.test_metric_pr_file, self.remote_test_metric_pr_file]
self.sync_file_list['test_metric_acc'] = [self.test_metric_acc_file, self.remote_test_metric_acc_file]
self.sync_file_list['summary_evaluation'] = [
self.summary_evaluation_file, self.remote_summary_evaluation_file]
self.sync_file_list['train_ks_table'] = [
self.train_metric_ks_table, self.remote_train_metric_ks_table]
self.sync_file_list['train_metric_roc'] = [
self.train_metric_roc_file, self.remote_train_metric_roc_file]
self.sync_file_list['train_metric_ks'] = [
self.train_metric_ks_file, self.remote_train_metric_ks_file]
self.sync_file_list['train_metric_pr'] = [
self.train_metric_pr_file, self.remote_train_metric_pr_file]
self.sync_file_list['train_metric_acc'] = [
self.train_metric_acc_file, self.remote_train_metric_acc_file]
self.sync_file_list['test_ks_table'] = [
self.test_metric_ks_table, self.remote_test_metric_ks_table]
self.sync_file_list['test_metric_roc'] = [
self.test_metric_roc_file, self.remote_test_metric_roc_file]
self.sync_file_list['test_metric_ks'] = [
self.test_metric_ks_file, self.remote_test_metric_ks_file]
self.sync_file_list['test_metric_pr'] = [
self.test_metric_pr_file, self.remote_test_metric_pr_file]
self.sync_file_list['test_metric_acc'] = [
self.test_metric_acc_file, self.remote_test_metric_acc_file]


class LRMessage(Enum):
Expand Down
2 changes: 1 addition & 1 deletion python/ppc_model/task/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def record_model_job_log(self, job_id):
log_file = self._get_log_file_path()
if log_file is None or log_file == "":
current_working_dir = os.getcwd()
relative_log_path = "logs/ppcs-model4ef-node.log"
relative_log_path = "logs/wedpr-model.log"
log_file = os.path.join(current_working_dir, relative_log_path)

start_keyword = LOG_START_FLAG_FORMATTER.format(job_id=job_id)
Expand Down
5 changes: 1 addition & 4 deletions python/ppc_model/tools/start.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

dirpath="$(cd "$(dirname "$0")" && pwd)"
cd $dirpath
LOG_DIR=/data/app/logs/ppcs-model4ef/

# kill crypto process
crypto_pro_num=`ps -ef | grep /ppc/scripts | grep j- | grep -v 'grep' | awk '{print $2}' | wc -l`
Expand Down Expand Up @@ -34,6 +33,4 @@ check_service() {
}

sleep 5
check_service ppc_model_app.py
rm -rf logs
ln -s ${LOG_DIR} logs
check_service ppc_model_app.py
4 changes: 2 additions & 2 deletions python/ppc_model_gateway/conf/logging.conf
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ keys=fileHandler,consoleHandler,accessHandler

[handler_accessHandler]
class=handlers.TimedRotatingFileHandler
args=('/data/app/logs/ppcs-modelgateway/appmonitor.log', 'D', 1, 30, 'utf-8')
args=('logs/appmonitor.log', 'D', 1, 30, 'utf-8')
level=INFO
formatter=simpleFormatter

[handler_fileHandler]
class=handlers.TimedRotatingFileHandler
args=('/data/app/logs/ppcs-modelgateway/ppcs-modelgateway-gateway.log', 'D', 1, 30, 'utf-8')
args=('logs/ppcs-modelgateway-gateway.log', 'D', 1, 30, 'utf-8')
level=INFO
formatter=simpleFormatter

Expand Down
20 changes: 11 additions & 9 deletions python/ppc_model_gateway/ppc_model_gateway_app.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from ppc_model_gateway.endpoints.partner_to_node import PartnerToNodeService
from ppc_model_gateway.endpoints.node_to_partner import NodeToPartnerService
from ppc_model_gateway import config
from ppc_common.ppc_utils import utils
from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc
import grpc
from threading import Thread
from concurrent import futures
import os
# Note: here can't be refactored by autopep
import sys
sys.path.append("../")

import os
from concurrent import futures
from threading import Thread
import grpc
from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc
from ppc_common.ppc_utils import utils
from ppc_model_gateway import config
from ppc_model_gateway.endpoints.node_to_partner import NodeToPartnerService
from ppc_model_gateway.endpoints.partner_to_node import PartnerToNodeService



log = config.get_logger()

Expand Down
3 changes: 0 additions & 3 deletions python/ppc_model_gateway/tools/start.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

dirpath="$(cd "$(dirname "$0")" && pwd)"
cd $dirpath
LOG_DIR=/data/app/logs/ppcs-modelgateway/

export PYTHONPATH=$dirpath/../
source /data/app/ppcs-modelgateway/gateway_env/bin/deactivate
Expand Down Expand Up @@ -31,5 +30,3 @@ check_service() {

sleep 5
check_service ppc_model_gateway_app.py
rm -rf logs
ln -s ${LOG_DIR} logs
30 changes: 19 additions & 11 deletions python/ppc_scheduler/job/job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,30 @@ def run_task(self, job_id, request_body):
# TODO: The database persists job information
with self._rw_lock.gen_wlock():
if job_id in self._jobs:
self.logger.info(f"Job already exists, job_id: {job_id}, status: {self._jobs[job_id][0]}")
self.logger.info(
f"Job already exists, job_id: {job_id}, status: {self._jobs[job_id][0]}")
return
self._jobs[job_id] = [JobStatus.RUNNING, datetime.datetime.now(), 0]
self._jobs[job_id] = [
JobStatus.RUNNING, datetime.datetime.now(), 0]
self.logger.info(log_utils.job_start_log_info(job_id))

# Create job context
job_context = JobContext.create_job_context(request_body, self._workspace)
job_context = JobContext.create_job_context(
request_body, self._workspace)
# Build job workflow
flow_context = self._flow_builder.build_flow_context(job_id=job_context.job_id, workflow_configs=job_context.workflow_configs)
flow_context = self._flow_builder.build_flow_context(
job_id=job_context.job_id, workflow_configs=job_context.workflow_configs)
# Persistent workflow
self._flow_builder.save_flow_context(job_context.job_id, flow_context)
# Run workflow
self._async_executor.execute(job_id, self._run_job_flow, self._on_task_finish, (job_context, flow_context))
self._async_executor.execute(
job_id, self._run_job_flow, self._on_task_finish, (job_context, flow_context))

def _run_job_flow(self, job_context, flow_context):
"""
run job flow
"""

# the scheduler module starts scheduling tasks
self._scheduler.run(job_context, flow_context)

Expand Down Expand Up @@ -98,17 +103,20 @@ def _on_task_finish(self, job_id: str, is_succeeded: bool, e: Exception = None):
self._jobs[job_id][2] = time_costs
if is_succeeded:
self._jobs[job_id][0] = JobStatus.SUCCESS
self.logger.info(f"Job {job_id} completed, time_costs: {time_costs}s")
self.logger.info(
f"Job {job_id} completed, time_costs: {time_costs}s")
else:
self._jobs[job_id][0] = JobStatus.FAILURE
self.logger.warn(f"Job {job_id} failed, time_costs: {time_costs}s, error: {e}")
self.logger.warn(
f"Job {job_id} failed, time_costs: {time_costs}s, error: {e}")
self.logger.info(log_utils.job_end_log_info(job_id))

def _loop_action(self):
while True:
time.sleep(20)
self._terminate_timeout_jobs()
self._cleanup_finished_jobs()
# TODO: store into the database
# self._cleanup_finished_jobs()
self._report_jobs()

def _terminate_timeout_jobs(self):
Expand Down Expand Up @@ -139,7 +147,7 @@ def _cleanup_finished_jobs(self):
del self._jobs[job_id]
self._thread_event_manager.remove_event(job_id)
self.logger.info(f"Cleanup job cache, job_id: {job_id}")

def _report_jobs(self):
with self._rw_lock.gen_rlock():
job_count = len(self._jobs)
Expand Down

0 comments on commit df53c4d

Please sign in to comment.