-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
124 lines (97 loc) · 3.53 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from call_methods import make_network, make_params
from utils.logs import Logger
from process.preprocessing import DataProcessor
from feature_analysis.calculate_chisquare import Chi2_Calculation
from feature_analysis.calculate_correlation import CorrelationCoefficient
from process.train_test_split import TrainTestProcessor
from options.train_options import TrainOptions
from utils.save_utils import save_model_and_logs
from feature_importance.calculate_sfs import CalculateSfsImportance
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
def run() -> None:
"""
Run the training process
Parameters
----------
None
Returns
-------
None
Process
-------
1. Parse the training options
2. Initialize the logger
3. Preprocess the data
4. Calculate the correlation matrix
5. Calculate the chi-square scores
6. Split the data into training and testing sets
7. Initialize the model
8. Train the model
9. Evaluate the model
10. Tune the model
11. Calculate the SHAP values
12. Calculate the SFS importance
13. Save the model and logs
"""
# Parse the training options
opt = TrainOptions().parse()
# Initialize the logger
logger = Logger(opt)
# Initialize and process data
processor = DataProcessor(opt.data_path, logger, opt)
# Missing Value Imputation Dictionary
imputation_dict = opt.missing_values_imputation
# Encode the data
processed_data, missing_values, chi2_data = processor.process_and_save(
imputation_dict,
label_encode_columns=opt.label_encode_columns,
one_hot_encode_columns=opt.one_hot_encode_columns,
dtype_dict=opt.dtype_dict,
feature_engg_names=opt.feature_engg_name,
calculate_chi2=opt.calculate_chi2,
)
# Log missing values
logger.update_log("data_processing", "missing_values", missing_values.to_dict())
# Perform Correlation calculation
correlation_calc = CorrelationCoefficient(processed_data, logger, opt)
correlation_calc.calculate_correlation()
# Perform Chi-Square calculation
chi2_calc = Chi2_Calculation(chi2_data, processed_data, opt, logger)
chi2_calc.get_chi2_scores()
# Initialize TrainTestProcessor
train_test_processor = TrainTestProcessor(processed_data, logger, opt)
X_train, X_test, y_train, y_test = train_test_processor.process()
# Perform final checks
train_test_processor.final_checks(X_train, X_test, y_train, y_test)
# Initialize model using make_network
model = make_network(opt.model_name, logger, opt)
# Train the model
model.train(X_train, y_train)
# Evaluate the model
model.evaluate(X_test, y_test)
# Tune the model
get_params_func = make_params(opt.model_name)
model.model_tuning(
get_params_func,
X_train,
y_train,
X_test,
y_test,
n_trials=opt.n_trials,
)
# Calculating SHAP values
model.shap_calculation(X_test, model.model_type)
# Calculating SFS importance with default parameters
sfs_default = CalculateSfsImportance(model, logger, opt)
sfs_default.perform_sfs(X_train, y_train, tuning_phase="before")
# Calculating SFS importance with tuned parameters
tuned_params = model.get_params()
sfs_tuned = CalculateSfsImportance(model, logger, opt)
sfs_tuned.perform_sfs(
X_train, y_train, model_params=tuned_params, tuning_phase="after"
)
# Save the model and logs
save_model_and_logs(model, logger, opt)
if __name__ == "__main__":
run()