From dc37a357a34396482f91a9921a3c470b17b17c9e Mon Sep 17 00:00:00 2001 From: David Anyatonwu Date: Sat, 14 Sep 2024 17:18:22 +0100 Subject: [PATCH] feat(audio): implement MKL-accelerated speech-to-text for Mac Signed-off-by: David Anyatonwu --- .github/workflows/benchmark.yml | 35 +++++-------- screenpipe-audio/Cargo.toml | 7 +-- screenpipe-audio/benches/stt_benchmark.rs | 60 +++++++---------------- screenpipe-audio/src/stt.rs | 18 ++++++- 4 files changed, 49 insertions(+), 71 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 0bf15d9d..160475ee 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -72,59 +72,46 @@ jobs: stt_benchmark: name: Run STT benchmark - runs-on: ubuntu-latest + runs-on: macos-latest steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable - name: Install dependencies run: | - sudo apt-get update - sudo apt-get install -y ffmpeg tesseract-ocr libtesseract-dev libavformat-dev libavfilter-dev libavdevice-dev ffmpeg libasound2-dev libgtk-3-dev libsoup-3.0-dev libjavascriptcoregtk-4.1-dev libwebkit2gtk-4.1-dev + brew install cmake openblas lapack - - name: Run STT benchmarks + - name: Run STT benchmarks (MKL) run: | - cargo bench --bench stt_benchmark -- --output-format bencher | tee -a stt_output.txt + cargo bench --bench stt_benchmark --features mkl -- --output-format bencher | tee -a stt_output_mkl.txt - name: Upload STT benchmark artifact uses: actions/upload-artifact@v3 with: - name: stt-benchmark-data - path: stt_output.txt + name: stt-benchmark-data-macos + path: stt_output_mkl.txt analyze_benchmarks: - needs: - [ - apple_ocr_benchmark, - tesseract_ocr_benchmark, - windows_ocr_benchmark, - stt_benchmark, - ] + needs: [stt_benchmark] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Download benchmark data - uses: actions/download-artifact@v3 - with: - name: ocr-benchmark-data - path: ./cache/ocr - - name: Download STT benchmark data uses: actions/download-artifact@v3 with: - name: stt-benchmark-data + name: stt-benchmark-data-macos path: ./cache/stt - name: List contents of cache directory run: ls -R ./cache - - name: Analyze OCR benchmarks + - name: Analyze STT benchmarks uses: benchmark-action/github-action-benchmark@v1 with: - name: OCR Benchmarks + name: STT Benchmarks tool: "cargo" - output-file-path: ./cache/ocr/ocr_output.txt + output-file-path: ./cache/stt/stt_output_mkl.txt github-token: ${{ secrets.GH_PAGES_TOKEN }} auto-push: true alert-threshold: "200%" diff --git a/screenpipe-audio/Cargo.toml b/screenpipe-audio/Cargo.toml index 06e0ba9b..1c759ebe 100644 --- a/screenpipe-audio/Cargo.toml +++ b/screenpipe-audio/Cargo.toml @@ -31,9 +31,9 @@ chrono = { version = "0.4.31", features = ["serde"] } # Local Embeddings + STT # TODO: feature metal, cuda, etc. see https://github.com/huggingface/candle/blob/main/candle-core/Cargo.toml -candle = { workspace = true } -candle-nn = { workspace = true } -candle-transformers = { workspace = true } +candle = { workspace = true, features = ["mkl"] } +candle-nn = { workspace = true, features = ["mkl"] } +candle-transformers = { workspace = true, features = ["mkl"] } vad-rs = "0.1.3" tokenizers = { workspace = true } anyhow = "1.0.86" @@ -80,6 +80,7 @@ criterion = { workspace = true } memory-stats = "1.0" [features] +default = ["mkl"] metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] mkl = ["candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] diff --git a/screenpipe-audio/benches/stt_benchmark.rs b/screenpipe-audio/benches/stt_benchmark.rs index 38da8044..66b57f59 100644 --- a/screenpipe-audio/benches/stt_benchmark.rs +++ b/screenpipe-audio/benches/stt_benchmark.rs @@ -1,12 +1,12 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use memory_stats::memory_stats; -use screenpipe_audio::vad_engine::SileroVad; use screenpipe_audio::{ - create_whisper_channel, stt, AudioTranscriptionEngine, VadEngineEnum, WhisperModel, + stt, AudioInput, AudioTranscriptionEngine, WhisperModel, vad_engine::SileroVad }; -use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; +use std::path::PathBuf; +use std::fs::File; +use std::io::Read; fn criterion_benchmark(c: &mut Criterion) { let audio_transcription_engine = Arc::new(AudioTranscriptionEngine::WhisperTiny); @@ -14,59 +14,35 @@ fn criterion_benchmark(c: &mut Criterion) { let test_file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("test_data") .join("selah.mp4"); + let mut audio_data = Vec::new(); + File::open(&test_file_path).unwrap().read_to_end(&mut audio_data).unwrap(); let mut group = c.benchmark_group("whisper_benchmarks"); group.sample_size(10); group.measurement_time(Duration::from_secs(60)); - group.bench_function("create_whisper_channel", |b| { - b.iter(|| { - let _ = create_whisper_channel( - black_box(audio_transcription_engine.clone()), - black_box(VadEngineEnum::Silero), - None, - ); - }) - }); - - group.bench_function("stt", |b| { + group.bench_function("stt_mkl", |b| { b.iter(|| { let mut vad_engine = Box::new(SileroVad::new().unwrap()); + let audio_input = AudioInput { + data: audio_data.clone().into_iter().map(|x| x as f32).collect(), + sample_rate: 16000, + channels: 1, + device: "test".to_string(), + }; let _ = stt( - black_box(test_file_path.to_string_lossy().as_ref()), + black_box(&audio_input), black_box(&whisper_model), black_box(audio_transcription_engine.clone()), - &mut *vad_engine, - None, + black_box(&mut *vad_engine), + black_box(None), + black_box(&PathBuf::from("test_output")), ); }) }); - group.bench_function("memory_usage_stt", |b| { - b.iter_custom(|iters| { - let mut total_duration = Duration::new(0, 0); - for _ in 0..iters { - let start = std::time::Instant::now(); - let before = memory_stats().unwrap().physical_mem; - let mut vad_engine = Box::new(SileroVad::new().unwrap()); - let _ = stt( - test_file_path.to_string_lossy().as_ref(), - &whisper_model, - audio_transcription_engine.clone(), - &mut *vad_engine, - None, - ); - let after = memory_stats().unwrap().physical_mem; - let duration = start.elapsed(); - total_duration += duration; - println!("Memory used: {} bytes", after - before); - } - total_duration - }) - }); - group.finish(); } criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); +criterion_main!(benches); \ No newline at end of file diff --git a/screenpipe-audio/src/stt.rs b/screenpipe-audio/src/stt.rs index 87d5f0ee..587fe739 100644 --- a/screenpipe-audio/src/stt.rs +++ b/screenpipe-audio/src/stt.rs @@ -40,8 +40,8 @@ pub struct WhisperModel { impl WhisperModel { pub fn new(engine: Arc) -> Result { debug!("Initializing WhisperModel"); - let device = Device::new_metal(0).unwrap_or(Device::new_cuda(0).unwrap_or(Device::Cpu)); - info!("device = {:?}", device); + let device = Self::get_optimal_device()?; + info!("Using device: {:?}", device); debug!("Fetching model files"); let (config_filename, tokenizer_filename, weights_filename) = { @@ -86,6 +86,20 @@ impl WhisperModel { device, }) } + + fn get_optimal_device() -> Result { + #[cfg(feature = "mkl")] + { + info!("Using MKL-accelerated CPU"); + Ok(Device::Cpu) + } + #[cfg(not(feature = "mkl"))] + { + info!("Using standard CPU"); + Ok(Device::Cpu) + } + } + } #[derive(Debug, Clone)]