Skip to content

Commit

Permalink
更新预测任务的特征筛选流程 (#56)
Browse files Browse the repository at this point in the history
* update secure lr

* update model and predict

* update ppc_dev

* update model setting

* Update booster.py

* update wedpr_ml_toolkit

* update predict feature selection
  • Loading branch information
yanxinyi620 authored Oct 16, 2024
1 parent 456e83c commit f33a82a
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 52 deletions.
9 changes: 9 additions & 0 deletions python/ppc_model/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ def _dataset_fe_selected(self, file_path, feature_name):
self.model_data = self.model_data.drop(columns=drop_columns)

def _construct_dataset(self):
if self.algorithm_type == AlgorithmType.Predict.name:
my_fields = [
item["fields"] for item in self.ctx.model_predict_algorithm['participant_agency_list']
if item["agency"] == self.ctx.components.config_data['AGENCY_ID']]
if 'y' in self.model_data.columns and 'y' not in my_fields:
my_fields = ['y'] + my_fields
if 'id' in self.model_data.columns and 'id' not in my_fields:
my_fields = ['id'] + my_fields
self.model_data = self.model_data[my_fields]

if os.path.exists(self.iv_selected_file):
self._dataset_fe_selected(self.iv_selected_file, 'feature')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ def processing(self):
dataset_file_path = self.ctx.dataset_file_path
storage_client = self.ctx.components.storage_client
job_algorithm_type = self.ctx.job_algorithm_type
if job_algorithm_type == utils.AlgorithmType.Predict.name:
storage_client.download_file(os.path.join(self.ctx.training_job_id, self.ctx.PREPROCESSING_RESULT_FILE),
self.ctx.preprocessing_result_file)
psi_result_path = self.ctx.psi_result_path
model_prepare_file = self.ctx.model_prepare_file
storage_client.download_file(dataset_path, dataset_file_path)
Expand Down
47 changes: 8 additions & 39 deletions python/ppc_model/preprocessing/local_processing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,24 +79,7 @@ def process_dataframe(dataset_df: pd.DataFrame, model_setting: ModelSetting, xgb

column_info = {}

if ppc_job_type == utils.AlgorithmType.Predict.name:
column_info_fm = pd.read_csv(
ctx.preprocessing_result_file, index_col=0)
column_info_train_str = json.dumps(
column_info_fm.to_dict(orient='index'))
if column_info_train_str is None:
raise PpcException(-1, "column_info_train is None")
try:
# 对应orient='records'
# column_info_train = json.loads(column_info_train_str, orient='records')
column_info_train = json.loads(column_info_train_str)
except Exception as e:
log.error(
f"jobid: {job_id} column_info_train json.loads error, e:{e}")
raise PpcException(-1, "column_info_train json.loads error")
dataset_df = process_train_dataframe(dataset_df, column_info_train)
column_info = column_info_train
elif ppc_job_type == utils.AlgorithmType.Train.name:
if ppc_job_type != utils.AlgorithmType.Predict.name:
# 如果是训练任务 先默认所有数据都存在
column_info = {col: {'isExisted': True} for col in dataset_df.columns}

Expand Down Expand Up @@ -141,7 +124,7 @@ def process_dataframe(dataset_df: pd.DataFrame, model_setting: ModelSetting, xgb
log.info(f"jobid: {job_id} move id column finish.")

# 2.1 缺失值筛选
if ppc_job_type == utils.AlgorithmType.Train.name:
if ppc_job_type != utils.AlgorithmType.Predict.name:
if 0 <= model_setting.na_select <= 1:
log.info(f"jobid: {job_id} run fillna start")
df_filled, column_info = process_na_dataframe(
Expand All @@ -152,13 +135,9 @@ def process_dataframe(dataset_df: pd.DataFrame, model_setting: ModelSetting, xgb
f"jobid: {job_id} xgb_model_dict['na_select'] is range not 0 to 1, xgb_model_dict:{model_setting}")
raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(),
"xgb_model_dict['na_select'] range not 0 to 1")
elif ppc_job_type == utils.AlgorithmType.Predict.name:
log.info(f"jobid: {job_id} don't need run fillna for predict job.")
else:
log.error(
f"jobid: {job_id} ppc_job_type is not Train or Predict, ppc_job_type:{ppc_job_type}")
raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(),
"ppc_job_type is not Train or Predict")
log.info(f"jobid: {job_id} don't need run fillna for predict job.")

# 2.2 缺失值处理
if model_setting.fillna == 1:
# 填充
Expand All @@ -184,7 +163,7 @@ def process_dataframe(dataset_df: pd.DataFrame, model_setting: ModelSetting, xgb
), "xgb_model_dict['fillna'] is not 0 or 1")

# 6.1 特征选择 进行 psi稳定性指标筛选 计算特征相关性 降维可以减少模型的复杂度,提高模型的泛化能力
if ppc_job_type == utils.AlgorithmType.Train.name:
if ppc_job_type != utils.AlgorithmType.Predict.name:
if model_setting.psi_select_col in df_filled.columns.tolist() and model_setting.psi_select_col != 0:
log.info(f"jobid: {job_id} run psi_select_col start")
psi_select_base = model_setting.psi_select_base
Expand Down Expand Up @@ -215,17 +194,12 @@ def process_dataframe(dataset_df: pd.DataFrame, model_setting: ModelSetting, xgb
f"jobid: {job_id} xgb_model_dict['psi_select_col'] is not 0 or in col, model_setting:{model_setting}")
raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(),
"xgb_model_dict['psi_select_col'] is not 0 or in col")
elif ppc_job_type == utils.AlgorithmType.Predict.name:
else:
log.info(
f"jobid: {job_id} don't need run psi_select_col for predict job.")
else:
log.error(
f"jobid: {job_id} ppc_job_type is not Train or Predict, ppc_job_type:{ppc_job_type}")
raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(),
"ppc_job_type is not Train or Predict")

# 6.2 特征选择 进行 corr_select 计算特征相关性
if ppc_job_type == utils.AlgorithmType.Train.name:
if ppc_job_type != utils.AlgorithmType.Predict.name:
if model_setting.corr_select > 0:
log.info(f"jobid: {job_id} run corr_select start")
corr_select = model_setting.corr_select
Expand All @@ -247,14 +221,9 @@ def process_dataframe(dataset_df: pd.DataFrame, model_setting: ModelSetting, xgb
f"jobid: {job_id} xgb_model_dict['corr_select'] is not >= 0, model_setting:{model_setting}")
raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(),
"xgb_model_dict['corr_select'] is not >= 0")
elif ppc_job_type == utils.AlgorithmType.Predict.name:
else:
log.info(
f"jobid: {job_id} don't need run corr_select for predict job.")
else:
log.error(
f"jobid: {job_id} ppc_job_type is not Train or Predict, ppc_job_type:{ppc_job_type}")
raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(),
"ppc_job_type is not Train or Predict")

# 3. 离群值处理 3-sigma 法
if model_setting.filloutlier == 1:
Expand Down
2 changes: 0 additions & 2 deletions python/ppc_model/preprocessing/processing_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ def __init__(self,
self.job_algorithm_type = args['algorithm_type']
self.need_run_psi = args['need_run_psi']
self.model_dict = args['model_dict']
self.training_job_id = common_func.get_config_value(
"training_job_id", None, args, False)
if "psi_result_path" in args:
self.remote_psi_result_path = args["psi_result_path"]
self.model_setting = ModelSetting(self.model_dict)
5 changes: 0 additions & 5 deletions python/ppc_model/secure_lgbm/vertical/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,6 @@ def load_model(self, file_path=None):
file_path, self.ctx.FEATURE_BIN_FILE)
self.ctx.model_data_file = os.path.join(
file_path, self.ctx.MODEL_DATA_FILE)
if self.ctx.algorithm_type == AlgorithmType.Predict.name:
self.ctx.remote_feature_bin_file = os.path.join(
self.ctx.model_params.training_job_id, self.ctx.FEATURE_BIN_FILE)
self.ctx.remote_model_data_file = os.path.join(
self.ctx.model_params.training_job_id, self.ctx.MODEL_DATA_FILE)

try:
ResultFileHandling._download_file(self.ctx.components.storage_client,
Expand Down
3 changes: 0 additions & 3 deletions python/ppc_model/secure_lr/vertical/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,6 @@ def load_model(self, file_path=None):
if file_path is not None:
self.ctx.model_data_file = os.path.join(
file_path, self.ctx.MODEL_DATA_FILE)
if self.ctx.algorithm_type == AlgorithmType.Predict.name:
self.ctx.remote_model_data_file = os.path.join(
self.ctx.model_params.training_job_id, self.ctx.MODEL_DATA_FILE)

try:
ResultFileHandling._download_file(self.ctx.components.storage_client,
Expand Down

0 comments on commit f33a82a

Please sign in to comment.