-
Notifications
You must be signed in to change notification settings - Fork 0
/
surrogate_train.py
84 lines (65 loc) · 2.31 KB
/
surrogate_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/bin/env python
# -*- coding: utf-8 -*-
#
# Created on 03.10.22
#
# Created for ddim_for_attractors
#
# @author: Tobias Sebastian Finn, [email protected]
#
# Copyright (C) {2022} {Tobias Sebastian Finn}
# System modules
import logging
# External modules
import hydra
from omegaconf import DictConfig, OmegaConf
main_logger = logging.getLogger(__name__)
def train_task(cfg: DictConfig) -> None:
# Import within main loop to speed up training on jean zay
import wandb
from hydra.utils import instantiate
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
if cfg.get("seed"):
pl.seed_everything(cfg.seed, workers=True)
main_logger.info(f"Instantiating datamodule <{cfg.data._target_}>")
data_module: pl.LightningDataModule = instantiate(cfg.data)
main_logger.info(f"Instantiating model <{cfg.surrogate._target_}>")
model: pl.LightningModule = instantiate(
cfg.surrogate,
)
if OmegaConf.select(cfg, "callbacks") is not None:
callbacks = []
for _, callback_cfg in cfg.callbacks.items():
curr_callback: pl.callbacks.Callback = instantiate(callback_cfg)
callbacks.append(curr_callback)
else:
callbacks = None
training_logger = None
if OmegaConf.select(cfg, "logger") is not None:
training_logger = instantiate(cfg.logger)
if isinstance(training_logger, WandbLogger):
main_logger.info("Watch gradients and parameters of model")
training_logger.watch(model, log="all", log_freq=75)
main_logger.info(f"Instantiating trainer")
trainer: pl.Trainer = instantiate(
cfg.trainer,
callbacks=callbacks,
logger=training_logger
)
main_logger.info(f"Starting training")
trainer.fit(model=model, datamodule=data_module)
main_logger.info(f"Training finished")
main_logger.info(f"Starting validation")
val_loss = trainer.validate(model=model, datamodule=data_module)[0]
main_logger.info(f"Validation finished")
main_logger.info(f"Validation loss: {val_loss}")
wandb.finish()
@hydra.main(version_base=None, config_path='configs/', config_name='surrogate')
def main_train(cfg: DictConfig) -> None:
try:
train_task(cfg)
except MemoryError:
pass
if __name__ == '__main__':
main_train()