@@ -120,8 +120,7 @@ class VadIterator
120
120
void reset_states ()
121
121
{
122
122
// Call reset before each audio start
123
- std::memset (_h.data (), 0 .0f , _h.size () * sizeof (float ));
124
- std::memset (_c.data (), 0 .0f , _c.size () * sizeof (float ));
123
+ std::memset (_state.data (), 0 .0f , _state.size () * sizeof (float ));
125
124
triggered = false ;
126
125
temp_end = 0 ;
127
126
current_sample = 0 ;
@@ -139,19 +138,16 @@ class VadIterator
139
138
input.assign (data.begin (), data.end ());
140
139
Ort::Value input_ort = Ort::Value::CreateTensor<float >(
141
140
memory_info, input.data (), input.size (), input_node_dims, 2 );
141
+ Ort::Value state_ort = Ort::Value::CreateTensor<float >(
142
+ memory_info, _state.data (), _state.size (), state_node_dims, 3 );
142
143
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t >(
143
144
memory_info, sr.data (), sr.size (), sr_node_dims, 1 );
144
- Ort::Value h_ort = Ort::Value::CreateTensor<float >(
145
- memory_info, _h.data (), _h.size (), hc_node_dims, 3 );
146
- Ort::Value c_ort = Ort::Value::CreateTensor<float >(
147
- memory_info, _c.data (), _c.size (), hc_node_dims, 3 );
148
145
149
146
// Clear and add inputs
150
147
ort_inputs.clear ();
151
148
ort_inputs.emplace_back (std::move (input_ort));
149
+ ort_inputs.emplace_back (std::move (state_ort));
152
150
ort_inputs.emplace_back (std::move (sr_ort));
153
- ort_inputs.emplace_back (std::move (h_ort));
154
- ort_inputs.emplace_back (std::move (c_ort));
155
151
156
152
// Infer
157
153
ort_outputs = session->Run (
@@ -161,10 +157,8 @@ class VadIterator
161
157
162
158
// Output probability & update h,c recursively
163
159
float speech_prob = ort_outputs[0 ].GetTensorMutableData <float >()[0 ];
164
- float *hn = ort_outputs[1 ].GetTensorMutableData <float >();
165
- std::memcpy (_h.data (), hn, size_hc * sizeof (float ));
166
- float *cn = ort_outputs[2 ].GetTensorMutableData <float >();
167
- std::memcpy (_c.data (), cn, size_hc * sizeof (float ));
160
+ float *stateN = ort_outputs[1 ].GetTensorMutableData <float >();
161
+ std::memcpy (_state.data (), stateN, size_state * sizeof (float ));
168
162
169
163
// Push forward sample index
170
164
current_sample += window_size_samples;
@@ -376,27 +370,26 @@ class VadIterator
376
370
// Inputs
377
371
std::vector<Ort::Value> ort_inputs;
378
372
379
- std::vector<const char *> input_node_names = {" input" , " sr " , " h " , " c " };
373
+ std::vector<const char *> input_node_names = {" input" , " state " , " sr " };
380
374
std::vector<float > input;
375
+ unsigned int size_state = 2 * 1 * 128 ; // It's FIXED.
376
+ std::vector<float > _state;
381
377
std::vector<int64_t > sr;
382
- unsigned int size_hc = 2 * 1 * 64 ; // It's FIXED.
383
- std::vector<float > _h;
384
- std::vector<float > _c;
385
378
386
- int64_t input_node_dims[2 ] = {};
379
+ int64_t input_node_dims[2 ] = {};
380
+ const int64_t state_node_dims[3 ] = {2 , 1 , 128 };
387
381
const int64_t sr_node_dims[1 ] = {1 };
388
- const int64_t hc_node_dims[3 ] = {2 , 1 , 64 };
389
382
390
383
// Outputs
391
384
std::vector<Ort::Value> ort_outputs;
392
- std::vector<const char *> output_node_names = {" output" , " hn " , " cn " };
385
+ std::vector<const char *> output_node_names = {" output" , " stateN " };
393
386
394
387
public:
395
388
// Construction
396
389
VadIterator (const std::wstring ModelPath,
397
- int Sample_rate = 16000 , int windows_frame_size = 64 ,
390
+ int Sample_rate = 16000 , int windows_frame_size = 32 ,
398
391
float Threshold = 0.5 , int min_silence_duration_ms = 0 ,
399
- int speech_pad_ms = 64 , int min_speech_duration_ms = 64 ,
392
+ int speech_pad_ms = 32 , int min_speech_duration_ms = 32 ,
400
393
float max_speech_duration_s = std::numeric_limits<float >::infinity())
401
394
{
402
395
init_onnx_model (ModelPath);
@@ -422,8 +415,7 @@ class VadIterator
422
415
input_node_dims[0 ] = 1 ;
423
416
input_node_dims[1 ] = window_size_samples;
424
417
425
- _h.resize (size_hc);
426
- _c.resize (size_hc);
418
+ _state.resize (size_state);
427
419
sr.resize (1 );
428
420
sr[0 ] = sample_rate;
429
421
};
0 commit comments