-
Notifications
You must be signed in to change notification settings - Fork 130
/
experiment.py
executable file
·975 lines (873 loc) · 38.3 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
import copy
import json
import os
import re
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from numpy.lib.function_base import flip
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import *
from torch import nn
from torch.cuda import amp
from torch.distributions import Categorical
from torch.optim.optimizer import Optimizer
from torch.utils.data.dataset import ConcatDataset, TensorDataset
from torchvision.utils import make_grid, save_image
from config import *
from dataset import *
from dist_utils import *
from lmdb_writer import *
from metrics import *
from renderer import *
class LitModel(pl.LightningModule):
def __init__(self, conf: TrainConfig):
super().__init__()
assert conf.train_mode != TrainMode.manipulate
if conf.seed is not None:
pl.seed_everything(conf.seed)
self.save_hyperparameters(conf.as_dict_jsonable())
self.conf = conf
self.model = conf.make_model_conf().make_model()
self.ema_model = copy.deepcopy(self.model)
self.ema_model.requires_grad_(False)
self.ema_model.eval()
model_size = 0
for param in self.model.parameters():
model_size += param.data.nelement()
print('Model params: %.2f M' % (model_size / 1024 / 1024))
self.sampler = conf.make_diffusion_conf().make_sampler()
self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
# this is shared for both model and latent
self.T_sampler = conf.make_T_sampler()
if conf.train_mode.use_latent_net():
self.latent_sampler = conf.make_latent_diffusion_conf(
).make_sampler()
self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf(
).make_sampler()
else:
self.latent_sampler = None
self.eval_latent_sampler = None
# initial variables for consistent sampling
self.register_buffer(
'x_T',
torch.randn(conf.sample_size, 3, conf.img_size, conf.img_size))
if conf.pretrain is not None:
print(f'loading pretrain ... {conf.pretrain.name}')
state = torch.load(conf.pretrain.path, map_location='cpu')
print('step:', state['global_step'])
self.load_state_dict(state['state_dict'], strict=False)
if conf.latent_infer_path is not None:
print('loading latent stats ...')
state = torch.load(conf.latent_infer_path)
self.conds = state['conds']
self.register_buffer('conds_mean', state['conds_mean'][None, :])
self.register_buffer('conds_std', state['conds_std'][None, :])
else:
self.conds_mean = None
self.conds_std = None
def normalize(self, cond):
cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to(
self.device)
return cond
def denormalize(self, cond):
cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to(
self.device)
return cond
def sample(self, N, device, T=None, T_latent=None):
if T is None:
sampler = self.eval_sampler
latent_sampler = self.latent_sampler
else:
sampler = self.conf._make_diffusion_conf(T).make_sampler()
latent_sampler = self.conf._make_latent_diffusion_conf(T_latent).make_sampler()
noise = torch.randn(N,
3,
self.conf.img_size,
self.conf.img_size,
device=device)
pred_img = render_uncondition(
self.conf,
self.ema_model,
noise,
sampler=sampler,
latent_sampler=latent_sampler,
conds_mean=self.conds_mean,
conds_std=self.conds_std,
)
pred_img = (pred_img + 1) / 2
return pred_img
def render(self, noise, cond=None, T=None):
if T is None:
sampler = self.eval_sampler
else:
sampler = self.conf._make_diffusion_conf(T).make_sampler()
if cond is not None:
pred_img = render_condition(self.conf,
self.ema_model,
noise,
sampler=sampler,
cond=cond)
else:
pred_img = render_uncondition(self.conf,
self.ema_model,
noise,
sampler=sampler,
latent_sampler=None)
pred_img = (pred_img + 1) / 2
return pred_img
def encode(self, x):
# TODO:
assert self.conf.model_type.has_autoenc()
cond = self.ema_model.encoder.forward(x)
return cond
def encode_stochastic(self, x, cond, T=None):
if T is None:
sampler = self.eval_sampler
else:
sampler = self.conf._make_diffusion_conf(T).make_sampler()
out = sampler.ddim_reverse_sample_loop(self.ema_model,
x,
model_kwargs={'cond': cond})
return out['sample']
def forward(self, noise=None, x_start=None, ema_model: bool = False):
with amp.autocast(False):
if ema_model:
model = self.ema_model
else:
model = self.model
gen = self.eval_sampler.sample(model=model,
noise=noise,
x_start=x_start)
return gen
def setup(self, stage=None) -> None:
"""
make datasets & seeding each worker separately
"""
##############################################
# NEED TO SET THE SEED SEPARATELY HERE
if self.conf.seed is not None:
seed = self.conf.seed * get_world_size() + self.global_rank
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
print('local seed:', seed)
##############################################
self.train_data = self.conf.make_dataset()
print('train data:', len(self.train_data))
self.val_data = self.train_data
print('val data:', len(self.val_data))
def _train_dataloader(self, drop_last=True):
"""
really make the dataloader
"""
# make sure to use the fraction of batch size
# the batch size is global!
conf = self.conf.clone()
conf.batch_size = self.batch_size
dataloader = conf.make_loader(self.train_data,
shuffle=True,
drop_last=drop_last)
return dataloader
def train_dataloader(self):
"""
return the dataloader, if diffusion mode => return image dataset
if latent mode => return the inferred latent dataset
"""
print('on train dataloader start ...')
if self.conf.train_mode.require_dataset_infer():
if self.conds is None:
# usually we load self.conds from a file
# so we do not need to do this again!
self.conds = self.infer_whole_dataset()
# need to use float32! unless the mean & std will be off!
# (1, c)
self.conds_mean.data = self.conds.float().mean(dim=0,
keepdim=True)
self.conds_std.data = self.conds.float().std(dim=0,
keepdim=True)
print('mean:', self.conds_mean.mean(), 'std:',
self.conds_std.mean())
# return the dataset with pre-calculated conds
conf = self.conf.clone()
conf.batch_size = self.batch_size
data = TensorDataset(self.conds)
return conf.make_loader(data, shuffle=True)
else:
return self._train_dataloader()
@property
def batch_size(self):
"""
local batch size for each worker
"""
ws = get_world_size()
assert self.conf.batch_size % ws == 0
return self.conf.batch_size // ws
@property
def num_samples(self):
"""
(global) batch size * iterations
"""
# batch size here is global!
# global_step already takes into account the accum batches
return self.global_step * self.conf.batch_size_effective
def is_last_accum(self, batch_idx):
"""
is it the last gradient accumulation loop?
used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not
"""
return (batch_idx + 1) % self.conf.accum_batches == 0
def infer_whole_dataset(self,
with_render=False,
T_render=None,
render_save_path=None):
"""
predicting the latents given images using the encoder
Args:
both_flips: include both original and flipped images; no need, it's not an improvement
with_render: whether to also render the images corresponding to that latent
render_save_path: lmdb output for the rendered images
"""
data = self.conf.make_dataset()
if isinstance(data, CelebAlmdb) and data.crop_d2c:
# special case where we need the d2c crop
data.transform = make_transform(self.conf.img_size,
flip_prob=0,
crop_d2c=True)
else:
data.transform = make_transform(self.conf.img_size, flip_prob=0)
# data = SubsetDataset(data, 21)
loader = self.conf.make_loader(
data,
shuffle=False,
drop_last=False,
batch_size=self.conf.batch_size_eval,
parallel=True,
)
model = self.ema_model
model.eval()
conds = []
if with_render:
sampler = self.conf._make_diffusion_conf(
T=T_render or self.conf.T_eval).make_sampler()
if self.global_rank == 0:
writer = LMDBImageWriter(render_save_path,
format='webp',
quality=100)
else:
writer = nullcontext()
else:
writer = nullcontext()
with writer:
for batch in tqdm(loader, total=len(loader), desc='infer'):
with torch.no_grad():
# (n, c)
# print('idx:', batch['index'])
cond = model.encoder(batch['img'].to(self.device))
# used for reordering to match the original dataset
idx = batch['index']
idx = self.all_gather(idx)
if idx.dim() == 2:
idx = idx.flatten(0, 1)
argsort = idx.argsort()
if with_render:
noise = torch.randn(len(cond),
3,
self.conf.img_size,
self.conf.img_size,
device=self.device)
render = sampler.sample(model, noise=noise, cond=cond)
render = (render + 1) / 2
# print('render:', render.shape)
# (k, n, c, h, w)
render = self.all_gather(render)
if render.dim() == 5:
# (k*n, c)
render = render.flatten(0, 1)
# print('global_rank:', self.global_rank)
if self.global_rank == 0:
writer.put_images(render[argsort])
# (k, n, c)
cond = self.all_gather(cond)
if cond.dim() == 3:
# (k*n, c)
cond = cond.flatten(0, 1)
conds.append(cond[argsort].cpu())
# break
model.train()
# (N, c) cpu
conds = torch.cat(conds).float()
return conds
def training_step(self, batch, batch_idx):
"""
given an input, calculate the loss function
no optimization at this stage.
"""
with amp.autocast(False):
# batch size here is local!
# forward
if self.conf.train_mode.require_dataset_infer():
# this mode as pre-calculated cond
cond = batch[0]
if self.conf.latent_znormalize:
cond = (cond - self.conds_mean.to(
self.device)) / self.conds_std.to(self.device)
else:
imgs, idxs = batch['img'], batch['index']
# print(f'(rank {self.global_rank}) batch size:', len(imgs))
x_start = imgs
if self.conf.train_mode == TrainMode.diffusion:
"""
main training mode!!!
"""
# with numpy seed we have the problem that the sample t's are related!
t, weight = self.T_sampler.sample(len(x_start), x_start.device)
losses = self.sampler.training_losses(model=self.model,
x_start=x_start,
t=t)
elif self.conf.train_mode.is_latent_diffusion():
"""
training the latent variables!
"""
# diffusion on the latent
t, weight = self.T_sampler.sample(len(cond), cond.device)
latent_losses = self.latent_sampler.training_losses(
model=self.model.latent_net, x_start=cond, t=t)
# train only do the latent diffusion
losses = {
'latent': latent_losses['loss'],
'loss': latent_losses['loss']
}
else:
raise NotImplementedError()
loss = losses['loss'].mean()
# divide by accum batches to make the accumulated gradient exact!
for key in ['loss', 'vae', 'latent', 'mmd', 'chamfer', 'arg_cnt']:
if key in losses:
losses[key] = self.all_gather(losses[key]).mean()
if self.global_rank == 0:
self.logger.experiment.add_scalar('loss', losses['loss'],
self.num_samples)
for key in ['vae', 'latent', 'mmd', 'chamfer', 'arg_cnt']:
if key in losses:
self.logger.experiment.add_scalar(
f'loss/{key}', losses[key], self.num_samples)
return {'loss': loss}
def on_train_batch_end(self, outputs, batch, batch_idx: int,
dataloader_idx: int) -> None:
"""
after each training step ...
"""
if self.is_last_accum(batch_idx):
# only apply ema on the last gradient accumulation step,
# if it is the iteration that has optimizer.step()
if self.conf.train_mode == TrainMode.latent_diffusion:
# it trains only the latent hence change only the latent
ema(self.model.latent_net, self.ema_model.latent_net,
self.conf.ema_decay)
else:
ema(self.model, self.ema_model, self.conf.ema_decay)
# logging
if self.conf.train_mode.require_dataset_infer():
imgs = None
else:
imgs = batch['img']
self.log_sample(x_start=imgs)
self.evaluate_scores()
def on_before_optimizer_step(self, optimizer: Optimizer,
optimizer_idx: int) -> None:
# fix the fp16 + clip grad norm problem with pytorch lightinng
# this is the currently correct way to do it
if self.conf.grad_clip > 0:
# from trainer.params_grads import grads_norm, iter_opt_params
params = [
p for group in optimizer.param_groups for p in group['params']
]
# print('before:', grads_norm(iter_opt_params(optimizer)))
torch.nn.utils.clip_grad_norm_(params,
max_norm=self.conf.grad_clip)
# print('after:', grads_norm(iter_opt_params(optimizer)))
def log_sample(self, x_start):
"""
put images to the tensorboard
"""
def do(model,
postfix,
use_xstart,
save_real=False,
no_latent_diff=False,
interpolate=False):
model.eval()
with torch.no_grad():
all_x_T = self.split_tensor(self.x_T)
batch_size = min(len(all_x_T), self.conf.batch_size_eval)
# allow for superlarge models
loader = DataLoader(all_x_T, batch_size=batch_size)
Gen = []
for x_T in loader:
if use_xstart:
_xstart = x_start[:len(x_T)]
else:
_xstart = None
if self.conf.train_mode.is_latent_diffusion(
) and not use_xstart:
# diffusion of the latent first
gen = render_uncondition(
conf=self.conf,
model=model,
x_T=x_T,
sampler=self.eval_sampler,
latent_sampler=self.eval_latent_sampler,
conds_mean=self.conds_mean,
conds_std=self.conds_std)
else:
if not use_xstart and self.conf.model_type.has_noise_to_cond(
):
model: BeatGANsAutoencModel
# special case, it may not be stochastic, yet can sample
cond = torch.randn(len(x_T),
self.conf.style_ch,
device=self.device)
cond = model.noise_to_cond(cond)
else:
if interpolate:
with amp.autocast(self.conf.fp16):
cond = model.encoder(_xstart)
i = torch.randperm(len(cond))
cond = (cond + cond[i]) / 2
else:
cond = None
gen = self.eval_sampler.sample(model=model,
noise=x_T,
cond=cond,
x_start=_xstart)
Gen.append(gen)
gen = torch.cat(Gen)
gen = self.all_gather(gen)
if gen.dim() == 5:
# (n, c, h, w)
gen = gen.flatten(0, 1)
if save_real and use_xstart:
# save the original images to the tensorboard
real = self.all_gather(_xstart)
if real.dim() == 5:
real = real.flatten(0, 1)
if self.global_rank == 0:
grid_real = (make_grid(real) + 1) / 2
self.logger.experiment.add_image(
f'sample{postfix}/real', grid_real,
self.num_samples)
if self.global_rank == 0:
# save samples to the tensorboard
grid = (make_grid(gen) + 1) / 2
sample_dir = os.path.join(self.conf.logdir,
f'sample{postfix}')
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
path = os.path.join(sample_dir,
'%d.png' % self.num_samples)
save_image(grid, path)
self.logger.experiment.add_image(f'sample{postfix}', grid,
self.num_samples)
model.train()
if self.conf.sample_every_samples > 0 and is_time(
self.num_samples, self.conf.sample_every_samples,
self.conf.batch_size_effective):
if self.conf.train_mode.require_dataset_infer():
do(self.model, '', use_xstart=False)
do(self.ema_model, '_ema', use_xstart=False)
else:
if self.conf.model_type.has_autoenc(
) and self.conf.model_type.can_sample():
do(self.model, '', use_xstart=False)
do(self.ema_model, '_ema', use_xstart=False)
# autoencoding mode
do(self.model, '_enc', use_xstart=True, save_real=True)
do(self.ema_model,
'_enc_ema',
use_xstart=True,
save_real=True)
elif self.conf.train_mode.use_latent_net():
do(self.model, '', use_xstart=False)
do(self.ema_model, '_ema', use_xstart=False)
# autoencoding mode
do(self.model, '_enc', use_xstart=True, save_real=True)
do(self.model,
'_enc_nodiff',
use_xstart=True,
save_real=True,
no_latent_diff=True)
do(self.ema_model,
'_enc_ema',
use_xstart=True,
save_real=True)
else:
do(self.model, '', use_xstart=True, save_real=True)
do(self.ema_model, '_ema', use_xstart=True, save_real=True)
def evaluate_scores(self):
"""
evaluate FID and other scores during training (put to the tensorboard)
For, FID. It is a fast version with 5k images (gold standard is 50k).
Don't use its results in the paper!
"""
def fid(model, postfix):
score = evaluate_fid(self.eval_sampler,
model,
self.conf,
device=self.device,
train_data=self.train_data,
val_data=self.val_data,
latent_sampler=self.eval_latent_sampler,
conds_mean=self.conds_mean,
conds_std=self.conds_std)
if self.global_rank == 0:
self.logger.experiment.add_scalar(f'FID{postfix}', score,
self.num_samples)
if not os.path.exists(self.conf.logdir):
os.makedirs(self.conf.logdir)
with open(os.path.join(self.conf.logdir, 'eval.txt'),
'a') as f:
metrics = {
f'FID{postfix}': score,
'num_samples': self.num_samples,
}
f.write(json.dumps(metrics) + "\n")
def lpips(model, postfix):
if self.conf.model_type.has_autoenc(
) and self.conf.train_mode.is_autoenc():
# {'lpips', 'ssim', 'mse'}
score = evaluate_lpips(self.eval_sampler,
model,
self.conf,
device=self.device,
val_data=self.val_data,
latent_sampler=self.eval_latent_sampler)
if self.global_rank == 0:
for key, val in score.items():
self.logger.experiment.add_scalar(
f'{key}{postfix}', val, self.num_samples)
if self.conf.eval_every_samples > 0 and self.num_samples > 0 and is_time(
self.num_samples, self.conf.eval_every_samples,
self.conf.batch_size_effective):
print(f'eval fid @ {self.num_samples}')
lpips(self.model, '')
fid(self.model, '')
if self.conf.eval_ema_every_samples > 0 and self.num_samples > 0 and is_time(
self.num_samples, self.conf.eval_ema_every_samples,
self.conf.batch_size_effective):
print(f'eval fid ema @ {self.num_samples}')
fid(self.ema_model, '_ema')
# it's too slow
# lpips(self.ema_model, '_ema')
def configure_optimizers(self):
out = {}
if self.conf.optimizer == OptimizerType.adam:
optim = torch.optim.Adam(self.model.parameters(),
lr=self.conf.lr,
weight_decay=self.conf.weight_decay)
elif self.conf.optimizer == OptimizerType.adamw:
optim = torch.optim.AdamW(self.model.parameters(),
lr=self.conf.lr,
weight_decay=self.conf.weight_decay)
else:
raise NotImplementedError()
out['optimizer'] = optim
if self.conf.warmup > 0:
sched = torch.optim.lr_scheduler.LambdaLR(optim,
lr_lambda=WarmupLR(
self.conf.warmup))
out['lr_scheduler'] = {
'scheduler': sched,
'interval': 'step',
}
return out
def split_tensor(self, x):
"""
extract the tensor for a corresponding "worker" in the batch dimension
Args:
x: (n, c)
Returns: x: (n_local, c)
"""
n = len(x)
rank = self.global_rank
world_size = get_world_size()
# print(f'rank: {rank}/{world_size}')
per_rank = n // world_size
return x[rank * per_rank:(rank + 1) * per_rank]
def test_step(self, batch, *args, **kwargs):
"""
for the "eval" mode.
We first select what to do according to the "conf.eval_programs".
test_step will only run for "one iteration" (it's a hack!).
We just want the multi-gpu support.
"""
# make sure you seed each worker differently!
self.setup()
# it will run only one step!
print('global step:', self.global_step)
"""
"infer" = predict the latent variables using the encoder on the whole dataset
"""
if 'infer' in self.conf.eval_programs:
if 'infer' in self.conf.eval_programs:
print('infer ...')
conds = self.infer_whole_dataset().float()
# NOTE: always use this path for the latent.pkl files
save_path = f'checkpoints/{self.conf.name}/latent.pkl'
else:
raise NotImplementedError()
if self.global_rank == 0:
conds_mean = conds.mean(dim=0)
conds_std = conds.std(dim=0)
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
torch.save(
{
'conds': conds,
'conds_mean': conds_mean,
'conds_std': conds_std,
}, save_path)
"""
"infer+render" = predict the latent variables using the encoder on the whole dataset
THIS ALSO GENERATE CORRESPONDING IMAGES
"""
# infer + reconstruction quality of the input
for each in self.conf.eval_programs:
if each.startswith('infer+render'):
m = re.match(r'infer\+render([0-9]+)', each)
if m is not None:
T = int(m[1])
self.setup()
print(f'infer + reconstruction T{T} ...')
conds = self.infer_whole_dataset(
with_render=True,
T_render=T,
render_save_path=
f'latent_infer_render{T}/{self.conf.name}.lmdb',
)
save_path = f'latent_infer_render{T}/{self.conf.name}.pkl'
conds_mean = conds.mean(dim=0)
conds_std = conds.std(dim=0)
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
torch.save(
{
'conds': conds,
'conds_mean': conds_mean,
'conds_std': conds_std,
}, save_path)
# evals those "fidXX"
"""
"fid<T>" = unconditional generation (conf.train_mode = diffusion).
Note: Diff. autoenc will still receive real images in this mode.
"fid<T>,<T_latent>" = unconditional generation for latent models (conf.train_mode = latent_diffusion).
Note: Diff. autoenc will still NOT receive real images in this made.
but you need to make sure that the train_mode is latent_diffusion.
"""
for each in self.conf.eval_programs:
if each.startswith('fid'):
m = re.match(r'fid\(([0-9]+),([0-9]+)\)', each)
clip_latent_noise = False
if m is not None:
# eval(T1,T2)
T = int(m[1])
T_latent = int(m[2])
print(f'evaluating FID T = {T}... latent T = {T_latent}')
else:
m = re.match(r'fidclip\(([0-9]+),([0-9]+)\)', each)
if m is not None:
# fidclip(T1,T2)
T = int(m[1])
T_latent = int(m[2])
clip_latent_noise = True
print(
f'evaluating FID (clip latent noise) T = {T}... latent T = {T_latent}'
)
else:
# evalT
_, T = each.split('fid')
T = int(T)
T_latent = None
print(f'evaluating FID T = {T}...')
self.train_dataloader()
sampler = self.conf._make_diffusion_conf(T=T).make_sampler()
if T_latent is not None:
latent_sampler = self.conf._make_latent_diffusion_conf(
T=T_latent).make_sampler()
else:
latent_sampler = None
conf = self.conf.clone()
conf.eval_num_images = 50_000
score = evaluate_fid(
sampler,
self.ema_model,
conf,
device=self.device,
train_data=self.train_data,
val_data=self.val_data,
latent_sampler=latent_sampler,
conds_mean=self.conds_mean,
conds_std=self.conds_std,
remove_cache=False,
clip_latent_noise=clip_latent_noise,
)
if T_latent is None:
self.log(f'fid_ema_T{T}', score)
else:
name = 'fid'
if clip_latent_noise:
name += '_clip'
name += f'_ema_T{T}_Tlatent{T_latent}'
self.log(name, score)
"""
"recon<T>" = reconstruction & autoencoding (without noise inversion)
"""
for each in self.conf.eval_programs:
if each.startswith('recon'):
self.model: BeatGANsAutoencModel
_, T = each.split('recon')
T = int(T)
print(f'evaluating reconstruction T = {T}...')
sampler = self.conf._make_diffusion_conf(T=T).make_sampler()
conf = self.conf.clone()
# eval whole val dataset
conf.eval_num_images = len(self.val_data)
# {'lpips', 'mse', 'ssim'}
score = evaluate_lpips(sampler,
self.ema_model,
conf,
device=self.device,
val_data=self.val_data,
latent_sampler=None)
for k, v in score.items():
self.log(f'{k}_ema_T{T}', v)
"""
"inv<T>" = reconstruction with noise inversion
"""
for each in self.conf.eval_programs:
if each.startswith('inv'):
self.model: BeatGANsAutoencModel
_, T = each.split('inv')
T = int(T)
print(
f'evaluating reconstruction with noise inversion T = {T}...'
)
sampler = self.conf._make_diffusion_conf(T=T).make_sampler()
conf = self.conf.clone()
# eval whole val dataset
conf.eval_num_images = len(self.val_data)
# {'lpips', 'mse', 'ssim'}
score = evaluate_lpips(sampler,
self.ema_model,
conf,
device=self.device,
val_data=self.val_data,
latent_sampler=None,
use_inverted_noise=True)
for k, v in score.items():
self.log(f'{k}_inv_ema_T{T}', v)
def ema(source, target, decay):
source_dict = source.state_dict()
target_dict = target.state_dict()
for key in source_dict.keys():
target_dict[key].data.copy_(target_dict[key].data * decay +
source_dict[key].data * (1 - decay))
class WarmupLR:
def __init__(self, warmup) -> None:
self.warmup = warmup
def __call__(self, step):
return min(step, self.warmup) / self.warmup
def is_time(num_samples, every, step_size):
closest = (num_samples // every) * every
return num_samples - closest < step_size
def train(conf: TrainConfig, gpus, nodes=1, mode: str = 'train'):
print('conf:', conf.name)
# assert not (conf.fp16 and conf.grad_clip > 0
# ), 'pytorch lightning has bug with amp + gradient clipping'
model = LitModel(conf)
if not os.path.exists(conf.logdir):
os.makedirs(conf.logdir)
checkpoint = ModelCheckpoint(dirpath=f'{conf.logdir}',
save_last=True,
save_top_k=1,
every_n_train_steps=conf.save_every_samples //
conf.batch_size_effective)
checkpoint_path = f'{conf.logdir}/last.ckpt'
print('ckpt path:', checkpoint_path)
if os.path.exists(checkpoint_path):
resume = checkpoint_path
print('resume!')
else:
if conf.continue_from is not None:
# continue from a checkpoint
resume = conf.continue_from.path
else:
resume = None
tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir,
name=None,
version='')
# from pytorch_lightning.
plugins = []
if len(gpus) == 1 and nodes == 1:
accelerator = None
else:
accelerator = 'ddp'
from pytorch_lightning.plugins import DDPPlugin
# important for working with gradient checkpoint
plugins.append(DDPPlugin(find_unused_parameters=False))
trainer = pl.Trainer(
max_steps=conf.total_samples // conf.batch_size_effective,
resume_from_checkpoint=resume,
gpus=gpus,
num_nodes=nodes,
accelerator=accelerator,
precision=16 if conf.fp16 else 32,
callbacks=[
checkpoint,
LearningRateMonitor(),
],
# clip in the model instead
# gradient_clip_val=conf.grad_clip,
replace_sampler_ddp=True,
logger=tb_logger,
accumulate_grad_batches=conf.accum_batches,
plugins=plugins,
)
if mode == 'train':
trainer.fit(model)
elif mode == 'eval':
# load the latest checkpoint
# perform lpips
# dummy loader to allow calling "test_step"
dummy = DataLoader(TensorDataset(torch.tensor([0.] * conf.batch_size)),
batch_size=conf.batch_size)
eval_path = conf.eval_path or checkpoint_path
# conf.eval_num_images = 50
print('loading from:', eval_path)
state = torch.load(eval_path, map_location='cpu')
print('step:', state['global_step'])
model.load_state_dict(state['state_dict'])
# trainer.fit(model)
out = trainer.test(model, dataloaders=dummy)
# first (and only) loader
out = out[0]
print(out)
if get_rank() == 0:
# save to tensorboard
for k, v in out.items():
tb_logger.experiment.add_scalar(
k, v, state['global_step'] * conf.batch_size_effective)
# # save to file
# # make it a dict of list
# for k, v in out.items():
# out[k] = [v]
tgt = f'evals/{conf.name}.txt'
dirname = os.path.dirname(tgt)
if not os.path.exists(dirname):
os.makedirs(dirname)
with open(tgt, 'a') as f:
f.write(json.dumps(out) + "\n")
# pd.DataFrame(out).to_csv(tgt)
else:
raise NotImplementedError()