forked from jx1100370217/DFCNN-master
-
Notifications
You must be signed in to change notification settings - Fork 0
/
read_data.py
307 lines (260 loc) · 10.7 KB
/
read_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import platform as plat
import os
import numpy as np
from general_function.file_wav import *
from general_function.file_dict import *
import random
#import scipy.io.wavfile as wav
from scipy.fftpack import fft
class DataSpeech():
def __init__(self, path, type, LoadToMem = False, MemWavCount = 10000):
'''
初始化
参数:
path:数据存放位置根目录
'''
system_type = plat.system() # 由于不同的系统的文件路径表示不一样,需要进行判断
self.datapath = path; # 数据存放位置根目录
self.type = type # 数据类型,分为三种:训练集(train)、验证集(dev)、测试集(test)
self.slash = ''
if(system_type == 'Windows'):
self.slash='\\' # 反斜杠
elif(system_type == 'Linux'):
self.slash='/' # 正斜杠
else:
print('*[Message] Unknown System\n')
self.slash='/' # 正斜杠
if(self.slash != self.datapath[-1]): # 在目录路径末尾增加斜杠
self.datapath = self.datapath + self.slash
self.dic_wavlist_thchs30 = {}
self.dic_symbollist_thchs30 = {}
self.dic_wavlist_stcmds = {}
self.dic_symbollist_stcmds = {}
self.SymbolNum = 0 # 记录拼音符号数量
self.list_symbol = self.GetSymbolList() # 全部汉语拼音符号列表
self.list_wavnum=[] # wav文件标记列表
self.list_symbolnum=[] # symbol标记列表
self.DataNum = 0 # 记录数据量
self.LoadDataList()
self.wavs_data = []
self.LoadToMem = LoadToMem
self.MemWavCount = MemWavCount
pass
def LoadDataList(self):
'''
加载用于计算的数据列表
参数:
type:选取的数据集类型
train 训练集
dev 开发集
test 测试集
'''
# 设定选取哪一项作为要使用的数据集
if(self.type=='train'):
filename_wavlist_thchs30 = 'thchs30' + self.slash + 'train.wav.lst'
filename_wavlist_stcmds = 'st-cmds' + self.slash + 'train.wav.txt'
filename_wavlist_aishell = 'aishell' + self.slash + 'train.wav.lst'
filename_symbollist_thchs30 = 'thchs30' + self.slash + 'train.syllabel.txt'
filename_symbollist_stcmds = 'st-cmds' + self.slash + 'train.syllabel.txt'
filename_symbollist_aishell = 'aishell' + self.slash + 'train.syllabel.txt'
elif(self.type=='dev'):
filename_wavlist_thchs30 = 'thchs30' + self.slash + 'dev.wav.lst'
filename_wavlist_stcmds = 'st-cmds' + self.slash + 'dev.wav.txt'
filename_wavlist_aishell = 'aishell' + self.slash + 'dev.wav.lst'
filename_symbollist_thchs30 = 'thchs30' + self.slash + 'dev.syllabel.txt'
filename_symbollist_stcmds = 'st-cmds' + self.slash + 'dev.syllabel.txt'
filename_symbollist_aishell = 'aishell' + self.slash + 'dev.syllabel.txt'
elif(self.type=='test'):
filename_wavlist_thchs30 = 'thchs30' + self.slash + 'test.wav.lst'
filename_wavlist_stcmds = 'st-cmds' + self.slash + 'test.wav.txt'
filename_wavlist_aishell = 'aishell' + self.slash + 'test.wav.lst'
filename_symbollist_thchs30 = 'thchs30' + self.slash + 'test.syllabel.txt'
filename_symbollist_stcmds = 'st-cmds' + self.slash + 'test.syllabel.txt'
filename_symbollist_aishell = 'aishell' + self.slash + 'test.syllabel.txt'
else:
filename_wavlist = '' # 默认留空
filename_symbollist = ''
# 读取数据列表,wav文件列表和其对应的符号列表
self.dic_wavlist_thchs30,self.list_wavnum_thchs30 = get_wav_list(self.datapath + filename_wavlist_thchs30)
self.dic_wavlist_stcmds,self.list_wavnum_stcmds = get_wav_list(self.datapath + filename_wavlist_stcmds)
self.dic_wavlist_aishell,self.list_wavnum_aishell = get_wav_list(self.datapath + filename_wavlist_aishell)
self.dic_symbollist_thchs30,self.list_symbolnum_thchs30 = get_wav_symbol(self.datapath + filename_symbollist_thchs30)
self.dic_symbollist_stcmds,self.list_symbolnum_stcmds = get_wav_symbol(self.datapath + filename_symbollist_stcmds)
self.dic_symbollist_aishell,self.list_symbolnum_aishell = get_wav_symbol(self.datapath + filename_symbollist_aishell)
self.DataNum = self.GetDataNum()
def GetDataNum(self):
'''
获取数据的数量
当wav数量和symbol数量一致的时候返回正确的值,否则返回-1,代表出错。
'''
num_wavlist_thchs30 = len(self.dic_wavlist_thchs30)
num_symbollist_thchs30 = len(self.dic_symbollist_thchs30)
num_wavlist_stcmds = len(self.dic_wavlist_stcmds)
num_symbollist_stcmds = len(self.dic_symbollist_stcmds)
num_wavlist_aishell = len(self.dic_wavlist_aishell)
num_symbollist_aishell = len(self.dic_symbollist_aishell)
if(num_wavlist_thchs30 == num_symbollist_thchs30 and num_wavlist_stcmds == num_symbollist_stcmds and num_wavlist_aishell == num_symbollist_aishell):
DataNum = num_wavlist_thchs30 + num_wavlist_stcmds + num_wavlist_aishell
else:
DataNum = -1
return DataNum
def GetData(self,n_start,n_amount=1):
'''
读取数据,返回神经网络输入值和输出值矩阵(可直接用于神经网络训练的那种)
参数:
n_start:从编号为n_start数据开始选取数据
n_amount:选取的数据数量,默认为1,即一次一个wav文件
返回:
三个包含wav特征矩阵的神经网络输入值,和一个标定的类别矩阵神经网络输出值
'''
if(self.type=='train'):
# 读取一个文件
if(n_start < 10000):
filename = self.dic_wavlist_thchs30[self.list_wavnum_thchs30[n_start]]
list_symbol=self.dic_symbollist_thchs30[self.list_symbolnum_thchs30[n_start]]
elif(n_start >= 110000):
filename = self.dic_wavlist_aishell[self.list_wavnum_aishell[n_start-110000]]
list_symbol=self.dic_symbollist_aishell[self.list_symbolnum_aishell[n_start-110000]]
else:
filename = self.dic_wavlist_stcmds[self.list_wavnum_stcmds[n_start-10000]]
list_symbol=self.dic_symbollist_stcmds[self.list_symbolnum_stcmds[n_start-10000]]
elif(self.type=='dev'):
# 读取一个文件
if(n_start < 893):
filename = self.dic_wavlist_thchs30[self.list_wavnum_thchs30[n_start]]
list_symbol=self.dic_symbollist_thchs30[self.list_symbolnum_thchs30[n_start]]
elif(n_start >= 1493):
filename = self.dic_wavlist_aishell[self.list_wavnum_aishell[n_start-1493]]
list_symbol=self.dic_symbollist_aishell[self.list_symbolnum_aishell[n_start-1493]]
else:
filename = self.dic_wavlist_stcmds[self.list_wavnum_stcmds[n_start-893]]
list_symbol=self.dic_symbollist_stcmds[self.list_symbolnum_stcmds[n_start-893]]
else:
if(n_start < 2495):
filename = self.dic_wavlist_thchs30[self.list_wavnum_thchs30[n_start]]
list_symbol=self.dic_symbollist_thchs30[self.list_symbolnum_thchs30[n_start]]
elif(n_start >= 4495):
filename = self.dic_wavlist_aishell[self.list_wavnum_aishell[n_start-4495]]
list_symbol=self.dic_symbollist_aishell[self.list_symbolnum_aishell[n_start-4495]]
else:
filename = self.dic_wavlist_stcmds[self.list_wavnum_stcmds[n_start-2495]]
list_symbol=self.dic_symbollist_stcmds[self.list_symbolnum_stcmds[n_start-2495]]
if('Windows' == plat.system()):
filename = filename.replace('/','\\') # windows系统下需要执行这一行,对文件路径做特别处理
wavsignal,fs=read_wav_data(self.datapath + filename)
# 获取输出特征
feat_out=[]
#print("数据编号",n_start,filename)
for i in list_symbol:
if(''!=i):
n=self.SymbolToNum(i)
#v=self.NumToVector(n)
#feat_out.append(v)
feat_out.append(n)
#print('feat_out:',feat_out)
# 获取输入特征
data_input = GetFrequencyFeature3(wavsignal,fs)
#data_input = np.array(data_input)
data_input = data_input.reshape(data_input.shape[0],data_input.shape[1],1)
#arr_zero = np.zeros((1, 39), dtype=np.int16) #一个全是0的行向量
#while(len(data_input)<1600): #长度不够时补全到1600
# data_input = np.row_stack((data_input,arr_zero))
#data_input = data_input.T
data_label = np.array(feat_out)
return data_input, data_label
def data_genetator(self, batch_size=32, audio_length = 1600):
'''
数据生成器函数,用于Keras的generator_fit训练
batch_size: 一次产生的数据量
需要再修改。。。
'''
labels = []
for i in range(0,batch_size):
#input_length.append([1500])
labels.append([0.0])
labels = np.array(labels, dtype = np.float)
#print(input_length,len(input_length))
while True:
X = np.zeros((batch_size, audio_length, 200, 1), dtype = np.float)
#y = np.zeros((batch_size, 64, self.SymbolNum), dtype=np.int16)
y = np.zeros((batch_size, 64), dtype=np.int16)
#generator = ImageCaptcha(width=width, height=height)
input_length = []
label_length = []
for i in range(batch_size):
ran_num = random.randint(0,self.DataNum - 1) # 获取一个随机数
data_input, data_labels = self.GetData(ran_num) # 通过随机数取一个数据
#data_input, data_labels = self.GetData((ran_num + i) % self.DataNum) # 从随机数开始连续向后取一定数量数据
input_length.append(data_input.shape[0] // 8 + data_input.shape[0] % 8)
#print(data_input, data_labels)
#print('data_input长度:',len(data_input))
X[i,0:len(data_input)] = data_input
#print('data_labels长度:',len(data_labels))
#print(data_labels)
y[i,0:len(data_labels)] = data_labels
#print(i,y[i].shape)
#y[i] = y[i].T
#print(i,y[i].shape)
label_length.append([len(data_labels)])
label_length = np.matrix(label_length)
input_length = np.array(input_length).T
#input_length = np.array(input_length)
#print('input_length:\n',input_length)
#X=X.reshape(batch_size, audio_length, 200, 1)
#print(X)
yield [X, y, input_length, label_length ], labels
pass
def GetSymbolList(self):
'''
加载拼音符号列表,用于标记符号
返回一个列表list类型变量
'''
txt_obj=open('dict_2.txt','r',encoding='UTF-8') # 打开文件并读入
txt_text=txt_obj.read()
txt_lines=txt_text.split('\n') # 文本分割
list_symbol=[] # 初始化符号列表
for i in txt_lines:
if(i!=''):
txt_l=i.split('\t')
list_symbol.append(txt_l[0])
txt_obj.close()
list_symbol.append('_')
self.SymbolNum = len(list_symbol)
return list_symbol
def GetSymbolNum(self):
'''
获取拼音符号数量
'''
return len(self.list_symbol)
def SymbolToNum(self,symbol):
'''
符号转为数字
'''
if(symbol != ''):
return self.list_symbol.index(symbol)
return self.SymbolNum
def NumToVector(self,num):
'''
数字转为对应的向量
'''
v_tmp=[]
for i in range(0,len(self.list_symbol)):
if(i==num):
v_tmp.append(1)
else:
v_tmp.append(0)
v=np.array(v_tmp)
return v
if(__name__=='__main__'):
#path='E:\\语音数据集'
#l=DataSpeech(path)
#l.LoadDataList('train')
#print(l.GetDataNum())
#print(l.GetData(0))
#aa=l.data_genetator()
#for i in aa:
#a,b=i
#print(a,b)
pass