-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathrun_deepvae.py
219 lines (187 loc) · 8.59 KB
/
run_deepvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import argparse
import json
import logging
import os
import pickle
import numpy as np
import torch
from biva.datasets import get_binmnist_datasets, get_cifar10_datasets
from biva.evaluation import VariationalInference
from biva.model import DeepVae, get_deep_vae_mnist, get_deep_vae_cifar, VaeStage, LvaeStage, BivaStage
from biva.utils import LowerBoundedExponentialLR, training_step, test_step, summary2logger, save_model, load_model, \
sample_model, DiscretizedMixtureLogits
from booster import Aggregator
from booster.utils import EMA, logging_sep, available_device
from torch.distributions import Bernoulli
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('--root', default='runs/', help='directory to store training logs')
parser.add_argument('--data_root', default='data/', help='directory to store the dataset')
parser.add_argument('--dataset', default='binmnist', help='binmnist')
parser.add_argument('--model_type', default='biva', help='model type (vae | lvae | biva)')
parser.add_argument('--device', default='auto', help='auto, cuda, cpu')
parser.add_argument('--num_workers', default=1, type=int, help='number of workers')
parser.add_argument('--bs', default=48, type=int, help='batch size')
parser.add_argument('--epochs', default=500, type=int, help='number of epochs')
parser.add_argument('--lr', default=2e-3, type=float, help='base learning rate')
parser.add_argument('--seed', default=42, type=int, help='random seed')
parser.add_argument('--freebits', default=2.0, type=float, help='freebits per latent variable')
parser.add_argument('--nr_mix', default=10, type=int, help='number of mixtures')
parser.add_argument('--ema', default=0.9995, type=float, help='ema')
parser.add_argument('--q_dropout', default=0.5, type=float, help='inference model dropout')
parser.add_argument('--p_dropout', default=0.5, type=float, help='generative model dropout')
parser.add_argument('--iw_samples', default=1000, type=int, help='number of importance weighted samples for testing')
parser.add_argument('--id', default='', type=str, help='run id suffix')
parser.add_argument('--no_skip', action='store_true', help='do not use skip connections')
parser.add_argument('--log_var_act', default='softplus', type=str, help='activation for the log variance')
parser.add_argument('--beta', default=1.0, type=float, help='Beta parameter (Beta-VAE)')
opt = parser.parse_args()
# set random seed, set run-id, init log directory and save config
torch.manual_seed(opt.seed)
np.random.seed(opt.seed)
run_id = f"{opt.dataset}-{opt.model_type}-seed{opt.seed}"
if len(opt.id):
run_id += f"-{opt.id}"
if opt.beta != 1:
run_id += f"-{opt.beta}"
logdir = os.path.join(opt.root, run_id)
if not os.path.exists(logdir):
os.makedirs(logdir)
with open(os.path.join(logdir, 'config.json'), 'w') as fp:
fp.write(json.dumps(vars(opt)))
# define tensorboard writers
train_writer = SummaryWriter(os.path.join(logdir, 'train'))
valid_writer = SummaryWriter(os.path.join(logdir, 'valid'))
# load data
if opt.dataset == 'binmnist':
train_dataset, valid_dataset, test_dataset = get_binmnist_datasets(opt.data_root)
elif opt.dataset == 'cifar10':
from torchvision.transforms import Lambda
transform = Lambda(lambda x: x * 2 - 1)
train_dataset, valid_dataset, test_dataset = get_cifar10_datasets(opt.data_root, transform=transform)
else:
raise NotImplementedError
train_loader = DataLoader(train_dataset, batch_size=opt.bs, shuffle=True, pin_memory=False, num_workers=opt.num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=2 * opt.bs, shuffle=True, pin_memory=False,
num_workers=opt.num_workers)
test_loader = DataLoader(test_dataset, batch_size=2 * opt.bs, shuffle=True, pin_memory=False,
num_workers=opt.num_workers)
tensor_shp = (-1, *train_dataset[0].shape)
# define likelihood
likelihood = {'cifar10': DiscretizedMixtureLogits(opt.nr_mix), 'binmnist': Bernoulli}[opt.dataset]
# define model
if 'cifar' in opt.dataset:
stages, latents = get_deep_vae_cifar()
features_out = 10 * opt.nr_mix
else:
stages, latents = get_deep_vae_mnist()
features_out = tensor_shp[1]
Stage = {'vae': VaeStage, 'lvae': LvaeStage, 'biva': BivaStage}[opt.model_type]
log_var_act = {'none': None, 'softplus': torch.nn.Softplus, 'tanh': torch.nn.Tanh}[opt.log_var_act]
hyperparameters = {
'Stage': Stage,
'tensor_shp': tensor_shp,
'stages': stages,
'latents': latents,
'nonlinearity': 'elu',
'q_dropout': opt.q_dropout,
'p_dropout': opt.p_dropout,
'type': opt.model_type,
'features_out': features_out,
'no_skip': opt.no_skip,
'log_var_act': log_var_act
}
# save hyper parameters for easy loading
pickle.dump(hyperparameters, open(os.path.join(logdir, "hyperparameters.p"), "wb"))
# instantiate the model and move to target device
model = DeepVae(**hyperparameters)
device = available_device() if opt.device == 'auto' else opt.device
model.to(device)
# define the evaluator
evaluator = VariationalInference(likelihood, iw_samples=1)
# define evaluation model with Exponential Moving Average
ema = EMA(model, opt.ema)
# data dependent init for weight normalization (automatically done during the first forward pass)
with torch.no_grad():
model.train()
x = next(iter(train_loader)).to(device)
model(x)
# print stages
print(logging_sep("=") + "\nGenerative model:\n" + logging_sep("-"))
for i, (convs, z) in reversed(list(enumerate(zip(stages, latents)))):
print(f"Stage #{i + 1}")
print("Stochastic layer:", z)
print("Deterministic block:", convs)
print(logging_sep("="))
# define freebits
n_latents = len(latents)
if opt.model_type == 'biva':
n_latents = 2 * n_latents - 1
freebits = [opt.freebits] * n_latents
# optimizer
optimizer = torch.optim.Adamax(model.parameters(), lr=opt.lr, betas=(0.9, 0.999,))
scheduler = LowerBoundedExponentialLR(optimizer, 0.999999, 0.0001)
# logging utils
kwargs = {'beta': opt.beta, 'freebits': freebits}
best_elbo = (-1e20, 0, 0)
global_step = 1
train_agg = Aggregator()
val_agg = Aggregator()
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(name)-4s %(levelname)-4s %(message)s',
datefmt='%m-%d %H:%M',
handlers=[logging.FileHandler(os.path.join(logdir, 'run.log')),
logging.StreamHandler()])
train_logger = logging.getLogger('train')
eval_logger = logging.getLogger('eval')
M_parameters = (sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6)
logging.getLogger(run_id).info(f'# Total Number of Parameters: {M_parameters:.3f}M')
print(logging_sep() + f"\nLogging directory: {logdir}\n" + logging_sep())
# init sample
sample_model(ema.model, likelihood, logdir, writer=valid_writer, global_step=global_step, N=100)
# run
for epoch in range(1, opt.epochs + 1):
# training
train_agg.initialize()
for x in tqdm(train_loader, desc='train epoch'):
x = x.to(device)
diagnostics = training_step(x, model, evaluator, optimizer, scheduler, **kwargs)
train_agg.update(diagnostics)
ema.update()
global_step += 1
train_summary = train_agg.data.to('cpu')
# evaluation
val_agg.initialize()
for x in tqdm(valid_loader, desc='valid epoch'):
x = x.to(device)
diagnostics = test_step(x, ema.model, evaluator, **kwargs)
val_agg.update(diagnostics)
eval_summary = val_agg.data.to('cpu')
# keep best model
best_elbo = save_model(ema.model, eval_summary, global_step, epoch, best_elbo, logdir)
# logging
summary2logger(train_logger, train_summary, global_step, epoch)
summary2logger(eval_logger, eval_summary, global_step, epoch, best_elbo)
# tensorboard logging
train_summary.log(train_writer, global_step)
eval_summary.log(valid_writer, global_step)
# sample model
sample_model(ema.model, likelihood, logdir, writer=valid_writer, global_step=global_step, N=100)
# load best model
load_model(ema.model, logdir)
# sample model
sample_model(ema.model, likelihood, logdir, N=100)
# final test
iw_evaluator = VariationalInference(likelihood, iw_samples=opt.iw_samples)
test_agg = Aggregator()
test_logger = logging.getLogger('test')
test_logger.info(f"best elbo at step {best_elbo[1]}, epoch {best_elbo[2]}: {best_elbo[0]:.3f} nats")
test_agg.initialize()
for x in tqdm(test_loader, desc='iw test epoch'):
x = x.to(device)
diagnostics = test_step(x, ema.model, iw_evaluator, **kwargs)
test_agg.update(diagnostics)
test_summary = test_agg.data.to('cpu')
summary2logger(test_logger, test_summary, best_elbo[1], best_elbo[2], None)