diff --git a/python-api-examples/compute-speaker-simiarlity.py b/python-api-examples/compute-speaker-simiarlity.py new file mode 100755 index 00000000..e73fc0f6 --- /dev/null +++ b/python-api-examples/compute-speaker-simiarlity.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 Xiaomi Corporation + +""" +Please download model files from +https://github.com/k2-fsa/sherpa/releases/ + +E.g. + +wget https://github.com/k2-fsa/sherpa/releases/download/speaker-recognition-models/3d_speaker-speech_eres2netv2_sv_zh-cn_16k-common.pt + +Please download test files from +https://github.com/csukuangfj/sr-data/tree/main/test/3d-speaker + +""" + +import time +from typing import Tuple +import torch + +import librosa +import numpy as np +import soundfile as sf + +import sherpa + + +def load_audio(filename: str) -> Tuple[np.ndarray, int]: + data, sample_rate = sf.read( + filename, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + samples = np.ascontiguousarray(data) + return samples, sample_rate + + +def create_extractor(): + config = sherpa.SpeakerEmbeddingExtractorConfig( + model="./3d_speaker-speech_eres2netv2_sv_zh-cn_16k-common.pt", + ) + print(config) + return sherpa.SpeakerEmbeddingExtractor(config) + + +def main(): + extractor = create_extractor() + + file1 = "./speaker1_a_cn_16k.wav" + file2 = "./speaker1_b_cn_16k.wav" + file3 = "./speaker2_a_cn_16k.wav" + + samples1, sample_rate1 = load_audio(file1) + if sample_rate1 != 16000: + samples1 = librosa.resample(samples1, orig_sr=sample_rate1, target_sr=16000) + sample_rate1 = 16000 + + samples2, sample_rate2 = load_audio(file2) + if sample_rate2 != 16000: + samples2 = librosa.resample(samples2, orig_sr=sample_rate2, target_sr=16000) + sample_rate2 = 16000 + + samples3, sample_rate3 = load_audio(file3) + if sample_rate3 != 16000: + samples3 = librosa.resample(samples3, orig_sr=sample_rate3, target_sr=16000) + sample_rate3 = 16000 + + start = time.time() + stream1 = extractor.create_stream() + stream2 = extractor.create_stream() + stream3 = extractor.create_stream() + + stream1.accept_waveform(samples1) + stream2.accept_waveform(samples2) + stream3.accept_waveform(samples3) + + embeddings = extractor.compute([stream1, stream2, stream3]) + # embeddings: (batch_size, dim) + + x12 = torch.nn.functional.cosine_similarity(embeddings[0], embeddings[1], dim=0) + x13 = torch.nn.functional.cosine_similarity(embeddings[0], embeddings[2], dim=0) + x23 = torch.nn.functional.cosine_similarity(embeddings[1], embeddings[2], dim=0) + + end = time.time() + + elapsed_seconds = end - start + + print(x12, x13, x23) + + audio_duration = ( + len(samples1) / sample_rate1 + + len(samples2) / sample_rate2 + + len(samples3) / sample_rate3 + ) + real_time_factor = elapsed_seconds / audio_duration + print(f"Elapsed seconds: {elapsed_seconds:.3f}") + print(f"Audio duration in seconds: {audio_duration:.3f}") + print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") + + +if __name__ == "__main__": + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + torch._C._set_graph_executor_optimize(False) + + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/python-api-examples/vad-with-sense-voice.py b/python-api-examples/vad-with-sense-voice.py index 3135f84a..336a2b91 100755 --- a/python-api-examples/vad-with-sense-voice.py +++ b/python-api-examples/vad-with-sense-voice.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# Copyright (c) 2025 Xiaomi Corporation """ Please download sense voice model from diff --git a/sherpa/csrc/CMakeLists.txt b/sherpa/csrc/CMakeLists.txt index fed74a94..66573a42 100644 --- a/sherpa/csrc/CMakeLists.txt +++ b/sherpa/csrc/CMakeLists.txt @@ -44,6 +44,10 @@ set(sherpa_srcs vad-model-config.cc voice-activity-detector-impl.cc voice-activity-detector.cc + # + speaker-embedding-extractor-model.cc + speaker-embedding-extractor.cc + speaker-embedding-extractor-impl.cc ) add_library(sherpa_core ${sherpa_srcs}) @@ -129,6 +133,9 @@ target_include_directories(sherpa-version PRIVATE ${CMAKE_BINARY_DIR}) add_executable(sherpa-vad sherpa-vad.cc) target_link_libraries(sherpa-vad sherpa_core) +add_executable(sherpa-compute-speaker-similarity sherpa-compute-speaker-similarity.cc) +target_link_libraries(sherpa-compute-speaker-similarity sherpa_core) + install(TARGETS sherpa_core DESTINATION lib @@ -138,5 +145,6 @@ install( TARGETS sherpa-version sherpa-vad + sherpa-compute-speaker-similarity DESTINATION bin ) diff --git a/sherpa/csrc/sherpa-compute-speaker-similarity.cc b/sherpa/csrc/sherpa-compute-speaker-similarity.cc new file mode 100644 index 00000000..4a0f35bc --- /dev/null +++ b/sherpa/csrc/sherpa-compute-speaker-similarity.cc @@ -0,0 +1,93 @@ +// sherpa/csrc/sherpa-compute-speaker-similarity.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include // NOLINT +#include + +#include "sherpa/cpp_api/parse-options.h" +#include "sherpa/csrc/fbank-features.h" +#include "sherpa/csrc/speaker-embedding-extractor.h" + +int32_t main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +This program uses a speaker embedding model to compute +similarity between two wave files. + +sherpa-compute-speaker-similarity \ + --model=/path/to/model.pt \ + ./foo.wav \ + ./bar.wav \ +)usage"; + + int32_t num_threads = 1; + sherpa::ParseOptions po(kUsageMessage); + sherpa::SpeakerEmbeddingExtractorConfig config; + config.Register(&po); + po.Register("num-threads", &num_threads, "Number of threads for PyTorch"); + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + std::cerr << "Please provide only 2 test waves\n"; + exit(-1); + } + + std::cerr << config.ToString() << "\n"; + if (!config.Validate()) { + std::cerr << "Please check your config\n"; + return -1; + } + + int32_t sr = 16000; + sherpa::SpeakerEmbeddingExtractor extractor(config); + + const auto begin = std::chrono::steady_clock::now(); + + torch::Tensor samples1 = sherpa::ReadWave(po.GetArg(1), sr).first; + + auto stream1 = extractor.CreateStream(); + stream1->AcceptSamples(samples1.data_ptr(), samples1.numel()); + + torch::Tensor samples2 = sherpa::ReadWave(po.GetArg(2), sr).first; + + auto stream2 = extractor.CreateStream(); + stream2->AcceptSamples(samples2.data_ptr(), samples2.numel()); + + torch::Tensor embedding1; + torch::Tensor embedding2; + if (false) { + embedding1 = extractor.Compute(stream1.get()).squeeze(0); + embedding2 = extractor.Compute(stream2.get()).squeeze(0); + } else { + std::vector ss{stream1.get(), stream2.get()}; + auto embeddings = extractor.Compute(ss.data(), ss.size()); + + embedding1 = embeddings.index({0}); + embedding2 = embeddings.index({1}); + } + + auto score = + torch::nn::functional::cosine_similarity( + embedding1, embedding2, + torch::nn::functional::CosineSimilarityFuncOptions{}.dim(0).eps(1e-6)) + .item() + .toFloat(); + + const auto end = std::chrono::steady_clock::now(); + + const float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float duration = (samples1.size(0) + samples2.size(0)) / 16000.0f; + const float rtf = elapsed_seconds / duration; + + std::cout << "score: " << score << "\n"; + + fprintf(stderr, "Elapsed seconds: %.3f\n", elapsed_seconds); + fprintf(stderr, "Audio duration: %.3f s\n", duration); + fprintf(stderr, "Real time factor (RTF): %.3f/%.3f = %.3f\n", elapsed_seconds, + duration, rtf); + + return 0; +} diff --git a/sherpa/csrc/sherpa-vad.cc b/sherpa/csrc/sherpa-vad.cc index ce7244fb..38b963f9 100644 --- a/sherpa/csrc/sherpa-vad.cc +++ b/sherpa/csrc/sherpa-vad.cc @@ -17,7 +17,7 @@ This program uses a VAD models to add timestamps to a audio file sherpa-vad \ --silero-vad-model=/path/to/model.pt \ - --use-gpu=false \ + --vad-use-gpu=false \ --num-threads=1 \ ./foo.wav diff --git a/sherpa/csrc/speaker-embedding-extractor-general-impl.h b/sherpa/csrc/speaker-embedding-extractor-general-impl.h new file mode 100644 index 00000000..7e63a7ea --- /dev/null +++ b/sherpa/csrc/speaker-embedding-extractor-general-impl.h @@ -0,0 +1,98 @@ +// sherpa/csrc/speaker-embedding-extractor-general-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ +#define SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ +#include +#include +#include +#include + +#include "sherpa/cpp_api/feature-config.h" +#include "sherpa/cpp_api/macros.h" +#include "sherpa/cpp_api/offline-stream.h" +#include "sherpa/csrc/speaker-embedding-extractor-impl.h" +#include "sherpa/csrc/speaker-embedding-extractor-model.h" + +namespace sherpa { + +class SpeakerEmbeddingExtractorGeneralImpl + : public SpeakerEmbeddingExtractorImpl { + public: + explicit SpeakerEmbeddingExtractorGeneralImpl( + const SpeakerEmbeddingExtractorConfig &config) + : model_(config) { + // TODO(fangjun): make it configurable + feat_config_.fbank_opts.frame_opts.dither = 0; + feat_config_.fbank_opts.frame_opts.snip_edges = true; + feat_config_.fbank_opts.frame_opts.samp_freq = 16000; + feat_config_.fbank_opts.mel_opts.num_bins = 80; + feat_config_.normalize_samples = true; + + fbank_ = std::make_unique(feat_config_.fbank_opts); + + WarmUp(); + } + + int32_t Dim() const override { return model_.GetModelMetadata().output_dim; } + + std::unique_ptr CreateStream() const override { + return std::make_unique(fbank_.get(), feat_config_); + } + + torch::Tensor Compute(OfflineStream *s) const override { + InferenceMode no_grad; + auto features = s->GetFeatures(); + features -= features.mean(0, true); + features = features.unsqueeze(0); + auto device = model_.Device(); + return model_.Compute(features.to(device)); + } + + torch::Tensor Compute(OfflineStream **ss, int32_t n) const override { + InferenceMode no_grad; + if (n == 1) { + return Compute(ss[0]); + } + + std::vector features_vec(n); + for (int32_t i = 0; i != n; ++i) { + auto f = ss[i]->GetFeatures(); + f -= f.mean(0, true); + features_vec[i] = f; + } + + auto device = model_.Device(); + + auto features = + torch::nn::utils::rnn::pad_sequence(features_vec, true, 0).to(device); + + return model_.Compute(features); + } + + private: + void WarmUp() { + InferenceMode no_grad; + SHERPA_LOG(INFO) << "WarmUp begins"; + auto s = CreateStream(); + float sample_rate = fbank_->GetFrameOptions().samp_freq; + std::vector samples(2 * sample_rate, 0); + s->AcceptSamples(samples.data(), samples.size()); + + auto embedding = Compute(s.get()); + + model_.GetModelMetadata().output_dim = embedding.size(1); + + SHERPA_LOG(INFO) << "WarmUp ended"; + } + + private: + SpeakerEmbeddingExtractorModel model_; + std::unique_ptr fbank_; + FeatureConfig feat_config_; +}; + +} // namespace sherpa + +#endif // SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ diff --git a/sherpa/csrc/speaker-embedding-extractor-impl.cc b/sherpa/csrc/speaker-embedding-extractor-impl.cc new file mode 100644 index 00000000..05405da9 --- /dev/null +++ b/sherpa/csrc/speaker-embedding-extractor-impl.cc @@ -0,0 +1,17 @@ +// sherpa/csrc/speaker-embedding-extractor-impl.cc +// +// Copyright (c) 2025 Xiaomi Corporation +#include "sherpa/csrc/speaker-embedding-extractor-impl.h" + +#include "sherpa/csrc/speaker-embedding-extractor-general-impl.h" + +namespace sherpa { + +std::unique_ptr +SpeakerEmbeddingExtractorImpl::Create( + const SpeakerEmbeddingExtractorConfig &config) { + // supports only 3-d speaker for now + return std::make_unique(config); +} + +} // namespace sherpa diff --git a/sherpa/csrc/speaker-embedding-extractor-impl.h b/sherpa/csrc/speaker-embedding-extractor-impl.h new file mode 100644 index 00000000..adc285ee --- /dev/null +++ b/sherpa/csrc/speaker-embedding-extractor-impl.h @@ -0,0 +1,34 @@ +// sherpa/csrc/speaker-embedding-extractor-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ +#define SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ + +#include +#include +#include + +#include "sherpa/csrc/speaker-embedding-extractor.h" + +namespace sherpa { + +class SpeakerEmbeddingExtractorImpl { + public: + virtual ~SpeakerEmbeddingExtractorImpl() = default; + + static std::unique_ptr Create( + const SpeakerEmbeddingExtractorConfig &config); + + virtual int32_t Dim() const = 0; + + virtual std::unique_ptr CreateStream() const = 0; + + virtual torch::Tensor Compute(OfflineStream *s) const = 0; + + virtual torch::Tensor Compute(OfflineStream **s, int32_t n) const = 0; +}; + +} // namespace sherpa + +#endif // SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ diff --git a/sherpa/csrc/speaker-embedding-extractor-model-meta-data.h b/sherpa/csrc/speaker-embedding-extractor-model-meta-data.h new file mode 100644 index 00000000..962bb68c --- /dev/null +++ b/sherpa/csrc/speaker-embedding-extractor-model-meta-data.h @@ -0,0 +1,28 @@ +// sherpa/csrc/speaker-embedding-extractor-model-meta-data.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_ +#define SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_ + +#include +#include + +namespace sherpa { + +struct SpeakerEmbeddingExtractorModelMetaData { + int32_t output_dim = 0; + int32_t sample_rate = 0; + + // for wespeaker models, it is 0; + // for 3d-speaker models, it is 1 + int32_t normalize_samples = 1; + + // Chinese, English, etc. + std::string language; + + // for 3d-speaker, it is global-mean + std::string feature_normalize_type; +}; + +} // namespace sherpa +#endif // SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_ diff --git a/sherpa/csrc/speaker-embedding-extractor-model.cc b/sherpa/csrc/speaker-embedding-extractor-model.cc new file mode 100644 index 00000000..0793dd5e --- /dev/null +++ b/sherpa/csrc/speaker-embedding-extractor-model.cc @@ -0,0 +1,87 @@ +// sherpa/csrc/speaker-embedding-extractor-model.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa/csrc/speaker-embedding-extractor-model.h" + +#include +#include +#include + +#include "sherpa/csrc/macros.h" +#include "sherpa/csrc/speaker-embedding-extractor-model-meta-data.h" + +namespace sherpa { + +class SpeakerEmbeddingExtractorModel::Impl { + public: + explicit Impl(const SpeakerEmbeddingExtractorConfig &config) + : config_(config) { + torch::jit::ExtraFilesMap meta_data{ + {"version", {}}, + {"model_type", {}}, + }; + + if (config.use_gpu) { + device_ = torch::Device{torch::kCUDA}; + } + + model_ = torch::jit::load(config.model, device_, meta_data); + + model_.eval(); + + if (meta_data.at("model_type") != "3d-speaker") { + SHERPA_LOGE("Expect model_type '3d-speaker'. Given: '%s'\n", + meta_data.at("model_type").c_str()); + SHERPA_EXIT(-1); + } + } + + torch::Tensor Compute(torch::Tensor x) { + return model_.run_method("forward", x).toTensor(); + } + + SpeakerEmbeddingExtractorModelMetaData &GetModelMetadata() { + return meta_data_; + } + + const SpeakerEmbeddingExtractorModelMetaData &GetModelMetadata() const { + return meta_data_; + } + + torch::Device Device() const { return device_; } + + private: + SpeakerEmbeddingExtractorConfig config_; + + torch::jit::Module model_; + torch::Device device_{torch::kCPU}; + + SpeakerEmbeddingExtractorModelMetaData meta_data_; +}; + +SpeakerEmbeddingExtractorModel::SpeakerEmbeddingExtractorModel( + const SpeakerEmbeddingExtractorConfig &config) + : impl_(std::make_unique(config)) {} + +SpeakerEmbeddingExtractorModel::~SpeakerEmbeddingExtractorModel() = default; + +SpeakerEmbeddingExtractorModelMetaData & +SpeakerEmbeddingExtractorModel::GetModelMetadata() { + return impl_->GetModelMetadata(); +} + +torch::Tensor SpeakerEmbeddingExtractorModel::Compute(torch::Tensor x) const { + return impl_->Compute(x); +} + +const SpeakerEmbeddingExtractorModelMetaData & +SpeakerEmbeddingExtractorModel::GetModelMetadata() const { + return impl_->GetModelMetadata(); +} + +torch::Device SpeakerEmbeddingExtractorModel::Device() const { + return impl_->Device(); +} + +} // namespace sherpa diff --git a/sherpa/csrc/speaker-embedding-extractor-model.h b/sherpa/csrc/speaker-embedding-extractor-model.h new file mode 100644 index 00000000..3efd30fc --- /dev/null +++ b/sherpa/csrc/speaker-embedding-extractor-model.h @@ -0,0 +1,40 @@ +// sherpa/csrc/speaker-embedding-extractor-model.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ +#define SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ + +#include + +#include "sherpa/csrc/speaker-embedding-extractor-model-meta-data.h" +#include "sherpa/csrc/speaker-embedding-extractor.h" +#include "torch/script.h" + +namespace sherpa { + +class SpeakerEmbeddingExtractorModel { + public: + explicit SpeakerEmbeddingExtractorModel( + const SpeakerEmbeddingExtractorConfig &config); + + ~SpeakerEmbeddingExtractorModel(); + + SpeakerEmbeddingExtractorModelMetaData &GetModelMetadata(); + const SpeakerEmbeddingExtractorModelMetaData &GetModelMetadata() const; + + /** + * @param x A float32 tensor of shape (N, T, C) + * @return A float32 tensor of shape (N, C) + */ + torch::Tensor Compute(torch::Tensor x) const; + + torch::Device Device() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa + +#endif // SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ diff --git a/sherpa/csrc/speaker-embedding-extractor.cc b/sherpa/csrc/speaker-embedding-extractor.cc new file mode 100644 index 00000000..9df3b66d --- /dev/null +++ b/sherpa/csrc/speaker-embedding-extractor.cc @@ -0,0 +1,73 @@ +// sherpa/csrc/speaker-embedding-extractor.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa/csrc/speaker-embedding-extractor.h" + +#include + +#include "sherpa/cpp_api/macros.h" +#include "sherpa/csrc/file-utils.h" +#include "sherpa/csrc/macros.h" +#include "sherpa/csrc/speaker-embedding-extractor-impl.h" + +namespace sherpa { + +void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) { + po->Register("model", &model, "Path to the speaker embedding model."); + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("use_gpu", &use_gpu, "true to gpu."); +} + +bool SpeakerEmbeddingExtractorConfig::Validate() const { + if (model.empty()) { + SHERPA_LOGE("Please provide a speaker embedding extractor model"); + return false; + } + + if (!FileExists(model)) { + SHERPA_LOGE("speaker embedding extractor model: '%s' does not exist", + model.c_str()); + return false; + } + + return true; +} + +std::string SpeakerEmbeddingExtractorConfig::ToString() const { + std::ostringstream os; + + os << "SpeakerEmbeddingExtractorConfig("; + os << "model=\"" << model << "\", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "use_gpu=" << (use_gpu ? "True" : "False") << ")"; + + return os.str(); +} + +SpeakerEmbeddingExtractor::SpeakerEmbeddingExtractor( + const SpeakerEmbeddingExtractorConfig &config) + : impl_(SpeakerEmbeddingExtractorImpl::Create(config)) {} + +SpeakerEmbeddingExtractor::~SpeakerEmbeddingExtractor() = default; + +int32_t SpeakerEmbeddingExtractor::Dim() const { return impl_->Dim(); } + +std::unique_ptr SpeakerEmbeddingExtractor::CreateStream() const { + return impl_->CreateStream(); +} + +torch::Tensor SpeakerEmbeddingExtractor::Compute(OfflineStream *s) const { + InferenceMode no_grad; + return impl_->Compute(s); +} + +torch::Tensor SpeakerEmbeddingExtractor::Compute(OfflineStream **ss, + int32_t n) const { + InferenceMode no_grad; + return impl_->Compute(ss, n); +} + +} // namespace sherpa diff --git a/sherpa/csrc/speaker-embedding-extractor.h b/sherpa/csrc/speaker-embedding-extractor.h new file mode 100644 index 00000000..cf9d6f69 --- /dev/null +++ b/sherpa/csrc/speaker-embedding-extractor.h @@ -0,0 +1,65 @@ +// sherpa/csrc/speaker-embedding-extractor.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ +#define SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ + +#include +#include +#include + +#include "sherpa/cpp_api/offline-stream.h" +#include "sherpa/cpp_api/parse-options.h" + +namespace sherpa { + +struct SpeakerEmbeddingExtractorConfig { + std::string model; + bool use_gpu = false; + bool debug = false; + + SpeakerEmbeddingExtractorConfig() = default; + SpeakerEmbeddingExtractorConfig(const std::string &model, bool use_gpu, + bool debug) + : model(model), use_gpu(use_gpu), debug(debug) {} + + void Register(ParseOptions *po); + bool Validate() const; + std::string ToString() const; +}; + +class SpeakerEmbeddingExtractorImpl; + +class SpeakerEmbeddingExtractor { + public: + explicit SpeakerEmbeddingExtractor( + const SpeakerEmbeddingExtractorConfig &config); + + template + SpeakerEmbeddingExtractor(Manager *mgr, + const SpeakerEmbeddingExtractorConfig &config); + + ~SpeakerEmbeddingExtractor(); + + // Return the dimension of the embedding + int32_t Dim() const; + + // Create a stream to accept audio samples and compute features + std::unique_ptr CreateStream() const; + + // Compute the speaker embedding from the available unprocessed features + // of the given stream + // + // You have to ensure IsReady(s) returns true before you call this method. + torch::Tensor Compute(OfflineStream *s) const; + + torch::Tensor Compute(OfflineStream **ss, int32_t n) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa + +#endif // SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ diff --git a/sherpa/csrc/voice-activity-detector-silero-vad-impl.h b/sherpa/csrc/voice-activity-detector-silero-vad-impl.h index 2221b7f3..4f6f0911 100644 --- a/sherpa/csrc/voice-activity-detector-silero-vad-impl.h +++ b/sherpa/csrc/voice-activity-detector-silero-vad-impl.h @@ -154,6 +154,7 @@ class VoiceActivityDetectorSileroVadImpl : public VoiceActivityDetectorImpl { int32_t segment_size = config_.model.sample_rate * config_.segment_size; int32_t num_samples = samples.size(0); + float audio_duration = num_samples / 16000.0; bool need_pad = (num_samples > segment_size) && (num_samples % segment_size != 0); @@ -189,6 +190,11 @@ class VoiceActivityDetectorSileroVadImpl : public VoiceActivityDetectorImpl { segments = MergeSegments(std::move(segments)); + for (auto &s : segments) { + s.start = std::min(s.start, audio_duration); + s.end = std::min(s.end, audio_duration); + } + return segments; } diff --git a/sherpa/python/csrc/CMakeLists.txt b/sherpa/python/csrc/CMakeLists.txt index 53ee4656..d29e978c 100644 --- a/sherpa/python/csrc/CMakeLists.txt +++ b/sherpa/python/csrc/CMakeLists.txt @@ -16,6 +16,7 @@ pybind11_add_module(_sherpa resample.cc sherpa.cc silero-vad-model-config.cc + speaker-embedding-extractor.cc vad-model-config.cc voice-activity-detector-config.cc voice-activity-detector.cc diff --git a/sherpa/python/csrc/sherpa.cc b/sherpa/python/csrc/sherpa.cc index c6a132ef..4f083c52 100644 --- a/sherpa/python/csrc/sherpa.cc +++ b/sherpa/python/csrc/sherpa.cc @@ -30,6 +30,7 @@ #include "sherpa/python/csrc/online-recognizer.h" #include "sherpa/python/csrc/online-stream.h" #include "sherpa/python/csrc/resample.h" +#include "sherpa/python/csrc/speaker-embedding-extractor.h" #include "sherpa/python/csrc/voice-activity-detector.h" namespace sherpa { @@ -52,6 +53,8 @@ PYBIND11_MODULE(_sherpa, m) { PybindOnlineStream(m); PybindOnlineRecognizer(m); PybindVoiceActivityDetector(&m); + + PybindSpeakerEmbeddingExtractor(&m); } } // namespace sherpa diff --git a/sherpa/python/csrc/sherpa.h b/sherpa/python/csrc/sherpa.h index 74171da4..3e9d923c 100644 --- a/sherpa/python/csrc/sherpa.h +++ b/sherpa/python/csrc/sherpa.h @@ -20,6 +20,7 @@ #include "pybind11/pybind11.h" #include "pybind11/stl.h" +#include "torch/torch.h" namespace py = pybind11; diff --git a/sherpa/python/csrc/speaker-embedding-extractor.cc b/sherpa/python/csrc/speaker-embedding-extractor.cc new file mode 100644 index 00000000..515eae3a --- /dev/null +++ b/sherpa/python/csrc/speaker-embedding-extractor.cc @@ -0,0 +1,49 @@ +// sherpa/python/csrc/speaker-embedding-extractor.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa/python/csrc/speaker-embedding-extractor.h" + +#include +#include + +#include "sherpa/csrc/speaker-embedding-extractor.h" + +namespace sherpa { + +static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) { + using PyClass = SpeakerEmbeddingExtractorConfig; + py::class_(*m, "SpeakerEmbeddingExtractorConfig") + .def(py::init<>()) + .def(py::init(), py::arg("model"), + py::arg("use_gpu") = false, py::arg("debug") = false) + .def_readwrite("model", &PyClass::model) + .def_readwrite("use_gpu", &PyClass::use_gpu) + .def_readwrite("debug", &PyClass::debug) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +void PybindSpeakerEmbeddingExtractor(py::module *m) { + PybindSpeakerEmbeddingExtractorConfig(m); + + using PyClass = SpeakerEmbeddingExtractor; + py::class_(*m, "SpeakerEmbeddingExtractor") + .def(py::init(), + py::arg("config"), py::call_guard()) + .def_property_readonly("dim", &PyClass::Dim) + .def("create_stream", &PyClass::CreateStream, + py::call_guard()) + .def( + "compute", + [](PyClass &self, OfflineStream *s) { return self.Compute(s); }, + py::arg("s"), py::call_guard()) + .def( + "compute", + [](PyClass &self, std::vector &ss) { + return self.Compute(ss.data(), ss.size()); + }, + py::arg("ss"), py::call_guard()); +} + +} // namespace sherpa diff --git a/sherpa/python/csrc/speaker-embedding-extractor.h b/sherpa/python/csrc/speaker-embedding-extractor.h new file mode 100644 index 00000000..d72123ed --- /dev/null +++ b/sherpa/python/csrc/speaker-embedding-extractor.h @@ -0,0 +1,16 @@ +// sherpa/python/csrc/speaker-embedding-extractor.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ +#define SHERPA_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ + +#include "sherpa/python/csrc/sherpa.h" + +namespace sherpa { + +void PybindSpeakerEmbeddingExtractor(py::module *m); + +} + +#endif // SHERPA_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ diff --git a/sherpa/python/sherpa/__init__.py b/sherpa/python/sherpa/__init__.py index 6abf2731..13fb0c27 100644 --- a/sherpa/python/sherpa/__init__.py +++ b/sherpa/python/sherpa/__init__.py @@ -27,6 +27,8 @@ OnlineRecognizerConfig, OnlineStream, SileroVadModelConfig, + SpeakerEmbeddingExtractor, + SpeakerEmbeddingExtractorConfig, VadModelConfig, VoiceActivityDetector, VoiceActivityDetectorConfig,