-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
65 lines (56 loc) · 2.19 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import hydra
import wandb
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from hydra.utils import instantiate
from src.utils import DataModule, save_results
from torch import device, no_grad, exp
TOY_EXPERIMENTS = ("normal-normal", "bayes-linreg")
@hydra.main(config_path="configs", config_name="config.yaml", version_base=None)
def main(cfg):
gpu = device(cfg.gpu_device)
dataset = instantiate(cfg.simulator)
observed_data = dataset.get_observed_data()
if cfg.train.batch_size is None:
batch_size = cfg.simulator.n_sample
else:
batch_size = cfg.train.batch_size
datamodule = DataModule(
dataset, cfg.train.seed, batch_size, cfg.train.train_frac
)
model = instantiate(cfg.model, d_x=dataset.d_x, d_theta=dataset.d_theta)
if cfg.log:
wandb.init(reinit=False)
logger = WandbLogger(project='crkp')
else:
logger = None
if cfg.train.stop_early:
callbacks = callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=cfg.train.patience)]
else:
callbacks = None
trainer = L.Trainer(max_epochs=cfg.train.max_epochs, logger=logger,
devices=cfg.train.devices,
log_every_n_steps=cfg.train.log_freq, callbacks=callbacks,
fast_dev_run=cfg.fast_dev_run)
trainer.fit(model, datamodule=datamodule)
if model.name == "gdn":
posterior_params = model.predict_step(observed_data)
if cfg.simulator.name in TOY_EXPERIMENTS:
dataset.evaluate(posterior_params)
else:
save_results(posterior_params, model.val_losses, cfg)
elif model.name == "flow":
# TODO: figure out logic for saving results from normalizing flows
# this is starting to get ugly!
with no_grad():
M = cfg.n_posterior_sample
sample = model.to(gpu).sample(
M, dataset.get_observed_data(M).to(gpu)
)
if cfg.simulator.log_scale:
sample = exp(sample)
print(sample.mean(0))
wandb.finish()
if __name__ == "__main__":
main()