-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
89 lines (71 loc) · 2.52 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#!/bin/env python
# -*- coding: utf-8 -*-
#
# Created on 03.10.22
#
# Created for ddim_for_attractors
#
# @author: Tobias Sebastian Finn, [email protected]
#
# Copyright (C) {2022} {Tobias Sebastian Finn}
# System modules
import logging
# External modules
import hydra
from omegaconf import DictConfig, OmegaConf
main_logger = logging.getLogger(__name__)
def train_task(cfg: DictConfig) -> float:
# Import within main loop to speed up training on jean zay
import wandb
from hydra.utils import instantiate
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from dyn_ddim.train_utils import log_hyperparameters
if cfg.get("seed"):
pl.seed_everything(cfg.seed, workers=True)
main_logger.info(f"Instantiating datamodule <{cfg.data._target_}>")
data_module: pl.LightningDataModule = instantiate(cfg.data)
main_logger.info(f"Instantiating network <{cfg.network._target_}>")
model: pl.LightningModule = instantiate(
cfg.network,
lr=cfg.learning_rate,
_recursive_=False
)
if OmegaConf.select(cfg, "callbacks") is not None:
callbacks = []
for _, callback_cfg in cfg.callbacks.items():
curr_callback: pl.callbacks.Callback = instantiate(callback_cfg)
callbacks.append(curr_callback)
else:
callbacks = None
training_logger = None
if OmegaConf.select(cfg, "logger") is not None:
training_logger = instantiate(cfg.logger)
if isinstance(training_logger, WandbLogger):
main_logger.info("Watch gradients and parameters of model")
hydra_params = log_hyperparameters(config=cfg, model=model)
training_logger.log_hyperparams(hydra_params)
training_logger.watch(model, log="all", log_freq=75)
main_logger.info(f"Instantiating trainer")
trainer: pl.Trainer = instantiate(
cfg.trainer,
callbacks=callbacks,
logger=training_logger
)
main_logger.info(f"Starting training")
trainer.fit(model=model, datamodule=data_module)
main_logger.info(f"Training finished")
val_loss = trainer.callback_metrics.get('val/total_loss')
main_logger.info(f"Validation loss: {val_loss}")
wandb.finish()
return val_loss
@hydra.main(version_base=None, config_path='configs/', config_name='config')
def main_train(cfg: DictConfig) -> float:
import numpy as np
try:
val_loss = train_task(cfg)
except MemoryError:
val_loss = np.inf
return val_loss
if __name__ == '__main__':
main_train()