forked from unmtransinfo/ProteinGraphML
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrainModelML.py
executable file
·130 lines (113 loc) · 6.06 KB
/
TrainModelML.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python3
###
import sys, os, time, argparse, logging, yaml
import pyreadr, pickle
import numpy as np
import pandas as pd
import networkx as nx
from ProteinGraphML.DataAdapter import OlegDB, selectAsDF, TCRD
from ProteinGraphML.GraphTools import ProteinDiseaseAssociationGraph
from ProteinGraphML.MLTools.MetapathFeatures import metapathFeatures, ProteinInteractionNode, KeggNode, ReactomeNode, \
GoNode, InterproNode, getMetapaths
from ProteinGraphML.MLTools.Data import BinaryLabel
from ProteinGraphML.MLTools.Models import XGBoostModel
from ProteinGraphML.MLTools.Procedures import *
from ProteinGraphML.Analysis import Visualize
def getSourceForStaticFeatures(idSource, folder, features):
"""
Read the header of each static feature file and add to the dictionary
idSource.
"""
static_features = features.split(',')
for feature in static_features:
flname = folder + '/' + feature + '.tsv'
with open(flname, 'r') as f:
header = f.readline()
vals = header.strip().split('\t')[1:]
idSource.update({v: str(feature) for v in vals})
return idSource
# ************************ START OF THE CODE ********************************* #
if __name__ == '__main__':
"""
This program uses the training data set generated by "GenTrainingAndTestFeatures.py" to train a machine learning
model. Currently it trains XGBoost-based model and saves it to predict the labels of the records in the prediction
set.
It uses grid search to find the optimal parameters for the XGBoost; running this code with procedure "XGBGridSearch"
can take several hours.
XGBoost uses 5-fold cross-validation to assign the predicted labels to the training records. Classification results
and list of important features used by the XGBoost to train the model are saved in files.
"""
t0 = time.time()
NUM_OF_ROUNDS = 10
RSEED = 1234
NTHREADS = 1
PROCEDURES = ["XGBCrossValPred", "XGBKfoldsRunPred", "XGBGridSearch"]
XGB_PARAMETERS_FILE = 'XGBparams.txt'
DEFAULT_STATIC_FEATURES = "gtex,lincs,ccle,hpa"
DBS = ['olegdb', 'tcrd']
parser = argparse.ArgumentParser(description='Run ML Procedure',
epilog='--disease or --file must be specified; available procedures: {0}'.format(
str(PROCEDURES)))
parser.add_argument('procedure', choices=PROCEDURES, help='ML procedure to run')
parser.add_argument('--trainingfile', help='input file, pickled training data, e.g. "diabetesTrainData.pkl"')
parser.add_argument('--resultdir', help='folder where results will be saved, e.g. "diabetes_no_lincs"')
parser.add_argument('--rseed', type=int, default=RSEED, help='random seed for XGboost')
parser.add_argument('--nthreads', type=int, default=NTHREADS, help='Number of CPU threads for GridSearch')
parser.add_argument('--nrounds_for_avg', type=int, default=NUM_OF_ROUNDS,
help='number of iterations for average AUC,ACC,MCC (default: "{0}")'.format(NUM_OF_ROUNDS))
parser.add_argument('--xgboost_param_file', default=XGB_PARAMETERS_FILE,
help='text file containing parameters for XGBoost classifier (e.g. XGBparams.txt)')
parser.add_argument("-v", "--verbose", action="count", default=0, help="verbosity")
parser.add_argument('--db', choices=DBS, default="tcrd", help='{0}'.format(str(DBS)))
parser.add_argument('--static_data', default=DEFAULT_STATIC_FEATURES,
help='(default: "{0}")'.format(DEFAULT_STATIC_FEATURES))
parser.add_argument('--static_dir', default=os.getcwd() + "/static_tcrd/")
args = parser.parse_args()
# argData = vars(parser.parse_args())
logging.basicConfig(format='%(levelname)s:%(message)s', level=(logging.DEBUG if args.verbose > 1 else logging.INFO))
# Get data from file
trainingDataFile = args.trainingfile
if trainingDataFile is None:
parser.error("--trainingfile must be specified.")
else:
try:
with open(trainingDataFile, 'rb') as f:
trainData = pickle.load(f)
except:
logging.error('Must generate pickled training data file')
exit()
Procedure = args.procedure
logging.info('Procedure: {0}'.format(Procedure))
# Get reult directory and number of folds
if args.resultdir is not None:
resultDir = args.resultdir # folder where all results will be stored
logging.info('Results will be saved in directory: {0}'.format(resultDir))
else:
logging.error('Result directory is needed')
exit()
# nfolds = args.nrounds_for_avg # applicable for average CV
# fetch the parameters for XGboost from the text file
paramVals = ""
with open(args.xgboost_param_file, 'r') as fh:
for line in fh:
paramVals += line.strip().strip(' ')
xgbParams = yaml.full_load(paramVals)
# Access the db adaptor. Make TCRD as the default DB
dbAdapter = OlegDB() if args.db == "olegdb" else TCRD()
idDescription = dbAdapter.fetchPathwayIdDescription() # fetch the description
idNameSymbol = dbAdapter.fetchSymbolForProteinId() # fetch name and symbol for protein
idSource = dbAdapter.addDatabaseSourceToProteinId() # fetch protein source
idSource = getSourceForStaticFeatures(idSource, args.static_dir, args.static_data)
# call ML codes
d = BinaryLabel()
d.loadData(trainData)
if Procedure == "XGBKfoldsRunPred":
locals()[Procedure](d, idDescription, idNameSymbol, idSource, resultDir, args.nrounds_for_avg, params=xgbParams)
elif Procedure == "XGBCrossValPred":
locals()[Procedure](d, idDescription, idNameSymbol, idSource, resultDir, params=xgbParams)
elif Procedure == "XGBGridSearch":
locals()[Procedure](d, idDescription, idNameSymbol, resultDir, args.rseed, args.nthreads)
else:
logging.error('Wrong procedure entered !!!')
logging.info('{0}: elapsed time: {1}'.format(os.path.basename(sys.argv[0]),
time.strftime('%Hh:%Mm:%Ss', time.gmtime(time.time() - t0))))