diff --git a/src/draw.rs b/src/draw.rs index d754b74..d7dfee5 100644 --- a/src/draw.rs +++ b/src/draw.rs @@ -1,5 +1,8 @@ use bevy::{ - input::mouse::{MouseButtonInput, MouseMotion, MouseWheel}, + input::{ + mouse::{MouseButtonInput, MouseMotion, MouseWheel}, + ElementState, + }, prelude::*, window::CursorMoved, }; @@ -13,15 +16,13 @@ pub enum ImageEvent { Clear, } -const INPUT_IMG_SIZE: u32 = 128; - -const WINDOW_WIDTH: f32 = 1350.; -const WINDOW_HEIGHT: f32 = 700.; +pub const WINDOW_WIDTH: f32 = 1350.; +pub const WINDOW_HEIGHT: f32 = 700.; // Offset from left top corner -const OFFSET: f32 = WINDOW_HEIGHT / 14.; -const CANVAS_WIDTH: f32 = (WINDOW_WIDTH - OFFSET * 3.0) / 2.0; -const CANVAS_HEIGHT: f32 = WINDOW_HEIGHT - OFFSET * 2.0; +pub const OFFSET: f32 = WINDOW_HEIGHT / 14.; +pub const CANVAS_WIDTH: f32 = (WINDOW_WIDTH - OFFSET * 3.0) / 2.0; +pub const CANVAS_HEIGHT: f32 = WINDOW_HEIGHT - OFFSET * 2.0; pub fn clear_canvas( keyboard_input: Res>, @@ -51,6 +52,8 @@ pub fn create_canvas( // Setup images on the right canvas + return; + let texture1 = asset_server.load("axe1.png"); let texture2 = asset_server.load("axe2.png"); let texture3 = asset_server.load("axe3.png"); @@ -174,6 +177,7 @@ fn clear_inference( pub fn mouse_draw( mut cursor_moved_events: EventReader, + mut mouse_button_input_events: EventReader, mut image_events: EventWriter, mut last_mouse_position: Local>, drawable: Query<(&Interaction, &GlobalTransform, &Style), With>, @@ -198,13 +202,6 @@ pub fn mouse_draw( 0. }; - // dbg!(transform.translation); - // [examples/ui/ui.rs:89] transform.translation = Vec3( - // 400.0, - // 320.0, - // 0.001, - // ) - for event in cursor_moved_events.iter() { // info!("{:?}", event.position); @@ -235,8 +232,14 @@ pub fn mouse_draw( *last_mouse_position = Some(event.position); } - } else { - // println!("None"); + } + } + + for event in mouse_button_input_events.iter() { + // info!("mouse_button_input_events: {:?}", event); + + if event.state == ElementState::Released { + *last_mouse_position = None; } } } @@ -258,8 +261,8 @@ pub fn update_canvas( ImageEvent::DrawPos(pos) => { let x_scale = texture.size.width as f32 / CANVAS_WIDTH; let y_scale = texture.size.height as f32 / CANVAS_HEIGHT; - let line_scale = 5; - let line_radius = 5; + let line_scale = 2; + let line_radius = 1; for i in -line_radius..=line_radius { for j in -line_radius..=line_radius { diff --git a/src/model.rs b/src/model.rs index 7fa2657..5a0eac8 100644 --- a/src/model.rs +++ b/src/model.rs @@ -4,17 +4,31 @@ use bevy::{ reflect::TypeUuid, }; use image::{imageops::FilterType, ImageBuffer, RgbImage}; -use std::time::{Duration, Instant}; +use std::{ + path::PathBuf, + time::{Duration, Instant}, +}; use tract_ndarray::Array; use tract_onnx::prelude::*; use wasm_bindgen::prelude::*; -use crate::draw::{ImageEvent, TestCanvas}; +use crate::draw::{ + Canvas, ImageEvent, TestCanvas, CANVAS_HEIGHT, CANVAS_WIDTH, OFFSET, WINDOW_HEIGHT, + WINDOW_WIDTH, +}; + +const INPUT_IMG_SIZE: u32 = 128; #[wasm_bindgen] extern "C" { #[wasm_bindgen(js_namespace = console)] fn log(s: &str); + + #[wasm_bindgen(js_namespace = console)] + pub fn time(s: &str); + + #[wasm_bindgen(js_namespace = console)] + pub fn timeEnd(s: &str); } macro_rules! console_log { @@ -68,24 +82,32 @@ pub struct State { pub inference_state: InferenceState, } +enum ImageClass { + Rabbit = 1, + Axe, + SmileyFace, +} + impl FromWorld for State { fn from_world(world: &mut World) -> Self { let asset_server = world.get_resource::().unwrap(); State { inference_state: InferenceState::Wait, - model: asset_server.load("resnet50.onnx"), + model: asset_server.load("cnn_sketch_3class.onnx"), } } } pub fn infer_sketch( + mut commands: Commands, + asset_server: Res, mut image_events: EventReader, keyboard_input: Res>, - materials: ResMut>, + mut materials: ResMut>, textures: Res>, models: Res>, mut state: ResMut, - drawable: Query<&Handle, With>, + drawable: Query<&Handle, With>, ) { // If canvas is cleared and nothing drawed then return without inference for event in image_events.iter() { @@ -97,9 +119,7 @@ pub fn infer_sketch( if keyboard_input.just_pressed(KeyCode::B) && state.inference_state == InferenceState::Infer { for mat in drawable.iter() { - println!("Save image"); - - let material = materials.get(mat).unwrap(); + let material = &materials.get(mat).unwrap(); let texture = textures.get(material.texture.as_ref().unwrap()).unwrap(); let mut img: RgbImage = ImageBuffer::new(texture.size.width, texture.size.height); @@ -115,23 +135,65 @@ pub fn infer_sketch( } } + // #[cfg(not(target_arch = "wasm32"))] + // img.save("image.png").unwrap(); + + let resized = + image::imageops::resize(&img, INPUT_IMG_SIZE, INPUT_IMG_SIZE, FilterType::Triangle); + + // let tensor_image = tract_ndarray::Array4::from_shape_fn( + // (1, 1, INPUT_IMG_SIZE as usize, INPUT_IMG_SIZE as usize), + // |(_, _, y, x)| resized[(x as _, y as _)][0] as f32, + // ); + // println!("tensor_image shape: {:?}", tensor_image.shape()); + // for i in 0..10 { + // for j in 0..10 { + // let a = resized.get_pixel(i, j); + // print!("({} {} {}),", a[0], a[1], a[2]); + // } + // println!(); + // } + + let tensor_image: Tensor = ((tract_ndarray::Array3::from_shape_fn( + (1, INPUT_IMG_SIZE as usize, INPUT_IMG_SIZE as usize), + |(_, y, x)| { + // Convert RGB to gray scale value + // let r = resized[(x as _, y as _)][0] as f32; + // let g = resized[(x as _, y as _)][1] as f32; + // let b = resized[(x as _, y as _)][2] as f32; + // (r * 0.3 + g * 0.59 + b * 0.11) / 255.0 + resized[(x as _, y as _)][0] as f32 / 255.0 + }, + ) - 0.5) + / 0.5) + .into(); + #[cfg(not(target_arch = "wasm32"))] - img.save("image.png").unwrap(); + resized.save("resized.png").unwrap(); - // Imagenet mean and standard deviation - let mean = Array::from_shape_vec((1, 3, 1, 1), vec![0.485, 0.456, 0.406]).unwrap(); - let std = Array::from_shape_vec((1, 3, 1, 1), vec![0.229, 0.224, 0.225]).unwrap(); + // dbg!(&tensor_image.shape()); + // dbg!(&tensor_image); + for i in 0..10 { + for j in 0..10 { + // print!("{:?} ", resized[(i, j)]); + // dbg!(resized[(i, j)]); - let resized = image::imageops::resize(&img, 224, 224, FilterType::Triangle); - let tensor_image: Tensor = - ((tract_ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| { - resized[(x as _, y as _)][c] as f32 / 255.0 - }) - mean) - / std) - .into(); + // let a = resized.get_pixel(i, j); + // print!("[({} {} {}),", a[0], a[1], a[2]); + } + // println!(); + } + #[cfg(not(target_arch = "wasm32"))] let start = Instant::now(); + // #[cfg(target_arch = "wasm32")] + // console_log!("aaaaaaa"); + // return; + + #[cfg(target_arch = "wasm32")] + time("infer"); + if let Some(model) = models.get(state.model.as_weak::()) { let result = model.model.run(tvec!(tensor_image)).unwrap(); @@ -148,13 +210,20 @@ pub fn infer_sketch( println!("{} {}", score, class); + // let a = result[0].to_array_view::().unwrap(); + // println!("{}", a); + + #[cfg(not(target_arch = "wasm32"))] let duration = start.elapsed(); + #[cfg(not(target_arch = "wasm32"))] println!("Inference time: {:?}", duration); #[cfg(target_arch = "wasm32")] console_log!("{} {}", score, class); #[cfg(target_arch = "wasm32")] - console_log!("Inference time: {:?}", duration); + timeEnd("infer"); + + show_infer_result(&mut commands, &asset_server, &mut materials, class); } state.inference_state = InferenceState::Wait; @@ -162,10 +231,87 @@ pub fn infer_sketch( } } +fn show_infer_result( + commands: &mut Commands, + asset_server: &Res, + materials: &mut ResMut>, + class: u32, +) { + let path = match class { + 1 => "rabbit", + 2 => "axe", + 3 => "smile", + _ => "err", + }; + let path = path.to_string(); + + let texture1 = asset_server.load(PathBuf::from(path.clone() + "1.png")); + let texture2 = asset_server.load(PathBuf::from(path.clone() + "2.png")); + let texture3 = asset_server.load(PathBuf::from(path.clone() + "3.png")); + let texture4 = asset_server.load(PathBuf::from(path + "4.png")); + + // Upper left + commands.spawn_bundle(SpriteBundle { + sprite: Sprite::new(Vec2::new(CANVAS_WIDTH / 2., CANVAS_HEIGHT / 2.)), + material: materials.add(texture1.into()), + transform: Transform { + translation: Vec3::new( + WINDOW_WIDTH / 2. - OFFSET - CANVAS_WIDTH / 2. - CANVAS_WIDTH / 4., + WINDOW_HEIGHT / 2. - OFFSET - CANVAS_HEIGHT / 4., + 0., + ), + ..Default::default() + }, + ..Default::default() + }); + // Upper right + commands.spawn_bundle(SpriteBundle { + sprite: Sprite::new(Vec2::new(CANVAS_WIDTH / 2., CANVAS_HEIGHT / 2.)), + material: materials.add(texture2.into()), + transform: Transform { + translation: Vec3::new( + WINDOW_WIDTH / 2. - OFFSET - CANVAS_WIDTH / 4., + WINDOW_HEIGHT / 2. - OFFSET - CANVAS_HEIGHT / 4., + 0., + ), + ..Default::default() + }, + ..Default::default() + }); + // Lower left + commands.spawn_bundle(SpriteBundle { + sprite: Sprite::new(Vec2::new(CANVAS_WIDTH / 2., CANVAS_HEIGHT / 2.)), + material: materials.add(texture3.into()), + transform: Transform { + translation: Vec3::new( + WINDOW_WIDTH / 2. - OFFSET - CANVAS_WIDTH / 2. - CANVAS_WIDTH / 4., + -(WINDOW_HEIGHT / 2. - OFFSET - CANVAS_HEIGHT / 4.), + 0., + ), + ..Default::default() + }, + ..Default::default() + }); + // Lower right + commands.spawn_bundle(SpriteBundle { + sprite: Sprite::new(Vec2::new(CANVAS_WIDTH / 2., CANVAS_HEIGHT / 2.)), + material: materials.add(texture4.into()), + transform: Transform { + translation: Vec3::new( + WINDOW_WIDTH / 2. - OFFSET - CANVAS_WIDTH / 4., + -(WINDOW_HEIGHT / 2. - OFFSET - CANVAS_HEIGHT / 4.), + 0., + ), + ..Default::default() + }, + ..Default::default() + }); +} + pub fn infer_timer(time: Res