-
Notifications
You must be signed in to change notification settings - Fork 26
/
train.py
143 lines (107 loc) · 4.43 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs
"""
import itertools
import logging
import os
import shutil
import hydra
import wandb
from src.ddp.executor import start_ddp_workers
from src.models import modelFactory
from src.utils import print_network
from src.wandb_logger import _init_wandb_run
logger = logging.getLogger(__name__)
def run(args):
import torch
from src.ddp import distrib
from src.data.datasets import LrHrSet
from src.solver import Solver
logger.info(f'calling distrib.init')
distrib.init(args)
_init_wandb_run(args)
if distrib.rank == 0:
if os.path.exists(args.samples_dir):
shutil.rmtree(args.samples_dir)
os.makedirs(args.samples_dir)
# torch also initialize cuda seed if available
torch.manual_seed(args.seed)
models = modelFactory.get_model(args)
for model_name, model in models.items():
print_network(model_name, model, logger)
wandb.watch(tuple(models.values()), log=args.wandb.log, log_freq=args.wandb.log_freq)
if args.show:
logger.info(models)
mb = sum(p.numel() for p in models.parameters()) * 4 / 2 ** 20
logger.info('Size: %.1f MB', mb)
return
assert args.experiment.batch_size % distrib.world_size == 0
args.experiment.batch_size //= distrib.world_size
# Building datasets and loaders
tr_dataset = LrHrSet(args.dset.train, args.experiment.lr_sr, args.experiment.hr_sr,
args.experiment.stride, args.experiment.segment, upsample=args.experiment.upsample)
tr_loader = distrib.loader(tr_dataset, batch_size=args.experiment.batch_size, shuffle=True,
num_workers=args.num_workers)
if args.dset.valid:
args.valid_equals_test = args.dset.valid == args.dset.test
if args.dset.valid:
cv_dataset = LrHrSet(args.dset.valid, args.experiment.lr_sr, args.experiment.hr_sr,
stride=None, segment=None, upsample=args.experiment.upsample)
cv_loader = distrib.loader(cv_dataset, batch_size=1, shuffle=False, num_workers=args.num_workers)
else:
cv_loader = None
if args.dset.test:
tt_dataset = LrHrSet(args.dset.test, args.experiment.lr_sr, args.experiment.hr_sr,
stride=None, segment=None, with_path=True, upsample=args.experiment.upsample)
tt_loader = distrib.loader(tt_dataset, batch_size=1, shuffle=False, num_workers=args.num_workers)
else:
tt_loader = None
data = {"tr_loader": tr_loader, "cv_loader": cv_loader, "tt_loader": tt_loader}
if torch.cuda.is_available() and args.device=='cuda':
for model in models.values():
model.cuda()
# optimizer
if args.optim == "adam":
optimizer = torch.optim.Adam(models['generator'].parameters(), lr=args.lr, betas=(0.9, args.beta2))
else:
logger.fatal('Invalid optimizer %s', args.optim)
os._exit(1)
optimizers = {'optimizer': optimizer}
if 'adversarial' in args.experiment and args.experiment.adversarial:
disc_optimizer = torch.optim.Adam(
itertools.chain(*[models[disc_name].parameters() for disc_name in
args.experiment.discriminator_models]),
args.lr, betas=(0.9, args.beta2))
optimizers.update({'disc_optimizer': disc_optimizer})
# Construct Solver
solver = Solver(data, models, optimizers, args)
solver.train()
distrib.close()
def _main(args):
global __file__
print(args)
# Updating paths in config
for key, value in args.dset.items():
if isinstance(value, str):
args.dset[key] = hydra.utils.to_absolute_path(value)
__file__ = hydra.utils.to_absolute_path(__file__)
if args.verbose:
logger.setLevel(logging.DEBUG)
logging.getLogger("src").setLevel(logging.DEBUG)
logger.info("For logs, checkpoints and samples check %s", os.getcwd())
logger.debug(args)
if args.ddp and args.rank is None:
start_ddp_workers(args)
else:
run(args)
wandb.finish()
@hydra.main(config_path="conf", config_name="main_config") # for latest version of hydra=1.0
def main(args):
try:
_main(args)
except Exception:
logger.exception("Some error happened")
# Hydra intercepts exit code, fixed in beta but I could not get the beta to work
os._exit(1)
if __name__ == "__main__":
main()