forked from facebookresearch/denoiser
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
executable file
·113 lines (91 loc) · 3.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# authors: adiyoss and adefossez
import logging
import os
import hydra
from denoiser.executor import start_ddp_workers
logger = logging.getLogger(__name__)
def run(args):
import torch
from denoiser import distrib
from denoiser.data import NoisyCleanSet
from denoiser.demucs import Demucs
from denoiser.solver import Solver
distrib.init(args)
# torch also initialize cuda seed if available
torch.manual_seed(args.seed)
model = Demucs(**args.demucs, sample_rate=args.sample_rate)
if args.show:
logger.info(model)
mb = sum(p.numel() for p in model.parameters()) * 4 / 2**20
logger.info('Size: %.1f MB', mb)
if hasattr(model, 'valid_length'):
field = model.valid_length(1)
logger.info('Field: %.1f ms', field / args.sample_rate * 1000)
return
assert args.batch_size % distrib.world_size == 0
args.batch_size //= distrib.world_size
length = int(args.segment * args.sample_rate)
stride = int(args.stride * args.sample_rate)
# Demucs requires a specific number of samples to avoid 0 padding during training
if hasattr(model, 'valid_length'):
length = model.valid_length(length)
kwargs = {"matching": args.dset.matching, "sample_rate": args.sample_rate}
# Building datasets and loaders
tr_dataset = NoisyCleanSet(
args.dset.train, length=length, stride=stride, pad=args.pad, **kwargs)
tr_loader = distrib.loader(
tr_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
if args.dset.valid:
cv_dataset = NoisyCleanSet(args.dset.valid, **kwargs)
cv_loader = distrib.loader(cv_dataset, batch_size=1, num_workers=args.num_workers)
else:
cv_loader = None
if args.dset.test:
tt_dataset = NoisyCleanSet(args.dset.test, **kwargs)
tt_loader = distrib.loader(tt_dataset, batch_size=1, num_workers=args.num_workers)
else:
tt_loader = None
data = {"tr_loader": tr_loader, "cv_loader": cv_loader, "tt_loader": tt_loader}
if torch.cuda.is_available():
model.cuda()
# optimizer
if args.optim == "adam":
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, args.beta2))
else:
logger.fatal('Invalid optimizer %s', args.optim)
os._exit(1)
# Construct Solver
solver = Solver(data, model, optimizer, args)
solver.train()
def _main(args):
global __file__
# Updating paths in config
for key, value in args.dset.items():
if isinstance(value, str) and key not in ["matching"]:
args.dset[key] = hydra.utils.to_absolute_path(value)
__file__ = hydra.utils.to_absolute_path(__file__)
if args.verbose:
logger.setLevel(logging.DEBUG)
logging.getLogger("denoise").setLevel(logging.DEBUG)
logger.info("For logs, checkpoints and samples check %s", os.getcwd())
logger.debug(args)
if args.ddp and args.rank is None:
start_ddp_workers(args)
else:
run(args)
@hydra.main(config_path="conf/config.yaml")
def main(args):
try:
_main(args)
except Exception:
logger.exception("Some error happened")
# Hydra intercepts exit code, fixed in beta but I could not get the beta to work
os._exit(1)
if __name__ == "__main__":
main()