Skip to content

Commit

Permalink
Feature milestone2 (#52)
Browse files Browse the repository at this point in the history
* update secure lr

* update model and predict

* update ppc_dev

* update model setting

* Update booster.py
  • Loading branch information
yanxinyi620 authored Oct 15, 2024
1 parent d2f2e1a commit 904b457
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 58 deletions.
4 changes: 4 additions & 0 deletions python/ppc_model/secure_lgbm/vertical/active_party.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import time
import json

import numpy as np
from pandas import DataFrame
Expand Down Expand Up @@ -143,6 +144,9 @@ def _init_active_data(self):
[s.decode('utf-8') for s in feature_name_bytes.split(b' ') if s])
self._all_feature_num += len([s.decode('utf-8')
for s in feature_name_bytes.split(b' ') if s])
for i in range(1, len(self.ctx.participant_id_list)):
self._send_byte_data(self.ctx, LGBMMessage.FEATURE_NAME.value,
json.dumps(self._all_feature_name).encode('utf-8'), i)

self.log.info(f'task {self.ctx.task_id}: total feature number:{self._all_feature_num}, '
f'total feature name: {self._all_feature_name}.')
Expand Down
49 changes: 17 additions & 32 deletions python/ppc_model/secure_lgbm/vertical/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,17 @@ def merge_model_file(self):

# 加密文件
lgbm_model = {}
lgbm_model['model_type'] = 'xgb_model'
lgbm_model['label_provider'] = self.ctx.participant_id_list[0]
lgbm_model['label_column'] = 'y'
lgbm_model['participant_agency_list'] = []
for partner_index in range(0, len(self.ctx.participant_id_list)):
agency_info = {'agency': self.ctx.participant_id_list[partner_index]}
agency_info['fields'] = self.ctx._all_feature_name[partner_index]
lgbm_model['participant_agency_list'].append(agency_info)

lgbm_model['model_dict'] = self.ctx.model_params
model_text = {}
with open(self.ctx.feature_bin_file, 'rb') as f:
feature_bin_data = f.read()
with open(self.ctx.model_data_file, 'rb') as f:
Expand All @@ -256,7 +267,7 @@ def merge_model_file(self):
model_data_enc = encrypt_data(self.ctx.key, model_data)

my_agency_id = self.ctx.components.config_data['AGENCY_ID']
lgbm_model[my_agency_id] = [cipher_to_base64(feature_bin_enc), cipher_to_base64(model_data_enc)]
model_text[my_agency_id] = [cipher_to_base64(feature_bin_enc), cipher_to_base64(model_data_enc)]

# 发送&接受文件
for partner_index in range(0, len(self.ctx.participant_id_list)):
Expand All @@ -273,8 +284,9 @@ def merge_model_file(self):
self.ctx, f'{LGBMMessage.MODEL_DATA.value}_feature_bin', partner_index)
model_data_enc = self._receive_byte_data(
self.ctx, f'{LGBMMessage.MODEL_DATA.value}_model_data', partner_index)
lgbm_model[self.ctx.participant_id_list[partner_index]] = \
model_text[self.ctx.participant_id_list[partner_index]] = \
[cipher_to_base64(feature_bin_enc), cipher_to_base64(model_data_enc)]
lgbm_model['model_text'] = model_text

# 上传密文模型
with open(self.ctx.model_enc_file, 'w') as f:
Expand All @@ -285,38 +297,11 @@ def merge_model_file(self):
f"task {self.ctx.task_id}: Saved enc model to {self.ctx.model_enc_file} finished.")

def split_model_file(self):
# 下载密文模型
try:
ResultFileHandling._download_file(self.ctx.components.storage_client,
self.ctx.remote_model_enc_file, self.ctx.model_enc_file)
except:
pass

# 发送/接受文件
# 传入模型
my_agency_id = self.ctx.components.config_data['AGENCY_ID']
if os.path.exists(self.ctx.model_enc_file):

with open(self.ctx.model_enc_file, 'r') as f:
lgbm_model = json.load(f)
model_text = self.ctx.model_predict_algorithm['model_text']
feature_bin_enc, model_data_enc = [base64_to_cipher(i) for i in model_text[my_agency_id]]

for partner_index in range(0, len(self.ctx.participant_id_list)):
if self.ctx.participant_id_list[partner_index] != my_agency_id:
feature_bin_enc, model_data_enc = \
[base64_to_cipher(i) for i in lgbm_model[self.ctx.participant_id_list[partner_index]]]
self._send_byte_data(
self.ctx, f'{LGBMMessage.MODEL_DATA.value}_feature_bin',
feature_bin_enc, partner_index)
self._send_byte_data(
self.ctx, f'{LGBMMessage.MODEL_DATA.value}_model_data',
model_data_enc, partner_index)
feature_bin_enc, model_data_enc = [base64_to_cipher(i) for i in lgbm_model[my_agency_id]]

else:
feature_bin_enc = self._receive_byte_data(
self.ctx, f'{LGBMMessage.MODEL_DATA.value}_feature_bin', 0)
model_data_enc = self._receive_byte_data(
self.ctx, f'{LGBMMessage.MODEL_DATA.value}_model_data', 0)

# 解密文件
feature_bin_data = decrypt_data(self.ctx.key, feature_bin_enc)
model_data = decrypt_data(self.ctx.key, model_data_enc)
Expand Down
5 changes: 5 additions & 0 deletions python/ppc_model/secure_lgbm/vertical/passive_party.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import multiprocessing
import time
import json
import numpy as np
from pandas import DataFrame

Expand All @@ -16,6 +17,7 @@ class VerticalLGBMPassiveParty(VerticalBooster):
def __init__(self, ctx: SecureLGBMContext, dataset: SecureDataset) -> None:
super().__init__(ctx, dataset)
self.params = ctx.model_params
self._all_feature_name = []
self.log = ctx.components.logger()
self.log.info(
f'task {self.ctx.task_id}: print all params: {self.params.get_all_params()}')
Expand Down Expand Up @@ -94,6 +96,9 @@ def _init_passive_data(self):
b''.join(s.encode('utf-8') + b' ' for s in self.dataset.feature_name), 0)
self.params.my_categorical_idx = self._get_categorical_idx(
self.dataset.feature_name, self.params.categorical_feature)
feature_name_bytes = self._receive_byte_data(
self.ctx, LGBMMessage.FEATURE_NAME.value, 0)
self._all_feature_name = json.loads(feature_name_bytes.decode('utf-8'))

# 初始化分桶数据集
feat_bin = FeatureBinning(self.ctx)
Expand Down
4 changes: 4 additions & 0 deletions python/ppc_model/secure_lr/vertical/active_party.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import time
import json

import numpy as np
from pandas import DataFrame
Expand Down Expand Up @@ -115,6 +116,9 @@ def _init_active_data(self):
[s.decode('utf-8') for s in feature_name_bytes.split(b' ') if s])
self._all_feature_num += len([s.decode('utf-8')
for s in feature_name_bytes.split(b' ') if s])
for i in range(1, len(self.ctx.participant_id_list)):
self._send_byte_data(self.ctx, LRMessage.FEATURE_NAME.value,
json.dumps(self._all_feature_name).encode('utf-8'), i)

self.log.info(f'task {self.ctx.task_id}: total feature number:{self._all_feature_num}, '
f'total feature name: {self._all_feature_name}.')
Expand Down
43 changes: 17 additions & 26 deletions python/ppc_model/secure_lr/vertical/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,23 @@ def merge_model_file(self):

# 加密文件
lr_model = {}
lr_model['model_type'] = 'lr_model'
lr_model['label_provider'] = self.ctx.participant_id_list[0]
lr_model['label_column'] = 'y'
lr_model['participant_agency_list'] = []
for partner_index in range(0, len(self.ctx.participant_id_list)):
agency_info = {'agency': self.ctx.participant_id_list[partner_index]}
agency_info['fields'] = self.ctx._all_feature_name[partner_index]
lr_model['participant_agency_list'].append(agency_info)

lr_model['model_dict'] = self.ctx.model_params
model_text = {}
with open(self.ctx.model_data_file, 'rb') as f:
model_data = f.read()
model_data_enc = encrypt_data(self.ctx.key, model_data)

my_agency_id = self.ctx.components.config_data['AGENCY_ID']
lr_model[my_agency_id] = cipher_to_base64(model_data_enc)
model_text[my_agency_id] = cipher_to_base64(model_data_enc)

# 发送&接受文件
for partner_index in range(0, len(self.ctx.participant_id_list)):
Expand All @@ -282,7 +293,8 @@ def merge_model_file(self):
if self.ctx.participant_id_list[partner_index] != my_agency_id:
model_data_enc = self._receive_byte_data(
self.ctx, f'{LRMessage.MODEL_DATA.value}_model_data', partner_index)
lr_model[self.ctx.participant_id_list[partner_index]] = cipher_to_base64(model_data_enc)
model_text[self.ctx.participant_id_list[partner_index]] = cipher_to_base64(model_data_enc)
lr_model['model_text'] = model_text

# 上传密文模型
with open(self.ctx.model_enc_file, 'w') as f:
Expand All @@ -293,32 +305,11 @@ def merge_model_file(self):
f"task {self.ctx.task_id}: Saved enc model to {self.ctx.model_enc_file} finished.")

def split_model_file(self):
# 下载密文模型
try:
ResultFileHandling._download_file(self.ctx.components.storage_client,
self.ctx.remote_model_enc_file, self.ctx.model_enc_file)
except:
pass

# 发送/接受文件
# 传入模型
my_agency_id = self.ctx.components.config_data['AGENCY_ID']
if os.path.exists(self.ctx.model_enc_file):

with open(self.ctx.model_enc_file, 'r') as f:
lr_model = json.load(f)

for partner_index in range(0, len(self.ctx.participant_id_list)):
if self.ctx.participant_id_list[partner_index] != my_agency_id:
model_data_enc = base64_to_cipher(lr_model[self.ctx.participant_id_list[partner_index]])
self._send_byte_data(
self.ctx, f'{LRMessage.MODEL_DATA.value}_model_data',
model_data_enc, partner_index)
model_data_enc = base64_to_cipher(lr_model[my_agency_id])
model_text = self.ctx.model_predict_algorithm['model_text']
model_data_enc = base64_to_cipher(model_text[my_agency_id])

else:
model_data_enc = self._receive_byte_data(
self.ctx, f'{LRMessage.MODEL_DATA.value}_model_data', 0)

# 解密文件
model_data = decrypt_data(self.ctx.key, model_data_enc)
with open(self.ctx.model_data_file, 'wb') as f:
Expand Down
5 changes: 5 additions & 0 deletions python/ppc_model/secure_lr/vertical/passive_party.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
import multiprocessing
import time
import json
import numpy as np
from pandas import DataFrame

Expand All @@ -19,6 +20,7 @@ class VerticalLRPassiveParty(VerticalBooster):
def __init__(self, ctx: SecureLRContext, dataset: SecureDataset) -> None:
super().__init__(ctx, dataset)
self.params = ctx.model_params
self._all_feature_name = []
self._loss_func = BinaryLoss()
self.log = ctx.components.logger()
self.log.info(
Expand Down Expand Up @@ -86,6 +88,9 @@ def _init_passive_data(self):
b''.join(s.encode('utf-8') + b' ' for s in self.dataset.feature_name), 0)
self.params.my_categorical_idx = self._get_categorical_idx(
self.dataset.feature_name, self.params.categorical_feature)
feature_name_bytes = self._receive_byte_data(
self.ctx, LRMessage.FEATURE_NAME.value, 0)
self._all_feature_name = json.loads(feature_name_bytes.decode('utf-8'))

def _build_iter(self, feature_select, idx):

Expand Down

0 comments on commit 904b457

Please sign in to comment.