-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
144 lines (112 loc) · 7.46 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import argparse
import os
from dataManipulation import *
from utils import summary, summary_raw, mcmc_treeprob, get_support_from_mcmc, BitArray, tree_process
from vbpi import VBPI
import time
import numpy as np
import datetime
parser = argparse.ArgumentParser()
######### Data arguments
parser.add_argument('--dataset', required=True, help=' DS1 | DS2 | DS3 | DS4 | DS5 | DS6 | DS7 | DS8 ')
parser.add_argument('--supportType', type=str, default='ufboot', help=' ufboot | mcmc ')
parser.add_argument('--empFreq', default=False, action='store_true', help='emprical frequence for KL computation')
######### Model arguments
parser.add_argument('--psp', default=False, action='store_true', help=' turn on psp branch length feature')
parser.add_argument('--nf', type=int, default=2, help=' branch length feature embedding dimension ')
parser.add_argument('--hdim', type=int, default=100, help='hidden dimension for node embedding net')
parser.add_argument('--hL', type=int, default=2, help='number of hidden layers for node embedding net')
parser.add_argument('--brlen_model', type=str, default='gnn', help='branch length models')
parser.add_argument('--gnn_type', type=str, default='gcn', help='gcn | sage | gin | ggnn')
parser.add_argument('--aggr', type=str, default='sum', help='sum | mean | max')
parser.add_argument('--proj', default=False, action='store_true', help='use projection first in SAGEConv')
parser.add_argument('--test', default=False, action='store_true', help='turn on the test mode')
parser.add_argument('--datetime', type=str, default='2022-01-01', help=' 2020-04-01 | 2020-04-02 | ...... ')
######### Optimizer arguments
parser.add_argument('--stepszTree', type=float, default=0.001, help=' step size for tree topology parameters ')
parser.add_argument('--stepszBranch', type=float, default=0.001, help=' stepsz for branch length parameters ')
parser.add_argument('--maxIter', type=int, default=200000, help=' number of iterations for training, default=400000')
parser.add_argument('--invT0', type=float, default=0.001, help=' initial inverse temperature for annealing schedule, default=0.001')
parser.add_argument('--nwarmStart', type=float, default=100000, help=' number of warm start iterations, default=100000')
parser.add_argument('--nParticle', type=int, default=10, help='number of particles for variational objectives, default=10')
parser.add_argument('--ar', type=float, default=0.75, help='step size anneal rate, default=0.75')
parser.add_argument('--af', type=int, default=20000, help='step size anneal frequency, default=20000')
parser.add_argument('--tf', type=int, default=1000, help='monitor frequency during training, default=1000')
parser.add_argument('--lbf', type=int, default=5000, help='lower bound test frequency, default=5000')
parser.add_argument('--gradMethod', type=str, default='vimco', help=' vimco | rws ')
args = parser.parse_args()
args.result_folder = 'results/' + args.dataset + '/' + args.brlen_model
if not os.path.exists(args.result_folder):
os.makedirs(args.result_folder)
args.save_to_path = args.result_folder + '/' + args.supportType + '_' + args.gradMethod + '_' + str(args.nParticle)
if args.brlen_model == 'gnn':
args.save_to_path = args.save_to_path + '_' + args.gnn_type + '_' + args.aggr
if args.psp:
args.save_to_path = args.save_to_path + '_psp'
if args.proj:
args.save_to_path = args.save_to_path + '_proj'
if args.test:
args.load_from_path = args.save_to_path + '_' + args.datetime + '.pt'
args.save_to_path = args.save_to_path + '_' + str(datetime.datetime.now()) + '.pt'
if not args.test:
print('Training with the following settings: {}'.format(args))
else:
print('Testing with the following settings: {}'.format(args))
ufboot_support_path = 'data/ufboot_data_DS1-11/'
data_path = 'data/hohna_datasets_fasta/'
ground_truth_path, samp_size = 'data/raw_data_DS1-11/', 750001
###### Load Data
print('\nLoading Data set: {} ......'.format(args.dataset))
run_time = -time.time()
if args.supportType == 'ufboot':
tree_dict_support, tree_names_support = summary_raw(args.dataset, ufboot_support_path)
elif args.supportType == 'mcmc':
tree_dict_support, tree_names_support, _ = mcmc_treeprob(mcmc_support_path + args.dataset + '.trprobs', 'nexus', taxon='keep')
data, taxa = loadData(data_path + args.dataset + '.fasta', 'fasta')
run_time += time.time()
print('Support loaded in {:.1f} seconds'.format(run_time))
if args.empFreq:
print('\nLoading empirical posterior estimates ......')
run_time = -time.time()
tree_dict_total, tree_names_total, tree_wts_total = summary(args.dataset, ground_truth_path, samp_size=samp_size)
emp_tree_freq = {tree_dict_total[tree_name]:tree_wts_total[i] for i, tree_name in enumerate(tree_names_total)}
run_time += time.time()
print('Empirical estimates from MrBayes loaded in {:.1f} seconds'.format(run_time))
else:
emp_tree_freq = None
rootsplit_supp_dict, subsplit_supp_dict = get_support_from_mcmc(taxa, tree_dict_support, tree_names_support)
del tree_dict_support, tree_names_support
model = VBPI(taxa, rootsplit_supp_dict, subsplit_supp_dict, data, pden=np.ones(4)/4., subModel=('JC', 1.0),
emp_tree_freq=emp_tree_freq, feature_dim=args.nf, psp=args.psp, hidden_dim=args.hdim, num_layers=args.hL, branch_model=args.brlen_model, gnn_type=args.gnn_type, aggr=args.aggr, project=args.proj)
print('Parameter Info:')
for param in model.parameters():
print(param.dtype, param.size())
if not args.test:
print('\nVBPI running, results will be saved to: {}\n'.format(args.save_to_path))
test_lb, test_kl_div = model.learn({'tree':args.stepszTree,'branch':args.stepszBranch}, args.maxIter, test_freq=args.tf, n_particles=args.nParticle, anneal_freq=args.af, init_inverse_temp=args.invT0,
warm_start_interval=args.nwarmStart, method=args.gradMethod, save_to_path=args.save_to_path)
np.save(args.save_to_path.replace('.pt', '_test_lb.npy'), test_lb)
if args.empFreq:
np.save(args.save_to_path.replace('.pt', '_kl_div.npy'), test_kl_div)
else:
print('Loading parameters from: {}\n'.format(args.load_from_path))
model.load_from(args.load_from_path)
print('Computing one sample lower bounds\n')
lower_bound_1_sample = np.array([model.lower_bound(n_particles=1, n_runs=1000) for i in range(100)])
np.save(args.load_from_path.replace('.pt', '_lower_bound_1_' + str(datetime.datetime.now()) + '.npy'), lower_bound_1_sample)
print('Computing ten sample lower bounds\n')
lower_bound_10_sample = np.array([model.lower_bound(n_particles=10, n_runs=1000) for i in range(100)])
np.save(args.load_from_path.replace('.pt', '_lower_bound_10_' + str(datetime.datetime.now()) + '.npy'), lower_bound_10_sample)
print('Computing marginal loglikelihood\n')
marginal_likelihood_est = np.array([model.lower_bound(n_particles=1000, n_runs=1) for i in range(1000)])
np.save(args.load_from_path.replace('.pt', '_marginal_likelihood_est_' + str(datetime.datetime.now()) +'.npy'), marginal_likelihood_est)
if args.empFreq:
tree_ci_index = np.argsort(tree_wts_total)[::-1]
print('Computing 95% confidence interval tree lower bound\n')
lower_bound_ci = []
toBitArr = BitArray(taxa)
for i in tree_ci_index[:42]:
test_tree = tree_dict_total[tree_names_total[i]].copy()
tree_process(test_tree, toBitArr)
lower_bound_ci.append(model.tree_lower_bound(test_tree, n_runs=10000))
np.save(args.save_to_path.replace('.pt', '_tree_lower_bound.npy'), lower_bound_ci)