Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
toomuat committed Dec 20, 2021
1 parent 76f046b commit 030db50
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 40 deletions.
41 changes: 22 additions & 19 deletions src/draw.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use bevy::{
input::mouse::{MouseButtonInput, MouseMotion, MouseWheel},
input::{
mouse::{MouseButtonInput, MouseMotion, MouseWheel},
ElementState,
},
prelude::*,
window::CursorMoved,
};
Expand All @@ -13,15 +16,13 @@ pub enum ImageEvent {
Clear,
}

const INPUT_IMG_SIZE: u32 = 128;

const WINDOW_WIDTH: f32 = 1350.;
const WINDOW_HEIGHT: f32 = 700.;
pub const WINDOW_WIDTH: f32 = 1350.;
pub const WINDOW_HEIGHT: f32 = 700.;

// Offset from left top corner
const OFFSET: f32 = WINDOW_HEIGHT / 14.;
const CANVAS_WIDTH: f32 = (WINDOW_WIDTH - OFFSET * 3.0) / 2.0;
const CANVAS_HEIGHT: f32 = WINDOW_HEIGHT - OFFSET * 2.0;
pub const OFFSET: f32 = WINDOW_HEIGHT / 14.;
pub const CANVAS_WIDTH: f32 = (WINDOW_WIDTH - OFFSET * 3.0) / 2.0;
pub const CANVAS_HEIGHT: f32 = WINDOW_HEIGHT - OFFSET * 2.0;

pub fn clear_canvas(
keyboard_input: Res<Input<KeyCode>>,
Expand Down Expand Up @@ -51,6 +52,8 @@ pub fn create_canvas(

// Setup images on the right canvas

return;

let texture1 = asset_server.load("axe1.png");
let texture2 = asset_server.load("axe2.png");
let texture3 = asset_server.load("axe3.png");
Expand Down Expand Up @@ -174,6 +177,7 @@ fn clear_inference(

pub fn mouse_draw(
mut cursor_moved_events: EventReader<CursorMoved>,
mut mouse_button_input_events: EventReader<MouseButtonInput>,
mut image_events: EventWriter<ImageEvent>,
mut last_mouse_position: Local<Option<Vec2>>,
drawable: Query<(&Interaction, &GlobalTransform, &Style), With<Canvas>>,
Expand All @@ -198,13 +202,6 @@ pub fn mouse_draw(
0.
};

// dbg!(transform.translation);
// [examples/ui/ui.rs:89] transform.translation = Vec3(
// 400.0,
// 320.0,
// 0.001,
// )

for event in cursor_moved_events.iter() {
// info!("{:?}", event.position);

Expand Down Expand Up @@ -235,8 +232,14 @@ pub fn mouse_draw(

*last_mouse_position = Some(event.position);
}
} else {
// println!("None");
}
}

for event in mouse_button_input_events.iter() {
// info!("mouse_button_input_events: {:?}", event);

if event.state == ElementState::Released {
*last_mouse_position = None;
}
}
}
Expand All @@ -258,8 +261,8 @@ pub fn update_canvas(
ImageEvent::DrawPos(pos) => {
let x_scale = texture.size.width as f32 / CANVAS_WIDTH;
let y_scale = texture.size.height as f32 / CANVAS_HEIGHT;
let line_scale = 5;
let line_radius = 5;
let line_scale = 2;
let line_radius = 1;

for i in -line_radius..=line_radius {
for j in -line_radius..=line_radius {
Expand Down
188 changes: 167 additions & 21 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,31 @@ use bevy::{
reflect::TypeUuid,
};
use image::{imageops::FilterType, ImageBuffer, RgbImage};
use std::time::{Duration, Instant};
use std::{
path::PathBuf,
time::{Duration, Instant},
};
use tract_ndarray::Array;
use tract_onnx::prelude::*;
use wasm_bindgen::prelude::*;

use crate::draw::{ImageEvent, TestCanvas};
use crate::draw::{
Canvas, ImageEvent, TestCanvas, CANVAS_HEIGHT, CANVAS_WIDTH, OFFSET, WINDOW_HEIGHT,
WINDOW_WIDTH,
};

const INPUT_IMG_SIZE: u32 = 128;

#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(js_namespace = console)]
fn log(s: &str);

#[wasm_bindgen(js_namespace = console)]
pub fn time(s: &str);

#[wasm_bindgen(js_namespace = console)]
pub fn timeEnd(s: &str);
}

macro_rules! console_log {
Expand Down Expand Up @@ -68,24 +82,32 @@ pub struct State {
pub inference_state: InferenceState,
}

enum ImageClass {
Rabbit = 1,
Axe,
SmileyFace,
}

impl FromWorld for State {
fn from_world(world: &mut World) -> Self {
let asset_server = world.get_resource::<AssetServer>().unwrap();
State {
inference_state: InferenceState::Wait,
model: asset_server.load("resnet50.onnx"),
model: asset_server.load("cnn_sketch_3class.onnx"),
}
}
}

pub fn infer_sketch(
mut commands: Commands,
asset_server: Res<AssetServer>,
mut image_events: EventReader<ImageEvent>,
keyboard_input: Res<Input<KeyCode>>,
materials: ResMut<Assets<ColorMaterial>>,
mut materials: ResMut<Assets<ColorMaterial>>,
textures: Res<Assets<Texture>>,
models: Res<Assets<OnnxModelAsset>>,
mut state: ResMut<State>,
drawable: Query<&Handle<ColorMaterial>, With<TestCanvas>>,
drawable: Query<&Handle<ColorMaterial>, With<Canvas>>,
) {
// If canvas is cleared and nothing drawed then return without inference
for event in image_events.iter() {
Expand All @@ -97,9 +119,7 @@ pub fn infer_sketch(

if keyboard_input.just_pressed(KeyCode::B) && state.inference_state == InferenceState::Infer {
for mat in drawable.iter() {
println!("Save image");

let material = materials.get(mat).unwrap();
let material = &materials.get(mat).unwrap();
let texture = textures.get(material.texture.as_ref().unwrap()).unwrap();

let mut img: RgbImage = ImageBuffer::new(texture.size.width, texture.size.height);
Expand All @@ -115,23 +135,65 @@ pub fn infer_sketch(
}
}

// #[cfg(not(target_arch = "wasm32"))]
// img.save("image.png").unwrap();

let resized =
image::imageops::resize(&img, INPUT_IMG_SIZE, INPUT_IMG_SIZE, FilterType::Triangle);

// let tensor_image = tract_ndarray::Array4::from_shape_fn(
// (1, 1, INPUT_IMG_SIZE as usize, INPUT_IMG_SIZE as usize),
// |(_, _, y, x)| resized[(x as _, y as _)][0] as f32,
// );
// println!("tensor_image shape: {:?}", tensor_image.shape());
// for i in 0..10 {
// for j in 0..10 {
// let a = resized.get_pixel(i, j);
// print!("({} {} {}),", a[0], a[1], a[2]);
// }
// println!();
// }

let tensor_image: Tensor = ((tract_ndarray::Array3::from_shape_fn(
(1, INPUT_IMG_SIZE as usize, INPUT_IMG_SIZE as usize),
|(_, y, x)| {
// Convert RGB to gray scale value
// let r = resized[(x as _, y as _)][0] as f32;
// let g = resized[(x as _, y as _)][1] as f32;
// let b = resized[(x as _, y as _)][2] as f32;
// (r * 0.3 + g * 0.59 + b * 0.11) / 255.0
resized[(x as _, y as _)][0] as f32 / 255.0
},
) - 0.5)
/ 0.5)
.into();

#[cfg(not(target_arch = "wasm32"))]
img.save("image.png").unwrap();
resized.save("resized.png").unwrap();

// Imagenet mean and standard deviation
let mean = Array::from_shape_vec((1, 3, 1, 1), vec![0.485, 0.456, 0.406]).unwrap();
let std = Array::from_shape_vec((1, 3, 1, 1), vec![0.229, 0.224, 0.225]).unwrap();
// dbg!(&tensor_image.shape());
// dbg!(&tensor_image);
for i in 0..10 {
for j in 0..10 {
// print!("{:?} ", resized[(i, j)]);
// dbg!(resized[(i, j)]);

let resized = image::imageops::resize(&img, 224, 224, FilterType::Triangle);
let tensor_image: Tensor =
((tract_ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| {
resized[(x as _, y as _)][c] as f32 / 255.0
}) - mean)
/ std)
.into();
// let a = resized.get_pixel(i, j);
// print!("[({} {} {}),", a[0], a[1], a[2]);
}
// println!();
}

#[cfg(not(target_arch = "wasm32"))]
let start = Instant::now();

// #[cfg(target_arch = "wasm32")]
// console_log!("aaaaaaa");
// return;

#[cfg(target_arch = "wasm32")]
time("infer");

if let Some(model) = models.get(state.model.as_weak::<OnnxModelAsset>()) {
let result = model.model.run(tvec!(tensor_image)).unwrap();

Expand All @@ -148,24 +210,108 @@ pub fn infer_sketch(

println!("{} {}", score, class);

// let a = result[0].to_array_view::<f32>().unwrap();
// println!("{}", a);

#[cfg(not(target_arch = "wasm32"))]
let duration = start.elapsed();
#[cfg(not(target_arch = "wasm32"))]
println!("Inference time: {:?}", duration);

#[cfg(target_arch = "wasm32")]
console_log!("{} {}", score, class);
#[cfg(target_arch = "wasm32")]
console_log!("Inference time: {:?}", duration);
timeEnd("infer");

show_infer_result(&mut commands, &asset_server, &mut materials, class);
}

state.inference_state = InferenceState::Wait;
}
}
}

fn show_infer_result(
commands: &mut Commands,
asset_server: &Res<AssetServer>,
materials: &mut ResMut<Assets<ColorMaterial>>,
class: u32,
) {
let path = match class {
1 => "rabbit",
2 => "axe",
3 => "smile",
_ => "err",
};
let path = path.to_string();

let texture1 = asset_server.load(PathBuf::from(path.clone() + "1.png"));
let texture2 = asset_server.load(PathBuf::from(path.clone() + "2.png"));
let texture3 = asset_server.load(PathBuf::from(path.clone() + "3.png"));
let texture4 = asset_server.load(PathBuf::from(path + "4.png"));

// Upper left
commands.spawn_bundle(SpriteBundle {
sprite: Sprite::new(Vec2::new(CANVAS_WIDTH / 2., CANVAS_HEIGHT / 2.)),
material: materials.add(texture1.into()),
transform: Transform {
translation: Vec3::new(
WINDOW_WIDTH / 2. - OFFSET - CANVAS_WIDTH / 2. - CANVAS_WIDTH / 4.,
WINDOW_HEIGHT / 2. - OFFSET - CANVAS_HEIGHT / 4.,
0.,
),
..Default::default()
},
..Default::default()
});
// Upper right
commands.spawn_bundle(SpriteBundle {
sprite: Sprite::new(Vec2::new(CANVAS_WIDTH / 2., CANVAS_HEIGHT / 2.)),
material: materials.add(texture2.into()),
transform: Transform {
translation: Vec3::new(
WINDOW_WIDTH / 2. - OFFSET - CANVAS_WIDTH / 4.,
WINDOW_HEIGHT / 2. - OFFSET - CANVAS_HEIGHT / 4.,
0.,
),
..Default::default()
},
..Default::default()
});
// Lower left
commands.spawn_bundle(SpriteBundle {
sprite: Sprite::new(Vec2::new(CANVAS_WIDTH / 2., CANVAS_HEIGHT / 2.)),
material: materials.add(texture3.into()),
transform: Transform {
translation: Vec3::new(
WINDOW_WIDTH / 2. - OFFSET - CANVAS_WIDTH / 2. - CANVAS_WIDTH / 4.,
-(WINDOW_HEIGHT / 2. - OFFSET - CANVAS_HEIGHT / 4.),
0.,
),
..Default::default()
},
..Default::default()
});
// Lower right
commands.spawn_bundle(SpriteBundle {
sprite: Sprite::new(Vec2::new(CANVAS_WIDTH / 2., CANVAS_HEIGHT / 2.)),
material: materials.add(texture4.into()),
transform: Transform {
translation: Vec3::new(
WINDOW_WIDTH / 2. - OFFSET - CANVAS_WIDTH / 4.,
-(WINDOW_HEIGHT / 2. - OFFSET - CANVAS_HEIGHT / 4.),
0.,
),
..Default::default()
},
..Default::default()
});
}

pub fn infer_timer(time: Res<Time>, mut state: ResMut<State>, mut query: Query<&mut Timer>) {
for mut timer in query.iter_mut() {
if timer.tick(time.delta()).finished() {
info!("Entity timer just finished");
// info!("Entity timer just finished");

state.inference_state = InferenceState::Infer;
}
Expand Down

0 comments on commit 030db50

Please sign in to comment.