Skip to content

Commit

Permalink
Refactor the whisper microphone example. (#2523)
Browse files Browse the repository at this point in the history
* Refactor the whisper microphone example.

* Tweak the whisper microphone example more.
  • Loading branch information
LaurentMazare authored Sep 30, 2024
1 parent aa35bf2 commit 6110ad8
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 82 deletions.
2 changes: 1 addition & 1 deletion candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
microphone = ["cpal", "rubato"]
encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
Expand Down
154 changes: 73 additions & 81 deletions candle-examples/examples/whisper-microphone/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@ use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use std::iter;
use tokenizers::Tokenizer;

mod multilingual;

use candle_transformers::models::whisper::{self as m, audio, Config};

use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use std::sync::{Arc, Mutex};

pub enum Model {
Normal(m::model::Whisper),
Expand Down Expand Up @@ -479,6 +477,10 @@ struct Args {
/// Print the full DecodingResult structure rather than just the text.
#[arg(long)]
verbose: bool,

/// The input device to use.
#[arg(long)]
device: Option<String>,
}

pub fn main() -> Result<()> {
Expand Down Expand Up @@ -543,13 +545,12 @@ pub fn main() -> Result<()> {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
Model::Normal(m::model::Whisper::load(&vb, config.clone())?)
};
let language_token = None;
let mut dc = Decoder::new(
let mut decoder = Decoder::new(
model,
tokenizer.clone(),
args.seed,
&device,
language_token,
/* language_token */ None,
args.task,
args.timestamps,
args.verbose,
Expand All @@ -565,47 +566,69 @@ pub fn main() -> Result<()> {

// Set up the input device and stream with the default input config.
let host = cpal::default_host();
let _device = "default";
let _device = if _device == "default" {
host.default_input_device()
} else {
host.input_devices()?
.find(|x| x.name().map(|y| y == _device).unwrap_or(false))
let audio_device = match args.device.as_ref() {
None => host.default_input_device(),
Some(device) => host
.input_devices()?
.find(|x| x.name().map_or(false, |y| &y == device)),
}
.expect("failed to find input device");
.expect("failed to find the audio input device");

let _config = _device
let audio_config = audio_device
.default_input_config()
.expect("Failed to get default input config");

let channel_count = _config.channels() as usize;

let audio_ring_buffer = Arc::new(Mutex::new(Vec::new()));
let audio_ring_buffer_2 = audio_ring_buffer.clone();

std::thread::spawn(move || loop {
let data = record_audio(&_device, &_config, 300).unwrap();
audio_ring_buffer.lock().unwrap().extend_from_slice(&data);
let max_len = data.len() * 16;
let data_len = data.len();
let len = audio_ring_buffer.lock().unwrap().len();
if len > max_len {
let mut data = audio_ring_buffer.lock().unwrap();
let new_data = data[data_len..].to_vec();
*data = new_data;
}
});
println!("audio config {audio_config:?}");

let channel_count = audio_config.channels() as usize;
let in_sample_rate = audio_config.sample_rate().0 as usize;
let resample_ratio = 16000. / in_sample_rate as f64;
let mut resampler = rubato::FastFixedIn::new(
resample_ratio,
10.,
rubato::PolynomialDegree::Septic,
1024,
1,
)?;
let (tx, rx) = std::sync::mpsc::channel();
let stream = audio_device.build_input_stream(
&audio_config.config(),
move |pcm: &[f32], _: &cpal::InputCallbackInfo| {
let pcm = pcm
.iter()
.step_by(channel_count)
.copied()
.collect::<Vec<f32>>();
if !pcm.is_empty() {
tx.send(pcm).unwrap()
}
},
move |err| {
eprintln!("an error occurred on stream: {}", err);
},
None,
)?;
stream.play()?;

// loop to process the audio data forever (until the user stops the program)
println!("Transcribing audio...");
for (i, _) in iter::repeat(()).enumerate() {
std::thread::sleep(std::time::Duration::from_millis(1000));
let data = audio_ring_buffer_2.lock().unwrap().clone();
let pcm_data: Vec<_> = data[..data.len() / channel_count as usize]
.iter()
.map(|v| *v as f32 / 32768.)
.collect();
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
println!("transcribing audio...");
let mut buffered_pcm = vec![];
let mut language_token_set = false;
while let Ok(pcm) = rx.recv() {
use rubato::Resampler;

buffered_pcm.extend_from_slice(&pcm);
if buffered_pcm.len() < 10 * in_sample_rate {
continue;
}
let mut resampled_pcm = vec![];
for buffered_pcm in buffered_pcm.chunks(1024) {
let pcm = resampler.process(&[&buffered_pcm], None)?;
resampled_pcm.extend_from_slice(&pcm[0])
}
let pcm = resampled_pcm;
println!("{} {}", buffered_pcm.len(), pcm.len());
buffered_pcm.clear();
let mel = audio::pcm_to_mel(&config, &pcm, &mel_filters);
let mel_len = mel.len();
let mel = Tensor::from_vec(
mel,
Expand All @@ -614,9 +637,13 @@ pub fn main() -> Result<()> {
)?;

// on the first iteration, we detect the language and set the language token.
if i == 0 {
if !language_token_set {
let language_token = match (args.model.is_multilingual(), args.language.clone()) {
(true, None) => Some(multilingual::detect_language(dc.model(), &tokenizer, &mel)?),
(true, None) => Some(multilingual::detect_language(
decoder.model(),
&tokenizer,
&mel,
)?),
(false, None) => None,
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
Ok(token_id) => Some(token_id),
Expand All @@ -627,47 +654,12 @@ pub fn main() -> Result<()> {
}
};
println!("language_token: {:?}", language_token);
dc.set_language_token(language_token);
decoder.set_language_token(language_token);
language_token_set = true;
}
dc.run(
&mel,
Some((
i as f64,
i as f64 + data.len() as f64 / m::SAMPLE_RATE as f64,
)),
)?;
dc.reset_kv_cache();
decoder.run(&mel, None)?;
decoder.reset_kv_cache();
}

Ok(())
}

fn record_audio(
device: &cpal::Device,
config: &cpal::SupportedStreamConfig,
milliseconds: u64,
) -> Result<Vec<i16>> {
let writer = Arc::new(Mutex::new(Vec::new()));
let writer_2 = writer.clone();
let stream = device.build_input_stream(
&config.config(),
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let processed = data
.iter()
.map(|v| (v * 32768.0) as i16)
.collect::<Vec<i16>>();
writer_2.lock().unwrap().extend_from_slice(&processed);
},
move |err| {
eprintln!("an error occurred on stream: {}", err);
},
None,
)?;
stream.play()?;
std::thread::sleep(std::time::Duration::from_millis(milliseconds));
drop(stream);
let data = writer.lock().unwrap().clone();
let step = 3;
let data: Vec<i16> = data.iter().step_by(step).copied().collect();
Ok(data)
}

0 comments on commit 6110ad8

Please sign in to comment.