Skip to content

Commit

Permalink
update resuming
Browse files Browse the repository at this point in the history
  • Loading branch information
ArnovanHilten committed Jun 30, 2022
1 parent dc47f51 commit aebdca0
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 105 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ results/*
*.swp
.log
*.log
examples/A_to_Z/processed_data/
10 changes: 10 additions & 0 deletions GenNet_utils/Create_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@

from GenNet_utils.Utility_functions import get_paths

def plot_loss_function(resultpath):
log_file = pd.read_csv(resultpath + "/train_log.csv")
plt.plot(log_file['loss'])
plt.plot(log_file['val_loss'])
plt.title('Loss curve')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper right')
plt.savefig(resultpath + "train_val_loss.png")
plt.show()

def sunburst_plot(resultpath, importance_csv, num_layers=3, plot_threshold=0.01, add_end_node=True):
csv_file = importance_csv.copy()
Expand Down
176 changes: 71 additions & 105 deletions GenNet_utils/Train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
import shutil
import matplotlib

import datetime
warnings.filterwarnings('ignore')
matplotlib.use('agg')
sys.path.insert(1, os.path.dirname(os.getcwd()) + "/GenNet_utils/")
Expand All @@ -27,6 +27,7 @@ def weighted_binary_crossentropy(y_true, y_pred):


def train_classification(args):
SlURM_JOB_ID = get_SLURM_id()
model = None
masks = None
datapath = args.path
Expand Down Expand Up @@ -98,17 +99,24 @@ def train_classification(args):
model.summary(print_fn=lambda x: fh.write(x + '\n'))

csv_logger = K.callbacks.CSVLogger(resultpath + 'train_log.csv', append=True)

early_stop = K.callbacks.EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=patience, verbose=1, mode='auto',
restore_best_weights=True)
save_best_model = K.callbacks.ModelCheckpoint(resultpath + "bestweights_job.h5", monitor='val_loss',
verbose=1, save_best_only=False, mode='auto')
verbose=1, save_best_only=True, mode='auto')


if os.path.exists(resultpath + '/bestweights_job.h5') and not(args.resume):
print('Model already Trained')
print('Model already Trained')
elif os.path.exists(resultpath + '/bestweights_job.h5') and args.resume:
print("load and save weights before resuming")
shutil.copyfile(resultpath + '/bestweights_job.h5', resultpath + '/weights_before_resuming.h5') # save old weights
shutil.copyfile(resultpath + '/bestweights_job.h5', resultpath + '/weights_before_resuming_'
+ datetime.datetime.now().strftime("%Y_%m_%d-%I_%M_%p")+'.h5') # save old weights
log_file = pd.read_csv(resultpath + "/train_log.csv")
save_best_model = K.callbacks.ModelCheckpoint(resultpath + "bestweights_job.h5", monitor='val_loss',
verbose=1, save_best_only=True, mode='auto',
initial_value_threshold=log_file.val_loss.min())

print("Resuming training")
model.load_weights(resultpath + '/bestweights_job.h5')
train_generator = TrainDataGenerator(datapath=datapath,
Expand All @@ -130,15 +138,6 @@ def train_classification(args):
inputsize=inputsize, evalset="validation")

)

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig(resultpath + "train_val_loss.png")
plt.show()
else:
print("Start training from scratch")
train_generator = TrainDataGenerator(datapath=datapath,
Expand All @@ -161,15 +160,7 @@ def train_classification(args):

)

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig(resultpath + "train_val_loss.png")
plt.show()

plot_loss_function(resultpath)
model.load_weights(resultpath + '/bestweights_job.h5')
print("Finished")
print("Analysis over the validation set")
Expand All @@ -189,46 +180,37 @@ def train_classification(args):
auc_test, confusionmatrix_test = evaluate_performance(ytest, ptest)
np.save(resultpath + "/ptest.npy", ptest)

data = {'Jobid': [args.ID],
'Datapath': [str(args.path)],
'genotype_path': [str(genotype_path)],
'Batchsize': [args.batch_size],
'Learning rate': [args.learning_rate],
'L1 value': [args.L1],
'patience': [args.patience],
'epoch size': [args.epoch_size],
'epochs': [args.epochs],
'Weight positive class': str(args.wpc),
'AUC validation': [auc_val],
'AUC test': [auc_test]}
pd_summary_row = pd.DataFrame(data)
pd_summary_row.to_csv(resultpath + "/Summary_results.csv")

with open(resultpath + '/Results_' + str(jobid) + '.txt', 'a') as f:
f.write('\n Jobid = ' + str(jobid))
f.write('\n Batchsize = ' + str(batch_size))
f.write('\n Weight positive class = ' + str(weight_positive_class))
f.write('\n Weight negative class= ' + str(weight_negative_class))
f.write('\n Learningrate = ' + str(lr_opt))
f.write('\n Optimizer = ' + str(optimizer_model))
f.write('\n L1 value = ' + str(l1_value))
f.write('\n')
f.write("Validation set")
f.write('\n Score auc = ' + str(auc_val))
f.write('\n Confusion matrix:')
f.write(str(confusionmatrix_val))
f.write('\n')
f.write("Test set")
f.write('\n Score auc = ' + str(auc_test))
f.write('\n Confusion matrix ')
f.write(str(confusionmatrix_test))
data = {'Jobid': args.ID,
'Datapath': str(args.path),
'genotype_path': str(genotype_path),
'Batchsize': args.batch_size,
'Learning rate': args.learning_rate,
'L1 value': args.L1,
'patience': args.patience,
'epoch size': args.epoch_size,
'epochs': args.epochs,
'Weight positive class': args.wpc,
'AUC validation': auc_val,
'AUC test': auc_test,
'SlURM_JOB_ID': SlURM_JOB_ID}

pd_summary_row = pd.Series(data)
pd_summary_row.to_csv(resultpath + "/pd_summary_results.csv")

data['confusionmatrix_val'] = confusionmatrix_val
data['confusionmatrix_test'] = confusionmatrix_test

with open(resultpath + "results_summary.txt", 'w') as f:
for key, value in data.items():
f.write('%s:%s\n' % (key, value))

if os.path.exists(datapath + "/topology.csv"):
importance_csv = create_importance_csv(datapath, model, masks)
importance_csv.to_csv(resultpath + "connection_weights.csv")


def train_regression(args):
SlURM_JOB_ID = get_SLURM_id()
model = None
masks = None
datapath = args.path
Expand Down Expand Up @@ -291,18 +273,25 @@ def train_regression(args):

with open(resultpath + '/model_architecture.txt', 'w') as fh:
model.summary(print_fn=lambda x: fh.write(x + '\n'))


csv_logger = K.callbacks.CSVLogger(resultpath + 'train_log.csv', append=True)
early_stop = K.callbacks.EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=patience, verbose=1, mode='auto',
restore_best_weights=True)
save_best_model = K.callbacks.ModelCheckpoint(resultpath + "bestweights_job.h5", monitor='val_loss',
verbose=1, save_best_only=True, mode='auto')
csv_logger = K.callbacks.CSVLogger(resultpath + 'train_log.csv', append=True)


if os.path.exists(resultpath + '/bestweights_job.h5') and not(args.resume):
print('Model already trained')
elif os.path.exists(resultpath + '/bestweights_job.h5') and args.resume:
print("load and save weights before resuming")
shutil.copyfile(resultpath + '/bestweights_job.h5', resultpath + '/weights_before_resuming.h5') # save old weights
shutil.copyfile(resultpath + '/bestweights_job.h5', resultpath + '/weights_before_resuming_'
+ datetime.datetime.now().strftime("%Y_%m_%d-%I_%M_%p")+'.h5') # save old weights

log_file = pd.read_csv(resultpath + "/train_log.csv")
save_best_model = K.callbacks.ModelCheckpoint(resultpath + "bestweights_job.h5", monitor='val_loss',
verbose=1, save_best_only=True, mode='auto',
initial_value_threshold=log_file.val_loss.min())
print("Resuming training")
model.load_weights(resultpath + '/bestweights_job.h5')

Expand All @@ -322,14 +311,6 @@ def train_regression(args):
validation_data=EvalGenerator(datapath=datapath, genotype_path=genotype_path, batch_size=batch_size,
setsize=val_size_train, inputsize=inputsize, evalset="validation")
)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig(resultpath + "train_val_loss.png")
plt.show()

else:
print("Start training from scratch")
Expand All @@ -349,15 +330,8 @@ def train_regression(args):
validation_data=EvalGenerator(datapath=datapath, genotype_path=genotype_path, batch_size=batch_size,
setsize=val_size_train, inputsize=inputsize, evalset="validation")
)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig(resultpath + "train_val_loss.png")
plt.show()

plot_loss_function(resultpath)
model.load_weights(resultpath + '/bestweights_job.h5')
print("Finished")
print("Analysis over the validation set")
Expand All @@ -378,37 +352,29 @@ def train_regression(args):
np.save(resultpath + "/ptest.npy", ptest)
fig.savefig(resultpath + "/test_predictions.png", bbox_inches='tight', pad_inches=0)

data = {'Jobid': [args.ID],
'Datapath': [str(args.path)],
'genotype_path': [str(genotype_path)],
'Batchsize': [args.batch_size],
'Learning rate': [args.learning_rate],
'L1 value': [args.L1],
'patience': [args.patience],
'epoch size': [args.epoch_size],
'epochs': [args.epochs],
'MSE validation': [mse_val],
'MSE test': [mse_test],
'Explained variance val': [explained_variance_val],
'Explained variance test': [explained_variance_test]}
pd_summary_row = pd.DataFrame(data)
pd_summary_row.to_csv(resultpath + "/Summary_results.csv")

with open(resultpath + '/Results_' + str(jobid) + '.txt', 'a') as f:
f.write('\n Jobid = ' + str(jobid))
f.write('\n Batchsize = ' + str(batch_size))
f.write('\n Learningrate = ' + str(lr_opt))
f.write('\n Optimizer = ' + str(optimizer_model))
f.write('\n L1 value = ' + str(l1_value))
f.write('\n')
f.write("Validation set")
f.write('\n Mean squared error = ' + str(mse_val))
f.write('\n Explained variance = ' + str(explained_variance_val))
f.write('\n R2 = ' + str(r2_val))
f.write("Test set")
f.write('\n Mean squared error = ' + str(mse_test))
f.write('\n Explained variance = ' + str(explained_variance_val))
f.write('\n R2 = ' + str(r2_test))
data = {'Jobid': args.ID,
'Datapath': str(args.path),
'genotype_path': str(genotype_path),
'Batchsize': args.batch_size,
'Learning rate': args.learning_rate,
'L1 value': args.L1,
'patience': args.patience,
'epoch size': args.epoch_size,
'epochs': args.epochs,
'MSE validation': mse_val,
'MSE test': mse_test,
'Explained variance val': explained_variance_val,
'Explained variance test': explained_variance_test,
'R2_validation': r2_val,
'R2_test': r2_test,
'SlURM_JOB_ID': SlURM_JOB_ID}

pd_summary_row = pd.Series(data)
pd_summary_row.to_csv(resultpath + "/pd_summary_results.csv")

with open(resultpath + "results_summary.txt", 'w') as f:
for key, value in data.items():
f.write('%s:%s\n' % (key, value))

if os.path.exists(datapath + "/topology.csv"):
importance_csv = create_importance_csv(datapath, model, masks)
Expand Down
11 changes: 11 additions & 0 deletions GenNet_utils/Utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
import seaborn as sns
from sklearn.metrics import mean_squared_error, explained_variance_score, r2_score


def get_SLURM_id():
SlURM_JOB_ID = "unknown"
try:
print('SlURM_JOB_ID',os.environ["SLURM_JOB_ID"])
SlURM_JOB_ID = os.environ["SLURM_JOB_ID"]
except:
print("no slurm id")
return SlURM_JOB_ID


def use_mixed_precision():
from tensorflow.keras import mixed_precision
policy = mixed_precision.Policy('mixed_float16')
Expand Down

0 comments on commit aebdca0

Please sign in to comment.