Skip to content

Commit

Permalink
Support specifying number of threads in sherpa-vad (#705)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Jan 22, 2025
1 parent f46678e commit 2a88cf8
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions sherpa/csrc/sherpa-vad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "sherpa/cpp_api/parse-options.h"
#include "sherpa/csrc/fbank-features.h"
#include "sherpa/csrc/voice-activity-detector.h"
#include "torch/torch.h"

int32_t main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Expand All @@ -17,13 +18,16 @@ This program uses a VAD models to add timestamps to a audio file
sherpa-vad \
--silero-vad-model=/path/to/model.pt \
--use-gpu=false \
--num-threads=1 \
./foo.wav
)usage";

int32_t num_threads = 1;
sherpa::ParseOptions po(kUsageMessage);
sherpa::VoiceActivityDetectorConfig config;
config.Register(&po);
po.Register("num-threads", &num_threads, "Number of threads for PyTorch");
po.Read(argc, argv);

if (po.NumArgs() != 1) {
Expand All @@ -34,6 +38,9 @@ sherpa-vad \
std::cerr << config.ToString() << "\n";
config.Validate();

torch::set_num_threads(num_threads);
torch::set_num_interop_threads(num_threads);

sherpa::VoiceActivityDetector vad(config);

torch::Tensor samples = sherpa::ReadWave(po.GetArg(1), 16000).first;
Expand All @@ -55,6 +62,7 @@ sherpa-vad \
fprintf(stderr, "%.3f -- %.3f\n", s.start, s.end);
}

fprintf(stderr, "Number of threads: %d\n", num_threads);
fprintf(stderr, "Elapsed seconds: %.3f\n", elapsed_seconds);
fprintf(stderr, "Audio duration: %.3f s\n", duration);
fprintf(stderr, "Real time factor (RTF): %.3f/%.3f = %.3f\n", elapsed_seconds,
Expand Down

0 comments on commit 2a88cf8

Please sign in to comment.