-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
70 lines (53 loc) · 2.55 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import re
import json
import hydra
import torch
import utils.setup as setup
from training.trainer import Trainer
from testing.tester import Tester
def _main(args):
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
global __file__
__file__ = hydra.utils.to_absolute_path(__file__)
dirname = os.path.dirname(__file__)
args.model_dir = os.path.join(dirname, str(args.model_dir))
if not os.path.exists(args.model_dir):
os.makedirs(args.model_dir)
args.exp.model_dir=args.model_dir
train_set=hydra.utils.instantiate(args.dset.train)
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=args.exp.batch_size, num_workers=args.exp.num_workers, pin_memory=True, worker_init_fn=setup.worker_init_fn,timeout=0, prefetch_factor=20)
train_loader=iter(train_loader)
test_set=hydra.utils.instantiate(args.dset.test)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=1, num_workers=args.exp.num_workers, pin_memory=True, worker_init_fn=setup.worker_init_fn)
# Diffusion parameters
diff_params=hydra.utils.instantiate(args.diff_params) #instantiate in trainer better
# Network
if args.network._target_=='networks.unet_octCQT.UNet_octCQT':
network=hydra.utils.instantiate(args.network, sample_rate=args.exp.sample_rate, audio_len=args.exp.audio_len, device=device ) #instantiate
else:
network=hydra.utils.instantiate(args.network) #instantiate in trainer better
network=network.to(device)
# Tester
args.tester.sampling_params.same_as_training = True #Make sure that we use the same HP for sampling as the ones used in training
tester=Tester(args, network, diff_params, test_set=test_loader, device=device, in_training=True)
# Trainer
trainer=hydra.utils.instantiate(args.exp.trainer, args, train_loader, network, diff_params, tester, device) # This works
# Print options.
print()
print('Training options:')
print()
print(f'Output directory: {args.model_dir}')
print(f'Network architecture: {args.network._target_}')
print(f'Dataset: {args.dset.train._target_}')
print(f'Diffusion parameterization: {args.diff_params._target_}')
print(f'Batch size: {args.exp.batch_size}')
print()
# Train.
trainer.training_loop()
@hydra.main(config_path="conf", config_name="conf", version_base=str(hydra.__version__))
def main(args):
_main(args)
if __name__ == "__main__":
main()
#----------------------------------------------------------------------------