From 8c35988df79052cbbb1214c4fbc0de1fe4d40391 Mon Sep 17 00:00:00 2001 From: Louis Beaumont Date: Fri, 14 Feb 2025 13:26:25 -0800 Subject: [PATCH 1/4] v0 --- screenpipe-audio/Cargo.toml | 10 +- .../record_and_transcribe_benchmark.rs | 9 +- screenpipe-audio/examples/stt.rs | 9 +- screenpipe-audio/src/audio_processing.rs | 278 ------ .../src/bin/screenpipe-audio-forever.rs | 7 +- screenpipe-audio/src/bin/screenpipe-audio.rs | 7 +- screenpipe-audio/src/core.rs | 162 ++-- screenpipe-audio/src/deepgram/mod.rs | 4 +- screenpipe-audio/src/deepgram/realtime.rs | 86 +- screenpipe-audio/src/lib.rs | 3 +- screenpipe-audio/src/pyannote/embedding.rs | 19 +- screenpipe-audio/src/pyannote/identify.rs | 23 +- screenpipe-audio/src/pyannote/models.rs | 8 - screenpipe-audio/src/pyannote/segment.rs | 40 +- screenpipe-audio/src/pyannote/session.rs | 5 +- screenpipe-audio/src/realtime.rs | 16 +- screenpipe-audio/src/segments.rs | 6 +- screenpipe-audio/src/stt.rs | 165 ++-- screenpipe-audio/src/whisper/decoder.rs | 1 - screenpipe-audio/src/whisper/model.rs | 37 +- screenpipe-audio/src/whisper/process_chunk.rs | 25 +- screenpipe-audio/tests/accuracy_test.rs | 8 +- screenpipe-audio/tests/core_tests.rs | 26 +- screenpipe-audio/tests/realtime_test.rs | 34 +- .../tests/speaker_identification.rs | 4 +- .../src/unstructured_ocr.rs | 3 +- screenpipe-server/Cargo.toml | 2 + screenpipe-server/benches/db_benchmarks.rs | 4 +- screenpipe-server/src/add.rs | 12 +- screenpipe-server/src/auto_destruct.rs | 2 + .../src/bin/screenpipe-server.rs | 757 ++++------------ screenpipe-server/src/cli.rs | 79 +- screenpipe-server/src/core.rs | 815 ++++++++---------- screenpipe-server/src/db.rs | 8 +- screenpipe-server/src/db_types.rs | 4 +- screenpipe-server/src/filtering.rs | 43 +- screenpipe-server/src/lib.rs | 1 - screenpipe-server/src/resource_monitor.rs | 234 +++-- screenpipe-server/src/server.rs | 772 ++++++----------- screenpipe-server/src/video.rs | 329 +++---- screenpipe-server/tests/db.rs | 29 +- screenpipe-server/tests/endpoint_test.rs | 38 +- screenpipe-server/tests/tags_test.rs | 15 +- screenpipe-server/tests/video_utils_test.rs | 18 +- screenpipe-vision/Cargo.toml | 2 +- screenpipe-vision/benches/apple_leak_bench.rs | 4 +- screenpipe-vision/benches/vision_benchmark.rs | 1 - screenpipe-vision/examples/websocket.rs | 17 +- screenpipe-vision/src/apple.rs | 7 +- .../src/bin/screenpipe-vision.rs | 9 +- screenpipe-vision/src/core.rs | 286 +++--- screenpipe-vision/src/custom_ocr.rs | 4 +- screenpipe-vision/src/monitor.rs | 1 - .../src/run_ui_monitoring_macos.rs | 8 +- screenpipe-vision/src/tesseract.rs | 5 +- screenpipe-vision/src/utils.rs | 3 +- screenpipe-vision/tests/apple_vision_test.rs | 6 +- screenpipe-vision/tests/custom_ocr_test.rs | 9 +- .../tests/windows_vision_test.rs | 33 +- 59 files changed, 1627 insertions(+), 2925 deletions(-) diff --git a/screenpipe-audio/Cargo.toml b/screenpipe-audio/Cargo.toml index 3ef8f6cbcb..dd15aaa55d 100644 --- a/screenpipe-audio/Cargo.toml +++ b/screenpipe-audio/Cargo.toml @@ -62,26 +62,24 @@ webrtc-vad = "0.4.0" reqwest = { workspace = true } screenpipe-core = { path = "../screenpipe-core" } -screenpipe-events = { path = "../screenpipe-events" } # crossbeam crossbeam = { workspace = true } +dashmap = { workspace = true } # Directories dirs = "5.0.1" -lazy_static = "1.4.0" +lazy_static = { version = "1.4.0" } realfft = "3.4.0" regex = "1.11.0" ndarray = "0.16" ort = "=2.0.0-rc.6" -knf-rs = { git = "https://github.com/Neptune650/knf-rs.git", branch = "main" } +knf-rs = { git = "https://github.com/Neptune650/knf-rs.git" } ort-sys = "=2.0.0-rc.8" futures = "0.3.31" -deepgram = { git = "https://github.com/EzraEllette/deepgram-rust-sdk.git" } +deepgram = "0.6.4" bytes = { version = "1.9.0", features = ["serde"] } -lru = "0.13.0" -num-traits = "0.2.19" [target.'cfg(target_os = "windows")'.dependencies] ort = { version = "=2.0.0-rc.6", features = [ diff --git a/screenpipe-audio/benches/record_and_transcribe_benchmark.rs b/screenpipe-audio/benches/record_and_transcribe_benchmark.rs index dbc5dc40b7..ae81a7db39 100644 --- a/screenpipe-audio/benches/record_and_transcribe_benchmark.rs +++ b/screenpipe-audio/benches/record_and_transcribe_benchmark.rs @@ -1,10 +1,9 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; use screenpipe_audio::vad_engine::VadSensitivity; use screenpipe_audio::{ - create_whisper_channel, default_input_device, record_and_transcribe, AudioInput, AudioStream, - AudioTranscriptionEngine, + create_whisper_channel, default_input_device, record_and_transcribe, AudioDevice, AudioInput, + AudioStream, AudioTranscriptionEngine, }; -use screenpipe_core::{AudioDevice, DeviceManager}; use std::path::PathBuf; use std::sync::atomic::AtomicBool; use std::sync::Arc; @@ -19,14 +18,14 @@ async fn setup_test() -> ( let audio_device = default_input_device().unwrap(); // TODO feed voice in automatically somehow let output_path = PathBuf::from("/tmp/test_audio.mp4"); // let (whisper_sender, _) = mpsc::unbounded_channel(); - let (whisper_sender, _) = create_whisper_channel( + let (whisper_sender, _, _) = create_whisper_channel( Arc::new(AudioTranscriptionEngine::WhisperDistilLargeV3), screenpipe_audio::VadEngineEnum::Silero, None, &output_path, VadSensitivity::High, vec![], - Arc::new(DeviceManager::default()), + None, ) .await .unwrap(); diff --git a/screenpipe-audio/examples/stt.rs b/screenpipe-audio/examples/stt.rs index 3c5ac21845..be53a6199a 100644 --- a/screenpipe-audio/examples/stt.rs +++ b/screenpipe-audio/examples/stt.rs @@ -8,7 +8,6 @@ use screenpipe_audio::{AudioInput, AudioTranscriptionEngine}; use screenpipe_core::Language; use std::path::PathBuf; use std::sync::Arc; -use std::sync::Mutex as StdMutex; use strsim::levenshtein; use tokio::sync::Mutex; use tracing::debug; @@ -91,15 +90,16 @@ async fn main() { }; let mut segments = prepare_segments( - audio_input.data, + &audio_input.data, vad_engine.clone(), &segmentation_model_path, - Arc::new(StdMutex::new(embedding_manager)), + embedding_manager, embedding_extractor, &audio_input.device.to_string(), ) .await .unwrap(); + let mut whisper_model_guard = whisper_model.lock().await; let mut transcription = String::new(); while let Some(segment) = segments.recv().await { @@ -107,7 +107,7 @@ async fn main() { &segment.samples, audio_input.sample_rate, &audio_input.device.to_string(), - whisper_model.clone(), + &mut whisper_model_guard, Arc::new(AudioTranscriptionEngine::WhisperLargeV3Turbo), None, vec![Language::English], @@ -117,6 +117,7 @@ async fn main() { transcription.push_str(&transcript); } + drop(whisper_model_guard); let distance = levenshtein(expected_transcription, &transcription.to_lowercase()); let accuracy = 1.0 - (distance as f64 / expected_transcription.len() as f64); diff --git a/screenpipe-audio/src/audio_processing.rs b/screenpipe-audio/src/audio_processing.rs index 51040c0715..5f243033ec 100644 --- a/screenpipe-audio/src/audio_processing.rs +++ b/screenpipe-audio/src/audio_processing.rs @@ -161,281 +161,3 @@ pub fn write_audio_to_file( } Ok(file_path_clone) } - -// Audio processing code, adapted from whisper.cpp -// https://github.com/ggerganov/whisper.cpp - -use candle::utils::get_num_threads; - -pub trait Float: - num_traits::Float + num_traits::FloatConst + num_traits::NumAssign + Send + Sync -{ -} - -impl Float for f32 {} -impl Float for f64 {} - -// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357 -fn fft(inp: &[T]) -> Vec { - let n = inp.len(); - let zero = T::zero(); - if n == 1 { - return vec![inp[0], zero]; - } - if n % 2 == 1 { - return dft(inp); - } - let mut out = vec![zero; n * 2]; - - let mut even = Vec::with_capacity(n / 2); - let mut odd = Vec::with_capacity(n / 2); - - for (i, &inp) in inp.iter().enumerate() { - if i % 2 == 0 { - even.push(inp) - } else { - odd.push(inp); - } - } - - let even_fft = fft(&even); - let odd_fft = fft(&odd); - - let two_pi = T::PI() + T::PI(); - let n_t = T::from(n).unwrap(); - for k in 0..n / 2 { - let k_t = T::from(k).unwrap(); - let theta = two_pi * k_t / n_t; - let re = theta.cos(); - let im = -theta.sin(); - - let re_odd = odd_fft[2 * k]; - let im_odd = odd_fft[2 * k + 1]; - - out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd; - out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd; - - out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd; - out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd; - } - out -} - -// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337 -fn dft(inp: &[T]) -> Vec { - let zero = T::zero(); - let n = inp.len(); - let two_pi = T::PI() + T::PI(); - - let mut out = Vec::with_capacity(2 * n); - let n_t = T::from(n).unwrap(); - for k in 0..n { - let k_t = T::from(k).unwrap(); - let mut re = zero; - let mut im = zero; - - for (j, &inp) in inp.iter().enumerate() { - let j_t = T::from(j).unwrap(); - let angle = two_pi * k_t * j_t / n_t; - re += inp * angle.cos(); - im -= inp * angle.sin(); - } - - out.push(re); - out.push(im); - } - out -} - -#[allow(clippy::too_many_arguments)] -// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414 -fn log_mel_spectrogram_w( - ith: usize, - hann: &[T], - samples: &[T], - filters: &[T], - fft_size: usize, - fft_step: usize, - speed_up: bool, - n_len: usize, - n_mel: usize, - n_threads: usize, -) -> Vec { - let n_fft = if speed_up { - 1 + fft_size / 4 - } else { - 1 + fft_size / 2 - }; - - let zero = T::zero(); - let half = T::from(0.5).unwrap(); - let mut fft_in = vec![zero; fft_size]; - let mut mel = vec![zero; n_len * n_mel]; - let n_samples = samples.len(); - let end = std::cmp::min(n_samples / fft_step + 1, n_len); - - for i in (ith..end).step_by(n_threads) { - let offset = i * fft_step; - - // apply Hanning window - for j in 0..std::cmp::min(fft_size, n_samples - offset) { - fft_in[j] = hann[j] * samples[offset + j]; - } - - // fill the rest with zeros - if n_samples - offset < fft_size { - fft_in[n_samples - offset..].fill(zero); - } - - // FFT - let mut fft_out: Vec = fft(&fft_in); - - // Calculate modulus^2 of complex numbers - for j in 0..fft_size { - fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1]; - } - for j in 1..fft_size / 2 { - let v = fft_out[fft_size - j]; - fft_out[j] += v; - } - - if speed_up { - // scale down in the frequency domain results in a speed up in the time domain - for j in 0..n_fft { - fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]); - } - } - - // mel spectrogram - for j in 0..n_mel { - let mut sum = zero; - let mut k = 0; - // Unroll loop - while k < n_fft.saturating_sub(3) { - sum += fft_out[k] * filters[j * n_fft + k] - + fft_out[k + 1] * filters[j * n_fft + k + 1] - + fft_out[k + 2] * filters[j * n_fft + k + 2] - + fft_out[k + 3] * filters[j * n_fft + k + 3]; - k += 4; - } - // Handle remainder - while k < n_fft { - sum += fft_out[k] * filters[j * n_fft + k]; - k += 1; - } - mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10(); - } - } - mel -} - -pub async fn log_mel_spectrogram_( - samples: &[T], - filters: &[T], - fft_size: usize, - fft_step: usize, - n_mel: usize, - speed_up: bool, -) -> Vec { - let zero = T::zero(); - let two_pi = T::PI() + T::PI(); - let half = T::from(0.5).unwrap(); - let one = T::from(1.0).unwrap(); - let four = T::from(4.0).unwrap(); - let fft_size_t = T::from(fft_size).unwrap(); - - let hann: Vec = (0..fft_size) - .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos())) - .collect(); - let n_len = samples.len() / fft_step; - - // pad audio with at least one extra chunk of zeros - let pad = 100 * candle_transformers::models::whisper::CHUNK_LENGTH / 2; - let n_len = if n_len % pad != 0 { - (n_len / pad + 1) * pad - } else { - n_len - }; - let n_len = n_len + pad; - let samples = { - let mut samples_padded = samples.to_vec(); - let to_add = n_len * fft_step - samples.len(); - samples_padded.extend(std::iter::repeat(zero).take(to_add)); - samples_padded - }; - - // ensure that the number of threads is even and less than 12 - let n_threads = std::cmp::min(get_num_threads() - get_num_threads() % 2, 12); - let n_threads = std::cmp::max(n_threads, 2); - - // Create owned copies of the input data - let samples = samples.to_vec(); - let filters = filters.to_vec(); - - let mut handles = vec![]; - for thread_id in 0..n_threads { - // Clone the owned vectors for each task - let hann = hann.clone(); - let samples = samples.clone(); - let filters = filters.clone(); - - handles.push(tokio::task::spawn(async move { - log_mel_spectrogram_w( - thread_id, &hann, &samples, &filters, fft_size, fft_step, speed_up, n_len, n_mel, - n_threads, - ) - })); - } - - let all_outputs = futures::future::join_all(handles) - .await - .into_iter() - .map(|res| res.expect("Task failed")) - .collect::>(); - - let l = all_outputs[0].len(); - let mut mel = vec![zero; l]; - - // iterate over mel spectrogram segments, dividing work by threads. - for segment_start in (0..l).step_by(n_threads) { - // go through each thread's output. - for thread_output in all_outputs.iter() { - // add each thread's piece to our mel spectrogram. - for offset in 0..n_threads { - let mel_index = segment_start + offset; // find location in mel. - if mel_index < mel.len() { - // Make sure we don't go out of bounds. - mel[mel_index] += thread_output[mel_index]; - } - } - } - } - - let mmax = mel - .iter() - .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater)) - .copied() - .unwrap_or(zero) - - T::from(8).unwrap(); - for m in mel.iter_mut() { - let v = T::max(*m, mmax); - *m = v / four + one - } - mel -} - -pub async fn pcm_to_mel( - cfg: &candle_transformers::models::whisper::Config, - samples: &[T], - filters: &[T], -) -> Vec { - log_mel_spectrogram_( - samples, - filters, - candle_transformers::models::whisper::N_FFT, - candle_transformers::models::whisper::HOP_LENGTH, - cfg.num_mel_bins, - false, - ) - .await -} diff --git a/screenpipe-audio/src/bin/screenpipe-audio-forever.rs b/screenpipe-audio/src/bin/screenpipe-audio-forever.rs index fbf5ad13c2..073919061b 100644 --- a/screenpipe-audio/src/bin/screenpipe-audio-forever.rs +++ b/screenpipe-audio/src/bin/screenpipe-audio-forever.rs @@ -8,11 +8,10 @@ use screenpipe_audio::list_audio_devices; use screenpipe_audio::parse_audio_device; use screenpipe_audio::record_and_transcribe; use screenpipe_audio::vad_engine::VadSensitivity; +use screenpipe_audio::AudioDevice; use screenpipe_audio::AudioStream; use screenpipe_audio::AudioTranscriptionEngine; use screenpipe_audio::VadEngineEnum; -use screenpipe_core::AudioDevice; -use screenpipe_core::DeviceManager; use screenpipe_core::Language; use std::path::PathBuf; use std::sync::atomic::AtomicBool; @@ -89,14 +88,14 @@ async fn main() -> Result<()> { } let chunk_duration = Duration::from_secs_f32(args.audio_chunk_duration); - let (whisper_sender, whisper_receiver) = create_whisper_channel( + let (whisper_sender, whisper_receiver, _) = create_whisper_channel( Arc::new(AudioTranscriptionEngine::WhisperDistilLargeV3), VadEngineEnum::Silero, // Or VadEngineEnum::WebRtc, hardcoded for now args.deepgram_api_key, &PathBuf::from("output.mp4"), VadSensitivity::Medium, languages, - Arc::new(DeviceManager::default()), + None, ) .await?; diff --git a/screenpipe-audio/src/bin/screenpipe-audio.rs b/screenpipe-audio/src/bin/screenpipe-audio.rs index 733780e643..22c3222b7f 100644 --- a/screenpipe-audio/src/bin/screenpipe-audio.rs +++ b/screenpipe-audio/src/bin/screenpipe-audio.rs @@ -8,11 +8,10 @@ use screenpipe_audio::list_audio_devices; use screenpipe_audio::parse_audio_device; use screenpipe_audio::record_and_transcribe; use screenpipe_audio::vad_engine::VadSensitivity; +use screenpipe_audio::AudioDevice; use screenpipe_audio::AudioStream; use screenpipe_audio::AudioTranscriptionEngine; use screenpipe_audio::VadEngineEnum; -use screenpipe_core::AudioDevice; -use screenpipe_core::DeviceManager; use screenpipe_core::Language; use std::path::PathBuf; use std::sync::atomic::AtomicBool; @@ -93,14 +92,14 @@ async fn main() -> Result<()> { let chunk_duration = Duration::from_secs(10); let output_path = PathBuf::from("output.mp4"); - let (whisper_sender, whisper_receiver) = create_whisper_channel( + let (whisper_sender, whisper_receiver, _) = create_whisper_channel( Arc::new(AudioTranscriptionEngine::WhisperDistilLargeV3), VadEngineEnum::WebRtc, // Or VadEngineEnum::WebRtc, hardcoded for now deepgram_api_key, &output_path, VadSensitivity::Medium, languages, - Arc::new(DeviceManager::default()), + None ) .await?; // Spawn threads for each device diff --git a/screenpipe-audio/src/core.rs b/screenpipe-audio/src/core.rs index 88a93c50bd..1d33ac9888 100644 --- a/screenpipe-audio/src/core.rs +++ b/screenpipe-audio/src/core.rs @@ -1,12 +1,13 @@ use crate::audio_processing::audio_to_mono; -use crate::realtime::realtime_stt; +use crate::realtime::{realtime_stt, RealtimeTranscriptionEvent}; use crate::AudioInput; use anyhow::{anyhow, Result}; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use cpal::StreamError; use lazy_static::lazy_static; use log::{debug, error, info, warn}; -use screenpipe_core::{AudioDevice, AudioDeviceType, Language}; +use screenpipe_core::Language; +use serde::{Deserialize, Serialize}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::mpsc; use std::sync::Arc; @@ -44,6 +45,68 @@ impl fmt::Display for AudioTranscriptionEngine { } } +#[derive(Clone, Debug)] +pub struct DeviceControl { + pub is_running: bool, + pub is_paused: bool, +} + +#[derive(Clone, Eq, PartialEq, Hash, Serialize, Debug, Deserialize)] +pub enum DeviceType { + Input, + Output, +} + +#[derive(Clone, Eq, PartialEq, Hash, Serialize, Debug)] +pub struct AudioDevice { + pub name: String, + pub device_type: DeviceType, +} + +impl AudioDevice { + pub fn new(name: String, device_type: DeviceType) -> Self { + AudioDevice { name, device_type } + } + + pub fn from_name(name: &str) -> Result { + if name.trim().is_empty() { + return Err(anyhow!("Device name cannot be empty")); + } + + let (name, device_type) = if name.to_lowercase().ends_with("(input)") { + ( + name.trim_end_matches("(input)").trim().to_string(), + DeviceType::Input, + ) + } else if name.to_lowercase().ends_with("(output)") { + ( + name.trim_end_matches("(output)").trim().to_string(), + DeviceType::Output, + ) + } else { + return Err(anyhow!( + "Device type (input/output) not specified in the name" + )); + }; + + Ok(AudioDevice::new(name, device_type)) + } +} + +impl fmt::Display for AudioDevice { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{} ({})", + self.name, + match self.device_type { + DeviceType::Input => "input", + DeviceType::Output => "output", + } + ) + } +} + pub fn parse_audio_device(name: &str) -> Result { AudioDevice::from_name(name) } @@ -53,18 +116,18 @@ pub async fn get_device_and_config( ) -> Result<(cpal::Device, cpal::SupportedStreamConfig)> { let host = cpal::default_host(); - let is_output_device = audio_device.device_type == AudioDeviceType::Output; + let is_output_device = audio_device.device_type == DeviceType::Output; let is_display = audio_device.to_string().contains("Display"); let cpal_audio_device = if audio_device.to_string() == "default" { match audio_device.device_type { - AudioDeviceType::Input => host.default_input_device(), - AudioDeviceType::Output => host.default_output_device(), + DeviceType::Input => host.default_input_device(), + DeviceType::Output => host.default_output_device(), } } else { let mut devices = match audio_device.device_type { - AudioDeviceType::Input => host.input_devices()?, - AudioDeviceType::Output => host.output_devices()?, + DeviceType::Input => host.input_devices()?, + DeviceType::Output => host.output_devices()?, }; #[cfg(target_os = "macos")] @@ -89,7 +152,7 @@ pub async fn get_device_and_config( .unwrap_or(false) }) } - .ok_or_else(|| anyhow!("audio device not found"))?; + .ok_or_else(|| anyhow!("Audio device not found"))?; // if output device and windows, using output config let config = if is_output_device && !is_display { @@ -136,14 +199,16 @@ pub async fn record_and_transcribe( pub async fn start_realtime_recording( audio_stream: Arc, - languages: Arc<[Language]>, + languages: Vec, is_running: Arc, + realtime_transcription_sender: Arc>, deepgram_api_key: Option, ) -> Result<()> { while is_running.load(Ordering::Relaxed) { match realtime_stt( audio_stream.clone(), languages.clone(), + realtime_transcription_sender.clone(), is_running.clone(), deepgram_api_key.clone(), ) @@ -183,23 +248,26 @@ async fn run_record_and_transcribe( ); const OVERLAP_SECONDS: usize = 2; + let mut collected_audio = Vec::new(); let sample_rate = audio_stream.device_config.sample_rate().0 as usize; let overlap_samples = OVERLAP_SECONDS * sample_rate; - let duration_samples = (duration.as_secs_f64() * sample_rate as f64).ceil() as usize; - let max_samples = duration_samples + overlap_samples; - - let mut collected_audio = Vec::with_capacity(max_samples); while is_running.load(Ordering::Relaxed) && !audio_stream.is_disconnected.load(Ordering::Relaxed) { let start_time = tokio::time::Instant::now(); - // Collect audio for the duration period while start_time.elapsed() < duration && is_running.load(Ordering::Relaxed) { match tokio::time::timeout(Duration::from_millis(100), receiver.recv()).await { Ok(Ok(chunk)) => { - collected_audio.extend_from_slice(&chunk); + collected_audio.extend(chunk); + LAST_AUDIO_CAPTURE.store( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + Ordering::Relaxed, + ); } Ok(Err(e)) => { error!("error receiving audio data: {}", e); @@ -209,12 +277,6 @@ async fn run_record_and_transcribe( } } - // Discard oldest samples if we exceed buffer capacity - if collected_audio.len() > max_samples { - let excess = collected_audio.len() - max_samples; - collected_audio.drain(0..excess); - } - if !collected_audio.is_empty() { debug!("sending audio segment to audio model"); match whisper_sender.try_send(AudioInput { @@ -225,14 +287,9 @@ async fn run_record_and_transcribe( }) { Ok(_) => { debug!("sent audio segment to audio model"); - // Retain only overlap samples for next iteration - let current_len = collected_audio.len(); - if current_len > overlap_samples { - let keep_from = current_len - overlap_samples; - collected_audio.drain(0..keep_from); - } else { - // If we don't have enough samples, keep all (unlikely case) - collected_audio.truncate(current_len); + if collected_audio.len() > overlap_samples { + collected_audio = + collected_audio.split_off(collected_audio.len() - overlap_samples); } } Err(e) => { @@ -258,7 +315,7 @@ pub async fn list_audio_devices() -> Result> { for device in host.input_devices()? { if let Ok(name) = device.name() { - devices.push(AudioDevice::new(name, AudioDeviceType::Input)); + devices.push(AudioDevice::new(name, DeviceType::Input)); } } @@ -285,7 +342,7 @@ pub async fn list_audio_devices() -> Result> { for device in host.input_devices()? { if let Ok(name) = device.name() { if should_include_output_device(&name) { - devices.push(AudioDevice::new(name, AudioDeviceType::Output)); + devices.push(AudioDevice::new(name, DeviceType::Output)); } } } @@ -296,7 +353,7 @@ pub async fn list_audio_devices() -> Result> { for device in host.output_devices()? { if let Ok(name) = device.name() { if should_include_output_device(&name) { - devices.push(AudioDevice::new(name, AudioDeviceType::Output)); + devices.push(AudioDevice::new(name, DeviceType::Output)); } } } @@ -308,10 +365,7 @@ pub async fn list_audio_devices() -> Result> { && should_include_output_device(&device.name().unwrap()) { // TODO: not sure if it can be input, usually aggregate or multi output - devices.push(AudioDevice::new( - device.name().unwrap(), - AudioDeviceType::Output, - )); + devices.push(AudioDevice::new(device.name().unwrap(), DeviceType::Output)); } } @@ -323,7 +377,7 @@ pub fn default_input_device() -> Result { let device = host .default_input_device() .ok_or(anyhow!("No default input device detected"))?; - Ok(AudioDevice::new(device.name()?, AudioDeviceType::Input)) + Ok(AudioDevice::new(device.name()?, DeviceType::Input)) } // this should be optional ? pub fn default_output_device() -> Result { @@ -333,7 +387,7 @@ pub fn default_output_device() -> Result { if let Ok(host) = cpal::host_from_id(cpal::HostId::ScreenCaptureKit) { if let Some(device) = host.default_input_device() { if let Ok(name) = device.name() { - return Ok(AudioDevice::new(name, AudioDeviceType::Output)); + return Ok(AudioDevice::new(name, DeviceType::Output)); } } } @@ -341,7 +395,7 @@ pub fn default_output_device() -> Result { let device = host .default_output_device() .ok_or_else(|| anyhow!("No default output device found"))?; - Ok(AudioDevice::new(device.name()?, AudioDeviceType::Output)) + Ok(AudioDevice::new(device.name()?, DeviceType::Output)) } #[cfg(not(target_os = "macos"))] @@ -350,7 +404,7 @@ pub fn default_output_device() -> Result { let device = host .default_output_device() .ok_or_else(|| anyhow!("No default output device found"))?; - return Ok(AudioDevice::new(device.name()?, AudioDeviceType::Output)); + return Ok(AudioDevice::new(device.name()?, DeviceType::Output)); } } @@ -446,13 +500,6 @@ impl AudioStream { move |data: &[f32], _: &_| { let mono = audio_to_mono(data, channels); let _ = tx.send(mono); - LAST_AUDIO_CAPTURE.store( - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs(), - Ordering::Relaxed, - ); }, error_callback, None, @@ -464,13 +511,6 @@ impl AudioStream { move |data: &[i16], _: &_| { let mono = audio_to_mono(bytemuck::cast_slice(data), channels); let _ = tx.send(mono); - LAST_AUDIO_CAPTURE.store( - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs(), - Ordering::Relaxed, - ); }, error_callback, None, @@ -482,13 +522,6 @@ impl AudioStream { move |data: &[i32], _: &_| { let mono = audio_to_mono(bytemuck::cast_slice(data), channels); let _ = tx.send(mono); - LAST_AUDIO_CAPTURE.store( - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs(), - Ordering::Relaxed, - ); }, error_callback, None, @@ -500,13 +533,6 @@ impl AudioStream { move |data: &[i8], _: &_| { let mono = audio_to_mono(bytemuck::cast_slice(data), channels); let _ = tx.send(mono); - LAST_AUDIO_CAPTURE.store( - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs(), - Ordering::Relaxed, - ); }, error_callback, None, diff --git a/screenpipe-audio/src/deepgram/mod.rs b/screenpipe-audio/src/deepgram/mod.rs index 62a8ee0220..c81843cd43 100644 --- a/screenpipe-audio/src/deepgram/mod.rs +++ b/screenpipe-audio/src/deepgram/mod.rs @@ -9,8 +9,8 @@ use std::env; lazy_static! { pub(crate) static ref DEEPGRAM_API_URL: String = env::var("DEEPGRAM_API_URL") .unwrap_or_else(|_| "https://api.deepgram.com/v1/listen".to_string()); - pub(crate) static ref DEEPGRAM_WEBSOCKET_URL: String = - env::var("DEEPGRAM_WEBSOCKET_URL").unwrap_or_else(|_| "".to_string()); + pub(crate) static ref DEEPGRAM_WEBSOCKET_URL: String = env::var("DEEPGRAM_WEBSOCKET_URL") + .unwrap_or_else(|_| "wss://api.deepgram.com/v1/listen".to_string()); pub(crate) static ref CUSTOM_DEEPGRAM_API_TOKEN: String = env::var("CUSTOM_DEEPGRAM_API_TOKEN").unwrap_or_else(|_| String::new()); } diff --git a/screenpipe-audio/src/deepgram/realtime.rs b/screenpipe-audio/src/deepgram/realtime.rs index 7666e8e20e..a5949fd5a1 100644 --- a/screenpipe-audio/src/deepgram/realtime.rs +++ b/screenpipe-audio/src/deepgram/realtime.rs @@ -1,7 +1,7 @@ use crate::{ - deepgram::CUSTOM_DEEPGRAM_API_TOKEN, deepgram::DEEPGRAM_WEBSOCKET_URL, - realtime::RealtimeTranscriptionEvent, AudioStream, + deepgram::CUSTOM_DEEPGRAM_API_TOKEN, realtime::RealtimeTranscriptionEvent, AudioStream, }; +use crate::{AudioDevice, DeviceType}; use anyhow::Result; use bytes::BufMut; use bytes::Bytes; @@ -11,19 +11,16 @@ use deepgram::common::options::Encoding; use deepgram::common::stream_response::StreamResponse; use futures::channel::mpsc::{self, Receiver as FuturesReceiver}; use futures::{SinkExt, TryStreamExt}; -use screenpipe_core::AudioDevice; -use screenpipe_core::AudioDeviceType; use screenpipe_core::Language; -use screenpipe_events::send_event; use std::sync::{atomic::AtomicBool, Arc}; use std::time::Duration; use tokio::sync::broadcast::Receiver; -use tokio::sync::oneshot; -use tracing::info; +use tracing::error; pub async fn stream_transcription_deepgram( stream: Arc, - languages: Arc<[Language]>, + realtime_transcription_sender: Arc>, + languages: Vec, is_running: Arc, deepgram_api_key: Option, ) -> Result<()> { @@ -31,6 +28,7 @@ pub async fn stream_transcription_deepgram( stream.subscribe().await, stream.device.clone(), stream.device_config.sample_rate().0, + realtime_transcription_sender, is_running, languages, deepgram_api_key, @@ -44,8 +42,9 @@ pub async fn start_deepgram_stream( stream: Receiver>, device: Arc, sample_rate: u32, + realtime_transcription_sender: Arc>, is_running: Arc, - _languages: Arc<[Language]>, + _languages: Vec, deepgram_api_key: Option, ) -> Result<()> { let api_key = deepgram_api_key.unwrap_or(CUSTOM_DEEPGRAM_API_TOKEN.to_string()); @@ -54,28 +53,7 @@ pub async fn start_deepgram_stream( return Err(anyhow::anyhow!("Deepgram API key not found")); } - // create shutdown rx from is_running - let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); - - tokio::spawn(async move { - loop { - let running = is_running.load(std::sync::atomic::Ordering::SeqCst); - if !running { - shutdown_tx.send(()).unwrap(); - break; - } - tokio::time::sleep(Duration::from_millis(100)).await; - } - }); - - info!("Starting deepgram stream for device: {}", device); - - let deepgram = match DEEPGRAM_WEBSOCKET_URL.as_str().is_empty() { - true => deepgram::Deepgram::new(api_key)?, - false => { - deepgram::Deepgram::with_base_url_and_api_key(DEEPGRAM_WEBSOCKET_URL.as_str(), api_key)? - } - }; + let deepgram = deepgram::Deepgram::new(api_key)?; let deepgram_transcription = deepgram.transcription(); @@ -93,18 +71,20 @@ pub async fn start_deepgram_stream( let mut handle = req.clone().handle().await?; let mut results = req.stream(get_stream(stream)).await?; + let realtime_transcription_sender_clone = realtime_transcription_sender.clone(); let device_clone = device.clone(); loop { + if !is_running.load(std::sync::atomic::Ordering::SeqCst) { + break; + } + tokio::select! { - _ = &mut shutdown_rx => { - info!("Shutting down deepgram stream for device: {}", device); - break; - } result = results.try_next() => { if let Ok(Some(result)) = result { handle_transcription( result, + realtime_transcription_sender_clone.clone(), device_clone.clone(), ).await; } @@ -128,35 +108,41 @@ fn get_stream(mut stream: Receiver>) -> FuturesReceiver) { +async fn handle_transcription( + result: StreamResponse, + realtime_transcription_sender: Arc>, + device: Arc, +) { if let StreamResponse::TranscriptResponse { channel, is_final, .. } = result { let res = channel.alternatives.first().unwrap(); let text = res.transcript.clone(); - let is_input = device.device_type == AudioDeviceType::Input; + let is_input = device.device_type == DeviceType::Input; if !text.is_empty() { - let _ = send_event( - "transcription", - RealtimeTranscriptionEvent { - timestamp: chrono::Utc::now(), - device: device.to_string(), - transcription: text.to_string(), - is_final, - is_input, - }, - ); + match realtime_transcription_sender.send(RealtimeTranscriptionEvent { + timestamp: chrono::Utc::now(), + device: device.to_string(), + transcription: text.to_string(), + is_final, + is_input, + }) { + Ok(_) => {} + Err(e) => { + if !e.to_string().contains("channel closed") { + error!("Error sending transcription event: {}", e); + } + } + } } } } diff --git a/screenpipe-audio/src/lib.rs b/screenpipe-audio/src/lib.rs index 0cd10da334..6251b31003 100644 --- a/screenpipe-audio/src/lib.rs +++ b/screenpipe-audio/src/lib.rs @@ -14,7 +14,8 @@ pub use audio_processing::resample; pub use core::{ default_input_device, default_output_device, get_device_and_config, list_audio_devices, parse_audio_device, record_and_transcribe, start_realtime_recording, trigger_audio_permission, - AudioStream, AudioTranscriptionEngine, LAST_AUDIO_CAPTURE, + AudioDevice, AudioStream, AudioTranscriptionEngine, DeviceControl, DeviceType, + LAST_AUDIO_CAPTURE, }; pub mod realtime; pub use encode::encode_single_audio; diff --git a/screenpipe-audio/src/pyannote/embedding.rs b/screenpipe-audio/src/pyannote/embedding.rs index 788dfb40ac..6f16ca245d 100644 --- a/screenpipe-audio/src/pyannote/embedding.rs +++ b/screenpipe-audio/src/pyannote/embedding.rs @@ -2,31 +2,24 @@ use crate::pyannote::session; use anyhow::{Context, Result}; use ndarray::Array2; use ort::Session; -use std::{path::Path, sync::Mutex}; +use std::path::Path; #[derive(Debug)] -pub struct EmbeddingExtractor {} - -lazy_static::lazy_static! { - static ref EMBEDDING_SESSION: Mutex> = Mutex::new(None); +pub struct EmbeddingExtractor { + session: Session, } impl EmbeddingExtractor { pub fn new>(model_path: P) -> Result { - let mut session = EMBEDDING_SESSION.lock().unwrap(); - if session.is_none() { - *session = Some(session::create_session(model_path.as_ref(), false)?); - } - Ok(Self {}) + let session = session::create_session(model_path.as_ref())?; + Ok(Self { session }) } pub fn compute(&mut self, samples: &[f32]) -> Result> { - let session = EMBEDDING_SESSION.lock().unwrap(); - let session = session.as_ref().unwrap(); let features: Array2 = knf_rs::compute_fbank(samples).map_err(anyhow::Error::msg)?; let features = features.insert_axis(ndarray::Axis(0)); // Add batch dimension let inputs = ort::inputs! ["feats" => features.view()]?; - let ort_outs = session.run(inputs).context("Failed to run the session")?; + let ort_outs = self.session.run(inputs)?; let ort_out = ort_outs .get("embs") .context("Output tensor not found")? diff --git a/screenpipe-audio/src/pyannote/identify.rs b/screenpipe-audio/src/pyannote/identify.rs index 3e26b6dfe0..d76e2cc361 100644 --- a/screenpipe-audio/src/pyannote/identify.rs +++ b/screenpipe-audio/src/pyannote/identify.rs @@ -1,11 +1,11 @@ use anyhow::{bail, Result}; use ndarray::Array1; -use std::num::NonZeroUsize; +use std::collections::HashMap; #[derive(Debug, Clone)] pub struct EmbeddingManager { max_speakers: usize, - speakers: lru::LruCache>, + speakers: HashMap>, next_speaker_id: usize, } @@ -13,7 +13,7 @@ impl EmbeddingManager { pub fn new(max_speakers: usize) -> Self { Self { max_speakers, - speakers: lru::LruCache::new(NonZeroUsize::new(max_speakers).unwrap()), + speakers: HashMap::new(), next_speaker_id: 1, } } @@ -68,26 +68,13 @@ impl EmbeddingManager { fn add_speaker(&mut self, embedding: Array1) -> usize { let speaker_id = self.next_speaker_id; - self.speakers.push(speaker_id, embedding); + self.speakers.insert(speaker_id, embedding); self.next_speaker_id += 1; speaker_id } - pub fn add_embedding(&mut self, speaker_id: usize, embedding: Array1) { - self.speakers.put(speaker_id, embedding); - tracing::info!("Speaker cache size: {}", self.speakers.len()); - } - - pub fn get_embedding(&mut self, speaker_id: usize) -> Option<&Array1> { - self.speakers.get(&speaker_id) - } - - pub fn prune(&mut self) { - self.speakers.pop_lru(); - } - #[allow(unused)] - pub fn get_all_speakers(&self) -> &lru::LruCache> { + pub fn get_all_speakers(&self) -> &HashMap> { &self.speakers } } diff --git a/screenpipe-audio/src/pyannote/models.rs b/screenpipe-audio/src/pyannote/models.rs index 82e9d1d21e..42539845f3 100644 --- a/screenpipe-audio/src/pyannote/models.rs +++ b/screenpipe-audio/src/pyannote/models.rs @@ -118,14 +118,6 @@ async fn download_model(model_type: PyannoteModel) -> Result<()> { tokio::io::AsyncWriteExt::write_all(&mut file, &model_data).await?; info!("{} model successfully downloaded and saved", filename); - // debug!("optimizing {} model", filename); - // let session = ort::SessionBuilder::new()? - // .with_optimization_level(ort::GraphOptimizationLevel::Level3)? - // .with_intra_threads(1)? - // .with_inter_threads(1)? - // .with_optimized_model_path(path.to_str().unwrap())?; - // session.commit_from_file(path.to_str().unwrap())?; - Ok(()) } diff --git a/screenpipe-audio/src/pyannote/segment.rs b/screenpipe-audio/src/pyannote/segment.rs index 5fd7a58543..78c72b0f9a 100644 --- a/screenpipe-audio/src/pyannote/segment.rs +++ b/screenpipe-audio/src/pyannote/segment.rs @@ -1,15 +1,10 @@ use crate::pyannote::session; use anyhow::{Context, Result}; use ndarray::{ArrayBase, Axis, IxDyn, ViewRepr}; -use ort::Session; use std::{cmp::Ordering, path::Path, sync::Arc, sync::Mutex}; use super::{embedding::EmbeddingExtractor, identify::EmbeddingManager}; -lazy_static::lazy_static! { - static ref SEGMENTATION_SESSION: Mutex> = Mutex::new(None); -} - #[derive(Debug, Clone)] #[repr(C)] pub struct SpeechSegment { @@ -41,7 +36,7 @@ fn create_speech_segment( samples: &[f32], padded_samples: &[f32], embedding_extractor: Arc>, - embedding_manager: Arc>, + embedding_manager: &mut EmbeddingManager, ) -> Result { let start = start_offset / sample_rate as f64; let end = offset as f64 / sample_rate as f64; @@ -98,8 +93,9 @@ fn handle_new_segment( pub struct SegmentIterator { samples: Vec, sample_rate: u32, + session: ort::Session, embedding_extractor: Arc>, - embedding_manager: Arc>, + embedding_manager: EmbeddingManager, current_position: usize, frame_size: i32, window_size: usize, @@ -116,12 +112,9 @@ impl SegmentIterator { sample_rate: u32, model_path: P, embedding_extractor: Arc>, - embedding_manager: Arc>, + embedding_manager: EmbeddingManager, ) -> Result { - let mut session = SEGMENTATION_SESSION.lock().unwrap(); - if session.is_none() { - *session = Some(session::create_session(model_path, true)?); - } + let session = session::create_session(model_path.as_ref())?; let window_size = (sample_rate * 10) as usize; let padded_samples = { @@ -133,6 +126,7 @@ impl SegmentIterator { Ok(Self { samples, sample_rate, + session, embedding_extractor, embedding_manager, current_position: 0, @@ -155,9 +149,10 @@ impl SegmentIterator { .to_owned(); let inputs = ort::inputs![array].context("Failed to prepare inputs")?; - let session = SEGMENTATION_SESSION.lock().unwrap(); - let session = session.as_ref().unwrap(); - let ort_outs = session.run(inputs).context("Failed to run the session")?; + let ort_outs = self + .session + .run(inputs) + .context("Failed to run the session")?; let ort_out = ort_outs.get("output").context("Output tensor not found")?; let ort_out = ort_out @@ -183,7 +178,7 @@ impl SegmentIterator { &self.samples, &self.padded_samples, self.embedding_extractor.clone(), - self.embedding_manager.clone(), + &mut self.embedding_manager, ) { Ok(segment) => segment, Err(_) => { @@ -254,7 +249,7 @@ pub fn get_segments>( sample_rate: u32, model_path: P, embedding_extractor: Arc>, - embedding_manager: Arc>, + embedding_manager: EmbeddingManager, ) -> Result { SegmentIterator::new( samples.to_vec(), @@ -276,21 +271,14 @@ fn get_speaker_embedding( } pub fn get_speaker_from_embedding( - embedding_manager: Arc>, + embedding_manager: &mut EmbeddingManager, embedding: Vec, ) -> String { let search_threshold = 0.5; embedding_manager - .lock() - .unwrap() .search_speaker(embedding.clone(), search_threshold) - .ok_or_else(|| { - embedding_manager - .lock() - .unwrap() - .search_speaker(embedding, 0.0) - }) // Ensure always to return speaker + .ok_or_else(|| embedding_manager.search_speaker(embedding, 0.0)) // Ensure always to return speaker .map(|r| r.to_string()) .unwrap_or("?".into()) } diff --git a/screenpipe-audio/src/pyannote/session.rs b/screenpipe-audio/src/pyannote/session.rs index 77a218f7d5..ae7fe7c33d 100644 --- a/screenpipe-audio/src/pyannote/session.rs +++ b/screenpipe-audio/src/pyannote/session.rs @@ -3,12 +3,11 @@ use std::path::Path; use anyhow::Result; use ort::{GraphOptimizationLevel, Session}; -pub fn create_session>(path: P, enable_memory_pattern: bool) -> Result { +pub fn create_session>(path: P) -> Result { let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level3)? .with_intra_threads(1)? .with_inter_threads(1)? - .with_memory_pattern(enable_memory_pattern)? .commit_from_file(path.as_ref())?; Ok(session) -} +} \ No newline at end of file diff --git a/screenpipe-audio/src/realtime.rs b/screenpipe-audio/src/realtime.rs index 2cae0b4aba..5baf1a6e8f 100644 --- a/screenpipe-audio/src/realtime.rs +++ b/screenpipe-audio/src/realtime.rs @@ -2,21 +2,29 @@ use crate::{deepgram::stream_transcription_deepgram, AudioStream}; use anyhow::Result; use chrono::{DateTime, Utc}; use screenpipe_core::Language; -use serde::{Deserialize, Serialize}; +use serde::Serialize; use std::sync::{atomic::AtomicBool, Arc}; pub async fn realtime_stt( stream: Arc, - languages: Arc<[Language]>, + languages: Vec, + realtime_transcription_sender: Arc>, is_running: Arc, deepgram_api_key: Option, ) -> Result<()> { - stream_transcription_deepgram(stream, languages, is_running, deepgram_api_key).await?; + stream_transcription_deepgram( + stream, + realtime_transcription_sender, + languages, + is_running, + deepgram_api_key, + ) + .await?; Ok(()) } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Clone)] #[serde(rename_all = "camelCase")] pub struct RealtimeTranscriptionEvent { pub timestamp: DateTime, diff --git a/screenpipe-audio/src/segments.rs b/screenpipe-audio/src/segments.rs index 577586a41b..c864eb7c0e 100644 --- a/screenpipe-audio/src/segments.rs +++ b/screenpipe-audio/src/segments.rs @@ -15,14 +15,14 @@ use tokio::sync::Mutex; use vad_rs::VadStatus; pub async fn prepare_segments( - audio_data: Arc>, + audio_data: &[f32], vad_engine: Arc>>, segmentation_model_path: &PathBuf, - embedding_manager: Arc>, + embedding_manager: EmbeddingManager, embedding_extractor: Arc>, device: &str, ) -> Result> { - let audio_data = normalize_v2(audio_data.as_ref()); + let audio_data = normalize_v2(audio_data); let frame_size = 1600; let vad_engine = vad_engine.clone(); diff --git a/screenpipe-audio/src/stt.rs b/screenpipe-audio/src/stt.rs index 181e7d64e4..077e5bfd28 100644 --- a/screenpipe-audio/src/stt.rs +++ b/screenpipe-audio/src/stt.rs @@ -2,21 +2,22 @@ use crate::audio_processing::write_audio_to_file; use crate::deepgram::transcribe_with_deepgram; use crate::pyannote::models::{get_or_download_model, PyannoteModel}; use crate::pyannote::segment::SpeechSegment; -use crate::resample; pub use crate::segments::prepare_segments; use crate::{ pyannote::{embedding::EmbeddingExtractor, identify::EmbeddingManager}, vad_engine::{SileroVad, VadEngine, VadEngineEnum, VadSensitivity, WebRtcVad}, whisper::{process_with_whisper, WhisperModel}, - AudioTranscriptionEngine, + AudioDevice, AudioTranscriptionEngine, }; +use crate::{resample, DeviceControl}; use anyhow::{anyhow, Result}; use candle_transformers::models::whisper as m; -use log::{debug, error}; +use dashmap::DashMap; +use log::{debug, error, info}; #[cfg(target_os = "macos")] use objc::rc::autoreleasepool; -use screenpipe_core::{AudioDevice, DeviceManager, Language}; -// use std::time::Duration; +use screenpipe_core::Language; +use std::sync::atomic::{AtomicBool, Ordering}; use std::{ path::Path, sync::Arc, @@ -25,38 +26,80 @@ use std::{ }; use tokio::sync::Mutex; +pub fn stt_sync( + audio: &[f32], + sample_rate: u32, + device: &str, + whisper_model: &mut WhisperModel, + audio_transcription_engine: Arc, + deepgram_api_key: Option, + languages: Vec, +) -> Result { + let mut whisper_model = whisper_model.clone(); + let audio = audio.to_vec(); + + let device = device.to_string(); + let handle = std::thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + + rt.block_on(stt( + &audio, + sample_rate, + &device, + &mut whisper_model, + audio_transcription_engine, + deepgram_api_key, + languages, + )) + }); + + handle.join().unwrap() +} + #[allow(clippy::too_many_arguments)] pub async fn stt( audio: &[f32], sample_rate: u32, device: &str, - whisper_model: Arc>, + whisper_model: &mut WhisperModel, audio_transcription_engine: Arc, deepgram_api_key: Option, languages: Vec, ) -> Result { - let transcription: Result = - if audio_transcription_engine == AudioTranscriptionEngine::Deepgram.into() { - // Deepgram implementation - let api_key = deepgram_api_key.unwrap_or_default(); + let model = &whisper_model.model; - match transcribe_with_deepgram(&api_key, audio, device, sample_rate, languages.clone()) - .await - { - Ok(transcription) => Ok(transcription), - Err(e) => { - error!( - "device: {}, deepgram transcription failed, falling back to Whisper: {:?}", - device, e - ); - // Fallback to Whisper - process_with_whisper(whisper_model, audio, languages.clone()).await - } + debug!("Loading mel filters"); + let mel_bytes = match model.config().num_mel_bins { + 80 => include_bytes!("../models/whisper/melfilters.bytes").as_slice(), + 128 => include_bytes!("../models/whisper/melfilters128.bytes").as_slice(), + nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"), + }; + let mut mel_filters = vec![0f32; mel_bytes.len() / 4]; + ::read_f32_into(mel_bytes, &mut mel_filters); + + let transcription: Result = if audio_transcription_engine + == AudioTranscriptionEngine::Deepgram.into() + { + // Deepgram implementation + let api_key = deepgram_api_key.unwrap_or_default(); + + match transcribe_with_deepgram(&api_key, audio, device, sample_rate, languages.clone()) + .await + { + Ok(transcription) => Ok(transcription), + Err(e) => { + error!( + "device: {}, deepgram transcription failed, falling back to Whisper: {:?}", + device, e + ); + // Fallback to Whisper + process_with_whisper(&mut *whisper_model, audio, &mel_filters, languages.clone()) } - } else { - // Existing Whisper implementation - process_with_whisper(whisper_model, audio, languages.clone()).await - }; + } + } else { + // Existing Whisper implementation + process_with_whisper(&mut *whisper_model, audio, &mel_filters, languages) + }; transcription } @@ -83,11 +126,11 @@ pub struct TranscriptionResult { impl TranscriptionResult { // TODO --optimize - pub fn cleanup_overlap(&mut self, previous_transcript: &str) -> Option<(String, String)> { + pub fn cleanup_overlap(&mut self, previous_transcript: String) -> Option<(String, String)> { if let Some(transcription) = &self.transcription { let transcription = transcription.to_string(); if let Some((prev_idx, cur_idx)) = - longest_common_word_substring(previous_transcript, transcription.as_str()) + longest_common_word_substring(previous_transcript.as_str(), transcription.as_str()) { // strip old transcript from prev_idx word pos let new_prev = previous_transcript @@ -113,13 +156,13 @@ pub async fn create_whisper_channel( output_path: &Path, vad_sensitivity: VadSensitivity, languages: Vec, - device_manager: Arc, + audio_devices_control: Option>>, ) -> Result<( crossbeam::channel::Sender, crossbeam::channel::Receiver, + Arc, // Shutdown flag )> { - let whisper_model = WhisperModel::new(&audio_transcription_engine)?; - let whisper_model = Arc::new(Mutex::new(whisper_model)); + let mut whisper_model = WhisperModel::new(&audio_transcription_engine)?; let (input_sender, input_receiver): ( crossbeam::channel::Sender, crossbeam::channel::Receiver, @@ -134,6 +177,8 @@ pub async fn create_whisper_channel( }; vad_engine.set_sensitivity(vad_sensitivity); let vad_engine = Arc::new(Mutex::new(vad_engine)); + let shutdown_flag = Arc::new(AtomicBool::new(false)); + let shutdown_flag_clone = shutdown_flag.clone(); let output_path = output_path.to_path_buf(); let embedding_model_path = get_or_download_model(PyannoteModel::Embedding).await?; @@ -145,20 +190,29 @@ pub async fn create_whisper_channel( .ok_or_else(|| anyhow!("Invalid embedding model path"))?, )?)); - let embedding_manager = Arc::new(StdMutex::new(EmbeddingManager::new(25))); + let embedding_manager = EmbeddingManager::new(usize::MAX); tokio::spawn(async move { loop { + if shutdown_flag_clone.load(Ordering::Relaxed) { + info!("Whisper channel shutting down"); + break; + } + debug!("Waiting for input from input_receiver"); + crossbeam::select! { recv(input_receiver) -> input_result => { match input_result { Ok(mut audio) => { - // Check device state - if let Some(device) = device_manager.get_active_devices().await.get(&audio.device.to_string()) { - if !device.is_running { + // Check if device should be recording + if let Some(control) = audio_devices_control.as_ref().unwrap().get(&audio.device) { + if !control.is_running { debug!("Skipping audio processing for stopped device: {}", audio.device); continue; } + } else { + debug!("Device not found in control list: {}", audio.device); + continue; } debug!("Received input from input_receiver"); @@ -183,10 +237,10 @@ pub async fn create_whisper_channel( audio.data.as_ref().to_vec() }; - audio.data = Arc::new(audio_data); + audio.data = Arc::new(audio_data.clone()); audio.sample_rate = m::SAMPLE_RATE as u32; - let mut segments = match prepare_segments(audio.data.clone(), vad_engine.clone(), &segmentation_model_path, embedding_manager.clone(), embedding_extractor.clone(), &audio.device.to_string()).await { + let mut segments = match prepare_segments(&audio_data, vad_engine.clone(), &segmentation_model_path, embedding_manager.clone(), embedding_extractor.clone(), &audio.device.to_string()).await { Ok(segments) => segments, Err(e) => { error!("Error preparing segments: {:?}", e); @@ -195,7 +249,7 @@ pub async fn create_whisper_channel( }; let path = match write_audio_to_file( - audio.data.as_ref(), + &audio.data.to_vec(), audio.sample_rate, &output_path, &audio.device.to_string(), @@ -210,25 +264,20 @@ pub async fn create_whisper_channel( while let Some(segment) = segments.recv().await { let path = path.clone(); - let device = audio.device.clone(); let transcription_result = if cfg!(target_os = "macos") { #[cfg(target_os = "macos")] { - let whisper_model = whisper_model.clone(); - let audio_transcription_engine = audio_transcription_engine.clone(); - let deepgram_api_key = deepgram_api_key.clone(); - let languages = languages.clone(); let timestamp = timestamp + segment.start.round() as u64; - autoreleasepool(|| async move { - run_stt(segment, device.clone(), whisper_model.clone(), audio_transcription_engine.clone(), deepgram_api_key.clone(), languages.clone(), path, timestamp).await - }).await + autoreleasepool(|| { + run_stt(segment, audio.device.clone(), &mut whisper_model, audio_transcription_engine.clone(), deepgram_api_key.clone(), languages.clone(), path, timestamp) + }) } #[cfg(not(target_os = "macos"))] { unreachable!("This code should not be reached on non-macOS platforms") } } else { - run_stt(segment, device, whisper_model.clone(), audio_transcription_engine.clone(), deepgram_api_key.clone(), languages.clone(), path, timestamp).await + run_stt(segment, audio.device.clone(), &mut whisper_model, audio_transcription_engine.clone(), deepgram_api_key.clone(), languages.clone(), path, timestamp) }; if output_sender.send(transcription_result).is_err() { @@ -238,23 +287,25 @@ pub async fn create_whisper_channel( }, Err(e) => { error!("Error receiving input: {:?}", e); + // Depending on the error type, you might want to break the loop or continue + // For now, we'll continue to the next iteration break; } } }, - // default(Duration::from_millis(100)) => {} } } + // Cleanup code here (if needed) }); - Ok((input_sender, output_receiver)) + Ok((input_sender, output_receiver, shutdown_flag)) } #[allow(clippy::too_many_arguments)] -pub async fn run_stt( +pub fn run_stt( segment: SpeechSegment, device: Arc, - whisper_model: Arc>, + whisper_model: &mut WhisperModel, audio_transcription_engine: Arc, deepgram_api_key: Option, languages: Vec, @@ -263,17 +314,15 @@ pub async fn run_stt( ) -> TranscriptionResult { let audio = segment.samples.clone(); let sample_rate = segment.sample_rate; - match stt( + match stt_sync( &audio, sample_rate, &device.to_string(), whisper_model, - audio_transcription_engine, - deepgram_api_key, - languages, - ) - .await - { + audio_transcription_engine.clone(), + deepgram_api_key.clone(), + languages.clone(), + ) { Ok(transcription) => TranscriptionResult { input: AudioInput { data: Arc::new(audio), diff --git a/screenpipe-audio/src/whisper/decoder.rs b/screenpipe-audio/src/whisper/decoder.rs index b8cf2e6d9f..baa041a65d 100644 --- a/screenpipe-audio/src/whisper/decoder.rs +++ b/screenpipe-audio/src/whisper/decoder.rs @@ -284,7 +284,6 @@ impl<'a> Decoder<'a> { Ok(segments) } } - pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result { match tokenizer.token_to_id(token) { None => candle::bail!("no token-id for {token}"), diff --git a/screenpipe-audio/src/whisper/model.rs b/screenpipe-audio/src/whisper/model.rs index a994e5b519..ad6d22952c 100644 --- a/screenpipe-audio/src/whisper/model.rs +++ b/screenpipe-audio/src/whisper/model.rs @@ -3,22 +3,21 @@ use candle::{Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::models::whisper::{self as m, Config}; use hf_hub::{api::sync::Api, Repo, RepoType}; -use log::debug; +use log::{debug, info}; use tokenizers::Tokenizer; +#[derive(Clone)] pub struct WhisperModel { pub model: Model, - pub tokenizer: Box, + pub tokenizer: Tokenizer, pub device: Device, - pub mel_filters: Vec, - // pub weights_filename: String, } impl WhisperModel { pub fn new(engine: &crate::AudioTranscriptionEngine) -> Result { debug!("Initializing WhisperModel"); let device = Device::new_metal(0).unwrap_or(Device::new_cuda(0).unwrap_or(Device::Cpu)); - debug!("device = {:?}", device); + info!("device = {:?}", device); debug!("Fetching model files"); let (config_filename, tokenizer_filename, weights_filename) = { @@ -54,7 +53,7 @@ impl WhisperModel { debug!("Parsing config and tokenizer"); let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; - let tokenizer = Box::new(Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; // tokenizer.with_pre_tokenizer(PreT) debug!("Loading model weights"); let vb = @@ -63,39 +62,13 @@ impl WhisperModel { let model = Model::Normal(whisper); - debug!("Loading mel filters"); - let mel_bytes = match config.num_mel_bins { - 80 => include_bytes!("../../models/whisper/melfilters.bytes").as_slice(), - 128 => include_bytes!("../../models/whisper/melfilters128.bytes").as_slice(), - nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"), - }; - let mut mel_filters = vec![0f32; mel_bytes.len() / 4]; - ::read_f32_into( - mel_bytes, - &mut mel_filters, - ); - debug!("WhisperModel initialization complete"); Ok(Self { model, tokenizer, device, - mel_filters, }) } - - pub fn reset(&mut self) -> Result<()> { - match &mut self.model { - Model::Normal(m) => { - m.reset_kv_cache(); - } - Model::Quantized(m) => { - m.reset_kv_cache(); - } - } - - Ok(()) - } } #[derive(Debug, Clone)] diff --git a/screenpipe-audio/src/whisper/process_chunk.rs b/screenpipe-audio/src/whisper/process_chunk.rs index fe41c6d913..ecb251913d 100644 --- a/screenpipe-audio/src/whisper/process_chunk.rs +++ b/screenpipe-audio/src/whisper/process_chunk.rs @@ -1,37 +1,33 @@ use super::Segment; use crate::{ - audio_processing::pcm_to_mel, multilingual, whisper::{Decoder, WhisperModel}, }; use anyhow::Result; use candle::Tensor; +use candle_transformers::models::whisper::audio; use lazy_static::lazy_static; use log::debug; use regex::Regex; use screenpipe_core::Language; -use std::{collections::HashSet, sync::Arc}; -use tokio::sync::Mutex; +use std::collections::HashSet; lazy_static! { static ref TOKEN_REGEX: Regex = Regex::new(r"<\|\d{1,2}\.\d{1,2}\|>").unwrap(); } -pub async fn process_with_whisper( - whisper_model: Arc>, +pub fn process_with_whisper( + whisper_model: &mut WhisperModel, audio: &[f32], + mel_filters: &[f32], languages: Vec, ) -> Result { - let mut whisper = whisper_model.lock().await; - let WhisperModel { - model, - tokenizer, - device, - mel_filters, - } = &mut *whisper; + let model = &mut whisper_model.model; + let tokenizer = &whisper_model.tokenizer; + let device = &whisper_model.device; debug!("converting pcm to mel spectrogram"); - let mel = pcm_to_mel(model.config(), audio, mel_filters).await; + let mel = audio::pcm_to_mel(model.config(), audio, mel_filters); let mel_len = mel.len(); debug!("creating tensor from mel spectrogram"); @@ -55,12 +51,11 @@ pub async fn process_with_whisper( debug!("initializing decoder"); let mut dc = Decoder::new(model, tokenizer, 42, device, language_token, true, false)?; - dc.reset_kv_cache(); + debug!("starting decoding process"); let segments = dc.run(&mel)?; debug!("decoding complete"); - dc.reset_kv_cache(); process_segments(segments) } diff --git a/screenpipe-audio/tests/accuracy_test.rs b/screenpipe-audio/tests/accuracy_test.rs index 80a1cd1768..5dae1c4918 100644 --- a/screenpipe-audio/tests/accuracy_test.rs +++ b/screenpipe-audio/tests/accuracy_test.rs @@ -107,15 +107,16 @@ async fn test_transcription_accuracy() { }; let mut segments = prepare_segments( - Arc::new(audio_data), + &audio_data, vad_engine.clone(), &segmentation_model_path, - Arc::new(std::sync::Mutex::new(embedding_manager)), + embedding_manager, embedding_extractor, &audio_input.device.name, ) .await .unwrap(); + let mut whisper_model_guard = whisper_model.lock().await; let mut transcription = String::new(); while let Some(segment) = segments.recv().await { @@ -123,7 +124,7 @@ async fn test_transcription_accuracy() { &segment.samples, audio_input.sample_rate, &audio_input.device.to_string(), - whisper_model.clone(), + &mut whisper_model_guard, Arc::new(AudioTranscriptionEngine::WhisperLargeV3Turbo), None, vec![Language::English], @@ -133,6 +134,7 @@ async fn test_transcription_accuracy() { transcription.push_str(&transcript); } + drop(whisper_model_guard); let distance = levenshtein(expected_transcription, &transcription.to_lowercase()); let accuracy = 1.0 - (distance as f64 / expected_transcription.len() as f64); diff --git a/screenpipe-audio/tests/core_tests.rs b/screenpipe-audio/tests/core_tests.rs index 5957d4ce5a..ac698fefbc 100644 --- a/screenpipe-audio/tests/core_tests.rs +++ b/screenpipe-audio/tests/core_tests.rs @@ -13,7 +13,7 @@ mod tests { AudioTranscriptionEngine, }; use screenpipe_audio::{parse_audio_device, record_and_transcribe}; - use screenpipe_core::{DeviceManager, Language}; + use screenpipe_core::Language; use std::path::{Path, PathBuf}; use std::process::Command; use std::str::FromStr; @@ -209,14 +209,14 @@ mod tests { let output_path = PathBuf::from(format!("test_output_{}.mp4", Utc::now().timestamp_millis())); let output_path_2 = output_path.clone(); - let (whisper_sender, whisper_receiver) = create_whisper_channel( + let (whisper_sender, whisper_receiver, _) = create_whisper_channel( Arc::new(AudioTranscriptionEngine::WhisperTiny), VadEngineEnum::WebRtc, None, &output_path_2.clone(), VadSensitivity::High, vec![], - Arc::new(DeviceManager::default()), + None, ) .await .unwrap(); @@ -331,15 +331,16 @@ mod tests { let embedding_manager = EmbeddingManager::new(usize::MAX); let mut segments = prepare_segments( - audio_input.data, + &audio_input.data, vad_engine.clone(), &segmentation_model_path, - Arc::new(std::sync::Mutex::new(embedding_manager)), + embedding_manager, embedding_extractor, &audio_input.device.to_string(), ) .await .unwrap(); + let mut whisper_model_guard = whisper_model.lock().await; let mut transcription_result = String::new(); while let Some(segment) = segments.recv().await { @@ -347,7 +348,7 @@ mod tests { &segment.samples, audio_input.sample_rate, &audio_input.device.to_string(), - whisper_model.clone(), + &mut whisper_model_guard, Arc::new(AudioTranscriptionEngine::WhisperLargeV3Turbo), None, vec![Language::Arabic], @@ -358,6 +359,7 @@ mod tests { transcription_result.push_str(&transcript); transcription_result.push('\n'); } + drop(whisper_model_guard); debug!("Received transcription: {:?}", transcription_result); // Check if we received a valid transcription @@ -414,10 +416,8 @@ mod tests { let embedding_manager = EmbeddingManager::new(usize::MAX); // Initialize the WhisperModel - let whisper_model = Arc::new(Mutex::new( - WhisperModel::new(&AudioTranscriptionEngine::WhisperLargeV3Turbo) - .expect("Failed to initialize WhisperModel"), - )); + let mut whisper_model = WhisperModel::new(&AudioTranscriptionEngine::WhisperLargeV3Turbo) + .expect("Failed to initialize WhisperModel"); // Initialize VAD engine let vad_engine: Box = Box::new(SileroVad::new().await.unwrap()); @@ -427,10 +427,10 @@ mod tests { let start_time = Instant::now(); let mut segments = prepare_segments( - audio_input.data, + &audio_input.data, vad_engine.clone(), &segmentation_model_path, - Arc::new(std::sync::Mutex::new(embedding_manager)), + embedding_manager, embedding_extractor, &audio_input.device.to_string(), ) @@ -443,7 +443,7 @@ mod tests { &segment.samples, audio_input.sample_rate, &audio_input.device.to_string(), - whisper_model.clone(), + &mut whisper_model, Arc::new(AudioTranscriptionEngine::WhisperLargeV3Turbo), None, vec![Language::English], diff --git a/screenpipe-audio/tests/realtime_test.rs b/screenpipe-audio/tests/realtime_test.rs index c8427263fc..f9a6736037 100644 --- a/screenpipe-audio/tests/realtime_test.rs +++ b/screenpipe-audio/tests/realtime_test.rs @@ -1,9 +1,5 @@ -use futures::StreamExt; use screenpipe_audio::deepgram::start_deepgram_stream; -use screenpipe_audio::pcm_decode; -use screenpipe_audio::realtime::RealtimeTranscriptionEvent; -use screenpipe_core::{AudioDevice, AudioDeviceType}; -use screenpipe_events::subscribe_to_event; +use screenpipe_audio::{pcm_decode, AudioDevice, DeviceType}; use std::{ sync::{atomic::AtomicBool, Arc}, time::Duration, @@ -17,16 +13,17 @@ use tokio::sync::broadcast; #[tokio::test] #[ignore] async fn test_realtime_transcription() { - let (samples, sample_rate) = pcm_decode("test_data/accuracy1.wav").unwrap_or_else(|e| { - panic!("Failed to decode audio: {}", e); - }); + let (samples, sample_rate) = pcm_decode("test_data/accuracy1.wav").unwrap(); let (stream_tx, stream_rx) = broadcast::channel(sample_rate as usize * 3); - let device = AudioDevice::new("test".to_string(), AudioDeviceType::Output); + let device = AudioDevice::new("test".to_string(), DeviceType::Output); let is_running = Arc::new(AtomicBool::new(true)); let deepgram_api_key = std::env::var("CUSTOM_DEEPGRAM_API_KEY").unwrap(); + let (realtime_transcription_sender, realtime_transcription_receiver) = + broadcast::channel(10000); + let is_running_clone = is_running.clone(); tokio::spawn(async move { @@ -34,8 +31,9 @@ async fn test_realtime_transcription() { stream_rx, Arc::new(device), sample_rate, + Arc::new(realtime_transcription_sender), is_running_clone, - Arc::new([].to_vec()), + vec![], Some(deepgram_api_key), ) .await; @@ -44,17 +42,17 @@ async fn test_realtime_transcription() { }); let transcription_receiver_handle = tokio::spawn(async move { - let mut receiver = subscribe_to_event::("transcription"); + let mut receiver = realtime_transcription_receiver; loop { tokio::select! { - event = receiver.next() => { - if let Some(event) = event { - println!("Received event: {:?}", event.data.transcription); + event = receiver.recv() => { + if let Ok(event) = event { + println!("Received event: {:?}", event.transcription); } else { println!("Receiver closed"); } }, - _ = tokio::time::sleep(Duration::from_secs(15)) => { + _ = tokio::time::sleep(Duration::from_secs(10)) => { println!("Timeout"); return; } @@ -62,15 +60,13 @@ async fn test_realtime_transcription() { } }); - tokio::time::sleep(Duration::from_secs(10)).await; + tokio::time::sleep(Duration::from_secs(5)).await; let tx = stream_tx.clone(); let samples = samples.clone(); for sample in samples.chunks(sample_rate as usize * 5) { - tx.send(sample.to_vec()).unwrap_or_else(|e| { - panic!("Failed to send sample: {}", e); - }); + tx.send(sample.to_vec()).unwrap(); tokio::time::sleep(Duration::from_millis(10)).await; } diff --git a/screenpipe-audio/tests/speaker_identification.rs b/screenpipe-audio/tests/speaker_identification.rs index 91a7e9ab07..2199643717 100644 --- a/screenpipe-audio/tests/speaker_identification.rs +++ b/screenpipe-audio/tests/speaker_identification.rs @@ -37,7 +37,7 @@ mod tests { .ok_or_else(|| anyhow::anyhow!("Invalid embedding model path")) .unwrap(), ) - .unwrap(), + .unwrap() )); let embedding_manager = screenpipe_audio::pyannote::identify::EmbeddingManager::new(usize::MAX); @@ -76,7 +76,7 @@ mod tests { 16000, &segmentation_model_path, embedding_extractor, - Arc::new(Mutex::new(embedding_manager)), + embedding_manager.clone(), ) .unwrap() .collect::>(); diff --git a/screenpipe-integrations/src/unstructured_ocr.rs b/screenpipe-integrations/src/unstructured_ocr.rs index 3472ba35d5..d67b4bdeef 100644 --- a/screenpipe-integrations/src/unstructured_ocr.rs +++ b/screenpipe-integrations/src/unstructured_ocr.rs @@ -12,13 +12,12 @@ use std::io::Cursor; use std::io::Read; use std::io::Write; use std::path::PathBuf; -use std::sync::Arc; use tempfile::NamedTempFile; use tokio::time::{timeout, Duration}; pub async fn perform_ocr_cloud( image: &DynamicImage, - languages: Arc<[Language]>, + languages: Vec, ) -> Result<(String, String, Option)> { let api_key = match env::var("UNSTRUCTURED_API_KEY") { Ok(key) => key, diff --git a/screenpipe-server/Cargo.toml b/screenpipe-server/Cargo.toml index b72230dad4..d289634998 100644 --- a/screenpipe-server/Cargo.toml +++ b/screenpipe-server/Cargo.toml @@ -118,6 +118,8 @@ regex = "1.10.0" lru = "0.13.0" tokio-util = { version = "0.7", features = ["io"] } + +dashmap = "6.1.0" [dev-dependencies] env_logger = "0.10" tempfile = "3.3.0" diff --git a/screenpipe-server/benches/db_benchmarks.rs b/screenpipe-server/benches/db_benchmarks.rs index 0c1f3f685e..9c7d391d38 100644 --- a/screenpipe-server/benches/db_benchmarks.rs +++ b/screenpipe-server/benches/db_benchmarks.rs @@ -2,7 +2,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; use rand::Rng; -use screenpipe_core::AudioDevice; +use screenpipe_audio::AudioDevice; use screenpipe_server::{db_types::ContentType, DatabaseManager}; use screenpipe_vision::OcrEngine; use std::sync::Arc; @@ -41,7 +41,7 @@ async fn setup_large_db(size: usize) -> DatabaseManager { "test_engine", &AudioDevice::new( "test_device".to_string(), - screenpipe_core::AudioDeviceType::Input, + screenpipe_audio::DeviceType::Input, ), None, None, diff --git a/screenpipe-server/src/add.rs b/screenpipe-server/src/add.rs index 5d62542268..3a15ca7955 100644 --- a/screenpipe-server/src/add.rs +++ b/screenpipe-server/src/add.rs @@ -140,7 +140,7 @@ pub async fn handle_index_command( ) .await?; - let mut previous_image: Option> = None; + let mut previous_image: Option = None; let mut frame_counter: i64 = 0; for (idx, frame) in frames.iter().enumerate() { @@ -162,7 +162,7 @@ pub async fn handle_index_command( continue; } - previous_image = Some(Arc::new(frame.clone())); + previous_image = Some(frame.clone()); // Use specified OCR engine or fall back to platform default let engine = match ocr_engine { @@ -183,15 +183,13 @@ pub async fn handle_index_command( // Do OCR processing directly let (text, _, confidence): (String, String, Option) = match engine.clone() { #[cfg(target_os = "macos")] - OcrEngine::AppleNative => perform_ocr_apple(frame, Arc::new([])), + OcrEngine::AppleNative => perform_ocr_apple(frame, &[]), #[cfg(target_os = "windows")] OcrEngine::WindowsNative => perform_ocr_windows(&frame).await.unwrap(), _ => { #[cfg(not(any(target_os = "macos", target_os = "windows")))] - { - perform_ocr_tesseract(&frame, Arc::new([])) - } - #[cfg(any(target_os = "macos", target_os = "windows"))] + perform_ocr_tesseract(&frame, vec![]); + panic!("unsupported ocr engine"); } }; diff --git a/screenpipe-server/src/auto_destruct.rs b/screenpipe-server/src/auto_destruct.rs index 174bd95c2e..0ac8562648 100644 --- a/screenpipe-server/src/auto_destruct.rs +++ b/screenpipe-server/src/auto_destruct.rs @@ -67,6 +67,8 @@ pub async fn watch_pid(pid: u32) -> bool { let pid_alive = String::from_utf8_lossy(&pid_output.stdout).contains(&pid.to_string()); let app_alive = !String::from_utf8_lossy(&app_output.stdout).is_empty(); + info!("pid alive: {}, app alive: {}", pid_alive, app_alive); + if !pid_alive || !app_alive { return true; } diff --git a/screenpipe-server/src/bin/screenpipe-server.rs b/screenpipe-server/src/bin/screenpipe-server.rs index 4f68f7fcc3..741ca4d729 100644 --- a/screenpipe-server/src/bin/screenpipe-server.rs +++ b/screenpipe-server/src/bin/screenpipe-server.rs @@ -1,27 +1,17 @@ use clap::Parser; #[allow(unused_imports)] use colored::Colorize; +use dashmap::DashMap; use dirs::home_dir; use futures::pin_mut; -use futures::StreamExt; use port_check::is_local_ipv4_port_free; -use screenpipe_audio::vad_engine::VadSensitivity; use screenpipe_audio::{ - create_whisper_channel, default_input_device, default_output_device, list_audio_devices, - parse_audio_device, VadEngineEnum, + default_input_device, default_output_device, list_audio_devices, parse_audio_device, + AudioDevice, DeviceControl, }; - -use screenpipe_audio::{AudioInput, TranscriptionResult}; -use screenpipe_core::DeviceControl; -use screenpipe_core::DeviceType; -use screenpipe_core::{find_ffmpeg_path, DeviceManager}; -use screenpipe_server::VisionDeviceControlRequest; +use screenpipe_core::find_ffmpeg_path; use screenpipe_server::{ - cli::{ - AudioCommand, Cli, CliAudioTranscriptionEngine, CliOcrEngine, Command, OutputFormat, - PipeCommand, VisionCommand, - }, - core::{AudioConfig, RecordingConfig, VisionConfig}, + cli::{Cli, CliAudioTranscriptionEngine, CliOcrEngine, Command, OutputFormat, PipeCommand}, handle_index_command, pipe_manager::PipeInfo, start_continuous_recording, watch_pid, DatabaseManager, PipeManager, ResourceMonitor, Server, @@ -29,16 +19,15 @@ use screenpipe_server::{ use screenpipe_vision::monitor::list_monitors; #[cfg(target_os = "macos")] use screenpipe_vision::run_ui; -use serde::Deserialize; use serde_json::{json, Value}; use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, env, fs, io::Write, net::SocketAddr, ops::Deref, path::PathBuf, - sync::Arc, + sync::{atomic::AtomicBool, Arc}, time::Duration, }; use tokio::{runtime::Runtime, signal, sync::broadcast}; @@ -47,15 +36,16 @@ use tracing_appender::non_blocking::WorkerGuard; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_subscriber::prelude::__tracing_subscriber_SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; -use tracing_subscriber::Layer; use tracing_subscriber::{fmt, EnvFilter}; -// Add this struct to handle the server response structure -#[derive(Debug, Deserialize)] -#[allow(dead_code)] -struct ApiResponse { - data: T, - success: bool, +fn print_devices(devices: &[AudioDevice]) { + println!("available audio devices:"); + for device in devices.iter() { + println!(" {}", device); + } + + #[cfg(target_os = "macos")] + println!("on macos, it's not intuitive but output devices are your displays"); } const DISPLAY: &str = r" @@ -93,67 +83,48 @@ fn setup_logging(local_data_dir: &PathBuf, cli: &Cli) -> anyhow::Result filter.add_directive(directive), - Err(e) => { - eprintln!( - "warning: invalid log directive '{}': {}", - module_directive, e - ); - filter - } + let env_filter = EnvFilter::from_default_env() + .add_directive("info".parse().unwrap()) + .add_directive("tokenizers=error".parse().unwrap()) + .add_directive("rusty_tesseract=error".parse().unwrap()) + .add_directive("symphonia=error".parse().unwrap()) + .add_directive("hf_hub=error".parse().unwrap()); + + // filtering out xcap::platform::impl_window - Access is denied. (0x80070005) + // which is noise + #[cfg(target_os = "windows")] + let env_filter = env_filter.add_directive("xcap::platform::impl_window=off".parse().unwrap()); + + let env_filter = env::var("SCREENPIPE_LOG") + .unwrap_or_default() + .split(',') + .filter(|s| !s.is_empty()) + .fold( + env_filter, + |filter, module_directive| match module_directive.parse() { + Ok(directive) => filter.add_directive(directive), + Err(e) => { + eprintln!( + "warning: invalid log directive '{}': {}", + module_directive, e + ); + filter } - }); + }, + ); - if cli.debug { - filter.add_directive("screenpipe=debug".parse().unwrap()) - } else { - filter - } + let env_filter = if cli.debug { + env_filter.add_directive("screenpipe=debug".parse().unwrap()) + } else { + env_filter }; - let tracing_registry = tracing_subscriber::registry() - .with( - fmt::layer() - .with_writer(std::io::stdout) - .with_filter(make_env_filter()), - ) - .with( - fmt::layer() - .with_writer(non_blocking) - .with_filter(make_env_filter()), - ); - - #[cfg(feature = "debug-console")] - let tracing_registry = tracing_registry.with( - console_subscriber::spawn().with_filter( - EnvFilter::from_default_env() - .add_directive("tokio=trace".parse().unwrap()) - .add_directive("runtime=trace".parse().unwrap()), - ), - ); + tracing_subscriber::registry() + .with(env_filter) + .with(fmt::layer().with_writer(std::io::stdout)) + .with(fmt::layer().with_writer(non_blocking)) + .init(); - tracing_registry.init(); Ok(guard) } @@ -186,7 +157,8 @@ async fn main() -> anyhow::Result<()> { let local_data_dir = get_base_dir(&cli.data_dir)?; let local_data_dir_clone = local_data_dir.clone(); - match &cli.command { + // Only set up logging if we're not running a pipe command with JSON output + let should_log = match &cli.command { Some(Command::Pipe { subcommand }) => { matches!( subcommand, @@ -210,110 +182,24 @@ async fn main() -> anyhow::Result<()> { output: OutputFormat::Text, .. }) => true, - Some(Command::Doctor { - output: OutputFormat::Text, - .. - }) => true, - Some(Command::Vision { subcommand }) => { - matches!( - subcommand, - VisionCommand::List { - output: OutputFormat::Text, - .. - } - ) - } - Some(Command::Audio { subcommand }) => { - matches!( - subcommand, - AudioCommand::List { - output: OutputFormat::Text, - .. - } - ) - } _ => true, }; // Store the guard in a variable that lives for the entire main function - let _log_guard = Some(setup_logging(&local_data_dir, &cli)?); + let _log_guard = if should_log { + Some(setup_logging(&local_data_dir, &cli)?) + } else { + None + }; let pipe_manager = Arc::new(PipeManager::new(local_data_dir_clone.clone())); - if let Some(Command::Completions { shell }) = &cli.command { - cli.handle_completions(*shell)?; - return Ok(()); - } if let Some(command) = cli.command { match command { Command::Pipe { subcommand } => { handle_pipe_command(subcommand, &pipe_manager).await?; return Ok(()); } - Command::Vision { subcommand } => match subcommand { - VisionCommand::List { output } => { - let monitors = list_monitors().await; - match output { - OutputFormat::Json => { - println!( - "{}", - serde_json::to_string_pretty(&json!({ - "data": monitors.iter().map(|m| { - json!({ - "id": m.id(), - "name": m.name(), - "x": m.x(), - "y": m.y(), - "width": m.width(), - "height": m.height(), - "rotation": m.rotation(), - "scale_factor": m.scale_factor(), - "frequency": m.frequency(), - "is_primary": m.is_primary() - }) - }).collect::>(), - "success": true - }))? - ); - } - OutputFormat::Text => { - println!("available monitors:"); - for monitor in monitors.iter() { - println!(" {}. {:?}", monitor.id(), monitor); - } - } - } - return Ok(()); - } - }, - Command::Audio { subcommand } => match subcommand { - AudioCommand::List { output } => { - let devices = list_audio_devices().await?; - match output { - OutputFormat::Json => { - println!( - "{}", - serde_json::to_string_pretty(&json!({ - "data": devices.iter().map(|d| { - json!({ - "name": d.name, - "device_type": d.device_type, - }) - }).collect::>(), - "success": true - }))? - ); - } - OutputFormat::Text => { - println!("available audio devices:"); - for device in devices.iter() { - println!(" {:?}", device); - } - } - } - return Ok(()); - } - }, #[allow(unused_variables)] Command::Setup { enable_beta } => { #[cfg(feature = "beta")] @@ -445,37 +331,20 @@ async fn main() -> anyhow::Result<()> { .await?; return Ok(()); } - Command::Doctor { output, fix } => { - handle_doctor_command(output, fix).await?; - return Ok(()); - } - Command::AddToPath { yes } => { - if !yes { - print!("add screenpipe to system PATH? [y/N] "); - std::io::stdout().flush()?; - let mut input = String::new(); - std::io::stdin().read_line(&mut input)?; - if !input.trim().eq_ignore_ascii_case("y") { - println!("operation cancelled"); - return Ok(()); - } - } + } + } - match ensure_screenpipe_in_path().await { - Ok(_) => { - println!("✓ screenpipe successfully added to PATH"); - println!("note: you may need to restart your terminal for changes to take effect"); - } - Err(e) => { - eprintln!("failed to add screenpipe to PATH: {}", e); - return Err(e); - } - } - return Ok(()); - } - Command::Completions { .. } => todo!(), + // Check if Screenpipe is present in PATH + // TODO: likely should not force user to install in PATH (eg brew, powershell, or button in UI) + match ensure_screenpipe_in_path().await { + Ok(_) => info!("screenpipe is available and properly set in the PATH"), + Err(e) => { + warn!("screenpipe PATH check failed: {}", e); + warn!("please ensure screenpipe is installed correctly and is in your PATH"); + // do not crash } } + if find_ffmpeg_path().is_none() { eprintln!("ffmpeg not found. please install ffmpeg and ensure it is in your path."); std::process::exit(1); @@ -490,16 +359,29 @@ async fn main() -> anyhow::Result<()> { let all_audio_devices = list_audio_devices().await?; let mut devices_status = HashMap::new(); + if cli.list_audio_devices { + print_devices(&all_audio_devices); + return Ok(()); + } let all_monitors = list_monitors().await; + if cli.list_monitors { + println!("available monitors:"); + for monitor in all_monitors.iter() { + println!(" {}. {:?}", monitor.id(), monitor); + } + return Ok(()); + } let mut audio_devices = Vec::new(); - let device_manager = Arc::new(DeviceManager::default()); + + let audio_devices_control = Arc::new(DashMap::new()); + let audio_devices_control_recording = audio_devices_control.clone(); + let mut realtime_audio_devices = Vec::new(); // Add all available audio devices to the controls for device in &all_audio_devices { let device_control = DeviceControl { - device: screenpipe_core::DeviceType::Audio(device.clone()), is_running: false, is_paused: false, }; @@ -512,16 +394,16 @@ async fn main() -> anyhow::Result<()> { if let Ok(input_device) = default_input_device() { audio_devices.push(Arc::new(input_device.clone())); let device_control = DeviceControl { - device: screenpipe_core::DeviceType::Audio(input_device.clone()), is_running: true, is_paused: false, }; devices_status.insert(input_device, device_control); } + // audio output only on macos <15.0 atm ? + // see https://github.com/mediar-ai/screenpipe/pull/106 if let Ok(output_device) = default_output_device() { audio_devices.push(Arc::new(output_device.clone())); let device_control = DeviceControl { - device: screenpipe_core::DeviceType::Audio(output_device.clone()), is_running: true, is_paused: false, }; @@ -533,7 +415,6 @@ async fn main() -> anyhow::Result<()> { let device = parse_audio_device(d).expect("failed to parse audio device"); audio_devices.push(Arc::new(device.clone())); let device_control = DeviceControl { - device: screenpipe_core::DeviceType::Audio(device.clone()), is_running: true, is_paused: false, }; @@ -544,29 +425,17 @@ async fn main() -> anyhow::Result<()> { if audio_devices.is_empty() { eprintln!("no audio devices available. audio recording will be disabled."); } else { - for (index, device) in audio_devices.iter().enumerate() { - if cli.disable_audio { - break; - } + for device in &audio_devices { + let device_control = DeviceControl { + is_running: true, + is_paused: false, + }; let device_clone = device.deref().clone(); - let sender_clone = device_manager.clone(); - // send signal after everything started, with staggered delays + let sender_clone = audio_devices_control.clone(); + // send signal after everything started tokio::spawn(async move { - // Base delay of 15s + 5s for each device index - let delay = Duration::from_secs(15) + Duration::from_secs(5 * index as u64); - tokio::time::sleep(delay).await; - info!( - "initializing audio device control for device: {}", - device_clone.name - ); - - let _ = sender_clone - .update_device(DeviceControl { - device: screenpipe_core::DeviceType::Audio(device_clone), - is_running: true, - is_paused: false, - }) - .await; + tokio::time::sleep(Duration::from_secs(15)).await; + sender_clone.insert(device_clone, device_control); }); } } @@ -575,15 +444,17 @@ async fn main() -> anyhow::Result<()> { if cli.realtime_audio_device.is_empty() { // Use default devices if let Ok(input_device) = default_input_device() { - realtime_audio_devices.push(input_device.clone()); + realtime_audio_devices.push(Arc::new(input_device.clone())); } + // audio output only on macos <15.0 atm ? + // see https://github.com/mediar-ai/screenpipe/pull/106 if let Ok(output_device) = default_output_device() { - realtime_audio_devices.push(output_device.clone()); + realtime_audio_devices.push(Arc::new(output_device.clone())); } } else { for d in &cli.realtime_audio_device { let device = parse_audio_device(d).expect("failed to parse audio device"); - realtime_audio_devices.push(device.clone()); + realtime_audio_devices.push(Arc::new(device.clone())); } } @@ -593,7 +464,7 @@ async fn main() -> anyhow::Result<()> { } } - let resource_monitor = ResourceMonitor::new(!cli.disable_telemetry); + let resource_monitor = ResourceMonitor::new(); resource_monitor.start_monitoring(Duration::from_secs(10)); let db = Arc::new( @@ -607,6 +478,11 @@ async fn main() -> anyhow::Result<()> { let db_server = db.clone(); + // Channel for controlling the recorder ! TODO RENAME SHIT + let vision_control = Arc::new(AtomicBool::new(true)); + + let vision_control_server_clone = vision_control.clone(); + let warning_ocr_engine_clone = cli.ocr_engine.clone(); let warning_audio_transcription_engine_clone = cli.audio_transcription_engine.clone(); let monitor_ids = if cli.monitor_id.is_empty() { @@ -615,34 +491,6 @@ async fn main() -> anyhow::Result<()> { cli.monitor_id.clone() }; - // Initialize vision devices control based on user selected monitors - { - for (index, monitor_id) in monitor_ids.clone().into_iter().enumerate() { - if cli.disable_vision { - break; - } - let device_manager = device_manager.clone(); - // Send signal after everything started - tokio::spawn(async move { - // Base delay of 15s + 5s for each monitor index - let delay = Duration::from_secs(15) + Duration::from_secs(5 * index as u64); - tokio::time::sleep(delay).await; - info!( - "initializing vision device control for monitor: {}", - monitor_id - ); - let device_control = DeviceControl { - device: DeviceType::Vision(monitor_id), - is_running: true, - is_paused: false, - }; - if let Err(e) = device_manager.update_device(device_control).await { - warn!("failed to initialize vision device control: {}", e); - } - }); - } - } - let languages = cli.unique_languages().unwrap(); let languages_clone = languages.clone(); @@ -660,7 +508,9 @@ async fn main() -> anyhow::Result<()> { let db_clone = Arc::clone(&db); let output_path_clone = Arc::new(local_data_dir.join("data").to_string_lossy().into_owned()); + let vision_control_clone = Arc::clone(&vision_control); let shutdown_tx_clone = shutdown_tx.clone(); + let monitor_ids_clone = monitor_ids.clone(); let ignored_windows_clone = cli.ignored_windows.clone(); let included_windows_clone = cli.included_windows.clone(); let realtime_audio_devices_clone = realtime_audio_devices.clone(); @@ -673,79 +523,46 @@ async fn main() -> anyhow::Result<()> { }; let audio_chunk_duration = Duration::from_secs(cli.audio_chunk_duration); - let dm_clone = device_manager.clone(); - let device_manager_clone = device_manager.clone(); - let device_manager_clone_2 = device_manager.clone(); - - let (whisper_sender, whisper_receiver) = if cli.disable_audio { - // Create a dummy channel if no audio devices are available, e.g. audio disabled - let (input_sender, _): ( - crossbeam::channel::Sender, - crossbeam::channel::Receiver, - ) = crossbeam::channel::bounded(100); - let (_, output_receiver): ( - crossbeam::channel::Sender, - crossbeam::channel::Receiver, - ) = crossbeam::channel::bounded(100); - (input_sender, output_receiver) - } else { - create_whisper_channel( - Arc::new(cli.audio_transcription_engine.clone().into()), - VadEngineEnum::from(cli.vad_engine), - cli.deepgram_api_key.clone(), - &PathBuf::from(output_path_clone.as_ref()), - VadSensitivity::from(cli.vad_sensitivity.clone()), - languages.clone(), - device_manager.clone(), - ) - .await? - }; - + let (realtime_transcription_sender, _) = tokio::sync::broadcast::channel(1000); + let realtime_transcription_sender_clone = realtime_transcription_sender.clone(); + let (realtime_vision_sender, _) = tokio::sync::broadcast::channel(1000); + let realtime_vision_sender = Arc::new(realtime_vision_sender.clone()); + let realtime_vision_sender_clone = realtime_vision_sender.clone(); let handle = { let runtime = &tokio::runtime::Handle::current(); runtime.spawn(async move { loop { + let realtime_vision_sender_clone = realtime_vision_sender.clone(); let vad_engine_clone = vad_engine.clone(); // Clone it here for each iteration let mut shutdown_rx = shutdown_tx_clone.subscribe(); - - // Create the configs - let recording_config = RecordingConfig { - output_path: output_path_clone.clone(), - fps, - audio_chunk_duration, - video_chunk_duration: Duration::from_secs(cli.video_chunk_duration), - use_pii_removal: cli.use_pii_removal, - capture_unfocused_windows: cli.capture_unfocused_windows, - languages: Arc::from(languages.clone()), - }; - - let audio_config = AudioConfig { - disabled: cli.disable_audio, - transcription_engine: Arc::new(cli.audio_transcription_engine.clone().into()), - vad_engine: vad_engine_clone, - vad_sensitivity: cli.vad_sensitivity.clone(), - deepgram_api_key: cli.deepgram_api_key.clone(), - realtime_enabled: cli.enable_realtime_audio_transcription, - realtime_devices: Arc::from(realtime_audio_devices.clone()), - whisper_sender: whisper_sender.clone(), - whisper_receiver: whisper_receiver.clone(), - }; - - let vision_config = VisionConfig { - disabled: cli.disable_vision, - ocr_engine: Arc::new(cli.ocr_engine.clone().into()), - ignored_windows: Arc::from(cli.ignored_windows.clone()), - include_windows: Arc::from(cli.included_windows.clone()), - }; - + let realtime_transcription_sender_clone = realtime_transcription_sender.clone(); let recording_future = start_continuous_recording( db_clone.clone(), - recording_config, - audio_config, - vision_config, + output_path_clone.clone(), + fps, + audio_chunk_duration, + Duration::from_secs(cli.video_chunk_duration), + vision_control_clone.clone(), + audio_devices_control_recording.clone(), + cli.disable_audio, + Arc::new(cli.audio_transcription_engine.clone().into()), + Arc::new(cli.ocr_engine.clone().into()), + monitor_ids_clone.clone(), + cli.use_pii_removal, + cli.disable_vision, + vad_engine_clone, &vision_handle, &audio_handle, - device_manager.clone(), + &cli.ignored_windows, + &cli.included_windows, + cli.deepgram_api_key.clone(), + cli.vad_sensitivity.clone(), + languages.clone(), + cli.capture_unfocused_windows, + realtime_audio_devices.clone(), + cli.enable_realtime_audio_transcription, + Arc::new(realtime_transcription_sender_clone), // Use the cloned sender + realtime_vision_sender_clone, ); let result = tokio::select! { @@ -786,10 +603,14 @@ async fn main() -> anyhow::Result<()> { } }; + let (audio_devices_tx, _) = broadcast::channel(100); + let audio_devices_tx_clone = Arc::new(audio_devices_tx.clone()); + + let realtime_vision_sender_clone = realtime_vision_sender_clone.clone(); + // TODO: Add SSE stream for realtime audio transcription let server = Server::new( db_server, SocketAddr::from(([127, 0, 0, 1], cli.port)), - dm_clone, local_data_dir_clone_2, pipe_manager.clone(), cli.disable_vision, @@ -797,76 +618,43 @@ async fn main() -> anyhow::Result<()> { cli.enable_ui_monitoring, ); + let mut rx = audio_devices_tx.subscribe(); + let audio_devices_control_for_spawn = audio_devices_control.clone(); tokio::spawn(async move { - // Watch for device changes using the new stream API - let mut device_changes = device_manager_clone.watch_devices().await; - - while let Some(change) = device_changes.next().await { - // ignore if vision disabled and its a vision device - if cli.disable_vision && change.control.device.is_vision() { - continue; - } - if cli.disable_audio && change.control.device.is_audio() { - continue; - } - debug!("received device update: {:?}", change); - if let Err(e) = handle_device_update( - &change.control.device, - change.control.clone(), - &device_manager_clone, - ) - .await + while let Ok((device, control)) = rx.recv().await { + if let Err(e) = + handle_device_update(&device, control, &audio_devices_control_for_spawn).await { error!("Device update failed: {}", e); continue; } } - info!("device control task stopped"); + info!("Device monitoring task completed"); }); async fn handle_device_update( - device: &DeviceType, + device: &AudioDevice, control: DeviceControl, - devices_control: &Arc, + devices_control: &DashMap, ) -> anyhow::Result<()> { - debug!("received device update"); - - match device { - DeviceType::Audio(device) => { - match list_audio_devices().await { - Ok(available_devices) => { - if !available_devices.contains(device) { - warn!("attempted to control non-existent device: {}", device.name); - return Err(anyhow::anyhow!( - "attempted to control non-existent device: {}", - device.name - )); - } - - // Update the device state using new DeviceManager API - devices_control.update_device(control).await?; - Ok(()) - } - Err(e) => { - warn!("failed to list audio devices: {}", e); - Err(anyhow::anyhow!("failed to list audio devices: {}", e)) - } - } - } - DeviceType::Vision(monitor_id) => { - let monitors = list_monitors().await; - if !monitors.iter().any(|m| m.id() == *monitor_id) { - warn!("attempted to control non-existent device: {}", monitor_id); + match list_audio_devices().await { + Ok(available_devices) => { + if !available_devices.contains(device) { return Err(anyhow::anyhow!( "attempted to control non-existent device: {}", - monitor_id + device.name )); } - // Update the device state using new DeviceManager API - devices_control.update_device(control).await?; + // Update the device state + devices_control.insert(device.clone(), control.clone()); + info!( + "Device state changed: {} - running: {}", + device.name, control.is_running + ); Ok(()) } + Err(e) => Err(anyhow::anyhow!("failed to list audio devices: {}", e)), } } @@ -946,7 +734,6 @@ async fn main() -> anyhow::Result<()> { "│ frame cache │ {:<34} │", cli.enable_frame_cache ); - println!("│ use all monitors │ {:<34} │", cli.use_all_monitors); const VALUE_WIDTH: usize = 34; @@ -1057,7 +844,7 @@ async fn main() -> anyhow::Result<()> { .enumerate() .take(MAX_ITEMS_TO_DISPLAY) { - let device_str = device.to_string(); + let device_str = device.deref().to_string(); let formatted_device = format_cell(&device_str, VALUE_WIDTH); println!("│ {:<22} │ {:<34} │", "", formatted_device); @@ -1165,41 +952,11 @@ async fn main() -> anyhow::Result<()> { if let Some(pid) = cli.auto_destruct_pid { info!("watching pid {} for auto-destruction", pid); let shutdown_tx_clone = shutdown_tx.clone(); - let pipe_manager = pipe_manager.clone(); tokio::spawn(async move { - // sleep for 1 seconds - tokio::time::sleep(std::time::Duration::from_secs(1)).await; + // sleep for 5 seconds + tokio::time::sleep(std::time::Duration::from_secs(5)).await; if watch_pid(pid).await { info!("Watched pid ({}) has stopped, initiating shutdown", pid); - - // Get list of enabled pipes - let pipes = pipe_manager.list_pipes().await; - let enabled_pipes: Vec<_> = pipes.into_iter().filter(|p| p.enabled).collect(); - - // Stop all enabled pipes in parallel - let stop_futures = enabled_pipes.iter().map(|pipe| { - let pipe_manager = pipe_manager.clone(); - let pipe_id = pipe.id.clone(); - tokio::spawn(async move { - if let Err(e) = pipe_manager.stop_pipe(&pipe_id).await { - error!("failed to stop pipe {}: {}", pipe_id, e); - } - }) - }); - - // Wait for all pipes to stop with timeout - let timeout = tokio::time::sleep(Duration::from_secs(10)); - tokio::pin!(timeout); - - tokio::select! { - _ = futures::future::join_all(stop_futures) => { - info!("all pipes stopped successfully"); - } - _ = &mut timeout => { - warn!("timeout waiting for pipes to stop"); - } - } - let _ = shutdown_tx_clone.send(()); } }); @@ -1243,7 +1000,7 @@ async fn main() -> anyhow::Result<()> { loop { tokio::select! { - result = run_ui() => { + result = run_ui(realtime_vision_sender_clone.clone()) => { match result { Ok(_) => break, Err(e) => { @@ -1262,78 +1019,6 @@ async fn main() -> anyhow::Result<()> { }); } - if cli.use_all_monitors && !cli.disable_vision { - let client = reqwest::Client::new(); - let port = cli.port; - let mut shutdown_rx = shutdown_tx.subscribe(); - - tokio::spawn(async move { - // wait 10 seconds - tokio::time::sleep(Duration::from_secs(10)).await; - // Start all available monitors immediately - let initial_monitors: HashSet = - list_monitors().await.into_iter().map(|m| m.id()).collect(); - - for monitor_id in &initial_monitors { - info!("starting monitor: {}", monitor_id); - let _ = client - .post(format!("http://127.0.0.1:{}/vision/start", port)) - .json(&VisionDeviceControlRequest::new(*monitor_id)) - .send() - .await - .map_err(|e| error!("failed to start monitor {}: {}", monitor_id, e)); - } - - let mut previous_monitors = initial_monitors; - - loop { - tokio::select! { - _ = shutdown_rx.recv() => { - info!("stopping monitor polling due to shutdown signal"); - break; - } - _ = tokio::time::sleep(Duration::from_secs(5)) => { - let current_monitors: HashSet = list_monitors() - .await - .into_iter() - .map(|m| m.id()) - .collect(); - - // Handle new monitors - for monitor_id in current_monitors.difference(&previous_monitors) { - info!("new monitor detected: {}", monitor_id); - - // Start recording the new monitor using the API - let _ = client - .post(format!("http://127.0.0.1:{}/vision/start", port)) - .json(&VisionDeviceControlRequest::new(*monitor_id)) - .send() - .await - .map_err(|e| error!("failed to start new monitor {}: {}", monitor_id, e)); - } - - // Handle removed monitors - for monitor_id in previous_monitors.difference(¤t_monitors) { - info!("monitor removed: {}", monitor_id); - - // Stop recording the removed monitor using the API - let _ = client - .post(format!("http://127.0.0.1:{}/vision/stop", port)) - .json(&VisionDeviceControlRequest::new(*monitor_id)) - .send() - .await - .map_err(|e| { - error!("failed to stop removed monitor {}: {}", monitor_id, e) - }); - } - - previous_monitors = current_monitors; - } - } - } - }); - } - tokio::select! { _ = handle => info!("recording completed"), result = &mut server_future => { @@ -1348,8 +1033,6 @@ async fn main() -> anyhow::Result<()> { } } - device_manager_clone_2.shutdown().await; - tokio::task::block_in_place(|| { drop(vision_runtime); drop(audio_runtime); @@ -1466,24 +1149,24 @@ async fn handle_pipe_command( } PipeCommand::Info { id, output, port } => { - let info: ApiResponse = match client + let info = match client .get(format!("{}:{}/pipes/info/{}", server_url, port, id)) .send() .await { Ok(response) if response.status().is_success() => response.json().await?, - _ => ApiResponse { - data: pipe_manager + _ => { + println!("note: server not running, showing pipe configuration"); + pipe_manager .get_pipe_info(&id) .await - .ok_or_else(|| anyhow::anyhow!("pipe not found"))?, - success: true, - }, + .ok_or_else(|| anyhow::anyhow!("pipe not found"))? + } }; match output { - OutputFormat::Json => println!("{}", serde_json::to_string_pretty(&info.data)?), - OutputFormat::Text => println!("pipe info: {:?}", info.data), + OutputFormat::Json => println!("{}", serde_json::to_string_pretty(&info)?), + OutputFormat::Text => println!("pipe info: {:?}", info), } } PipeCommand::Enable { id, port } => { @@ -1761,93 +1444,3 @@ async fn check_ffmpeg() -> anyhow::Result<()> { Ok(()) } - -async fn handle_doctor_command(output: OutputFormat, fix: bool) -> anyhow::Result<()> { - let mut checks = Vec::new(); - - // Check ffmpeg - let ffmpeg_status = match find_ffmpeg_path() { - Some(path) => ("ffmpeg", true, format!("found at {}", path.display())), - None => ("ffmpeg", false, "not found in PATH".to_string()), - }; - checks.push(ffmpeg_status); - - // Check data directory - let data_dir = get_base_dir(&None)?; - let data_dir_status = ( - "data directory", - data_dir.exists(), - format!("{}", data_dir.display()), - ); - checks.push(data_dir_status); - - // Check database - let db_path = data_dir.join("db.sqlite"); - let db_exists = db_path.exists(); - let db_status = ("database", db_exists, format!("{}", db_path.display())); - checks.push(db_status); - - // Check audio devices - let audio_devices = match list_audio_devices().await { - Ok(devices) => { - let count = devices.len(); - ( - "audio devices", - count > 0, - format!("{} devices found", count), - ) - } - Err(e) => ("audio devices", false, format!("error: {}", e)), - }; - checks.push(audio_devices); - - // Check monitors - let monitors = list_monitors().await; - let monitor_status = ( - "monitors", - !monitors.is_empty(), - format!("{} found", monitors.len()), - ); - checks.push(monitor_status); - - // Output results - match output { - OutputFormat::Json => { - let json = serde_json::json!({ - "checks": checks.iter().map(|(name, status, msg)| { - serde_json::json!({ - "name": name, - "status": status, - "message": msg, - }) - }).collect::>(), - "healthy": checks.iter().all(|(_, status, _)| *status), - }); - println!("{}", serde_json::to_string_pretty(&json)?); - } - OutputFormat::Text => { - println!("\n🔍 screenpipe system diagnostics\n"); - - for (name, status, msg) in &checks { - let symbol = if *status { "✅" } else { "❌" }; - println!("{} {}: {}", symbol, name, msg); - } - - if fix { - println!("\n🔧 attempting to fix issues..."); - // Add auto-fix logic here if needed - let db = DatabaseManager::new(&db_path.to_string_lossy()).await?; - db.repair_database().await?; - } - - let healthy = checks.iter().all(|(_, status, _)| *status); - println!( - "\n{} overall health: {}\n", - if healthy { "✅" } else { "❌" }, - if healthy { "healthy" } else { "issues found" } - ); - } - } - - Ok(()) -} diff --git a/screenpipe-server/src/cli.rs b/screenpipe-server/src/cli.rs index e0467ce574..b158e14219 100644 --- a/screenpipe-server/src/cli.rs +++ b/screenpipe-server/src/cli.rs @@ -1,11 +1,9 @@ use std::path::PathBuf; -use clap::{CommandFactory, Parser, Subcommand, ValueHint}; -use clap_complete::{generate, Shell}; +use clap::{Parser, Subcommand}; use screenpipe_audio::{vad_engine::VadSensitivity, AudioTranscriptionEngine as CoreAudioTranscriptionEngine}; use screenpipe_vision::{custom_ocr::CustomOcrConfig, utils::OcrEngine as CoreOcrEngine}; use clap::ValueEnum; -use anyhow::Result; use screenpipe_audio::vad_engine::VadEngineEnum; use screenpipe_core::Language; @@ -147,13 +145,12 @@ pub struct Cli { #[arg(short = 'r', long)] pub realtime_audio_device: Vec, - /// List available audio devices (deprecated: use 'audio list' instead) + /// List available audio devices #[arg(long)] - #[deprecated(since = "0.2.30", note = "please use 'audio list' instead")] pub list_audio_devices: bool, /// Data directory. Default to $HOME/.screenpipe - #[arg(long, value_hint = ValueHint::DirPath)] + #[arg(long)] pub data_dir: Option, /// Enable debug logging for screenpipe modules @@ -191,9 +188,8 @@ pub struct Cli { )] pub ocr_engine: CliOcrEngine, - /// List available monitors (deprecated: use 'vision list' instead) + /// List available monitors, then you can use --monitor-id to select one (with the ID) #[arg(long)] - #[deprecated(since = "0.2.30", note = "please use 'vision list' instead")] pub list_monitors: bool, /// Monitor IDs to use, these will be used to select the monitors to record @@ -268,10 +264,6 @@ pub struct Cli { #[arg(long, default_value_t = false)] pub capture_unfocused_windows: bool, - /// Automatically detect and use all monitors, including newly connected ones - #[arg(long, default_value_t = false)] - pub use_all_monitors: bool, - #[command(subcommand)] pub command: Option, @@ -289,12 +281,6 @@ impl Cli { } Ok(unique_langs.into_iter().collect()) } - - pub fn handle_completions(&self, shell: Shell) -> anyhow::Result<()> { - let mut cmd = Self::command(); - generate(shell, &mut cmd, "screenpipe", &mut std::io::stdout()); - Ok(()) - } } #[derive(Subcommand)] @@ -304,22 +290,12 @@ pub enum Command { #[command(subcommand)] subcommand: PipeCommand, }, - /// Vision device management commands - Vision { - #[command(subcommand)] - subcommand: VisionCommand, - }, - /// Audio device management commands - Audio { - #[command(subcommand)] - subcommand: AudioCommand, - }, /// Add video files to existing screenpipe data (OCR only) - DOES NOT SUPPORT AUDIO Add { /// Path to folder containing video files path: String, /// Data directory. Default to $HOME/.screenpipe - #[arg(long, value_hint = ValueHint::DirPath)] + #[arg(long)] data_dir: Option, /// Output format #[arg(short = 'o', long, value_enum, default_value_t = OutputFormat::Text)] @@ -331,7 +307,7 @@ pub enum Command { #[arg(short = 'o', long, value_enum)] ocr_engine: Option, /// Path to JSON file containing metadata overrides - #[arg(long, value_hint = ValueHint::FilePath)] + #[arg(long)] metadata_override: Option, /// Copy videos to screenpipe data directory #[arg(long, default_value_t = true)] @@ -348,33 +324,12 @@ pub enum Command { /// Enable beta features #[arg(long, default_value_t = false)] enable_beta: bool, - }, - /// Generate shell completions - Completions { - /// The shell to generate completions for - #[arg(value_enum)] - shell: Shell, }, /// Run database migrations Migrate, - /// Run system diagnostics and health checks - Doctor { - /// Output format - #[arg(short, long, value_enum, default_value_t = OutputFormat::Text)] - output: OutputFormat, - - /// Fix issues automatically when possible - #[arg(short, long, default_value_t = false)] - fix: bool, - }, - /// Add screenpipe to system PATH - AddToPath { - /// Skip confirmation prompt - #[arg(short = 'y', long)] - yes: bool, - }, } + #[derive(Subcommand)] pub enum PipeCommand { /// List all pipes @@ -468,26 +423,6 @@ pub enum PipeCommand { }, } -#[derive(Subcommand)] -pub enum VisionCommand { - /// List available vision devices (monitors) - List { - /// Output format - #[arg(short, long, value_enum, default_value_t = OutputFormat::Text)] - output: OutputFormat, - }, -} - -#[derive(Subcommand)] -pub enum AudioCommand { - /// List available audio devices - List { - /// Output format - #[arg(short, long, value_enum, default_value_t = OutputFormat::Text)] - output: OutputFormat, - }, -} - #[derive(Clone, Debug, ValueEnum, PartialEq)] pub enum OutputFormat { Text, diff --git a/screenpipe-server/src/core.rs b/screenpipe-server/src/core.rs index 6a0d4c8e49..bb472819d5 100644 --- a/screenpipe-server/src/core.rs +++ b/screenpipe-server/src/core.rs @@ -2,138 +2,144 @@ use crate::cli::{CliVadEngine, CliVadSensitivity}; use crate::db_types::Speaker; use crate::{DatabaseManager, VideoCapture}; use anyhow::Result; +use dashmap::DashMap; use futures::future::join_all; -use futures::StreamExt; +use tracing::{debug, error, info, warn}; +use screenpipe_audio::realtime::RealtimeTranscriptionEvent; +use screenpipe_audio::vad_engine::VadSensitivity; use screenpipe_audio::{ - record_and_transcribe, AudioInput, AudioTranscriptionEngine, TranscriptionResult, + create_whisper_channel, record_and_transcribe, vad_engine::VadEngineEnum, AudioDevice, + AudioInput, AudioTranscriptionEngine, DeviceControl, TranscriptionResult, }; use screenpipe_audio::{start_realtime_recording, AudioStream}; use screenpipe_core::pii_removal::remove_pii; -use screenpipe_core::{AudioDevice, DeviceManager, DeviceType, Language}; -use screenpipe_events::{poll_meetings_events, send_event}; -use screenpipe_vision::core::WindowOcr; -use screenpipe_vision::{CaptureResult, OcrEngine}; +use screenpipe_core::Language; +use screenpipe_vision::core::{RealtimeVisionEvent, WindowOcr}; +use screenpipe_vision::OcrEngine; use std::collections::HashMap; -use std::str::FromStr; +use std::path::PathBuf; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use std::sync::Weak; use std::time::Duration; use tokio::runtime::Handle; use tokio::task::JoinHandle; -use tracing::{debug, error, info, instrument, warn}; - -#[derive(Clone)] -pub struct RecordingConfig { - pub output_path: Arc, - pub fps: f64, - pub audio_chunk_duration: Duration, - pub video_chunk_duration: Duration, - pub use_pii_removal: bool, - pub capture_unfocused_windows: bool, - pub languages: Arc<[Language]>, -} - -#[derive(Clone)] -pub struct AudioConfig { - pub disabled: bool, - pub transcription_engine: Arc, - pub vad_engine: CliVadEngine, - pub vad_sensitivity: CliVadSensitivity, - pub deepgram_api_key: Option, - pub realtime_enabled: bool, - pub realtime_devices: Arc<[AudioDevice]>, - pub whisper_sender: crossbeam::channel::Sender, - pub whisper_receiver: crossbeam::channel::Receiver, -} - -#[derive(Clone)] -pub struct VisionConfig { - pub disabled: bool, - pub ocr_engine: Arc, - pub ignored_windows: Arc<[String]>, - pub include_windows: Arc<[String]>, -} - -#[derive(Clone)] -pub struct VideoRecordingConfig { - pub db: Arc, - pub output_path: Arc, - pub fps: f64, - pub ocr_engine: Weak, - pub monitor_id: u32, - pub use_pii_removal: bool, - pub ignored_windows: Arc<[String]>, - pub include_windows: Arc<[String]>, - pub video_chunk_duration: Duration, - pub languages: Arc<[Language]>, - pub capture_unfocused_windows: bool, -} -#[instrument(skip(device_manager, db, recording, audio, vision))] +#[allow(clippy::too_many_arguments)] pub async fn start_continuous_recording( db: Arc, - recording: RecordingConfig, - audio: AudioConfig, - vision: VisionConfig, + output_path: Arc, + fps: f64, + audio_chunk_duration: Duration, + video_chunk_duration: Duration, + vision_control: Arc, + audio_devices_control: Arc>, + audio_disabled: bool, + audio_transcription_engine: Arc, + ocr_engine: Arc, + monitor_ids: Vec, + use_pii_removal: bool, + vision_disabled: bool, + vad_engine: CliVadEngine, vision_handle: &Handle, audio_handle: &Handle, - device_manager: Arc, + ignored_windows: &[String], + include_windows: &[String], + deepgram_api_key: Option, + vad_sensitivity: CliVadSensitivity, + languages: Vec, + capture_unfocused_windows: bool, + realtime_audio_devices: Vec>, + realtime_audio_enabled: bool, + realtime_transcription_sender: Arc>, + realtime_vision_sender: Arc>, ) -> Result<()> { - let recording_config = recording; - - let output_path_clone = recording_config.output_path.clone(); - let languages_clone = recording_config.languages.clone(); - let db_clone = db.clone(); - let ocr_engine_clone = vision.ocr_engine.clone(); - let device_manager_vision = device_manager.clone(); - let device_manager_audio = device_manager.clone(); - - let video_task = if !vision.disabled { - vision_handle.spawn(async move { - record_vision( - device_manager_vision, - ocr_engine_clone, - db_clone, - output_path_clone, - recording_config.fps, - languages_clone, - recording_config.capture_unfocused_windows, - vision.ignored_windows, - vision.include_windows, - recording_config.video_chunk_duration, - recording_config.use_pii_removal, - ) - .await - }) + debug!("Starting video recording for monitor {:?}", monitor_ids); + let video_tasks = if !vision_disabled { + monitor_ids + .iter() + .map(|&monitor_id| { + let db_manager_video = Arc::clone(&db); + let output_path_video = Arc::clone(&output_path); + let is_running_video = Arc::clone(&vision_control); + let ocr_engine = Arc::clone(&ocr_engine); + let ignored_windows_video = ignored_windows.to_vec(); + let include_windows_video = include_windows.to_vec(); + let realtime_vision_sender_clone = realtime_vision_sender.clone(); + + let languages = languages.clone(); + + debug!("Starting video recording for monitor {}", monitor_id); + vision_handle.spawn(async move { + record_video( + db_manager_video, + output_path_video, + fps, + is_running_video, + ocr_engine, + monitor_id, + use_pii_removal, + &ignored_windows_video, + &include_windows_video, + video_chunk_duration, + languages.clone(), + capture_unfocused_windows, + realtime_vision_sender_clone, + ) + .await + }) + }) + .collect::>() } else { - vision_handle.spawn(async move { + vec![vision_handle.spawn(async move { tokio::time::sleep(Duration::from_secs(60)).await; Ok(()) - }) + })] }; - let whisper_sender_clone = audio.whisper_sender.clone(); - let whisper_receiver_clone = audio.whisper_receiver.clone(); + let (whisper_sender, whisper_receiver, whisper_shutdown_flag) = if audio_disabled { + // Create a dummy channel if no audio devices are available, e.g. audio disabled + let (input_sender, _): ( + crossbeam::channel::Sender, + crossbeam::channel::Receiver, + ) = crossbeam::channel::bounded(100); + let (_, output_receiver): ( + crossbeam::channel::Sender, + crossbeam::channel::Receiver, + ) = crossbeam::channel::bounded(100); + ( + input_sender, + output_receiver, + Arc::new(AtomicBool::new(false)), + ) + } else { + create_whisper_channel( + audio_transcription_engine.clone(), + VadEngineEnum::from(vad_engine), + deepgram_api_key.clone(), + &PathBuf::from(output_path.as_ref()), + VadSensitivity::from(vad_sensitivity), + languages.clone(), + Some(audio_devices_control.clone()), + ) + .await? + }; + let whisper_sender_clone = whisper_sender.clone(); let db_manager_audio = Arc::clone(&db); - tokio::spawn(async move { - let _ = poll_meetings_events().await; - }); - - let audio_task = if !audio.disabled { + let audio_task = if !audio_disabled { audio_handle.spawn(async move { record_audio( - device_manager_audio, db_manager_audio, - recording_config.audio_chunk_duration, - whisper_sender_clone, - whisper_receiver_clone, - audio.transcription_engine, - audio.realtime_enabled, - audio.realtime_devices, - recording_config.languages, - audio.deepgram_api_key, + audio_chunk_duration, + whisper_sender, + whisper_receiver, + audio_devices_control, + audio_transcription_engine, + realtime_audio_enabled, + realtime_audio_devices, + languages, + realtime_transcription_sender, + deepgram_api_key, ) .await }) @@ -144,15 +150,22 @@ pub async fn start_continuous_recording( }) }; - if let Err(e) = video_task.await { - error!("Video recording error: {:?}", e); + // Join all video tasks + let video_results = join_all(video_tasks); + + // Handle any errors from the tasks + for (i, result) in video_results.await.into_iter().enumerate() { + if let Err(e) = result { + error!("Video recording error for monitor {}: {:?}", i, e); + } } if let Err(e) = audio_task.await { error!("Audio recording error: {:?}", e); } // Shutdown the whisper channel - drop(audio.whisper_sender); // Close the sender channel + whisper_shutdown_flag.store(true, Ordering::Relaxed); + drop(whisper_sender_clone); // Close the sender channel // TODO: process any remaining audio chunks // TODO: wait a bit for whisper to finish processing @@ -163,115 +176,25 @@ pub async fn start_continuous_recording( } #[allow(clippy::too_many_arguments)] -async fn record_vision( - device_manager: Arc, - ocr_engine: Arc, +async fn record_video( db: Arc, output_path: Arc, fps: f64, - languages: Arc<[Language]>, - capture_unfocused_windows: bool, - ignored_windows: Arc<[String]>, - include_windows: Arc<[String]>, - video_chunk_duration: Duration, + is_running: Arc, + ocr_engine: Arc, + monitor_id: u32, use_pii_removal: bool, + ignored_windows: &[String], + include_windows: &[String], + video_chunk_duration: Duration, + languages: Vec, + capture_unfocused_windows: bool, + realtime_vision_sender: Arc>, ) -> Result<()> { - let mut handles: HashMap> = HashMap::new(); - let mut device_states = device_manager.watch_devices().await; - - // Create weak reference to device_manager - let device_manager_weak = Arc::downgrade(&device_manager); - - loop { - tokio::select! { - Some(state_change) = device_states.next() => { - // Clean up finished handles first - handles.retain(|monitor_id, handle| { - if handle.is_finished() { - info!("handle for monitor {} has finished", monitor_id); - false - } else { - true - } - }); - - match DeviceType::from_str(&state_change.device) { - Ok(DeviceType::Vision(monitor_id)) => { - debug!("record_vision: vision state change: {:?}", state_change); - if !state_change.control.is_running { - if let Some(handle) = handles.remove(&monitor_id) { - let _ = handle.await; - info!("stopped thread for monitor {}", monitor_id); - } - continue; - } - - if handles.contains_key(&monitor_id) { - continue; - } - - info!("starting vision capture thread for monitor: {}", monitor_id); - - let db_manager_video = Arc::clone(&db); - let output_path_video = Arc::clone(&output_path); - let ocr_engine_weak = Arc::downgrade(&ocr_engine); - - let languages = languages.clone(); - // Use weak reference for the device manager - let device_manager_weak = device_manager_weak.clone(); - let ignored_windows = ignored_windows.clone(); - let include_windows = include_windows.clone(); - let handle = tokio::spawn(async move { - let config = VideoRecordingConfig { - db: db_manager_video, - output_path: output_path_video, - fps, - ocr_engine: ocr_engine_weak, - monitor_id, - use_pii_removal, - ignored_windows, - include_windows, - video_chunk_duration, - languages, - capture_unfocused_windows, - }; - - // Upgrade weak reference when needed - let device_manager = match device_manager_weak.upgrade() { - Some(dm) => dm, - None => { - warn!("device manager no longer exists"); - return; - } - }; - - if let Err(e) = record_video(device_manager, config).await { - error!( - "Error in video recording thread for monitor {}: {}", - monitor_id, e - ); - } - }); - - handles.insert(monitor_id, handle); - } - _ => continue, // Ignore non-vision devices - } - } - _ = tokio::time::sleep(Duration::from_millis(100)) => { - - } - } - } -} - -async fn record_video( - device_manager: Arc, - config: VideoRecordingConfig, -) -> Result<()> { - let db_chunk_callback = Arc::clone(&config.db); + debug!("record_video: Starting"); + let db_chunk_callback = Arc::clone(&db); let rt = Handle::current(); - let device_name = Arc::new(format!("monitor_{}", config.monitor_id)); + let device_name = Arc::new(format!("monitor_{}", monitor_id)); let new_chunk_callback = { let db_chunk_callback = Arc::clone(&db_chunk_callback); @@ -293,330 +216,272 @@ async fn record_video( }; let video_capture = VideoCapture::new( - &config.output_path, - config.fps, - config.video_chunk_duration, + &output_path, + fps, + video_chunk_duration, new_chunk_callback, - config.ocr_engine.clone(), - config.monitor_id, - config.ignored_windows, - config.include_windows, - config.languages, - config.capture_unfocused_windows, + Arc::clone(&ocr_engine), + monitor_id, + ignored_windows, + include_windows, + languages, + capture_unfocused_windows, ); - let mut device_states = device_manager.watch_devices().await; - - loop { - tokio::select! { - Some(state_change) = device_states.next() => { - match DeviceType::from_str(&state_change.device) { - Ok(DeviceType::Vision(monitor_id)) if monitor_id == config.monitor_id => { - debug!("record_video: vision state change: {:?}", state_change); - if !state_change.control.is_running { - info!("vision thread for monitor {} received stop signal", monitor_id); - let _ = video_capture.shutdown().await; - info!("vision thread for monitor {} shutdown complete", monitor_id); - return Ok(()); + while is_running.load(Ordering::SeqCst) { + if let Some(frame) = video_capture.ocr_frame_queue.pop() { + for window_result in &frame.window_ocr_results { + match db.insert_frame(&device_name, None).await { + Ok(frame_id) => { + let text_json = + serde_json::to_string(&window_result.text_json).unwrap_or_default(); + + let text = if use_pii_removal { + &remove_pii(&window_result.text) + } else { + &window_result.text + }; + + let _ = realtime_vision_sender.send(RealtimeVisionEvent::Ocr(WindowOcr { + image: Some(frame.image.clone()), + text: text.clone(), + text_json: window_result.text_json.clone(), + app_name: window_result.app_name.clone(), + window_name: window_result.window_name.clone(), + focused: window_result.focused, + confidence: window_result.confidence, + timestamp: frame.timestamp, + })); + if let Err(e) = db + .insert_ocr_text( + frame_id, + text, + &text_json, + &window_result.app_name, + &window_result.window_name, + Arc::clone(&ocr_engine), + window_result.focused, // Add this line + ) + .await + { + error!( + "Failed to insert OCR text: {}, skipping window {} of frame {}", + e, window_result.window_name, frame_id + ); + continue; } } - _ => continue, // Ignore other devices or monitors + Err(e) => { + warn!("Failed to insert frame: {}", e); + tokio::time::sleep(Duration::from_millis(100)).await; + continue; + } } } - // we should process faster than the fps we use to do OCR - _ = tokio::time::sleep(Duration::from_secs_f64(1.0 / (config.fps * 2.0))) => { - let frame = match video_capture.ocr_frame_queue.pop() { - Some(f) => f, - None => continue, - }; - - process_ocr_frame( - frame, - &config.db, - &device_name, - config.use_pii_removal, - config.ocr_engine.clone(), - ).await; - } } + tokio::time::sleep(Duration::from_secs_f64(1.0 / fps)).await; } -} - -async fn process_ocr_frame( - frame: Arc, - db: &DatabaseManager, - device_name: &str, - use_pii_removal: bool, - ocr_engine: Weak, -) { - let ocr_engine = match ocr_engine.upgrade() { - Some(engine) => engine, - None => { - warn!("OCR engine no longer exists"); - return; - } - }; - - for window_result in &frame.window_ocr_results { - let frame_id = match db.insert_frame(device_name, None).await { - Ok(id) => id, - Err(e) => { - warn!("Failed to insert frame: {}", e); - tokio::time::sleep(Duration::from_millis(100)).await; - continue; - } - }; - - let text_json = serde_json::to_string(&window_result.text_json).unwrap_or_default(); - - let text = if use_pii_removal { - remove_pii(&window_result.text) - } else { - window_result.text.clone() - }; - - let _ = send_event( - "ocr_result", - WindowOcr { - image: Some(frame.image.clone()), - text: text.clone(), - text_json: window_result.text_json.clone(), - app_name: window_result.app_name.clone(), - window_name: window_result.window_name.clone(), - focused: window_result.focused, - confidence: window_result.confidence, - timestamp: frame.timestamp, - }, - ); - if let Err(e) = db - .insert_ocr_text( - frame_id, - &text, - &text_json, - &window_result.app_name, - &window_result.window_name, - ocr_engine.clone(), - window_result.focused, - ) - .await - { - error!( - "Failed to insert OCR text: {}, skipping window {} of frame {}", - e, window_result.window_name, frame_id - ); - } - } + Ok(()) } #[allow(clippy::too_many_arguments)] async fn record_audio( - device_manager: Arc, db: Arc, chunk_duration: Duration, whisper_sender: crossbeam::channel::Sender, whisper_receiver: crossbeam::channel::Receiver, + audio_devices_control: Arc>, audio_transcription_engine: Arc, realtime_audio_enabled: bool, - realtime_audio_devices: Arc<[AudioDevice]>, - languages: Arc<[Language]>, + realtime_audio_devices: Vec>, + languages: Vec, + realtime_transcription_sender: Arc>, deepgram_api_key: Option, ) -> Result<()> { let mut handles: HashMap> = HashMap::new(); - let mut device_states = device_manager.watch_devices().await; - let mut previous_transcript = String::new(); + let mut previous_transcript = "".to_string(); let mut previous_transcript_id: Option = None; - - // Create a weak reference to device_manager - let device_manager_weak = Arc::downgrade(&device_manager); - let mut prev_handles_len = 0; + let realtime_transcription_sender_clone = realtime_transcription_sender.clone(); loop { - if handles.len() != prev_handles_len { - prev_handles_len = handles.len(); - info!("handles length: {}", prev_handles_len); - } - tokio::select! { - Some(state_change) = device_states.next() => { - // Handle cleanup of finished handles - handles.retain(|device_id, handle| { - if handle.is_finished() { - info!("handle for device {} has finished", device_id); - false - } else { - true - } - }); + // Iterate over DashMap entries and process each device + for entry in audio_devices_control.iter() { + let audio_device = entry.key().clone(); + let device_control = entry.value().clone(); + let device_id = audio_device.to_string(); + + // Skip if we're already handling this device + if handles.contains_key(&device_id) { + continue; + } - match DeviceType::from_str(&state_change.device) { - Ok(DeviceType::Audio(audio_device)) => { - let device_id = audio_device.to_string(); + info!("Received audio device: {}", &audio_device); - if !state_change.control.is_running { - if let Some(handle) = handles.remove(&device_id) { - handle.abort(); - info!("stopped thread for device {}", &audio_device); - } - continue; - } + if !device_control.is_running { + info!("Device control signaled stop for device {}", &audio_device); + if let Some(handle) = handles.remove(&device_id) { + handle.abort(); + info!("Stopped thread for device {}", &audio_device); + } + // Remove from DashMap + audio_devices_control.remove(&audio_device); + continue; + } - if handles.contains_key(&device_id) { - continue; - } + let whisper_sender_clone = whisper_sender.clone(); - info!("starting audio capture thread for device: {}", &audio_device); - - let audio_device = Arc::new(audio_device); - let is_running = Arc::new(AtomicBool::new(true)); - - // Use weak reference for the spawned task - let device_manager_weak = device_manager_weak.clone(); - let whisper_sender = whisper_sender.clone(); - let languages = Arc::clone(&languages); - let deepgram_api_key = deepgram_api_key.clone(); - - let device_id_for_handle = device_id.clone(); - let handle = tokio::spawn({ - let audio_device = Arc::clone(&audio_device); - let is_running = Arc::clone(&is_running); - let realtime_devices = realtime_audio_devices.clone(); - - async move { - info!("starting audio capture thread for device: {}", &audio_device); - let mut did_warn = false; - - // Move the device state monitoring outside the main loop - let device_states = device_manager_weak.upgrade().unwrap().watch_devices().await; - let is_running_clone = Arc::clone(&is_running); - - let monitor_handle = tokio::spawn(async move { - let mut device_states = device_states; - while let Some(state_change) = device_states.next().await { - if state_change.device == device_id_for_handle && !state_change.control.is_running { - is_running_clone.store(false, Ordering::Relaxed); - break; - } - } - }); - - while is_running.load(Ordering::Relaxed) { - - - let audio_stream = match AudioStream::from_device( - Arc::clone(&audio_device), - Arc::clone(&is_running), - ).await { - Ok(stream) => Arc::new(stream), - Err(e) => { - if e.to_string().contains("audio device not found") { - if !did_warn { - warn!("audio device not found: {}", audio_device.name); - did_warn = true; - } - tokio::time::sleep(Duration::from_secs(1)).await; - continue; - } else { - error!("failed to create audio stream: {}", e); - return; - } - } - }; - - let mut recording_handles = vec![]; - - // Spawn record and transcribe task - recording_handles.push(tokio::spawn({ - let audio_stream = Arc::clone(&audio_stream); - let is_running = Arc::clone(&is_running); - let whisper_sender = whisper_sender.clone(); - - async move { - let _ = record_and_transcribe( - audio_stream, - chunk_duration, - whisper_sender, - is_running, - ).await; - } - })); - - // Spawn realtime recording task if enabled - if realtime_audio_enabled && realtime_devices.contains(&audio_device) { - recording_handles.push(tokio::spawn({ - let audio_stream = Arc::clone(&audio_stream); - let is_running = Arc::clone(&is_running); - let languages = Arc::clone(&languages); - let deepgram_api_key = deepgram_api_key.clone(); - - async move { - let _ = start_realtime_recording( - audio_stream, - languages, - is_running, - deepgram_api_key, - ).await; - } - })); - } - - join_all(recording_handles).await; - } + let audio_device = Arc::new(audio_device); + let device_control = Arc::new(device_control); - // Clean up the monitor task - monitor_handle.abort(); + let realtime_audio_devices_clone = realtime_audio_devices.clone(); + let languages_clone = languages.clone(); + let realtime_transcription_sender_clone = realtime_transcription_sender_clone.clone(); + let deepgram_api_key_clone = deepgram_api_key.clone(); + let handle = tokio::spawn(async move { + let audio_device_clone = Arc::clone(&audio_device); + let deepgram_api_key = deepgram_api_key_clone.clone(); + debug!( + "Starting audio capture thread for device: {}", + &audio_device + ); + + let mut did_warn = false; + let is_running = Arc::new(AtomicBool::new(device_control.is_running)); + + while is_running.load(Ordering::Relaxed) { + let deepgram_api_key = deepgram_api_key.clone(); + let is_running_loop = Arc::clone(&is_running); // Create separate reference for the loop + let audio_stream = match AudioStream::from_device( + audio_device_clone.clone(), + Arc::clone(&is_running_loop), // Clone from original Arc + ) + .await + { + Ok(stream) => stream, + Err(e) => { + if e.to_string().contains("Audio device not found") { + if !did_warn { + warn!("Audio device not found: {}", audio_device.name); + did_warn = true; + } + tokio::time::sleep(Duration::from_secs(1)).await; + continue; + } else { + error!("Failed to create audio stream: {}", e); + return; } - }); + } + }; + + let mut recording_handles: Vec> = vec![]; + + let audio_stream = Arc::new(audio_stream); + let whisper_sender_clone = whisper_sender_clone.clone(); + let audio_stream_clone = audio_stream.clone(); + let is_running_loop_clone = is_running_loop.clone(); + let record_handle = Some(tokio::spawn(async move { + let _ = record_and_transcribe( + audio_stream, + chunk_duration, + whisper_sender_clone.clone(), + is_running_loop_clone.clone(), + ) + .await; + })); + + if let Some(handle) = record_handle { + recording_handles.push(handle); + } - handles.insert(device_id, handle); + let audio_device_clone = audio_device_clone.clone(); + let realtime_audio_devices_clone = realtime_audio_devices_clone.clone(); + let languages_clone = languages_clone.clone(); + let is_running_loop = is_running_loop.clone(); + let realtime_transcription_sender_clone = + realtime_transcription_sender_clone.clone(); + let live_transcription_handle = Some(tokio::spawn(async move { + if realtime_audio_enabled + && realtime_audio_devices_clone.contains(&audio_device_clone) + { + let _ = start_realtime_recording( + audio_stream_clone, + languages_clone.clone(), + is_running_loop.clone(), + realtime_transcription_sender_clone.clone(), + deepgram_api_key.clone(), + ) + .await; + } + })); + + if let Some(handle) = live_transcription_handle { + recording_handles.push(handle); } - _ => continue, + + join_all(recording_handles).await; } - } - _ = tokio::time::sleep(Duration::from_millis(100)) => {} + + info!("exiting audio capture thread for device: {}", &audio_device); + }); + + handles.insert(device_id, handle); } - // Process transcription results + handles.retain(|device_id, handle| { + if handle.is_finished() { + info!("Handle for device {} has finished", device_id); + false + } else { + true + } + }); + while let Ok(mut transcription) = whisper_receiver.try_recv() { info!( "device {} received transcription {:?}", transcription.input.device, transcription.transcription ); - let mut current_transcript = transcription.transcription.clone(); - let mut processed_previous = None; - - if let Some((previous, current)) = transcription.cleanup_overlap(&previous_transcript) { + // Insert the new transcript after fetching + let mut current_transcript: Option = transcription.transcription.clone(); + let mut processed_previous: Option = None; + if let Some((previous, current)) = + transcription.cleanup_overlap(previous_transcript.clone()) + { if !previous.is_empty() && !current.is_empty() { if previous != previous_transcript { processed_previous = Some(previous); } - if current_transcript.as_ref() != Some(¤t) { + if current_transcript.is_some() + && current != current_transcript.clone().unwrap_or_default() + { current_transcript = Some(current); } } } transcription.transcription = current_transcript.clone(); - - if let Some(transcript) = current_transcript { - previous_transcript = transcript; + if current_transcript.is_some() { + previous_transcript = current_transcript.unwrap(); } else { continue; } - // Process the audio result match process_audio_result( &db, transcription, - Arc::clone(&audio_transcription_engine), + audio_transcription_engine.clone(), processed_previous, previous_transcript_id, ) .await { + Err(e) => error!("Error processing audio result: {}", e), Ok(id) => previous_transcript_id = id, - Err(e) => error!("error processing audio result: {}", e), } } + + tokio::time::sleep(Duration::from_millis(100)).await; } } @@ -629,7 +494,7 @@ async fn process_audio_result( ) -> Result, anyhow::Error> { if result.error.is_some() || result.transcription.is_none() { error!( - "error in audio recording: {}. not inserting audio result", + "Error in audio recording: {}. Not inserting audio result", result.error.unwrap_or_default() ); return Ok(None); @@ -637,7 +502,7 @@ async fn process_audio_result( let speaker = get_or_create_speaker_from_embedding(db, &result.speaker_embedding).await?; - info!("detected speaker: {:?}", speaker); + info!("Detected speaker: {:?}", speaker); let transcription = result.transcription.unwrap(); let transcription_engine = audio_transcription_engine.to_string(); @@ -655,7 +520,7 @@ async fn process_audio_result( { Ok(_) => {} Err(e) => error!( - "failed to update transcription for {}: audio_chunk_id {}", + "Failed to update transcription for {}: audio_chunk_id {}", result.input.device, e ), } @@ -681,20 +546,20 @@ async fn process_audio_result( .await { error!( - "failed to insert audio transcription for device {}: {}", + "Failed to insert audio transcription for device {}: {}", result.input.device, e ); return Ok(Some(audio_chunk_id)); } else { debug!( - "inserted audio transcription for chunk {} from device {} using {}", + "Inserted audio transcription for chunk {} from device {} using {}", audio_chunk_id, result.input.device, transcription_engine ); chunk_id = Some(audio_chunk_id); } } Err(e) => error!( - "failed to insert audio chunk for device {}: {}", + "Failed to insert audio chunk for device {}: {}", result.input.device, e ), } @@ -729,6 +594,6 @@ pub async fn merge_speakers( .await { Ok(speaker) => Ok(speaker), - Err(e) => Err(anyhow::anyhow!("failed to merge speakers: {}", e)), + Err(e) => Err(anyhow::anyhow!("Failed to merge speakers: {}", e)), } } diff --git a/screenpipe-server/src/db.rs b/screenpipe-server/src/db.rs index b2202880bb..3128637601 100644 --- a/screenpipe-server/src/db.rs +++ b/screenpipe-server/src/db.rs @@ -1,7 +1,7 @@ use chrono::{DateTime, Utc}; use image::DynamicImage; use libsqlite3_sys::sqlite3_auto_extension; -use screenpipe_core::{AudioDevice, AudioDeviceType}; +use screenpipe_audio::{AudioDevice, DeviceType}; use screenpipe_vision::OcrEngine; use sqlite_vec::sqlite3_vec_init; use sqlx::migrate::MigrateDatabase; @@ -147,7 +147,7 @@ impl DatabaseManager { .bind(Utc::now()) .bind(transcription_engine) .bind(&device.name) - .bind(device.device_type == AudioDeviceType::Input) + .bind(device.device_type == DeviceType::Input) .bind(speaker_id) .bind(start_time) .bind(end_time) @@ -906,9 +906,9 @@ impl DatabaseManager { .unwrap_or_default(), device_name: raw.device_name, device_type: if raw.is_input_device { - AudioDeviceType::Input + DeviceType::Input } else { - AudioDeviceType::Output + DeviceType::Output }, speaker, start_time: raw.start_time, diff --git a/screenpipe-server/src/db_types.rs b/screenpipe-server/src/db_types.rs index 8a8d066cb8..6a1275c88f 100644 --- a/screenpipe-server/src/db_types.rs +++ b/screenpipe-server/src/db_types.rs @@ -1,5 +1,5 @@ use chrono::{DateTime, Utc}; -use screenpipe_core::AudioDeviceType; +use screenpipe_audio::DeviceType; use serde::{Deserialize, Serialize}; use sqlx::FromRow; use std::error::Error as StdError; @@ -105,7 +105,7 @@ pub struct AudioResult { pub transcription_engine: String, pub tags: Vec, pub device_name: String, - pub device_type: AudioDeviceType, + pub device_type: DeviceType, pub speaker: Option, pub start_time: Option, pub end_time: Option, diff --git a/screenpipe-server/src/filtering.rs b/screenpipe-server/src/filtering.rs index 3566bbaf95..66ccd07963 100644 --- a/screenpipe-server/src/filtering.rs +++ b/screenpipe-server/src/filtering.rs @@ -1,6 +1,6 @@ -use sqlx::{sqlite::SqlitePool, query_as}; use ndarray::{Array2, Axis}; use rust_stemmers::{Algorithm, Stemmer}; +use sqlx::{query_as, sqlite::SqlitePool}; use std::collections::{HashMap, HashSet}; use std::error::Error; use tracing::debug; @@ -16,13 +16,17 @@ fn keep_least_similar(chunks: &[String], percentage: f64) -> Vec { // Calculate TF-IDF for (doc_id, chunk) in chunks.iter().enumerate() { - let words: Vec<_> = chunk.split_whitespace() + let words: Vec<_> = chunk + .split_whitespace() .map(|word| en_stemmer.stem(word).to_string()) .collect(); for word in words.iter() { *term_freq.entry((doc_id, word.clone())).or_insert(0) += 1; - doc_freq.entry(word.clone()).or_insert_with(HashSet::new).insert(doc_id); + doc_freq + .entry(word.clone()) + .or_insert_with(HashSet::new) + .insert(doc_id); } } @@ -55,7 +59,8 @@ fn keep_least_similar(chunks: &[String], percentage: f64) -> Vec { sorted_similarities.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); let keep_count = (percentage * n_docs as f64) as usize; - sorted_similarities.into_iter() + sorted_similarities + .into_iter() .take(keep_count) .map(|(i, _)| i) .collect() @@ -65,15 +70,21 @@ fn word_count(text: &str) -> usize { text.split_whitespace().count() } -pub async fn filter_texts(timestamp: &str, memory_source: &str, pool: &SqlitePool) -> Result> { - let texts: Vec = query_as(" +pub async fn filter_texts( + timestamp: &str, + memory_source: &str, + pool: &SqlitePool, +) -> Result> { + let texts: Vec = query_as( + " SELECT cti.text FROM chunked_text_index cti JOIN chunked_text_entries cte ON cti.text_id = cte.text_id WHERE cte.timestamp > ? AND cte.source = ? GROUP BY cti.text_id ORDER BY MAX(cte.timestamp) DESC - ") + ", + ) .bind(timestamp) .bind(memory_source) .fetch_all(pool) @@ -84,7 +95,10 @@ pub async fn filter_texts(timestamp: &str, memory_source: &str, pool: &SqlitePoo let initial_text_count = texts.len(); let initial_word_count: usize = texts.iter().map(|text| word_count(text)).sum(); - debug!("Initial: {} texts, {} words", initial_text_count, initial_word_count); + debug!( + "Initial: {} texts, {} words", + initial_text_count, initial_word_count + ); if texts.is_empty() { return Ok("".to_string()); @@ -95,9 +109,16 @@ pub async fn filter_texts(timestamp: &str, memory_source: &str, pool: &SqlitePoo let final_text_count = kept_texts.len(); let final_word_count: usize = kept_texts.iter().map(|&text| word_count(text)).sum(); - debug!("Kept: {} texts, {} words", final_text_count, final_word_count); - - let output = kept_texts.iter().map(|s| s.as_str()).collect::>().join("\n"); + debug!( + "Kept: {} texts, {} words", + final_text_count, final_word_count + ); + + let output = kept_texts + .iter() + .map(|s| s.as_str()) + .collect::>() + .join("\n"); Ok(output) } diff --git a/screenpipe-server/src/lib.rs b/screenpipe-server/src/lib.rs index 0a77b0b7e9..ef51ff969e 100644 --- a/screenpipe-server/src/lib.rs +++ b/screenpipe-server/src/lib.rs @@ -31,5 +31,4 @@ pub use server::ContentItem; pub use server::HealthCheckResponse; pub use server::PaginatedResponse; pub use server::Server; -pub use server::VisionDeviceControlRequest; pub use video::VideoCapture; diff --git a/screenpipe-server/src/resource_monitor.rs b/screenpipe-server/src/resource_monitor.rs index f4e2d5142b..e9672c1a7e 100644 --- a/screenpipe-server/src/resource_monitor.rs +++ b/screenpipe-server/src/resource_monitor.rs @@ -1,5 +1,4 @@ use chrono::Local; -use reqwest::Client; use serde_json::json; use serde_json::Value; use std::env; @@ -12,15 +11,14 @@ use std::io::Write; use std::sync::Arc; use std::time::{Duration, Instant}; use sysinfo::{PidExt, ProcessExt, System, SystemExt}; -use tracing::{error, info, warn}; -use uuid; +use tracing::{error, info}; + +#[cfg(target_os = "macos")] +use std::process::Command; pub struct ResourceMonitor { start_time: Instant, resource_log_file: Option, // analyse output here: https://colab.research.google.com/drive/1zELlGdzGdjChWKikSqZTHekm5XRxY-1r?usp=sharing - posthog_client: Option, - posthog_enabled: bool, - distinct_id: String, } pub enum RestartSignal { @@ -28,7 +26,7 @@ pub enum RestartSignal { } impl ResourceMonitor { - pub fn new(telemetry_enabled: bool) -> Arc { + pub fn new() -> Arc { let resource_log_file = if env::var("SAVE_RESOURCE_USAGE").is_ok() { let now = Local::now(); let filename = format!("resource_usage_{}.json", now.format("%Y%m%d_%H%M%S")); @@ -48,152 +46,98 @@ impl ResourceMonitor { None }; - // Create client once and reuse instead of Option - let posthog_client = telemetry_enabled.then(Client::new); - - // Generate a unique ID for this installation - let distinct_id = uuid::Uuid::new_v4().to_string(); - Arc::new(Self { start_time: Instant::now(), resource_log_file, - posthog_client, - posthog_enabled: telemetry_enabled, - distinct_id, }) } - async fn send_to_posthog( - &self, - total_memory_gb: f64, - system_total_memory: f64, - total_cpu: f32, - ) { - let Some(client) = &self.posthog_client else { - return; - }; - - // Create System only when needed - let sys = System::new(); - - // Avoid unnecessary cloning by using references - let payload = json!({ - "api_key": "phc_6TUWxXM2NQGPuLhkwgRHxPfXMWqhGGpXqWNIw0GRpMD", - "event": "resource_usage", - "properties": { - "distinct_id": &self.distinct_id, - "$lib": "rust-reqwest", - "total_memory_gb": total_memory_gb, - "system_total_memory_gb": system_total_memory, - "memory_usage_percent": (total_memory_gb / system_total_memory) * 100.0, - "total_cpu_percent": total_cpu, - "runtime_seconds": self.start_time.elapsed().as_secs(), - "os_name": sys.name().unwrap_or_default(), - "os_version": sys.os_version().unwrap_or_default(), - "kernel_version": sys.kernel_version().unwrap_or_default(), - "cpu_count": sys.cpus().len(), - "release": env!("CARGO_PKG_VERSION"), - } - }); - - // Send the event to PostHog - if let Err(e) = client - .post("https://eu.i.posthog.com/capture/") - .json(&payload) - .send() - .await - { - error!("Failed to send resource usage to PostHog: {}", e); - } - } - - async fn log_status(&self, sys: &System) { + fn log_status(&self, sys: &System) { let pid = std::process::id(); - let main_process = match sys.process(sysinfo::Pid::from_u32(pid)) { - Some(p) => p, - None => { - warn!("Could not find main process"); - return; - } - }; + let main_process = sys.process(sysinfo::Pid::from_u32(pid)); let mut total_memory = 0.0; let mut total_cpu = 0.0; - total_memory += main_process.memory() as f64; - total_cpu += main_process.cpu_usage(); + if let Some(process) = main_process { + total_memory += process.memory() as f64; + total_cpu += process.cpu_usage(); - // Iterate through all processes to find children - for child_process in sys.processes().values() { - if child_process.parent() == Some(sysinfo::Pid::from_u32(pid)) { - total_memory += child_process.memory() as f64; - total_cpu += child_process.cpu_usage(); + // Iterate through all processes to find children + for child_process in sys.processes().values() { + if child_process.parent() == Some(sysinfo::Pid::from_u32(pid)) { + total_memory += child_process.memory() as f64; + total_cpu += child_process.cpu_usage(); + } } - } - - let total_memory_gb = total_memory / 1048576000.0; - let system_total_memory = sys.total_memory() as f64 / 1048576000.0; - let memory_usage_percent = (total_memory_gb / system_total_memory) * 100.0; - let runtime = self.start_time.elapsed(); - - let log_message = format!( - "Runtime: {}s, Total Memory: {:.0}% ({:.2} GB / {:.2} GB), Total CPU: {:.0}%", - runtime.as_secs(), - memory_usage_percent, - total_memory_gb, - system_total_memory, - total_cpu - ); - - info!("{}", log_message); - - if let Some(ref filename) = self.resource_log_file { - let file = OpenOptions::new() - .read(true) - .write(true) - .open(filename) - .map_err(|e| { - error!("Failed to open resource log file: {}", e); - e - }); - if let Ok(mut file) = file { + let total_memory_gb = total_memory / 1048576000.0; + let system_total_memory = sys.total_memory() as f64 / 1048576000.0; + let memory_usage_percent = (total_memory_gb / system_total_memory) * 100.0; + let runtime = self.start_time.elapsed(); + + let log_message = if cfg!(target_os = "macos") { + if let Some(npu_usage) = self.get_npu_usage() { + format!( + "Runtime: {}s, Total Memory: {:.0}% ({:.0} GB / {:.0} GB), Total CPU: {:.0}%, NPU: {:.0}%", + runtime.as_secs(), memory_usage_percent, total_memory_gb, system_total_memory, total_cpu, npu_usage + ) + } else { + format!( + "Runtime: {}s, Total Memory: {:.0}% ({:.0} GB / {:.0} GB), Total CPU: {:.0}%, NPU: N/A", + runtime.as_secs(), memory_usage_percent, total_memory_gb, system_total_memory, total_cpu + ) + } + } else { + format!( + "Runtime: {}s, Total Memory: {:.0}% ({:.2} GB / {:.2} GB), Total CPU: {:.0}%", + runtime.as_secs(), + memory_usage_percent, + total_memory_gb, + system_total_memory, + total_cpu + ) + }; + + info!("{}", log_message); + + if let Some(filename) = &self.resource_log_file { + let now = Local::now(); let json_data = json!({ - "timestamp": Local::now().to_rfc3339(), + "timestamp": now.to_rfc3339(), "runtime_seconds": runtime.as_secs(), "total_memory_gb": total_memory_gb, "system_total_memory_gb": system_total_memory, "memory_usage_percent": memory_usage_percent, "total_cpu_percent": total_cpu, + "npu_usage_percent": self.get_npu_usage().unwrap_or(-1.0), }); - // Create string buffer first - let mut contents = String::new(); - file.read_to_string(&mut contents).unwrap_or_default(); - if let Ok(mut json_array) = serde_json::from_str::(&contents) { - if let Some(array) = json_array.as_array_mut() { - array.push(json_data); - if file.set_len(0).is_ok() && file.seek(SeekFrom::Start(0)).is_ok() { - if let Err(e) = file.write_all(json_array.to_string().as_bytes()) { - error!("Failed to write JSON data to file: {}", e); + if let Ok(mut file) = OpenOptions::new().read(true).write(true).open(filename) { + let mut contents = String::new(); + if file.read_to_string(&mut contents).is_ok() { + if let Ok(mut json_array) = serde_json::from_str::(&contents) { + if let Some(array) = json_array.as_array_mut() { + array.push(json_data); + if file.set_len(0).is_ok() && file.seek(SeekFrom::Start(0)).is_ok() + { + if let Err(e) = + file.write_all(json_array.to_string().as_bytes()) + { + error!("Failed to write JSON data to file: {}", e); + } + } else { + error!("Failed to truncate and seek file: {}", filename); + } } } else { - error!("Failed to truncate and seek file: {}", filename); + error!("Failed to parse JSON from file: {}", filename); } + } else { + error!("Failed to read JSON file: {}", filename); } } else { - error!("Failed to parse JSON from file: {}", filename); - } - - let _ = file.flush(); - } - } - - if self.posthog_enabled { - tokio::select! { - _ = self.send_to_posthog(total_memory_gb, system_total_memory, total_cpu) => {}, - _ = tokio::time::sleep(Duration::from_secs(5)) => { - warn!("PostHog request timed out"); + error!("Failed to open JSON file: {}", filename); } } } @@ -201,31 +145,47 @@ impl ResourceMonitor { pub fn start_monitoring(self: &Arc, interval: Duration) { let monitor = Arc::clone(self); - tokio::spawn(async move { let mut sys = System::new_all(); - loop { tokio::select! { _ = tokio::time::sleep(interval) => { sys.refresh_all(); - monitor.log_status(&sys).await; - + monitor.log_status(&sys); } } } }); } - pub async fn shutdown(&self) { - if let Some(ref file) = self.resource_log_file { - if let Ok(mut f) = OpenOptions::new().write(true).open(file) { - let _ = f.flush(); + #[cfg(target_os = "macos")] + fn get_npu_usage(&self) -> Option { + let output = Command::new("ioreg") + .args(["-r", "-c", "AppleARMIODevice", "-n", "ane0"]) + .output() + .ok()?; + + let output_str = String::from_utf8_lossy(&output.stdout); + + // Parse the output to find the "ane_power" value + for line in output_str.lines() { + if line.contains("\"ane_power\"") { + if let Some(value) = line.split('=').nth(1) { + if let Ok(power) = value.trim().parse::() { + // Assuming max ANE power is 8.0W (adjust if needed) + let max_ane_power = 8.0; + let npu_usage_percent = (power / max_ane_power) * 100.0; + return Some(npu_usage_percent); + } + } } } - if let Some(_) = &self.posthog_client { - tokio::time::sleep(Duration::from_millis(100)).await; - } + None + } + + #[cfg(not(target_os = "macos"))] + fn get_npu_usage(&self) -> Option { + None } } diff --git a/screenpipe-server/src/server.rs b/screenpipe-server/src/server.rs index 67ad1e7c27..ec5b4ffbcd 100644 --- a/screenpipe-server/src/server.rs +++ b/screenpipe-server/src/server.rs @@ -18,7 +18,6 @@ use futures::{ SinkExt, StreamExt, }; use image::ImageFormat::{self}; -use screenpipe_core::{AudioDevice, AudioDeviceType, DeviceControl, DeviceManager, DeviceType}; use screenpipe_events::{send_event, subscribe_to_all_events, Event as ScreenpipeEvent}; use crate::{ @@ -34,7 +33,7 @@ use crate::{ }; use crate::{plugin::ApiPluginLayer, video_utils::extract_frame}; use chrono::{DateTime, Utc}; -use screenpipe_audio::{default_input_device, default_output_device, list_audio_devices}; +use screenpipe_audio::{default_input_device, default_output_device, list_audio_devices, AudioDevice, DeviceType}; use tracing::{debug, error, info}; use screenpipe_vision::monitor::list_monitors; @@ -73,7 +72,7 @@ use crate::text_embeds::generate_embedding; pub struct AppState { pub db: Arc, - pub device_manager: Arc, + // pub device_manager: Arc, pub app_start_time: DateTime, pub screenpipe_dir: PathBuf, pub pipe_manager: Arc, @@ -200,7 +199,7 @@ pub struct AudioContent { pub offset_index: i64, pub tags: Vec, pub device_name: String, - pub device_type: AudioDeviceType, + pub device_type: DeviceType, pub speaker: Option, pub start_time: Option, pub end_time: Option, @@ -887,7 +886,7 @@ async fn list_pipes_handler(State(state): State>) -> JsonResponse< pub struct Server { db: Arc, addr: SocketAddr, - device_manager: Arc, + // device_manager: Arc, screenpipe_dir: PathBuf, pipe_manager: Arc, vision_disabled: bool, @@ -900,7 +899,7 @@ impl Server { pub fn new( db: Arc, addr: SocketAddr, - device_manager: Arc, + // device_manager: Arc, screenpipe_dir: PathBuf, pipe_manager: Arc, vision_disabled: bool, @@ -910,7 +909,7 @@ impl Server { Server { db, addr, - device_manager, + // device_manager, screenpipe_dir, pipe_manager, vision_disabled, @@ -929,7 +928,7 @@ impl Server { { let app_state = Arc::new(AppState { db: self.db.clone(), - device_manager: self.device_manager.clone(), + // device_manager: self.device_manager.clone(), app_start_time: Utc::now(), screenpipe_dir: self.screenpipe_dir.clone(), pipe_manager: self.pipe_manager, @@ -1161,7 +1160,7 @@ async fn add_transcription_to_db( let device = AudioDevice { name: device_name.to_string(), - device_type: AudioDeviceType::Input, + device_type: DeviceType::Input, }; let dummy_audio_chunk_id = db.insert_audio_chunk("").await?; @@ -1670,7 +1669,7 @@ async fn get_similar_speakers_handler( pub struct AudioDeviceControlRequest { device_name: String, #[serde(default)] - device_type: Option, + device_type: Option, } #[derive(Serialize)] @@ -1680,94 +1679,94 @@ pub struct AudioDeviceControlResponse { } // Add these new handler functions before create_router() -async fn start_audio_device( - State(state): State>, - Json(payload): Json, -) -> Result, (StatusCode, JsonResponse)> { - let device = AudioDevice { - name: payload.device_name.clone(), - device_type: payload.device_type.unwrap_or(AudioDeviceType::Input), - }; - - // Validate device exists - let available_devices = list_audio_devices().await.map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - JsonResponse(json!({ - "error": format!("failed to list audio devices: {}", e), - "success": false - })), - ) - })?; - - if !available_devices.contains(&device) { - return Err(( - StatusCode::BAD_REQUEST, - JsonResponse(json!({ - "error": format!("device not found: {}", device.name), - "success": false - })), - )); - } - - let control = DeviceControl { - device: screenpipe_core::DeviceType::Audio(device.clone()), - is_running: true, - is_paused: false, - }; - - let _ = state.device_manager.update_device(control).await; - - Ok(JsonResponse(AudioDeviceControlResponse { - success: true, - message: format!("started audio device: {}", device.name), - })) -} - -async fn stop_audio_device( - State(state): State>, - Json(payload): Json, -) -> Result, (StatusCode, JsonResponse)> { - let device = AudioDevice { - name: payload.device_name.clone(), - device_type: payload.device_type.unwrap_or(AudioDeviceType::Input), - }; - - // Validate device exists - let available_devices = list_audio_devices().await.map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - JsonResponse(json!({ - "error": format!("failed to list audio devices: {}", e), - "success": false - })), - ) - })?; - - if !available_devices.contains(&device) { - return Err(( - StatusCode::BAD_REQUEST, - JsonResponse(json!({ - "error": format!("device not found: {}", device.name), - "success": false - })), - )); - } - - let _ = state - .device_manager - .update_device(DeviceControl { - device: screenpipe_core::DeviceType::Audio(device.clone()), - is_running: false, - is_paused: false, - }) - .await; - - Ok(JsonResponse(AudioDeviceControlResponse { - success: true, - message: format!("stopped audio device: {}", device.name), - })) -} +// async fn start_audio_device( +// State(state): State>, +// Json(payload): Json, +// ) -> Result, (StatusCode, JsonResponse)> { +// let device = AudioDevice { +// name: payload.device_name.clone(), +// device_type: payload.device_type.unwrap_or(AudioDeviceType::Input), +// }; + +// // Validate device exists +// let available_devices = list_audio_devices().await.map_err(|e| { +// ( +// StatusCode::INTERNAL_SERVER_ERROR, +// JsonResponse(json!({ +// "error": format!("failed to list audio devices: {}", e), +// "success": false +// })), +// ) +// })?; + +// if !available_devices.contains(&device) { +// return Err(( +// StatusCode::BAD_REQUEST, +// JsonResponse(json!({ +// "error": format!("device not found: {}", device.name), +// "success": false +// })), +// )); +// } + +// let control = DeviceControl { +// device: screenpipe_core::DeviceType::Audio(device.clone()), +// is_running: true, +// is_paused: false, +// }; + +// let _ = state.device_manager.update_device(control).await; + +// Ok(JsonResponse(AudioDeviceControlResponse { +// success: true, +// message: format!("started audio device: {}", device.name), +// })) +// } + +// async fn stop_audio_device( +// State(state): State>, +// Json(payload): Json, +// ) -> Result, (StatusCode, JsonResponse)> { +// let device = AudioDevice { +// name: payload.device_name.clone(), +// device_type: payload.device_type.unwrap_or(AudioDeviceType::Input), +// }; + +// // Validate device exists +// let available_devices = list_audio_devices().await.map_err(|e| { +// ( +// StatusCode::INTERNAL_SERVER_ERROR, +// JsonResponse(json!({ +// "error": format!("failed to list audio devices: {}", e), +// "success": false +// })), +// ) +// })?; + +// if !available_devices.contains(&device) { +// return Err(( +// StatusCode::BAD_REQUEST, +// JsonResponse(json!({ +// "error": format!("device not found: {}", device.name), +// "success": false +// })), +// )); +// } + +// let _ = state +// .device_manager +// .update_device(DeviceControl { +// device: screenpipe_core::DeviceType::Audio(device.clone()), +// is_running: false, +// is_paused: false, +// }) +// .await; + +// Ok(JsonResponse(AudioDeviceControlResponse { +// success: true, +// message: format!("stopped audio device: {}", device.name), +// })) +// } #[derive(Deserialize)] struct EventsQuery { @@ -1842,72 +1841,72 @@ pub struct VisionDeviceControlResponse { message: String, } -async fn start_vision_device( - State(state): State>, - Json(payload): Json, -) -> Result, (StatusCode, JsonResponse)> { - debug!("starting vision device: {}", payload.device_id); - // Validate device exists - let monitors = list_monitors().await; - if !monitors.iter().any(|m| m.id() == payload.device_id) { - return Err(( - StatusCode::BAD_REQUEST, - JsonResponse(json!({ - "error": format!("monitor not found: {}", payload.device_id), - "success": false - })), - )); - } - - debug!("starting vision device: {}", payload.device_id); - let _ = state - .device_manager - .update_device(DeviceControl { - device: screenpipe_core::DeviceType::Vision(payload.device_id), - is_running: true, - is_paused: false, - }) - .await; - - Ok(JsonResponse(VisionDeviceControlResponse { - success: true, - message: format!("started vision device: {}", payload.device_id), - })) -} - -async fn stop_vision_device( - State(state): State>, - Json(payload): Json, -) -> Result, (StatusCode, JsonResponse)> { - debug!("stopping vision device: {}", payload.device_id); - // Validate device exists - let monitors = list_monitors().await; - if !monitors.iter().any(|m| m.id() == payload.device_id) { - return Err(( - StatusCode::BAD_REQUEST, - JsonResponse(json!({ - "error": format!("monitor not found: {}", payload.device_id), - "success": false - })), - )); - } - - debug!("stopping vision device: {}", payload.device_id); - - let _ = state - .device_manager - .update_device(DeviceControl { - device: screenpipe_core::DeviceType::Vision(payload.device_id), - is_running: false, - is_paused: false, - }) - .await; - - Ok(JsonResponse(VisionDeviceControlResponse { - success: true, - message: format!("stopped vision device: {}", payload.device_id), - })) -} +// async fn start_vision_device( +// State(state): State>, +// Json(payload): Json, +// ) -> Result, (StatusCode, JsonResponse)> { +// debug!("starting vision device: {}", payload.device_id); +// // Validate device exists +// let monitors = list_monitors().await; +// if !monitors.iter().any(|m| m.id() == payload.device_id) { +// return Err(( +// StatusCode::BAD_REQUEST, +// JsonResponse(json!({ +// "error": format!("monitor not found: {}", payload.device_id), +// "success": false +// })), +// )); +// } + +// debug!("starting vision device: {}", payload.device_id); +// let _ = state +// .device_manager +// .update_device(DeviceControl { +// device: screenpipe_core::DeviceType::Vision(payload.device_id), +// is_running: true, +// is_paused: false, +// }) +// .await; + +// Ok(JsonResponse(VisionDeviceControlResponse { +// success: true, +// message: format!("started vision device: {}", payload.device_id), +// })) +// } + +// async fn stop_vision_device( +// State(state): State>, +// Json(payload): Json, +// ) -> Result, (StatusCode, JsonResponse)> { +// debug!("stopping vision device: {}", payload.device_id); +// // Validate device exists +// let monitors = list_monitors().await; +// if !monitors.iter().any(|m| m.id() == payload.device_id) { +// return Err(( +// StatusCode::BAD_REQUEST, +// JsonResponse(json!({ +// "error": format!("monitor not found: {}", payload.device_id), +// "success": false +// })), +// )); +// } + +// debug!("stopping vision device: {}", payload.device_id); + +// let _ = state +// .device_manager +// .update_device(DeviceControl { +// device: screenpipe_core::DeviceType::Vision(payload.device_id), +// is_running: false, +// is_paused: false, +// }) +// .await; + +// Ok(JsonResponse(VisionDeviceControlResponse { +// success: true, +// message: format!("stopped vision device: {}", payload.device_id), +// })) +// } // websocket events handler async fn ws_events_handler(ws: WebSocketUpgrade, query: Query) -> Response { @@ -2042,15 +2041,15 @@ pub fn create_router() -> Router> { .route("/speakers/similar", get(get_similar_speakers_handler)) .route("/experimental/frames/merge", post(merge_frames_handler)) .route("/experimental/validate/media", get(validate_media_handler)) - .route("/audio/start", post(start_audio_device)) - .route("/audio/stop", post(stop_audio_device)) + // .route("/audio/start", post(start_audio_device)) + // .route("/audio/stop", post(stop_audio_device)) .route("/ws/events", get(ws_events_handler)) .route("/semantic-search", get(semantic_search_handler)) .route("/frames/:frame_id", get(get_frame_data)) - .route("/vision/start", post(start_vision_device)) - .route("/vision/stop", post(stop_vision_device)) - .route("/audio/restart", post(restart_audio_devices)) - .route("/vision/restart", post(restart_vision_devices)) + // .route("/vision/start", post(start_vision_device)) + // .route("/vision/stop", post(stop_vision_device)) + // .route("/audio/restart", post(restart_audio_devices)) + // .route("/vision/restart", post(restart_vision_devices)) .layer(cors); #[cfg(feature = "experimental")] @@ -2408,62 +2407,62 @@ pub struct RestartAudioDevicesResponse { restarted_devices: Vec, } -async fn restart_audio_devices( - State(state): State>, -) -> Result, (StatusCode, JsonResponse)> { - debug!("restarting active audio devices"); - - // Get currently active devices from device manager - let active_devices = state.device_manager.get_active_devices().await; - let mut restarted_devices = Vec::new(); - - for (_, control) in active_devices { - debug!("restarting audio device: {:?}", control.device); - - let audio_device = match control.device { - DeviceType::Audio(device) => device, - _ => continue, - }; - // Stop the device - let _ = state - .device_manager - .update_device(DeviceControl { - device: screenpipe_core::DeviceType::Audio(audio_device.clone()), - is_running: false, - is_paused: false, - }) - .await; - - // Small delay to ensure clean shutdown - tokio::time::sleep(Duration::from_millis(1000)).await; - - // Start the device again - let _ = state - .device_manager - .update_device(DeviceControl { - device: screenpipe_core::DeviceType::Audio(audio_device.clone()), - is_running: true, - is_paused: false, - }) - .await; - - restarted_devices.push(audio_device.name.clone()); - } - - if restarted_devices.is_empty() { - Ok(JsonResponse(RestartAudioDevicesResponse { - success: true, - message: "no active audio devices to restart".to_string(), - restarted_devices, - })) - } else { - Ok(JsonResponse(RestartAudioDevicesResponse { - success: true, - message: format!("restarted {} audio devices", restarted_devices.len()), - restarted_devices, - })) - } -} +// async fn restart_audio_devices( +// State(state): State>, +// ) -> Result, (StatusCode, JsonResponse)> { +// debug!("restarting active audio devices"); + +// // Get currently active devices from device manager +// let active_devices = state.device_manager.get_active_devices().await; +// let mut restarted_devices = Vec::new(); + +// for (_, control) in active_devices { +// debug!("restarting audio device: {:?}", control.device); + +// let audio_device = match control.device { +// DeviceType::Audio(device) => device, +// _ => continue, +// }; +// // Stop the device +// let _ = state +// .device_manager +// .update_device(DeviceControl { +// device: screenpipe_core::DeviceType::Audio(audio_device.clone()), +// is_running: false, +// is_paused: false, +// }) +// .await; + +// // Small delay to ensure clean shutdown +// tokio::time::sleep(Duration::from_millis(1000)).await; + +// // Start the device again +// let _ = state +// .device_manager +// .update_device(DeviceControl { +// device: screenpipe_core::DeviceType::Audio(audio_device.clone()), +// is_running: true, +// is_paused: false, +// }) +// .await; + +// restarted_devices.push(audio_device.name.clone()); +// } + +// if restarted_devices.is_empty() { +// Ok(JsonResponse(RestartAudioDevicesResponse { +// success: true, +// message: "no active audio devices to restart".to_string(), +// restarted_devices, +// })) +// } else { +// Ok(JsonResponse(RestartAudioDevicesResponse { +// success: true, +// message: format!("restarted {} audio devices", restarted_devices.len()), +// restarted_devices, +// })) +// } +// } #[derive(Serialize)] pub struct RestartVisionDevicesResponse { @@ -2472,273 +2471,54 @@ pub struct RestartVisionDevicesResponse { restarted_devices: Vec, } -async fn restart_vision_devices( - State(state): State>, -) -> Result, (StatusCode, JsonResponse)> { - debug!("restarting active vision devices"); - - let active_devices = state.device_manager.get_active_devices().await; - let mut restarted_devices = Vec::new(); - - for (_, control) in active_devices { - let vision_device = match control.device { - DeviceType::Vision(device) => device, - _ => continue, - }; - - debug!("restarting vision device: {:?}", vision_device); - - // Stop the device - let _ = state - .device_manager - .update_device(DeviceControl { - device: screenpipe_core::DeviceType::Vision(vision_device.clone()), - is_running: false, - is_paused: false, - }) - .await; - - tokio::time::sleep(Duration::from_millis(1000)).await; - - // Start the device again - let _ = state - .device_manager - .update_device(DeviceControl { - device: screenpipe_core::DeviceType::Vision(vision_device.clone()), - is_running: true, - is_paused: false, - }) - .await; - - restarted_devices.push(vision_device.clone()); - } - - Ok(JsonResponse(RestartVisionDevicesResponse { - success: true, - message: if restarted_devices.is_empty() { - "no active vision devices to restart".to_string() - } else { - format!("restarted {} vision devices", restarted_devices.len()) - }, - restarted_devices, - })) -} - -/* - -Curl commands for reference: -# 1. Basic search query -curl "http://localhost:3030/search?q=test&limit=5&offset=0" | jq - -# 2. Search with content type filter (OCR) -curl "http://localhost:3030/search?q=test&limit=5&offset=0&content_type=ocr" | jq - -# 3. Search with content type filter (Audio) -curl "http://localhost:3030/search?q=test&limit=5&offset=0&content_type=audio" | jq - -# 4. Search with pagination -curl "http://localhost:3030/search?q=test&limit=10&offset=20" | jq - -# 6. Search with no query (should return all results) -curl "http://localhost:3030/search?limit=5&offset=0" - -// list devices -// # curl "http://localhost:3030/audio/list" | jq - - -echo "Listing audio devices:" -curl "http://localhost:3030/audio/list" | jq - - -echo "Searching for content:" -curl "http://localhost:3030/search?q=test&limit=5&offset=0&content_type=all" | jq -curl "http://localhost:3030/search?limit=5&offset=0&content_type=ocr" | jq - -curl "http://localhost:3030/search?q=libmp3&limit=5&offset=0&content_type=all" | jq - -# last 5 w frames -curl "http://localhost:3030/search?limit=5&offset=0&content_type=all&include_frames=true&start_time=$(date -u -v-5M +%Y-%m-%dT%H:%M:%SZ)" | jq - -# 30 min to 25 min ago -curl "http://localhost:3030/search?limit=5&offset=0&content_type=all&include_frames=true&start_time=$(date -u -v-30M +%Y-%m-%dT%H:%M:%SZ)&end_time=$(date -u -v-25M +%Y-%m-%dT%H:%M:%SZ)" | jq - -curl "http://localhost:3030/search?limit=1&offset=0&content_type=all&include_frames=true&start_time=$(date -u -v-30M +%Y-%m-%dT%H:%M:%SZ)&end_time=$(date -u -v-25M +%Y-%m-%dT%H:%M:%SZ)" | jq -r '.data[0].content.frame' | base64 --decode > /tmp/frame.png && open /tmp/frame.png - -# Search for content from the last 30 minutes -curl "http://localhost:3030/search?q=test&limit=5&offset=0&content_type=all&start_time=$(date -u -v-30M +%Y-%m-%dT%H:%M:%SZ)&end_time=$(date -u -v-25M +%Y-%m-%dT%H:%M:%SZ)" | jq - -# Search for content up to 1 hour ago -curl "http://localhost:3030/search?q=test&limit=5&offset=0&content_type=all&end_time=$(date -u -v-1H +%Y-%m-%dT%H:%M:%SZ)" | jq - -# Search for content between 2 hours ago and 1 hour ago -curl "http://localhost:3030/search?q=test&limit=5&offset=0&content_type=all&start_time=$(date -u -v-2H +%Y-%m-%dT%H:%M:%SZ)&end_time=$(date -u -v-1H +%Y-%m-%dT%H:%M:%SZ)" | jq - -# Search for OCR content from yesterday -curl "http://localhost:3030/search?q=test&limit=5&offset=0&content_type=ocr&start_time=$(date -u -v-1d -v0H -v0M -v0S +%Y-%m-%dT%H:%M:%SZ)&end_time=$(date -u -v-1d -v23H -v59M -v59S +%Y-%m-%dT%H:%M:%SZ)" | jq - -# Search for audio content with a keyword from the beginning of the current month -curl "http://localhost:3030/search?q=libmp3&limit=5&offset=0&content_type=audio&start_time=$(date -u -v1d -v0H -v0M -v0S +%Y-%m-01T%H:%M:%SZ)" | jq - -curl "http://localhost:3030/search?app_name=cursor" -curl "http://localhost:3030/search?content_type=audio&min_length=20" - -curl "http://localhost:3030/search?q=Matt&offset=0&limit=50&start_time=$(date -u -v-2H +%Y-%m-%dT%H:%M:%SZ)&end_time=$(date -u -v-1H +%Y-%m-%dT%H:%M:%SZ)" | jq . - - -curl "http://localhost:3030/search?limit=50&offset=0&content_type=all&start_time=$(date -u -v-2H +%Y-%m-%dT%H:%M:%SZ)&end_time=$(date -u -v-1H +%Y-%m-%dT%H:%M:%SZ)" | jq - -date -u -v-2H +%Y-%m-%dT%H:%M:%SZ -2024-08-12T06:51:54Z -date -u -v-1H +%Y-%m-%dT%H:%M:%SZ -2024-08-12T07:52:17Z - -curl 'http://localhost:3030/search?limit=50&offset=0&content_type=all&start_time=2024-08-12T06:48:18Z&end_time=2024-08-12T07:48:34Z' | jq . - - -curl "http://localhost:3030/search?q=Matt&offset=0&limit=10&start_time=2024-08-12T04:00:00Z&end_time=2024-08-12T05:00:00Z&content_type=all" | jq . - -curl "http://localhost:3030/search?q=Matt&offset=0&limit=10&start_time=2024-08-12T06:43:53Z&end_time=2024-08-12T08:43:53Z&content_type=all" | jq . - -curl 'http://localhost:3030/search?offset=0&limit=10&start_time=2024-08-12T04%3A00%3A00Z&end_time=2024-08-12T05%3A00%3A00Z&content_type=all' | jq . - - - - - - - - - - - -# First, search for Rust-related content -curl "http://localhost:3030/search?q=debug&limit=5&offset=0&content_type=ocr" - -# Then, assuming you found a relevant item with id 123, tag it -curl -X POST "http://localhost:3030/tags/vision/626" \ - -H "Content-Type: application/json" \ - -d '{"tags": ["debug"]}' - - - - -# List all pipes -curl "http://localhost:3030/pipes/list" | jq - -# Download a new pipe -curl -X POST "http://localhost:3030/pipes/download" \ - -H "Content-Type: application/json" \ - -d '{"url": "./pipes/pipe-stream-ocr-text"}' | jq - -curl -X POST "http://localhost:3030/pipes/download" \ - -H "Content-Type: application/json" \ - -d '{"url": "./pipes/pipe-security-check"}' | jq - - -curl -X POST "http://localhost:3030/pipes/download" \ - -H "Content-Type: application/json" \ - -d '{"url": "https://github.com/mediar-ai/screenpipe/tree/main/pipes/pipe-stream-ocr-text"}' | jq - - -# Get info for a specific pipe -curl "http://localhost:3030/pipes/info/pipe-stream-ocr-text" | jq - -# Run a pipe -curl -X POST "http://localhost:3030/pipes/enable" \ - -H "Content-Type: application/json" \ - -d '{"pipe_id": "pipe-stream-ocr-text"}' | jq - - - - curl -X POST "http://localhost:3030/pipes/enable" \ - -H "Content-Type: application/json" \ - -d '{"pipe_id": "pipe-security-check"}' | jq - -# Stop a pipe -curl -X POST "http://localhost:3030/pipes/disable" \ - -H "Content-Type: application/json" \ - -d '{"pipe_id": "pipe-stream-ocr-text"}' | jq - -# Update pipe configuration -curl -X POST "http://localhost:3030/pipes/update" \ - -H "Content-Type: application/json" \ - -d '{ - "pipe_id": "pipe-stream-ocr-text", - "config": { - "key": "value", - "another_key": "another_value" - } - }' | jq - - - - - -# Basic search with min_length and max_length -curl "http://localhost:3030/search?q=test&limit=10&offset=0&min_length=5&max_length=50" | jq - -# Search for OCR content with length constraints -curl "http://localhost:3030/search?q=code&content_type=ocr&limit=5&offset=0&min_length=20&max_length=100" | jq - -# Search for audio content with length constraints -curl "http://localhost:3030/search?q=meeting&content_type=audio&limit=5&offset=0&min_length=50&max_length=200" | jq - -# Search with time range and length constraints -curl "http://localhost:3030/search?q=project&limit=10&offset=0&min_length=10&max_length=100&start_time=$(date -u -v-1H +%Y-%m-%dT%H:%M:%SZ)&end_time=$(date -u +%Y-%m-%dT%H:%M:%SZ)" | jq - -# Search with app_name and length constraints -curl "http://localhost:3030/search?app_name=cursor&limit=5&offset=0&min_length=15&max_length=150" | jq - -# Search with window_name and length constraints -curl "http://localhost:3030/search?window_name=alacritty&min_length=5&max_length=50" | jq - -# Search for very short content -curl "http://localhost:3030/search?q=&limit=10&offset=0&max_length=10" | jq - -# Search for very long content -curl "http://localhost:3030/search?q=&limit=10&offset=0&min_length=500" | jq - - -curl "http://localhost:3030/search?limit=10&offset=0&min_length=500&content_type=audio" | jq - - -# read random data and generate a clip using the merge endpoint - - -# Perform the search and store the response - -# First, let's search for some recent video content -SEARCH_RESPONSE1=$(curl -s "http://localhost:3030/search?q=&limit=5&offset=0&content_type=ocr&start_time=$(date -u -v-30M +%Y-%m-%dT%H:%M:%SZ)&end_time=$(date -u -v-25M +%Y-%m-%dT%H:%M:%SZ)") -SEARCH_RESPONSE2=$(curl -s "http://localhost:3030/search?q=&limit=5&offset=0&content_type=ocr&start_time=$(date -u -v-40M +%Y-%m-%dT%H:%M:%SZ)&end_time=$(date -u -v-35M +%Y-%m-%dT%H:%M:%SZ)") -SEARCH_RESPONSE3=$(curl -s "http://localhost:3030/search?q=&limit=5&offset=0&content_type=ocr&start_time=$(date -u -v-50M +%Y-%m-%dT%H:%M:%SZ)&end_time=$(date -u -v-45M +%Y-%m-%dT%H:%M:%SZ)") - -# Extract the file paths from the search results without creating JSON arrays -VIDEO_PATHS1=$(echo "$SEARCH_RESPONSE1" | jq -r '.data[].content.file_path' | sort -u) -VIDEO_PATHS2=$(echo "$SEARCH_RESPONSE2" | jq -r '.data[].content.file_path' | sort -u) -VIDEO_PATHS3=$(echo "$SEARCH_RESPONSE3" | jq -r '.data[].content.file_path' | sort -u) - -# Merge the video paths and create a single JSON array -MERGED_VIDEO_PATHS=$(echo "$VIDEO_PATHS1"$'\n'"$VIDEO_PATHS2"$'\n'"$VIDEO_PATHS3" | sort -u | jq -R -s -c 'split("\n") | map(select(length > 0))') - -# Create the JSON payload for merging videos -MERGE_PAYLOAD=$(jq -n \ - --argjson video_paths "$MERGED_VIDEO_PATHS" \ - '{ - video_paths: $video_paths - }') - -echo "Merge Payload: $MERGE_PAYLOAD" - -# Send the merge request and store the response -MERGE_RESPONSE=$(curl -s -X POST "http://localhost:3030/experimental/frames/merge" \ - -H "Content-Type: application/json" \ - -d "$MERGE_PAYLOAD") - -echo "Merge Response: $MERGE_RESPONSE" - -# Extract the merged video path from the response -MERGED_VIDEO_PATH=$(echo "$MERGE_RESPONSE" | jq -r '.video_path') - -echo "Merged Video Path: $MERGED_VIDEO_PATH" - -*/ +// async fn restart_vision_devices( +// State(state): State>, +// ) -> Result, (StatusCode, JsonResponse)> { +// debug!("restarting active vision devices"); + +// let active_devices = state.device_manager.get_active_devices().await; +// let mut restarted_devices = Vec::new(); + +// for (_, control) in active_devices { +// let vision_device = match control.device { +// DeviceType::Vision(device) => device, +// _ => continue, +// }; + +// debug!("restarting vision device: {:?}", vision_device); + +// // Stop the device +// let _ = state +// .device_manager +// .update_device(DeviceControl { +// device: screenpipe_core::DeviceType::Vision(vision_device.clone()), +// is_running: false, +// is_paused: false, +// }) +// .await; + +// tokio::time::sleep(Duration::from_millis(1000)).await; + +// // Start the device again +// let _ = state +// .device_manager +// .update_device(DeviceControl { +// device: screenpipe_core::DeviceType::Vision(vision_device.clone()), +// is_running: true, +// is_paused: false, +// }) +// .await; + +// restarted_devices.push(vision_device.clone()); +// } + +// Ok(JsonResponse(RestartVisionDevicesResponse { +// success: true, +// message: if restarted_devices.is_empty() { +// "no active vision devices to restart".to_string() +// } else { +// format!("restarted {} vision devices", restarted_devices.len()) +// }, +// restarted_devices, +// })) +// } diff --git a/screenpipe-server/src/video.rs b/screenpipe-server/src/video.rs index 8d75e540ba..8137d20168 100644 --- a/screenpipe-server/src/video.rs +++ b/screenpipe-server/src/video.rs @@ -1,24 +1,21 @@ use chrono::Utc; use crossbeam::queue::ArrayQueue; +use image::ImageFormat::{self}; +use tracing::{debug, error, info, warn}; use screenpipe_core::{find_ffmpeg_path, Language}; use screenpipe_vision::{ capture_screenshot_by_window::WindowFilters, continuous_capture, CaptureResult, OcrEngine, }; -use std::borrow::Cow; use std::path::PathBuf; use std::process::Stdio; use std::sync::Arc; -use std::sync::Weak; use std::time::Duration; use tokio::io::AsyncBufReadExt; use tokio::io::AsyncWriteExt; use tokio::io::BufReader; -use tokio::process::{Child, ChildStderr, ChildStdin, Command}; +use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command}; use tokio::sync::mpsc::channel; -use tokio::sync::watch; -use tokio::task::JoinHandle; use tokio::time::sleep; -use tracing::{debug, error, info, warn}; pub(crate) const MAX_FPS: f64 = 30.0; // Adjust based on your needs const MAX_QUEUE_SIZE: usize = 10; @@ -27,30 +24,26 @@ pub struct VideoCapture { #[allow(unused)] video_frame_queue: Arc>>, pub ocr_frame_queue: Arc>>, - shutdown_tx: watch::Sender, - handles: Vec>, } impl VideoCapture { + #[allow(clippy::too_many_arguments)] pub fn new( output_path: &str, fps: f64, video_chunk_duration: Duration, new_chunk_callback: impl Fn(&str) + Send + Sync + 'static, - ocr_engine: Weak, + ocr_engine: Arc, monitor_id: u32, - ignore_list: Arc<[String]>, - include_list: Arc<[String]>, - languages: Arc<[Language]>, + ignore_list: &[String], + include_list: &[String], + languages: Vec, capture_unfocused_windows: bool, ) -> Self { let fps = if fps.is_finite() && fps > 0.0 { fps } else { - warn!( - "[monitor_id: {}] Invalid FPS value: {}. Using default of 1.0", - monitor_id, fps - ); + warn!("Invalid FPS value: {}. Using default of 1.0", fps); 1.0 }; let interval = Duration::from_secs_f64(1.0 / fps); @@ -62,82 +55,48 @@ impl VideoCapture { let capture_video_frame_queue = video_frame_queue.clone(); let capture_ocr_frame_queue = ocr_frame_queue.clone(); let (result_sender, mut result_receiver) = channel(512); - let window_filters = Arc::new(WindowFilters::new(&ignore_list, &include_list)); + let window_filters = Arc::new(WindowFilters::new(ignore_list, include_list)); let window_filters_clone = Arc::clone(&window_filters); - let (shutdown_tx, shutdown_rx) = watch::channel(false); - let mut handles = Vec::new(); - let shutdown_rx_capture = shutdown_rx.clone(); - let shutdown_rx_queue = shutdown_rx.clone(); - let shutdown_rx_video = shutdown_rx.clone(); - let languages_clone = languages.clone(); - let result_sender_inner = result_sender.clone(); - - let capture_handle = tokio::spawn(async move { - let mut rx = shutdown_rx_capture; - loop { - if *rx.borrow() { - info!( - "[monitor_id: {}] shutting down video capture thread", - monitor_id - ); - break; - } - let result_sender = result_sender_inner.clone(); - let window_filters_clone = Arc::clone(&window_filters_clone); - let languages_clone = languages_clone.clone(); - - let ocr_engine = match ocr_engine.upgrade() { - Some(engine) => engine, - None => { - warn!("[monitor_id: {}] OCR engine no longer exists", monitor_id); - break; - } - }; - - tokio::select! { - _ = continuous_capture( - result_sender, - interval, - ocr_engine, - monitor_id, - window_filters_clone, - languages_clone.clone(), - capture_unfocused_windows, - rx.clone(), - ) => { - debug!("[monitor_id: {}] continuous capture completed, restarting", monitor_id); + let _capture_thread = tokio::spawn(async move { + continuous_capture( + result_sender, + interval, + (*ocr_engine).clone(), + monitor_id, + window_filters_clone, + languages.clone(), + capture_unfocused_windows, + ) + .await; + }); + + // In the _queue_thread + let _queue_thread = tokio::spawn(async move { + // Helper function to push to queue and handle errors + fn push_to_queue( + queue: &ArrayQueue>, + result: &Arc, + queue_name: &str, + ) -> bool { + if queue.push(Arc::clone(result)).is_err() { + if queue.pop().is_none() { + error!("{} queue is in an inconsistent state", queue_name); + return false; } - _ = rx.changed() => { - if *rx.borrow() { - info!("[monitor_id: {}] shutting down video capture thread", monitor_id); - break; - } + if queue.push(Arc::clone(result)).is_err() { + error!( + "Failed to push to {} queue after removing oldest frame", + queue_name + ); + return false; } + debug!("{} queue was full, dropped oldest frame", queue_name); } + true } - debug!( - "[monitor_id: {}] exiting capture handle loop, dropping sender", - monitor_id - ); - drop(result_sender_inner); - }); - handles.push(capture_handle); - - let queue_handle = tokio::spawn(async move { - let rx = shutdown_rx_queue; while let Some(result) = result_receiver.recv().await { - if *rx.borrow() { - info!( - "[monitor_id: {}] shutting down video queue thread", - monitor_id - ); - break; - } let frame_number = result.frame_number; - debug!( - "[monitor_id: {}] received frame {} for queueing", - monitor_id, frame_number - ); + debug!("Received frame {} for queueing", frame_number); let result = Arc::new(result); @@ -146,105 +105,92 @@ impl VideoCapture { if !video_pushed || !ocr_pushed { error!( - "[monitor_id: {}] failed to push frame {} to one or more queues, queue lengths: {}, {}", - monitor_id, frame_number, - capture_video_frame_queue.len(), - capture_ocr_frame_queue.len() + "Failed to push frame {} to one or more queues", + frame_number ); continue; // Skip to next iteration instead of crashing } debug!( - "[monitor_id: {}] frame {} pushed to queues. Queue lengths: {}, {}", - monitor_id, + "Frame {} pushed to queues. Queue lengths: {}, {}", frame_number, capture_video_frame_queue.len(), capture_ocr_frame_queue.len() ); } }); - handles.push(queue_handle); let video_frame_queue_clone = video_frame_queue.clone(); let output_path = output_path.to_string(); - let video_handle = tokio::spawn(async move { - let rx = shutdown_rx_video; - save_frames_as_video_with_shutdown( + let _video_thread = tokio::spawn(async move { + save_frames_as_video( &video_frame_queue_clone, &output_path, fps, new_chunk_callback_clone, monitor_id, video_chunk_duration, - rx, ) .await; }); - handles.push(video_handle); VideoCapture { video_frame_queue, ocr_frame_queue, - shutdown_tx, - handles, - } - } - - pub async fn shutdown(self) -> Result<(), anyhow::Error> { - info!("shutting down video capture"); - self.shutdown_tx.send(true)?; - - for handle in self.handles { - if let Err(e) = handle.await { - error!("error joining handle: {}", e); - } } - - Ok(()) } } pub async fn start_ffmpeg_process(output_file: &str, fps: f64) -> Result { - let fps = fps.min(MAX_FPS); - - debug!("starting ffmpeg process for: {}", output_file); + // Overriding fps with max fps if over the max and warning user + let fps = if fps > MAX_FPS { + warn!("Overriding FPS from {} to {}", fps, MAX_FPS); + MAX_FPS + } else { + fps + }; + + info!("Starting FFmpeg process for file: {}", output_file); let fps_str = fps.to_string(); let mut command = Command::new(find_ffmpeg_path().unwrap()); - - // Updated FFmpeg arguments for better performance and quality - let args = vec![ + let mut args = vec![ "-f", "image2pipe", "-vcodec", - "mjpeg", + "png", "-r", &fps_str, "-i", "-", "-vf", - "format=yuv420p,pad=width=ceil(iw/2)*2:height=ceil(ih/2)*2", - "-c:v", + "pad=width=ceil(iw/2)*2:height=ceil(ih/2)*2", + ]; + + args.extend_from_slice(&[ + "-vcodec", "libx265", "-tag:v", "hvc1", "-preset", - "medium", // Changed from ultrafast for better compression + "ultrafast", "-crf", - "28", // Slightly higher CRF for smaller file size - "-x265-params", - "log-level=error", // Reduce x265 logging noise - output_file, - ]; + "23", + ]); + + args.extend_from_slice(&["-pix_fmt", "yuv420p", output_file]); command .args(&args) .stdin(Stdio::piped()) - .stdout(Stdio::null()) // Changed to null since we don't need stdout + .stdout(Stdio::piped()) .stderr(Stdio::piped()); - debug!("ffmpeg command: {:?}", command); + debug!("FFmpeg command: {:?}", command); + let child = command.spawn()?; + debug!("FFmpeg process spawned"); + Ok(child) } @@ -256,30 +202,29 @@ pub async fn write_frame_to_ffmpeg( Ok(()) } -async fn save_frames_as_video_with_shutdown( +async fn log_ffmpeg_output(stream: impl AsyncBufReadExt + Unpin, stream_name: &str) { + let reader = BufReader::new(stream); + let mut lines = reader.lines(); + while let Ok(Some(line)) = lines.next_line().await { + debug!("FFmpeg {}: {}", stream_name, line); + } +} + +async fn save_frames_as_video( frame_queue: &Arc>>, output_path: &str, fps: f64, new_chunk_callback: Arc, monitor_id: u32, video_chunk_duration: Duration, - mut shutdown_rx: watch::Receiver, ) { - debug!("starting save_frames_as_video function"); + debug!("Starting save_frames_as_video function"); let frames_per_video = (fps * video_chunk_duration.as_secs_f64()).ceil() as usize; let mut frame_count = 0; let mut current_ffmpeg: Option = None; let mut current_stdin: Option = None; loop { - if *shutdown_rx.borrow() { - info!("shutting down video capture thread"); - if let Some(child) = current_ffmpeg.take() { - finish_ffmpeg_process(child, current_stdin.take()).await; - } - break; - } - if frame_count >= frames_per_video || current_ffmpeg.is_none() { if let Some(child) = current_ffmpeg.take() { finish_ffmpeg_process(child, current_stdin.take()).await; @@ -295,20 +240,20 @@ async fn save_frames_as_video_with_shutdown( match start_ffmpeg_process(&output_file, fps).await { Ok(mut child) => { let mut stdin = child.stdin.take().expect("Failed to open stdin"); - spawn_ffmpeg_loggers(child.stderr.take()); + spawn_ffmpeg_loggers(child.stderr.take(), child.stdout.take()); if let Err(e) = write_frame_to_ffmpeg(&mut stdin, &buffer).await { - error!("failed to write first frame to ffmpeg: {}", e); + error!("Failed to write first frame to ffmpeg: {}", e); continue; } frame_count += 1; current_ffmpeg = Some(child); current_stdin = Some(stdin); - debug!("new FFmpeg process started for file: {}", output_file); + debug!("New FFmpeg process started for file: {}", output_file); } Err(e) => { - error!("failed to start FFmpeg process: {}", e); + error!("Failed to start FFmpeg process: {}", e); continue; } } @@ -320,23 +265,10 @@ async fn save_frames_as_video_with_shutdown( &mut frame_count, frames_per_video, fps, - &shutdown_rx, ) .await; - tokio::select! { - _ = shutdown_rx.changed() => { - if *shutdown_rx.borrow() { - if let Some(child) = current_ffmpeg.take() { - finish_ffmpeg_process(child, current_stdin.take()).await; - } - break; - } - } - _ = tokio::time::sleep(Duration::from_millis(10)) => { - // Continue with normal processing - } - } + tokio::task::yield_now().await; } } @@ -354,12 +286,9 @@ async fn wait_for_first_frame( fn encode_frame(frame: &CaptureResult) -> Vec { let mut buffer = Vec::new(); - let rgb_image = frame.image.to_rgb8(); - rgb_image - .write_to( - &mut std::io::Cursor::new(&mut buffer), - image::ImageFormat::Jpeg, - ) + frame + .image + .write_to(&mut std::io::Cursor::new(&mut buffer), ImageFormat::Png) .expect("Failed to encode frame"); buffer } @@ -374,22 +303,12 @@ fn create_output_file(output_path: &str, monitor_id: u32) -> String { .to_string() } -fn spawn_ffmpeg_loggers(stderr: Option) { +fn spawn_ffmpeg_loggers(stderr: Option, stdout: Option) { if let Some(stderr) = stderr { - tokio::spawn(async move { - let reader = BufReader::new(stderr); - let mut lines = reader.lines(); - while let Ok(Some(line)) = lines.next_line().await { - // Only log important messages - if line.contains("error") || line.contains("fatal") { - error!("ffmpeg: {}", line); - } else if line.contains("warning") { - warn!("ffmpeg: {}", line); - } else { - debug!("ffmpeg: {}", line); - } - } - }); + tokio::spawn(log_ffmpeg_output(BufReader::new(stderr), "stderr")); + } + if let Some(stdout) = stdout { + tokio::spawn(log_ffmpeg_output(BufReader::new(stdout), "stdout")); } } @@ -399,39 +318,25 @@ async fn process_frames( frame_count: &mut usize, frames_per_video: usize, fps: f64, - shutdown_rx: &watch::Receiver, ) { let write_timeout = Duration::from_secs_f64(1.0 / fps); - let mut should_break = false; - - while *frame_count < frames_per_video && !should_break { - if *shutdown_rx.borrow() { - info!("process_frames: shutdown signal received, breaking out"); - should_break = true; - continue; - } - + while *frame_count < frames_per_video { if let Some(frame) = frame_queue.pop() { let buffer = encode_frame(&frame); if let Some(stdin) = current_stdin.as_mut() { if let Err(e) = write_frame_with_retry(stdin, &buffer).await { - error!("failed to write frame to ffmpeg after max retries: {}", e); - should_break = true; - continue; + error!("Failed to write frame to ffmpeg after max retries: {}", e); + break; } *frame_count += 1; - debug!("wrote frame {} to ffmpeg", frame_count); + debug!("Wrote frame {} to FFmpeg", frame_count); + flush_ffmpeg_input(stdin, *frame_count, fps).await; } } else { tokio::time::sleep(write_timeout).await; } } - - // Cleanup remaining frames - while frame_queue.pop().is_some() { - debug!("cleaning up remaining frame from queue"); - } } async fn write_frame_with_retry( @@ -479,34 +384,10 @@ pub async fn finish_ffmpeg_process(child: Child, stdin: Option) { match child.wait_with_output().await { Ok(output) => { debug!("FFmpeg process exited with status: {}", output.status); - let stderr = String::from_utf8_lossy(&output.stderr); - if !output.status.success() && stderr != Cow::Borrowed("") { - error!("FFmpeg stderr: {}", stderr); + if !output.status.success() { + error!("FFmpeg stderr: {}", String::from_utf8_lossy(&output.stderr)); } } Err(e) => error!("Failed to wait for FFmpeg process: {}", e), } } - -fn push_to_queue( - queue: &ArrayQueue>, - result: &Arc, - queue_name: &str, -) -> bool { - match queue.push(Arc::clone(result)) { - Ok(_) => { - debug!( - "{} queue: Successfully pushed frame {}", - queue_name, result.frame_number - ); - true - } - Err(_) => { - warn!( - "{} queue full, dropping frame {}", - queue_name, result.frame_number - ); - false - } - } -} diff --git a/screenpipe-server/tests/db.rs b/screenpipe-server/tests/db.rs index 9ddd43121a..9e96eae47f 100644 --- a/screenpipe-server/tests/db.rs +++ b/screenpipe-server/tests/db.rs @@ -3,7 +3,7 @@ mod tests { use std::sync::Arc; use chrono::Utc; - use screenpipe_core::{AudioDevice, AudioDeviceType}; + use screenpipe_audio::{AudioDevice, DeviceType}; use screenpipe_server::{ db_types::{ContentType, SearchResult}, DatabaseManager, @@ -69,7 +69,7 @@ mod tests { "Hello from audio", 0, "", - &AudioDevice::new("test".to_string(), AudioDeviceType::Output), + &AudioDevice::new("test".to_string(), DeviceType::Output), None, None, None, @@ -131,7 +131,7 @@ mod tests { "Hello from audio", 0, "", - &AudioDevice::new("test".to_string(), AudioDeviceType::Output), + &AudioDevice::new("test".to_string(), DeviceType::Output), None, None, None, @@ -220,7 +220,7 @@ mod tests { "Hello from audio", 0, "", - &AudioDevice::new("test".to_string(), AudioDeviceType::Output), + &AudioDevice::new("test".to_string(), DeviceType::Output), None, None, None, @@ -310,7 +310,7 @@ mod tests { "Hello from audio 1", 0, "", - &AudioDevice::new("test".to_string(), AudioDeviceType::Output), + &AudioDevice::new("test".to_string(), DeviceType::Output), None, None, None, @@ -346,7 +346,7 @@ mod tests { "Hello from audio 2", 1, "", - &AudioDevice::new("test".to_string(), AudioDeviceType::Output), + &AudioDevice::new("test".to_string(), DeviceType::Output), None, None, None, @@ -467,7 +467,6 @@ mod tests { } #[tokio::test] - #[ignore] // TODO FIX async fn test_count_search_results_with_time_range() { let db = setup_test_db().await; @@ -498,7 +497,7 @@ mod tests { "Hello from audio 1", 0, "", - &AudioDevice::new("test".to_string(), AudioDeviceType::Output), + &AudioDevice::new("test".to_string(), DeviceType::Output), None, None, None, @@ -533,7 +532,7 @@ mod tests { "Hello from audio 2", 1, "", - &AudioDevice::new("test".to_string(), AudioDeviceType::Output), + &AudioDevice::new("test".to_string(), DeviceType::Output), None, None, None, @@ -673,7 +672,7 @@ mod tests { "test transcription", 0, "", - &AudioDevice::new("test".to_string(), AudioDeviceType::Output), + &AudioDevice::new("test".to_string(), DeviceType::Output), Some(speaker.id), None, None, @@ -737,7 +736,7 @@ mod tests { "test transcription", 0, "", - &AudioDevice::new("test".to_string(), AudioDeviceType::Output), + &AudioDevice::new("test".to_string(), DeviceType::Output), Some(speaker.id), None, None, @@ -793,7 +792,7 @@ mod tests { "test transcription", 0, "", - &AudioDevice::new("test".to_string(), AudioDeviceType::Output), + &AudioDevice::new("test".to_string(), DeviceType::Output), Some(speaker.id), None, None, @@ -836,7 +835,7 @@ mod tests { "test transcription", 0, "", - &AudioDevice::new("test".to_string(), AudioDeviceType::Output), + &AudioDevice::new("test".to_string(), DeviceType::Output), Some(speaker.id), None, None, @@ -880,7 +879,7 @@ mod tests { "test transcription", 0, "", - &AudioDevice::new("test".to_string(), AudioDeviceType::Output), + &AudioDevice::new("test".to_string(), DeviceType::Output), Some(speaker.id), None, None, @@ -897,7 +896,7 @@ mod tests { "test transcription", 0, "", - &AudioDevice::new("test".to_string(), AudioDeviceType::Output), + &AudioDevice::new("test".to_string(), DeviceType::Output), Some(speaker2.id), None, None, diff --git a/screenpipe-server/tests/endpoint_test.rs b/screenpipe-server/tests/endpoint_test.rs index 6717ba7f9f..63f064e91a 100644 --- a/screenpipe-server/tests/endpoint_test.rs +++ b/screenpipe-server/tests/endpoint_test.rs @@ -6,9 +6,7 @@ mod tests { use axum::Router; use chrono::DateTime; use chrono::{Duration, Utc}; - use screenpipe_core::AudioDevice; - use screenpipe_core::AudioDeviceType; - use screenpipe_core::DeviceManager; + use screenpipe_audio::{AudioDevice, DeviceType}; use screenpipe_server::db_types::ContentType; use screenpipe_server::db_types::SearchResult; use screenpipe_server::video_cache::FrameCache; @@ -18,7 +16,9 @@ mod tests { }; use screenpipe_vision::OcrEngine; // Adjust this import based on your actual module structure use serde::Deserialize; + use std::collections::HashMap; use std::path::PathBuf; + use std::sync::atomic::AtomicBool; use std::sync::Arc; use tower::ServiceExt; // for `oneshot` and `ready` @@ -32,7 +32,9 @@ mod tests { let app_state = Arc::new(AppState { db: db.clone(), - device_manager: Arc::new(DeviceManager::default()), + vision_control: Arc::new(AtomicBool::new(false)), + audio_devices_tx: Arc::new(tokio::sync::broadcast::channel(1000).0), + devices_status: HashMap::new(), app_start_time: Utc::now(), screenpipe_dir: PathBuf::from(""), pipe_manager: Arc::new(PipeManager::new(PathBuf::from(""))), @@ -42,7 +44,9 @@ mod tests { FrameCache::new(PathBuf::from(""), db).await.unwrap(), )), ui_monitoring_enabled: false, - frame_image_cache: None, + realtime_transcription_sender: Arc::new(tokio::sync::broadcast::channel(1000).0), + realtime_transcription_enabled: false, + realtime_vision_sender: Arc::new(tokio::sync::broadcast::channel(1000).0), }); let router = create_router(); @@ -67,7 +71,7 @@ mod tests { "Short", 0, "", - &AudioDevice::new("test1".to_string(), AudioDeviceType::Input), + &AudioDevice::new("test1".to_string(), DeviceType::Input), None, None, None, @@ -81,7 +85,7 @@ mod tests { "This is a longer transcription with more words", 0, "", - &AudioDevice::new("test2".to_string(), AudioDeviceType::Input), + &AudioDevice::new("test2".to_string(), DeviceType::Input), None, None, None, @@ -203,7 +207,7 @@ mod tests { "This is a test audio transcription that should definitely be longer than thirty characters", // >30 chars 0, "", - &AudioDevice::new("test1".to_string(), AudioDeviceType::Input), + &AudioDevice::new("test1".to_string(), DeviceType::Input), None, None, None, @@ -216,7 +220,7 @@ mod tests { "Short audio", // <30 chars 0, "", - &AudioDevice::new("test2".to_string(), AudioDeviceType::Input), + &AudioDevice::new("test2".to_string(), DeviceType::Input), None, None, None, @@ -352,7 +356,6 @@ mod tests { } #[tokio::test] - #[ignore] // FIX ME async fn test_search_with_time_constraints() { let (_, state) = setup_test_app().await; let db = &state.db; @@ -397,7 +400,7 @@ mod tests { "old audio transcription", 0, "", - &AudioDevice::new("test".to_string(), AudioDeviceType::Input), + &AudioDevice::new("test".to_string(), DeviceType::Input), None, None, None, @@ -532,7 +535,7 @@ mod tests { ) .await .unwrap(); - assert_eq!(audio_count, 1); // TODO fail here ? + assert_eq!(audio_count, 1); } #[tokio::test] @@ -655,16 +658,11 @@ mod tests { async fn test_recent_tasks_no_bleeding_production_db() { // Get home directory safely let home = std::env::var("HOME").expect("HOME environment variable not set"); - let source_db_path = format!("{}/.screenpipe/db.sqlite", home); + let db_path = format!("{}/.screenpipe/db.sqlite", home); - // Create temporary directory and copy database - let temp_dir = tempfile::tempdir().unwrap(); - let temp_db_path = temp_dir.path().join("temp_db.sqlite"); - std::fs::copy(&source_db_path, &temp_db_path).unwrap(); - - // Open temporary database copy + // Open database in read-only mode for safety let db = Arc::new( - DatabaseManager::new(&format!("sqlite:{}", temp_db_path.display())) + DatabaseManager::new(&format!("sqlite:{}?mode=ro", db_path)) .await .unwrap(), ); diff --git a/screenpipe-server/tests/tags_test.rs b/screenpipe-server/tests/tags_test.rs index a9b9fce6c7..57bca9a7b7 100644 --- a/screenpipe-server/tests/tags_test.rs +++ b/screenpipe-server/tests/tags_test.rs @@ -4,11 +4,12 @@ use axum::{ Router, }; use chrono::Utc; -use screenpipe_core::{AudioDevice, AudioDeviceType, DeviceManager}; +use screenpipe_audio::{AudioDevice, DeviceType}; use screenpipe_vision::OcrEngine; use serde_json::json; -use std::path::PathBuf; +use std::sync::atomic::AtomicBool; use std::sync::Arc; +use std::{collections::HashMap, path::PathBuf}; use tower::ServiceExt; use screenpipe_server::{ @@ -26,9 +27,11 @@ async fn setup_test_app() -> (Router, Arc) { let app_state = Arc::new(AppState { db: db.clone(), - device_manager: Arc::new(DeviceManager::default()), vision_disabled: false, audio_disabled: false, + vision_control: Arc::new(AtomicBool::new(false)), + audio_devices_tx: Arc::new(tokio::sync::broadcast::channel(1000).0), + devices_status: HashMap::new(), app_start_time: Utc::now(), screenpipe_dir: PathBuf::from(""), pipe_manager: Arc::new(PipeManager::new(PathBuf::from(""))), @@ -36,7 +39,9 @@ async fn setup_test_app() -> (Router, Arc) { FrameCache::new(PathBuf::from(""), db).await.unwrap(), )), ui_monitoring_enabled: false, - frame_image_cache: None, + realtime_transcription_sender: Arc::new(tokio::sync::broadcast::channel(1000).0), + realtime_transcription_enabled: false, + realtime_vision_sender: Arc::new(tokio::sync::broadcast::channel(1000).0), }); let app = create_router().with_state(app_state.clone()); @@ -387,7 +392,7 @@ async fn insert_test_data(db: &Arc) { "Test audio transcription", 0, "test_engine", - &AudioDevice::new("test".to_string(), AudioDeviceType::Output), + &AudioDevice::new("test".to_string(), DeviceType::Output), None, None, None, diff --git a/screenpipe-server/tests/video_utils_test.rs b/screenpipe-server/tests/video_utils_test.rs index 93fe7f99bf..b33db5b3a3 100644 --- a/screenpipe-server/tests/video_utils_test.rs +++ b/screenpipe-server/tests/video_utils_test.rs @@ -2,16 +2,16 @@ use anyhow::Result; use dirs::{self, home_dir}; use screenpipe_core::Language; use screenpipe_server::video_utils::extract_frames_from_video; -use screenpipe_vision::capture_screenshot_by_window::CapturedWindow; +use screenpipe_vision::{capture_screenshot_by_window::CapturedWindow, perform_ocr_apple}; use std::path::PathBuf; use tokio::fs; use tracing::info; async fn setup_test_env() -> Result<()> { - // enable tracing logging; use try_init to avoid setting the subscriber multiple times - let _ = tracing_subscriber::fmt() + // enable tracing logging + tracing_subscriber::fmt() .with_max_level(tracing::Level::DEBUG) - .try_init(); + .init(); Ok(()) } @@ -42,7 +42,6 @@ async fn create_test_video() -> Result { } #[tokio::test] -#[ignore] // TODO: fix this test async fn test_extract_frames() -> Result<()> { setup_test_env().await?; let video_path = create_test_video().await?; @@ -102,12 +101,8 @@ async fn test_extract_frames() -> Result<()> { Ok(()) } -#[cfg(target_os = "macos")] #[tokio::test] async fn test_extract_frames_and_ocr() -> Result<()> { - use std::sync::Arc; - - use screenpipe_vision::perform_ocr_apple; setup_test_env().await?; let video_path = create_test_video().await?; @@ -135,10 +130,7 @@ async fn test_extract_frames_and_ocr() -> Result<()> { }; // perform ocr using apple native (macos only) - let (text, _, confidence) = perform_ocr_apple( - &captured_window.image, - Arc::new([Language::English].to_vec()), - ); + let (text, _, confidence) = perform_ocr_apple(&captured_window.image, &[Language::English]); println!("ocr confidence: {}", confidence.unwrap_or(0.0)); println!("extracted text: {}", text); diff --git a/screenpipe-vision/Cargo.toml b/screenpipe-vision/Cargo.toml index 92792dd1ce..aebc8628a4 100644 --- a/screenpipe-vision/Cargo.toml +++ b/screenpipe-vision/Cargo.toml @@ -33,7 +33,6 @@ clap = { version = "4.0", features = ["derive"] } # Integrations screenpipe-integrations = { path = "../screenpipe-integrations" } -screenpipe-events = { path = "../screenpipe-events" } # Lanuage specification screenpipe-core = { path = "../screenpipe-core" } @@ -44,6 +43,7 @@ which = "6.0" serde = "1.0.200" once_cell = { workspace = true } +chrono = { version = "0.4.39", features = ["serde"] } base64 = "0.22.1" reqwest = { workspace = true } diff --git a/screenpipe-vision/benches/apple_leak_bench.rs b/screenpipe-vision/benches/apple_leak_bench.rs index 1196a13a08..e618f57168 100644 --- a/screenpipe-vision/benches/apple_leak_bench.rs +++ b/screenpipe-vision/benches/apple_leak_bench.rs @@ -2,7 +2,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; use image::GenericImageView; use memory_stats::memory_stats; use screenpipe_vision::perform_ocr_apple; -use std::{path::PathBuf, sync::Arc}; +use std::path::PathBuf; fn bytes_to_mb(bytes: usize) -> f64 { bytes as f64 / (1024.0 * 1024.0) @@ -38,7 +38,7 @@ fn apple_ocr_benchmark(c: &mut Criterion) { } } - let result = perform_ocr_apple(&image, Arc::new(vec![])); + let result = perform_ocr_apple(&image, &[]); assert!( result.0.contains("receiver_count"), "OCR failed: {:?}", diff --git a/screenpipe-vision/benches/vision_benchmark.rs b/screenpipe-vision/benches/vision_benchmark.rs index 2c41961697..b3b5ea0f28 100644 --- a/screenpipe-vision/benches/vision_benchmark.rs +++ b/screenpipe-vision/benches/vision_benchmark.rs @@ -25,7 +25,6 @@ async fn benchmark_continuous_capture(duration_secs: u64) -> f64 { window_filters, vec![], false, - tokio::sync::watch::channel(false).1, ) .await; }); diff --git a/screenpipe-vision/examples/websocket.rs b/screenpipe-vision/examples/websocket.rs index e34f017d8b..06f6ed1fcc 100644 --- a/screenpipe-vision/examples/websocket.rs +++ b/screenpipe-vision/examples/websocket.rs @@ -8,7 +8,7 @@ use screenpipe_vision::{ continuous_capture, monitor::get_default_monitor, CaptureResult, OcrEngine, }; use serde::Serialize; -use serde_json; +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::net::{TcpListener, TcpStream}; @@ -24,10 +24,11 @@ struct SimplifiedResult { #[derive(Clone, Serialize)] pub struct SimplifiedWindowResult { + // pub image: String, pub window_name: String, pub app_name: String, pub text: String, - pub text_json: Vec, + pub text_json: Vec>, // Change this line pub focused: bool, pub confidence: f64, } @@ -91,17 +92,16 @@ async fn main() -> Result<()> { Duration::from_secs_f64(1.0 / cli.fps), // if apple use apple otherwise if windows use windows native otherwise use tesseract if cfg!(target_os = "macos") { - Arc::new(OcrEngine::AppleNative) + OcrEngine::AppleNative } else if cfg!(target_os = "windows") { - Arc::new(OcrEngine::WindowsNative) + OcrEngine::WindowsNative } else { - Arc::new(OcrEngine::Tesseract) + OcrEngine::Tesseract }, id, window_filters, - Arc::new([].to_vec()), + vec![], false, - tokio::sync::watch::channel(false).1, ) .await }); @@ -147,10 +147,11 @@ async fn run_websocket_server( let _base64_image = general_purpose::STANDARD.encode(buffer); SimplifiedWindowResult { + // image: base64_image, window_name: window.window_name, app_name: window.app_name, text: window.text, - text_json: window.text_json, + text_json: window.text_json, // Add this line focused: window.focused, confidence: window.confidence, } diff --git a/screenpipe-vision/src/apple.rs b/screenpipe-vision/src/apple.rs index 04ddbcb065..dcf1c62d7d 100644 --- a/screenpipe-vision/src/apple.rs +++ b/screenpipe-vision/src/apple.rs @@ -10,13 +10,12 @@ use log::error; use screenpipe_core::Language; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use std::sync::Arc; use std::sync::OnceLock; use std::{ffi::c_void, ptr::null_mut}; static APPLE_LANGUAGE_MAP: OnceLock> = OnceLock::new(); - -pub fn get_apple_languages(languages: Arc<[Language]>) -> Vec { + +pub fn get_apple_languages(languages: &[Language]) -> Vec { let map = APPLE_LANGUAGE_MAP.get_or_init(|| { let mut m = HashMap::new(); m.insert(Language::English, "en-US"); @@ -72,7 +71,7 @@ extern "C" fn release_callback(_refcon: *mut c_void, _data_ptr: *const *const c_ #[cfg(target_os = "macos")] pub fn perform_ocr_apple( image: &DynamicImage, - languages: Arc<[Language]>, + languages: &[Language], ) -> (String, String, Option) { cidre::objc::ar_pool(|| { // Convert languages to Apple format and create ns::Array diff --git a/screenpipe-vision/src/bin/screenpipe-vision.rs b/screenpipe-vision/src/bin/screenpipe-vision.rs index 0ab50aafcb..7c777096bb 100644 --- a/screenpipe-vision/src/bin/screenpipe-vision.rs +++ b/screenpipe-vision/src/bin/screenpipe-vision.rs @@ -5,7 +5,7 @@ use screenpipe_vision::{ OcrEngine, }; use std::{sync::Arc, time::Duration}; -use tokio::sync::{mpsc::channel, watch}; +use tokio::sync::mpsc::channel; use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; #[derive(Parser)] @@ -39,18 +39,15 @@ async fn main() { let id = monitor.id(); let window_filters = WindowFilters::new(&[], &[]); - let (_, shutdown_rx) = watch::channel(false); - tokio::spawn(async move { continuous_capture( result_tx, Duration::from_secs_f32(1.0 / cli.fps), - Arc::new(OcrEngine::AppleNative), + OcrEngine::AppleNative, id, Arc::new(window_filters), - Arc::from(languages), + languages.clone(), false, - shutdown_rx, ) .await }); diff --git a/screenpipe-vision/src/core.rs b/screenpipe-vision/src/core.rs index 5869f8481f..6640f38d2a 100644 --- a/screenpipe-vision/src/core.rs +++ b/screenpipe-vision/src/core.rs @@ -22,14 +22,14 @@ use serde::Serialize; use serde::Serializer; use serde_json; use std::sync::Arc; -use std::time::{Duration, Instant, UNIX_EPOCH}; +use std::{ + collections::HashMap, + time::{Duration, Instant, UNIX_EPOCH}, +}; use tokio::fs::File; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::sync::mpsc::Sender; -use tokio::sync::mpsc::WeakSender; -use tokio::sync::watch; -use tracing::info; -use tracing::warn; +use tokio::time::sleep; #[cfg(target_os = "macos")] use xcap_macos::Monitor; @@ -37,44 +37,54 @@ use xcap_macos::Monitor; #[cfg(not(target_os = "macos"))] use xcap::Monitor; -fn serialize_image(image: &Option>, serializer: S) -> Result +fn serialize_image(image: &Option, serializer: S) -> Result where S: serde::Serializer, { if let Some(image) = image { let mut webp_buffer = Vec::new(); let mut cursor = std::io::Cursor::new(&mut webp_buffer); + let mut encoder = JpegEncoder::new_with_quality(&mut cursor, 80); + // Encode the image as WebP encoder - .encode_image(image.as_ref()) + .encode_image(image) .map_err(serde::ser::Error::custom)?; + // Base64 encode the WebP data let base64_string = general_purpose::STANDARD.encode(webp_buffer); + + // Serialize the base64 string serializer.serialize_str(&base64_string) } else { serializer.serialize_none() } } -fn deserialize_image<'de, D>(deserializer: D) -> Result>, D::Error> +fn deserialize_image<'de, D>(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, { + // Deserialize the base64 string let base64_string: String = serde::Deserialize::deserialize(deserializer)?; + // Check if the base64 string is empty or invalid if base64_string.trim().is_empty() { return Ok(None); } + // Decode base64 to bytes let image_bytes = general_purpose::STANDARD .decode(&base64_string) .map_err(serde::de::Error::custom)?; + // Create a cursor to read from the bytes let cursor = std::io::Cursor::new(image_bytes); - let image = image::load(cursor, image::ImageFormat::Jpeg).map_err(serde::de::Error::custom)?; - Ok(Some(Arc::new(image))) + // Decode the JPEG data back into an image + let image = image::load(cursor, image::ImageFormat::Jpeg).map_err(serde::de::Error::custom)?; + Ok(Some(image)) } fn serialize_instant(instant: &Instant, serializer: S) -> Result @@ -96,32 +106,24 @@ where } pub struct CaptureResult { - pub image: Arc, + pub image: DynamicImage, pub frame_number: u64, pub timestamp: Instant, pub window_ocr_results: Vec, } -impl Drop for CaptureResult { - fn drop(&mut self) { - if Arc::strong_count(&self.image) == 1 { - debug!("dropping last reference to captured image"); - } - } -} - pub struct WindowOcrResult { - pub image: Arc, + pub image: DynamicImage, pub window_name: String, pub app_name: String, pub text: String, - pub text_json: Vec, + pub text_json: Vec>, // Change this line pub focused: bool, pub confidence: f64, } pub struct OcrTaskData { - pub image: Arc, + pub image: DynamicImage, pub window_images: Vec, pub frame_number: u64, pub timestamp: Instant, @@ -131,15 +133,14 @@ pub struct OcrTaskData { pub async fn continuous_capture( result_tx: Sender, interval: Duration, - ocr_engine: Arc, + ocr_engine: OcrEngine, monitor_id: u32, window_filters: Arc, - languages: Arc<[Language]>, + languages: Vec, capture_unfocused_windows: bool, - mut shutdown_rx: watch::Receiver, ) { let mut frame_counter: u64 = 0; - let mut previous_image: Option> = None; + let mut previous_image: Option = None; let mut max_average: Option = None; let mut max_avg_value = 0.0; @@ -148,145 +149,118 @@ pub async fn continuous_capture( monitor_id ); - let monitor = get_monitor_by_id(monitor_id).await.unwrap(); - loop { - // Check shutdown signal - if *shutdown_rx.borrow() { - info!( - "continuous_capture: received shutdown signal for monitor {}", - monitor_id - ); - drop(result_tx); - break; - } - - // Use tokio::select! to handle both capture and shutdown - tokio::select! { - _ = shutdown_rx.changed() => { - if *shutdown_rx.borrow() { - info!("continuous_capture: shutdown signal received for monitor {}", monitor_id); - drop(result_tx); - break; + let monitor = match get_monitor_by_id(monitor_id).await { + Some(m) => m, + None => { + sleep(Duration::from_secs(1)).await; + continue; + } + }; + let capture_result = + match capture_screenshot(&monitor, &window_filters, capture_unfocused_windows).await { + Ok((image, window_images, image_hash, _capture_duration)) => { + debug!( + "Captured screenshot on monitor {} with hash: {}", + monitor_id, image_hash + ); + Some((image, window_images, image_hash)) + } + Err(e) => { + error!("Failed to capture screenshot: {}", e); + None } + }; + + if let Some((image, window_images, image_hash)) = capture_result { + let current_average = match compare_with_previous_image( + previous_image.as_ref(), + &image, + &mut max_average, + frame_counter, + &mut max_avg_value, + ) + .await + { + Ok(avg) => avg, + Err(e) => { + error!("Error comparing images: {}", e); + 0.0 + } + }; + + let current_average = if previous_image.is_none() { + 1.0 + } else { + current_average + }; + + if current_average < 0.006 { + debug!( + "Skipping frame {} due to low average difference: {:.3}", + frame_counter, current_average + ); + frame_counter += 1; + tokio::time::sleep(interval).await; + continue; + } + + if current_average > max_avg_value { + max_average = Some(MaxAverageFrame { + image: image.clone(), + window_images: window_images.clone(), + image_hash, + frame_number: frame_counter, + timestamp: Instant::now(), + result_tx: result_tx.clone(), + average: current_average, + }); + max_avg_value = current_average; } - _ = async { - let languages_clone = languages.clone(); - let capture_result = match capture_screenshot(&monitor, &window_filters, capture_unfocused_windows).await { - Ok((image, window_images, image_hash, _capture_duration)) => { - debug!( - "captured screenshot on monitor {} with hash: {}", - monitor.id(), - image_hash - ); - Some((image, window_images, image_hash)) - } - Err(e) => { - error!("Failed to capture screenshot: {}", e); - None - } + + previous_image = Some(image); + + if let Some(max_avg_frame) = max_average.take() { + let ocr_task_data = OcrTaskData { + image: max_avg_frame.image, + window_images: max_avg_frame.window_images, + frame_number: max_avg_frame.frame_number, + timestamp: max_avg_frame.timestamp, + result_tx: max_avg_frame.result_tx, }; - if let Some((image, window_images, image_hash)) = capture_result { - let image = Arc::new(image); - let current_average = match compare_with_previous_image( - previous_image.as_ref(), - &image, - &mut max_average, - frame_counter, - &mut max_avg_value, - ) - .await - { - Ok(avg) => avg, - Err(e) => { - error!("Error comparing images: {}", e); - previous_image = None; - 0.0 - } - }; - - let current_average = if previous_image.is_none() { - 1.0 - } else { - current_average - }; - - if current_average < 0.006 { - debug!( - "Skipping frame {} due to low average difference: {:.3}", - frame_counter, current_average - ); - frame_counter += 1; - tokio::time::sleep(interval).await; - return; - } - - if current_average > max_avg_value { - max_average = Some(MaxAverageFrame { - image: Arc::clone(&image), - window_images, - image_hash, - frame_number: frame_counter, - timestamp: Instant::now(), - result_tx: result_tx.downgrade(), - average: current_average, - }); - max_avg_value = current_average; - } - - previous_image = Some(Arc::clone(&image)); - - if let Some(max_avg_frame) = max_average.take() { - if let Some(sender) = max_avg_frame.result_tx.upgrade() { - let ocr_task_data = OcrTaskData { - image: max_avg_frame.image.clone(), - window_images: max_avg_frame.window_images.iter().cloned().collect(), - frame_number: max_avg_frame.frame_number, - timestamp: max_avg_frame.timestamp, - result_tx: sender, - }; - - if let Err(e) = - process_ocr_task(ocr_task_data, ocr_engine.clone(), languages_clone).await - { - error!("Error processing OCR task: {}", e); - } - - frame_counter = 0; - max_avg_value = 0.0; - } else { - warn!("result_tx was dropped, skipping OCR task"); - return; - } - } - } else { - debug!("Skipping frame {} due to capture failure", frame_counter); + if let Err(e) = + process_ocr_task(ocr_task_data, &ocr_engine, languages.clone()).await + { + error!("Error processing OCR task: {}", e); } - frame_counter += 1; - tokio::time::sleep(interval).await; - } => {} + frame_counter = 0; + max_avg_value = 0.0; + } + } else { + debug!("Skipping frame {} due to capture failure", frame_counter); } - } - debug!("Continuous capture stopped for monitor {}", monitor_id); + frame_counter += 1; + tokio::time::sleep(interval).await; + } } pub struct MaxAverageFrame { - pub image: Arc, + pub image: DynamicImage, pub window_images: Vec, pub image_hash: u64, pub frame_number: u64, pub timestamp: Instant, - pub result_tx: WeakSender, + pub result_tx: Sender, pub average: f64, } pub async fn process_ocr_task( ocr_task_data: OcrTaskData, - ocr_engine: Arc, - languages: Arc<[Language]>, + ocr_engine: &OcrEngine, + languages: Vec, ) -> Result<(), std::io::Error> { let OcrTaskData { image, @@ -307,7 +281,7 @@ pub async fn process_ocr_task( let mut window_count = 0; for captured_window in window_images { - let (window_text, window_json_output, confidence) = match ocr_engine.as_ref() { + let (window_text, window_json_output, confidence) = match ocr_engine { OcrEngine::Unstructured => perform_ocr_cloud(&captured_window.image, languages.clone()) .await .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?, @@ -319,7 +293,7 @@ pub async fn process_ocr_task( .await .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?, #[cfg(target_os = "macos")] - OcrEngine::AppleNative => perform_ocr_apple(&captured_window.image, languages.clone()), + OcrEngine::AppleNative => perform_ocr_apple(&captured_window.image, &languages), OcrEngine::Custom(config) => { perform_ocr_custom(&captured_window.image, languages.clone(), config) .await @@ -338,17 +312,15 @@ pub async fn process_ocr_task( window_count += 1; } - let ocr_result = WindowOcrResult { - image: Arc::new(captured_window.image), + window_ocr_results.push(WindowOcrResult { + image: captured_window.image, window_name: captured_window.window_name, app_name: captured_window.app_name, text: window_text, text_json: parse_json_output(&window_json_output), focused: captured_window.is_focused, confidence: confidence.unwrap_or(0.0), - }; - - window_ocr_results.push(ocr_result); + }); } let capture_result = CaptureResult { @@ -386,12 +358,14 @@ pub async fn process_ocr_task( Ok(()) } -fn parse_json_output(json_output: &str) -> Vec { - serde_json::from_str(json_output) // <-- Temporary allocations +fn parse_json_output(json_output: &str) -> Vec> { + let parsed_output: Vec> = serde_json::from_str(json_output) .unwrap_or_else(|e| { error!("Failed to parse JSON output: {}", e); Vec::new() - }) + }); + + parsed_output } pub fn trigger_screen_capture_permission() -> Result<()> { @@ -419,11 +393,11 @@ pub struct WindowOcr { serialize_with = "serialize_image", deserialize_with = "deserialize_image" )] - pub image: Option>, + pub image: Option, pub window_name: String, pub app_name: String, pub text: String, - pub text_json: Vec, + pub text_json: Vec>, // Change this line pub focused: bool, pub confidence: f64, #[serde( diff --git a/screenpipe-vision/src/custom_ocr.rs b/screenpipe-vision/src/custom_ocr.rs index 87c61dcfe9..bfd077ecd8 100644 --- a/screenpipe-vision/src/custom_ocr.rs +++ b/screenpipe-vision/src/custom_ocr.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use anyhow::Result; use base64::{engine::general_purpose, Engine as _}; use image::DynamicImage; @@ -25,7 +23,7 @@ impl Default for CustomOcrConfig { pub async fn perform_ocr_custom( image: &DynamicImage, - languages: Arc<[Language]>, + languages: Vec, config: &CustomOcrConfig, ) -> Result<(String, String, Option)> { // Convert image to RGB before encoding to JPEG diff --git a/screenpipe-vision/src/monitor.rs b/screenpipe-vision/src/monitor.rs index f96196751d..c48f07a48c 100644 --- a/screenpipe-vision/src/monitor.rs +++ b/screenpipe-vision/src/monitor.rs @@ -4,7 +4,6 @@ use xcap_macos::Monitor; #[cfg(not(target_os = "macos"))] use xcap::Monitor; - pub async fn list_monitors() -> Vec { Monitor::all().unwrap().to_vec() } diff --git a/screenpipe-vision/src/run_ui_monitoring_macos.rs b/screenpipe-vision/src/run_ui_monitoring_macos.rs index f852840d21..044a2f84c3 100644 --- a/screenpipe-vision/src/run_ui_monitoring_macos.rs +++ b/screenpipe-vision/src/run_ui_monitoring_macos.rs @@ -1,6 +1,5 @@ use anyhow::Result; use log::{debug, error, info, warn}; -use screenpipe_events::send_event; use std::fs; use std::io; use std::path::PathBuf; @@ -14,9 +13,12 @@ use tokio::signal; use tokio::time::{sleep, timeout, Duration}; use which::which; +use crate::core::RealtimeVisionEvent; use crate::UIFrame; -pub async fn run_ui() -> Result<()> { +pub async fn run_ui( + realtime_vision_sender: Arc>, +) -> Result<()> { info!("starting ui monitoring service..."); let binary_name = "ui_monitor"; @@ -131,7 +133,7 @@ pub async fn run_ui() -> Result<()> { frame = UIFrame::read_from_pipe(&mut reader) => { match frame { Ok(frame) => { - let _ = send_event("ui_frame", frame); + let _ = realtime_vision_sender.send(RealtimeVisionEvent::Ui(frame)); } Err(e) => { if let Some(io_err) = e.downcast_ref::() { diff --git a/screenpipe-vision/src/tesseract.rs b/screenpipe-vision/src/tesseract.rs index 04b93727e1..1c1c518ae7 100644 --- a/screenpipe-vision/src/tesseract.rs +++ b/screenpipe-vision/src/tesseract.rs @@ -1,10 +1,11 @@ use image::DynamicImage; use rusty_tesseract::{Args, DataOutput, Image}; use screenpipe_core::{Language, TESSERACT_LANGUAGES}; -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; + pub fn perform_ocr_tesseract( image: &DynamicImage, - languages: Arc<[Language]>, + languages: Vec, ) -> (String, String, Option) { let language_string = match languages.is_empty() { true => "eng".to_string(), diff --git a/screenpipe-vision/src/utils.rs b/screenpipe-vision/src/utils.rs index dae184916d..28a8f2349d 100644 --- a/screenpipe-vision/src/utils.rs +++ b/screenpipe-vision/src/utils.rs @@ -7,7 +7,6 @@ use image::DynamicImage; use image_compare::{Algorithm, Metric, Similarity}; use log::{debug, error, warn}; use std::hash::{DefaultHasher, Hash, Hasher}; -use std::sync::Arc; use std::time::{Duration, Instant}; #[cfg(target_os = "macos")] @@ -83,7 +82,7 @@ pub async fn capture_screenshot( } pub async fn compare_with_previous_image( - previous_image: Option<&Arc>, + previous_image: Option<&DynamicImage>, current_image: &DynamicImage, max_average: &mut Option, frame_number: u64, diff --git a/screenpipe-vision/tests/apple_vision_test.rs b/screenpipe-vision/tests/apple_vision_test.rs index fffa752d53..e0a95e24fe 100644 --- a/screenpipe-vision/tests/apple_vision_test.rs +++ b/screenpipe-vision/tests/apple_vision_test.rs @@ -4,7 +4,7 @@ mod tests { use image::GenericImageView; use screenpipe_core::Language; use screenpipe_vision::perform_ocr_apple; - use std::{path::PathBuf, sync::Arc}; + use std::path::PathBuf; #[tokio::test] async fn test_apple_native_ocr() { @@ -26,7 +26,7 @@ mod tests { let rgb_image = image.to_rgb8(); println!("RGB image dimensions: {:?}", rgb_image.dimensions()); - let (ocr_text, _, _) = perform_ocr_apple(&image, Arc::new([].to_vec())); + let (ocr_text, _, _) = perform_ocr_apple(&image, &[]); println!("OCR text: {:?}", ocr_text); assert!( @@ -46,7 +46,7 @@ mod tests { let image = image::open(&path).expect("Failed to open Chinese test image"); println!("Image dimensions: {:?}", image.dimensions()); - let (ocr_text, _, _) = perform_ocr_apple(&image, Arc::new([Language::Chinese].to_vec())); + let (ocr_text, _, _) = perform_ocr_apple(&image, &[Language::Chinese]); println!("OCR text: {:?}", ocr_text); assert!( diff --git a/screenpipe-vision/tests/custom_ocr_test.rs b/screenpipe-vision/tests/custom_ocr_test.rs index 5761a8d0b9..701e5c9eb7 100644 --- a/screenpipe-vision/tests/custom_ocr_test.rs +++ b/screenpipe-vision/tests/custom_ocr_test.rs @@ -55,7 +55,7 @@ async def read_ocr(payload: dict): # # Configure your "CustomOcrConfig" in Rust to point to http://localhost:8000/ocr -# Clean up +# Clean up deactivate rm -rf venv app.py */ @@ -67,10 +67,8 @@ mod tests { use screenpipe_vision::custom_ocr::{perform_ocr_custom, CustomOcrConfig}; use screenpipe_vision::utils::OcrEngine; use std::path::PathBuf; - use std::sync::Arc; #[tokio::test] - #[ignore] // need to run server async fn test_custom_ocr() { println!("Starting custom OCR test..."); @@ -100,7 +98,7 @@ mod tests { // Perform the custom OCR. let (ocr_text, structured_data, confidence) = match ocr_engine { OcrEngine::Custom(ref config) => { - perform_ocr_custom(&image, Arc::new([Language::English].to_vec()), config) + perform_ocr_custom(&image, vec![Language::English], config) .await .expect("Custom OCR failed") } @@ -119,7 +117,6 @@ mod tests { } #[tokio::test] - #[ignore] // need to run server async fn test_custom_ocr_chinese() { println!("Starting custom OCR Chinese test..."); @@ -145,7 +142,7 @@ mod tests { let (ocr_text, _, _) = match ocr_engine { OcrEngine::Custom(ref config) => { - perform_ocr_custom(&image, Arc::new([Language::Chinese].to_vec()), config) + perform_ocr_custom(&image, vec![Language::Chinese], config) .await .expect("Custom OCR failed") } diff --git a/screenpipe-vision/tests/windows_vision_test.rs b/screenpipe-vision/tests/windows_vision_test.rs index 6f0e23a3ed..7f9d5ecfd4 100644 --- a/screenpipe-vision/tests/windows_vision_test.rs +++ b/screenpipe-vision/tests/windows_vision_test.rs @@ -3,13 +3,11 @@ mod tests { use screenpipe_vision::core::OcrTaskData; use screenpipe_vision::monitor::get_default_monitor; - use screenpipe_vision::{process_ocr_task, OcrEngine, WindowFilters}; + use screenpipe_vision::{process_ocr_task, OcrEngine}; use std::{path::PathBuf, time::Instant}; use tokio::sync::mpsc; - use screenpipe_vision::core::CapturedWindow; use screenpipe_vision::{continuous_capture, CaptureResult}; - use std::sync::Arc; use std::time::Duration; use tokio::time::timeout; @@ -23,18 +21,18 @@ mod tests { println!("Path to testing_OCR.png: {:?}", path); let image = image::open(&path).expect("Failed to open image"); - let image_arc = Arc::new(image.clone()); + let image_arc = image.clone(); let frame_number = 1; let timestamp = Instant::now(); let (tx, _rx) = mpsc::channel(1); - let ocr_engine = Arc::new(OcrEngine::WindowsNative); + let ocr_engine = OcrEngine::WindowsNative; - let window_images = vec![CapturedWindow { - image: image.clone(), - app_name: "test_app".to_string(), - window_name: "test_window".to_string(), - is_foreground: true, - }]; + let window_images = vec![( + image.clone(), + "test_app".to_string(), + "test_window".to_string(), + true, + )]; let result = process_ocr_task( OcrTaskData { @@ -44,8 +42,8 @@ mod tests { timestamp, result_tx: tx, }, - ocr_engine, - Arc::new(vec![]), + false, + &ocr_engine, ) .await; @@ -65,8 +63,7 @@ mod tests { // Set up test parameters let interval = Duration::from_millis(1000); let save_text_files_flag = false; - let ocr_engine = Arc::new(OcrEngine::WindowsNative); - let window_filters = Arc::new(WindowFilters::new(&[], &[])); + let ocr_engine = OcrEngine::WindowsNative; // Spawn the continuous_capture function let capture_handle = tokio::spawn(continuous_capture( @@ -75,10 +72,8 @@ mod tests { save_text_files_flag, ocr_engine, monitor, - window_filters, - vec![], - false, - tokio::sync::watch::channel(false).1, + &[], + &[], )); // Wait for a short duration to allow some captures to occur From ea3b6b49d1a5c2e371ef7bff79f9aecd56765941 Mon Sep 17 00:00:00 2001 From: Louis Beaumont Date: Fri, 14 Feb 2025 15:10:50 -0800 Subject: [PATCH 2/4] remove restart interval --- .../components/recording-settings.tsx | 4 +- .../lib/hooks/use-settings.tsx | 2 +- screenpipe-app-tauri/src-tauri/Cargo.lock | 2 +- screenpipe-app-tauri/src-tauri/src/sidecar.rs | 66 +++++++++---------- .../src/bin/screenpipe-server.rs | 4 +- 5 files changed, 39 insertions(+), 39 deletions(-) diff --git a/screenpipe-app-tauri/components/recording-settings.tsx b/screenpipe-app-tauri/components/recording-settings.tsx index 72d6421c98..d3ffc3cda3 100644 --- a/screenpipe-app-tauri/components/recording-settings.tsx +++ b/screenpipe-app-tauri/components/recording-settings.tsx @@ -1372,7 +1372,7 @@ export function RecordingSettings() {
-
+ {/*
-
+
*/}
diff --git a/screenpipe-app-tauri/lib/hooks/use-settings.tsx b/screenpipe-app-tauri/lib/hooks/use-settings.tsx index d71dcbb11d..a6559bd4d2 100644 --- a/screenpipe-app-tauri/lib/hooks/use-settings.tsx +++ b/screenpipe-app-tauri/lib/hooks/use-settings.tsx @@ -126,7 +126,7 @@ const DEFAULT_SETTINGS: Settings = { monitorIds: ["default"], audioDevices: ["default"], usePiiRemoval: false, - restartInterval: 120, + restartInterval: 0, port: 3030, dataDir: "default", disableAudio: false, diff --git a/screenpipe-app-tauri/src-tauri/Cargo.lock b/screenpipe-app-tauri/src-tauri/Cargo.lock index 1e227b0783..3ebc029c88 100755 --- a/screenpipe-app-tauri/src-tauri/Cargo.lock +++ b/screenpipe-app-tauri/src-tauri/Cargo.lock @@ -5289,7 +5289,7 @@ dependencies = [ [[package]] name = "screenpipe-app" -version = "0.32.14" +version = "0.32.16" dependencies = [ "anyhow", "async-stream", diff --git a/screenpipe-app-tauri/src-tauri/src/sidecar.rs b/screenpipe-app-tauri/src-tauri/src/sidecar.rs index 63eae390a8..fbdc70a0a5 100644 --- a/screenpipe-app-tauri/src-tauri/src/sidecar.rs +++ b/screenpipe-app-tauri/src-tauri/src/sidecar.rs @@ -520,39 +520,39 @@ impl SidecarManager { // Spawn the sidecar let child = spawn_sidecar(app)?; self.child = Some(child); - self.last_restart = Instant::now(); - info!("last restart: {:?}", self.last_restart); - - // kill previous task if any - if let Some(task) = self.restart_task.take() { - task.abort(); - } - - let restart_interval = self.restart_interval.clone(); - info!("restart_interval: {:?}", restart_interval); - // Add this function outside the SidecarManager impl - async fn check_and_restart_sidecar(app_handle: &tauri::AppHandle) -> Result<(), String> { - let state = app_handle.state::(); - let mut manager = state.0.lock().await; - if let Some(manager) = manager.as_mut() { - manager.check_and_restart(app_handle).await - } else { - Ok(()) - } - } - - // In the spawn method - let app_handle = app.app_handle().clone(); - self.restart_task = Some(tauri::async_runtime::spawn(async move { - loop { - let interval = *restart_interval.lock().await; - info!("interval: {}", interval.as_secs()); - if let Err(e) = check_and_restart_sidecar(&app_handle).await { - error!("Failed to check and restart sidecar: {}", e); - } - sleep(Duration::from_secs(60)).await; - } - })); + // self.last_restart = Instant::now(); + // info!("last restart: {:?}", self.last_restart); + + // // kill previous task if any + // if let Some(task) = self.restart_task.take() { + // task.abort(); + // } + + // let restart_interval = self.restart_interval.clone(); + // info!("restart_interval: {:?}", restart_interval); + // // Add this function outside the SidecarManager impl + // async fn check_and_restart_sidecar(app_handle: &tauri::AppHandle) -> Result<(), String> { + // let state = app_handle.state::(); + // let mut manager = state.0.lock().await; + // if let Some(manager) = manager.as_mut() { + // manager.check_and_restart(app_handle).await + // } else { + // Ok(()) + // } + // } + + // // In the spawn method + // let app_handle = app.app_handle().clone(); + // self.restart_task = Some(tauri::async_runtime::spawn(async move { + // loop { + // let interval = *restart_interval.lock().await; + // info!("interval: {}", interval.as_secs()); + // if let Err(e) = check_and_restart_sidecar(&app_handle).await { + // error!("Failed to check and restart sidecar: {}", e); + // } + // sleep(Duration::from_secs(60)).await; + // } + // })); Ok(()) } diff --git a/screenpipe-server/src/bin/screenpipe-server.rs b/screenpipe-server/src/bin/screenpipe-server.rs index 741ca4d729..d856c4d93f 100644 --- a/screenpipe-server/src/bin/screenpipe-server.rs +++ b/screenpipe-server/src/bin/screenpipe-server.rs @@ -953,8 +953,8 @@ async fn main() -> anyhow::Result<()> { info!("watching pid {} for auto-destruction", pid); let shutdown_tx_clone = shutdown_tx.clone(); tokio::spawn(async move { - // sleep for 5 seconds - tokio::time::sleep(std::time::Duration::from_secs(5)).await; + // sleep for 1 seconds + tokio::time::sleep(std::time::Duration::from_secs(1)).await; if watch_pid(pid).await { info!("Watched pid ({}) has stopped, initiating shutdown", pid); let _ = shutdown_tx_clone.send(()); From 33d27458df29110acc239ab1d966b72b7d81b2f8 Mon Sep 17 00:00:00 2001 From: Louis Beaumont Date: Fri, 14 Feb 2025 15:33:46 -0800 Subject: [PATCH 3/4] fix: remove use all monitor --- .../components/recording-settings.tsx | 4 +-- .../src/bin/screenpipe-server.rs | 29 ++++++++++++++++--- screenpipe-server/src/server.rs | 22 +++++++------- 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/screenpipe-app-tauri/components/recording-settings.tsx b/screenpipe-app-tauri/components/recording-settings.tsx index d3ffc3cda3..46ffabd4a3 100644 --- a/screenpipe-app-tauri/components/recording-settings.tsx +++ b/screenpipe-app-tauri/components/recording-settings.tsx @@ -737,7 +737,7 @@ export function RecordingSettings() { {!settings.disableVision && ( <> -
+ {/*

use all monitors

@@ -752,7 +752,7 @@ export function RecordingSettings() { handleSettingsChange({ useAllMonitors: checked }) } /> -

+
*/}
diff --git a/screenpipe-server/src/bin/screenpipe-server.rs b/screenpipe-server/src/bin/screenpipe-server.rs index d856c4d93f..4b563b02ae 100644 --- a/screenpipe-server/src/bin/screenpipe-server.rs +++ b/screenpipe-server/src/bin/screenpipe-server.rs @@ -481,8 +481,6 @@ async fn main() -> anyhow::Result<()> { // Channel for controlling the recorder ! TODO RENAME SHIT let vision_control = Arc::new(AtomicBool::new(true)); - let vision_control_server_clone = vision_control.clone(); - let warning_ocr_engine_clone = cli.ocr_engine.clone(); let warning_audio_transcription_engine_clone = cli.audio_transcription_engine.clone(); let monitor_ids = if cli.monitor_id.is_empty() { @@ -524,7 +522,6 @@ async fn main() -> anyhow::Result<()> { let audio_chunk_duration = Duration::from_secs(cli.audio_chunk_duration); let (realtime_transcription_sender, _) = tokio::sync::broadcast::channel(1000); - let realtime_transcription_sender_clone = realtime_transcription_sender.clone(); let (realtime_vision_sender, _) = tokio::sync::broadcast::channel(1000); let realtime_vision_sender = Arc::new(realtime_vision_sender.clone()); let realtime_vision_sender_clone = realtime_vision_sender.clone(); @@ -604,7 +601,6 @@ async fn main() -> anyhow::Result<()> { }; let (audio_devices_tx, _) = broadcast::channel(100); - let audio_devices_tx_clone = Arc::new(audio_devices_tx.clone()); let realtime_vision_sender_clone = realtime_vision_sender_clone.clone(); // TODO: Add SSE stream for realtime audio transcription @@ -957,6 +953,31 @@ async fn main() -> anyhow::Result<()> { tokio::time::sleep(std::time::Duration::from_secs(1)).await; if watch_pid(pid).await { info!("Watched pid ({}) has stopped, initiating shutdown", pid); + + // Get list of enabled pipes + let pipes = pipe_manager.list_pipes().await; + let enabled_pipes: Vec<_> = pipes.into_iter().filter(|p| p.enabled).collect(); + // Stop all enabled pipes in parallel + let stop_futures = enabled_pipes.iter().map(|pipe| { + let pipe_manager = pipe_manager.clone(); + let pipe_id = pipe.id.clone(); + tokio::spawn(async move { + if let Err(e) = pipe_manager.stop_pipe(&pipe_id).await { + error!("failed to stop pipe {}: {}", pipe_id, e); + } + }) + }); + // Wait for all pipes to stop with timeout + let timeout = tokio::time::sleep(Duration::from_secs(10)); + tokio::pin!(timeout); + tokio::select! { + _ = futures::future::join_all(stop_futures) => { + info!("all pipes stopped successfully"); + } + _ = &mut timeout => { + warn!("timeout waiting for pipes to stop"); + } + } let _ = shutdown_tx_clone.send(()); } }); diff --git a/screenpipe-server/src/server.rs b/screenpipe-server/src/server.rs index ec5b4ffbcd..ec40648ca7 100644 --- a/screenpipe-server/src/server.rs +++ b/screenpipe-server/src/server.rs @@ -1665,12 +1665,12 @@ async fn get_similar_speakers_handler( Ok(JsonResponse(similar_speakers)) } -#[derive(Deserialize)] -pub struct AudioDeviceControlRequest { - device_name: String, - #[serde(default)] - device_type: Option, -} +// #[derive(Deserialize)] +// pub struct AudioDeviceControlRequest { +// device_name: String, +// #[serde(default)] +// device_type: Option, +// } #[derive(Serialize)] pub struct AudioDeviceControlResponse { @@ -1829,11 +1829,11 @@ pub struct VisionDeviceControlRequest { device_id: u32, } -impl VisionDeviceControlRequest { - pub fn new(device_id: u32) -> Self { - Self { device_id } - } -} +// impl VisionDeviceControlRequest { +// pub fn new(device_id: u32) -> Self { +// Self { device_id } +// } +// } #[derive(Serialize)] pub struct VisionDeviceControlResponse { From a05438ed7e70d3014560518e04db4f3ac2cd9281 Mon Sep 17 00:00:00 2001 From: Louis Beaumont Date: Fri, 14 Feb 2025 17:12:46 -0800 Subject: [PATCH 4/4] v0 --- .../src/bin/screenpipe-server.rs | 1 + screenpipe-server/src/cli.rs | 4 + screenpipe-server/src/core.rs | 99 +++++- screenpipe-server/src/video.rs | 318 ++++++++++++------ .../src/bin/screenpipe-vision.rs | 5 +- screenpipe-vision/src/core.rs | 203 ++++++----- 6 files changed, 430 insertions(+), 200 deletions(-) diff --git a/screenpipe-server/src/bin/screenpipe-server.rs b/screenpipe-server/src/bin/screenpipe-server.rs index 4b563b02ae..818a6d99d6 100644 --- a/screenpipe-server/src/bin/screenpipe-server.rs +++ b/screenpipe-server/src/bin/screenpipe-server.rs @@ -560,6 +560,7 @@ async fn main() -> anyhow::Result<()> { cli.enable_realtime_audio_transcription, Arc::new(realtime_transcription_sender_clone), // Use the cloned sender realtime_vision_sender_clone, + cli.use_all_monitors, ); let result = tokio::select! { diff --git a/screenpipe-server/src/cli.rs b/screenpipe-server/src/cli.rs index b158e14219..b5afba72b5 100644 --- a/screenpipe-server/src/cli.rs +++ b/screenpipe-server/src/cli.rs @@ -264,6 +264,10 @@ pub struct Cli { #[arg(long, default_value_t = false)] pub capture_unfocused_windows: bool, + /// Automatically detect and use all monitors, including newly connected ones + #[arg(long, default_value_t = false)] + pub use_all_monitors: bool, + #[command(subcommand)] pub command: Option, diff --git a/screenpipe-server/src/core.rs b/screenpipe-server/src/core.rs index bb472819d5..bffe098f10 100644 --- a/screenpipe-server/src/core.rs +++ b/screenpipe-server/src/core.rs @@ -4,7 +4,6 @@ use crate::{DatabaseManager, VideoCapture}; use anyhow::Result; use dashmap::DashMap; use futures::future::join_all; -use tracing::{debug, error, info, warn}; use screenpipe_audio::realtime::RealtimeTranscriptionEvent; use screenpipe_audio::vad_engine::VadSensitivity; use screenpipe_audio::{ @@ -15,6 +14,7 @@ use screenpipe_audio::{start_realtime_recording, AudioStream}; use screenpipe_core::pii_removal::remove_pii; use screenpipe_core::Language; use screenpipe_vision::core::{RealtimeVisionEvent, WindowOcr}; +use screenpipe_vision::monitor::list_monitors; use screenpipe_vision::OcrEngine; use std::collections::HashMap; use std::path::PathBuf; @@ -23,6 +23,7 @@ use std::sync::Arc; use std::time::Duration; use tokio::runtime::Handle; use tokio::task::JoinHandle; +use tracing::{debug, error, info, warn}; #[allow(clippy::too_many_arguments)] pub async fn start_continuous_recording( @@ -52,9 +53,10 @@ pub async fn start_continuous_recording( realtime_audio_enabled: bool, realtime_transcription_sender: Arc>, realtime_vision_sender: Arc>, + use_all_monitors: bool, ) -> Result<()> { debug!("Starting video recording for monitor {:?}", monitor_ids); - let video_tasks = if !vision_disabled { + let video_tasks = if !vision_disabled && !use_all_monitors { monitor_ids .iter() .map(|&monitor_id| { @@ -96,6 +98,85 @@ pub async fn start_continuous_recording( })] }; + let monitor_watcher = if !vision_disabled && use_all_monitors { + let vision_control_clone = Arc::clone(&vision_control); + let db_clone = Arc::clone(&db); + let output_path_clone = Arc::clone(&output_path); + let ocr_engine_clone = Arc::clone(&ocr_engine); + let ignored_windows_clone = ignored_windows.to_vec(); + let include_windows_clone = include_windows.to_vec(); + let languages_clone = languages.clone(); + let realtime_vision_sender_clone = realtime_vision_sender.clone(); + let vision_handle_clone = vision_handle.clone(); + + Some(tokio::spawn(async move { + let mut current_monitors = HashMap::new(); + + loop { + let available_monitors = list_monitors().await; + let available_monitors = Arc::new(available_monitors); + + // Check for new monitors + for monitor in available_monitors.iter() { + let monitor_id = monitor.id(); + if !current_monitors.contains_key(&monitor_id) { + debug!("New monitor detected: {}", monitor_id); + + let handle = vision_handle_clone.spawn({ + let db_clone = Arc::clone(&db_clone); + let output_path_clone = Arc::clone(&output_path_clone); + let vision_control_clone = Arc::clone(&vision_control_clone); + let ocr_engine_clone = Arc::clone(&ocr_engine_clone); + let ignored_windows = ignored_windows_clone.clone(); + let include_windows = include_windows_clone.clone(); + let languages = languages_clone.clone(); + let realtime_sender = realtime_vision_sender_clone.clone(); + + async move { + record_video( + db_clone, + output_path_clone, + fps, + vision_control_clone, + ocr_engine_clone, + monitor_id, + use_pii_removal, + &ignored_windows, + &include_windows, + video_chunk_duration, + languages, + capture_unfocused_windows, + realtime_sender, + ) + .await + } + }); + + current_monitors.insert(monitor_id, handle); + } + } + + // Check for removed monitors + current_monitors.retain(|monitor_id, handle| { + if !Arc::clone(&available_monitors) + .iter() + .any(|m| m.id() == *monitor_id) + { + debug!("Monitor removed: {}", monitor_id); + handle.abort(); + false + } else { + true + } + }); + + tokio::time::sleep(Duration::from_secs(5)).await; + } + })) + } else { + None + }; + let (whisper_sender, whisper_receiver, whisper_shutdown_flag) = if audio_disabled { // Create a dummy channel if no audio devices are available, e.g. audio disabled let (input_sender, _): ( @@ -163,6 +244,13 @@ pub async fn start_continuous_recording( error!("Audio recording error: {:?}", e); } + // Make sure to handle the monitor watcher task + if let Some(watcher) = monitor_watcher { + if let Err(e) = watcher.await { + error!("Monitor watcher task failed: {}", e); + } + } + // Shutdown the whisper channel whisper_shutdown_flag.store(true, Ordering::Relaxed); drop(whisper_sender_clone); // Close the sender channel @@ -172,6 +260,7 @@ pub async fn start_continuous_recording( // TODO: any additional cleanup like device controls to release info!("Stopped recording"); + Ok(()) } @@ -220,10 +309,10 @@ async fn record_video( fps, video_chunk_duration, new_chunk_callback, - Arc::clone(&ocr_engine), + (*ocr_engine).clone(), monitor_id, - ignored_windows, - include_windows, + ignored_windows.to_vec().into(), + include_windows.to_vec().into(), languages, capture_unfocused_windows, ); diff --git a/screenpipe-server/src/video.rs b/screenpipe-server/src/video.rs index 8137d20168..650a273f18 100644 --- a/screenpipe-server/src/video.rs +++ b/screenpipe-server/src/video.rs @@ -1,11 +1,10 @@ use chrono::Utc; use crossbeam::queue::ArrayQueue; -use image::ImageFormat::{self}; -use tracing::{debug, error, info, warn}; use screenpipe_core::{find_ffmpeg_path, Language}; use screenpipe_vision::{ capture_screenshot_by_window::WindowFilters, continuous_capture, CaptureResult, OcrEngine, }; +use std::borrow::Cow; use std::path::PathBuf; use std::process::Stdio; use std::sync::Arc; @@ -13,9 +12,12 @@ use std::time::Duration; use tokio::io::AsyncBufReadExt; use tokio::io::AsyncWriteExt; use tokio::io::BufReader; -use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command}; +use tokio::process::{Child, ChildStderr, ChildStdin, Command}; use tokio::sync::mpsc::channel; +use tokio::sync::watch; +use tokio::task::JoinHandle; use tokio::time::sleep; +use tracing::{debug, error, info, warn}; pub(crate) const MAX_FPS: f64 = 30.0; // Adjust based on your needs const MAX_QUEUE_SIZE: usize = 10; @@ -24,26 +26,30 @@ pub struct VideoCapture { #[allow(unused)] video_frame_queue: Arc>>, pub ocr_frame_queue: Arc>>, + shutdown_tx: watch::Sender, + handles: Vec>, } impl VideoCapture { - #[allow(clippy::too_many_arguments)] pub fn new( output_path: &str, fps: f64, video_chunk_duration: Duration, new_chunk_callback: impl Fn(&str) + Send + Sync + 'static, - ocr_engine: Arc, + ocr_engine: OcrEngine, monitor_id: u32, - ignore_list: &[String], - include_list: &[String], + ignore_list: Arc<[String]>, + include_list: Arc<[String]>, languages: Vec, capture_unfocused_windows: bool, ) -> Self { let fps = if fps.is_finite() && fps > 0.0 { fps } else { - warn!("Invalid FPS value: {}. Using default of 1.0", fps); + warn!( + "[monitor_id: {}] Invalid FPS value: {}. Using default of 1.0", + monitor_id, fps + ); 1.0 }; let interval = Duration::from_secs_f64(1.0 / fps); @@ -55,48 +61,74 @@ impl VideoCapture { let capture_video_frame_queue = video_frame_queue.clone(); let capture_ocr_frame_queue = ocr_frame_queue.clone(); let (result_sender, mut result_receiver) = channel(512); - let window_filters = Arc::new(WindowFilters::new(ignore_list, include_list)); + let window_filters = Arc::new(WindowFilters::new(&ignore_list, &include_list)); let window_filters_clone = Arc::clone(&window_filters); - let _capture_thread = tokio::spawn(async move { - continuous_capture( - result_sender, - interval, - (*ocr_engine).clone(), - monitor_id, - window_filters_clone, - languages.clone(), - capture_unfocused_windows, - ) - .await; - }); - - // In the _queue_thread - let _queue_thread = tokio::spawn(async move { - // Helper function to push to queue and handle errors - fn push_to_queue( - queue: &ArrayQueue>, - result: &Arc, - queue_name: &str, - ) -> bool { - if queue.push(Arc::clone(result)).is_err() { - if queue.pop().is_none() { - error!("{} queue is in an inconsistent state", queue_name); - return false; + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let mut handles = Vec::new(); + let shutdown_rx_capture = shutdown_rx.clone(); + let shutdown_rx_queue = shutdown_rx.clone(); + let shutdown_rx_video = shutdown_rx.clone(); + let languages_clone = languages.clone(); + let result_sender_inner = result_sender.clone(); + + let capture_handle = tokio::spawn(async move { + let mut rx = shutdown_rx_capture; + loop { + if *rx.borrow() { + info!( + "[monitor_id: {}] shutting down video capture thread", + monitor_id + ); + break; + } + let result_sender = result_sender_inner.clone(); + let window_filters_clone = Arc::clone(&window_filters_clone); + let languages_clone = languages_clone.clone(); + + tokio::select! { + _ = continuous_capture( + result_sender, + interval, + ocr_engine.clone(), + monitor_id, + window_filters_clone, + languages_clone.clone(), + capture_unfocused_windows, + rx.clone(), + ) => { + debug!("[monitor_id: {}] continuous capture completed, restarting", monitor_id); } - if queue.push(Arc::clone(result)).is_err() { - error!( - "Failed to push to {} queue after removing oldest frame", - queue_name - ); - return false; + _ = rx.changed() => { + if *rx.borrow() { + info!("[monitor_id: {}] shutting down video capture thread", monitor_id); + break; + } } - debug!("{} queue was full, dropped oldest frame", queue_name); } - true } + debug!( + "[monitor_id: {}] exiting capture handle loop, dropping sender", + monitor_id + ); + drop(result_sender_inner); + }); + handles.push(capture_handle); + + let queue_handle = tokio::spawn(async move { + let rx = shutdown_rx_queue; while let Some(result) = result_receiver.recv().await { + if *rx.borrow() { + info!( + "[monitor_id: {}] shutting down video queue thread", + monitor_id + ); + break; + } let frame_number = result.frame_number; - debug!("Received frame {} for queueing", frame_number); + debug!( + "[monitor_id: {}] received frame {} for queueing", + monitor_id, frame_number + ); let result = Arc::new(result); @@ -105,92 +137,105 @@ impl VideoCapture { if !video_pushed || !ocr_pushed { error!( - "Failed to push frame {} to one or more queues", - frame_number + "[monitor_id: {}] failed to push frame {} to one or more queues, queue lengths: {}, {}", + monitor_id, frame_number, + capture_video_frame_queue.len(), + capture_ocr_frame_queue.len() ); continue; // Skip to next iteration instead of crashing } debug!( - "Frame {} pushed to queues. Queue lengths: {}, {}", + "[monitor_id: {}] frame {} pushed to queues. Queue lengths: {}, {}", + monitor_id, frame_number, capture_video_frame_queue.len(), capture_ocr_frame_queue.len() ); } }); + handles.push(queue_handle); let video_frame_queue_clone = video_frame_queue.clone(); let output_path = output_path.to_string(); - let _video_thread = tokio::spawn(async move { - save_frames_as_video( + let video_handle = tokio::spawn(async move { + let rx = shutdown_rx_video; + save_frames_as_video_with_shutdown( &video_frame_queue_clone, &output_path, fps, new_chunk_callback_clone, monitor_id, video_chunk_duration, + rx, ) .await; }); + handles.push(video_handle); VideoCapture { video_frame_queue, ocr_frame_queue, + shutdown_tx, + handles, } } + + pub async fn shutdown(self) -> Result<(), anyhow::Error> { + info!("shutting down video capture"); + self.shutdown_tx.send(true)?; + + for handle in self.handles { + if let Err(e) = handle.await { + error!("error joining handle: {}", e); + } + } + + Ok(()) + } } pub async fn start_ffmpeg_process(output_file: &str, fps: f64) -> Result { - // Overriding fps with max fps if over the max and warning user - let fps = if fps > MAX_FPS { - warn!("Overriding FPS from {} to {}", fps, MAX_FPS); - MAX_FPS - } else { - fps - }; - - info!("Starting FFmpeg process for file: {}", output_file); + let fps = fps.min(MAX_FPS); + + debug!("starting ffmpeg process for: {}", output_file); let fps_str = fps.to_string(); let mut command = Command::new(find_ffmpeg_path().unwrap()); - let mut args = vec![ + + // Updated FFmpeg arguments for better performance and quality + let args = vec![ "-f", "image2pipe", "-vcodec", - "png", + "mjpeg", "-r", &fps_str, "-i", "-", "-vf", - "pad=width=ceil(iw/2)*2:height=ceil(ih/2)*2", - ]; - - args.extend_from_slice(&[ - "-vcodec", + "format=yuv420p,pad=width=ceil(iw/2)*2:height=ceil(ih/2)*2", + "-c:v", "libx265", "-tag:v", "hvc1", "-preset", - "ultrafast", + "medium", // Changed from ultrafast for better compression "-crf", - "23", - ]); - - args.extend_from_slice(&["-pix_fmt", "yuv420p", output_file]); + "28", // Slightly higher CRF for smaller file size + "-x265-params", + "log-level=error", // Reduce x265 logging noise + output_file, + ]; command .args(&args) .stdin(Stdio::piped()) - .stdout(Stdio::piped()) + .stdout(Stdio::null()) // Changed to null since we don't need stdout .stderr(Stdio::piped()); - debug!("FFmpeg command: {:?}", command); - + debug!("ffmpeg command: {:?}", command); let child = command.spawn()?; - debug!("FFmpeg process spawned"); - Ok(child) } @@ -202,29 +247,30 @@ pub async fn write_frame_to_ffmpeg( Ok(()) } -async fn log_ffmpeg_output(stream: impl AsyncBufReadExt + Unpin, stream_name: &str) { - let reader = BufReader::new(stream); - let mut lines = reader.lines(); - while let Ok(Some(line)) = lines.next_line().await { - debug!("FFmpeg {}: {}", stream_name, line); - } -} - -async fn save_frames_as_video( +async fn save_frames_as_video_with_shutdown( frame_queue: &Arc>>, output_path: &str, fps: f64, new_chunk_callback: Arc, monitor_id: u32, video_chunk_duration: Duration, + mut shutdown_rx: watch::Receiver, ) { - debug!("Starting save_frames_as_video function"); + debug!("starting save_frames_as_video function"); let frames_per_video = (fps * video_chunk_duration.as_secs_f64()).ceil() as usize; let mut frame_count = 0; let mut current_ffmpeg: Option = None; let mut current_stdin: Option = None; loop { + if *shutdown_rx.borrow() { + info!("shutting down video capture thread"); + if let Some(child) = current_ffmpeg.take() { + finish_ffmpeg_process(child, current_stdin.take()).await; + } + break; + } + if frame_count >= frames_per_video || current_ffmpeg.is_none() { if let Some(child) = current_ffmpeg.take() { finish_ffmpeg_process(child, current_stdin.take()).await; @@ -240,20 +286,20 @@ async fn save_frames_as_video( match start_ffmpeg_process(&output_file, fps).await { Ok(mut child) => { let mut stdin = child.stdin.take().expect("Failed to open stdin"); - spawn_ffmpeg_loggers(child.stderr.take(), child.stdout.take()); + spawn_ffmpeg_loggers(child.stderr.take()); if let Err(e) = write_frame_to_ffmpeg(&mut stdin, &buffer).await { - error!("Failed to write first frame to ffmpeg: {}", e); + error!("failed to write first frame to ffmpeg: {}", e); continue; } frame_count += 1; current_ffmpeg = Some(child); current_stdin = Some(stdin); - debug!("New FFmpeg process started for file: {}", output_file); + debug!("new FFmpeg process started for file: {}", output_file); } Err(e) => { - error!("Failed to start FFmpeg process: {}", e); + error!("failed to start FFmpeg process: {}", e); continue; } } @@ -265,10 +311,23 @@ async fn save_frames_as_video( &mut frame_count, frames_per_video, fps, + &shutdown_rx, ) .await; - tokio::task::yield_now().await; + tokio::select! { + _ = shutdown_rx.changed() => { + if *shutdown_rx.borrow() { + if let Some(child) = current_ffmpeg.take() { + finish_ffmpeg_process(child, current_stdin.take()).await; + } + break; + } + } + _ = tokio::time::sleep(Duration::from_millis(10)) => { + // Continue with normal processing + } + } } } @@ -286,9 +345,12 @@ async fn wait_for_first_frame( fn encode_frame(frame: &CaptureResult) -> Vec { let mut buffer = Vec::new(); - frame - .image - .write_to(&mut std::io::Cursor::new(&mut buffer), ImageFormat::Png) + let rgb_image = frame.image.to_rgb8(); + rgb_image + .write_to( + &mut std::io::Cursor::new(&mut buffer), + image::ImageFormat::Jpeg, + ) .expect("Failed to encode frame"); buffer } @@ -303,12 +365,22 @@ fn create_output_file(output_path: &str, monitor_id: u32) -> String { .to_string() } -fn spawn_ffmpeg_loggers(stderr: Option, stdout: Option) { +fn spawn_ffmpeg_loggers(stderr: Option) { if let Some(stderr) = stderr { - tokio::spawn(log_ffmpeg_output(BufReader::new(stderr), "stderr")); - } - if let Some(stdout) = stdout { - tokio::spawn(log_ffmpeg_output(BufReader::new(stdout), "stdout")); + tokio::spawn(async move { + let reader = BufReader::new(stderr); + let mut lines = reader.lines(); + while let Ok(Some(line)) = lines.next_line().await { + // Only log important messages + if line.contains("error") || line.contains("fatal") { + error!("ffmpeg: {}", line); + } else if line.contains("warning") { + warn!("ffmpeg: {}", line); + } else { + debug!("ffmpeg: {}", line); + } + } + }); } } @@ -318,25 +390,39 @@ async fn process_frames( frame_count: &mut usize, frames_per_video: usize, fps: f64, + shutdown_rx: &watch::Receiver, ) { let write_timeout = Duration::from_secs_f64(1.0 / fps); - while *frame_count < frames_per_video { + let mut should_break = false; + + while *frame_count < frames_per_video && !should_break { + if *shutdown_rx.borrow() { + info!("process_frames: shutdown signal received, breaking out"); + should_break = true; + continue; + } + if let Some(frame) = frame_queue.pop() { let buffer = encode_frame(&frame); if let Some(stdin) = current_stdin.as_mut() { if let Err(e) = write_frame_with_retry(stdin, &buffer).await { - error!("Failed to write frame to ffmpeg after max retries: {}", e); - break; + error!("failed to write frame to ffmpeg after max retries: {}", e); + should_break = true; + continue; } *frame_count += 1; - debug!("Wrote frame {} to FFmpeg", frame_count); - + debug!("wrote frame {} to ffmpeg", frame_count); flush_ffmpeg_input(stdin, *frame_count, fps).await; } } else { tokio::time::sleep(write_timeout).await; } } + + // Cleanup remaining frames + while frame_queue.pop().is_some() { + debug!("cleaning up remaining frame from queue"); + } } async fn write_frame_with_retry( @@ -384,10 +470,34 @@ pub async fn finish_ffmpeg_process(child: Child, stdin: Option) { match child.wait_with_output().await { Ok(output) => { debug!("FFmpeg process exited with status: {}", output.status); - if !output.status.success() { - error!("FFmpeg stderr: {}", String::from_utf8_lossy(&output.stderr)); + let stderr = String::from_utf8_lossy(&output.stderr); + if !output.status.success() && stderr != Cow::Borrowed("") { + error!("FFmpeg stderr: {}", stderr); } } Err(e) => error!("Failed to wait for FFmpeg process: {}", e), } } + +fn push_to_queue( + queue: &ArrayQueue>, + result: &Arc, + queue_name: &str, +) -> bool { + match queue.push(Arc::clone(result)) { + Ok(_) => { + debug!( + "{} queue: Successfully pushed frame {}", + queue_name, result.frame_number + ); + true + } + Err(_) => { + warn!( + "{} queue full, dropping frame {}", + queue_name, result.frame_number + ); + false + } + } +} diff --git a/screenpipe-vision/src/bin/screenpipe-vision.rs b/screenpipe-vision/src/bin/screenpipe-vision.rs index 7c777096bb..163f771ca6 100644 --- a/screenpipe-vision/src/bin/screenpipe-vision.rs +++ b/screenpipe-vision/src/bin/screenpipe-vision.rs @@ -5,7 +5,7 @@ use screenpipe_vision::{ OcrEngine, }; use std::{sync::Arc, time::Duration}; -use tokio::sync::mpsc::channel; +use tokio::sync::{mpsc::channel, watch}; use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; #[derive(Parser)] @@ -48,8 +48,9 @@ async fn main() { Arc::new(window_filters), languages.clone(), false, + watch::channel(false).1, ) - .await + .await; }); // Example: Process results for 10 seconds, then pause for 5 seconds, then stop diff --git a/screenpipe-vision/src/core.rs b/screenpipe-vision/src/core.rs index 6640f38d2a..fa045df616 100644 --- a/screenpipe-vision/src/core.rs +++ b/screenpipe-vision/src/core.rs @@ -13,7 +13,6 @@ use anyhow::{anyhow, Result}; use base64::{engine::general_purpose, Engine as _}; use image::codecs::jpeg::JpegEncoder; use image::DynamicImage; -use log::{debug, error}; use screenpipe_core::Language; use screenpipe_integrations::unstructured_ocr::perform_ocr_cloud; use serde::Deserialize; @@ -29,7 +28,10 @@ use std::{ use tokio::fs::File; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::sync::mpsc::Sender; -use tokio::time::sleep; +use tokio::sync::watch; +use tracing::debug; +use tracing::error; +use tracing::info; #[cfg(target_os = "macos")] use xcap_macos::Monitor; @@ -138,6 +140,7 @@ pub async fn continuous_capture( window_filters: Arc, languages: Vec, capture_unfocused_windows: bool, + mut shutdown_rx: watch::Receiver, ) { let mut frame_counter: u64 = 0; let mut previous_image: Option = None; @@ -149,102 +152,124 @@ pub async fn continuous_capture( monitor_id ); + let monitor = get_monitor_by_id(monitor_id).await.unwrap(); + loop { - let monitor = match get_monitor_by_id(monitor_id).await { - Some(m) => m, - None => { - sleep(Duration::from_secs(1)).await; - continue; - } - }; - let capture_result = - match capture_screenshot(&monitor, &window_filters, capture_unfocused_windows).await { - Ok((image, window_images, image_hash, _capture_duration)) => { - debug!( - "Captured screenshot on monitor {} with hash: {}", - monitor_id, image_hash - ); - Some((image, window_images, image_hash)) - } - Err(e) => { - error!("Failed to capture screenshot: {}", e); - None - } - }; - - if let Some((image, window_images, image_hash)) = capture_result { - let current_average = match compare_with_previous_image( - previous_image.as_ref(), - &image, - &mut max_average, - frame_counter, - &mut max_avg_value, - ) - .await - { - Ok(avg) => avg, - Err(e) => { - error!("Error comparing images: {}", e); - 0.0 - } - }; - - let current_average = if previous_image.is_none() { - 1.0 - } else { - current_average - }; - - if current_average < 0.006 { - debug!( - "Skipping frame {} due to low average difference: {:.3}", - frame_counter, current_average - ); - frame_counter += 1; - tokio::time::sleep(interval).await; - continue; - } + // Check shutdown signal + if *shutdown_rx.borrow() { + info!( + "continuous_capture: received shutdown signal for monitor {}", + monitor_id + ); + drop(result_tx); + break; + } - if current_average > max_avg_value { - max_average = Some(MaxAverageFrame { - image: image.clone(), - window_images: window_images.clone(), - image_hash, - frame_number: frame_counter, - timestamp: Instant::now(), - result_tx: result_tx.clone(), - average: current_average, - }); - max_avg_value = current_average; + // Use tokio::select! to handle both capture and shutdown + tokio::select! { + _ = shutdown_rx.changed() => { + if *shutdown_rx.borrow() { + info!("continuous_capture: shutdown signal received for monitor {}", monitor_id); + drop(result_tx); + break; + } } - - previous_image = Some(image); - - if let Some(max_avg_frame) = max_average.take() { - let ocr_task_data = OcrTaskData { - image: max_avg_frame.image, - window_images: max_avg_frame.window_images, - frame_number: max_avg_frame.frame_number, - timestamp: max_avg_frame.timestamp, - result_tx: max_avg_frame.result_tx, + _ = async { + let languages_clone = languages.clone(); + let capture_result = match capture_screenshot(&monitor, &window_filters, capture_unfocused_windows).await { + Ok((image, window_images, image_hash, _capture_duration)) => { + debug!( + "captured screenshot on monitor {} with hash: {}", + monitor.id(), + image_hash + ); + Some((image, window_images, image_hash)) + } + Err(e) => { + error!("Failed to capture screenshot: {}", e); + None + } }; - if let Err(e) = - process_ocr_task(ocr_task_data, &ocr_engine, languages.clone()).await - { - error!("Error processing OCR task: {}", e); + if let Some((image, window_images, image_hash)) = capture_result { + let current_average = match compare_with_previous_image( + previous_image.as_ref(), + &image, + &mut max_average, + frame_counter, + &mut max_avg_value, + ) + .await + { + Ok(avg) => avg, + Err(e) => { + error!("Error comparing images: {}", e); + previous_image = None; + 0.0 + } + }; + + let current_average = if previous_image.is_none() { + 1.0 + } else { + current_average + }; + + if current_average < 0.006 { + debug!( + "Skipping frame {} due to low average difference: {:.3}", + frame_counter, current_average + ); + frame_counter += 1; + tokio::time::sleep(interval).await; + return; + } + + if current_average > max_avg_value { + max_average = Some(MaxAverageFrame { + image: image.clone(), + window_images, + image_hash, + frame_number: frame_counter, + timestamp: Instant::now(), + result_tx: result_tx.clone(), + average: current_average, + }); + max_avg_value = current_average; + } + + previous_image = Some(image); + + if let Some(max_avg_frame) = max_average.take() { + let ocr_task_data = OcrTaskData { + image: max_avg_frame.image.clone(), + window_images: max_avg_frame.window_images.iter().cloned().collect(), + frame_number: max_avg_frame.frame_number, + timestamp: max_avg_frame.timestamp, + result_tx: result_tx.clone(), + }; + + if let Err(e) = + process_ocr_task(ocr_task_data, &ocr_engine, languages_clone).await + { + error!("Error processing OCR task: {}", e); + } + + frame_counter = 0; + max_avg_value = 0.0; + + } + } else { + debug!("Skipping frame {} due to capture failure", frame_counter); } - frame_counter = 0; - max_avg_value = 0.0; - } - } else { - debug!("Skipping frame {} due to capture failure", frame_counter); + frame_counter += 1; + tokio::time::sleep(interval).await; + } => {} } - - frame_counter += 1; - tokio::time::sleep(interval).await; } + + debug!("Continuous capture stopped for monitor {}", monitor_id); } pub struct MaxAverageFrame {