-
Notifications
You must be signed in to change notification settings - Fork 115
/
dataLoader.py
executable file
·91 lines (85 loc) · 3.73 KB
/
dataLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
'''
DataLoader for training
'''
import glob, numpy, os, random, soundfile, torch
from scipy import signal
class train_loader(object):
def __init__(self, train_list, train_path, musan_path, rir_path, num_frames, **kwargs):
self.train_path = train_path
self.num_frames = num_frames
# Load and configure augmentation files
self.noisetypes = ['noise','speech','music']
self.noisesnr = {'noise':[0,15],'speech':[13,20],'music':[5,15]}
self.numnoise = {'noise':[1,1], 'speech':[3,8], 'music':[1,1]}
self.noiselist = {}
augment_files = glob.glob(os.path.join(musan_path,'*/*/*/*.wav'))
for file in augment_files:
if file.split('/')[-4] not in self.noiselist:
self.noiselist[file.split('/')[-4]] = []
self.noiselist[file.split('/')[-4]].append(file)
self.rir_files = glob.glob(os.path.join(rir_path,'*/*/*.wav'))
# Load data & labels
self.data_list = []
self.data_label = []
lines = open(train_list).read().splitlines()
dictkeys = list(set([x.split()[0] for x in lines]))
dictkeys.sort()
dictkeys = { key : ii for ii, key in enumerate(dictkeys) }
for index, line in enumerate(lines):
speaker_label = dictkeys[line.split()[0]]
file_name = os.path.join(train_path, line.split()[1])
self.data_label.append(speaker_label)
self.data_list.append(file_name)
def __getitem__(self, index):
# Read the utterance and randomly select the segment
audio, sr = soundfile.read(self.data_list[index])
length = self.num_frames * 160 + 240
if audio.shape[0] <= length:
shortage = length - audio.shape[0]
audio = numpy.pad(audio, (0, shortage), 'wrap')
start_frame = numpy.int64(random.random()*(audio.shape[0]-length))
audio = audio[start_frame:start_frame + length]
audio = numpy.stack([audio],axis=0)
# Data Augmentation
augtype = random.randint(0,5)
if augtype == 0: # Original
audio = audio
elif augtype == 1: # Reverberation
audio = self.add_rev(audio)
elif augtype == 2: # Babble
audio = self.add_noise(audio, 'speech')
elif augtype == 3: # Music
audio = self.add_noise(audio, 'music')
elif augtype == 4: # Noise
audio = self.add_noise(audio, 'noise')
elif augtype == 5: # Television noise
audio = self.add_noise(audio, 'speech')
audio = self.add_noise(audio, 'music')
return torch.FloatTensor(audio[0]), self.data_label[index]
def __len__(self):
return len(self.data_list)
def add_rev(self, audio):
rir_file = random.choice(self.rir_files)
rir, sr = soundfile.read(rir_file)
rir = numpy.expand_dims(rir.astype(numpy.float),0)
rir = rir / numpy.sqrt(numpy.sum(rir**2))
return signal.convolve(audio, rir, mode='full')[:,:self.num_frames * 160 + 240]
def add_noise(self, audio, noisecat):
clean_db = 10 * numpy.log10(numpy.mean(audio ** 2)+1e-4)
numnoise = self.numnoise[noisecat]
noiselist = random.sample(self.noiselist[noisecat], random.randint(numnoise[0],numnoise[1]))
noises = []
for noise in noiselist:
noiseaudio, sr = soundfile.read(noise)
length = self.num_frames * 160 + 240
if noiseaudio.shape[0] <= length:
shortage = length - noiseaudio.shape[0]
noiseaudio = numpy.pad(noiseaudio, (0, shortage), 'wrap')
start_frame = numpy.int64(random.random()*(noiseaudio.shape[0]-length))
noiseaudio = noiseaudio[start_frame:start_frame + length]
noiseaudio = numpy.stack([noiseaudio],axis=0)
noise_db = 10 * numpy.log10(numpy.mean(noiseaudio ** 2)+1e-4)
noisesnr = random.uniform(self.noisesnr[noisecat][0],self.noisesnr[noisecat][1])
noises.append(numpy.sqrt(10 ** ((clean_db - noise_db - noisesnr) / 10)) * noiseaudio)
noise = numpy.sum(numpy.concatenate(noises,axis=0),axis=0,keepdims=True)
return noise + audio