Skip to content

Commit addca5e

Browse files
authored
Merge pull request #93 from snakers4/adamnsandle
additional vad utils
2 parents 1707106 + 1fc6b72 commit addca5e

File tree

2 files changed

+67
-4
lines changed

2 files changed

+67
-4
lines changed

utils_vad.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,11 @@ def get_speech_ts(wav: torch.Tensor,
8686
min_speech_samples: int = 10000, #samples
8787
min_silence_samples: int = 500,
8888
run_function=validate,
89-
visualize_probs=False):
89+
visualize_probs=False,
90+
smoothed_prob_func='mean',
91+
device='cpu'):
9092

93+
assert smoothed_prob_func in ['mean', 'max'], 'smoothed_prob_func not in ["max", "mean"]'
9194
num_samples = num_samples_per_window
9295
assert num_samples % num_steps == 0
9396
step = int(num_samples / num_steps) # stride / hop
@@ -99,13 +102,13 @@ def get_speech_ts(wav: torch.Tensor,
99102
chunk = F.pad(chunk, (0, num_samples - len(chunk)))
100103
to_concat.append(chunk.unsqueeze(0))
101104
if len(to_concat) >= batch_size:
102-
chunks = torch.Tensor(torch.cat(to_concat, dim=0))
105+
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
103106
out = run_function(model, chunks)
104107
outs.append(out)
105108
to_concat = []
106109

107110
if to_concat:
108-
chunks = torch.Tensor(torch.cat(to_concat, dim=0))
111+
chunks = torch.Tensor(torch.cat(to_concat, dim=0)).to(device)
109112
out = run_function(model, chunks)
110113
outs.append(out)
111114

@@ -123,7 +126,11 @@ def get_speech_ts(wav: torch.Tensor,
123126
temp_end = 0
124127
for i, predict in enumerate(speech_probs): # add name
125128
buffer.append(predict)
126-
smoothed_prob = (sum(buffer) / len(buffer))
129+
if smoothed_prob_func == 'mean':
130+
smoothed_prob = (sum(buffer) / len(buffer))
131+
elif smoothed_prob_func == 'max':
132+
smoothed_prob = max(buffer)
133+
127134
if visualize_probs:
128135
smoothed_probs.append(float(smoothed_prob))
129136
if (smoothed_prob >= trig_sum) and temp_end:

utils_vad_additional.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from utils_vad import *
2+
import sys
3+
import os
4+
from pathlib import Path
5+
sys.path.append('/home/keras/notebook/nvme_raid/adamnsandle/silero_mono/pipelines/align/bin/')
6+
from align_utils import load_audio_norm
7+
import torch
8+
import pandas as pd
9+
import numpy as np
10+
sys.path.append('/home/keras/notebook/nvme_raid/adamnsandle/silero_mono/utils/')
11+
from open_stt import soundfile_opus as sf
12+
13+
def split_save_audio_chunks(audio_path, model_path, save_path=None, device='cpu', absolute=True, max_duration=10, adaptive=False, **kwargs):
14+
15+
if not save_path:
16+
save_path = str(Path(audio_path).with_name('after_vad'))
17+
print(f'No save path specified! Using {save_path} to save audio chunks!')
18+
19+
SAMPLE_RATE = 16000
20+
if type(model_path) == str:
21+
#print('Loading model...')
22+
model = init_jit_model(model_path, device)
23+
else:
24+
#print('Using loaded model')
25+
model = model_path
26+
save_name = Path(audio_path).stem
27+
audio, sr = load_audio_norm(audio_path)
28+
wav = torch.tensor(audio)
29+
if adaptive:
30+
speech_timestamps = get_speech_ts_adaptive(wav, model, device=device, **kwargs)
31+
else:
32+
speech_timestamps = get_speech_ts(wav, model, device=device, **kwargs)
33+
34+
full_save_path = Path(save_path, save_name)
35+
if not os.path.exists(full_save_path):
36+
os.makedirs(full_save_path, exist_ok=True)
37+
38+
chunks = []
39+
if not speech_timestamps:
40+
return pd.DataFrame()
41+
for ts in speech_timestamps:
42+
start_ts = int(ts['start'])
43+
end_ts = int(ts['end'])
44+
45+
for i in range(start_ts, end_ts, max_duration * SAMPLE_RATE):
46+
new_start = i
47+
new_end = min(end_ts, i + max_duration * SAMPLE_RATE)
48+
duration = round((new_end - new_start) / SAMPLE_RATE, 2)
49+
chunk_path = Path(full_save_path, f'{save_name}_{new_start}-{new_end}.opus')
50+
chunk_path = chunk_path.absolute() if absolute else chunk_path
51+
sf.write(str(chunk_path), audio[new_start: new_end], 16000, format='OGG', subtype='OPUS')
52+
chunks.append({'audio_path': chunk_path,
53+
'text': '',
54+
'duration': duration,
55+
'domain': ''})
56+
return pd.DataFrame(chunks)

0 commit comments

Comments
 (0)