diff --git a/python/wedpr_ml_toolkit/common/utils/base_object.py b/python/wedpr_ml_toolkit/common/utils/base_object.py new file mode 100644 index 00000000..03b9ef9b --- /dev/null +++ b/python/wedpr_ml_toolkit/common/utils/base_object.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +from typing import Any +import time + + +class BaseObject: + def set_params(self, **params: Any): + for key, value in params.items(): + setattr(self, key, value) + if hasattr(self, f"{key}"): + setattr(self, f"{key}", value) + return self + + def as_dict(obj): + return {attr: getattr(obj, attr) for attr in dir(obj) if not callable(getattr(obj, attr)) and not attr.startswith("__")} + + def execute_with_retry(self, request_func, retry_times, retry_wait_seconds, *args, **kwargs): + attempt = 0 + while attempt < retry_times: + try: + response = request_func(*args, **kwargs) + return response + except Exception as e: + attempt += 1 + if attempt < retry_times: + time.sleep(retry_wait_seconds) + else: + raise e diff --git a/python/wedpr_ml_toolkit/common/utils/constant.py b/python/wedpr_ml_toolkit/common/utils/constant.py index 460f36d8..b323d962 100644 --- a/python/wedpr_ml_toolkit/common/utils/constant.py +++ b/python/wedpr_ml_toolkit/common/utils/constant.py @@ -3,12 +3,14 @@ class Constant: NUMERIC_ARRAY = [i for i in range(10)] HTTP_STATUS_OK = 200 - DEFAULT_SUBMIT_JOB_URI = '/api/wedpr/v3/project/submitJob' - DEFAULT_QUERY_JOB_STATUS_URL = '/api/wedpr/v3/project/queryJobByCondition' + WEDPR_API_PREFIX = '/api/wedpr/v3/' + DEFAULT_SUBMIT_JOB_URI = f'{WEDPR_API_PREFIX}project/submitJob' + DEFAULT_QUERY_JOB_STATUS_URL = f'{WEDPR_API_PREFIX}project/queryJobByCondition' + DEFAULT_QUERY_JOB_DETAIL_URL = f'{WEDPR_API_PREFIX}scheduler/queryJobDetail' PSI_RESULT_FILE = "psi_result.csv" FEATURE_BIN_FILE = "feature_bin.json" TEST_MODEL_OUTPUT_FILE = "test_output.csv" TRAIN_MODEL_OUTPUT_FILE = "train_output.csv" - FE_RESULT_FILE = "fe_result.csv" \ No newline at end of file + FE_RESULT_FILE = "fe_result.csv" diff --git a/python/wedpr_ml_toolkit/common/utils/properies_parser.py b/python/wedpr_ml_toolkit/common/utils/properies_parser.py index d42f83fc..dd05a475 100644 --- a/python/wedpr_ml_toolkit/common/utils/properies_parser.py +++ b/python/wedpr_ml_toolkit/common/utils/properies_parser.py @@ -10,7 +10,7 @@ def getProperties(self): properties = {} for line in pro_file: if line.find('=') > 0: - strs = line.replace('\n', '').split('=') + strs = line.strip("\"").replace('\n', '').split('=') properties[strs[0].strip()] = strs[1].strip() except Exception as e: raise e diff --git a/python/wedpr_ml_toolkit/common/utils/utils.py b/python/wedpr_ml_toolkit/common/utils/utils.py index a90af704..27795c8a 100644 --- a/python/wedpr_ml_toolkit/common/utils/utils.py +++ b/python/wedpr_ml_toolkit/common/utils/utils.py @@ -17,7 +17,7 @@ def make_id(prefix): def generate_nonce(nonce_len): - return ''.join(random.choice(Constant.NUMERIC_ARRAY) for _ in range(nonce_len)) + return ''.join(str(random.choice(Constant.NUMERIC_ARRAY)) for _ in range(nonce_len)) def add_params_to_url(url, params): diff --git a/python/wedpr_ml_toolkit/config/wedpr_ml_config.py b/python/wedpr_ml_toolkit/config/wedpr_ml_config.py index 4783b07e..55c17e90 100644 --- a/python/wedpr_ml_toolkit/config/wedpr_ml_config.py +++ b/python/wedpr_ml_toolkit/config/wedpr_ml_config.py @@ -1,44 +1,42 @@ # -*- coding: utf-8 -*- import os -from typing import Any, Dict +from wedpr_ml_toolkit.common.utils.base_object import BaseObject from wedpr_ml_toolkit.common.utils.constant import Constant from wedpr_ml_toolkit.common.utils.properies_parser import Properties -class BaseConfig: - def set_params(self, **params: Any): - for key, value in params.items(): - setattr(self, key, value) - if hasattr(self, f"{key}"): - setattr(self, f"{key}", value) - return self - - -class AuthConfig(BaseConfig): +class AuthConfig(BaseObject): def __init__(self, access_key_id: str = None, access_key_secret: str = None, remote_entrypoints: str = None, nonce_len: int = 5): self.access_key_id = access_key_id self.access_key_secret = access_key_secret self.remote_entrypoints = remote_entrypoints self.nonce_len = nonce_len + def get_remote_entrypoints_list(self) -> []: + if self.remote_entrypoints is None: + return None + return self.remote_entrypoints.split(',') + -class JobConfig(BaseConfig): - def __init__(self, polling_interval_s: int = 5, max_retries: int = 5, retry_delay_s: int = 5, +class JobConfig(BaseObject): + def __init__(self, polling_interval_s: int = 5, max_retries: int = 2, retry_delay_s: int = 5, submit_job_uri: str = Constant.DEFAULT_SUBMIT_JOB_URI, - query_job_status_uri: str = Constant.DEFAULT_QUERY_JOB_STATUS_URL): + query_job_status_uri: str = Constant.DEFAULT_QUERY_JOB_STATUS_URL, + query_job_detail_uri: str = Constant.DEFAULT_QUERY_JOB_STATUS_URL): self.polling_interval_s = polling_interval_s self.max_retries = max_retries self.retry_delay_s = retry_delay_s self.submit_job_uri = submit_job_uri self.query_job_status_uri = query_job_status_uri + self.query_job_detail_uri = query_job_detail_uri -class StorageConfig(BaseConfig): +class StorageConfig(BaseObject): def __init__(self, storage_endpoint: str = None): self.storage_endpoint = storage_endpoint -class UserConfig(BaseConfig): +class UserConfig(BaseObject): def __init__(self, agency_name: str = None, workspace_path: str = None, user_name: str = None): self.agency_name = agency_name self.workspace_path = workspace_path @@ -48,6 +46,11 @@ def get_workspace_path(self): return os.path.join(self.workspace_path, self.user) +class HttpConfig(BaseObject): + def __init__(self, timeout_seconds=3): + self.timeout_seconds = timeout_seconds + + class WeDPRMlConfig: def __init__(self, config_dict): self.auth_config = AuthConfig() @@ -58,6 +61,8 @@ def __init__(self, config_dict): self.storage_config.set_params(**config_dict) self.user_config = UserConfig() self.user_config.set_params(**config_dict) + self.http_config = HttpConfig() + self.http_config.set_params(**config_dict) class WeDPRMlConfigBuilder: diff --git a/python/wedpr_ml_toolkit/context/job_context.py b/python/wedpr_ml_toolkit/context/job_context.py index 51809697..90355003 100644 --- a/python/wedpr_ml_toolkit/context/job_context.py +++ b/python/wedpr_ml_toolkit/context/job_context.py @@ -6,15 +6,7 @@ from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobInfo from abc import abstractmethod from wedpr_ml_toolkit.transport.wedpr_remote_job_client import WeDPRRemoteJobClient -from enum import Enum - - -class JobType(Enum): - PSI = "PSI", - PREPROCESSING = "PREPROCESSING", - FEATURE_ENGINEERING = "FEATURE_ENGINEERING", - XGB_TRAINING = "XGB_TRAINING", - XGB_PREDICTING = "XGB_PREDICTING" +from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobType class JobContext: @@ -99,7 +91,7 @@ def get_job_type(self) -> JobType: def build(self) -> JobParam: self.dataset_list = self.dataset.to_psi_format( self.merge_field, self.result_receiver_id_list) - job_info = JobInfo(self.get_job_type(), self.project_name, json.dumps( + job_info = JobInfo(job_type=self.get_job_type(), project_name=self.project_name, param=json.dumps( {'dataSetList': self.dataset_list}).replace('"', '\\"')) job_param = JobParam(job_info, self.task_parties, self.dataset_id_list) return job_param diff --git a/python/wedpr_ml_toolkit/test/config.properties b/python/wedpr_ml_toolkit/test/config.properties index c466500f..358d170d 100644 --- a/python/wedpr_ml_toolkit/test/config.properties +++ b/python/wedpr_ml_toolkit/test/config.properties @@ -1,9 +1,9 @@ -access_key_id="" -access_key_secret="" -remote_entrypoints="http://127.0.0.1:16000,http://127.0.0.1:16001" +access_key_id= +access_key_secret= +remote_entrypoints=http://127.0.0.1:16000,http://127.0.0.1:16001 -agency_name="SGD" -workspace_path="/user/wedpr/milestone2/sgd/" -user="test_user" -storage_endpoint="http://127.0.0.1:50070" +agency_name=SGD +workspace_path=/user/wedpr/milestone2/sgd/ +user=test_user +storage_endpoint=http://127.0.0.1:50070 diff --git a/python/wedpr_ml_toolkit/test/test_ml_toolkit.py b/python/wedpr_ml_toolkit/test/test_ml_toolkit.py index 98ab3ee8..9c88e530 100644 --- a/python/wedpr_ml_toolkit/test/test_ml_toolkit.py +++ b/python/wedpr_ml_toolkit/test/test_ml_toolkit.py @@ -10,63 +10,93 @@ from wedpr_ml_toolkit.context.job_context import JobType from wedpr_ml_toolkit.config.wedpr_model_setting import PreprocessingModelSetting -wedpr_config = WeDPRMlConfigBuilder.build_from_properties_file( - "config.properties") -wedpr_ml_toolkit = WeDPRMlToolkit(wedpr_config) +class WeDPRMlToolkitTestWrapper: + def __init__(self, config_file_path): + self.wedpr_config = WeDPRMlConfigBuilder.build_from_properties_file( + config_file_path) + self.wedpr_ml_toolkit = WeDPRMlToolkit(self.wedpr_config) -# 注册 dataset,支持两种方式: pd.Dataframe, hdfs_path -df = pd.DataFrame({ - 'id': np.arange(0, 100), # id列,顺序整数 - 'y': np.random.randint(0, 2, size=100), - **{f'x{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数 -}) + def test_submit_job(self): + # 注册 dataset,支持两种方式: pd.Dataframe, hdfs_path + df = pd.DataFrame({ + 'id': np.arange(0, 100), # id列,顺序整数 + 'y': np.random.randint(0, 2, size=100), + # x1到x10列,随机数 + **{f'x{i}': np.random.rand(100) for i in range(1, 11)} + }) -dataset1 = DatasetToolkit(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(), - storage_workspace=wedpr_config.user_config.get_workspace_path(), - agency=wedpr_config.user_config.agency_name, - values=df, - is_label_holder=True) -dataset1.save_values(path='d-101') + dataset1 = DatasetToolkit(storage_entrypoint=self.wedpr_ml_toolkit.get_storage_entry_point(), + storage_workspace=self.wedpr_config.user_config.get_workspace_path(), + agency=self.wedpr_config.user_config.agency_name, + values=df, + is_label_holder=True) + dataset1.save_values(path='d-101') -# hdfs_path -dataset2 = DatasetToolkit(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(), - dataset_path="d-9606695119693829", agency="WeBank") + # hdfs_path + dataset2 = DatasetToolkit(storage_entrypoint=self.wedpr_ml_toolkit.get_storage_entry_point(), + dataset_path="d-9606695119693829", agency="WeBank") -dataset2.storage_client = None -# dataset2.load_values() -if dataset2.storage_client is None: - # 支持更新dataset的values数据 - df2 = pd.DataFrame({ - 'id': np.arange(0, 100), # id列,顺序整数 - **{f'z{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数 - }) - dataset2.update_values(values=df2) -if dataset1.storage_client is not None: - dataset1.update_values( - path='/user/ppc/milestone2/sgd/flyhuang1/d-9606704699156485') - dataset1.load_values() + dataset2.storage_client = None + # dataset2.load_values() + if dataset2.storage_client is None: + # 支持更新dataset的values数据 + df2 = pd.DataFrame({ + 'id': np.arange(0, 100), # id列,顺序整数 + # x1到x10列,随机数 + **{f'z{i}': np.random.rand(100) for i in range(1, 11)} + }) + dataset2.update_values(values=df2) + if dataset1.storage_client is not None: + dataset1.update_values( + path='/user/ppc/milestone2/sgd/flyhuang1/d-9606704699156485') + dataset1.load_values() -# 构建 dataset context -dataset = DataContext(dataset1, dataset2) + # 构建 dataset context + dataset = DataContext(dataset1, dataset2) -# init the job context -project_name = "1" + # init the job context + project_name = "1" -psi_job_context = wedpr_ml_toolkit.build_job_context( - JobType.PSI, project_name, dataset, None, "id") -print(psi_job_context.participant_id_list, - psi_job_context.result_receiver_id_list) -# 执行psi任务 -psi_job_id = psi_job_context.submit() -psi_result = psi_job_context.fetch_job_result(psi_job_id, True) + psi_job_context = self.wedpr_ml_toolkit.build_job_context( + JobType.PSI, project_name, dataset, None, "id") + print(psi_job_context.participant_id_list, + psi_job_context.result_receiver_id_list) + # 执行psi任务 + psi_job_id = psi_job_context.submit() + psi_result = psi_job_context.fetch_job_result(psi_job_id, True) -# 初始化 -preprocessing_data = DataContext(dataset1) -preprocessing_job_context = wedpr_ml_toolkit.build_job_context( - JobType.PREPROCESSING, project_name, preprocessing_data, PreprocessingModelSetting()) -# 执行预处理任务 -fe_job_id = preprocessing_job_context.submit(dataset) -fe_result = preprocessing_job_context.fetch_job_result(fe_job_id, True) -print(preprocessing_job_context.participant_id_list, - preprocessing_job_context.result_receiver_id_list) + # 初始化 + preprocessing_data = DataContext(dataset1) + preprocessing_job_context = self.wedpr_ml_toolkit.build_job_context( + JobType.PREPROCESSING, project_name, preprocessing_data, PreprocessingModelSetting()) + # 执行预处理任务 + fe_job_id = preprocessing_job_context.submit(dataset) + fe_result = preprocessing_job_context.fetch_job_result(fe_job_id, True) + print(preprocessing_job_context.participant_id_list, + preprocessing_job_context.result_receiver_id_list) + + def test_query_job(self, job_id: str, block_until_finish): + job_result = self.wedpr_ml_toolkit.query_job_status( + job_id, block_until_finish) + print(f"#### job_result: {job_result}") + job_detail_result = self.wedpr_ml_toolkit.query_job_detail( + job_id, block_until_finish) + return (job_result, job_detail_result) + + +class TestMlToolkit(unittest.TestCase): + def test_query_jobs(self): + wrapper = WeDPRMlToolkitTestWrapper("config.properties") + # the success job case + success_job_id = "9630202187032582" + wrapper.test_query_job(success_job_id, False) + # wrapper.test_query_job(success_job_id, True) + # the fail job case + failed_job_id = "9630156365047814" + wrapper.test_query_job(success_job_id, False) + # wrapper.test_query_job(success_job_id, True) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/wedpr_ml_toolkit/transport/credential_generator.py b/python/wedpr_ml_toolkit/transport/credential_generator.py index 7800e74d..9702cfde 100644 --- a/python/wedpr_ml_toolkit/transport/credential_generator.py +++ b/python/wedpr_ml_toolkit/transport/credential_generator.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- import hashlib -from wedpr_ml_toolkit.common import utils +from wedpr_ml_toolkit.common.utils import utils import time @@ -10,7 +10,7 @@ class CredentialInfo: TIMESTAMP_KEY = "timestamp" SIGNATURE_KEY = "signature" - def __init__(self, access_key_id: str, nonce: str, timestamp: str, signature: str): + def __init__(self, access_key_id: str, nonce: str, timestamp: int, signature: str): self.access_key_id = access_key_id self.nonce = nonce self.timestamp = timestamp @@ -18,14 +18,11 @@ def __init__(self, access_key_id: str, nonce: str, timestamp: str, signature: st def to_dict(self): result = {} - result.update(CredentialInfo.ACCESS_ID_KEY, self.access_key_id) - result.update(CredentialInfo.NONCE_KEY, self.nonce) - result.update(CredentialInfo.TIMESTAMP_KEY, self.timestamp) - result.update(CredentialInfo.SIGNATURE_KEY, self.signature) - - def update_url_with_auth_info(self, url): - auth_params = self.to_dict() - return utils.add_params_to_url(auth_params) + result.update({CredentialInfo.ACCESS_ID_KEY: self.access_key_id}) + result.update({CredentialInfo.NONCE_KEY: self.nonce}) + result.update({CredentialInfo.TIMESTAMP_KEY: self.timestamp}) + result.update({CredentialInfo.SIGNATURE_KEY: self.signature}) + return result class CredentialGenerator: @@ -46,10 +43,11 @@ def generate_credential(self) -> CredentialInfo: def generate_signature(access_key_id, access_key_secret, nonce, timestamp) -> str: anti_replay_info_hash = hashlib.sha3_256() # hash(access_key_id + nonce + timestamp) - anti_replay_info = f"{access_key_id}{nonce}{timestamp}" - anti_replay_info_hash.update(anti_replay_info) + anti_replay_info_hash.update( + bytes(access_key_id + nonce + str(timestamp), encoding='utf-8')) # hash(anti_replay_info + access_key_secret) signature_hash = hashlib.sha3_256() - signature_hash.update(anti_replay_info_hash.hexdigest()) - signature_hash.update(access_key_secret) + signature_hash.update( + bytes(anti_replay_info_hash.hexdigest(), encoding='utf-8')) + signature_hash.update(bytes(access_key_secret, encoding='utf-8')) return signature_hash.hexdigest() diff --git a/python/wedpr_ml_toolkit/transport/wedpr_entrypoint.py b/python/wedpr_ml_toolkit/transport/wedpr_entrypoint.py index 9ffe3671..8832f22f 100644 --- a/python/wedpr_ml_toolkit/transport/wedpr_entrypoint.py +++ b/python/wedpr_ml_toolkit/transport/wedpr_entrypoint.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- import requests from wedpr_ml_toolkit.transport.credential_generator import CredentialGenerator +from wedpr_ml_toolkit.config.wedpr_ml_config import HttpConfig from wedpr_ml_toolkit.common.utils.constant import Constant import json class LoadBanlancer: - def __init__(self, remote_entrypoints): + def __init__(self, remote_entrypoints: []): if remote_entrypoints == None or len(remote_entrypoints) == 0: raise Exception(f"Must define the wedpr entrypoints") self.remote_entrypoints = remote_entrypoints @@ -18,22 +19,33 @@ def select(self, uri_path: str): self.round_robin_idx += 1 selected_entrypoint = self.remote_entrypoints[selected_idx % len( self.remote_entrypoints)] - return f"{selected_entrypoint}/${uri_path}" + return f"{selected_entrypoint}/{uri_path}" class WeDPREntryPoint: - def __init__(self, access_key_id: str, access_key_secret: str, remote_entrypoints, nonce_len: int = 5): + def __init__(self, access_key_id: str, access_key_secret: str, remote_entrypoints: [], http_config: HttpConfig, nonce_len: int = 5): self.credential_generator = CredentialGenerator( access_key_id, access_key_secret, nonce_len) + self.http_config = http_config self.loadbalancer = LoadBanlancer(remote_entrypoints) - def send_post_request(self, uri, params, headers, data): + def send_request(self, is_post: bool, uri, params, headers, data): credential_info = self.credential_generator.generate_credential() - url = credential_info.update_url_with_auth_info( - self.loadbalancer.select(uri)) + if params is None: + params = {} + params.update(credential_info.to_dict()) + url = self.loadbalancer.select(uri) if not headers: headers = {'content-type': 'application/json'} - response = requests.post(url, data=data, json=params, headers=headers) + response = None + if is_post: + print( + f"#### url: {url}, params: {str(params)}, data: {data}, timeout_seconds: {self.http_config.timeout_seconds}") + response = requests.post(url, data=data, params=params, headers=headers, + timeout=self.http_config.timeout_seconds) + else: + requests.get(url, params=params, headers=headers, + timeout=self.http_config.timeout_seconds) if response.status_code != Constant.HTTP_STATUS_OK: raise Exception( f"send post request to {url} failed, response: {response.text}") diff --git a/python/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py b/python/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py index 1fb1aeb0..d67cca0e 100644 --- a/python/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py +++ b/python/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py @@ -2,15 +2,75 @@ from wedpr_ml_toolkit.common.utils.constant import Constant from wedpr_ml_toolkit.config.wedpr_ml_config import AuthConfig from wedpr_ml_toolkit.config.wedpr_ml_config import JobConfig +from wedpr_ml_toolkit.config.wedpr_ml_config import HttpConfig +from wedpr_ml_toolkit.common.utils.base_object import BaseObject import json import time - - -class JobInfo: - def __init__(self, job_type: str, project_name: str, param: str): - self.jobType = job_type +from enum import Enum +from typing import Any + + +class JobType(Enum): + PSI = "PSI", + PREPROCESSING = "PREPROCESSING", + FEATURE_ENGINEERING = "FEATURE_ENGINEERING", + XGB_TRAINING = "XGB_TRAINING", + XGB_PREDICTING = "XGB_PREDICTING" + + +class JobStatus(Enum): + Submitted = "Submitted", + Running = "Running", + RunFailed = "RunFailed", + RunSuccess = "RunSuccess", + WaitToRetry = "WaitToRetry", + WaitToKill = "WaitToKill", + Killing = "Killing", + Killed = "Killed", + ChainInProgress = "ChainInProgress" + + def run_success(self) -> bool: + if self.name == JobStatus.RunSuccess: + return True + return False + + def run_failed(self) -> bool: + if self.name == JobStatus.RunFailed: + return True + return False + + def scheduling(self) -> bool: + return (not self.run_success()) and (not self.run_failed()) + + @staticmethod + def get_job_status(job_status_tr: str): + try: + if job_status_tr is None or len(job_status_tr) == 0: + return None + return JobStatus[job_status_tr] + except: + return None + + +class JobInfo(BaseObject): + def __init__(self, job_id: str = None, job_type: JobType = None, project_name: str = None, param: str = None, **params: Any): + self.id = job_id + self.name = None + self.owner = None + self.ownerAgency = None + if job_type is not None: + self.jobType = job_type.name + else: + self.jobType = None + self.parties = None self.projectName = project_name self.param = param + self.status = None + self.result = None + self.createTime = None + self.lastUpdateTime = None + self.set_params(**params) + self.job_status = JobStatus.get_job_status(self.status) class JobParam: @@ -20,31 +80,63 @@ def __init__(self, job_info: JobInfo, task_parities, dataset_list): self.datasetList = dataset_list -class WeDPRResponse: - def __init__(self, code, msg, data): - self.code = code - self.msg = msg - self.data = data +class WeDPRResponse(BaseObject): + def __init__(self, **params: Any): + self.code = None + self.msg = None + self.data = None + self.set_params(**params) def success(self): - return self.code == 0 + return self.code is not None and self.code == 0 -class QueryJobRequest: - def __init__(self, job_id): - # the job condition - self.job = {} - self.job.update("id", job_id) - +class QueryJobRequest(BaseObject): + def __init__(self, job_info): + self.job = job_info -class WeDPRRemoteJobClient(WeDPREntryPoint): - def __init__(self, auth_config: AuthConfig, job_config: JobConfig): + def as_dict(self): + if self.job is None: + return {} + return self.job.as_dict() + + +class JobListResponse(BaseObject): + def __init__(self, **params: Any): + self.jobs = [] + self.job_object_list = [] + self.total = None + self.set_params(**params) + for job_item in self.jobs: + self.job_object_list.append(JobInfo(param=job_item)) + + def get_queried_job(self): + if len(self.job_object_list) == 0: + return None + return self.job_object_list[0] + + +class JobDetailResponse(BaseObject): + def __init__(self, job: JobInfo = None, **params: Any): + self.job = job + self.job_object = None + self.modelResultDetail = None + self.resultFileInfo = None + self.model = None + self.set_params(**params) + # deserialize the job_object + self.job_object = JobInfo(self.job) + # TODO: deserialize the result + + +class WeDPRRemoteJobClient(WeDPREntryPoint, BaseObject): + def __init__(self, http_config: HttpConfig, auth_config: AuthConfig, job_config: JobConfig): if auth_config is None: raise Exception("Must define the auth config!") if job_config is None: raise Exception("Must define the job config") super().__init__(auth_config.access_key_id, auth_config.access_key_secret, - auth_config.remote_entrypoints, auth_config.nonce_len) + auth_config.get_remote_entrypoints_list(), http_config, auth_config.nonce_len) self.auth_config = auth_config self.job_config = job_config @@ -55,8 +147,8 @@ def get_job_config(self): return self.job_config def submit_job(self, job_params: JobParam) -> WeDPRResponse: - wedpr_response = self.send_post_request( - self.job_config._submit_job_uri, None, None, json.dumps(job_params)) + wedpr_response = self.send_request(True, + self.job_config._submit_job_uri, None, None, json.dumps(job_params)) submit_result = WeDPRResponse(**wedpr_response) # return the job_id if submit_result.success(): @@ -64,27 +156,51 @@ def submit_job(self, job_params: JobParam) -> WeDPRResponse: raise Exception( f"submit_job failed, code: {submit_result.code}, msg: {submit_result.msg}") - def _poll_task_status(self, job_id, token): + def query_job_detail(self, job_id, block_until_finish) -> JobDetailResponse: + job_result = self.poll_job_result(job_id, block_until_finish) + # failed case + if job_result == None or job_result.job_status == None or (not job_result.job_status.run_success()): + return JobDetailResponse(job=job_result, params=None) + # success case + params = {} + params["jobID"] = job_id + response_dict = self.execute_with_retry(self.send_request, + self.job_config.max_retries, + self.job_config.retry_delay_s, + True, + self.job_config.query_job_detail_uri, + params, + None, None) + wedpr_response = WeDPRResponse(**response_dict) + if not wedpr_response.success(): + raise Exception( + f"query_job_detail exception, job: {job_id}, code: {wedpr_response.code}, msg: {wedpr_response.msg}") + return JobDetailResponse(**(wedpr_response.data)) + + def poll_job_result(self, job_id, block_until_finish) -> JobInfo: while True: - wedpr_response = WeDPRResponse(self._send_request_with_retry( - self.send_post_request, self.job_config.query_job_status_uri, None, None, json.dumps(QueryJobRequest(job_id)))) - # TODO: check with status - if wedpr_response.success(): - return wedpr_response - else: + query_condition = JobInfo(job_id=job_id) + response_dict = self.execute_with_retry(self.send_request, + self.job_config.max_retries, + self.job_config.retry_delay_s, + True, + self.job_config.query_job_status_uri, + None, None, json.dumps(QueryJobRequest(job_info=query_condition).as_dict())) + wedpr_response = WeDPRResponse(**response_dict) + if not wedpr_response.success(): + raise Exception( + f"poll_job_result failed, job_id: {job_id}, code: {wedpr_response.code}, msg: {wedpr_response.msg}") + # check the result + result = JobListResponse(**(wedpr_response.data)) + result_job = result.get_queried_job() + if result_job is None: raise Exception( - f"_poll_task_status for job {job_id} failed, code: {wedpr_response.code}, msg: {wedpr_response.msg}") - time.sleep(self.job_config.polling_interval_s) - - def _send_request_with_retry(self, request_func, *args, **kwargs): - attempt = 0 - while attempt < self.max_retries: - try: - response = request_func(*args, **kwargs) - return response - except Exception as e: - attempt += 1 - if attempt < self.max_retries: - time.sleep(self.retry_delay_s) - else: - raise e + f"poll_job_result for the queried job {job_id} not exists!") + # run finished + if result_job.job_status.run_success() or result_job.job_status.run_failed(): + return result_job + # wait to finish + if block_until_finish: + time.sleep(self.job_config.polling_interval_s) + else: + return None diff --git a/python/wedpr_ml_toolkit/wedpr_ml_toolkit.py b/python/wedpr_ml_toolkit/wedpr_ml_toolkit.py index a4db7de2..6d8b008e 100644 --- a/python/wedpr_ml_toolkit/wedpr_ml_toolkit.py +++ b/python/wedpr_ml_toolkit/wedpr_ml_toolkit.py @@ -2,6 +2,8 @@ from wedpr_ml_toolkit.config.wedpr_ml_config import WeDPRMlConfig from wedpr_ml_toolkit.transport.wedpr_remote_job_client import WeDPRRemoteJobClient +from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobInfo +from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobDetailResponse from wedpr_ml_toolkit.transport.storage_entrypoint import StorageEntryPoint from wedpr_ml_toolkit.context.job_context import JobType from wedpr_ml_toolkit.context.job_context import PSIJobContext @@ -16,7 +18,7 @@ class WeDPRMlToolkit: def __init__(self, config: WeDPRMlConfig): self.config = config self.remote_job_client = WeDPRRemoteJobClient( - self.config.auth_config, self.config.job_config) + self.config.http_config, self.config.auth_config, self.config.job_config) self.storage_entry_point = StorageEntryPoint(self.config.user_config, self.config.storage_config) @@ -32,24 +34,26 @@ def get_storage_entry_point(self) -> StorageEntryPoint: def submit(self, job_param): return self.remote_job_client.submit_job(job_param) - def query_job_status(self, job_id, block_until_success=False): - if not block_until_success: - return self.remote_job_client.query_job_result(job_id) - return self.remote_job_client.wait_job_finished() + def query_job_status(self, job_id, block_until_finish=False) -> JobInfo: + return self.remote_job_client.poll_job_result(job_id, block_until_finish) - def query_job_detail(self, job_id, block_until_success): - job_result = self.query_job_status(job_id, block_until_success) - # TODO: determine success or not here - return self.remote_job_client.query_job_detail(job_id) + def query_job_detail(self, job_id, block_until_finish=False) -> JobDetailResponse: + return self.remote_job_client.query_job_detail(job_id, block_until_finish) - def build_job_context(self, job_type: JobType, project_name: str, dataset: DataContext, model_setting=None, id_fields='id'): + def build_job_context(self, job_type: JobType, project_name: str, dataset: DataContext, model_setting=None, + id_fields='id'): if job_type == JobType.PSI: - return PSIJobContext(self.remote_job_client, project_name, dataset, self.config.agency_config.agency_name, id_fields) + return PSIJobContext(self.remote_job_client, project_name, dataset, self.config.agency_config.agency_name, + id_fields) if job_type == JobType.PREPROCESSING: - return PreprocessingJobContext(self.remote_job_client, project_name, model_setting, dataset, self.config.agency_config.agency_name) + return PreprocessingJobContext(self.remote_job_client, project_name, model_setting, dataset, + self.config.agency_config.agency_name) if job_type == JobType.FEATURE_ENGINEERING: - return FeatureEngineeringJobContext(self.remote_job_client, project_name, model_setting, dataset, self.config.agency_config.agency_name) + return FeatureEngineeringJobContext(self.remote_job_client, project_name, model_setting, dataset, + self.config.agency_config.agency_name) if job_type == JobType.XGB_TRAINING: - return SecureLGBMTrainingJobContext(self.remote_job_client, project_name, model_setting, dataset, self.config.agency_config.agency_name) + return SecureLGBMTrainingJobContext(self.remote_job_client, project_name, model_setting, dataset, + self.config.agency_config.agency_name) if job_type == JobType.XGB_PREDICTING: - return SecureLGBMPredictJobContext(self.remote_job_client, project_name, model_setting, dataset, self.config.agency_config.agency_name) + return SecureLGBMPredictJobContext(self.remote_job_client, project_name, model_setting, dataset, + self.config.agency_config.agency_name)