diff --git a/python/aes_key.bin b/python/aes_key.bin new file mode 100644 index 00000000..0d516172 --- /dev/null +++ b/python/aes_key.bin @@ -0,0 +1,2 @@ +m}9H +褊c?Ӈ!<> \ No newline at end of file diff --git a/python/ppc_dev/__init__.py b/python/ppc_dev/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_dev/common/__init__.py b/python/ppc_dev/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_dev/common/base_context.py b/python/ppc_dev/common/base_context.py new file mode 100644 index 00000000..1496f381 --- /dev/null +++ b/python/ppc_dev/common/base_context.py @@ -0,0 +1,13 @@ +import os + + +class BaseContext: + + def __init__(self, project_id, user_name, pws_endpoint=None, hdfs_endpoint=None, token=None): + + self.project_id = project_id + self.user_name = user_name + self.pws_endpoint = pws_endpoint + self.hdfs_endpoint = hdfs_endpoint + self.token = token + self.workspace = os.path.join(self.project_id, self.user_name) diff --git a/python/ppc_dev/common/base_result.py b/python/ppc_dev/common/base_result.py new file mode 100644 index 00000000..ace5f8e3 --- /dev/null +++ b/python/ppc_dev/common/base_result.py @@ -0,0 +1,8 @@ +from ppc_dev.common.base_context import BaseContext + + +class BaseResult: + + def __init__(self, ctx: BaseContext): + + self.ctx = ctx diff --git a/python/ppc_dev/job_exceuter/__init__.py b/python/ppc_dev/job_exceuter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_dev/job_exceuter/hdfs_client.py b/python/ppc_dev/job_exceuter/hdfs_client.py new file mode 100644 index 00000000..ca77360a --- /dev/null +++ b/python/ppc_dev/job_exceuter/hdfs_client.py @@ -0,0 +1,53 @@ +import requests +import pandas as pd +import io + + +class HDFSApi: + def __init__(self, base_url): + self.base_url = base_url + + def upload(self, dataframe, hdfs_path): + """ + 上传Pandas DataFrame到HDFS + :param dataframe: 要上传的Pandas DataFrame + :param hdfs_path: HDFS目标路径 + :return: 响应信息 + """ + # 将DataFrame转换为CSV格式 + csv_buffer = io.StringIO() + dataframe.to_csv(csv_buffer, index=False) + + # 发送PUT请求上传CSV数据 + response = requests.put( + f"{self.base_url}/upload?path={hdfs_path}", + data=csv_buffer.getvalue(), + headers={'Content-Type': 'text/csv'} + ) + return response.json() + + def download(self, hdfs_path): + """ + 从HDFS下载数据并返回为Pandas DataFrame + :param hdfs_path: HDFS文件路径 + :return: Pandas DataFrame + """ + response = requests.get(f"{self.base_url}/download?path={hdfs_path}") + if response.status_code == 200: + # 读取CSV数据并转换为DataFrame + dataframe = pd.read_csv(io.StringIO(response.text)) + return dataframe + else: + raise Exception(f"下载失败: {response.json()}") + + def download_data(self, hdfs_path): + """ + 从HDFS下载数据并返回为Pandas DataFrame + :param hdfs_path: HDFS文件路径 + :return: text + """ + response = requests.get(f"{self.base_url}/download?path={hdfs_path}") + if response.status_code == 200: + return response.text + else: + raise Exception(f"下载失败: {response.json()}") diff --git a/python/ppc_dev/job_exceuter/pws_client.py b/python/ppc_dev/job_exceuter/pws_client.py new file mode 100644 index 00000000..8404620a --- /dev/null +++ b/python/ppc_dev/job_exceuter/pws_client.py @@ -0,0 +1,66 @@ +import random +import time + +from ppc_common.ppc_utils import http_utils +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode + + +class PWSApi: + def __init__(self, endpoint, token, + polling_interval_s: int = 5, max_retries: int = 5, retry_delay_s: int = 5): + self.endpoint = endpoint + self.token = token + self.polling_interval_s = polling_interval_s + self.max_retries = max_retries + self.retry_delay_s = retry_delay_s + self._async_run_task_method = 'asyncRunTask' + self._get_task_status_method = 'getTaskStatus' + self._completed_status = 'COMPLETED' + self._failed_status = 'FAILED' + + def run(self, datasets, params): + params = { + 'jsonrpc': '1', + 'method': self._async_run_task_method, + 'token': self.token, + 'id': random.randint(1, 65535), + 'dataset': datasets, + 'params': params + } + response = self._send_request_with_retry(http_utils.send_post_request, self.endpoint, None, params) + if response.status_code != 200: + raise Exception(f"创建任务失败: {response.json()}") + return self._poll_task_status(response.job_id, self.token) + + def _poll_task_status(self, job_id, token): + while True: + params = { + 'jsonrpc': '1', + 'method': self._get_task_status_method, + 'token': token, + 'id': random.randint(1, 65535), + 'params': { + 'taskID': job_id, + } + } + response = self._send_request_with_retry(http_utils.send_post_request, self.endpoint, None, params) + if response.status_code != 200: + raise Exception(f"轮询任务失败: {response.json()}") + if response['result']['status'] == self._completed_status: + return response['result'] + elif response['result']['status'] == self._failed_status: + raise PpcException(PpcErrorCode.CALL_SCS_ERROR.get_code(), response['data']) + time.sleep(self.polling_interval_s) + + def _send_request_with_retry(self, request_func, *args, **kwargs): + attempt = 0 + while attempt < self.max_retries: + try: + response = request_func(*args, **kwargs) + return response + except Exception as e: + attempt += 1 + if attempt < self.max_retries: + time.sleep(self.retry_delay_s) + else: + raise e diff --git a/python/ppc_dev/result/__init__.py b/python/ppc_dev/result/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_dev/result/fe_result.py b/python/ppc_dev/result/fe_result.py new file mode 100644 index 00000000..65c4dfc5 --- /dev/null +++ b/python/ppc_dev/result/fe_result.py @@ -0,0 +1,27 @@ +import os + +from ppc_dev.wedpr_data.data_context import DataContext +from ppc_dev.common.base_result import BaseResult + + +class FeResult(BaseResult): + + FE_RESULT_FILE = "fe_result.csv" + + def __init__(self, dataset: DataContext, job_id: str): + + super().__init__(dataset.ctx) + self.job_id = job_id + + participant_id_list = [] + for dataset in self.dataset.datasets: + participant_id_list.append(dataset.agency.agency_id) + self.participant_id_list = participant_id_list + + result_list = [] + for dataset in self.dataset.datasets: + dataset.update_path(os.path.join(self.job_id, self.FE_RESULT_FILE)) + result_list.append(dataset) + + fe_result = DataContext(*result_list) + return fe_result diff --git a/python/ppc_dev/result/model_result.py b/python/ppc_dev/result/model_result.py new file mode 100644 index 00000000..f1cc4e07 --- /dev/null +++ b/python/ppc_dev/result/model_result.py @@ -0,0 +1,56 @@ +import os +import numpy as np + +from ppc_common.ppc_utils import utils + +from ppc_dev.wedpr_data.data_context import DataContext +from ppc_dev.common.base_result import BaseResult +from ppc_dev.job_exceuter.hdfs_client import HDFSApi + + +class ModelResult(BaseResult): + + FEATURE_BIN_FILE = "feature_bin.json" + MODEL_DATA_FILE = utils.XGB_TREE_PERFIX + '.json' + TEST_MODEL_OUTPUT_FILE = "xgb_output.csv" + TRAIN_MODEL_OUTPUT_FILE = "xgb_train_output.csv" + + def __init__(self, dataset: DataContext, job_id: str, job_type: str): + + super().__init__(dataset.ctx) + self.job_id = job_id + + participant_id_list = [] + for dataset in self.dataset.datasets: + participant_id_list.append(dataset.agency.agency_id) + self.participant_id_list = participant_id_list + + if job_type == 'xgb_training': + self._xgb_train_result() + + def _xgb_train_result(self): + + # train_praba, test_praba, train_y, test_y, feature_importance, split_xbin, trees, params + # 从hdfs读取结果文件信息,构造为属性 + train_praba_path = os.path.join(self.job_id, self.TRAIN_MODEL_OUTPUT_FILE) + test_praba_path = os.path.join(self.job_id, self.TEST_MODEL_OUTPUT_FILE) + train_output = HDFSApi.download(train_praba_path) + test_output = HDFSApi.download(test_praba_path) + self.train_praba = train_output['class_pred'].values + self.test_praba = test_output['class_pred'].values + if 'class_label' in train_output.columns: + self.train_y = train_output['class_label'].values + self.test_y = test_output['class_label'].values + else: + self.train_y = None + self.test_y = None + + feature_bin_path = os.path.join(self.job_id, self.FEATURE_BIN_FILE) + model_path = os.path.join(self.job_id, self.MODEL_DATA_FILE) + feature_bin_data = HDFSApi.download_data(feature_bin_path) + model_data = HDFSApi.download_data(model_path) + + self.feature_importance = ... + self.split_xbin = feature_bin_data + self.trees = model_data + self.params = ... diff --git a/python/ppc_dev/result/psi_result.py b/python/ppc_dev/result/psi_result.py new file mode 100644 index 00000000..dae03f58 --- /dev/null +++ b/python/ppc_dev/result/psi_result.py @@ -0,0 +1,27 @@ +import os + +from ppc_dev.wedpr_data.data_context import DataContext +from ppc_dev.common.base_result import BaseResult + + +class PSIResult(BaseResult): + + PSI_RESULT_FILE = "psi_result.csv" + + def __init__(self, dataset: DataContext, job_id: str): + + super().__init__(dataset.ctx) + self.job_id = job_id + + participant_id_list = [] + for dataset in self.dataset.datasets: + participant_id_list.append(dataset.agency.agency_id) + self.participant_id_list = participant_id_list + + result_list = [] + for dataset in self.dataset.datasets: + dataset.update_path(os.path.join(self.job_id, self.PSI_RESULT_FILE)) + result_list.append(dataset) + + psi_result = DataContext(*result_list) + return psi_result diff --git a/python/ppc_dev/test/__init__.py b/python/ppc_dev/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_dev/test/test_dev.py b/python/ppc_dev/test/test_dev.py new file mode 100644 index 00000000..03bad07a --- /dev/null +++ b/python/ppc_dev/test/test_dev.py @@ -0,0 +1,70 @@ +import unittest +import numpy as np +import pandas as pd +from sklearn import metrics + +from ppc_dev.common.base_context import BaseContext +from ppc_dev.utils.agency import Agency +from ppc_dev.wedpr_data.wedpr_data import WedprData +from ppc_dev.wedpr_data.data_context import DataContext +from ppc_dev.wedpr_session.wedpr_session import WedprSession + + +# 从jupyter环境中获取project_id等信息 +# create workspace +# 相同项目/刷新专家模式project_id固定 +project_id = 'p-123' +user = 'admin' +my_agency='WeBank' +pws_endpoint = '0.0.0.0:0000' +hdfs_endpoint = '0.0.0.0:0001' +token = 'abc...' + + +# 自定义合作方机构 +partner_agency1='SG' +partner_agency2='TX' + +# 初始化project ctx 信息 +ctx = BaseContext(project_id, user, pws_endpoint, hdfs_endpoint, token) + +# 注册 agency +agency1 = Agency(agency_id=my_agency) +agency2 = Agency(agency_id=partner_agency1) + +# 注册 dataset,支持两种方式: pd.Dataframe, hdfs_path +# pd.Dataframe +df = pd.DataFrame({ + 'id': np.arange(0, 100), # id列,顺序整数 + **{f'x{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数 +}) +dataset1 = WedprData(ctx, values=df, agency=agency1) +dataset1.storage_client = None +dataset1.save_values(path='./project_id/user/data/d-101') +# hdfs_path +dataset2 = WedprData(ctx, dataset_path='./data_path/d-123', agency=agency2, is_label_holder=True) +dataset2.storage_client = None +dataset2.load_values() + +# 支持更新dataset的values数据 +df2 = pd.DataFrame({ + 'id': np.arange(0, 100), # id列,顺序整数 + 'y': np.random.randint(0, 2, size=100), + **{f'x{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数 +}) +dataset2.update_values(values=df2) + +# 构建 dataset context +dataset = DataContext(dataset1, dataset2) + +# 初始化 wedpr task session(含数据) +task = WedprSession(dataset, my_agency=my_agency) +print(task.participant_id_list, task.result_receiver_id_list) +# 执行psi任务 +psi_result = task.psi() + +# 初始化 wedpr task session(不含数据) (推荐:使用更灵活) +task = WedprSession(my_agency=my_agency) +# 执行psi任务 +fe_result = task.proprecessing(dataset) +print(task.participant_id_list, task.result_receiver_id_list) diff --git a/python/ppc_dev/utils/__init__.py b/python/ppc_dev/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_dev/utils/agency.py b/python/ppc_dev/utils/agency.py new file mode 100644 index 00000000..461bce03 --- /dev/null +++ b/python/ppc_dev/utils/agency.py @@ -0,0 +1,5 @@ +class Agency: + + def __init__(self, agency_id): + + self.agency_id = agency_id diff --git a/python/ppc_dev/utils/utils.py b/python/ppc_dev/utils/utils.py new file mode 100644 index 00000000..25e0ec72 --- /dev/null +++ b/python/ppc_dev/utils/utils.py @@ -0,0 +1,12 @@ +import uuid +from enum import Enum + + +class IdPrefixEnum(Enum): + DATASET = "d-" + ALGORITHM = "a-" + JOB = "j-" + + +def make_id(prefix): + return prefix + str(uuid.uuid4()).replace("-", "") diff --git a/python/ppc_dev/wedpr_data/__init__.py b/python/ppc_dev/wedpr_data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_dev/wedpr_data/data_context.py b/python/ppc_dev/wedpr_data/data_context.py new file mode 100644 index 00000000..177a828f --- /dev/null +++ b/python/ppc_dev/wedpr_data/data_context.py @@ -0,0 +1,35 @@ +import os + +from ppc_dev.utils import utils + + +class DataContext: + + def __init__(self, *datasets): + self.datasets = list(datasets) + self.ctx = self.datasets[0].ctx + + self._check_datasets() + + def _save_dataset(self, dataset): + if dataset.dataset_path is None: + dataset.dataset_id = utils.make_id(utils.IdPrefixEnum.DATASET.value) + dataset.dataset_path = os.path.join(dataset.ctx.workspace, dataset.dataset_id) + if self.storage_client is not None: + self.storage_client.upload(self.values, self.dataset_path) + + def _check_datasets(self): + for dataset in self.datasets: + self._save_dataset(dataset) + + def to_psi_format(self): + dataset_psi = [] + for dataset in self.datasets: + dataset_psi.append(dataset.dataset_path) + return dataset_psi + + def to_model_formort(self): + dataset_model = [] + for dataset in self.datasets: + dataset_model.append(dataset.dataset_path) + return dataset_model diff --git a/python/ppc_dev/wedpr_data/wedpr_data.py b/python/ppc_dev/wedpr_data/wedpr_data.py new file mode 100644 index 00000000..2e4c9575 --- /dev/null +++ b/python/ppc_dev/wedpr_data/wedpr_data.py @@ -0,0 +1,64 @@ +import pandas as pd + +from ppc_dev.common.base_context import BaseContext +from ppc_dev.job_exceuter.hdfs_client import HDFSApi + + +class WedprData: + + def __init__(self, + ctx: BaseContext, + dataset_id=None, + dataset_path=None, + agency=None, + values=None, + is_label_holder=False): + + # super().__init__(project_id) + self.ctx = ctx + + self.dataset_id = dataset_id + self.dataset_path = dataset_path + self.agency = agency + self.values = values + self.is_label_holder = is_label_holder + self.columns = None + self.shape = None + + self.storage_client = HDFSApi(self.ctx.hdfs_endpoint) + + if self.values is not None: + self.columns = self.values.columns + self.shape = self.values.shape + + def load_values(self): + # 加载hdfs的数据集 + if self.storage_client is not None: + self.values = self.storage_client.download(self.dataset_path) + self.columns = self.values.columns + self.shape = self.values.shape + + def save_values(self, path=None): + # 保存数据到hdfs目录 + if path is not None: + self.dataset_path = path + if self.storage_client is not None: + self.storage_client.upload(self.values, self.dataset_path) + + def update_values(self, values: pd.DataFrame = None, path: str = None): + # 将数据集存入hdfs相同路径,替换旧数据集 + if values is not None: + self.values = values + self.columns = self.values.columns + self.shape = self.values.shape + if path is not None: + self.dataset_path = path + if self.storage_client is not None: + self.storage_client.upload(self.values, self.dataset_path) + + def update_path(self, path: str = None): + # 将数据集存入hdfs相同路径,替换旧数据集 + if path is not None: + self.dataset_path = path + if self.values is not None: + self.values = None diff --git a/python/ppc_dev/wedpr_session/__init__.py b/python/ppc_dev/wedpr_session/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_dev/wedpr_session/wedpr_session.py b/python/ppc_dev/wedpr_session/wedpr_session.py new file mode 100644 index 00000000..e3c34e01 --- /dev/null +++ b/python/ppc_dev/wedpr_session/wedpr_session.py @@ -0,0 +1,135 @@ +from ppc_dev.wedpr_data.data_context import DataContext +from ppc_dev.job_exceuter.pws_client import PWSApi +from ppc_dev.result.psi_result import PSIResult +from ppc_dev.result.fe_result import FeResult +from ppc_dev.result.model_result import ModelResult + + +class WedprSession: + + def __init__(self, dataset: DataContext = None, my_agency = None): + + self.dataset = dataset + self.create_agency = my_agency + + if self.dataset is not None: + self.participant_id_list = self.get_agencies() + self.label_holder_agency = self.get_label_holder_agency() + self.result_receiver_id_list = [my_agency] # 仅限jupyter所在机构 + + self.excute = PWSApi(self.dataset.ctx.pws_endpoint, self.dataset.ctx.token) + + def task(self, datasets: list, params: dict = {}): + + self.check_agencies() + job_response = self.excute.run(datasets, params) + + return job_response.job_id + + def psi(self, dataset: DataContext = None, merge_filed: str = 'id'): + + if dataset is not None: + self.update_dataset(dataset) + + # 构造参数 + params = {merge_filed: merge_filed} + + # 执行任务 + job_id = self.task(self.dataset.to_psi_format(), params) + + # 结果处理 + psi_result = PSIResult(dataset, job_id) + + return psi_result + + def proprecessing(self, dataset: DataContext = None, psi_result = None, params: dict = None): + + if dataset is not None: + self.update_dataset(dataset) + + job_id = self.task(self.dataset.to_model_formort()) + + # 结果处理 + datasets_pre = FeResult(dataset, job_id) + + return datasets_pre + + def feature_engineering(self, dataset: DataContext = None, psi_result = None, params: dict = None): + + if dataset is not None: + self.update_dataset(dataset) + + job_id = self.task(self.dataset.to_model_formort()) + + # 结果处理 + datasets_fe = FeResult(dataset, job_id) + + return datasets_fe + + def xgb_training(self, dataset: DataContext = None, psi_result = None, params: dict = None): + + if dataset is not None: + self.update_dataset(dataset) + self.check_datasets() + + job_id = self.task(self.dataset.to_model_formort()) + + # 结果处理 + model_result = ModelResult(dataset, job_id, job_type='xgb_training') + + return model_result + + def xgb_predict(self, dataset: DataContext = None, psi_result = None, model = None): + + if dataset is not None: + self.update_dataset(dataset) + self.check_datasets() + + # 上传模型到hdfs + job_id = self.task(self.dataset.to_model_formort()) + + # 结果处理 + model_result = ModelResult(dataset, job_id, job_type='xgb_predict') + + # 结果处理 + return model_result + + def update_dataset(self, dataset: DataContext): + self.dataset = dataset + self.participant_id_list = self.get_agencies() + self.label_holder_agency = self.get_label_holder_agency() + + def get_agencies(self): + participant_id_list = [] + for dataset in self.dataset.datasets: + participant_id_list.append(dataset.agency.agency_id) + return participant_id_list + + def get_label_holder_agency(self): + label_holder_agency = None + for dataset in self.dataset.datasets: + if dataset.is_label_holder: + label_holder_agency = dataset.agency.agency_id + return label_holder_agency + + def check_agencies(self): + """ + 校验机构数和任务是否匹配 + """ + if len(self.participant_id_list) < 2: + raise ValueError("至少需要传入两个机构") + + def check_datasets(self): + """ + 校验是否包含标签提供方 + """ + if not self.label_holder_agency or self.label_holder_agency not in self.participant_id_list: + raise ValueError("数据集中标签提供方配置错误") + + # def get_agencies(self): + # """ + # 返回所有机构名称的列表。 + + # :return: 机构名称的列表 + # """ + # return self.agencies diff --git a/python/ppc_model/common/base_context.py b/python/ppc_model/common/base_context.py index 43f1e873..9b16680b 100644 --- a/python/ppc_model/common/base_context.py +++ b/python/ppc_model/common/base_context.py @@ -22,6 +22,9 @@ class BaseContext: # TRAIN_MODEL_OUTPUT_FILE = "train_model_output.csv" TRAIN_MODEL_OUTPUT_FILE = "xgb_train_output.csv" + MODEL_FILE = "model.kpl" + MODEL_ENC_FILE = "model_enc.kpl" + def __init__(self, job_id: str, job_temp_dir: str): self.job_id = job_id self.workspace = os.path.join(job_temp_dir, self.job_id) @@ -85,6 +88,11 @@ def __init__(self, job_id: str, job_temp_dir: str): self.metrics_iteration_file = os.path.join( self.workspace, utils.METRICS_OVER_ITERATION_FILE) + self.model_file = os.path.join( + self.workspace, self.MODEL_FILE) + self.model_enc_file = os.path.join( + self.workspace, self.MODEL_ENC_FILE) + self.remote_summary_evaluation_file = os.path.join( self.job_id, utils.MPC_XGB_EVALUATION_TABLE) self.remote_feature_importance_file = os.path.join( @@ -123,6 +131,24 @@ def __init__(self, job_id: str, job_temp_dir: str): self.remote_metrics_iteration_file = os.path.join( self.job_id, utils.METRICS_OVER_ITERATION_FILE) + self.remote_model_file = os.path.join( + self.job_id, self.MODEL_FILE) + self.remote_model_enc_file = os.path.join( + self.job_id, self.MODEL_ENC_FILE) + + # self.get_key_pair() + self.load_key('aes_key.bin') + @staticmethod def feature_engineering_input_path(job_id: str, job_temp_dir: str): return os.path.join(job_temp_dir, job_id, BaseContext.MODEL_PREPARE_FILE) + + def get_key_pair(self): + with open('public_key.pem', 'rb') as f: + self.public_pem = f.read() + with open('private_key.pem', 'rb') as f: + self.private_pem = f.read() + + def load_key(self, filename): + with open(filename, 'rb') as file: + self.key = file.read() diff --git a/python/ppc_model/model_crypto/__init__.py b/python/ppc_model/model_crypto/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/model_crypto/crypto_aes.py b/python/ppc_model/model_crypto/crypto_aes.py new file mode 100644 index 00000000..e63ee2ea --- /dev/null +++ b/python/ppc_model/model_crypto/crypto_aes.py @@ -0,0 +1,78 @@ +import os +import base64 + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.primitives import padding + + +# 生成256位(32字节)的AES密钥 +def generate_aes_key(): + return os.urandom(32) # 32 bytes == 256 bits + + +# 将密钥保存到文件 +def save_key_to_file(key, filename): + with open(filename, 'wb') as file: + file.write(key) + + +# 从文件中加载密钥 +def load_key_from_file(filename): + with open(filename, 'rb') as file: + key = file.read() + return key + + +# key = load_key_from_file('aes_key.bin') + + +# AES加密函数 +def encrypt_data(key, plaintext): + # 使用随机生成的初始向量 (IV) + iv = os.urandom(16) # AES块大小为128位(16字节) + + # 创建AES加密器 + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) + encryptor = cipher.encryptor() + + # 对数据进行填充(AES要求输入的块大小为128位) + padder = padding.PKCS7(128).padder() + padded_data = padder.update(plaintext) + padder.finalize() + + # 加密数据 + ciphertext = encryptor.update(padded_data) + encryptor.finalize() + + # 返回IV和密文 + return iv + ciphertext + + +# AES解密函数 +def decrypt_data(key, ciphertext): + # 提取IV和密文 + iv = ciphertext[:16] # 前16字节是IV + actual_ciphertext = ciphertext[16:] + + # 创建AES解密器 + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) + decryptor = cipher.decryptor() + + # 解密数据 + decrypted_padded_data = decryptor.update(actual_ciphertext) + decryptor.finalize() + + # 去除填充 + unpadder = padding.PKCS7(128).unpadder() + plaintext = unpadder.update(decrypted_padded_data) + unpadder.finalize() + + return plaintext + + +def cipher_to_base64(ciphertext): + # 将bytes类型转换为Base64字符串 + encoded_ciphertext = base64.b64encode(ciphertext).decode('utf-8') + return encoded_ciphertext + + +def base64_to_cipher(data): + decoded_ciphertext = base64.b64decode(data) + return decoded_ciphertext diff --git a/python/ppc_model/model_crypto/test_aes.py b/python/ppc_model/model_crypto/test_aes.py new file mode 100644 index 00000000..802ee54e --- /dev/null +++ b/python/ppc_model/model_crypto/test_aes.py @@ -0,0 +1,26 @@ +from ppc_model.model_crypto.crypto_aes import generate_aes_key, save_key_to_file +from ppc_model.model_crypto.crypto_aes import encrypt_data, decrypt_data +from ppc_model.model_crypto.crypto_aes import base64_to_cipher, cipher_to_base64 + + +key = generate_aes_key() +save_key_to_file(key, 'aes_key.bin') +print("AES密钥已生成并保存到aes_key.bin文件中。") + +plaintext = "需要加密的内容".encode('utf-8') +ciphertext = encrypt_data(key, plaintext) +print(f"加密后的内容: {ciphertext}") + +decrypted_text = decrypt_data(key, ciphertext) +print(f"解密后的内容: {decrypted_text.decode('utf-8')}") + +# 保存密文到文件 +# ciphertext = encrypt_data(key, plaintext) +encoded_ciphertext = cipher_to_base64(ciphertext) +print(f"encoded_ciphertext: {encoded_ciphertext}") + +# 使用AES密钥解密 +decoded_ciphertext = base64_to_cipher(encoded_ciphertext) +print(f"encoded_ciphertext: {decoded_ciphertext}") +decrypted_text = decrypt_data(key, decoded_ciphertext) +print(f"解密后的内容字符串: {decrypted_text.decode('utf-8')}") diff --git a/python/ppc_model/secure_lgbm/secure_lgbm_context.py b/python/ppc_model/secure_lgbm/secure_lgbm_context.py index ea536a14..c366c8e8 100644 --- a/python/ppc_model/secure_lgbm/secure_lgbm_context.py +++ b/python/ppc_model/secure_lgbm/secure_lgbm_context.py @@ -271,3 +271,4 @@ class LGBMMessage(Enum): VALID_LEAF_MASK = "PREDICT_VALID_LEAF_MASK" STOP_ITERATION = "STOP_ITERATION" PREDICT_PRABA = "PREDICT_PRABA" + MODEL_DATA = "MODEL_DATA" diff --git a/python/ppc_model/secure_lgbm/vertical/booster.py b/python/ppc_model/secure_lgbm/vertical/booster.py index 9bfab7b6..aedff0c4 100644 --- a/python/ppc_model/secure_lgbm/vertical/booster.py +++ b/python/ppc_model/secure_lgbm/vertical/booster.py @@ -6,6 +6,7 @@ from ppc_common.ppc_protos.generated.ppc_model_pb2 import BestSplitInfo from ppc_common.ppc_utils.utils import AlgorithmType +from ppc_model.model_crypto.crypto_aes import encrypt_data, decrypt_data, cipher_to_base64, base64_to_cipher from ppc_model.interface.model_base import VerticalModel from ppc_model.datasets.dataset import SecureDataset from ppc_model.common.protocol import PheMessage @@ -241,6 +242,89 @@ def save_model(self, file_path=None): log.info( f"task {self.ctx.task_id}: Saved serial_trees to {self.ctx.model_data_file} finished.") + self.merge_model_file() + + def merge_model_file(self): + + # 加密文件 + lgbm_model = {} + with open(self.ctx.feature_bin_file, 'rb') as f: + feature_bin_data = f.read() + with open(self.ctx.model_data_file, 'rb') as f: + model_data = f.read() + feature_bin_enc = encrypt_data(self.ctx.key, feature_bin_data) + model_data_enc = encrypt_data(self.ctx.key, model_data) + + my_agency_id = self.ctx.components.config_data['AGENCY_ID'] + lgbm_model[my_agency_id] = [cipher_to_base64(feature_bin_enc), cipher_to_base64(model_data_enc)] + + # 发送&接受文件 + for partner_index in range(0, len(self.ctx.participant_id_list)): + if self.ctx.participant_id_list[partner_index] != my_agency_id: + self._send_byte_data( + self.ctx, f'{LGBMMessage.MODEL_DATA.value}_feature_bin', + feature_bin_enc, partner_index) + self._send_byte_data( + self.ctx, f'{LGBMMessage.MODEL_DATA.value}_model_data', + model_data_enc, partner_index) + for partner_index in range(0, len(self.ctx.participant_id_list)): + if self.ctx.participant_id_list[partner_index] != my_agency_id: + feature_bin_enc = self._receive_byte_data( + self.ctx, f'{LGBMMessage.MODEL_DATA.value}_feature_bin', partner_index) + model_data_enc = self._receive_byte_data( + self.ctx, f'{LGBMMessage.MODEL_DATA.value}_model_data', partner_index) + lgbm_model[self.ctx.participant_id_list[partner_index]] = \ + [cipher_to_base64(feature_bin_enc), cipher_to_base64(model_data_enc)] + + # 上传密文模型 + with open(self.ctx.model_enc_file, 'w') as f: + json.dump(lgbm_model, f) + ResultFileHandling._upload_file(self.ctx.components.storage_client, + self.ctx.model_enc_file, self.ctx.remote_model_enc_file) + self.ctx.components.logger().info( + f"task {self.ctx.task_id}: Saved enc model to {self.ctx.model_enc_file} finished.") + + def split_model_file(self): + # 下载密文模型 + try: + ResultFileHandling._download_file(self.ctx.components.storage_client, + self.ctx.remote_model_enc_file, self.ctx.model_enc_file) + except: + pass + + # 发送/接受文件 + my_agency_id = self.ctx.components.config_data['AGENCY_ID'] + if os.path.exists(self.ctx.model_enc_file): + + with open(self.ctx.model_enc_file, 'r') as f: + lgbm_model = json.load(f) + + for partner_index in range(0, len(self.ctx.participant_id_list)): + if self.ctx.participant_id_list[partner_index] != my_agency_id: + feature_bin_enc, model_data_enc = \ + [base64_to_cipher(i) for i in lgbm_model[self.ctx.participant_id_list[partner_index]]] + self._send_byte_data( + self.ctx, f'{LGBMMessage.MODEL_DATA.value}_feature_bin', + feature_bin_enc, partner_index) + self._send_byte_data( + self.ctx, f'{LGBMMessage.MODEL_DATA.value}_model_data', + model_data_enc, partner_index) + feature_bin_enc, model_data_enc = [base64_to_cipher(i) for i in lgbm_model[my_agency_id]] + + else: + feature_bin_enc = self._receive_byte_data( + self.ctx, f'{LGBMMessage.MODEL_DATA.value}_feature_bin', 0) + model_data_enc = self._receive_byte_data( + self.ctx, f'{LGBMMessage.MODEL_DATA.value}_model_data', 0) + + # 解密文件 + feature_bin_data = decrypt_data(self.ctx.key, feature_bin_enc) + model_data = decrypt_data(self.ctx.key, model_data_enc) + with open(self.ctx.feature_bin_file, 'wb') as f: + f.write(feature_bin_data) + with open(self.ctx.model_data_file, 'wb') as f: + f.write(model_data) + def load_model(self, file_path=None): log = self.ctx.components.logger() if file_path is not None: @@ -254,10 +338,13 @@ def load_model(self, file_path=None): self.ctx.remote_model_data_file = os.path.join( self.ctx.model_params.training_job_id, self.ctx.MODEL_DATA_FILE) - ResultFileHandling._download_file(self.ctx.components.storage_client, - self.ctx.feature_bin_file, self.ctx.remote_feature_bin_file) - ResultFileHandling._download_file(self.ctx.components.storage_client, - self.ctx.model_data_file, self.ctx.remote_model_data_file) + try: + ResultFileHandling._download_file(self.ctx.components.storage_client, + self.ctx.feature_bin_file, self.ctx.remote_feature_bin_file) + ResultFileHandling._download_file(self.ctx.components.storage_client, + self.ctx.model_data_file, self.ctx.remote_model_data_file) + except: + self.split_model_file() with open(self.ctx.feature_bin_file, 'r') as f: X_split_dict = json.load(f) diff --git a/python/ppc_model/secure_lr/secure_lr_context.py b/python/ppc_model/secure_lr/secure_lr_context.py index 4ecebb9f..f80ad340 100644 --- a/python/ppc_model/secure_lr/secure_lr_context.py +++ b/python/ppc_model/secure_lr/secure_lr_context.py @@ -218,3 +218,4 @@ class LRMessage(Enum): TEST_LEAF_MASK = "PREDICT_TEST_LEAF_MASK" VALID_LEAF_MASK = "PREDICT_VALID_LEAF_MASK" PREDICT_PRABA = "PREDICT_PRABA" + MODEL_DATA = "MODEL_DATA" diff --git a/python/ppc_model/secure_lr/secure_lr_prediction_engine.py b/python/ppc_model/secure_lr/secure_lr_prediction_engine.py new file mode 100644 index 00000000..218264b6 --- /dev/null +++ b/python/ppc_model/secure_lr/secure_lr_prediction_engine.py @@ -0,0 +1,38 @@ +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +from ppc_model.common.protocol import TaskRole, ModelTask +from ppc_model.common.global_context import components +from ppc_model.interface.task_engine import TaskEngine +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.metrics.evaluation import Evaluation +from ppc_model.common.model_result import ResultFileHandling +from ppc_model.secure_lr.secure_lr_context import SecureLRContext +from ppc_model.secure_lr.vertical import VerticalLRActiveParty, VerticalLRPassiveParty + + +class SecureLGBMPredictionEngine(TaskEngine): + task_type = ModelTask.LR_PREDICTING + + @staticmethod + def run(args): + + task_info = SecureLRContext(args, components) + secure_dataset = SecureDataset(task_info) + + if task_info.role == TaskRole.ACTIVE_PARTY: + booster = VerticalLRActiveParty(task_info, secure_dataset) + elif task_info.role == TaskRole.PASSIVE_PARTY: + booster = VerticalLRPassiveParty(task_info, secure_dataset) + else: + raise PpcException(PpcErrorCode.ROLE_TYPE_ERROR.get_code(), + PpcErrorCode.ROLE_TYPE_ERROR.get_message()) + + booster.load_model() + booster.predict() + + # 获取测试集的预测概率值 + test_praba = booster.get_test_praba() + + # 获取测试集的预测值评估指标 + Evaluation(task_info, secure_dataset, test_praba=test_praba) + + ResultFileHandling(task_info) diff --git a/python/ppc_model/secure_lr/secure_lr_training_engine.py b/python/ppc_model/secure_lr/secure_lr_training_engine.py index c848fd76..66497f23 100644 --- a/python/ppc_model/secure_lr/secure_lr_training_engine.py +++ b/python/ppc_model/secure_lr/secure_lr_training_engine.py @@ -36,5 +36,4 @@ def run(args): # 获取训练集和验证集的预测值评估指标 Evaluation(task_info, secure_dataset, train_praba, test_praba) - ModelPlot(booster) ResultFileHandling(task_info) diff --git a/python/ppc_model/secure_lr/test/test_secure_lr_training.py b/python/ppc_model/secure_lr/test/test_secure_lr_training.py index 14a9bce5..d90e66a7 100644 --- a/python/ppc_model/secure_lr/test/test_secure_lr_training.py +++ b/python/ppc_model/secure_lr/test/test_secure_lr_training.py @@ -137,40 +137,40 @@ def test_fit(self): def active_worker(): try: booster_a.fit() - # booster_a.save_model() - # train_praba = booster_a.get_train_praba() - # test_praba = booster_a.get_test_praba() - # Evaluation(task_info_a, secure_dataset_a, - # train_praba, test_praba) - # ResultFileHandling(task_info_a) - # booster_a.load_model() - # booster_a.predict() - # test_praba = booster_a.get_test_praba() - # task_info_a.algorithm_type = 'Predict' - # task_info_a.sync_file_list = {} - # Evaluation(task_info_a, secure_dataset_a, - # test_praba=test_praba) - # ResultFileHandling(task_info_a) + booster_a.save_model() + train_praba = booster_a.get_train_praba() + test_praba = booster_a.get_test_praba() + Evaluation(task_info_a, secure_dataset_a, + train_praba, test_praba) + ResultFileHandling(task_info_a) + booster_a.load_model() + booster_a.predict() + test_praba = booster_a.get_test_praba() + task_info_a.algorithm_type = 'Predict' + task_info_a.sync_file_list = {} + Evaluation(task_info_a, secure_dataset_a, + test_praba=test_praba) + ResultFileHandling(task_info_a) except Exception as e: task_info_a.components.logger().info(traceback.format_exc()) def passive_worker(): try: booster_b.fit() - # booster_b.save_model() - # train_praba = booster_b.get_train_praba() - # test_praba = booster_b.get_test_praba() - # Evaluation(task_info_b, secure_dataset_b, - # train_praba, test_praba) - # ResultFileHandling(task_info_b) - # booster_b.load_model() - # booster_b.predict() - # test_praba = booster_b.get_test_praba() - # task_info_b.algorithm_type = 'Predict' - # task_info_b.sync_file_list = {} - # Evaluation(task_info_b, secure_dataset_b, - # test_praba=test_praba) - # ResultFileHandling(task_info_b) + booster_b.save_model() + train_praba = booster_b.get_train_praba() + test_praba = booster_b.get_test_praba() + Evaluation(task_info_b, secure_dataset_b, + train_praba, test_praba) + ResultFileHandling(task_info_b) + booster_b.load_model() + booster_b.predict() + test_praba = booster_b.get_test_praba() + task_info_b.algorithm_type = 'Predict' + task_info_b.sync_file_list = {} + Evaluation(task_info_b, secure_dataset_b, + test_praba=test_praba) + ResultFileHandling(task_info_b) except Exception as e: task_info_b.components.logger().info(traceback.format_exc()) diff --git a/python/ppc_model/secure_lr/vertical/booster.py b/python/ppc_model/secure_lr/vertical/booster.py index 68a68d5f..24f233cc 100644 --- a/python/ppc_model/secure_lr/vertical/booster.py +++ b/python/ppc_model/secure_lr/vertical/booster.py @@ -7,6 +7,7 @@ from ppc_common.ppc_protos.generated.ppc_model_pb2 import BestSplitInfo from ppc_common.ppc_utils.utils import AlgorithmType +from ppc_model.model_crypto.crypto_aes import encrypt_data, decrypt_data, cipher_to_base64, base64_to_cipher from ppc_model.interface.model_base import VerticalModel from ppc_model.datasets.data_reduction.feature_selection import FeatureSelection from ppc_model.datasets.dataset import SecureDataset @@ -258,6 +259,71 @@ def save_model(self, file_path=None): log.info( f"task {self.ctx.task_id}: Saved serial_weight to {self.ctx.model_data_file} finished.") + self.merge_model_file() + + def merge_model_file(self): + + # 加密文件 + lr_model = {} + with open(self.ctx.model_data_file, 'rb') as f: + model_data = f.read() + model_data_enc = encrypt_data(self.ctx.key, model_data) + + my_agency_id = self.ctx.components.config_data['AGENCY_ID'] + lr_model[my_agency_id] = cipher_to_base64(model_data_enc) + + # 发送&接受文件 + for partner_index in range(0, len(self.ctx.participant_id_list)): + if self.ctx.participant_id_list[partner_index] != my_agency_id: + self._send_byte_data( + self.ctx, f'{LRMessage.MODEL_DATA.value}_model_data', + model_data_enc, partner_index) + for partner_index in range(0, len(self.ctx.participant_id_list)): + if self.ctx.participant_id_list[partner_index] != my_agency_id: + model_data_enc = self._receive_byte_data( + self.ctx, f'{LRMessage.MODEL_DATA.value}_model_data', partner_index) + lr_model[self.ctx.participant_id_list[partner_index]] = cipher_to_base64(model_data_enc) + + # 上传密文模型 + with open(self.ctx.model_enc_file, 'w') as f: + json.dump(lr_model, f) + ResultFileHandling._upload_file(self.ctx.components.storage_client, + self.ctx.model_enc_file, self.ctx.remote_model_enc_file) + self.ctx.components.logger().info( + f"task {self.ctx.task_id}: Saved enc model to {self.ctx.model_enc_file} finished.") + + def split_model_file(self): + # 下载密文模型 + try: + ResultFileHandling._download_file(self.ctx.components.storage_client, + self.ctx.remote_model_enc_file, self.ctx.model_enc_file) + except: + pass + + # 发送/接受文件 + my_agency_id = self.ctx.components.config_data['AGENCY_ID'] + if os.path.exists(self.ctx.model_enc_file): + + with open(self.ctx.model_enc_file, 'r') as f: + lr_model = json.load(f) + + for partner_index in range(0, len(self.ctx.participant_id_list)): + if self.ctx.participant_id_list[partner_index] != my_agency_id: + model_data_enc = base64_to_cipher(lr_model[self.ctx.participant_id_list[partner_index]]) + self._send_byte_data( + self.ctx, f'{LRMessage.MODEL_DATA.value}_model_data', + model_data_enc, partner_index) + model_data_enc = base64_to_cipher(lr_model[my_agency_id]) + + else: + model_data_enc = self._receive_byte_data( + self.ctx, f'{LRMessage.MODEL_DATA.value}_model_data', 0) + + # 解密文件 + model_data = decrypt_data(self.ctx.key, model_data_enc) + with open(self.ctx.model_data_file, 'wb') as f: + f.write(model_data) + def load_model(self, file_path=None): log = self.ctx.components.logger() if file_path is not None: @@ -267,8 +333,11 @@ def load_model(self, file_path=None): self.ctx.remote_model_data_file = os.path.join( self.ctx.model_params.training_job_id, self.ctx.MODEL_DATA_FILE) - ResultFileHandling._download_file(self.ctx.components.storage_client, - self.ctx.model_data_file, self.ctx.remote_model_data_file) + try: + ResultFileHandling._download_file(self.ctx.components.storage_client, + self.ctx.model_data_file, self.ctx.remote_model_data_file) + except: + self.split_model_file() with open(self.ctx.model_data_file, 'r') as f: serial_weight = json.load(f)