forked from nschmidtg/Podcastmix
-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
196 lines (180 loc) · 6.64 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import os
import argparse
import json
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import seed_everything
from pytorch_lightning.plugins import DDPPlugin
import sys
from PodcastMixDataloader import PodcastMixDataloader
from asteroid.engine.optimizers import make_optimizer
from asteroid.engine.system import System
from logl2 import LogL2Time
seed_everything(1, workers=True)
# Keys which are not in the conf.yml file can be added here.
# In the hierarchical dictionary created when parsing, the key `key` can be
# found at dic['main_args'][key]
# By default train.py will use all available GPUs. The `id` option in run.sh
# will limit the number of available GPUs for train.py .
parser = argparse.ArgumentParser()
parser.add_argument(
"--exp_dir",
default="exp/tmp",
help="Full path to save best validation model"
)
def main(conf):
train_set = PodcastMixDataloader(
csv_dir=conf["data"]["train_dir"],
sample_rate=conf["data"]["sample_rate"],
original_sample_rate=conf["data"]["original_sample_rate"],
segment=conf["data"]["segment"],
shuffle_tracks=True,
multi_speakers=conf["training"]["multi_speakers"]
)
val_set = PodcastMixDataloader(
csv_dir=conf["data"]["valid_dir"],
sample_rate=conf["data"]["sample_rate"],
original_sample_rate=conf["data"]["original_sample_rate"],
segment=conf["data"]["segment"],
shuffle_tracks=True,
multi_speakers=conf["training"]["multi_speakers"]
)
train_loader = DataLoader(
train_set,
shuffle=True,
batch_size=conf["training"]["batch_size"],
num_workers=conf["training"]["num_workers"],
drop_last=True,
pin_memory=True
)
val_loader = DataLoader(
val_set,
shuffle=False,
batch_size=conf["training"]["batch_size"],
num_workers=conf["training"]["num_workers"],
drop_last=True,
pin_memory=True
)
if(conf["model"]["name"] == "ConvTasNet"):
sys.path.append('ConvTasNet_model')
from conv_tasnet_norm import ConvTasNetNorm
conf["masknet"].update({"n_src": conf["data"]["n_src"]})
model = ConvTasNetNorm(
**conf["filterbank"],
**conf["masknet"],
sample_rate=conf["data"]["sample_rate"]
)
loss_func = LogL2Time()
plugins = None
elif(conf["model"]["name"] == "UNet"):
# UNet with logl2 time loss and normalization inside model
sys.path.append('UNet_model')
from unet_model import UNet
model = UNet(
conf["data"]["sample_rate"],
conf["data"]["fft_size"],
conf["data"]["hop_size"],
conf["data"]["window_size"],
conf["convolution"]["kernel_size"],
conf["convolution"]["stride"]
)
loss_func = LogL2Time()
plugins = DDPPlugin(find_unused_parameters=False)
optimizer = make_optimizer(model.parameters(), **conf["optim"])
if conf["training"]["half_lr"]:
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
factor=0.5,
patience=5
)
# Just after instantiating, save the args. Easy loading in the future.
exp_dir = conf["model"]["name"] + "_model/" + conf["main_args"]["exp_dir"]
os.makedirs(exp_dir, exist_ok=True)
conf_path = os.path.join(exp_dir, "conf.yml")
with open(conf_path, "w") as outfile:
yaml.safe_dump(conf, outfile)
system = System(
model=model,
loss_func=loss_func,
optimizer=optimizer,
train_loader=train_loader,
val_loader=val_loader,
scheduler=scheduler,
config=conf
)
# Define callbacks
callbacks = []
checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="val_loss",
mode="min",
save_top_k=5,
verbose=True
)
callbacks.append(checkpoint)
if conf["training"]["early_stop"]:
callbacks.append(EarlyStopping(
monitor="val_loss",
mode="min",
patience=100,
verbose=True
))
# Don't ask GPU if they are not available.
gpus = -1 if torch.cuda.is_available() else None
distributed_backend = "ddp" if torch.cuda.is_available() else None
trainer = pl.Trainer(
max_epochs=conf["training"]["epochs"],
callbacks=callbacks,
default_root_dir=exp_dir,
gpus=gpus,
distributed_backend=distributed_backend,
gradient_clip_val=5.0,
resume_from_checkpoint=conf["main_args"]["resume_from"],
precision=32,
plugins=plugins
)
trainer.fit(system)
best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
print(best_k,f)
json.dump(best_k, f, indent=0)
print(checkpoint.best_model_path)
state_dict = torch.load(checkpoint.best_model_path)
system.load_state_dict(state_dict=state_dict["state_dict"])
system.cpu()
to_save = system.model.serialize()
to_save.update(train_set.get_infos())
torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
if __name__ == "__main__":
import yaml
from pprint import pprint
import sys
from asteroid.utils import prepare_parser_from_dict, parse_args_as_dict
# We start with opening the config file conf.yml as a dictionary from
# which we can create parsers. Each top level key in the dictionary defined
# by the YAML file creates a group in the parser.
parser.add_argument(
"--config_model", type=str, required=True, help="Asteroid model to use"
)
parser.add_argument(
"--resume_from",
type=str,
default=None,
help="path to the desired restore checkpoint with .ckpt extension"
)
config_model = sys.argv[2]
with open(config_model) as f:
def_conf = yaml.safe_load(f)
parser = prepare_parser_from_dict(def_conf, parser=parser)
# Arguments are then parsed into a hierarchical dictionary (instead of
# flat, as returned by argparse) to facilitate calls to the different
# asteroid methods (see in main).
# plain_args is the direct output of parser.parse_args() and contains all
# the attributes in an non-hierarchical structure. It can be useful to also
# have it so we included it here but it is not used.
arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True)
main(arg_dic)