Skip to content

Commit

Permalink
feat: windows native ocr
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Aug 3, 2024
1 parent 1631114 commit 0756163
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 69 deletions.
24 changes: 23 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ on:
branches: [main]

jobs:
test:
test-ubuntu:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
Expand All @@ -37,3 +37,25 @@ jobs:
- name: Run tests
run: cargo test

test-windows:
runs-on: windows-latest
steps:
- uses: actions/checkout@v3
- uses: actions/cache@v3
with:
path: |
~\AppData\Local\cargo\
target\
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}

- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable

# - name: Run tests
# run: cargo test

- name: Run specific Windows OCR test
run: cargo test test_process_ocr_task_windows
39 changes: 33 additions & 6 deletions screenpipe-server/src/bin/screenpipe-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,34 @@ use screenpipe_audio::{
default_input_device, default_output_device, list_audio_devices, parse_audio_device,
DeviceControl,
};
use screenpipe_vision::OcrEngine;
use std::io::Write;

use screenpipe_core::find_ffmpeg_path;
use screenpipe_server::logs::MultiWriter;
use screenpipe_server::{start_continuous_recording, DatabaseManager, ResourceMonitor, Server};
use tokio::sync::mpsc::channel;

use clap::ValueEnum;
use screenpipe_vision::utils::OcrEngine as CoreOcrEngine;

#[derive(Clone, Debug, ValueEnum, PartialEq)]
enum CliOcrEngine {
Deepgram,
Tesseract,
WindowsNative,
}

impl From<CliOcrEngine> for CoreOcrEngine {
fn from(cli_engine: CliOcrEngine) -> Self {
match cli_engine {
CliOcrEngine::Deepgram => CoreOcrEngine::Deepgram,
CliOcrEngine::Tesseract => CoreOcrEngine::Tesseract,
CliOcrEngine::WindowsNative => CoreOcrEngine::WindowsNative,
}
}
}

// keep in mind this is the most important feature ever // TODO: add a pipe and a ⭐️ e.g screen | ⭐️ somehow in ascii ♥️🤓
const DISPLAY: &str = r"
_
Expand Down Expand Up @@ -88,9 +109,11 @@ struct Cli {
#[arg(long, default_value_t = false)]
cloud_audio_on: bool,

/// Enable cloud OCR processing
#[arg(long, default_value_t = false)]
cloud_ocr_on: bool,
/// OCR engine to use. Tesseract is a local OCR engine (default).
/// WindowsNative is a local OCR engine for Windows.
/// Deepgram is a cloud OCR engine (free of charge on us)
#[arg(long, value_enum, default_value_t = CliOcrEngine::Tesseract)]
ocr_engine: CliOcrEngine,

/// UID key for sending data to friend wearable (if not provided, data won't be sent)
#[arg(long)]
Expand Down Expand Up @@ -278,6 +301,8 @@ async fn main() -> anyhow::Result<()> {
// Before the loop starts, clone friend_wearable_uid
let friend_wearable_uid = cli.friend_wearable_uid.clone();

let warning_ocr_engine_clone = cli.ocr_engine.clone();

// Function to start or restart the recording task
let _start_recording = tokio::spawn(async move {
// hack
Expand All @@ -301,6 +326,8 @@ async fn main() -> anyhow::Result<()> {
recording_task.abort();
}
}
let core_ocr_engine: CoreOcrEngine = cli.ocr_engine.clone().into();
let ocr_engine = Arc::new(OcrEngine::from(core_ocr_engine));
recording_task = tokio::spawn(async move {
let result = start_continuous_recording(
db_clone,
Expand All @@ -311,7 +338,7 @@ async fn main() -> anyhow::Result<()> {
audio_devices_control,
cli.save_text_files,
cli.cloud_audio_on,
cli.cloud_ocr_on,
ocr_engine,
friend_wearable_uid_clone, // Use the cloned version
)
.await;
Expand Down Expand Up @@ -362,7 +389,7 @@ async fn main() -> anyhow::Result<()> {
);

// Add warning for cloud arguments
if cli.cloud_audio_on || cli.cloud_ocr_on {
if cli.cloud_audio_on || warning_ocr_engine_clone == CliOcrEngine::Deepgram {
println!(
"{}",
"WARNING: You are using cloud now. Make sure to understand the data privacy risks."
Expand All @@ -380,4 +407,4 @@ async fn main() -> anyhow::Result<()> {
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
18 changes: 12 additions & 6 deletions screenpipe-server/src/bin/screenpipe-video.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use chrono::Utc;
use clap::Parser;
use env_logger::Env;
use image::GenericImageView;
use log::info;
use screenpipe_server::core::DataOutputWrapper;
use screenpipe_server::VideoCapture;
use screenpipe_vision::OcrEngine;
use serde_json::{json, Value};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use std::fs::{File, OpenOptions};
use std::io::{BufWriter, Write};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
use clap::Parser;
use screenpipe_server::core::DataOutputWrapper; // Correct import
use tokio::sync::mpsc::{channel, Receiver, Sender}; // Correct import

#[derive(Parser)]
#[command(author, version, about, long_about = None)]
Expand Down Expand Up @@ -39,7 +40,6 @@ async fn main() {

let cli = Cli::parse();
let save_text_files = cli.save_text_files;
let cloud_ocr = !cli.cloud_ocr_off; // Determine the cloud_ocr flag

let time = Utc::now();
let formatted_time = time.format("%Y-%m-%d_%H-%M-%S").to_string();
Expand All @@ -57,7 +57,13 @@ async fn main() {
}
};

let video_capture = VideoCapture::new(output_path, fps, new_chunk_callback, save_text_files, cloud_ocr); // Pass the cloud_ocr flag
let video_capture = VideoCapture::new(
output_path,
fps,
new_chunk_callback,
save_text_files,
Arc::new(OcrEngine::Tesseract),
); // Pass the cloud_ocr flag
let (_tx, rx): (Sender<()>, Receiver<()>) = channel(32);
let rx = Arc::new(Mutex::new(rx));
let rx_thread = rx.clone();
Expand Down Expand Up @@ -111,4 +117,4 @@ async fn main() {
video_capture.stop().await;
println!("Video capture completed. Output saved to: {}", output_path);
println!("JSON data saved to: {}", json_output_path);
}
}
22 changes: 13 additions & 9 deletions screenpipe-server/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@ use crate::{DatabaseManager, VideoCapture};
use anyhow::Result;
use chrono::Utc;
use crossbeam::queue::SegQueue;
use external_cloud_integrations::friend_wearable::send_data_to_friend_wearable;
use log::{debug, error, info, warn};
use screenpipe_audio::{
create_whisper_channel, record_and_transcribe, AudioDevice, AudioInput, DeviceControl,
TranscriptionResult,
};
use screenpipe_vision::OcrEngine;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::task::JoinHandle;
use external_cloud_integrations::friend_wearable::send_data_to_friend_wearable;

pub enum RecorderControl {
Pause,
Expand Down Expand Up @@ -52,7 +53,7 @@ pub async fn start_continuous_recording(
audio_devices_control: Arc<SegQueue<(AudioDevice, DeviceControl)>>,
save_text_files: bool,
cloud_audio: bool,
cloud_ocr: bool,
ocr_engine: Arc<OcrEngine>,
friend_wearable_uid: Option<String>, // Updated parameter
) -> Result<()> {
info!("Recording now");
Expand All @@ -76,7 +77,7 @@ pub async fn start_continuous_recording(
fps,
is_running_video,
save_text_files,
cloud_ocr,
ocr_engine,
friend_wearable_uid_video, // Use the cloned version
)
.await
Expand Down Expand Up @@ -108,7 +109,7 @@ async fn record_video(
fps: f64,
is_running: Arc<AtomicBool>,
save_text_files: bool,
cloud_ocr: bool,
ocr_engine: Arc<OcrEngine>,
friend_wearable_uid: Option<String>, // Updated parameter
) -> Result<()> {
debug!("record_video: Starting");
Expand All @@ -130,7 +131,7 @@ async fn record_video(
fps,
new_chunk_callback,
save_text_files,
cloud_ocr,
ocr_engine,
);

while is_running.load(Ordering::SeqCst) {
Expand All @@ -152,7 +153,7 @@ async fn record_video(
&text_json,
&new_text_json_vs_previous_frame,
&raw_data_output_from_ocr,
&frame.app_name
&frame.app_name,
)
.await
{
Expand Down Expand Up @@ -353,7 +354,7 @@ async fn process_audio_result(
"Inserted audio transcription for chunk {} from device {}",
audio_chunk_id, result.input.device
);

// Send data to friend wearable
if let Some(uid) = friend_wearable_uid {
if let Err(e) = send_data_to_friend_wearable(
Expand All @@ -364,7 +365,10 @@ async fn process_audio_result(
) {
error!("Failed to send data to friend wearable: {}", e);
} else {
debug!("Sent audio data to friend wearable for chunk {}", audio_chunk_id);
debug!(
"Sent audio data to friend wearable for chunk {}",
audio_chunk_id
);
}
}
}
Expand All @@ -374,4 +378,4 @@ async fn process_audio_result(
result.input.device, e
),
}
}
}
8 changes: 4 additions & 4 deletions screenpipe-server/src/video.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use chrono::Utc;
use image::ImageFormat::{self};
use log::{debug, error, info, warn};
use screenpipe_core::find_ffmpeg_path;
use screenpipe_vision::{continuous_capture, CaptureResult, ControlMessage};
use screenpipe_vision::{continuous_capture, CaptureResult, ControlMessage, OcrEngine};
use std::collections::VecDeque;
use std::path::PathBuf;
use std::process::Stdio;
Expand Down Expand Up @@ -32,7 +32,7 @@ impl VideoCapture {
fps: f64,
new_chunk_callback: impl Fn(&str) + Send + Sync + 'static,
save_text_files: bool,
cloud_ocr: bool, // Added cloud_ocr parameter
ocr_engine: Arc<OcrEngine>,
) -> Self {
info!("Starting new video capture");
let (control_tx, mut control_rx) = channel(512);
Expand All @@ -55,7 +55,7 @@ impl VideoCapture {
result_sender,
Duration::from_secs_f64(1.0 / fps),
save_text_files,
cloud_ocr, // Pass the cloud_ocr flag
ocr_engine,
)
.await;
});
Expand Down Expand Up @@ -351,4 +351,4 @@ async fn start_ffmpeg_process(output_file: &str, fps: f64) -> Result<Child, anyh
debug!("FFmpeg process spawned");

Ok(child)
}
}
7 changes: 6 additions & 1 deletion screenpipe-vision/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,9 @@ path = "src/bin/screenpipe-vision.rs"

[[bench]]
name = "vision_benchmark"
harness = false
harness = false


[target.'cfg(target_os = "windows")'.dependencies]
windows = { version = "0.48", features = ["Graphics_Imaging", "Media_Ocr", "Storage", "Storage_Streams"] }

18 changes: 12 additions & 6 deletions screenpipe-vision/src/bin/screenpipe-vision.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use screenpipe_vision::{continuous_capture, ControlMessage};
use std::time::Duration;
use tokio::sync::mpsc::channel;
use clap::Parser;
use screenpipe_vision::{continuous_capture, ControlMessage, OcrEngine};
use std::{sync::Arc, time::Duration};
use tokio::sync::mpsc::channel;

#[derive(Parser)]
#[command(author, version, about, long_about = None)]
Expand All @@ -23,10 +23,16 @@ async fn main() {
let (result_tx, mut result_rx) = channel(512);

let save_text_files = cli.save_text_files;
let cloud_ocr = !cli.cloud_ocr_off; // Determine the cloud_ocr flag

let capture_thread = tokio::spawn(async move {
continuous_capture(&mut control_rx, result_tx, Duration::from_secs(1), save_text_files, cloud_ocr).await
continuous_capture(
&mut control_rx,
result_tx,
Duration::from_secs(1),
save_text_files,
Arc::new(OcrEngine::Tesseract),
)
.await
});

// Example: Process results for 10 seconds, then pause for 5 seconds, then stop
Expand All @@ -48,4 +54,4 @@ async fn main() {
}

capture_thread.await.unwrap();
}
}
Loading

0 comments on commit 0756163

Please sign in to comment.