-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
60765ab
commit e8d19cf
Showing
16 changed files
with
398 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
class Properties: | ||
def __init__(self, file_path): | ||
self.file_path = file_path | ||
|
||
def getProperties(self): | ||
try: | ||
pro_file = open(self.file_path, 'r', encoding='utf-8') | ||
properties = {} | ||
for line in pro_file: | ||
if line.find('=') > 0: | ||
strs = line.replace('\n', '').split('=') | ||
properties[strs[0].strip()] = strs[1].strip() | ||
except Exception as e: | ||
raise e | ||
else: | ||
pro_file.close() | ||
return properties |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# -*- coding: utf-8 -*- | ||
import os | ||
from typing import Any, Dict | ||
from wedpr_ml_toolkit.common.utils.constant import Constant | ||
from wedpr_ml_toolkit.common.utils.properies_parser import Properties | ||
|
||
|
||
class BaseConfig: | ||
def set_params(self, **params: Any): | ||
for key, value in params.items(): | ||
setattr(self, key, value) | ||
if hasattr(self, f"{key}"): | ||
setattr(self, f"{key}", value) | ||
return self | ||
|
||
|
||
class AuthConfig(BaseConfig): | ||
def __init__(self, access_key_id: str = None, access_key_secret: str = None, remote_entrypoints: str = None, nonce_len: int = 5): | ||
self.access_key_id = access_key_id | ||
self.access_key_secret = access_key_secret | ||
self.remote_entrypoints = remote_entrypoints | ||
self.nonce_len = nonce_len | ||
|
||
|
||
class JobConfig(BaseConfig): | ||
def __init__(self, polling_interval_s: int = 5, max_retries: int = 5, retry_delay_s: int = 5, | ||
submit_job_uri: str = Constant.DEFAULT_SUBMIT_JOB_URI, | ||
query_job_status_uri: str = Constant.DEFAULT_QUERY_JOB_STATUS_URL): | ||
self.polling_interval_s = polling_interval_s | ||
self.max_retries = max_retries | ||
self.retry_delay_s = retry_delay_s | ||
self.submit_job_uri = submit_job_uri | ||
self.query_job_status_uri = query_job_status_uri | ||
|
||
|
||
class StorageConfig(BaseConfig): | ||
def __init__(self, storage_endpoint: str = None): | ||
self.storage_endpoint = storage_endpoint | ||
|
||
|
||
class UserConfig(BaseConfig): | ||
def __init__(self, agency_name: str = None, workspace_path: str = None, user_name: str = None): | ||
self.agency_name = agency_name | ||
self.workspace_path = workspace_path | ||
self.user = user_name | ||
|
||
def get_workspace_path(self): | ||
return os.path.join(self.workspace_path, self.user) | ||
|
||
|
||
class WeDPRMlConfig: | ||
def __init__(self, config_dict): | ||
self.auth_config = AuthConfig() | ||
self.auth_config.set_params(**config_dict) | ||
self.job_config = JobConfig() | ||
self.job_config.set_params(**config_dict) | ||
self.storage_config = StorageConfig() | ||
self.storage_config.set_params(**config_dict) | ||
self.user_config = UserConfig() | ||
self.user_config.set_params(**config_dict) | ||
|
||
|
||
class WeDPRMlConfigBuilder: | ||
@staticmethod | ||
def build(config_dict) -> WeDPRMlConfig: | ||
return WeDPRMlConfig(config_dict) | ||
|
||
@staticmethod | ||
def build_from_properties_file(config_file_path): | ||
if not os.path.exists(config_file_path): | ||
raise Exception( | ||
f"build WeDPRMlConfig failed for the config file {config_file_path} not exits!") | ||
properties = Properties(config_file_path) | ||
return WeDPRMlConfigBuilder.build(properties.getProperties()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
class PreprocessingModelSetting: | ||
def __init__(self): | ||
self.use_psi = False | ||
self.fillna = False | ||
self.na_select = 1.0 | ||
self.filloutlier = False | ||
self.normalized = False | ||
self.standardized = False | ||
self.categorical = '' | ||
self.psi_select_col = '' | ||
self.psi_select_base = '' | ||
self.psi_select_base = 0.3 | ||
self.psi_select_bins = 4 | ||
self.corr_select = 0 | ||
self.use_goss = False | ||
|
||
|
||
class FeatureEngineeringEngineModelSetting: | ||
def __init__(self): | ||
self.use_iv = False | ||
self.group_num = 4 | ||
self.iv_thresh = 0.1 | ||
|
||
|
||
class CommmonSecureModelSetting: | ||
def __init__(self): | ||
self.learning_rate = 0.1 | ||
self.eval_set_column = '' | ||
self.train_set_value = '' | ||
self.eval_set_value = '' | ||
self.verbose_eval = 1 | ||
self.silent = False | ||
self.train_features = '' | ||
self.random_state = None | ||
self.n_jobs = 0 | ||
|
||
|
||
class SecureLGBMModelSetting(CommmonSecureModelSetting): | ||
def __init__(self): | ||
super().__init__() | ||
self.test_size = 0.3 | ||
self.num_trees = 6 | ||
self.max_depth = 3 | ||
self.max_bin = 4 | ||
self.subsample = 1.0 | ||
self.colsample_bytree = 1 | ||
self.colsample_bylevel = 1 | ||
self.reg_alpha = 0 | ||
self.reg_lambda = 1.0 | ||
self.gamma = 0.0 | ||
self.min_child_weight = 0.0 | ||
self.min_child_samples = 10 | ||
self.seed = 2024 | ||
self.early_stopping_rounds = 5 | ||
self.eval_metric = "auc" | ||
self.threads = 8 | ||
self.one_hot = 0 | ||
|
||
|
||
class SecureLRModelSetting(CommmonSecureModelSetting): | ||
def __init__(self): | ||
super().__init__() | ||
self.feature_rate = 1.0 | ||
self.batch_size = 16 | ||
self.epochs = 3 | ||
|
||
|
||
class ModelSetting(PreprocessingModelSetting, FeatureEngineeringEngineModelSetting, SecureLGBMModelSetting, SecureLRModelSetting): | ||
def __init__(self): | ||
# init PreprocessingSetting | ||
super().__init__() | ||
# init FeatureEngineeringEngineSetting | ||
super(FeatureEngineeringEngineModelSetting, self).__init__(model_dict) | ||
# init SecureLGBMSetting | ||
super(SecureLGBMModelSetting, self).__init__(model_dict) | ||
# init SecureLRSetting | ||
super(SecureLRModelSetting, self).__init__(model_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
import os | ||
|
||
from wedpr_ml_toolkit.utils import utils | ||
from wedpr_ml_toolkit.common import utils | ||
|
||
|
||
class DataContext: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
access_key_id="" | ||
access_key_secret="" | ||
remote_entrypoints="http://127.0.0.1:16000,http://127.0.0.1:16001" | ||
|
||
agency_name="SGD" | ||
workspace_path="/user/wedpr/milestone2/sgd/" | ||
user="test_user" | ||
storage_endpoint="http://127.0.0.1:50070" | ||
|
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.