forked from asteroid-team/asteroid
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
117 lines (100 loc) · 5.29 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import argparse
import json
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from asteroid.data.wham_dataset import WhamDataset
from asteroid.engine.system import System
from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr
from model import make_model_and_optimizer
# Keys which are not in the conf.yml file can be added here.
# In the hierarchical dictionary created when parsing, the key `key` can be
# found at dic['main_args'][key]
# By default train.py will use all available GPUs. The `id` option in run.sh
# will limit the number of available GPUs for train.py .
# This can be changed: `python train.py --gpus 0,1` will only train on 2 GPUs.
parser = argparse.ArgumentParser()
parser.add_argument('--gpus', type=str, help='list of GPUs', default='-1')
parser.add_argument('--exp_dir', default='exp/tmp',
help='Full path to save best validation model')
def main(conf):
train_set = WhamDataset(conf['data']['train_dir'], conf['data']['task'],
sample_rate=conf['data']['sample_rate'], segment=conf['data']['segment'],
nondefault_nsrc=conf['data']['nondefault_nsrc'])
val_set = WhamDataset(conf['data']['valid_dir'], conf['data']['task'],
sample_rate=conf['data']['sample_rate'],
nondefault_nsrc=conf['data']['nondefault_nsrc'])
train_loader = DataLoader(train_set, shuffle=True,
batch_size=conf['training']['batch_size'],
num_workers=conf['training']['num_workers'],
drop_last=True)
val_loader = DataLoader(val_set, shuffle=True,
batch_size=conf['training']['batch_size'],
num_workers=conf['training']['num_workers'],
drop_last=True)
# Update number of source values (It depends on the task)
conf['masknet'].update({'n_src': train_set.n_src})
# Define model and optimizer in a local function (defined in the recipe).
# Two advantages to this : re-instantiating the model and optimizer
# for retraining and evaluating is straight-forward.
model, optimizer = make_model_and_optimizer(conf)
# Define scheduler
scheduler = None
if conf['training']['half_lr']:
scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5,
patience=5)
# Just after instantiating, save the args. Easy loading in the future.
exp_dir = conf['main_args']['exp_dir']
os.makedirs(exp_dir, exist_ok=True)
conf_path = os.path.join(exp_dir, 'conf.yml')
with open(conf_path, 'w') as outfile:
yaml.safe_dump(conf, outfile)
# Define Loss function.
loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
system = System(model=model, loss_func=loss_func, optimizer=optimizer,
train_loader=train_loader, val_loader=val_loader,
scheduler=scheduler, config=conf)
# Define callbacks
checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss',
mode='min', save_top_k=5, verbose=1)
early_stopping = False
if conf['training']['early_stop']:
early_stopping = EarlyStopping(monitor='val_loss', patience=10,
verbose=1)
# Don't ask GPU if they are not available.
if not torch.cuda.is_available():
print('No available GPU were found, set gpus to None')
conf['main_args']['gpus'] = None
trainer = pl.Trainer(max_nb_epochs=conf['training']['epochs'],
checkpoint_callback=checkpoint,
early_stop_callback=early_stopping,
default_save_path=exp_dir,
gpus=conf['main_args']['gpus'],
distributed_backend='ddp',
gradient_clip_val=conf['training']["gradient_clipping"])
trainer.fit(system)
with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
json.dump(checkpoint.best_k_models, f, indent=0)
if __name__ == '__main__':
import yaml
from pprint import pprint as print
from asteroid.utils import prepare_parser_from_dict, parse_args_as_dict
# We start with opening the config file conf.yml as a dictionary from
# which we can create parsers. Each top level key in the dictionary defined
# by the YAML file creates a group in the parser.
with open('local/conf.yml') as f:
def_conf = yaml.safe_load(f)
parser = prepare_parser_from_dict(def_conf, parser=parser)
# Arguments are then parsed into a hierarchical dictionary (instead of
# flat, as returned by argparse) to facilitate calls to the different
# asteroid methods (see in main).
# plain_args is the direct output of parser.parse_args() and contains all
# the attributes in an non-hierarchical structure. It can be useful to also
# have it so we included it here but it is not used.
arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True)
print(arg_dic)
main(arg_dic)