Skip to content

Commit

Permalink
refactor wedpr_ml_toolkit
Browse files Browse the repository at this point in the history
  • Loading branch information
cyjseagull committed Oct 17, 2024
1 parent 25e376d commit 55f7765
Show file tree
Hide file tree
Showing 29 changed files with 532 additions and 478 deletions.
13 changes: 0 additions & 13 deletions python/wedpr_ml_toolkit/common/base_context.py

This file was deleted.

8 changes: 0 additions & 8 deletions python/wedpr_ml_toolkit/common/base_result.py

This file was deleted.

14 changes: 14 additions & 0 deletions python/wedpr_ml_toolkit/common/utils/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-

class Constant:
NUMERIC_ARRAY = [i for i in range(10)]
HTTP_STATUS_OK = 200
DEFAULT_SUBMIT_JOB_URI = '/api/wedpr/v3/project/submitJob'
DEFAULT_QUERY_JOB_STATUS_URL = '/api/wedpr/v3/project/queryJobByCondition'
PSI_RESULT_FILE = "psi_result.csv"

FEATURE_BIN_FILE = "feature_bin.json"
TEST_MODEL_OUTPUT_FILE = "test_output.csv"
TRAIN_MODEL_OUTPUT_FILE = "train_output.csv"

FE_RESULT_FILE = "fe_result.csv"
26 changes: 26 additions & 0 deletions python/wedpr_ml_toolkit/common/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
import uuid
from enum import Enum
import random
from common.utils.constant import Constant
from urllib.parse import urlencode, urlparse, parse_qs, quote

class IdPrefixEnum(Enum):
DATASET = "d-"
ALGORITHM = "a-"
JOB = "j-"


def make_id(prefix):
return prefix + str(uuid.uuid4()).replace("-", "")

def generate_nonce(nonce_len):
return ''.join(random.choice(Constant.NUMERIC_ARRAY) for _ in range(nonce_len))

def add_params_to_url(url, params):
parsed_url = urlparse(url)
query_params = parse_qs(parsed_url.query)
for key, value in params.items():
query_params[key] = value
new_query = urlencode(query_params, doseq=True)
return parsed_url._replace(query=new_query).geturl()
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@ class DataContext:
def __init__(self, *datasets):
self.datasets = list(datasets)
self.ctx = self.datasets[0].ctx

self._check_datasets()

def _save_dataset(self, dataset):
if dataset.dataset_path is None:
dataset.dataset_id = utils.make_id(utils.IdPrefixEnum.DATASET.value)
dataset.dataset_path = os.path.join(dataset.storage_workspace, dataset.dataset_id)
dataset.dataset_id = utils.make_id(
utils.IdPrefixEnum.DATASET.value)
dataset.dataset_path = os.path.join(
dataset.storage_workspace, dataset.dataset_id)
if dataset.storage_client is not None:
dataset.storage_client.upload(dataset.values, dataset.dataset_path)
dataset.storage_client.upload(
dataset.values, dataset.dataset_path)

def _check_datasets(self):
for dataset in self.datasets:
Expand Down
139 changes: 139 additions & 0 deletions python/wedpr_ml_toolkit/context/job_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
import json

from wedpr_ml_toolkit.context.data_context import DataContext
from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobParam
from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobInfo
from abc import abstractmethod
from wedpr_ml_toolkit.transport.wedpr_remote_job_client import WeDPRRemoteJobClient


class JobContext:

def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, dataset: DataContext = None, my_agency=None):
if dataset is None:
raise Exception("Must define the job related datasets!")
self.remote_job_client = remote_job_client
self.project_name = project_name
self.dataset = dataset
self.create_agency = my_agency
self.participant_id_list = []
self.task_parties = []
self.dataset_id_list = []
self.dataset_list = []
self.label_holder_agency = None
self.label_columns = None
self.__init_participant__()
self.__init_label_information__()
self.result_receiver_id_list = [my_agency] # 仅限jupyter所在机构
self.__check__()

def __check__(self):
"""
校验机构数和任务是否匹配
"""
if len(self.participant_id_list) < 2:
raise Exception("至少需要传入两个机构")
if not self.label_holder_agency or self.label_holder_agency not in self.participant_id_list:
raise Exception("数据集中标签提供方配置错误")

def __init_participant__(self):
participant_id_list = []
dataset_id_list = []
for dataset in self.dataset.datasets:
participant_id_list.append(dataset.agency.agency_id)
dataset_id_list.append(dataset.dataset_id)
self.task_parties.append({'userName': dataset.ctx.user_name,
'agency': dataset.agency.agency_id})
self.participant_id_list = participant_id_list
self.dataset_id_list = dataset_id_list

def __init_label_information__(self):
label_holder_agency = None
label_columns = None
for dataset in self.dataset.datasets:
if dataset.is_label_holder:
label_holder_agency = dataset.agency.agency_id
label_columns = 'y'
self.label_holder_agency = label_holder_agency
self.label_columns = label_columns

@abstractmethod
def build(self) -> JobParam:
pass

@abstractmethod
def get_job_type(self) -> str:
pass

def submit(self, project_name):
return self.submit(self.build(project_name))


class PSIJobContext(JobContext):
def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, dataset: DataContext = None, my_agency=None, merge_field: str = 'id'):
super().__init__(remote_job_client, project_name, dataset, my_agency)
self.merge_field = merge_field

def get_job_type(self) -> str:
return "PSI"

def build(self) -> JobParam:
self.dataset_list = self.dataset.to_psi_format(
self.merge_field, self.result_receiver_id_list)
job_info = JobInfo(self.get_job_type(), self.project_name, json.dumps(
{'dataSetList': self.dataset_list}).replace('"', '\\"'))
job_param = JobParam(job_info, self.task_parties, self.dataset_id_list)
return job_param


class PreprocessingJobContext(JobContext):
def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None):
super().__init__(remote_job_client, project_name, dataset, my_agency)
self.model_setting = model_setting

def get_job_type(self) -> str:
return "PREPROCESSING"

# TODO: build the request
def build(self) -> JobParam:
return None


class FeatureEngineeringJobContext(JobContext):
def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None):
super().__init__(remote_job_client, project_name, dataset, my_agency)
self.model_setting = model_setting

def get_job_type(self) -> str:
return "FEATURE_ENGINEERING"

# TODO: build the jobParam
def build(self) -> JobParam:
return None


class SecureLGBMTrainingJobContext(JobContext):
def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None):
super().__init__(remote_job_client, project_name, dataset, my_agency)
self.model_setting = model_setting

def get_job_type(self) -> str:
return "XGB_TRAINING"

# TODO: build the jobParam
def build(self) -> JobParam:
return None


class SecureLGBMPredictJobContext(JobContext):
def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None):
super().__init__(remote_job_client, project_name, dataset, my_agency)
self.model_setting = model_setting

def get_job_type(self) -> str:
return "XGB_PREDICTING"

# TODO: build the jobParam
def build(self) -> JobParam:
return None
File renamed without changes.
22 changes: 22 additions & 0 deletions python/wedpr_ml_toolkit/context/result/fe_result_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os

from wedpr_ml_toolkit.context.data_context import DataContext
from wedpr_ml_toolkit.common.utils.constant import Constant
from wedpr_ml_toolkit.context.result.result_context import ResultContext
from wedpr_ml_toolkit.context.job_context import JobContext


class FeResultContext(ResultContext):

def __init__(self, job_context: JobContext, job_id: str):
super().__init__(job_context, job_id)

def parse_result(self):
result_list = []
for dataset in self.job_context.dataset.datasets:
dataset.update_path(os.path.join(
self.job_id, Constant.FE_RESULT_FILE))
result_list.append(dataset)

fe_result = DataContext(*result_list)
return fe_result
51 changes: 51 additions & 0 deletions python/wedpr_ml_toolkit/context/result/model_result_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import numpy as np

from ppc_common.ppc_utils import utils
from wedpr_ml_toolkit.context.result.result_context import ResultContext
from wedpr_ml_toolkit.transport.storage_entrypoint import StorageEntryPoint
from wedpr_ml_toolkit.common.utils.constant import Constant
from wedpr_ml_toolkit.context.job_context import JobContext


class ModelResultContext(ResultContext):
def __init__(self, job_context: JobContext, job_id: str, storage_entrypoint: StorageEntryPoint):
super().__init__(job_context, job_id)
self.storage_entrypoint = storage_entrypoint


class SecureLGBMResultContext(ModelResultContext):
MODEL_DATA_FILE = utils.XGB_TREE_PERFIX + '.json'

def __init__(self, job_context: JobContext, job_id: str, storage_entrypoint: StorageEntryPoint):
super().__init__(job_context, job_id, storage_entrypoint)

def parse_result(self):

# train_praba, test_praba, train_y, test_y, feature_importance, split_xbin, trees, params
# 从hdfs读取结果文件信息,构造为属性
train_praba_path = os.path.join(
self.job_id, Constant.TRAIN_MODEL_OUTPUT_FILE)
test_praba_path = os.path.join(
self.job_id, Constant.TEST_MODEL_OUTPUT_FILE)
train_output = self.storage_entrypoint.download(train_praba_path)
test_output = self.storage_entrypoint.download(test_praba_path)
self.train_praba = train_output['class_pred'].values
self.test_praba = test_output['class_pred'].values
if 'class_label' in train_output.columns:
self.train_y = train_output['class_label'].values
self.test_y = test_output['class_label'].values
else:
self.train_y = None
self.test_y = None

feature_bin_path = os.path.join(self.job_id, Constant.FEATURE_BIN_FILE)
model_path = os.path.join(self.job_id, self.MODEL_DATA_FILE)
feature_bin_data = self.storage_entrypoint.download_data(
feature_bin_path)
model_data = self.storage_entrypoint.download_data(model_path)

self.feature_importance = ...
self.split_xbin = feature_bin_data
self.trees = model_data
self.params = ...
24 changes: 24 additions & 0 deletions python/wedpr_ml_toolkit/context/result/psi_result_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
import os

from wedpr_ml_toolkit.context.job_context import JobContext
from wedpr_ml_toolkit.context.data_context import DataContext
from wedpr_ml_toolkit.common.utils.constant import Constant
from wedpr_ml_toolkit.context.result.result_context import ResultContext


class PSIResultContext(ResultContext):

PSI_RESULT_FILE = "psi_result.csv"

def __init__(self, job_context: JobContext, job_id: str):
super().__init__(job_context, job_id)

def parse_result(self):
result_list = []
for dataset in self.job_context.dataset.datasets:
dataset.update_path(os.path.join(
self.job_id, Constant.PSI_RESULT_FILE))
result_list.append(dataset)

self.psi_result = DataContext(*result_list)
14 changes: 14 additions & 0 deletions python/wedpr_ml_toolkit/context/result/result_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
from wedpr_ml_toolkit.context.job_context import JobContext
from abc import abstractmethod


class ResultContext:
def __init__(self, job_context: JobContext, job_id: str):
self.job_id = job_id
self.job_context = job_context
self.parse_result()

@abstractmethod
def parse_result(self):
pass
Loading

0 comments on commit 55f7765

Please sign in to comment.