Skip to content

Commit

Permalink
fix saveModel (#53)
Browse files Browse the repository at this point in the history
* add modelData response

* fix saveModel
  • Loading branch information
cyjseagull authored Oct 16, 2024
1 parent 904b457 commit abf582f
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 12 deletions.
4 changes: 4 additions & 0 deletions python/ppc_model/interface/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ def load_model(self, file_path):

class VerticalModel(ModelBase):
mode = "VERTICAL"

def __init__(self, ctx):
super().__init__(ctx)
self._all_feature_name = []
16 changes: 15 additions & 1 deletion python/ppc_model/model_result/task_result_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class ModelJobResult:
def __init__(self, xgb_job, job_id, components, property_name=DEFAULT_PROPERTY_NAME):
self.job_id = job_id
self.xgb_job = xgb_job
self.base_context = BaseContext(job_id, ".tmp")
self.components = components
self.logger = components.logger()
self.property_name = property_name
Expand Down Expand Up @@ -278,6 +279,12 @@ def load_feature_importance_table(self, feature_importance_path, property):
self.feature_importance_table = {property: DataItem(name=property, data=feature_importance_table.to_dict(),
type=DataType.TABLE).to_dict()}

def load_encrypted_model_data(self):
try:
return self.components.storage_client.get_data(self.base_context.remote_model_enc_file).decode("utf-8")
except:
pass

def load_iteration_metrics(self, iteration_path, property):
if not self.xgb_job:
return
Expand Down Expand Up @@ -317,6 +324,7 @@ def __init__(self, task_result_request: TaskResultRequest, components):
self.result_list = []
self.predict = False
self.xgb_job = False
self.model_data = None
if self.task_result_request.task_type == ModelTask.XGB_PREDICTING.name or self.task_result_request.task_type == ModelTask.LR_PREDICTING.name:
self.predict = True
if self.task_result_request.task_type == ModelTask.XGB_PREDICTING.name or self.task_result_request.task_type == ModelTask.XGB_TRAINING.name:
Expand All @@ -330,7 +338,12 @@ def get_response(self):
merged_result = dict()
for result in self.result_list:
merged_result.update(result.to_dict())
response = {"jobPlanetResult": merged_result}

if self.model_data is None:
response = {"jobPlanetResult": merged_result}
else:
response = {"jobPlanetResult": merged_result,
"modelData": self.model_data}
return utils.make_response(PpcErrorCode.SUCCESS.get_code(), PpcErrorCode.SUCCESS.get_msg(), response)

def _get_evaluation_result(self):
Expand Down Expand Up @@ -371,6 +384,7 @@ def _get_evaluation_result(self):
# the metrics iteration graph
self.model.load_iteration_metrics(
utils.METRICS_OVER_ITERATION_FILE, "IterationGraph")
self.model_data = self.model.load_encrypted_model_data()

if self.predict:
# the train evaluation result
Expand Down
21 changes: 12 additions & 9 deletions python/ppc_model/secure_lgbm/vertical/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __init__(self, ctx: SecureLGBMContext, dataset: SecureDataset) -> None:
self._train_praba = None
self._test_weights = None
self._test_praba = None

random.seed(ctx.model_params.random_state)
np.random.seed(ctx.model_params.random_state)

Expand Down Expand Up @@ -253,21 +252,23 @@ def merge_model_file(self):
lgbm_model['label_column'] = 'y'
lgbm_model['participant_agency_list'] = []
for partner_index in range(0, len(self.ctx.participant_id_list)):
agency_info = {'agency': self.ctx.participant_id_list[partner_index]}
agency_info['fields'] = self.ctx._all_feature_name[partner_index]
agency_info = {
'agency': self.ctx.participant_id_list[partner_index]}
agency_info['fields'] = self._all_feature_name[partner_index]
lgbm_model['participant_agency_list'].append(agency_info)
lgbm_model['model_dict'] = self.ctx.model_params

lgbm_model['model_dict'] = self.ctx.model_params.get_all_params()
model_text = {}
with open(self.ctx.feature_bin_file, 'rb') as f:
feature_bin_data = f.read()
with open(self.ctx.model_data_file, 'rb') as f:
model_data = f.read()
feature_bin_enc = encrypt_data(self.ctx.key, feature_bin_data)
model_data_enc = encrypt_data(self.ctx.key, model_data)

my_agency_id = self.ctx.components.config_data['AGENCY_ID']
model_text[my_agency_id] = [cipher_to_base64(feature_bin_enc), cipher_to_base64(model_data_enc)]
model_text[my_agency_id] = [cipher_to_base64(
feature_bin_enc), cipher_to_base64(model_data_enc)]

# 发送&接受文件
for partner_index in range(0, len(self.ctx.participant_id_list)):
Expand All @@ -285,7 +286,8 @@ def merge_model_file(self):
model_data_enc = self._receive_byte_data(
self.ctx, f'{LGBMMessage.MODEL_DATA.value}_model_data', partner_index)
model_text[self.ctx.participant_id_list[partner_index]] = \
[cipher_to_base64(feature_bin_enc), cipher_to_base64(model_data_enc)]
[cipher_to_base64(feature_bin_enc),
cipher_to_base64(model_data_enc)]
lgbm_model['model_text'] = model_text

# 上传密文模型
Expand All @@ -300,7 +302,8 @@ def split_model_file(self):
# 传入模型
my_agency_id = self.ctx.components.config_data['AGENCY_ID']
model_text = self.ctx.model_predict_algorithm['model_text']
feature_bin_enc, model_data_enc = [base64_to_cipher(i) for i in model_text[my_agency_id]]
feature_bin_enc, model_data_enc = [
base64_to_cipher(i) for i in model_text[my_agency_id]]

# 解密文件
feature_bin_data = decrypt_data(self.ctx.key, feature_bin_enc)
Expand Down
4 changes: 2 additions & 2 deletions python/ppc_model/secure_lr/vertical/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,10 @@ def merge_model_file(self):
lr_model['participant_agency_list'] = []
for partner_index in range(0, len(self.ctx.participant_id_list)):
agency_info = {'agency': self.ctx.participant_id_list[partner_index]}
agency_info['fields'] = self.ctx._all_feature_name[partner_index]
agency_info['fields'] = self._all_feature_name[partner_index]
lr_model['participant_agency_list'].append(agency_info)

lr_model['model_dict'] = self.ctx.model_params
lr_model['model_dict'] = self.ctx.model_params.get_all_params()
model_text = {}
with open(self.ctx.model_data_file, 'rb') as f:
model_data = f.read()
Expand Down

0 comments on commit abf582f

Please sign in to comment.