-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate.py
83 lines (68 loc) · 2.52 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/bin/env python
# -*- coding: utf-8 -*-
#
# Created on 03.10.22
#
# Created for ddim_for_attractors
#
# @author: Tobias Sebastian Finn, [email protected]
#
# Copyright (C) {2022} {Tobias Sebastian Finn}
# System modules
import logging
import os.path
# External modules
import hydra
from omegaconf import DictConfig, OmegaConf
main_logger = logging.getLogger(__name__)
@hydra.main(version_base=None, config_path='configs/', config_name='generate')
def main_generate(cfg: DictConfig):
from hydra.utils import instantiate, get_class
import pytorch_lightning as pl
import torch
from dyn_ddim.callbacks.ema import EMA
if cfg.get("seed"):
pl.seed_everything(cfg.seed, workers=True)
main_logger.info(f"Instantiating datamodule <{cfg.data._target_}>")
data_module: pl.LightningDataModule = instantiate(cfg.data)
main_logger.info(f"Instantiating network <{cfg.network._target_}>")
network: pl.LightningModule = instantiate(
cfg.network, _recursive_=False
)
main_logger.info(f"Load network state dict")
model_checkpoint = instantiate(cfg.callbacks["model_checkpoint"])
state_dict = torch.load(os.path.join(model_checkpoint.dirpath, "last.ckpt"))
incompatible_keys = network.load_state_dict(state_dict["state_dict"])
main_logger.info(f"{incompatible_keys}")
if cfg.ema:
try:
ema_callback = EMA()
ema_callback.load_state_dict(state_dict["callbacks"]["EMA"])
ema_callback._set_ema_weights(network)
main_logger.info(f"Loaded EMA state dict from checkpoint")
except KeyError:
main_logger.warn("EMA state dict not found, using without EMA!")
main_logger.info(f"Instantiating sampler <{cfg.sampler._target_}>")
sampler = instantiate(
cfg.sampler, head=network.head, scheduler=network.scheduler,
denoising_model=network.denoising_network,
)
network.sampler = sampler
main_logger.info(f"Instantiating trainer")
trainer: pl.Trainer = instantiate(
cfg.trainer,
callbacks=None,
logger=None
)
main_logger.info(f"Starting prediction")
predictions = trainer.predict(
model=network, datamodule=data_module
)
predictions = torch.cat(predictions, dim=0)
main_logger.info(f"Store predictions to <{cfg.output_path}>")
dir_path = os.path.dirname(os.path.abspath(cfg.output_path))
if not os.path.isdir(dir_path):
os.makedirs(dir_path)
torch.save(predictions, cfg.output_path)
if __name__ == "__main__":
main_generate()