Skip to content

Commit

Permalink
merge ezra pr
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Feb 12, 2025
2 parents 1978ff6 + 4be0923 commit fdc2313
Show file tree
Hide file tree
Showing 12 changed files with 395 additions and 73 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ edition = "2021"

[workspace.dependencies]
# AI
candle = { package = "candle-core", version = "0.7.2" }
candle-nn = { package = "candle-nn", version = "0.7.2" }
candle-transformers = { package = "candle-transformers", version = "0.7.2" }
tokenizers = "0.20.0"
candle = { package = "candle-core", version = "0.8.2" }
candle-nn = { package = "candle-nn", version = "0.8.2" }
candle-transformers = { package = "candle-transformers", version = "0.8.2" }
tokenizers = "0.21.0"
hf-hub = { version = "0.3.2", git = "https://github.com/neo773/hf-hub", features = [
"native-tls",
] }
Expand Down
4 changes: 3 additions & 1 deletion screenpipe-audio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ crossbeam = { workspace = true }
# Directories
dirs = "5.0.1"

lazy_static = { version = "1.4.0" }
lazy_static = "1.4.0"
realfft = "3.4.0"
regex = "1.11.0"
ndarray = "0.16"
Expand All @@ -80,6 +80,8 @@ ort-sys = "=2.0.0-rc.8"
futures = "0.3.31"
deepgram = { git = "https://github.com/EzraEllette/deepgram-rust-sdk.git" }
bytes = { version = "1.9.0", features = ["serde"] }
lru = "0.13.0"
num-traits = "0.2.19"

[target.'cfg(target_os = "windows")'.dependencies]
ort = { version = "=2.0.0-rc.6", features = [
Expand Down
278 changes: 278 additions & 0 deletions screenpipe-audio/src/audio_processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,281 @@ pub fn write_audio_to_file(
}
Ok(file_path_clone)
}

// Audio processing code, adapted from whisper.cpp
// https://github.com/ggerganov/whisper.cpp

use candle::utils::get_num_threads;

pub trait Float:
num_traits::Float + num_traits::FloatConst + num_traits::NumAssign + Send + Sync
{
}

impl Float for f32 {}
impl Float for f64 {}

// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357
fn fft<T: Float>(inp: &[T]) -> Vec<T> {
let n = inp.len();
let zero = T::zero();
if n == 1 {
return vec![inp[0], zero];
}
if n % 2 == 1 {
return dft(inp);
}
let mut out = vec![zero; n * 2];

let mut even = Vec::with_capacity(n / 2);
let mut odd = Vec::with_capacity(n / 2);

for (i, &inp) in inp.iter().enumerate() {
if i % 2 == 0 {
even.push(inp)
} else {
odd.push(inp);
}
}

let even_fft = fft(&even);
let odd_fft = fft(&odd);

let two_pi = T::PI() + T::PI();
let n_t = T::from(n).unwrap();
for k in 0..n / 2 {
let k_t = T::from(k).unwrap();
let theta = two_pi * k_t / n_t;
let re = theta.cos();
let im = -theta.sin();

let re_odd = odd_fft[2 * k];
let im_odd = odd_fft[2 * k + 1];

out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd;
out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd;

out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd;
out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd;
}
out
}

// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337
fn dft<T: Float>(inp: &[T]) -> Vec<T> {
let zero = T::zero();
let n = inp.len();
let two_pi = T::PI() + T::PI();

let mut out = Vec::with_capacity(2 * n);
let n_t = T::from(n).unwrap();
for k in 0..n {
let k_t = T::from(k).unwrap();
let mut re = zero;
let mut im = zero;

for (j, &inp) in inp.iter().enumerate() {
let j_t = T::from(j).unwrap();
let angle = two_pi * k_t * j_t / n_t;
re += inp * angle.cos();
im -= inp * angle.sin();
}

out.push(re);
out.push(im);
}
out
}

#[allow(clippy::too_many_arguments)]
// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414
fn log_mel_spectrogram_w<T: Float>(
ith: usize,
hann: &[T],
samples: &[T],
filters: &[T],
fft_size: usize,
fft_step: usize,
speed_up: bool,
n_len: usize,
n_mel: usize,
n_threads: usize,
) -> Vec<T> {
let n_fft = if speed_up {
1 + fft_size / 4
} else {
1 + fft_size / 2
};

let zero = T::zero();
let half = T::from(0.5).unwrap();
let mut fft_in = vec![zero; fft_size];
let mut mel = vec![zero; n_len * n_mel];
let n_samples = samples.len();
let end = std::cmp::min(n_samples / fft_step + 1, n_len);

for i in (ith..end).step_by(n_threads) {
let offset = i * fft_step;

// apply Hanning window
for j in 0..std::cmp::min(fft_size, n_samples - offset) {
fft_in[j] = hann[j] * samples[offset + j];
}

// fill the rest with zeros
if n_samples - offset < fft_size {
fft_in[n_samples - offset..].fill(zero);
}

// FFT
let mut fft_out: Vec<T> = fft(&fft_in);

// Calculate modulus^2 of complex numbers
for j in 0..fft_size {
fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
}
for j in 1..fft_size / 2 {
let v = fft_out[fft_size - j];
fft_out[j] += v;
}

if speed_up {
// scale down in the frequency domain results in a speed up in the time domain
for j in 0..n_fft {
fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]);
}
}

// mel spectrogram
for j in 0..n_mel {
let mut sum = zero;
let mut k = 0;
// Unroll loop
while k < n_fft.saturating_sub(3) {
sum += fft_out[k] * filters[j * n_fft + k]
+ fft_out[k + 1] * filters[j * n_fft + k + 1]
+ fft_out[k + 2] * filters[j * n_fft + k + 2]
+ fft_out[k + 3] * filters[j * n_fft + k + 3];
k += 4;
}
// Handle remainder
while k < n_fft {
sum += fft_out[k] * filters[j * n_fft + k];
k += 1;
}
mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
}
}
mel
}

pub async fn log_mel_spectrogram_<T: Float + 'static>(
samples: &[T],
filters: &[T],
fft_size: usize,
fft_step: usize,
n_mel: usize,
speed_up: bool,
) -> Vec<T> {
let zero = T::zero();
let two_pi = T::PI() + T::PI();
let half = T::from(0.5).unwrap();
let one = T::from(1.0).unwrap();
let four = T::from(4.0).unwrap();
let fft_size_t = T::from(fft_size).unwrap();

let hann: Vec<T> = (0..fft_size)
.map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos()))
.collect();
let n_len = samples.len() / fft_step;

// pad audio with at least one extra chunk of zeros
let pad = 100 * candle_transformers::models::whisper::CHUNK_LENGTH / 2;
let n_len = if n_len % pad != 0 {
(n_len / pad + 1) * pad
} else {
n_len
};
let n_len = n_len + pad;
let samples = {
let mut samples_padded = samples.to_vec();
let to_add = n_len * fft_step - samples.len();
samples_padded.extend(std::iter::repeat(zero).take(to_add));
samples_padded
};

// ensure that the number of threads is even and less than 12
let n_threads = std::cmp::min(get_num_threads() - get_num_threads() % 2, 12);
let n_threads = std::cmp::max(n_threads, 2);

// Create owned copies of the input data
let samples = samples.to_vec();
let filters = filters.to_vec();

let mut handles = vec![];
for thread_id in 0..n_threads {
// Clone the owned vectors for each task
let hann = hann.clone();
let samples = samples.clone();
let filters = filters.clone();

handles.push(tokio::task::spawn(async move {
log_mel_spectrogram_w(
thread_id, &hann, &samples, &filters, fft_size, fft_step, speed_up, n_len, n_mel,
n_threads,
)
}));
}

let all_outputs = futures::future::join_all(handles)
.await
.into_iter()
.map(|res| res.expect("Task failed"))
.collect::<Vec<_>>();

let l = all_outputs[0].len();
let mut mel = vec![zero; l];

// iterate over mel spectrogram segments, dividing work by threads.
for segment_start in (0..l).step_by(n_threads) {
// go through each thread's output.
for thread_output in all_outputs.iter() {
// add each thread's piece to our mel spectrogram.
for offset in 0..n_threads {
let mel_index = segment_start + offset; // find location in mel.
if mel_index < mel.len() {
// Make sure we don't go out of bounds.
mel[mel_index] += thread_output[mel_index];
}
}
}
}

let mmax = mel
.iter()
.max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
.copied()
.unwrap_or(zero)
- T::from(8).unwrap();
for m in mel.iter_mut() {
let v = T::max(*m, mmax);
*m = v / four + one
}
mel
}

pub async fn pcm_to_mel<T: Float + 'static>(
cfg: &candle_transformers::models::whisper::Config,
samples: &[T],
filters: &[T],
) -> Vec<T> {
log_mel_spectrogram_(
samples,
filters,
candle_transformers::models::whisper::N_FFT,
candle_transformers::models::whisper::HOP_LENGTH,
cfg.num_mel_bins,
false,
)
.await
}
25 changes: 20 additions & 5 deletions screenpipe-audio/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,23 @@ async fn run_record_and_transcribe(
);

const OVERLAP_SECONDS: usize = 2;
let mut collected_audio = Vec::new();
let sample_rate = audio_stream.device_config.sample_rate().0 as usize;
let overlap_samples = OVERLAP_SECONDS * sample_rate;
let duration_samples = (duration.as_secs_f64() * sample_rate as f64).ceil() as usize;
let max_samples = duration_samples + overlap_samples;

let mut collected_audio = Vec::with_capacity(max_samples);

while is_running.load(Ordering::Relaxed)
&& !audio_stream.is_disconnected.load(Ordering::Relaxed)
{
let start_time = tokio::time::Instant::now();

// Collect audio for the duration period
while start_time.elapsed() < duration && is_running.load(Ordering::Relaxed) {
match tokio::time::timeout(Duration::from_millis(100), receiver.recv()).await {
Ok(Ok(chunk)) => {
collected_audio.extend(chunk);
collected_audio.extend_from_slice(&chunk);
}
Ok(Err(e)) => {
error!("error receiving audio data: {}", e);
Expand All @@ -205,6 +209,12 @@ async fn run_record_and_transcribe(
}
}

// Discard oldest samples if we exceed buffer capacity
if collected_audio.len() > max_samples {
let excess = collected_audio.len() - max_samples;
collected_audio.drain(0..excess);
}

if !collected_audio.is_empty() {
debug!("sending audio segment to audio model");
match whisper_sender.try_send(AudioInput {
Expand All @@ -215,9 +225,14 @@ async fn run_record_and_transcribe(
}) {
Ok(_) => {
debug!("sent audio segment to audio model");
if collected_audio.len() > overlap_samples {
collected_audio =
collected_audio.split_off(collected_audio.len() - overlap_samples);
// Retain only overlap samples for next iteration
let current_len = collected_audio.len();
if current_len > overlap_samples {
let keep_from = current_len - overlap_samples;
collected_audio.drain(0..keep_from);
} else {
// If we don't have enough samples, keep all (unlikely case)
collected_audio.truncate(current_len);
}
}
Err(e) => {
Expand Down
Loading

0 comments on commit fdc2313

Please sign in to comment.