Skip to content

Commit fdbb0a3

Browse files
authored
Merge pull request #482 from filtercodes/v5_cpp_support
cpp example
2 parents a395853 + 60ae7ab commit fdbb0a3

File tree

1 file changed

+15
-23
lines changed

1 file changed

+15
-23
lines changed

examples/cpp/silero-vad-onnx.cpp

+15-23
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,7 @@ class VadIterator
120120
void reset_states()
121121
{
122122
// Call reset before each audio start
123-
std::memset(_h.data(), 0.0f, _h.size() * sizeof(float));
124-
std::memset(_c.data(), 0.0f, _c.size() * sizeof(float));
123+
std::memset(_state.data(), 0.0f, _state.size() * sizeof(float));
125124
triggered = false;
126125
temp_end = 0;
127126
current_sample = 0;
@@ -139,19 +138,16 @@ class VadIterator
139138
input.assign(data.begin(), data.end());
140139
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
141140
memory_info, input.data(), input.size(), input_node_dims, 2);
141+
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
142+
memory_info, _state.data(), _state.size(), state_node_dims, 3);
142143
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
143144
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
144-
Ort::Value h_ort = Ort::Value::CreateTensor<float>(
145-
memory_info, _h.data(), _h.size(), hc_node_dims, 3);
146-
Ort::Value c_ort = Ort::Value::CreateTensor<float>(
147-
memory_info, _c.data(), _c.size(), hc_node_dims, 3);
148145

149146
// Clear and add inputs
150147
ort_inputs.clear();
151148
ort_inputs.emplace_back(std::move(input_ort));
149+
ort_inputs.emplace_back(std::move(state_ort));
152150
ort_inputs.emplace_back(std::move(sr_ort));
153-
ort_inputs.emplace_back(std::move(h_ort));
154-
ort_inputs.emplace_back(std::move(c_ort));
155151

156152
// Infer
157153
ort_outputs = session->Run(
@@ -161,10 +157,8 @@ class VadIterator
161157

162158
// Output probability & update h,c recursively
163159
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
164-
float *hn = ort_outputs[1].GetTensorMutableData<float>();
165-
std::memcpy(_h.data(), hn, size_hc * sizeof(float));
166-
float *cn = ort_outputs[2].GetTensorMutableData<float>();
167-
std::memcpy(_c.data(), cn, size_hc * sizeof(float));
160+
float *stateN = ort_outputs[1].GetTensorMutableData<float>();
161+
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
168162

169163
// Push forward sample index
170164
current_sample += window_size_samples;
@@ -376,27 +370,26 @@ class VadIterator
376370
// Inputs
377371
std::vector<Ort::Value> ort_inputs;
378372

379-
std::vector<const char *> input_node_names = {"input", "sr", "h", "c"};
373+
std::vector<const char *> input_node_names = {"input", "state", "sr"};
380374
std::vector<float> input;
375+
unsigned int size_state = 2 * 1 * 128; // It's FIXED.
376+
std::vector<float> _state;
381377
std::vector<int64_t> sr;
382-
unsigned int size_hc = 2 * 1 * 64; // It's FIXED.
383-
std::vector<float> _h;
384-
std::vector<float> _c;
385378

386-
int64_t input_node_dims[2] = {};
379+
int64_t input_node_dims[2] = {};
380+
const int64_t state_node_dims[3] = {2, 1, 128};
387381
const int64_t sr_node_dims[1] = {1};
388-
const int64_t hc_node_dims[3] = {2, 1, 64};
389382

390383
// Outputs
391384
std::vector<Ort::Value> ort_outputs;
392-
std::vector<const char *> output_node_names = {"output", "hn", "cn"};
385+
std::vector<const char *> output_node_names = {"output", "stateN"};
393386

394387
public:
395388
// Construction
396389
VadIterator(const std::wstring ModelPath,
397-
int Sample_rate = 16000, int windows_frame_size = 64,
390+
int Sample_rate = 16000, int windows_frame_size = 32,
398391
float Threshold = 0.5, int min_silence_duration_ms = 0,
399-
int speech_pad_ms = 64, int min_speech_duration_ms = 64,
392+
int speech_pad_ms = 32, int min_speech_duration_ms = 32,
400393
float max_speech_duration_s = std::numeric_limits<float>::infinity())
401394
{
402395
init_onnx_model(ModelPath);
@@ -422,8 +415,7 @@ class VadIterator
422415
input_node_dims[0] = 1;
423416
input_node_dims[1] = window_size_samples;
424417

425-
_h.resize(size_hc);
426-
_c.resize(size_hc);
418+
_state.resize(size_state);
427419
sr.resize(1);
428420
sr[0] = sample_rate;
429421
};

0 commit comments

Comments
 (0)