-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* update secure lr * update model and predict * update ppc_dev
- Loading branch information
1 parent
f48a01f
commit 2070004
Showing
33 changed files
with
933 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
m�}9��H� | ||
褊���c�?Ӈ!<��>��� |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import os | ||
|
||
|
||
class BaseContext: | ||
|
||
def __init__(self, project_id, user_name, pws_endpoint=None, hdfs_endpoint=None, token=None): | ||
|
||
self.project_id = project_id | ||
self.user_name = user_name | ||
self.pws_endpoint = pws_endpoint | ||
self.hdfs_endpoint = hdfs_endpoint | ||
self.token = token | ||
self.workspace = os.path.join(self.project_id, self.user_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from ppc_dev.common.base_context import BaseContext | ||
|
||
|
||
class BaseResult: | ||
|
||
def __init__(self, ctx: BaseContext): | ||
|
||
self.ctx = ctx |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import requests | ||
import pandas as pd | ||
import io | ||
|
||
|
||
class HDFSApi: | ||
def __init__(self, base_url): | ||
self.base_url = base_url | ||
|
||
def upload(self, dataframe, hdfs_path): | ||
""" | ||
上传Pandas DataFrame到HDFS | ||
:param dataframe: 要上传的Pandas DataFrame | ||
:param hdfs_path: HDFS目标路径 | ||
:return: 响应信息 | ||
""" | ||
# 将DataFrame转换为CSV格式 | ||
csv_buffer = io.StringIO() | ||
dataframe.to_csv(csv_buffer, index=False) | ||
|
||
# 发送PUT请求上传CSV数据 | ||
response = requests.put( | ||
f"{self.base_url}/upload?path={hdfs_path}", | ||
data=csv_buffer.getvalue(), | ||
headers={'Content-Type': 'text/csv'} | ||
) | ||
return response.json() | ||
|
||
def download(self, hdfs_path): | ||
""" | ||
从HDFS下载数据并返回为Pandas DataFrame | ||
:param hdfs_path: HDFS文件路径 | ||
:return: Pandas DataFrame | ||
""" | ||
response = requests.get(f"{self.base_url}/download?path={hdfs_path}") | ||
if response.status_code == 200: | ||
# 读取CSV数据并转换为DataFrame | ||
dataframe = pd.read_csv(io.StringIO(response.text)) | ||
return dataframe | ||
else: | ||
raise Exception(f"下载失败: {response.json()}") | ||
|
||
def download_data(self, hdfs_path): | ||
""" | ||
从HDFS下载数据并返回为Pandas DataFrame | ||
:param hdfs_path: HDFS文件路径 | ||
:return: text | ||
""" | ||
response = requests.get(f"{self.base_url}/download?path={hdfs_path}") | ||
if response.status_code == 200: | ||
return response.text | ||
else: | ||
raise Exception(f"下载失败: {response.json()}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import random | ||
import time | ||
|
||
from ppc_common.ppc_utils import http_utils | ||
from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode | ||
|
||
|
||
class PWSApi: | ||
def __init__(self, endpoint, token, | ||
polling_interval_s: int = 5, max_retries: int = 5, retry_delay_s: int = 5): | ||
self.endpoint = endpoint | ||
self.token = token | ||
self.polling_interval_s = polling_interval_s | ||
self.max_retries = max_retries | ||
self.retry_delay_s = retry_delay_s | ||
self._async_run_task_method = 'asyncRunTask' | ||
self._get_task_status_method = 'getTaskStatus' | ||
self._completed_status = 'COMPLETED' | ||
self._failed_status = 'FAILED' | ||
|
||
def run(self, datasets, params): | ||
params = { | ||
'jsonrpc': '1', | ||
'method': self._async_run_task_method, | ||
'token': self.token, | ||
'id': random.randint(1, 65535), | ||
'dataset': datasets, | ||
'params': params | ||
} | ||
response = self._send_request_with_retry(http_utils.send_post_request, self.endpoint, None, params) | ||
if response.status_code != 200: | ||
raise Exception(f"创建任务失败: {response.json()}") | ||
return self._poll_task_status(response.job_id, self.token) | ||
|
||
def _poll_task_status(self, job_id, token): | ||
while True: | ||
params = { | ||
'jsonrpc': '1', | ||
'method': self._get_task_status_method, | ||
'token': token, | ||
'id': random.randint(1, 65535), | ||
'params': { | ||
'taskID': job_id, | ||
} | ||
} | ||
response = self._send_request_with_retry(http_utils.send_post_request, self.endpoint, None, params) | ||
if response.status_code != 200: | ||
raise Exception(f"轮询任务失败: {response.json()}") | ||
if response['result']['status'] == self._completed_status: | ||
return response['result'] | ||
elif response['result']['status'] == self._failed_status: | ||
raise PpcException(PpcErrorCode.CALL_SCS_ERROR.get_code(), response['data']) | ||
time.sleep(self.polling_interval_s) | ||
|
||
def _send_request_with_retry(self, request_func, *args, **kwargs): | ||
attempt = 0 | ||
while attempt < self.max_retries: | ||
try: | ||
response = request_func(*args, **kwargs) | ||
return response | ||
except Exception as e: | ||
attempt += 1 | ||
if attempt < self.max_retries: | ||
time.sleep(self.retry_delay_s) | ||
else: | ||
raise e |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import os | ||
|
||
from ppc_dev.wedpr_data.data_context import DataContext | ||
from ppc_dev.common.base_result import BaseResult | ||
|
||
|
||
class FeResult(BaseResult): | ||
|
||
FE_RESULT_FILE = "fe_result.csv" | ||
|
||
def __init__(self, dataset: DataContext, job_id: str): | ||
|
||
super().__init__(dataset.ctx) | ||
self.job_id = job_id | ||
|
||
participant_id_list = [] | ||
for dataset in self.dataset.datasets: | ||
participant_id_list.append(dataset.agency.agency_id) | ||
self.participant_id_list = participant_id_list | ||
|
||
result_list = [] | ||
for dataset in self.dataset.datasets: | ||
dataset.update_path(os.path.join(self.job_id, self.FE_RESULT_FILE)) | ||
result_list.append(dataset) | ||
|
||
fe_result = DataContext(*result_list) | ||
return fe_result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import os | ||
import numpy as np | ||
|
||
from ppc_common.ppc_utils import utils | ||
|
||
from ppc_dev.wedpr_data.data_context import DataContext | ||
from ppc_dev.common.base_result import BaseResult | ||
from ppc_dev.job_exceuter.hdfs_client import HDFSApi | ||
|
||
|
||
class ModelResult(BaseResult): | ||
|
||
FEATURE_BIN_FILE = "feature_bin.json" | ||
MODEL_DATA_FILE = utils.XGB_TREE_PERFIX + '.json' | ||
TEST_MODEL_OUTPUT_FILE = "xgb_output.csv" | ||
TRAIN_MODEL_OUTPUT_FILE = "xgb_train_output.csv" | ||
|
||
def __init__(self, dataset: DataContext, job_id: str, job_type: str): | ||
|
||
super().__init__(dataset.ctx) | ||
self.job_id = job_id | ||
|
||
participant_id_list = [] | ||
for dataset in self.dataset.datasets: | ||
participant_id_list.append(dataset.agency.agency_id) | ||
self.participant_id_list = participant_id_list | ||
|
||
if job_type == 'xgb_training': | ||
self._xgb_train_result() | ||
|
||
def _xgb_train_result(self): | ||
|
||
# train_praba, test_praba, train_y, test_y, feature_importance, split_xbin, trees, params | ||
# 从hdfs读取结果文件信息,构造为属性 | ||
train_praba_path = os.path.join(self.job_id, self.TRAIN_MODEL_OUTPUT_FILE) | ||
test_praba_path = os.path.join(self.job_id, self.TEST_MODEL_OUTPUT_FILE) | ||
train_output = HDFSApi.download(train_praba_path) | ||
test_output = HDFSApi.download(test_praba_path) | ||
self.train_praba = train_output['class_pred'].values | ||
self.test_praba = test_output['class_pred'].values | ||
if 'class_label' in train_output.columns: | ||
self.train_y = train_output['class_label'].values | ||
self.test_y = test_output['class_label'].values | ||
else: | ||
self.train_y = None | ||
self.test_y = None | ||
|
||
feature_bin_path = os.path.join(self.job_id, self.FEATURE_BIN_FILE) | ||
model_path = os.path.join(self.job_id, self.MODEL_DATA_FILE) | ||
feature_bin_data = HDFSApi.download_data(feature_bin_path) | ||
model_data = HDFSApi.download_data(model_path) | ||
|
||
self.feature_importance = ... | ||
self.split_xbin = feature_bin_data | ||
self.trees = model_data | ||
self.params = ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import os | ||
|
||
from ppc_dev.wedpr_data.data_context import DataContext | ||
from ppc_dev.common.base_result import BaseResult | ||
|
||
|
||
class PSIResult(BaseResult): | ||
|
||
PSI_RESULT_FILE = "psi_result.csv" | ||
|
||
def __init__(self, dataset: DataContext, job_id: str): | ||
|
||
super().__init__(dataset.ctx) | ||
self.job_id = job_id | ||
|
||
participant_id_list = [] | ||
for dataset in self.dataset.datasets: | ||
participant_id_list.append(dataset.agency.agency_id) | ||
self.participant_id_list = participant_id_list | ||
|
||
result_list = [] | ||
for dataset in self.dataset.datasets: | ||
dataset.update_path(os.path.join(self.job_id, self.PSI_RESULT_FILE)) | ||
result_list.append(dataset) | ||
|
||
psi_result = DataContext(*result_list) | ||
return psi_result |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import unittest | ||
import numpy as np | ||
import pandas as pd | ||
from sklearn import metrics | ||
|
||
from ppc_dev.common.base_context import BaseContext | ||
from ppc_dev.utils.agency import Agency | ||
from ppc_dev.wedpr_data.wedpr_data import WedprData | ||
from ppc_dev.wedpr_data.data_context import DataContext | ||
from ppc_dev.wedpr_session.wedpr_session import WedprSession | ||
|
||
|
||
# 从jupyter环境中获取project_id等信息 | ||
# create workspace | ||
# 相同项目/刷新专家模式project_id固定 | ||
project_id = 'p-123' | ||
user = 'admin' | ||
my_agency='WeBank' | ||
pws_endpoint = '0.0.0.0:0000' | ||
hdfs_endpoint = '0.0.0.0:0001' | ||
token = 'abc...' | ||
|
||
|
||
# 自定义合作方机构 | ||
partner_agency1='SG' | ||
partner_agency2='TX' | ||
|
||
# 初始化project ctx 信息 | ||
ctx = BaseContext(project_id, user, pws_endpoint, hdfs_endpoint, token) | ||
|
||
# 注册 agency | ||
agency1 = Agency(agency_id=my_agency) | ||
agency2 = Agency(agency_id=partner_agency1) | ||
|
||
# 注册 dataset,支持两种方式: pd.Dataframe, hdfs_path | ||
# pd.Dataframe | ||
df = pd.DataFrame({ | ||
'id': np.arange(0, 100), # id列,顺序整数 | ||
**{f'x{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数 | ||
}) | ||
dataset1 = WedprData(ctx, values=df, agency=agency1) | ||
dataset1.storage_client = None | ||
dataset1.save_values(path='./project_id/user/data/d-101') | ||
# hdfs_path | ||
dataset2 = WedprData(ctx, dataset_path='./data_path/d-123', agency=agency2, is_label_holder=True) | ||
dataset2.storage_client = None | ||
dataset2.load_values() | ||
|
||
# 支持更新dataset的values数据 | ||
df2 = pd.DataFrame({ | ||
'id': np.arange(0, 100), # id列,顺序整数 | ||
'y': np.random.randint(0, 2, size=100), | ||
**{f'x{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数 | ||
}) | ||
dataset2.update_values(values=df2) | ||
|
||
# 构建 dataset context | ||
dataset = DataContext(dataset1, dataset2) | ||
|
||
# 初始化 wedpr task session(含数据) | ||
task = WedprSession(dataset, my_agency=my_agency) | ||
print(task.participant_id_list, task.result_receiver_id_list) | ||
# 执行psi任务 | ||
psi_result = task.psi() | ||
|
||
# 初始化 wedpr task session(不含数据) (推荐:使用更灵活) | ||
task = WedprSession(my_agency=my_agency) | ||
# 执行psi任务 | ||
fe_result = task.proprecessing(dataset) | ||
print(task.participant_id_list, task.result_receiver_id_list) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
class Agency: | ||
|
||
def __init__(self, agency_id): | ||
|
||
self.agency_id = agency_id |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import uuid | ||
from enum import Enum | ||
|
||
|
||
class IdPrefixEnum(Enum): | ||
DATASET = "d-" | ||
ALGORITHM = "a-" | ||
JOB = "j-" | ||
|
||
|
||
def make_id(prefix): | ||
return prefix + str(uuid.uuid4()).replace("-", "") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import os | ||
|
||
from ppc_dev.utils import utils | ||
|
||
|
||
class DataContext: | ||
|
||
def __init__(self, *datasets): | ||
self.datasets = list(datasets) | ||
self.ctx = self.datasets[0].ctx | ||
|
||
self._check_datasets() | ||
|
||
def _save_dataset(self, dataset): | ||
if dataset.dataset_path is None: | ||
dataset.dataset_id = utils.make_id(utils.IdPrefixEnum.DATASET.value) | ||
dataset.dataset_path = os.path.join(dataset.ctx.workspace, dataset.dataset_id) | ||
if self.storage_client is not None: | ||
self.storage_client.upload(self.values, self.dataset_path) | ||
|
||
def _check_datasets(self): | ||
for dataset in self.datasets: | ||
self._save_dataset(dataset) | ||
|
||
def to_psi_format(self): | ||
dataset_psi = [] | ||
for dataset in self.datasets: | ||
dataset_psi.append(dataset.dataset_path) | ||
return dataset_psi | ||
|
||
def to_model_formort(self): | ||
dataset_model = [] | ||
for dataset in self.datasets: | ||
dataset_model.append(dataset.dataset_path) | ||
return dataset_model |
Oops, something went wrong.