forked from baofff/U-ViT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
122 lines (103 loc) · 4.63 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from tools.fid_score import calculate_fid_given_paths
import ml_collections
import torch
from torch import multiprocessing as mp
import accelerate
import utils
import sde
from datasets import get_dataset
import tempfile
from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
from absl import logging
import builtins
def evaluate(config):
if config.get('benchmark', False):
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
mp.set_start_method('spawn')
accelerator = accelerate.Accelerator()
device = accelerator.device
accelerate.utils.set_seed(config.seed, device_specific=True)
logging.info(f'Process {accelerator.process_index} using device: {device}')
config.mixed_precision = accelerator.mixed_precision
config = ml_collections.FrozenConfigDict(config)
if accelerator.is_main_process:
utils.set_logger(log_level='info', fname=config.output_path)
else:
utils.set_logger(log_level='error')
builtins.print = lambda *args: None
dataset = get_dataset(**config.dataset)
nnet = utils.get_nnet(**config.nnet)
nnet = accelerator.prepare(nnet)
logging.info(f'load nnet from {config.nnet_path}')
accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
nnet.eval()
if 'cfg' in config.sample and config.sample.cfg and config.sample.scale > 0: # classifier free guidance
logging.info(f'Use classifier free guidance with scale={config.sample.scale}')
def cfg_nnet(x, timesteps, y):
_cond = nnet(x, timesteps, y=y)
_uncond = nnet(x, timesteps, y=torch.tensor([dataset.K] * x.size(0), device=device))
return _cond + config.sample.scale * (_cond - _uncond)
score_model = sde.ScoreModel(cfg_nnet, pred=config.pred, sde=sde.VPSDE())
else:
score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE())
logging.info(config.sample)
assert os.path.exists(dataset.fid_stat)
logging.info(f'sample: n_samples={config.sample.n_samples}, mode={config.train.mode}, mixed_precision={config.mixed_precision}')
def sample_fn(_n_samples):
x_init = torch.randn(_n_samples, *dataset.data_shape, device=device)
if config.train.mode == 'uncond':
kwargs = dict()
elif config.train.mode == 'cond':
kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
else:
raise NotImplementedError
if config.sample.algorithm == 'euler_maruyama_sde':
rsde = sde.ReverseSDE(score_model)
return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs)
elif config.sample.algorithm == 'euler_maruyama_ode':
rsde = sde.ODE(score_model)
return sde.euler_maruyama(rsde, x_init, config.sample.sample_steps, verbose=accelerator.is_main_process, **kwargs)
elif config.sample.algorithm == 'dpm_solver':
noise_schedule = NoiseScheduleVP(schedule='linear')
model_fn = model_wrapper(
score_model.noise_pred,
noise_schedule,
time_input_type='0',
model_kwargs=kwargs
)
dpm_solver = DPM_Solver(model_fn, noise_schedule)
return dpm_solver.sample(
x_init,
steps=config.sample.sample_steps,
eps=1e-4,
adaptive_step_size=False,
fast_version=True,
)
else:
raise NotImplementedError
with tempfile.TemporaryDirectory() as temp_path:
path = config.sample.path or temp_path
if accelerator.is_main_process:
os.makedirs(path, exist_ok=True)
utils.sample2dir(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
if accelerator.is_main_process:
fid = calculate_fid_given_paths((dataset.fid_stat, path))
logging.info(f'nnet_path={config.nnet_path}, fid={fid}')
from absl import flags
from absl import app
from ml_collections import config_flags
import os
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
"config", None, "Training configuration.", lock_config=False)
flags.mark_flags_as_required(["config"])
flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
flags.DEFINE_string("output_path", None, "The path to output log.")
def main(argv):
config = FLAGS.config
config.nnet_path = FLAGS.nnet_path
config.output_path = FLAGS.output_path
evaluate(config)
if __name__ == "__main__":
app.run(main)