Skip to content

Commit

Permalink
fix: windows native ocr
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Aug 5, 2024
1 parent fa97e66 commit 55972a0
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 91 deletions.
1 change: 0 additions & 1 deletion screenpipe-server/src/bin/screenpipe-video.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ async fn main() {
}
}

video_capture.stop().await;
println!("Video capture completed. Output saved to: {}", output_path);
println!("JSON data saved to: {}", json_output_path);
}
1 change: 0 additions & 1 deletion screenpipe-server/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ async fn record_video(
tokio::time::sleep(Duration::from_secs_f64(1.0 / fps)).await;
}

video_capture.stop().await;
Ok(())
}

Expand Down
74 changes: 17 additions & 57 deletions screenpipe-server/src/video.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use chrono::Utc;
use image::ImageFormat::{self};
use log::{debug, error, info, warn};
use screenpipe_core::find_ffmpeg_path;
use screenpipe_vision::{continuous_capture, CaptureResult, ControlMessage, OcrEngine};
use screenpipe_vision::{continuous_capture, get_monitor, CaptureResult, OcrEngine};
use std::collections::VecDeque;
use std::path::PathBuf;
use std::process::Stdio;
Expand All @@ -18,12 +18,9 @@ use std::time::Duration;
const MAX_FPS: f64 = 30.0; // Adjust based on your needs

pub struct VideoCapture {
control_tx: Sender<ControlMessage>,
frame_queue: Arc<Mutex<VecDeque<CaptureResult>>>,
video_frame_queue: Arc<Mutex<VecDeque<CaptureResult>>>,
pub ocr_frame_queue: Arc<Mutex<VecDeque<CaptureResult>>>,
ffmpeg_handle: Arc<Mutex<Option<Child>>>,
is_running: Arc<Mutex<bool>>,
}

impl VideoCapture {
Expand All @@ -35,27 +32,23 @@ impl VideoCapture {
ocr_engine: Arc<OcrEngine>,
) -> Self {
info!("Starting new video capture");
let (control_tx, mut control_rx) = channel(512);
let frame_queue = Arc::new(Mutex::new(VecDeque::new()));
let video_frame_queue = Arc::new(Mutex::new(VecDeque::new()));
let ocr_frame_queue = Arc::new(Mutex::new(VecDeque::new()));
let ffmpeg_handle = Arc::new(Mutex::new(None));
let is_running = Arc::new(Mutex::new(true));
let new_chunk_callback = Arc::new(new_chunk_callback);
let new_chunk_callback_clone = Arc::clone(&new_chunk_callback);

let capture_frame_queue = frame_queue.clone();
let capture_video_frame_queue = video_frame_queue.clone();
let capture_ocr_frame_queue = ocr_frame_queue.clone();
let capture_thread_is_running = is_running.clone();
let (result_sender, mut result_receiver) = channel(512);
let _capture_thread = tokio::spawn(async move {
continuous_capture(
&mut control_rx,
result_sender,
Duration::from_secs_f64(1.0 / fps),
save_text_files,
ocr_engine,
get_monitor().await,
)
.await;
});
Expand All @@ -64,66 +57,40 @@ impl VideoCapture {

// Spawn another thread to handle receiving and queueing the results
let _queue_thread = tokio::spawn(async move {
while *capture_thread_is_running.lock().await {
if let Some(result) = result_receiver.recv().await {
let frame_number = result.frame_number;
debug!("Received frame {} for queueing", frame_number);
let mut queue = capture_frame_queue.lock().await;
let mut video_queue = capture_video_frame_queue.lock().await;
let mut ocr_queue = capture_ocr_frame_queue.lock().await;
queue.push_back(result.clone());
video_queue.push_back(result.clone());
ocr_queue.push_back(result);
debug!("Frame {} pushed to queues. Queue length: {}, Video queue length: {}, OCR queue length: {}", frame_number, queue.len(), video_queue.len(), ocr_queue.len());

// Clear the old queue after processing
if queue.len() > 1 {
queue.pop_front();
}
while let Some(result) = result_receiver.recv().await {
let frame_number = result.frame_number;
debug!("Received frame {} for queueing", frame_number);
let mut queue = capture_frame_queue.lock().await;
let mut video_queue = capture_video_frame_queue.lock().await;
let mut ocr_queue = capture_ocr_frame_queue.lock().await;
queue.push_back(result.clone());
video_queue.push_back(result.clone());
ocr_queue.push_back(result);
debug!("Frame {} pushed to queues. Queue length: {}, Video queue length: {}, OCR queue length: {}", frame_number, queue.len(), video_queue.len(), ocr_queue.len());

// Clear the old queue after processing
if queue.len() > 1 {
queue.pop_front();
}
}
});

let video_frame_queue_clone = video_frame_queue.clone();
let video_thread_is_running = is_running.clone();
let output_path = output_path.to_string();
let _video_thread = tokio::spawn(async move {
save_frames_as_video(
&video_frame_queue_clone,
&output_path,
fps,
video_thread_is_running,
new_chunk_callback_clone,
)
.await;
});

VideoCapture {
control_tx,
frame_queue,
video_frame_queue,
ocr_frame_queue,
ffmpeg_handle,
is_running,
}
}

pub async fn pause(&self) {
self.control_tx.send(ControlMessage::Pause).await.unwrap();
}

pub async fn resume(&self) {
self.control_tx.send(ControlMessage::Resume).await.unwrap();
}

pub async fn stop(&self) {
self.control_tx.send(ControlMessage::Stop).await.unwrap();
*self.is_running.lock().await = false;
if let Some(mut child) = self.ffmpeg_handle.lock().await.take() {
child
.wait()
.await
.expect("Failed to wait for ffmpeg process");
}
}

Expand All @@ -142,7 +109,6 @@ async fn save_frames_as_video(
frame_queue: &Arc<Mutex<VecDeque<CaptureResult>>>,
output_path: &str,
fps: f64,
is_running: Arc<Mutex<bool>>,
new_chunk_callback: Arc<dyn Fn(&str) + Send + Sync>,
) {
debug!("Starting save_frames_as_video function");
Expand All @@ -153,7 +119,7 @@ async fn save_frames_as_video(
let mut current_ffmpeg: Option<Child> = None;
let mut current_stdin: Option<ChildStdin> = None;

while *is_running.lock().await {
loop {
if frame_count % frames_per_video == 0 || current_ffmpeg.is_none() {
debug!("Starting new FFmpeg process");
// Close previous FFmpeg process if exists
Expand Down Expand Up @@ -288,12 +254,6 @@ async fn save_frames_as_video(
tokio::task::yield_now().await;
}
}

// Close the final FFmpeg process
if let Some(mut child) = current_ffmpeg.take() {
drop(current_stdin.take()); // Ensure stdin is closed
child.wait().await.expect("ffmpeg process failed");
}
}

use std::env;
Expand Down
17 changes: 11 additions & 6 deletions screenpipe-vision/benches/vision_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,30 @@
// cargo bench --bench vision_benchmark
// ! not very useful bench

use std::sync::Arc;

use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use screenpipe_vision::{continuous_capture, ControlMessage};
use screenpipe_vision::{continuous_capture, get_monitor, OcrEngine};
use tokio::sync::mpsc;
use tokio::time::Duration;

async fn benchmark_continuous_capture(duration_secs: u64) -> f64 {
let (control_tx, mut control_rx) = mpsc::channel(1);
let (result_tx, mut result_rx) = mpsc::channel(100);

let capture_handle = tokio::spawn(async move {
continuous_capture(&mut control_rx, result_tx, Duration::from_millis(100), false).await;
continuous_capture(
result_tx,
Duration::from_millis(100),
false,
Arc::new(OcrEngine::Tesseract),
get_monitor().await,
)
.await;
});

// Run for specified duration
tokio::time::sleep(Duration::from_secs(duration_secs)).await;

// Stop the capture
control_tx.send(ControlMessage::Stop).await.unwrap();

// Wait for the capture to finish
capture_handle.await.unwrap();

Expand Down
10 changes: 3 additions & 7 deletions screenpipe-vision/src/bin/screenpipe-vision.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use clap::Parser;
use screenpipe_vision::{continuous_capture, ControlMessage, OcrEngine};
use screenpipe_vision::{continuous_capture, get_monitor, OcrEngine};
use std::{sync::Arc, time::Duration};
use tokio::sync::mpsc::channel;

Expand All @@ -19,18 +19,17 @@ struct Cli {
async fn main() {
let cli = Cli::parse();

let (control_tx, mut control_rx) = channel(512);
let (result_tx, mut result_rx) = channel(512);

let save_text_files = cli.save_text_files;

let capture_thread = tokio::spawn(async move {
continuous_capture(
&mut control_rx,
result_tx,
Duration::from_secs(1),
save_text_files,
Arc::new(OcrEngine::Tesseract),
get_monitor().await,
)
.await
});
Expand All @@ -43,10 +42,7 @@ async fn main() {
}

let elapsed = start_time.elapsed();
if elapsed >= Duration::from_secs(10) && elapsed < Duration::from_secs(15) {
control_tx.send(ControlMessage::Pause).await.unwrap();
} else if elapsed >= Duration::from_secs(15) {
control_tx.send(ControlMessage::Stop).await.unwrap();
if elapsed >= Duration::from_secs(15) {
break;
}

Expand Down
19 changes: 7 additions & 12 deletions screenpipe-vision/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ use std::{
time::{Duration, Instant},
};
use strsim::levenshtein;
use tokio::sync::{
mpsc::{Receiver, Sender},
Mutex,
}; // Corrected import for Mutex
use tokio::sync::{mpsc::Sender, Mutex}; // Corrected import for Mutex
use xcap::{Monitor, Window};

#[cfg(target_os = "windows")]
Expand All @@ -22,11 +19,7 @@ use crate::utils::{
save_text_files,
};
use rusty_tesseract::{Data, DataOutput}; // Add this import
pub enum ControlMessage {
Pause,
Resume,
Stop,
}


pub struct DataOutputWrapper {
pub data_output: rusty_tesseract::tesseract::output_data::DataOutput,
Expand Down Expand Up @@ -102,15 +95,17 @@ pub struct OcrTaskData {
pub result_tx: Sender<CaptureResult>,
}

pub async fn get_monitor() -> Monitor {
Monitor::all().unwrap().first().unwrap().clone()
}

pub async fn continuous_capture(
_control_rx: &mut Receiver<ControlMessage>,
result_tx: Sender<CaptureResult>,
interval: Duration,
save_text_files_flag: bool,
ocr_engine: Arc<OcrEngine>,
monitor: Monitor,
) {
let monitor = Monitor::all().unwrap().first().unwrap().clone(); // Simplified monitor retrieval

debug!("continuous_capture: Starting using monitor: {:?}", monitor);
let previous_text_json = Arc::new(Mutex::new(None));
let ocr_task_running = Arc::new(AtomicBool::new(false));
Expand Down
2 changes: 1 addition & 1 deletion screenpipe-vision/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod core;
pub mod utils;
pub use core::{continuous_capture, process_ocr_task, CaptureResult, ControlMessage};
pub use core::{continuous_capture, get_monitor, process_ocr_task, CaptureResult};
pub use utils::{perform_ocr_tesseract, OcrEngine};
66 changes: 60 additions & 6 deletions screenpipe-vision/tests/windows_vision_test.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
#[cfg(target_os = "windows")]
#[cfg(test)]
mod tests {
use screenpipe_vision::{process_ocr_task, OcrEngine};
use std::path::PathBuf;
use std::{sync::Arc, time::Instant};
use screenpipe_vision::{get_monitor, process_ocr_task, OcrEngine};
use std::sync::Arc;
use std::{path::PathBuf, time::Instant};
use tokio::sync::{mpsc, Mutex};
#[cfg(target_os = "windows")]

use screenpipe_vision::{continuous_capture, CaptureResult};
use std::time::Duration;
use tokio::time::timeout;

// #[cfg(target_os = "windows")]
#[tokio::test]
async fn test_process_ocr_task_windows() {
// Use an absolute path that works in both local and CI environments
Expand All @@ -20,7 +24,7 @@ mod tests {
let timestamp = Instant::now();
let (tx, _rx) = mpsc::channel(1);
let previous_text_json = Arc::new(Mutex::new(None));
let ocr_engine = Arc::new(OcrEngine::WindowsNative);
let ocr_engine = Arc::new(OcrEngine::Tesseract);
let app_name = "test_app".to_string();

let result = process_ocr_task(
Expand All @@ -38,4 +42,54 @@ mod tests {
assert!(result.is_ok());
// Add more specific assertions based on expected behavior
}

#[tokio::test]
#[ignore] // TODO require UI
async fn test_continuous_capture() {
// Create channels for communication
let (result_tx, mut result_rx) = mpsc::channel::<CaptureResult>(10);

// Create a mock monitor
let monitor = get_monitor().await;

// Set up test parameters
let interval = Duration::from_millis(1000);
let save_text_files_flag = false;
let ocr_engine = Arc::new(OcrEngine::Tesseract);

// Spawn the continuous_capture function
let capture_handle = tokio::spawn(continuous_capture(
result_tx,
interval,
save_text_files_flag,
ocr_engine,
monitor,
));

// Wait for a short duration to allow some captures to occur
let timeout_duration = Duration::from_secs(5);
let _result = timeout(timeout_duration, async {
let mut capture_count = 0;
while let Some(capture_result) = result_rx.recv().await {
capture_count += 1;
// assert!(
// capture_result.image.width() == 100 && capture_result.image.height() == 100
// );
println!("capture_result: {:?}\n\n", capture_result.text);
if capture_count >= 3 {
break;
}
}
})
.await;

// Stop the continuous_capture task
capture_handle.abort();

// Assert that we received some results without timing out
// assert!(
// result.is_ok(),
// "Test timed out or failed to receive captures"
// );
}
}

0 comments on commit 55972a0

Please sign in to comment.