-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathsample.py
130 lines (98 loc) · 4.06 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Train a diffusion model on images.
"""
import os
import hydra
import logging
import torch
import torchaudio
torch.cuda.empty_cache()
import soundfile as sf
from omegaconf import OmegaConf
from omegaconf.omegaconf import open_dict
import numpy as np
from datetime import date
#from learner import Learner
#from model import UNet
import soundfile as sf
from tqdm import tqdm
import scipy.signal
def run(args):
"""Loads all the modules and starts the sampling
Args: args: Hydra dictionary
"""
#some preparation of the hydra args
args = OmegaConf.structured(OmegaConf.to_yaml(args))
dirname = os.path.dirname(__file__)
#define the path where weights will be loaded and audio samples and other logs will be saved
args.model_dir = os.path.join(dirname, str(args.model_dir))
if not os.path.exists(args.model_dir):
os.makedirs(args.model_dir)
mode=args.inference.mode
plot_animation=True
if mode=="unconditional":
from src.experimenters.exp_unconditional import Exp_Unconditional
exp=Exp_Unconditional(args, plot_animation)
elif mode=="bandwidth_extension":
from src.experimenters.exp_bandwidth_extension import Exp_BWE
exp=Exp_BWE(args, plot_animation)
elif mode=="declipping":
from src.experimenters.exp_declipping import Exp_Declipping
exp=Exp_Declipping(args, plot_animation)
elif mode=="inpainting":
from src.experimenters.exp_inpainting import Exp_Inpainting
exp=Exp_Inpainting(args, plot_animation)
elif mode=="compressive_sensing":
from src.experimenters.exp_comp_sens import Exp_CompSens
exp=Exp_CompSens(args, plot_animation)
elif mode=="phase_retrieval":
from src.experimenters.exp_phase_retrieval import Exp_PhaseRetrieval
exp=Exp_PhaseRetrieval(args, plot_animation)
print(args.dset.name)
import src.utils.setup as utils_setup
if mode!="unconditional":
print("load test set")
test_set = utils_setup.get_test_set_for_sampling(args)
else:
test_set=None
print(test_set)
if mode!="unconditional":
for i, (original, filename) in enumerate(tqdm(test_set)):
n=os.path.splitext(filename[0])[0]
#process the file if its output has not been already generated
if not(os.path.exists(os.path.join(exp.path_reconstructed,n+".wav"))):
print("Sampling on mode ",mode," with the file ",filename[0])
original=original.float()
#resampling if the sampling rate is not 22050, this information should be given on the resmaple_factor
print(original.shape)
if args.resample_factor!=1:
S=args.resample_factor
if S>2.1 and S<2.2:
#resampling 48k to 22.05k
resample=torchaudio.transforms.Resample(160*2,147) #I use 2**12 as an arbitrary number, as we don't care about the sampling frequency of the latents
else:
N=int(args.audio_len*S)
resample=torchaudio.transforms.Resample(N,args.audio_len)
original=resample(original)
#we need to split the track somehow, here we use only one chunk
#seg=original[...,10*args.sample_rate:(10*args.sample_rate+args.audio_len)]
seg=original
print(seg.shape, original.shape)
exp.conduct_experiment(seg, n)
else:
#unconditional mode
for i in range(args.inference.unconditional.num_samples):
exp.conduct_experiment(str(i))
def _main(args):
global __file__
__file__ = hydra.utils.to_absolute_path(__file__)
run(args)
@hydra.main(config_path="conf", config_name="conf")
def main(args):
#try:
_main(args)
#except Exception:
#logger.exception("Some error happened")
# os._exit(1)
if __name__ == "__main__":
main()