-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
25e376d
commit 55f7765
Showing
29 changed files
with
532 additions
and
478 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
class Constant: | ||
NUMERIC_ARRAY = [i for i in range(10)] | ||
HTTP_STATUS_OK = 200 | ||
DEFAULT_SUBMIT_JOB_URI = '/api/wedpr/v3/project/submitJob' | ||
DEFAULT_QUERY_JOB_STATUS_URL = '/api/wedpr/v3/project/queryJobByCondition' | ||
PSI_RESULT_FILE = "psi_result.csv" | ||
|
||
FEATURE_BIN_FILE = "feature_bin.json" | ||
TEST_MODEL_OUTPUT_FILE = "test_output.csv" | ||
TRAIN_MODEL_OUTPUT_FILE = "train_output.csv" | ||
|
||
FE_RESULT_FILE = "fe_result.csv" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# -*- coding: utf-8 -*- | ||
import uuid | ||
from enum import Enum | ||
import random | ||
from common.utils.constant import Constant | ||
from urllib.parse import urlencode, urlparse, parse_qs, quote | ||
|
||
class IdPrefixEnum(Enum): | ||
DATASET = "d-" | ||
ALGORITHM = "a-" | ||
JOB = "j-" | ||
|
||
|
||
def make_id(prefix): | ||
return prefix + str(uuid.uuid4()).replace("-", "") | ||
|
||
def generate_nonce(nonce_len): | ||
return ''.join(random.choice(Constant.NUMERIC_ARRAY) for _ in range(nonce_len)) | ||
|
||
def add_params_to_url(url, params): | ||
parsed_url = urlparse(url) | ||
query_params = parse_qs(parsed_url.query) | ||
for key, value in params.items(): | ||
query_params[key] = value | ||
new_query = urlencode(query_params, doseq=True) | ||
return parsed_url._replace(query=new_query).geturl() |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# -*- coding: utf-8 -*- | ||
import json | ||
|
||
from wedpr_ml_toolkit.context.data_context import DataContext | ||
from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobParam | ||
from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobInfo | ||
from abc import abstractmethod | ||
from wedpr_ml_toolkit.transport.wedpr_remote_job_client import WeDPRRemoteJobClient | ||
|
||
|
||
class JobContext: | ||
|
||
def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, dataset: DataContext = None, my_agency=None): | ||
if dataset is None: | ||
raise Exception("Must define the job related datasets!") | ||
self.remote_job_client = remote_job_client | ||
self.project_name = project_name | ||
self.dataset = dataset | ||
self.create_agency = my_agency | ||
self.participant_id_list = [] | ||
self.task_parties = [] | ||
self.dataset_id_list = [] | ||
self.dataset_list = [] | ||
self.label_holder_agency = None | ||
self.label_columns = None | ||
self.__init_participant__() | ||
self.__init_label_information__() | ||
self.result_receiver_id_list = [my_agency] # 仅限jupyter所在机构 | ||
self.__check__() | ||
|
||
def __check__(self): | ||
""" | ||
校验机构数和任务是否匹配 | ||
""" | ||
if len(self.participant_id_list) < 2: | ||
raise Exception("至少需要传入两个机构") | ||
if not self.label_holder_agency or self.label_holder_agency not in self.participant_id_list: | ||
raise Exception("数据集中标签提供方配置错误") | ||
|
||
def __init_participant__(self): | ||
participant_id_list = [] | ||
dataset_id_list = [] | ||
for dataset in self.dataset.datasets: | ||
participant_id_list.append(dataset.agency.agency_id) | ||
dataset_id_list.append(dataset.dataset_id) | ||
self.task_parties.append({'userName': dataset.ctx.user_name, | ||
'agency': dataset.agency.agency_id}) | ||
self.participant_id_list = participant_id_list | ||
self.dataset_id_list = dataset_id_list | ||
|
||
def __init_label_information__(self): | ||
label_holder_agency = None | ||
label_columns = None | ||
for dataset in self.dataset.datasets: | ||
if dataset.is_label_holder: | ||
label_holder_agency = dataset.agency.agency_id | ||
label_columns = 'y' | ||
self.label_holder_agency = label_holder_agency | ||
self.label_columns = label_columns | ||
|
||
@abstractmethod | ||
def build(self) -> JobParam: | ||
pass | ||
|
||
@abstractmethod | ||
def get_job_type(self) -> str: | ||
pass | ||
|
||
def submit(self, project_name): | ||
return self.submit(self.build(project_name)) | ||
|
||
|
||
class PSIJobContext(JobContext): | ||
def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, dataset: DataContext = None, my_agency=None, merge_field: str = 'id'): | ||
super().__init__(remote_job_client, project_name, dataset, my_agency) | ||
self.merge_field = merge_field | ||
|
||
def get_job_type(self) -> str: | ||
return "PSI" | ||
|
||
def build(self) -> JobParam: | ||
self.dataset_list = self.dataset.to_psi_format( | ||
self.merge_field, self.result_receiver_id_list) | ||
job_info = JobInfo(self.get_job_type(), self.project_name, json.dumps( | ||
{'dataSetList': self.dataset_list}).replace('"', '\\"')) | ||
job_param = JobParam(job_info, self.task_parties, self.dataset_id_list) | ||
return job_param | ||
|
||
|
||
class PreprocessingJobContext(JobContext): | ||
def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None): | ||
super().__init__(remote_job_client, project_name, dataset, my_agency) | ||
self.model_setting = model_setting | ||
|
||
def get_job_type(self) -> str: | ||
return "PREPROCESSING" | ||
|
||
# TODO: build the request | ||
def build(self) -> JobParam: | ||
return None | ||
|
||
|
||
class FeatureEngineeringJobContext(JobContext): | ||
def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None): | ||
super().__init__(remote_job_client, project_name, dataset, my_agency) | ||
self.model_setting = model_setting | ||
|
||
def get_job_type(self) -> str: | ||
return "FEATURE_ENGINEERING" | ||
|
||
# TODO: build the jobParam | ||
def build(self) -> JobParam: | ||
return None | ||
|
||
|
||
class SecureLGBMTrainingJobContext(JobContext): | ||
def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None): | ||
super().__init__(remote_job_client, project_name, dataset, my_agency) | ||
self.model_setting = model_setting | ||
|
||
def get_job_type(self) -> str: | ||
return "XGB_TRAINING" | ||
|
||
# TODO: build the jobParam | ||
def build(self) -> JobParam: | ||
return None | ||
|
||
|
||
class SecureLGBMPredictJobContext(JobContext): | ||
def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None): | ||
super().__init__(remote_job_client, project_name, dataset, my_agency) | ||
self.model_setting = model_setting | ||
|
||
def get_job_type(self) -> str: | ||
return "XGB_PREDICTING" | ||
|
||
# TODO: build the jobParam | ||
def build(self) -> JobParam: | ||
return None |
File renamed without changes.
22 changes: 22 additions & 0 deletions
22
python/wedpr_ml_toolkit/context/result/fe_result_context.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import os | ||
|
||
from wedpr_ml_toolkit.context.data_context import DataContext | ||
from wedpr_ml_toolkit.common.utils.constant import Constant | ||
from wedpr_ml_toolkit.context.result.result_context import ResultContext | ||
from wedpr_ml_toolkit.context.job_context import JobContext | ||
|
||
|
||
class FeResultContext(ResultContext): | ||
|
||
def __init__(self, job_context: JobContext, job_id: str): | ||
super().__init__(job_context, job_id) | ||
|
||
def parse_result(self): | ||
result_list = [] | ||
for dataset in self.job_context.dataset.datasets: | ||
dataset.update_path(os.path.join( | ||
self.job_id, Constant.FE_RESULT_FILE)) | ||
result_list.append(dataset) | ||
|
||
fe_result = DataContext(*result_list) | ||
return fe_result |
51 changes: 51 additions & 0 deletions
51
python/wedpr_ml_toolkit/context/result/model_result_context.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
import numpy as np | ||
|
||
from ppc_common.ppc_utils import utils | ||
from wedpr_ml_toolkit.context.result.result_context import ResultContext | ||
from wedpr_ml_toolkit.transport.storage_entrypoint import StorageEntryPoint | ||
from wedpr_ml_toolkit.common.utils.constant import Constant | ||
from wedpr_ml_toolkit.context.job_context import JobContext | ||
|
||
|
||
class ModelResultContext(ResultContext): | ||
def __init__(self, job_context: JobContext, job_id: str, storage_entrypoint: StorageEntryPoint): | ||
super().__init__(job_context, job_id) | ||
self.storage_entrypoint = storage_entrypoint | ||
|
||
|
||
class SecureLGBMResultContext(ModelResultContext): | ||
MODEL_DATA_FILE = utils.XGB_TREE_PERFIX + '.json' | ||
|
||
def __init__(self, job_context: JobContext, job_id: str, storage_entrypoint: StorageEntryPoint): | ||
super().__init__(job_context, job_id, storage_entrypoint) | ||
|
||
def parse_result(self): | ||
|
||
# train_praba, test_praba, train_y, test_y, feature_importance, split_xbin, trees, params | ||
# 从hdfs读取结果文件信息,构造为属性 | ||
train_praba_path = os.path.join( | ||
self.job_id, Constant.TRAIN_MODEL_OUTPUT_FILE) | ||
test_praba_path = os.path.join( | ||
self.job_id, Constant.TEST_MODEL_OUTPUT_FILE) | ||
train_output = self.storage_entrypoint.download(train_praba_path) | ||
test_output = self.storage_entrypoint.download(test_praba_path) | ||
self.train_praba = train_output['class_pred'].values | ||
self.test_praba = test_output['class_pred'].values | ||
if 'class_label' in train_output.columns: | ||
self.train_y = train_output['class_label'].values | ||
self.test_y = test_output['class_label'].values | ||
else: | ||
self.train_y = None | ||
self.test_y = None | ||
|
||
feature_bin_path = os.path.join(self.job_id, Constant.FEATURE_BIN_FILE) | ||
model_path = os.path.join(self.job_id, self.MODEL_DATA_FILE) | ||
feature_bin_data = self.storage_entrypoint.download_data( | ||
feature_bin_path) | ||
model_data = self.storage_entrypoint.download_data(model_path) | ||
|
||
self.feature_importance = ... | ||
self.split_xbin = feature_bin_data | ||
self.trees = model_data | ||
self.params = ... |
24 changes: 24 additions & 0 deletions
24
python/wedpr_ml_toolkit/context/result/psi_result_context.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# -*- coding: utf-8 -*- | ||
import os | ||
|
||
from wedpr_ml_toolkit.context.job_context import JobContext | ||
from wedpr_ml_toolkit.context.data_context import DataContext | ||
from wedpr_ml_toolkit.common.utils.constant import Constant | ||
from wedpr_ml_toolkit.context.result.result_context import ResultContext | ||
|
||
|
||
class PSIResultContext(ResultContext): | ||
|
||
PSI_RESULT_FILE = "psi_result.csv" | ||
|
||
def __init__(self, job_context: JobContext, job_id: str): | ||
super().__init__(job_context, job_id) | ||
|
||
def parse_result(self): | ||
result_list = [] | ||
for dataset in self.job_context.dataset.datasets: | ||
dataset.update_path(os.path.join( | ||
self.job_id, Constant.PSI_RESULT_FILE)) | ||
result_list.append(dataset) | ||
|
||
self.psi_result = DataContext(*result_list) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# -*- coding: utf-8 -*- | ||
from wedpr_ml_toolkit.context.job_context import JobContext | ||
from abc import abstractmethod | ||
|
||
|
||
class ResultContext: | ||
def __init__(self, job_context: JobContext, job_id: str): | ||
self.job_id = job_id | ||
self.job_context = job_context | ||
self.parse_result() | ||
|
||
@abstractmethod | ||
def parse_result(self): | ||
pass |
Oops, something went wrong.