Skip to content

Commit

Permalink
Feature milestone2 (#44)
Browse files Browse the repository at this point in the history
* update secure lr

* update model and predict

* update ppc_dev
  • Loading branch information
yanxinyi620 authored Oct 15, 2024
1 parent f48a01f commit 2070004
Show file tree
Hide file tree
Showing 33 changed files with 933 additions and 35 deletions.
2 changes: 2 additions & 0 deletions python/aes_key.bin
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
m�}9��H�
褊���c�?Ӈ!<��>���
Empty file added python/ppc_dev/__init__.py
Empty file.
Empty file.
13 changes: 13 additions & 0 deletions python/ppc_dev/common/base_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os


class BaseContext:

def __init__(self, project_id, user_name, pws_endpoint=None, hdfs_endpoint=None, token=None):

self.project_id = project_id
self.user_name = user_name
self.pws_endpoint = pws_endpoint
self.hdfs_endpoint = hdfs_endpoint
self.token = token
self.workspace = os.path.join(self.project_id, self.user_name)
8 changes: 8 additions & 0 deletions python/ppc_dev/common/base_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ppc_dev.common.base_context import BaseContext


class BaseResult:

def __init__(self, ctx: BaseContext):

self.ctx = ctx
Empty file.
53 changes: 53 additions & 0 deletions python/ppc_dev/job_exceuter/hdfs_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import requests
import pandas as pd
import io


class HDFSApi:
def __init__(self, base_url):
self.base_url = base_url

def upload(self, dataframe, hdfs_path):
"""
上传Pandas DataFrame到HDFS
:param dataframe: 要上传的Pandas DataFrame
:param hdfs_path: HDFS目标路径
:return: 响应信息
"""
# 将DataFrame转换为CSV格式
csv_buffer = io.StringIO()
dataframe.to_csv(csv_buffer, index=False)

# 发送PUT请求上传CSV数据
response = requests.put(
f"{self.base_url}/upload?path={hdfs_path}",
data=csv_buffer.getvalue(),
headers={'Content-Type': 'text/csv'}
)
return response.json()

def download(self, hdfs_path):
"""
从HDFS下载数据并返回为Pandas DataFrame
:param hdfs_path: HDFS文件路径
:return: Pandas DataFrame
"""
response = requests.get(f"{self.base_url}/download?path={hdfs_path}")
if response.status_code == 200:
# 读取CSV数据并转换为DataFrame
dataframe = pd.read_csv(io.StringIO(response.text))
return dataframe
else:
raise Exception(f"下载失败: {response.json()}")

def download_data(self, hdfs_path):
"""
从HDFS下载数据并返回为Pandas DataFrame
:param hdfs_path: HDFS文件路径
:return: text
"""
response = requests.get(f"{self.base_url}/download?path={hdfs_path}")
if response.status_code == 200:
return response.text
else:
raise Exception(f"下载失败: {response.json()}")
66 changes: 66 additions & 0 deletions python/ppc_dev/job_exceuter/pws_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import random
import time

from ppc_common.ppc_utils import http_utils
from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode


class PWSApi:
def __init__(self, endpoint, token,
polling_interval_s: int = 5, max_retries: int = 5, retry_delay_s: int = 5):
self.endpoint = endpoint
self.token = token
self.polling_interval_s = polling_interval_s
self.max_retries = max_retries
self.retry_delay_s = retry_delay_s
self._async_run_task_method = 'asyncRunTask'
self._get_task_status_method = 'getTaskStatus'
self._completed_status = 'COMPLETED'
self._failed_status = 'FAILED'

def run(self, datasets, params):
params = {
'jsonrpc': '1',
'method': self._async_run_task_method,
'token': self.token,
'id': random.randint(1, 65535),
'dataset': datasets,
'params': params
}
response = self._send_request_with_retry(http_utils.send_post_request, self.endpoint, None, params)
if response.status_code != 200:
raise Exception(f"创建任务失败: {response.json()}")
return self._poll_task_status(response.job_id, self.token)

def _poll_task_status(self, job_id, token):
while True:
params = {
'jsonrpc': '1',
'method': self._get_task_status_method,
'token': token,
'id': random.randint(1, 65535),
'params': {
'taskID': job_id,
}
}
response = self._send_request_with_retry(http_utils.send_post_request, self.endpoint, None, params)
if response.status_code != 200:
raise Exception(f"轮询任务失败: {response.json()}")
if response['result']['status'] == self._completed_status:
return response['result']
elif response['result']['status'] == self._failed_status:
raise PpcException(PpcErrorCode.CALL_SCS_ERROR.get_code(), response['data'])
time.sleep(self.polling_interval_s)

def _send_request_with_retry(self, request_func, *args, **kwargs):
attempt = 0
while attempt < self.max_retries:
try:
response = request_func(*args, **kwargs)
return response
except Exception as e:
attempt += 1
if attempt < self.max_retries:
time.sleep(self.retry_delay_s)
else:
raise e
Empty file.
27 changes: 27 additions & 0 deletions python/ppc_dev/result/fe_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

from ppc_dev.wedpr_data.data_context import DataContext
from ppc_dev.common.base_result import BaseResult


class FeResult(BaseResult):

FE_RESULT_FILE = "fe_result.csv"

def __init__(self, dataset: DataContext, job_id: str):

super().__init__(dataset.ctx)
self.job_id = job_id

participant_id_list = []
for dataset in self.dataset.datasets:
participant_id_list.append(dataset.agency.agency_id)
self.participant_id_list = participant_id_list

result_list = []
for dataset in self.dataset.datasets:
dataset.update_path(os.path.join(self.job_id, self.FE_RESULT_FILE))
result_list.append(dataset)

fe_result = DataContext(*result_list)
return fe_result
56 changes: 56 additions & 0 deletions python/ppc_dev/result/model_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import numpy as np

from ppc_common.ppc_utils import utils

from ppc_dev.wedpr_data.data_context import DataContext
from ppc_dev.common.base_result import BaseResult
from ppc_dev.job_exceuter.hdfs_client import HDFSApi


class ModelResult(BaseResult):

FEATURE_BIN_FILE = "feature_bin.json"
MODEL_DATA_FILE = utils.XGB_TREE_PERFIX + '.json'
TEST_MODEL_OUTPUT_FILE = "xgb_output.csv"
TRAIN_MODEL_OUTPUT_FILE = "xgb_train_output.csv"

def __init__(self, dataset: DataContext, job_id: str, job_type: str):

super().__init__(dataset.ctx)
self.job_id = job_id

participant_id_list = []
for dataset in self.dataset.datasets:
participant_id_list.append(dataset.agency.agency_id)
self.participant_id_list = participant_id_list

if job_type == 'xgb_training':
self._xgb_train_result()

def _xgb_train_result(self):

# train_praba, test_praba, train_y, test_y, feature_importance, split_xbin, trees, params
# 从hdfs读取结果文件信息,构造为属性
train_praba_path = os.path.join(self.job_id, self.TRAIN_MODEL_OUTPUT_FILE)
test_praba_path = os.path.join(self.job_id, self.TEST_MODEL_OUTPUT_FILE)
train_output = HDFSApi.download(train_praba_path)
test_output = HDFSApi.download(test_praba_path)
self.train_praba = train_output['class_pred'].values
self.test_praba = test_output['class_pred'].values
if 'class_label' in train_output.columns:
self.train_y = train_output['class_label'].values
self.test_y = test_output['class_label'].values
else:
self.train_y = None
self.test_y = None

feature_bin_path = os.path.join(self.job_id, self.FEATURE_BIN_FILE)
model_path = os.path.join(self.job_id, self.MODEL_DATA_FILE)
feature_bin_data = HDFSApi.download_data(feature_bin_path)
model_data = HDFSApi.download_data(model_path)

self.feature_importance = ...
self.split_xbin = feature_bin_data
self.trees = model_data
self.params = ...
27 changes: 27 additions & 0 deletions python/ppc_dev/result/psi_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

from ppc_dev.wedpr_data.data_context import DataContext
from ppc_dev.common.base_result import BaseResult


class PSIResult(BaseResult):

PSI_RESULT_FILE = "psi_result.csv"

def __init__(self, dataset: DataContext, job_id: str):

super().__init__(dataset.ctx)
self.job_id = job_id

participant_id_list = []
for dataset in self.dataset.datasets:
participant_id_list.append(dataset.agency.agency_id)
self.participant_id_list = participant_id_list

result_list = []
for dataset in self.dataset.datasets:
dataset.update_path(os.path.join(self.job_id, self.PSI_RESULT_FILE))
result_list.append(dataset)

psi_result = DataContext(*result_list)
return psi_result
Empty file added python/ppc_dev/test/__init__.py
Empty file.
70 changes: 70 additions & 0 deletions python/ppc_dev/test/test_dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import unittest
import numpy as np
import pandas as pd
from sklearn import metrics

from ppc_dev.common.base_context import BaseContext
from ppc_dev.utils.agency import Agency
from ppc_dev.wedpr_data.wedpr_data import WedprData
from ppc_dev.wedpr_data.data_context import DataContext
from ppc_dev.wedpr_session.wedpr_session import WedprSession


# 从jupyter环境中获取project_id等信息
# create workspace
# 相同项目/刷新专家模式project_id固定
project_id = 'p-123'
user = 'admin'
my_agency='WeBank'
pws_endpoint = '0.0.0.0:0000'
hdfs_endpoint = '0.0.0.0:0001'
token = 'abc...'


# 自定义合作方机构
partner_agency1='SG'
partner_agency2='TX'

# 初始化project ctx 信息
ctx = BaseContext(project_id, user, pws_endpoint, hdfs_endpoint, token)

# 注册 agency
agency1 = Agency(agency_id=my_agency)
agency2 = Agency(agency_id=partner_agency1)

# 注册 dataset,支持两种方式: pd.Dataframe, hdfs_path
# pd.Dataframe
df = pd.DataFrame({
'id': np.arange(0, 100), # id列,顺序整数
**{f'x{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数
})
dataset1 = WedprData(ctx, values=df, agency=agency1)
dataset1.storage_client = None
dataset1.save_values(path='./project_id/user/data/d-101')
# hdfs_path
dataset2 = WedprData(ctx, dataset_path='./data_path/d-123', agency=agency2, is_label_holder=True)
dataset2.storage_client = None
dataset2.load_values()

# 支持更新dataset的values数据
df2 = pd.DataFrame({
'id': np.arange(0, 100), # id列,顺序整数
'y': np.random.randint(0, 2, size=100),
**{f'x{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数
})
dataset2.update_values(values=df2)

# 构建 dataset context
dataset = DataContext(dataset1, dataset2)

# 初始化 wedpr task session(含数据)
task = WedprSession(dataset, my_agency=my_agency)
print(task.participant_id_list, task.result_receiver_id_list)
# 执行psi任务
psi_result = task.psi()

# 初始化 wedpr task session(不含数据) (推荐:使用更灵活)
task = WedprSession(my_agency=my_agency)
# 执行psi任务
fe_result = task.proprecessing(dataset)
print(task.participant_id_list, task.result_receiver_id_list)
Empty file.
5 changes: 5 additions & 0 deletions python/ppc_dev/utils/agency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class Agency:

def __init__(self, agency_id):

self.agency_id = agency_id
12 changes: 12 additions & 0 deletions python/ppc_dev/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import uuid
from enum import Enum


class IdPrefixEnum(Enum):
DATASET = "d-"
ALGORITHM = "a-"
JOB = "j-"


def make_id(prefix):
return prefix + str(uuid.uuid4()).replace("-", "")
Empty file.
35 changes: 35 additions & 0 deletions python/ppc_dev/wedpr_data/data_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os

from ppc_dev.utils import utils


class DataContext:

def __init__(self, *datasets):
self.datasets = list(datasets)
self.ctx = self.datasets[0].ctx

self._check_datasets()

def _save_dataset(self, dataset):
if dataset.dataset_path is None:
dataset.dataset_id = utils.make_id(utils.IdPrefixEnum.DATASET.value)
dataset.dataset_path = os.path.join(dataset.ctx.workspace, dataset.dataset_id)
if self.storage_client is not None:
self.storage_client.upload(self.values, self.dataset_path)

def _check_datasets(self):
for dataset in self.datasets:
self._save_dataset(dataset)

def to_psi_format(self):
dataset_psi = []
for dataset in self.datasets:
dataset_psi.append(dataset.dataset_path)
return dataset_psi

def to_model_formort(self):
dataset_model = []
for dataset in self.datasets:
dataset_model.append(dataset.dataset_path)
return dataset_model
Loading

0 comments on commit 2070004

Please sign in to comment.