diff --git a/python/wedpr_ml_toolkit/common/utils/properies_parser.py b/python/wedpr_ml_toolkit/common/utils/properies_parser.py new file mode 100644 index 00000000..d42f83fc --- /dev/null +++ b/python/wedpr_ml_toolkit/common/utils/properies_parser.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- + +class Properties: + def __init__(self, file_path): + self.file_path = file_path + + def getProperties(self): + try: + pro_file = open(self.file_path, 'r', encoding='utf-8') + properties = {} + for line in pro_file: + if line.find('=') > 0: + strs = line.replace('\n', '').split('=') + properties[strs[0].strip()] = strs[1].strip() + except Exception as e: + raise e + else: + pro_file.close() + return properties diff --git a/python/wedpr_ml_toolkit/common/utils/utils.py b/python/wedpr_ml_toolkit/common/utils/utils.py index f517eaba..a90af704 100644 --- a/python/wedpr_ml_toolkit/common/utils/utils.py +++ b/python/wedpr_ml_toolkit/common/utils/utils.py @@ -2,9 +2,10 @@ import uuid from enum import Enum import random -from common.utils.constant import Constant +from wedpr_ml_toolkit.common.utils.constant import Constant from urllib.parse import urlencode, urlparse, parse_qs, quote + class IdPrefixEnum(Enum): DATASET = "d-" ALGORITHM = "a-" @@ -14,9 +15,11 @@ class IdPrefixEnum(Enum): def make_id(prefix): return prefix + str(uuid.uuid4()).replace("-", "") + def generate_nonce(nonce_len): return ''.join(random.choice(Constant.NUMERIC_ARRAY) for _ in range(nonce_len)) + def add_params_to_url(url, params): parsed_url = urlparse(url) query_params = parse_qs(parsed_url.query) diff --git a/python/wedpr_ml_toolkit/config/__init__.py b/python/wedpr_ml_toolkit/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/wedpr_ml_toolkit/config/wedpr_ml_config.py b/python/wedpr_ml_toolkit/config/wedpr_ml_config.py new file mode 100644 index 00000000..4783b07e --- /dev/null +++ b/python/wedpr_ml_toolkit/config/wedpr_ml_config.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +import os +from typing import Any, Dict +from wedpr_ml_toolkit.common.utils.constant import Constant +from wedpr_ml_toolkit.common.utils.properies_parser import Properties + + +class BaseConfig: + def set_params(self, **params: Any): + for key, value in params.items(): + setattr(self, key, value) + if hasattr(self, f"{key}"): + setattr(self, f"{key}", value) + return self + + +class AuthConfig(BaseConfig): + def __init__(self, access_key_id: str = None, access_key_secret: str = None, remote_entrypoints: str = None, nonce_len: int = 5): + self.access_key_id = access_key_id + self.access_key_secret = access_key_secret + self.remote_entrypoints = remote_entrypoints + self.nonce_len = nonce_len + + +class JobConfig(BaseConfig): + def __init__(self, polling_interval_s: int = 5, max_retries: int = 5, retry_delay_s: int = 5, + submit_job_uri: str = Constant.DEFAULT_SUBMIT_JOB_URI, + query_job_status_uri: str = Constant.DEFAULT_QUERY_JOB_STATUS_URL): + self.polling_interval_s = polling_interval_s + self.max_retries = max_retries + self.retry_delay_s = retry_delay_s + self.submit_job_uri = submit_job_uri + self.query_job_status_uri = query_job_status_uri + + +class StorageConfig(BaseConfig): + def __init__(self, storage_endpoint: str = None): + self.storage_endpoint = storage_endpoint + + +class UserConfig(BaseConfig): + def __init__(self, agency_name: str = None, workspace_path: str = None, user_name: str = None): + self.agency_name = agency_name + self.workspace_path = workspace_path + self.user = user_name + + def get_workspace_path(self): + return os.path.join(self.workspace_path, self.user) + + +class WeDPRMlConfig: + def __init__(self, config_dict): + self.auth_config = AuthConfig() + self.auth_config.set_params(**config_dict) + self.job_config = JobConfig() + self.job_config.set_params(**config_dict) + self.storage_config = StorageConfig() + self.storage_config.set_params(**config_dict) + self.user_config = UserConfig() + self.user_config.set_params(**config_dict) + + +class WeDPRMlConfigBuilder: + @staticmethod + def build(config_dict) -> WeDPRMlConfig: + return WeDPRMlConfig(config_dict) + + @staticmethod + def build_from_properties_file(config_file_path): + if not os.path.exists(config_file_path): + raise Exception( + f"build WeDPRMlConfig failed for the config file {config_file_path} not exits!") + properties = Properties(config_file_path) + return WeDPRMlConfigBuilder.build(properties.getProperties()) diff --git a/python/wedpr_ml_toolkit/config/wedpr_model_setting.py b/python/wedpr_ml_toolkit/config/wedpr_model_setting.py new file mode 100644 index 00000000..aa108ac5 --- /dev/null +++ b/python/wedpr_ml_toolkit/config/wedpr_model_setting.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- + +class PreprocessingModelSetting: + def __init__(self): + self.use_psi = False + self.fillna = False + self.na_select = 1.0 + self.filloutlier = False + self.normalized = False + self.standardized = False + self.categorical = '' + self.psi_select_col = '' + self.psi_select_base = '' + self.psi_select_base = 0.3 + self.psi_select_bins = 4 + self.corr_select = 0 + self.use_goss = False + + +class FeatureEngineeringEngineModelSetting: + def __init__(self): + self.use_iv = False + self.group_num = 4 + self.iv_thresh = 0.1 + + +class CommmonSecureModelSetting: + def __init__(self): + self.learning_rate = 0.1 + self.eval_set_column = '' + self.train_set_value = '' + self.eval_set_value = '' + self.verbose_eval = 1 + self.silent = False + self.train_features = '' + self.random_state = None + self.n_jobs = 0 + + +class SecureLGBMModelSetting(CommmonSecureModelSetting): + def __init__(self): + super().__init__() + self.test_size = 0.3 + self.num_trees = 6 + self.max_depth = 3 + self.max_bin = 4 + self.subsample = 1.0 + self.colsample_bytree = 1 + self.colsample_bylevel = 1 + self.reg_alpha = 0 + self.reg_lambda = 1.0 + self.gamma = 0.0 + self.min_child_weight = 0.0 + self.min_child_samples = 10 + self.seed = 2024 + self.early_stopping_rounds = 5 + self.eval_metric = "auc" + self.threads = 8 + self.one_hot = 0 + + +class SecureLRModelSetting(CommmonSecureModelSetting): + def __init__(self): + super().__init__() + self.feature_rate = 1.0 + self.batch_size = 16 + self.epochs = 3 + + +class ModelSetting(PreprocessingModelSetting, FeatureEngineeringEngineModelSetting, SecureLGBMModelSetting, SecureLRModelSetting): + def __init__(self): + # init PreprocessingSetting + super().__init__() + # init FeatureEngineeringEngineSetting + super(FeatureEngineeringEngineModelSetting, self).__init__(model_dict) + # init SecureLGBMSetting + super(SecureLGBMModelSetting, self).__init__(model_dict) + # init SecureLRSetting + super(SecureLRModelSetting, self).__init__(model_dict) diff --git a/python/wedpr_ml_toolkit/context/data_context.py b/python/wedpr_ml_toolkit/context/data_context.py index 2dc611f8..ade5b8bb 100644 --- a/python/wedpr_ml_toolkit/context/data_context.py +++ b/python/wedpr_ml_toolkit/context/data_context.py @@ -1,6 +1,6 @@ import os -from wedpr_ml_toolkit.utils import utils +from wedpr_ml_toolkit.common import utils class DataContext: diff --git a/python/wedpr_ml_toolkit/context/job_context.py b/python/wedpr_ml_toolkit/context/job_context.py index 620bc1fb..51809697 100644 --- a/python/wedpr_ml_toolkit/context/job_context.py +++ b/python/wedpr_ml_toolkit/context/job_context.py @@ -6,6 +6,15 @@ from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobInfo from abc import abstractmethod from wedpr_ml_toolkit.transport.wedpr_remote_job_client import WeDPRRemoteJobClient +from enum import Enum + + +class JobType(Enum): + PSI = "PSI", + PREPROCESSING = "PREPROCESSING", + FEATURE_ENGINEERING = "FEATURE_ENGINEERING", + XGB_TRAINING = "XGB_TRAINING", + XGB_PREDICTING = "XGB_PREDICTING" class JobContext: @@ -63,11 +72,20 @@ def build(self) -> JobParam: pass @abstractmethod - def get_job_type(self) -> str: + def get_job_type(self) -> JobType: + pass + + def submit(self): + return self.remote_job_client.submit_job(self.build()) + + @abstractmethod + def parse_result(self, result_detail): pass - def submit(self, project_name): - return self.submit(self.build(project_name)) + def fetch_job_result(self, job_id, block_until_success): + job_result = self.query_job_status(job_id, block_until_success) + # TODO: determine success or not here + return self.parse_result(self.remote_job_client.query_job_detail(job_id)) class PSIJobContext(JobContext): @@ -75,8 +93,8 @@ def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, d super().__init__(remote_job_client, project_name, dataset, my_agency) self.merge_field = merge_field - def get_job_type(self) -> str: - return "PSI" + def get_job_type(self) -> JobType: + return JobType.PSI def build(self) -> JobParam: self.dataset_list = self.dataset.to_psi_format( @@ -92,8 +110,8 @@ def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, m super().__init__(remote_job_client, project_name, dataset, my_agency) self.model_setting = model_setting - def get_job_type(self) -> str: - return "PREPROCESSING" + def get_job_type(self) -> JobType: + return JobType.PREPROCESSING # TODO: build the request def build(self) -> JobParam: @@ -105,8 +123,8 @@ def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, m super().__init__(remote_job_client, project_name, dataset, my_agency) self.model_setting = model_setting - def get_job_type(self) -> str: - return "FEATURE_ENGINEERING" + def get_job_type(self) -> JobType: + return JobType.FEATURE_ENGINEERING # TODO: build the jobParam def build(self) -> JobParam: @@ -118,8 +136,8 @@ def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, m super().__init__(remote_job_client, project_name, dataset, my_agency) self.model_setting = model_setting - def get_job_type(self) -> str: - return "XGB_TRAINING" + def get_job_type(self) -> JobType: + return JobType.XGB_TRAINING # TODO: build the jobParam def build(self) -> JobParam: @@ -131,8 +149,8 @@ def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, m super().__init__(remote_job_client, project_name, dataset, my_agency) self.model_setting = model_setting - def get_job_type(self) -> str: - return "XGB_PREDICTING" + def get_job_type(self) -> JobType: + return JobType.XGB_PREDICTING # TODO: build the jobParam def build(self) -> JobParam: diff --git a/python/wedpr_ml_toolkit/test/config.properties b/python/wedpr_ml_toolkit/test/config.properties new file mode 100644 index 00000000..23ecf94e --- /dev/null +++ b/python/wedpr_ml_toolkit/test/config.properties @@ -0,0 +1,9 @@ +access_key_id="NmNmNzUxZmMtNDNjZi00NzUwLWFjNmQtMTM3OTliZmZmMmU5" +access_key_secret="ODVkY2U2OTUtZjJmZS00MjQ1LTljY2YtMzhkYWIwNDlhMDIy" +remote_entrypoints="[http://139.159.202.235:16000]" + +agency_name="SGD" +workspace_path="/user/ppc/milestone2/sgd/" +user="flyhuang1" +storage_endpoint="http://192.168.0.18:50070" + diff --git a/python/wedpr_ml_toolkit/test/test_dev.py b/python/wedpr_ml_toolkit/test/test_dev.py deleted file mode 100644 index db5efc9a..00000000 --- a/python/wedpr_ml_toolkit/test/test_dev.py +++ /dev/null @@ -1,78 +0,0 @@ -import unittest -import numpy as np -import pandas as pd -from sklearn import metrics - -from wedpr_ml_toolkit.common.base_context import BaseContext -from wedpr_ml_toolkit.utils.agency import Agency -from wedpr_ml_toolkit.toolkit import DatasetToolkit -from wedpr_ml_toolkit.context.data_context import DataContext -from wedpr_ml_toolkit.context.job_context import JobContext - - -# 从jupyter环境中获取project_id等信息 -# create workspace -# 相同项目/刷新专家模式project_id固定 -project_id = '测试-xinyi' -user = 'flyhuang1' -my_agency = 'sgd' -pws_endpoint = 'http://139.159.202.235:8005' # http -hdfs_endpoint = 'http://192.168.0.18:50070' # client -token = 'abc...' - - -# 自定义合作方机构 -partner_agency1 = 'webank' -partner_agency2 = 'TX' - -# 初始化project ctx 信息 -ctx = BaseContext(project_id, user, pws_endpoint, hdfs_endpoint, token) - -# 注册 agency -agency1 = Agency(agency_id=my_agency) -agency2 = Agency(agency_id=partner_agency1) - -# 注册 dataset,支持两种方式: pd.Dataframe, hdfs_path -# pd.Dataframe -df = pd.DataFrame({ - 'id': np.arange(0, 100), # id列,顺序整数 - 'y': np.random.randint(0, 2, size=100), - **{f'x{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数 -}) - -dataset1 = DatasetToolkit(ctx, values=df, agency=agency1, is_label_holder=True) -dataset1.storage_client = None -# './milestone2\\sgd\\flyhuang1\\share\\d-101' -dataset1.save_values(path='d-101') - -# hdfs_path -dataset2 = DatasetToolkit( - ctx, dataset_path='/user/ppc/milestone2/webank/flyhuang/d-9606695119693829', agency=agency2) -dataset2.storage_client = None -# dataset2.load_values() -if dataset2.storage_client is None: - # 支持更新dataset的values数据 - df2 = pd.DataFrame({ - 'id': np.arange(0, 100), # id列,顺序整数 - **{f'z{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数 - }) - dataset2.update_values(values=df2) -if dataset1.storage_client is not None: - dataset1.update_values( - path='/user/ppc/milestone2/sgd/flyhuang1/d-9606704699156485') - dataset1.load_values() - -# 构建 dataset context -dataset = DataContext(dataset1, dataset2) - -# 初始化 wedpr task session(含数据) -task = JobContext(dataset, my_agency=my_agency) -print(task.participant_id_list, task.result_receiver_id_list) -# 执行psi任务 -psi_result = task.psi() - -# 初始化 wedpr task session(不含数据) (推荐:使用更灵活) -task = JobContext(my_agency=my_agency) -# 执行psi任务 -fe_result = task.proprecessing(dataset) -print(task.participant_id_list, task.result_receiver_id_list) diff --git a/python/wedpr_ml_toolkit/test/test_ml_toolkit.py b/python/wedpr_ml_toolkit/test/test_ml_toolkit.py new file mode 100644 index 00000000..98ab3ee8 --- /dev/null +++ b/python/wedpr_ml_toolkit/test/test_ml_toolkit.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +import unittest +import numpy as np +import pandas as pd +from sklearn import metrics +from wedpr_ml_toolkit.config.wedpr_ml_config import WeDPRMlConfigBuilder +from wedpr_ml_toolkit.wedpr_ml_toolkit import WeDPRMlToolkit +from wedpr_ml_toolkit.toolkit.dataset_toolkit import DatasetToolkit +from wedpr_ml_toolkit.context.data_context import DataContext +from wedpr_ml_toolkit.context.job_context import JobType +from wedpr_ml_toolkit.config.wedpr_model_setting import PreprocessingModelSetting + +wedpr_config = WeDPRMlConfigBuilder.build_from_properties_file( + "config.properties") + +wedpr_ml_toolkit = WeDPRMlToolkit(wedpr_config) + +# 注册 dataset,支持两种方式: pd.Dataframe, hdfs_path +df = pd.DataFrame({ + 'id': np.arange(0, 100), # id列,顺序整数 + 'y': np.random.randint(0, 2, size=100), + **{f'x{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数 +}) + +dataset1 = DatasetToolkit(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(), + storage_workspace=wedpr_config.user_config.get_workspace_path(), + agency=wedpr_config.user_config.agency_name, + values=df, + is_label_holder=True) +dataset1.save_values(path='d-101') + +# hdfs_path +dataset2 = DatasetToolkit(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(), + dataset_path="d-9606695119693829", agency="WeBank") + +dataset2.storage_client = None +# dataset2.load_values() +if dataset2.storage_client is None: + # 支持更新dataset的values数据 + df2 = pd.DataFrame({ + 'id': np.arange(0, 100), # id列,顺序整数 + **{f'z{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数 + }) + dataset2.update_values(values=df2) +if dataset1.storage_client is not None: + dataset1.update_values( + path='/user/ppc/milestone2/sgd/flyhuang1/d-9606704699156485') + dataset1.load_values() + +# 构建 dataset context +dataset = DataContext(dataset1, dataset2) + +# init the job context +project_name = "1" + +psi_job_context = wedpr_ml_toolkit.build_job_context( + JobType.PSI, project_name, dataset, None, "id") +print(psi_job_context.participant_id_list, + psi_job_context.result_receiver_id_list) +# 执行psi任务 +psi_job_id = psi_job_context.submit() +psi_result = psi_job_context.fetch_job_result(psi_job_id, True) + +# 初始化 +preprocessing_data = DataContext(dataset1) +preprocessing_job_context = wedpr_ml_toolkit.build_job_context( + JobType.PREPROCESSING, project_name, preprocessing_data, PreprocessingModelSetting()) +# 执行预处理任务 +fe_job_id = preprocessing_job_context.submit(dataset) +fe_result = preprocessing_job_context.fetch_job_result(fe_job_id, True) +print(preprocessing_job_context.participant_id_list, + preprocessing_job_context.result_receiver_id_list) diff --git a/python/wedpr_ml_toolkit/toolkit/dataset_toolkit.py b/python/wedpr_ml_toolkit/toolkit/dataset_toolkit.py new file mode 100644 index 00000000..2e58a3c7 --- /dev/null +++ b/python/wedpr_ml_toolkit/toolkit/dataset_toolkit.py @@ -0,0 +1,64 @@ +import os +import pandas as pd +from wedpr_ml_toolkit.transport.storage_entrypoint import StorageEntryPoint + + +class DatasetToolkit: + + def __init__(self, + storage_entrypoint: StorageEntryPoint, + storage_workspace, + dataset_id=None, + dataset_path=None, + agency=None, + values=None, + is_label_holder=False): + self.dataset_id = dataset_id + self.dataset_path = dataset_path + self.agency = agency + self.values = values + self.is_label_holder = is_label_holder + self.columns = None + self.shape = None + + self.storage_client = storage_entrypoint + self.storage_workspace = storage_workspace + + if self.values is not None: + self.columns = self.values.columns + self.shape = self.values.shape + + def load_values(self): + # 加载hdfs的数据集 + if self.storage_client is not None: + self.values = self.storage_client.download(self.dataset_path) + self.columns = self.values.columns + self.shape = self.values.shape + + def save_values(self, path=None): + # 保存数据到hdfs目录 + if path is not None: + self.dataset_path = path + if not self.dataset_path.startswith(self.storage_workspace): + self.dataset_path = os.path.join( + self.storage_workspace, self.dataset_path) + if self.storage_client is not None: + self.storage_client.upload(self.values, self.dataset_path) + + def update_values(self, values: pd.DataFrame = None, path: str = None): + # 将数据集存入hdfs相同路径,替换旧数据集 + if values is not None: + self.values = values + self.columns = self.values.columns + self.shape = self.values.shape + if path is not None: + self.dataset_path = path + if values is not None and self.storage_client is not None: + self.storage_client.upload(self.values, self.dataset_path) + + def update_path(self, path: str = None): + # 将数据集存入hdfs相同路径,替换旧数据集 + if path is not None: + self.dataset_path = path + if self.values is not None: + self.values = None diff --git a/python/wedpr_ml_toolkit/transport/credential_generator.py b/python/wedpr_ml_toolkit/transport/credential_generator.py index c8a23a7d..7800e74d 100644 --- a/python/wedpr_ml_toolkit/transport/credential_generator.py +++ b/python/wedpr_ml_toolkit/transport/credential_generator.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- import hashlib -from common import utils +from wedpr_ml_toolkit.common import utils import time diff --git a/python/wedpr_ml_toolkit/transport/storage_entrypoint.py b/python/wedpr_ml_toolkit/transport/storage_entrypoint.py index 06770d05..788abe69 100644 --- a/python/wedpr_ml_toolkit/transport/storage_entrypoint.py +++ b/python/wedpr_ml_toolkit/transport/storage_entrypoint.py @@ -8,11 +8,10 @@ class StorageEntryPoint: def __init__(self, storage_config: StorageConfig): self.storage_config = storage_config - config_data = {} config_data['STORAGE_TYPE'] = 'HDFS' - config_data['HDFS_URL'] = self.storage_config.endpoint - config_data['HDFS_ENDPOINT'] = self.storage_config.endpoint + config_data['HDFS_URL'] = self.storage_config.storage_endpoint + config_data['HDFS_ENDPOINT'] = self.storage_config.storage_endpoint self.storage_client = storage_loader.load(config_data, logger=None) def upload(self, dataframe, hdfs_path): diff --git a/python/wedpr_ml_toolkit/transport/wedpr_entrypoint.py b/python/wedpr_ml_toolkit/transport/wedpr_entrypoint.py index 0803bc9a..9ffe3671 100644 --- a/python/wedpr_ml_toolkit/transport/wedpr_entrypoint.py +++ b/python/wedpr_ml_toolkit/transport/wedpr_entrypoint.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import requests -from transport.credential_generator import CredentialGenerator -from common.utils.constant import Constant +from wedpr_ml_toolkit.transport.credential_generator import CredentialGenerator +from wedpr_ml_toolkit.common.utils.constant import Constant import json diff --git a/python/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py b/python/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py index 45246c63..1fb1aeb0 100644 --- a/python/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py +++ b/python/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py @@ -1,8 +1,7 @@ -from transport.wedpr_entrypoint import WeDPREntryPoint -from common.utils.constant import Constant -from config.wedpr_ml_config import AuthConfig -from config.wedpr_ml_config import JobConfig -import random +from wedpr_ml_toolkit.transport.wedpr_entrypoint import WeDPREntryPoint +from wedpr_ml_toolkit.common.utils.constant import Constant +from wedpr_ml_toolkit.config.wedpr_ml_config import AuthConfig +from wedpr_ml_toolkit.config.wedpr_ml_config import JobConfig import json import time diff --git a/python/wedpr_ml_toolkit/wedpr_ml_toolkit.py b/python/wedpr_ml_toolkit/wedpr_ml_toolkit.py index 02fa8d47..b2493114 100644 --- a/python/wedpr_ml_toolkit/wedpr_ml_toolkit.py +++ b/python/wedpr_ml_toolkit/wedpr_ml_toolkit.py @@ -1,8 +1,15 @@ # -*- coding: utf-8 -*- -from config.wedpr_ml_config import WeDPRMlConfig +from wedpr_ml_toolkit.config.wedpr_ml_config import WeDPRMlConfig from wedpr_ml_toolkit.transport.wedpr_remote_job_client import WeDPRRemoteJobClient from wedpr_ml_toolkit.transport.storage_entrypoint import StorageEntryPoint +from wedpr_ml_toolkit.context.job_context import JobType +from wedpr_ml_toolkit.context.job_context import PSIJobContext +from wedpr_ml_toolkit.context.job_context import PreprocessingJobContext +from wedpr_ml_toolkit.context.job_context import FeatureEngineeringJobContext +from wedpr_ml_toolkit.context.job_context import SecureLGBMPredictJobContext +from wedpr_ml_toolkit.context.job_context import SecureLGBMTrainingJobContext +from wedpr_ml_toolkit.context.data_context import DataContext class WeDPRMlToolkit: @@ -24,3 +31,25 @@ def get_storage_entry_point(self) -> StorageEntryPoint: def submit(self, job_param): return self.remote_job_client.submit_job(job_param) + + def query_job_status(self, job_id, block_until_success=False): + if not block_until_success: + return self.remote_job_client.query_job_result(job_id) + return self.remote_job_client.wait_job_finished() + + def query_job_detail(self, job_id, block_until_success): + job_result = self.query_job_status(job_id, block_until_success) + # TODO: determine success or not here + return self.remote_job_client.query_job_detail(job_id) + + def build_job_context(self, job_type: JobType, project_name: str, dataset: DataContext, model_setting=None, id_fields='id'): + if job_type == JobType.PSI: + return PSIJobContext(self.remote_job_client, project_name, dataset, self.config.agency_config.agency_name, id_fields) + if job_type == JobType.PREPROCESSING: + return PreprocessingJobContext(self.remote_job_client, project_name, model_setting, dataset, self.config.agency_config.agency_name) + if job_type == JobType.FEATURE_ENGINEERING: + return FeatureEngineeringJobContext(self.remote_job_client, project_name, model_setting, dataset, self.config.agency_config.agency_name) + if job_type == JobType.XGB_TRAINING: + return SecureLGBMTrainingJobContext(self.remote_job_client, project_name, model_setting, dataset, self.config.agency_config.agency_name) + if job_type == JobType.XGB_PREDICTING: + return SecureLGBMPredictJobContext(self.remote_job_client, project_name, model_setting, dataset, self.config.agency_config.agency_name)