forked from D1ngn/U-TasNet-Beam
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
421 lines (388 loc) · 27.5 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
#!/usr/bin/env python
# coding: utf-8
# torch関連のモジュールのimport
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
# from torch2trt import torch2trt # Xavier上で動かす場合のみ
# その他必要モジュールのimport
import os
import glob
import numpy as np
import argparse
import time
import soundfile as sf
# WARNINGの表示を全て消す場合
import warnings
warnings.simplefilter('ignore')
from models import MCComplexUnet, MCConvTasNet # 雑音・残響除去モデル、話者分離モデル
from beamformer import estimate_covariance_matrix_sig, condition_covariance, estimate_steering_vector, mvdr_beamformer, mvdr_beamformer_two_speakers, gev_beamformer, ds_beamformer, mwf, localize_music # ビームフォーマ各種
from utils.utilities import AudioProcessForComplex, spec_plot, wave_plot, count_parameters # 音声処理用
from utils.embedder import SpeechEmbedder # 話者識別用
from utils.loss_func import solve_inter_channel_permutation_problem # マルチチャンネル話者分離時に使用
from utils.asr import ASR # 音声認識用
# from utils.asr import asr_julius # 音声認識用
from evaluation.evaluate import audio_eval, asr_eval # 評価用
def main():
# コマンドライン引数を受け取る
parser = argparse.ArgumentParser(description='Real time voice separation')
parser.add_argument('-sr', '--sample_rate', type=int, default=16000, help='sampling rate') # サンプリング周波数
parser.add_argument('-bl', '--batch_length', type=int, default=48000, help='batch length of denoising model input') # 音声をバッチ処理する際の1バッチ当たりのサンプル数
parser.add_argument('-c', '--channels', type=int, default=8, help='number of input channels') # マイクのチャンネル数
parser.add_argument('-fs', '--fft_size', type=int, default=512, help='size of fast fourier transform') # 高速フーリエ変換のフレームサイズ
parser.add_argument('-hl', '--hop_length', type=int, default=160, help='number of audio samples between adjacent STFT columns') # 高速フーリエ変換におけるフレームのスライド幅
parser.add_argument('-dmt', '--denoising_model_type', type=str, default='complex_unet', help='type of denoising model (FC or BLSTM or CNN or Unet or Unet_single_mask or Unet_single_mask_two_speakers)') # 雑音(残響除去)モデルのタイプ
parser.add_argument('-ssmt', '--speaker_separation_model_type', type=str, default='conv_tasnet', help='type of speaker separator model (conv_tasnet)') # 話者分離モデルのタイプ
parser.add_argument('-bt', '--beamformer_type', type=str, default='MVDR', help='type of beamformer (DS or MVDR or GEV or MWF)') # ビームフォーマのタイプ
parser.add_argument('-dt', '--dereverb_type', type=str, default='None', help='type of dereverb algorithm (None or WPE)') # 残響除去手法のタイプ
parser.add_argument('-ep', '--embedder_path', type=str, default="./utils/embedder.pt", help='path of pretrained embedder model') # 話者識別用の学習済みモデルのパス
parser.add_argument('-rsp', '--ref_speech_path', type=str, default="./utils/ref_speech/sample.wav", help='path of reference speech') # 声を抽出したい人の発話サンプルのパス
args = parser.parse_args()
#########################音源定位用設定########################
freq_range = [200, 3000] # 空間スペクトルの算出に用いる周波数帯[Hz]
# TAMAGO-03マイクロホンアレイにおける各マイクロホンの空間的な位置関係
mic_alignments = np.array(
[
[0.035, 0.0, 0.0],
[0.035/np.sqrt(2), 0.035/np.sqrt(2), 0.0],
[0.0, 0.035, 0.0],
[-0.035/np.sqrt(2), 0.035/np.sqrt(2), 0.0],
[-0.035, 0.0, 0.0],
[-0.035/np.sqrt(2), -0.035/np.sqrt(2), 0.0],
[0.0, -0.035, 0.0],
[0.035/np.sqrt(2), -0.035/np.sqrt(2), 0.0]
])
"""mic_alignments: (num_microphones, 3D coordinates [m])"""
# 各マイクロホンの空間的な位置関係を表す配列
mic_alignments = mic_alignments.T # get the microphone arra
"""mic_alignments: (3D coordinates [m], num_microphones)"""
#############################################################
# 入力データ
target_voice_file = "./test/p232_153_p257_120_noise_mix/p232_153_target.wav"
interference_audio_file = "./test/p232_153_p257_120_noise_mix/p232_153_p257_120_interference_azimuth45.wav"
noise_file = "./test/p232_153_p257_120_noise_mix/p232_153_p257_120_noise_azimuth180.wav"
mixed_audio_file = "./test/p232_153_p257_120_noise_mix/p232_153_p257_120_mixed.wav"
wave_dir = "./output/wave/"
os.makedirs(wave_dir, exist_ok=True)
# オーディオファイルに対応する音声の波形を保存
wave_image_dir = "./output/wave_image/"
os.makedirs(wave_image_dir, exist_ok=True)
# オーディオファイルに対応するスペクトログラムを保存
spec_dir = "./output/spectrogram/"
os.makedirs(spec_dir, exist_ok=True)
# 音声認識精度評価用正解ラベルを格納したディレクトリを指定
reference_label_dir = "../AudioDatasets/NoisySpeechDatabase/testset_txt/"
# 音声認識結果を保存するディレクトリを指定
recog_result_dir = "./recog_result/{}_{}_{}_{}_dereverb_type_{}/".format(target_voice_file.split('/')[-2], args.denoising_model_type, args.speaker_separation_model_type, args.beamformer_type, str(args.dereverb_type))
os.makedirs(recog_result_dir, exist_ok=True)
# GPUが使える場合はGPUを使用、使えない場合はCPUを使用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("使用デバイス:" , device)
# ネットワークモデルの定義、チャンネルの選び方の指定、モデル入力時にパディングを行うか否かを指定
# 雑音(残響)除去モデル
if args.denoising_model_type == 'complex_unet':
checkpoint_path = "./ckpt/ckpt_NoisySpeechDataset_multi_wav_test_original_length_ComplexUnet_ch_constant_SCM_loss_finetune_SNR_loss_model_multisteplr000001start_2021119/ckpt_epoch530.pt" # Complex U-Net speech and noise output ch constant snr loss fine-tuned with SCM loss (proposed model)
denoising_model = MCComplexUnet()
padding = True
else:
print("Please specify the correct denoising model type")
# 話者分離モデル
if args.speaker_separation_model_type == 'conv_tasnet':
checkpoint_path_for_speaker_separation_model = "./ckpt/ckpt_NoisySpeechDataset_multi_wav_for_ConvTasnet_snr_loss_multisteplr00001start_for_TensorRT_20211117/ckpt_epoch480.pt"
speaker_separation_model = MCConvTasNet()
else:
print("Please specify the correct speaker separator type")
# 音声処理クラスのインスタンスを作成
# audio_processor = AudioProcess(args.sample_rate, args.fft_size, args.hop_length, channel_select_type, padding)
audio_processor = AudioProcessForComplex(args.sample_rate, args.fft_size, args.hop_length, padding)
# 学習済みのパラメータをロード
denoising_model_params = torch.load(checkpoint_path, map_location=device)
denoising_model.load_state_dict(denoising_model_params['model_state_dict'])
denoising_model.to(device) # モデルをCPUまたはGPUへ
denoising_model.eval() # ネットワークを推論モードへ
# print("モデルのパラメータ数:", count_parameters(model))
# # 入力サンプルとともにTensorRTに変換
# tmp = torch.ones((1, args.channels, int(args.fft_size/2)+1, 513, 2)).to(device)
# print(tmp.shape)
# # denoising_model = torch2trt(denoising_model, [tmp])
# denoising_model = torch2trt(denoising_model, [tmp], fp16_mode=True) # 精度によってモード切り替え
# 話者分離モデルの学習済みパラメータをロード
speaker_separation_model_params = torch.load(checkpoint_path_for_speaker_separation_model, map_location=device)
speaker_separation_model.load_state_dict(speaker_separation_model_params['model_state_dict'])
speaker_separation_model.to(device) # モデルをCPUまたはGPUへ
speaker_separation_model.eval() # ネットワークを推論モードへ
# # 入力サンプルとともにTensorRTに変換
# tmp_2 = torch.ones((1, args.channels, 48000)).to(device)
# # denoising_model = torch2trt(denoising_model, [tmp])
# speaker_separation_model = torch2trt(speaker_separation_model, [tmp_2], fp16_mode=True) # 精度によってモード切り替え
# 話者識別モデルの学習済みパタメータをロード(いずれはhparamsでパラメータを指定できる様にする TODO)
embedder = SpeechEmbedder()
embed_params = torch.load(args.embedder_path, map_location=device)
embedder.load_state_dict(embed_params)
embedder.to(device) # モデルをCPUまたはGPUへ
embedder.eval()
# 声を分離抽出したい人の発話サンプルをロードし、評価用に保存
ref_speech_data, _ = sf.read(args.ref_speech_path)
# シングルチャンネル音声の場合はチャンネルの次元を追加
if ref_speech_data.ndim == 1:
ref_speech_data = ref_speech_data[:, np.newaxis]
ref_speech_save_path = os.path.join(wave_dir, "reference_voice.wav")
sf.write(ref_speech_save_path, ref_speech_data, args.sample_rate)
# 発話サンプルの特徴量(ログメルスペクトログラム)をベクトルに変換
ref_complex_spec = audio_processor.calc_complex_spec(ref_speech_data)
ref_log_mel_spec = audio_processor.calc_log_mel_spec(ref_complex_spec)
ref_log_mel_spec = torch.from_numpy(ref_log_mel_spec).float().to(device)
# 入力サンプルとともにTensorRTに変換
# embedder = torch2trt(embedder, [torch.unsqueeze(ref_log_mel_spec[0], 0)])
ref_dvec = embedder(ref_log_mel_spec[0]) # 入力は1ch分
"""ref_dvec: (embed_dim=256,)"""
# PyTorchのテンソルからnumpy配列に変換
ref_dvec = ref_dvec.cpu().detach().numpy().copy() # CPU
# 音声認識用のインスタンスを生成
asr_ins = ASR(lang='eng')
# 処理の開始時間
start_time = time.perf_counter()
# 音声データをロード
mixed_audio_data, _ = sf.read(mixed_audio_file)
"""mixed_audio_data: (num_samples, num_channels)"""
# マルチチャンネル音声データを複素スペクトログラムに変換
mixed_complex_spec = audio_processor.calc_complex_spec(mixed_audio_data)
"""mixed_complex_spec: (num_channels, freq_bins, time_frames)"""
# 残響除去手法を指定している場合は残響除去処理を実行
if args.dereverb_type == 'WPE':
mixed_complex_spec, _ = audio_processor.dereverberation_wpe_multi(mixed_complex_spec)
# モデルに入力できるように音声をミニバッチに分けながら振幅+位相スペクトログラムに変換
# torch.stftを使用する場合
mixed_audio_data_for_model_input = torch.transpose(torch.from_numpy(mixed_audio_data).float(), 0, 1)
mixed_audio_data_for_model_input = mixed_audio_data_for_model_input.to(device) # モデルをCPUまたはGPUへ
"""mixed_audio_data_for_model_input: (num_channels, num_samples)"""
mixed_amp_phase_spec_batch = audio_processor.preprocess_mask_estimator(mixed_audio_data_for_model_input, args.batch_length)
"""amp_phase_spec_batch: (batch_size, num_channels, freq_bins, time_frames, real_imaginary)"""
# 発話とそれ以外の雑音の時間周波数マスクを推定
speech_amp_phase_spec_output, noise_amp_phase_spec_output = denoising_model(mixed_amp_phase_spec_batch)
"""speech_amp_phase_spec_output: (batch_size, num_channels, freq_bins, time_frames, real_imaginary),
noise_amp_phase_spec_output: (batch_size, num_channels, freq_bins, time_frames, real_imaginary)"""
# ミニバッチに分けられた振幅+位相スペクトログラムを時間方向に結合
multichannel_speech_amp_phase_spec= audio_processor.postprocess_mask_estimator(mixed_complex_spec, speech_amp_phase_spec_output, args.batch_length)
"""multichannel_speech_amp_phase_spec: (num_channels, freq_bins, time_frames, real_imaginary)"""
multichannel_noise_amp_phase_spec = audio_processor.postprocess_mask_estimator(mixed_complex_spec, noise_amp_phase_spec_output, args.batch_length)
"""multichannel_noise_amp_phase_spec: (num_channels, freq_bins, time_frames, real_imaginary)"""
# torch.stftを使用する場合
# 発話のマルチチャンネルスペクトログラムを音声波形に変換
multichannel_denoised_data = torch.istft(multichannel_speech_amp_phase_spec, n_fft=512, hop_length=160, \
normalized=True, length=mixed_audio_data.shape[0], return_complex=False)
"""multichannel_denoised_data: (num_channels, num_samples)"""
# 雑音のマルチチャンネルスペクトログラムを音声波形に変換
multichannel_noise_data = torch.istft(multichannel_noise_amp_phase_spec, n_fft=512, hop_length=160, \
normalized=True, length=mixed_audio_data.shape[0], return_complex=False)
"""multichannel_noise_data: (num_channels, num_samples)"""
# 話者分離モデルに入力できるようにバッチサイズの次元を追加
multichannel_denoised_data = torch.unsqueeze(multichannel_denoised_data, 0)
"""multichannel_denoised_data: (batch_size, num_channels, num_samples)"""
# 話者分離
separated_audio_data = speaker_separation_model(multichannel_denoised_data)
"""separated_audio_data: (batch_size, num_speakers, num_channels, num_samples)"""
# チャンネルごとに順序がばらばらな発話の順序を揃える
separated_audio_data = solve_inter_channel_permutation_problem(separated_audio_data)
"""separated_audio_data: (batch_size, num_speakers, num_channels, num_samples)"""
# PyTorchのテンソルをNumpy配列に変換
separated_audio_data = separated_audio_data.cpu().detach().numpy().copy() # CPU
# バッチの次元を消して転置
separated_audio_data = np.transpose(np.squeeze(separated_audio_data, 0), (0, 2, 1))
"""separated_audio_data: (num_speakers, num_samples, num_channels)"""
# 分離音から目的話者の発話を選出(何番目の発話が目的話者のものかを判断) →いずれはspeaker_selectorに統一する TODO
target_speaker_id, speech_complex_spec_all = audio_processor.speaker_selector_sig_ver(separated_audio_data, ref_dvec, embedder, device)
"""speech_complex_spec_all: (num_speakers, num_channels, freq_bins, time_frames)"""
# 目的話者の発話の複素スペクトログラムを取得
multichannel_target_complex_spec = speech_complex_spec_all[target_speaker_id]
"""multichannel_target_complex_spec: (num_channels, freq_bins, time_frames)"""
multichannel_interference_complex_spec = np.zeros_like(multichannel_target_complex_spec)
# 干渉話者の発話の複素スペクトログラムを取得
for id in range(speech_complex_spec_all.shape[0]):
# 目的話者以外の話者の複素スペクトログラムを足し合わせる
if id == target_speaker_id:
pass
else:
multichannel_interference_complex_spec += speech_complex_spec_all[id]
"""multichannel_interference_complex_spec: (num_channels, freq_bins, time_frames)"""
# PyTorchのテンソルをnumpy配列に変換
multichannel_noise_data = multichannel_noise_data.cpu().detach().numpy().copy() # CPU
"""multichannel_noise_data: (num_channels, num_samples)"""
# 雑音の複素スペクトログラムを算出
multichannel_noise_complex_spec = audio_processor.calc_complex_spec(multichannel_noise_data.T)
"""multichannel_noise_complex_spec: (num_channels, freq_bins, time_frames)"""
# 目的音のマスクと雑音のマスクからそれぞれの空間共分散行列を推定
target_covariance_matrix = estimate_covariance_matrix_sig(multichannel_target_complex_spec)
interference_covariance_matrix = estimate_covariance_matrix_sig(multichannel_interference_complex_spec)
noise_covariance_matrix = estimate_covariance_matrix_sig(multichannel_noise_complex_spec)
noise_covariance_matrix = condition_covariance(noise_covariance_matrix, 1e-6) # これがないと性能が大きく落ちる(雑音の共分散行列のみで良い)
# noise_covariance_matrix /= np.trace(noise_covariance_matrix, axis1=-2, axis2=-1)[..., None, None]
# ビームフォーマによる雑音除去を実行
if args.beamformer_type == 'MVDR':
# target_steering_vectors = estimate_steering_vector(target_covariance_matrix)
# estimated_spec = mvdr_beamformer(mixed_complex_spec, target_steering_vectors, noise_covariance_matrix)
estimated_target_spec = mvdr_beamformer_two_speakers(mixed_complex_spec, target_covariance_matrix, interference_covariance_matrix, noise_covariance_matrix)
# estimated_interference_spec = mvdr_beamformer_two_speakers(mixed_complex_spec, interference_covariance_matrix, target_covariance_matrix, noise_covariance_matrix)
elif args.beamformer_type == 'GEV':
estimated_target_spec = gev_beamformer(mixed_complex_spec, target_covariance_matrix, noise_covariance_matrix)
elif args.beamformer_type == "DS":
target_steering_vectors = estimate_steering_vector(target_covariance_matrix)
estimated_target_spec = ds_beamformer(mixed_complex_spec, target_steering_vectors)
elif args.beamformer_type == "MWF":
estimated_target_spec = mwf(mixed_complex_spec, target_covariance_matrix, noise_covariance_matrix)
else:
print("Please specify the correct beamformer type")
"""estimated_target_spec: (num_channels, freq_bins, time_frames)"""
# マルチチャンネルスペクトログラムを音声波形に変換
multichannel_estimated_target_voice_data = audio_processor.spec_to_wave(estimated_target_spec, mixed_audio_data)
# multichannel_estimated_interference_voice_data = audio_processor.spec_to_wave(estimated_interference_spec, mixed_audio_data)
"""multichannel_estimated_target_voice_data: (num_samples, num_channels)"""
# 処理の終了時間
finish_time = time.perf_counter()
# 処理時間
process_time = finish_time - start_time
print("処理時間:{:.3f}sec".format(process_time))
# 実時間比(Real Time Factor)
rtf = process_time / (mixed_audio_data.shape[0] / args.sample_rate)
print("実時間比:{:.3f}".format(rtf))
# MUSIC法を用いた音源定位
speaker_azimuth = localize_music(estimated_target_spec, mic_alignments, args.sample_rate, args.fft_size)
print("音源定位結果:", str(speaker_azimuth) + "deg")
# オーディオデータを保存
estimated_target_voice_path = os.path.join(wave_dir, "estimated_target_voice.wav")
sf.write(estimated_target_voice_path, multichannel_estimated_target_voice_data, args.sample_rate)
# estimated_interference_voice_path = os.path.join(wave_dir, "estimated_interference_voice.wav")
# sf.write(estimated_interference_voice_path, multichannel_estimated_interference_voice_data, args.sample_rate)
# 雑音除去後の混合音を保存
denoised_voice_path = os.path.join(wave_dir, "denoised_voice.wav")
# PyTorchのテンソルからnumpy配列に変換
multichannel_denoised_data = multichannel_denoised_data[0].cpu().detach().numpy().copy() # CPU
sf.write(denoised_voice_path, multichannel_denoised_data.T, args.sample_rate)
# デバッグ用に元のオーディオデータとそのスペクトログラムを保存
# 目的話者の発話
target_voice_path = os.path.join(wave_dir, "target_voice.wav")
target_voice_data, _ = sf.read(target_voice_file)
sf.write(target_voice_path, target_voice_data, args.sample_rate)
# 干渉話者の発話
interference_audio_path = os.path.join(wave_dir, "interference_audio.wav")
interference_audio_data, _ = sf.read(interference_audio_file)
sf.write(interference_audio_path, interference_audio_data, args.sample_rate)
# 雑音
noise_path = os.path.join(wave_dir, "noise.wav")
noise_data, _ = sf.read(noise_file)
sf.write(noise_path, noise_data, args.sample_rate)
# 混合音声
mixed_audio_path = os.path.join(wave_dir, "mixed_audio.wav")
sf.write(mixed_audio_path, mixed_audio_data, args.sample_rate)
# 音声の波形を画像として保存(マルチチャンネル未対応)
# 目的話者の発話の波形
target_voice_img_path = os.path.join(wave_image_dir, "target_voice.png")
wave_plot(target_voice_path, target_voice_img_path, ylim_min=-1.0, ylim_max=1.0)
# 干渉話者の発話の波形
interference_img_path = os.path.join(wave_image_dir, "interference_audio.png")
wave_plot(interference_audio_path, interference_img_path, ylim_min=-1.0, ylim_max=1.0)
# 雑音
noise_img_path = os.path.join(wave_image_dir, "noise.png")
wave_plot(noise_path, noise_img_path, ylim_min=-1.0, ylim_max=1.0)
# 分離音の波形
estimated_voice_img_path = os.path.join(wave_image_dir, "estimated_target_voice.png")
wave_plot(estimated_target_voice_path, estimated_voice_img_path, ylim_min=-1.0, ylim_max=1.0)
# 目的話者の発話サンプルの波形
ref_speech_img_path = os.path.join(wave_image_dir, "ref_speech.png")
wave_plot(args.ref_speech_path, ref_speech_img_path, ylim_min=-1.0, ylim_max=1.0)
# 混合音声の波形
mixed_audio_img_path = os.path.join(wave_image_dir, "mixed_audio.png")
wave_plot(mixed_audio_path, mixed_audio_img_path, ylim_min=-1.0, ylim_max=1.0)
# スペクトログラムを画像として保存
# 現在のディレクトリ位置を取得
base_dir = os.getcwd()
# 目的話者の発話のスペクトログラム
target_voice_spec_path = os.path.join(spec_dir, "target_voice.png")
spec_plot(base_dir, target_voice_path, target_voice_spec_path)
# 干渉話者の発話のスペクトログラム
interference_audio_spec_path = os.path.join(spec_dir, "interference_audio.png")
spec_plot(base_dir, interference_audio_path, interference_audio_spec_path)
# 雑音のスペクトログラム
noise_spec_path = os.path.join(spec_dir, "noise.png")
spec_plot(base_dir, noise_path, noise_spec_path)
# 抽出した目的話者の発話のスペクトログラム
estimated_voice_spec_path = os.path.join(spec_dir, "estimated_target_voice.png")
spec_plot(base_dir, estimated_target_voice_path, estimated_voice_spec_path)
# 混合音声のスペクトログラム
mixed_audio_spec_path = os.path.join(spec_dir, "mixed_audio.png")
spec_plot(base_dir, mixed_audio_path, mixed_audio_spec_path)
# 音源分離性能の評価
sdr_mix, sir_mix, sar_mix, sdr_est, sir_est, sar_est = \
audio_eval(args.sample_rate, target_voice_path, interference_audio_path, mixed_audio_path, estimated_target_voice_path)
# 音声認識性能の評価
# ESPnetを用いる場合
target_voice_recog_text = asr_ins.speech_recognition(target_voice_path) # (例) IT IS MARVELLOUS
target_voice_recog_text = target_voice_recog_text.replace('.', '').upper().split() # (例) ['IT', 'IS', 'MARVELLOUS']
mixed_audio_recog_text = asr_ins.speech_recognition(mixed_audio_path)
mixed_audio_recog_text = mixed_audio_recog_text.replace('.', '').upper().split()
estimated_voice_recog_text = asr_ins.speech_recognition(estimated_target_voice_path)
estimated_voice_recog_text = estimated_voice_recog_text.replace('.', '').upper().split()
# ファイル名を取得
file_num = os.path.basename(target_voice_file).split('.')[0].rsplit('_', maxsplit=1)[0] # (例) p232_016
# 正解ラベルを読み込む
reference_label_path = os.path.join(reference_label_dir, file_num + '.txt')
with open(reference_label_path, 'r', encoding="utf8") as ref:
# ピリオドとコンマを消して大文字に変換した後、スペースで分割
reference_label_text = ref.read().replace('.', '').replace(',', '').upper().split()
# WERを計算
clean_recog_result_save_path = os.path.join(recog_result_dir, file_num + '_clean.txt')
mix_recog_result_save_path = os.path.join(recog_result_dir, file_num + '_mix.txt')
est_recog_result_save_path = os.path.join(recog_result_dir, file_num + '_est.txt')
wer_clean = asr_eval(reference_label_text, target_voice_recog_text, clean_recog_result_save_path)
wer_mix = asr_eval(reference_label_text, mixed_audio_recog_text, mix_recog_result_save_path)
wer_est = asr_eval(reference_label_text, estimated_voice_recog_text, est_recog_result_save_path)
# # Juliusを用いる場合(日本語シングルチャンネル音声のみに対応)
# # 目的音
# target_voice_data, _ = sf.read(target_voice_path)
# # マルチチャンネル音声の場合は1ch目を取り出す
# if target_voice_data.ndim == 2:
# target_voice_1ch_path = "./utils/target_voice_1ch.wav"
# sf.write(target_voice_1ch_path, target_voice_data[:, 0], args.sample_rate)
# target_voice_recog_text = asr_julius(target_voice_1ch_path) # (例) IT IS MARVELLOUS
# os.remove(target_voice_1ch_path)
# else:
# target_voice_recog_text = asr_julius(target_voice_path) # (例) IT IS MARVELLOUS
# target_voice_recog_text = target_voice_recog_text.split() # (例) ['IT', 'IS', 'MARVELLOUS']
# # 混合音
# mixed_audio_data, _ = sf.read(mixed_audio_path)
# if mixed_audio_data.ndim == 2:
# mixed_audio_1ch_path = "./utils/mixed_audio_1ch.wav"
# sf.write(mixed_audio_1ch_path, mixed_audio_data[:, 0], args.sample_rate)
# mixed_audio_recog_text = asr_julius(mixed_audio_1ch_path)
# os.remove(mixed_audio_1ch_path)
# else:
# mixed_audio_recog_text = asr_julius(mixed_audio_path)
# mixed_audio_recog_text = mixed_audio_recog_text.split()
# # 処理後の目的音
# estimated_target_voice_data, _ = sf.read(estimated_target_voice_path)
# if estimated_target_voice_data.ndim == 2:
# estimated_target_voice_1ch_path = "./utils/estimated_target_voice_1ch.wav"
# sf.write(estimated_target_voice_1ch_path, estimated_target_voice_data[:, 0], args.sample_rate)
# estimated_voice_recog_text = asr_julius(estimated_target_voice_1ch_path)
# os.remove(estimated_target_voice_1ch_path)
# else:
# estimated_voice_recog_text = asr_julius(mixed_audio_path)
# estimated_voice_recog_text = estimated_voice_recog_text.split()
# # WERを計算
# clean_recog_result_save_path = os.path.join(recog_result_dir, 'clean.txt')
# mix_recog_result_save_path = os.path.join(recog_result_dir, 'mix.txt')
# est_recog_result_save_path = os.path.join(recog_result_dir, 'est.txt')
# wer_clean = asr_eval(reference_label_text, target_voice_recog_text, clean_recog_result_save_path)
# wer_mix = asr_eval(reference_label_text, mixed_audio_recog_text, mix_recog_result_save_path)
# wer_est = asr_eval(reference_label_text, estimated_voice_recog_text, est_recog_result_save_path)
print("============================音源分離性能===============================")
print("SDR_mix: {:.3f}, SIR_mix: {:.3f}, SAR_mix: {:.3f}".format(sdr_mix, sir_mix, sar_mix))
print("SDR_est: {:.3f}, SIR_est: {:.3f}, SAR_est: {:.3f}".format(sdr_est, sir_est, sar_est))
print("============================音声認識性能===============================")
print("WER_clean: {:.3f}".format(wer_clean))
print("WER_mix: {:.3f}".format(wer_mix))
print("WER_est: {:.3f}".format(wer_est))
if __name__ == "__main__":
main()