Skip to content

Commit

Permalink
complement queryJob and queryJobDetails
Browse files Browse the repository at this point in the history
  • Loading branch information
cyjseagull committed Oct 18, 2024
1 parent e8d19cf commit bd3511c
Show file tree
Hide file tree
Showing 12 changed files with 356 additions and 169 deletions.
28 changes: 28 additions & 0 deletions python/wedpr_ml_toolkit/common/utils/base_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
from typing import Any
import time


class BaseObject:
def set_params(self, **params: Any):
for key, value in params.items():
setattr(self, key, value)
if hasattr(self, f"{key}"):
setattr(self, f"{key}", value)
return self

def as_dict(obj):
return {attr: getattr(obj, attr) for attr in dir(obj) if not callable(getattr(obj, attr)) and not attr.startswith("__")}

def execute_with_retry(self, request_func, retry_times, retry_wait_seconds, *args, **kwargs):
attempt = 0
while attempt < retry_times:
try:
response = request_func(*args, **kwargs)
return response
except Exception as e:
attempt += 1
if attempt < retry_times:
time.sleep(retry_wait_seconds)
else:
raise e
8 changes: 5 additions & 3 deletions python/wedpr_ml_toolkit/common/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
class Constant:
NUMERIC_ARRAY = [i for i in range(10)]
HTTP_STATUS_OK = 200
DEFAULT_SUBMIT_JOB_URI = '/api/wedpr/v3/project/submitJob'
DEFAULT_QUERY_JOB_STATUS_URL = '/api/wedpr/v3/project/queryJobByCondition'
WEDPR_API_PREFIX = '/api/wedpr/v3/'
DEFAULT_SUBMIT_JOB_URI = f'{WEDPR_API_PREFIX}project/submitJob'
DEFAULT_QUERY_JOB_STATUS_URL = f'{WEDPR_API_PREFIX}project/queryJobByCondition'
DEFAULT_QUERY_JOB_DETAIL_URL = f'{WEDPR_API_PREFIX}scheduler/queryJobDetail'
PSI_RESULT_FILE = "psi_result.csv"

FEATURE_BIN_FILE = "feature_bin.json"
TEST_MODEL_OUTPUT_FILE = "test_output.csv"
TRAIN_MODEL_OUTPUT_FILE = "train_output.csv"

FE_RESULT_FILE = "fe_result.csv"
FE_RESULT_FILE = "fe_result.csv"
2 changes: 1 addition & 1 deletion python/wedpr_ml_toolkit/common/utils/properies_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def getProperties(self):
properties = {}
for line in pro_file:
if line.find('=') > 0:
strs = line.replace('\n', '').split('=')
strs = line.strip("\"").replace('\n', '').split('=')
properties[strs[0].strip()] = strs[1].strip()
except Exception as e:
raise e
Expand Down
2 changes: 1 addition & 1 deletion python/wedpr_ml_toolkit/common/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def make_id(prefix):


def generate_nonce(nonce_len):
return ''.join(random.choice(Constant.NUMERIC_ARRAY) for _ in range(nonce_len))
return ''.join(str(random.choice(Constant.NUMERIC_ARRAY)) for _ in range(nonce_len))


def add_params_to_url(url, params):
Expand Down
37 changes: 21 additions & 16 deletions python/wedpr_ml_toolkit/config/wedpr_ml_config.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,42 @@
# -*- coding: utf-8 -*-
import os
from typing import Any, Dict
from wedpr_ml_toolkit.common.utils.base_object import BaseObject
from wedpr_ml_toolkit.common.utils.constant import Constant
from wedpr_ml_toolkit.common.utils.properies_parser import Properties


class BaseConfig:
def set_params(self, **params: Any):
for key, value in params.items():
setattr(self, key, value)
if hasattr(self, f"{key}"):
setattr(self, f"{key}", value)
return self


class AuthConfig(BaseConfig):
class AuthConfig(BaseObject):
def __init__(self, access_key_id: str = None, access_key_secret: str = None, remote_entrypoints: str = None, nonce_len: int = 5):
self.access_key_id = access_key_id
self.access_key_secret = access_key_secret
self.remote_entrypoints = remote_entrypoints
self.nonce_len = nonce_len

def get_remote_entrypoints_list(self) -> []:
if self.remote_entrypoints is None:
return None
return self.remote_entrypoints.split(',')


class JobConfig(BaseConfig):
def __init__(self, polling_interval_s: int = 5, max_retries: int = 5, retry_delay_s: int = 5,
class JobConfig(BaseObject):
def __init__(self, polling_interval_s: int = 5, max_retries: int = 2, retry_delay_s: int = 5,
submit_job_uri: str = Constant.DEFAULT_SUBMIT_JOB_URI,
query_job_status_uri: str = Constant.DEFAULT_QUERY_JOB_STATUS_URL):
query_job_status_uri: str = Constant.DEFAULT_QUERY_JOB_STATUS_URL,
query_job_detail_uri: str = Constant.DEFAULT_QUERY_JOB_STATUS_URL):
self.polling_interval_s = polling_interval_s
self.max_retries = max_retries
self.retry_delay_s = retry_delay_s
self.submit_job_uri = submit_job_uri
self.query_job_status_uri = query_job_status_uri
self.query_job_detail_uri = query_job_detail_uri


class StorageConfig(BaseConfig):
class StorageConfig(BaseObject):
def __init__(self, storage_endpoint: str = None):
self.storage_endpoint = storage_endpoint


class UserConfig(BaseConfig):
class UserConfig(BaseObject):
def __init__(self, agency_name: str = None, workspace_path: str = None, user_name: str = None):
self.agency_name = agency_name
self.workspace_path = workspace_path
Expand All @@ -48,6 +46,11 @@ def get_workspace_path(self):
return os.path.join(self.workspace_path, self.user)


class HttpConfig(BaseObject):
def __init__(self, timeout_seconds=3):
self.timeout_seconds = timeout_seconds


class WeDPRMlConfig:
def __init__(self, config_dict):
self.auth_config = AuthConfig()
Expand All @@ -58,6 +61,8 @@ def __init__(self, config_dict):
self.storage_config.set_params(**config_dict)
self.user_config = UserConfig()
self.user_config.set_params(**config_dict)
self.http_config = HttpConfig()
self.http_config.set_params(**config_dict)


class WeDPRMlConfigBuilder:
Expand Down
12 changes: 2 additions & 10 deletions python/wedpr_ml_toolkit/context/job_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,7 @@
from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobInfo
from abc import abstractmethod
from wedpr_ml_toolkit.transport.wedpr_remote_job_client import WeDPRRemoteJobClient
from enum import Enum


class JobType(Enum):
PSI = "PSI",
PREPROCESSING = "PREPROCESSING",
FEATURE_ENGINEERING = "FEATURE_ENGINEERING",
XGB_TRAINING = "XGB_TRAINING",
XGB_PREDICTING = "XGB_PREDICTING"
from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobType


class JobContext:
Expand Down Expand Up @@ -99,7 +91,7 @@ def get_job_type(self) -> JobType:
def build(self) -> JobParam:
self.dataset_list = self.dataset.to_psi_format(
self.merge_field, self.result_receiver_id_list)
job_info = JobInfo(self.get_job_type(), self.project_name, json.dumps(
job_info = JobInfo(job_type=self.get_job_type(), project_name=self.project_name, param=json.dumps(
{'dataSetList': self.dataset_list}).replace('"', '\\"'))
job_param = JobParam(job_info, self.task_parties, self.dataset_id_list)
return job_param
Expand Down
14 changes: 7 additions & 7 deletions python/wedpr_ml_toolkit/test/config.properties
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
access_key_id=""
access_key_secret=""
remote_entrypoints="http://127.0.0.1:16000,http://127.0.0.1:16001"
access_key_id=
access_key_secret=
remote_entrypoints=http://127.0.0.1:16000,http://127.0.0.1:16001

agency_name="SGD"
workspace_path="/user/wedpr/milestone2/sgd/"
user="test_user"
storage_endpoint="http://127.0.0.1:50070"
agency_name=SGD
workspace_path=/user/wedpr/milestone2/sgd/
user=test_user
storage_endpoint=http://127.0.0.1:50070

132 changes: 81 additions & 51 deletions python/wedpr_ml_toolkit/test/test_ml_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,63 +10,93 @@
from wedpr_ml_toolkit.context.job_context import JobType
from wedpr_ml_toolkit.config.wedpr_model_setting import PreprocessingModelSetting

wedpr_config = WeDPRMlConfigBuilder.build_from_properties_file(
"config.properties")

wedpr_ml_toolkit = WeDPRMlToolkit(wedpr_config)
class WeDPRMlToolkitTestWrapper:
def __init__(self, config_file_path):
self.wedpr_config = WeDPRMlConfigBuilder.build_from_properties_file(
config_file_path)
self.wedpr_ml_toolkit = WeDPRMlToolkit(self.wedpr_config)

# 注册 dataset,支持两种方式: pd.Dataframe, hdfs_path
df = pd.DataFrame({
'id': np.arange(0, 100), # id列,顺序整数
'y': np.random.randint(0, 2, size=100),
**{f'x{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数
})
def test_submit_job(self):
# 注册 dataset,支持两种方式: pd.Dataframe, hdfs_path
df = pd.DataFrame({
'id': np.arange(0, 100), # id列,顺序整数
'y': np.random.randint(0, 2, size=100),
# x1到x10列,随机数
**{f'x{i}': np.random.rand(100) for i in range(1, 11)}
})

dataset1 = DatasetToolkit(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(),
storage_workspace=wedpr_config.user_config.get_workspace_path(),
agency=wedpr_config.user_config.agency_name,
values=df,
is_label_holder=True)
dataset1.save_values(path='d-101')
dataset1 = DatasetToolkit(storage_entrypoint=self.wedpr_ml_toolkit.get_storage_entry_point(),
storage_workspace=self.wedpr_config.user_config.get_workspace_path(),
agency=self.wedpr_config.user_config.agency_name,
values=df,
is_label_holder=True)
dataset1.save_values(path='d-101')

# hdfs_path
dataset2 = DatasetToolkit(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(),
dataset_path="d-9606695119693829", agency="WeBank")
# hdfs_path
dataset2 = DatasetToolkit(storage_entrypoint=self.wedpr_ml_toolkit.get_storage_entry_point(),
dataset_path="d-9606695119693829", agency="WeBank")

dataset2.storage_client = None
# dataset2.load_values()
if dataset2.storage_client is None:
# 支持更新dataset的values数据
df2 = pd.DataFrame({
'id': np.arange(0, 100), # id列,顺序整数
**{f'z{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数
})
dataset2.update_values(values=df2)
if dataset1.storage_client is not None:
dataset1.update_values(
path='/user/ppc/milestone2/sgd/flyhuang1/d-9606704699156485')
dataset1.load_values()
dataset2.storage_client = None
# dataset2.load_values()
if dataset2.storage_client is None:
# 支持更新dataset的values数据
df2 = pd.DataFrame({
'id': np.arange(0, 100), # id列,顺序整数
# x1到x10列,随机数
**{f'z{i}': np.random.rand(100) for i in range(1, 11)}
})
dataset2.update_values(values=df2)
if dataset1.storage_client is not None:
dataset1.update_values(
path='/user/ppc/milestone2/sgd/flyhuang1/d-9606704699156485')
dataset1.load_values()

# 构建 dataset context
dataset = DataContext(dataset1, dataset2)
# 构建 dataset context
dataset = DataContext(dataset1, dataset2)

# init the job context
project_name = "1"
# init the job context
project_name = "1"

psi_job_context = wedpr_ml_toolkit.build_job_context(
JobType.PSI, project_name, dataset, None, "id")
print(psi_job_context.participant_id_list,
psi_job_context.result_receiver_id_list)
# 执行psi任务
psi_job_id = psi_job_context.submit()
psi_result = psi_job_context.fetch_job_result(psi_job_id, True)
psi_job_context = self.wedpr_ml_toolkit.build_job_context(
JobType.PSI, project_name, dataset, None, "id")
print(psi_job_context.participant_id_list,
psi_job_context.result_receiver_id_list)
# 执行psi任务
psi_job_id = psi_job_context.submit()
psi_result = psi_job_context.fetch_job_result(psi_job_id, True)

# 初始化
preprocessing_data = DataContext(dataset1)
preprocessing_job_context = wedpr_ml_toolkit.build_job_context(
JobType.PREPROCESSING, project_name, preprocessing_data, PreprocessingModelSetting())
# 执行预处理任务
fe_job_id = preprocessing_job_context.submit(dataset)
fe_result = preprocessing_job_context.fetch_job_result(fe_job_id, True)
print(preprocessing_job_context.participant_id_list,
preprocessing_job_context.result_receiver_id_list)
# 初始化
preprocessing_data = DataContext(dataset1)
preprocessing_job_context = self.wedpr_ml_toolkit.build_job_context(
JobType.PREPROCESSING, project_name, preprocessing_data, PreprocessingModelSetting())
# 执行预处理任务
fe_job_id = preprocessing_job_context.submit(dataset)
fe_result = preprocessing_job_context.fetch_job_result(fe_job_id, True)
print(preprocessing_job_context.participant_id_list,
preprocessing_job_context.result_receiver_id_list)

def test_query_job(self, job_id: str, block_until_finish):
job_result = self.wedpr_ml_toolkit.query_job_status(
job_id, block_until_finish)
print(f"#### job_result: {job_result}")
job_detail_result = self.wedpr_ml_toolkit.query_job_detail(
job_id, block_until_finish)
return (job_result, job_detail_result)


class TestMlToolkit(unittest.TestCase):
def test_query_jobs(self):
wrapper = WeDPRMlToolkitTestWrapper("config.properties")
# the success job case
success_job_id = "9630202187032582"
wrapper.test_query_job(success_job_id, False)
# wrapper.test_query_job(success_job_id, True)
# the fail job case
failed_job_id = "9630156365047814"
wrapper.test_query_job(success_job_id, False)
# wrapper.test_query_job(success_job_id, True)


if __name__ == '__main__':
unittest.main()
26 changes: 12 additions & 14 deletions python/wedpr_ml_toolkit/transport/credential_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
import hashlib
from wedpr_ml_toolkit.common import utils
from wedpr_ml_toolkit.common.utils import utils
import time


Expand All @@ -10,22 +10,19 @@ class CredentialInfo:
TIMESTAMP_KEY = "timestamp"
SIGNATURE_KEY = "signature"

def __init__(self, access_key_id: str, nonce: str, timestamp: str, signature: str):
def __init__(self, access_key_id: str, nonce: str, timestamp: int, signature: str):
self.access_key_id = access_key_id
self.nonce = nonce
self.timestamp = timestamp
self.signature = signature

def to_dict(self):
result = {}
result.update(CredentialInfo.ACCESS_ID_KEY, self.access_key_id)
result.update(CredentialInfo.NONCE_KEY, self.nonce)
result.update(CredentialInfo.TIMESTAMP_KEY, self.timestamp)
result.update(CredentialInfo.SIGNATURE_KEY, self.signature)

def update_url_with_auth_info(self, url):
auth_params = self.to_dict()
return utils.add_params_to_url(auth_params)
result.update({CredentialInfo.ACCESS_ID_KEY: self.access_key_id})
result.update({CredentialInfo.NONCE_KEY: self.nonce})
result.update({CredentialInfo.TIMESTAMP_KEY: self.timestamp})
result.update({CredentialInfo.SIGNATURE_KEY: self.signature})
return result


class CredentialGenerator:
Expand All @@ -46,10 +43,11 @@ def generate_credential(self) -> CredentialInfo:
def generate_signature(access_key_id, access_key_secret, nonce, timestamp) -> str:
anti_replay_info_hash = hashlib.sha3_256()
# hash(access_key_id + nonce + timestamp)
anti_replay_info = f"{access_key_id}{nonce}{timestamp}"
anti_replay_info_hash.update(anti_replay_info)
anti_replay_info_hash.update(
bytes(access_key_id + nonce + str(timestamp), encoding='utf-8'))
# hash(anti_replay_info + access_key_secret)
signature_hash = hashlib.sha3_256()
signature_hash.update(anti_replay_info_hash.hexdigest())
signature_hash.update(access_key_secret)
signature_hash.update(
bytes(anti_replay_info_hash.hexdigest(), encoding='utf-8'))
signature_hash.update(bytes(access_key_secret, encoding='utf-8'))
return signature_hash.hexdigest()
Loading

0 comments on commit bd3511c

Please sign in to comment.