Skip to content

Commit

Permalink
Merge pull request #8 from yanxinyi620/feature-milestone2
Browse files Browse the repository at this point in the history
Feature milestone2
  • Loading branch information
yanxinyi620 authored Aug 23, 2024
2 parents 34befa0 + 07ef9b7 commit b839a40
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 49 deletions.
4 changes: 0 additions & 4 deletions python/ppc_model/common/global_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import threading

from ppc_model.common.initializer import Initializer

Expand All @@ -9,6 +8,3 @@

components = Initializer(
log_config_path='logging.conf', config_path=config_path)

# matplotlib 线程不安全,并行任务绘图增加全局锁
plot_lock = threading.Lock()
7 changes: 6 additions & 1 deletion python/ppc_model/common/initializer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import logging.config
import os
import threading

import yaml

Expand All @@ -14,7 +15,7 @@


class Initializer:
def __init__(self, log_config_path, config_path):
def __init__(self, log_config_path, config_path, plot_lock=None):
self.log_config_path = log_config_path
self.config_path = config_path
self.config_data = None
Expand All @@ -27,6 +28,10 @@ def __init__(self, log_config_path, config_path):
self.mock_logger = None
self.public_key_length = 2048
self.homo_algorithm = 0
# matplotlib 线程不安全,并行任务绘图增加全局锁
self.plot_lock = plot_lock
if plot_lock is None:
self.plot_lock = threading.Lock()

def init_all(self):
self.init_log()
Expand Down
35 changes: 5 additions & 30 deletions python/ppc_model/common/model_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def __init__(self, ctx: Context) -> None:
if ctx.algorithm_type == AlgorithmType.Train.name:
self._process_fe_result()

# remove job workspace
# self._remove_workspace()

# Synchronization result file
if (len(ctx.result_receiver_id_list) == 1 and ctx.participant_id_list[0] != ctx.result_receiver_id_list[0]) \
or len(ctx.result_receiver_id_list) > 1:
self._sync_result_files()

# remove job workspace
self._remove_workspace()

def _process_fe_result(self):
if os.path.exists(self.ctx.preprocessing_result_file):
column_info_fm = pd.read_csv(
Expand Down Expand Up @@ -134,33 +134,8 @@ def _remove_workspace(self):
f'job {self.ctx.job_id}: {self.ctx.workspace} does not exist.')

def _sync_result_files(self):
if self.ctx.algorithm_type == AlgorithmType.Train.name:
self.sync_result_file(self.ctx, self.ctx.metrics_iteration_file,
self.ctx.remote_metrics_iteration_file, 'f1')
self.sync_result_file(self.ctx, self.ctx.feature_importance_file,
self.ctx.remote_feature_importance_file, 'f2')
self.sync_result_file(self.ctx, self.ctx.summary_evaluation_file,
self.ctx.remote_summary_evaluation_file, 'f3')
self.sync_result_file(self.ctx, self.ctx.train_metric_ks_table,
self.ctx.remote_train_metric_ks_table, 'f4')
self.sync_result_file(self.ctx, self.ctx.train_metric_roc_file,
self.ctx.remote_train_metric_roc_file, 'f5')
self.sync_result_file(self.ctx, self.ctx.train_metric_ks_file,
self.ctx.remote_train_metric_ks_file, 'f6')
self.sync_result_file(self.ctx, self.ctx.train_metric_pr_file,
self.ctx.remote_train_metric_pr_file, 'f7')
self.sync_result_file(self.ctx, self.ctx.train_metric_acc_file,
self.ctx.remote_train_metric_acc_file, 'f8')
self.sync_result_file(self.ctx, self.ctx.test_metric_ks_table,
self.ctx.remote_test_metric_ks_table, 'f9')
self.sync_result_file(self.ctx, self.ctx.test_metric_roc_file,
self.ctx.remote_test_metric_roc_file, 'f10')
self.sync_result_file(self.ctx, self.ctx.test_metric_ks_file,
self.ctx.remote_test_metric_ks_file, 'f11')
self.sync_result_file(self.ctx, self.ctx.test_metric_pr_file,
self.ctx.remote_test_metric_pr_file, 'f12')
self.sync_result_file(self.ctx, self.ctx.test_metric_acc_file,
self.ctx.remote_test_metric_acc_file, 'f13')
for key, value in self.ctx.sync_file_list.items():
self.sync_result_file(self.ctx, value[0], value[1], key)

@staticmethod
def sync_result_file(ctx, local_file, remote_file, key_file):
Expand Down
3 changes: 1 addition & 2 deletions python/ppc_model/metrics/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from sklearn.metrics import roc_curve, auc

from ppc_model.common.context import Context
from ppc_model.common.global_context import plot_lock
from ppc_model.datasets.dataset import SecureDataset
from ppc_model.common.model_result import ResultFileHandling
from ppc_model.secure_lgbm.monitor.feature.feature_evaluation_info import EvaluationType
Expand Down Expand Up @@ -156,7 +155,7 @@ def evaluation_file(self, ctx, data_index: np.ndarray,
while retry_num < max_retry:
retry_num += 1
try:
with plot_lock:
with ctx.components.plot_lock:
ks_value, auc_value = Evaluation.plot_two_class_graph(
self, y_true, y_praba)
except:
Expand Down
3 changes: 1 addition & 2 deletions python/ppc_model/metrics/model_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from networkx.drawing.nx_pydot import graphviz_layout

from ppc_model.common.model_result import ResultFileHandling
from ppc_model.common.global_context import plot_lock
from ppc_model.secure_lgbm.vertical.booster import VerticalBooster


Expand Down Expand Up @@ -50,7 +49,7 @@ def plot_tree(self):
while retry_num < max_retry:
retry_num += 1
try:
with plot_lock:
with self.ctx.components.plot_lock:
self._G.tree_plot(
figsize=(10, 5), save_filename=tree_file_path)
except:
Expand Down
4 changes: 1 addition & 3 deletions python/ppc_model/secure_lgbm/monitor/evaluation_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import matplotlib.pyplot as plt

from ppc_common.ppc_utils.utils import METRICS_OVER_ITERATION_FILE
from ppc_model.common.global_context import plot_lock
from ppc_model.secure_lgbm.monitor.callback import TrainingCallback
from ppc_model.secure_lgbm.monitor.core import _Model

Expand Down Expand Up @@ -110,8 +109,7 @@ def after_training(self, model: _Model) -> _Model:
while retry_num < max_retry:
retry_num += 1
try:
with plot_lock:
_draw_figure(model)
_draw_figure(model)
except:
self.logger.info(f'scores = {model.get_history()}')
self.logger.info(f'path = {model.get_workspace()}')
Expand Down
19 changes: 19 additions & 0 deletions python/ppc_model/secure_lgbm/secure_lgbm_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict
from sklearn.base import BaseEstimator

from ppc_common.ppc_utils.utils import AlgorithmType
from ppc_common.ppc_crypto.phe_factory import PheCipherFactory
from ppc_model.common.context import Context
from ppc_model.common.initializer import Initializer
Expand Down Expand Up @@ -231,6 +232,10 @@ def __init__(self,
self.lgbm_params.min_split_gain = model_setting.gamma
self.lgbm_params.random_state = model_setting.seed

self.sync_file_list = {}
if self.algorithm_type == AlgorithmType.Train.name:
self.set_sync_file()

def set_lgbm_params(self, model_setting: ModelSetting):
"""设置lgbm参数"""
self.lgbm_params.set_model_setting(model_setting)
Expand All @@ -239,6 +244,20 @@ def get_lgbm_params(self):
"""获取lgbm参数"""
return self.lgbm_params

def set_sync_file(self):
self.sync_file_list['metrics_iteration'] = [self.metrics_iteration_file, self.remote_metrics_iteration_file]
self.sync_file_list['feature_importance'] = [self.feature_importance_file, self.remote_feature_importance_file]
self.sync_file_list['summary_evaluation'] = [self.summary_evaluation_file, self.remote_summary_evaluation_file]
self.sync_file_list['train_ks_table'] = [self.train_metric_ks_table, self.remote_train_metric_ks_table]
self.sync_file_list['train_metric_roc'] = [self.train_metric_roc_file, self.remote_train_metric_roc_file]
self.sync_file_list['train_metric_ks'] = [self.train_metric_ks_file, self.remote_train_metric_ks_file]
self.sync_file_list['train_metric_pr'] = [self.train_metric_pr_file, self.remote_train_metric_pr_file]
self.sync_file_list['train_metric_acc'] = [self.train_metric_acc_file, self.remote_train_metric_acc_file]
self.sync_file_list['test_ks_table'] = [self.test_metric_ks_table, self.remote_test_metric_ks_table]
self.sync_file_list['test_metric_roc'] = [self.test_metric_roc_file, self.remote_test_metric_roc_file]
self.sync_file_list['test_metric_ks'] = [self.test_metric_ks_file, self.remote_test_metric_ks_file]
self.sync_file_list['test_metric_pr'] = [self.test_metric_pr_file, self.remote_test_metric_pr_file]
self.sync_file_list['test_metric_acc'] = [self.test_metric_acc_file, self.remote_test_metric_acc_file]

class LGBMMessage(Enum):
FEATURE_NAME = "FEATURE_NAME"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def active_worker():
booster_a.load_model()
booster_a.predict()
test_praba = booster_a.get_test_praba()
task_info_a.algorithm_type = 'PPC_PREDICT'
task_info_a.algorithm_type = 'Predict'
task_info_a.sync_file_list = {}
Evaluation(task_info_a, secure_dataset_a,
test_praba=test_praba)
ResultFileHandling(task_info_a)
Expand All @@ -159,7 +160,8 @@ def passive_worker():
booster_b.load_model()
booster_b.predict()
test_praba = booster_b.get_test_praba()
task_info_b.algorithm_type = 'PPC_PREDICT'
task_info_b.algorithm_type = 'Predict'
task_info_b.sync_file_list = {}
Evaluation(task_info_b, secure_dataset_b,
test_praba=test_praba)
ResultFileHandling(task_info_b)
Expand Down
12 changes: 8 additions & 4 deletions python/ppc_model/secure_lgbm/test/test_secure_lgbm_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ def setUp(self):

def test_fit(self):
args_a, args_b = mock_args()
plot_lock = threading.Lock()

active_components = Initializer(log_config_path='', config_path='')
active_components = Initializer(log_config_path='', config_path='', plot_lock=plot_lock)
active_components.stub = self._active_stub
active_components.config_data = {
'JOB_TEMP_DIR': '/tmp/active', 'AGENCY_ID': ACTIVE_PARTY}
Expand All @@ -114,7 +115,8 @@ def test_fit(self):
print(secure_dataset_a.test_X.shape)
print(secure_dataset_a.test_y.shape)

passive_components = Initializer(log_config_path='', config_path='')
passive_components = Initializer(log_config_path='', config_path='', plot_lock=plot_lock)
passive_components.stub = self._passive_stub
passive_components.stub = self._passive_stub
passive_components.config_data = {
'JOB_TEMP_DIR': '/tmp/passive', 'AGENCY_ID': PASSIVE_PARTY}
Expand All @@ -141,7 +143,8 @@ def active_worker():
booster_a.load_model()
booster_a.predict()
test_praba = booster_a.get_test_praba()
task_info_a.algorithm_type = 'PPC_PREDICT'
task_info_a.algorithm_type = 'Predict'
task_info_a.sync_file_list = {}
Evaluation(task_info_a, secure_dataset_a,
test_praba=test_praba)
ResultFileHandling(task_info_a)
Expand All @@ -161,7 +164,8 @@ def passive_worker():
booster_b.load_model()
booster_b.predict()
test_praba = booster_b.get_test_praba()
task_info_b.algorithm_type = 'PPC_PREDICT'
task_info_b.algorithm_type = 'Predict'
task_info_b.sync_file_list = {}
Evaluation(task_info_b, secure_dataset_b,
test_praba=test_praba)
ResultFileHandling(task_info_b)
Expand Down
3 changes: 2 additions & 1 deletion python/ppc_model/secure_lgbm/vertical/active_party.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,8 @@ def _end_active_data(self, is_train=True):
remote_file_path=self.ctx.remote_feature_importance_file, storage_client=self.storage_client)

if self.callback_container:
self.callback_container.after_training(self.model)
with self.ctx.components.plot_lock:
self.callback_container.after_training(self.model)

for partner_index in range(1, len(self.ctx.participant_id_list)):
if self.ctx.participant_id_list[partner_index] in self.ctx.result_receiver_id_list:
Expand Down

0 comments on commit b839a40

Please sign in to comment.