Skip to content

Commit

Permalink
v0
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Aug 3, 2024
1 parent 537d134 commit 5290356
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 2 deletions.
12 changes: 12 additions & 0 deletions screenpipe-vision/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,15 @@ harness = false
[target.'cfg(target_os = "windows")'.dependencies]
windows = { version = "0.58", features = ["Graphics_Imaging", "Media_Ocr", "Storage", "Storage_Streams"] }

[target.'cfg(target_os = "macos")'.dependencies]
core-foundation = "0.9.4"
core-graphics = "0.23.2"
objc = { version = "0.2.7", features = ["exception"] }
block = "0.1"
foreign-types-shared = "0.1"
cocoa-foundation = "0.1.2"



[build-dependencies]
cc = "1.0"
10 changes: 10 additions & 0 deletions screenpipe-vision/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#[cfg(target_os = "macos")]
fn main() {
// println!("cargo:rustc-link-lib=framework=Vision");
// println!("cargo:rustc-link-lib=framework=Foundation");
// println!("cargo:rustc-link-lib=framework=CoreGraphics");
println!("cargo:rustc-link-lib=framework=Vision");
}

#[cfg(not(target_os = "macos"))]
fn main() {}
7 changes: 6 additions & 1 deletion screenpipe-vision/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ use xcap::{Monitor, Window};

#[cfg(target_os = "windows")]
use crate::utils::perform_ocr_windows;
use crate::utils::OcrEngine;
use crate::utils::{
capture_screenshot, compare_with_previous_image, perform_ocr_cloud, perform_ocr_tesseract,
save_text_files,
};
use crate::utils::{perform_ocr_apple, OcrEngine};
use rusty_tesseract::{Data, DataOutput}; // Add this import
pub enum ControlMessage {
Pause,
Expand Down Expand Up @@ -232,6 +232,11 @@ pub async fn process_ocr_task(
debug!("Windows Native OCR");
perform_ocr_windows(&image_arc).await
}
#[cfg(target_os = "macos")]
OcrEngine::AppleNative => {
debug!("Apple Native OCR");
perform_ocr_apple(&image_arc)
}
_ => {
error!("Unsupported OCR engine");
return Err(std::io::Error::new(
Expand Down
130 changes: 129 additions & 1 deletion screenpipe-vision/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::core::MaxAverageFrame; // Assuming core.rs is in the same crate under the `core` module
use crate::core::MaxAverageFrame;
use core_foundation::base::FromVoid;
use core_graphics::base::kCGRenderingIntentDefault;
// Assuming core.rs is in the same crate under the `core` module
use image::codecs::png::PngEncoder;
use image::DynamicImage;
use image::ImageEncoder;
Expand All @@ -22,6 +25,7 @@ pub enum OcrEngine {
Unstructured,
Tesseract,
WindowsNative,
AppleNative,
}

impl Default for OcrEngine {
Expand Down Expand Up @@ -350,3 +354,127 @@ pub async fn perform_ocr_windows(image: &DynamicImage) -> (String, DataOutput, S

(text, data_output, json_output)
}

#[cfg(target_os = "macos")]
#[link(name = "Vision", kind = "framework")]
extern "C" {}

use core_foundation::string::CFString;
use core_graphics::color_space::CGColorSpace;
use core_graphics::data_provider::CGDataProvider;
use core_graphics::image::CGImage;
use core_graphics::image::CGImageAlphaInfo;
use objc::runtime::{Class, Object};
use objc::{msg_send, sel, sel_impl};
pub fn perform_ocr_apple(image: &DynamicImage) -> (String, DataOutput, String) {
// Convert DynamicImage to CGImage
let rgba_image = image.to_rgba8();
let (width, height) = rgba_image.dimensions();
let bytes_per_row = width as usize * 4;
let data = rgba_image.into_raw();

let cg_image = {
let provider = CGDataProvider::from_buffer(Arc::new(data));
CGImage::new(
width as usize,
height as usize,
8, // bits per component
32, // bits per pixel
bytes_per_row,
&CGColorSpace::create_device_rgb(),
CGImageAlphaInfo::CGImageAlphaPremultipliedLast as u32,
&provider,
true,
kCGRenderingIntentDefault,
)
};

unsafe {
let vision_class =
Class::get("VNRecognizeTextRequest").expect("VNRecognizeTextRequest class not found");
let request: *mut Object = msg_send![vision_class, alloc];
let request: *mut Object = msg_send![request, init];

println!("VNRecognizeTextRequest created successfully");

// Set up the request parameters
let _: () = msg_send![request, setRecognitionLevel:1]; // VNRequestTextRecognitionLevelAccurate
let _: () = msg_send![request, setUsesLanguageCorrection:true];

// Create VNImageRequestHandler
let handler_class =
Class::get("VNImageRequestHandler").expect("VNImageRequestHandler class not found");
let handler: *mut Object = msg_send![handler_class, alloc];
let handler: *mut Object = msg_send![handler, initWithCGImage:cg_image options:std::ptr::null::<std::ffi::c_void>()];

// Perform the request
let mut error_ptr: *mut Object = std::ptr::null_mut();
let success: bool = msg_send![handler, performRequests:&[request] error:&mut error_ptr];

if !success {
let error = if !error_ptr.is_null() {
let description: *const Object = msg_send![error_ptr, localizedDescription];
let cf_description = CFString::from_void(description as *const _);
cf_description.to_string()
} else {
"Unknown error".to_string()
};
return (
format!("Error performing OCR request: {}", error),
DataOutput {
data: vec![],
output: "".to_string(),
},
"".to_string(),
);
}

// Extract results
let results: *const Object = msg_send![request, results];
if results.is_null() {
return (
"Error: No results from OCR request".to_string(),
DataOutput {
data: vec![],
output: "".to_string(),
},
"".to_string(),
);
}

let count: usize = msg_send![results, count];
println!("Number of OCR results: {}", count);

let mut recognized_text = String::new();
for i in 0..count {
let observation: *const Object = msg_send![results, objectAtIndex:i];
if observation.is_null() {
println!("Warning: Null observation at index {}", i);
continue;
}

let text: *const Object = msg_send![observation, string];
if text.is_null() {
println!("Warning: Null text for observation at index {}", i);
continue;
}

let cf_string = CFString::from_void(text as *const _);
recognized_text.push_str(&cf_string.to_string());
recognized_text.push('\n');
}

let data_output = DataOutput {
data: vec![],
output: recognized_text.clone(),
};

let json_output = serde_json::json!({
"text": recognized_text,
"confidence": 1.0,
})
.to_string();

(recognized_text, data_output, json_output)
}
}
64 changes: 64 additions & 0 deletions screenpipe-vision/tests/apple_vision_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#[cfg(target_os = "macos")]
#[cfg(test)]
mod tests {
use screenpipe_vision::{process_ocr_task, OcrEngine};
use std::path::PathBuf;
use std::{sync::Arc, time::Instant};
use tokio::sync::{mpsc, Mutex};

#[cfg(target_os = "macos")]
#[tokio::test]
async fn test_process_ocr_task_apple() {
// Use an absolute path that works in both local and CI environments
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("tests");
path.push("testing_OCR.png");
println!("Path to testing_OCR.png: {:?}", path);
let image = image::open(&path).expect("Failed to open image");

let image_arc = Arc::new(image);
let frame_number = 1;
let timestamp = Instant::now();
let (tx, mut rx) = mpsc::channel(1);
let previous_text_json = Arc::new(Mutex::new(None));
let ocr_engine = Arc::new(OcrEngine::AppleNative);
let app_name = "test_app".to_string();

let result = process_ocr_task(
image_arc,
frame_number,
timestamp,
tx,
&previous_text_json,
false,
ocr_engine,
app_name,
)
.await;

assert!(result.is_ok(), "process_ocr_task failed: {:?}", result);

// Check if we received a result
let capture_result = rx.try_recv();
assert!(capture_result.is_ok(), "Failed to receive OCR result");

let capture_result = capture_result.unwrap();

// Add more specific assertions based on expected behavior
assert!(
!capture_result.text.is_empty(),
"OCR text should not be empty"
);
assert_eq!(capture_result.frame_number, 1, "Frame number should be 1");
assert_eq!(
capture_result.app_name, "test_app",
"App name should be 'test_app'"
);

println!("OCR text: {:?}", capture_result.text);

// You might want to add more specific checks based on the content of your test image
// For example, if your test image contains the text "Hello, World!", you could assert:
// assert!(capture_result.text.contains("Hello, World!"), "OCR text should contain 'Hello, World!'");
}
}

0 comments on commit 5290356

Please sign in to comment.