Skip to content

Commit

Permalink
fix wedpr_ml_toolkit (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyjseagull authored Oct 17, 2024
1 parent 60765ab commit e8d19cf
Show file tree
Hide file tree
Showing 16 changed files with 398 additions and 111 deletions.
19 changes: 19 additions & 0 deletions python/wedpr_ml_toolkit/common/utils/properies_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# -*- coding: utf-8 -*-

class Properties:
def __init__(self, file_path):
self.file_path = file_path

def getProperties(self):
try:
pro_file = open(self.file_path, 'r', encoding='utf-8')
properties = {}
for line in pro_file:
if line.find('=') > 0:
strs = line.replace('\n', '').split('=')
properties[strs[0].strip()] = strs[1].strip()
except Exception as e:
raise e
else:
pro_file.close()
return properties
5 changes: 4 additions & 1 deletion python/wedpr_ml_toolkit/common/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import uuid
from enum import Enum
import random
from common.utils.constant import Constant
from wedpr_ml_toolkit.common.utils.constant import Constant
from urllib.parse import urlencode, urlparse, parse_qs, quote


class IdPrefixEnum(Enum):
DATASET = "d-"
ALGORITHM = "a-"
Expand All @@ -14,9 +15,11 @@ class IdPrefixEnum(Enum):
def make_id(prefix):
return prefix + str(uuid.uuid4()).replace("-", "")


def generate_nonce(nonce_len):
return ''.join(random.choice(Constant.NUMERIC_ARRAY) for _ in range(nonce_len))


def add_params_to_url(url, params):
parsed_url = urlparse(url)
query_params = parse_qs(parsed_url.query)
Expand Down
Empty file.
74 changes: 74 additions & 0 deletions python/wedpr_ml_toolkit/config/wedpr_ml_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
import os
from typing import Any, Dict
from wedpr_ml_toolkit.common.utils.constant import Constant
from wedpr_ml_toolkit.common.utils.properies_parser import Properties


class BaseConfig:
def set_params(self, **params: Any):
for key, value in params.items():
setattr(self, key, value)
if hasattr(self, f"{key}"):
setattr(self, f"{key}", value)
return self


class AuthConfig(BaseConfig):
def __init__(self, access_key_id: str = None, access_key_secret: str = None, remote_entrypoints: str = None, nonce_len: int = 5):
self.access_key_id = access_key_id
self.access_key_secret = access_key_secret
self.remote_entrypoints = remote_entrypoints
self.nonce_len = nonce_len


class JobConfig(BaseConfig):
def __init__(self, polling_interval_s: int = 5, max_retries: int = 5, retry_delay_s: int = 5,
submit_job_uri: str = Constant.DEFAULT_SUBMIT_JOB_URI,
query_job_status_uri: str = Constant.DEFAULT_QUERY_JOB_STATUS_URL):
self.polling_interval_s = polling_interval_s
self.max_retries = max_retries
self.retry_delay_s = retry_delay_s
self.submit_job_uri = submit_job_uri
self.query_job_status_uri = query_job_status_uri


class StorageConfig(BaseConfig):
def __init__(self, storage_endpoint: str = None):
self.storage_endpoint = storage_endpoint


class UserConfig(BaseConfig):
def __init__(self, agency_name: str = None, workspace_path: str = None, user_name: str = None):
self.agency_name = agency_name
self.workspace_path = workspace_path
self.user = user_name

def get_workspace_path(self):
return os.path.join(self.workspace_path, self.user)


class WeDPRMlConfig:
def __init__(self, config_dict):
self.auth_config = AuthConfig()
self.auth_config.set_params(**config_dict)
self.job_config = JobConfig()
self.job_config.set_params(**config_dict)
self.storage_config = StorageConfig()
self.storage_config.set_params(**config_dict)
self.user_config = UserConfig()
self.user_config.set_params(**config_dict)


class WeDPRMlConfigBuilder:
@staticmethod
def build(config_dict) -> WeDPRMlConfig:
return WeDPRMlConfig(config_dict)

@staticmethod
def build_from_properties_file(config_file_path):
if not os.path.exists(config_file_path):
raise Exception(
f"build WeDPRMlConfig failed for the config file {config_file_path} not exits!")
properties = Properties(config_file_path)
return WeDPRMlConfigBuilder.build(properties.getProperties())
79 changes: 79 additions & 0 deletions python/wedpr_ml_toolkit/config/wedpr_model_setting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-

class PreprocessingModelSetting:
def __init__(self):
self.use_psi = False
self.fillna = False
self.na_select = 1.0
self.filloutlier = False
self.normalized = False
self.standardized = False
self.categorical = ''
self.psi_select_col = ''
self.psi_select_base = ''
self.psi_select_base = 0.3
self.psi_select_bins = 4
self.corr_select = 0
self.use_goss = False


class FeatureEngineeringEngineModelSetting:
def __init__(self):
self.use_iv = False
self.group_num = 4
self.iv_thresh = 0.1


class CommmonSecureModelSetting:
def __init__(self):
self.learning_rate = 0.1
self.eval_set_column = ''
self.train_set_value = ''
self.eval_set_value = ''
self.verbose_eval = 1
self.silent = False
self.train_features = ''
self.random_state = None
self.n_jobs = 0


class SecureLGBMModelSetting(CommmonSecureModelSetting):
def __init__(self):
super().__init__()
self.test_size = 0.3
self.num_trees = 6
self.max_depth = 3
self.max_bin = 4
self.subsample = 1.0
self.colsample_bytree = 1
self.colsample_bylevel = 1
self.reg_alpha = 0
self.reg_lambda = 1.0
self.gamma = 0.0
self.min_child_weight = 0.0
self.min_child_samples = 10
self.seed = 2024
self.early_stopping_rounds = 5
self.eval_metric = "auc"
self.threads = 8
self.one_hot = 0


class SecureLRModelSetting(CommmonSecureModelSetting):
def __init__(self):
super().__init__()
self.feature_rate = 1.0
self.batch_size = 16
self.epochs = 3


class ModelSetting(PreprocessingModelSetting, FeatureEngineeringEngineModelSetting, SecureLGBMModelSetting, SecureLRModelSetting):
def __init__(self):
# init PreprocessingSetting
super().__init__()
# init FeatureEngineeringEngineSetting
super(FeatureEngineeringEngineModelSetting, self).__init__(model_dict)
# init SecureLGBMSetting
super(SecureLGBMModelSetting, self).__init__(model_dict)
# init SecureLRSetting
super(SecureLRModelSetting, self).__init__(model_dict)
2 changes: 1 addition & 1 deletion python/wedpr_ml_toolkit/context/data_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from wedpr_ml_toolkit.utils import utils
from wedpr_ml_toolkit.common import utils


class DataContext:
Expand Down
44 changes: 31 additions & 13 deletions python/wedpr_ml_toolkit/context/job_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobInfo
from abc import abstractmethod
from wedpr_ml_toolkit.transport.wedpr_remote_job_client import WeDPRRemoteJobClient
from enum import Enum


class JobType(Enum):
PSI = "PSI",
PREPROCESSING = "PREPROCESSING",
FEATURE_ENGINEERING = "FEATURE_ENGINEERING",
XGB_TRAINING = "XGB_TRAINING",
XGB_PREDICTING = "XGB_PREDICTING"


class JobContext:
Expand Down Expand Up @@ -63,20 +72,29 @@ def build(self) -> JobParam:
pass

@abstractmethod
def get_job_type(self) -> str:
def get_job_type(self) -> JobType:
pass

def submit(self):
return self.remote_job_client.submit_job(self.build())

@abstractmethod
def parse_result(self, result_detail):
pass

def submit(self, project_name):
return self.submit(self.build(project_name))
def fetch_job_result(self, job_id, block_until_success):
job_result = self.query_job_status(job_id, block_until_success)
# TODO: determine success or not here
return self.parse_result(self.remote_job_client.query_job_detail(job_id))


class PSIJobContext(JobContext):
def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, dataset: DataContext = None, my_agency=None, merge_field: str = 'id'):
super().__init__(remote_job_client, project_name, dataset, my_agency)
self.merge_field = merge_field

def get_job_type(self) -> str:
return "PSI"
def get_job_type(self) -> JobType:
return JobType.PSI

def build(self) -> JobParam:
self.dataset_list = self.dataset.to_psi_format(
Expand All @@ -92,8 +110,8 @@ def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, m
super().__init__(remote_job_client, project_name, dataset, my_agency)
self.model_setting = model_setting

def get_job_type(self) -> str:
return "PREPROCESSING"
def get_job_type(self) -> JobType:
return JobType.PREPROCESSING

# TODO: build the request
def build(self) -> JobParam:
Expand All @@ -105,8 +123,8 @@ def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, m
super().__init__(remote_job_client, project_name, dataset, my_agency)
self.model_setting = model_setting

def get_job_type(self) -> str:
return "FEATURE_ENGINEERING"
def get_job_type(self) -> JobType:
return JobType.FEATURE_ENGINEERING

# TODO: build the jobParam
def build(self) -> JobParam:
Expand All @@ -118,8 +136,8 @@ def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, m
super().__init__(remote_job_client, project_name, dataset, my_agency)
self.model_setting = model_setting

def get_job_type(self) -> str:
return "XGB_TRAINING"
def get_job_type(self) -> JobType:
return JobType.XGB_TRAINING

# TODO: build the jobParam
def build(self) -> JobParam:
Expand All @@ -131,8 +149,8 @@ def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, m
super().__init__(remote_job_client, project_name, dataset, my_agency)
self.model_setting = model_setting

def get_job_type(self) -> str:
return "XGB_PREDICTING"
def get_job_type(self) -> JobType:
return JobType.XGB_PREDICTING

# TODO: build the jobParam
def build(self) -> JobParam:
Expand Down
9 changes: 9 additions & 0 deletions python/wedpr_ml_toolkit/test/config.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
access_key_id=""
access_key_secret=""
remote_entrypoints="http://127.0.0.1:16000,http://127.0.0.1:16001"

agency_name="SGD"
workspace_path="/user/wedpr/milestone2/sgd/"
user="test_user"
storage_endpoint="http://127.0.0.1:50070"

78 changes: 0 additions & 78 deletions python/wedpr_ml_toolkit/test/test_dev.py

This file was deleted.

Loading

0 comments on commit e8d19cf

Please sign in to comment.