From 74f4eb3a23660555ed22a625867d45f9626e87e9 Mon Sep 17 00:00:00 2001 From: mosure Date: Sun, 5 May 2024 10:36:50 -0500 Subject: [PATCH] feat: flame inference system --- Cargo.toml | 2 +- README.md | 105 ++++++++++++++------------------------------ src/models/flame.rs | 46 +++++++++++++++++++ tools/flame.rs | 55 ++++++++++------------- 4 files changed, 105 insertions(+), 103 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fd34eed..171bab3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "bevy_ort" description = "bevy ort (onnxruntime) plugin" -version = "0.10.0" +version = "0.11.0" edition = "2021" authors = ["mosure "] license = "MIT" diff --git a/README.md b/README.md index d247aaa..18e0a75 100644 --- a/README.md +++ b/README.md @@ -34,12 +34,12 @@ use bevy::prelude::*; use bevy_ort::{ BevyOrtPlugin, - inputs, - models::modnet::{ - images_to_modnet_input, - modnet_output_to_luma_images, + models::flame::{ + FlameInput, + FlameOutput, + Flame, + FlamePlugin, }, - Onnx, }; @@ -48,88 +48,51 @@ fn main() { .add_plugins(( DefaultPlugins, BevyOrtPlugin, + FlamePlugin, )) - .init_resource::() - .add_systems(Startup, load_modnet) - .add_systems(Update, inference) + .add_systems(Startup, load_flame) + .add_systems(Startup, setup) + .add_systems(Update, on_flame_output) .run(); } -#[derive(Resource, Default)] -pub struct Modnet { - pub onnx: Handle, - pub input: Handle, -} -fn load_modnet( +fn load_flame( asset_server: Res, - mut modnet: ResMut, + mut flame: ResMut, ) { - let modnet_handle: Handle = asset_server.load("modnet_photographic_portrait_matting.onnx"); - modnet.onnx = modnet_handle; + flame.onnx = asset_server.load("models/flame.onnx"); +} - let input_handle: Handle = asset_server.load("images/person.png"); - modnet.input = input_handle; + +fn setup( + mut commands: Commands, +) { + commands.spawn(FlameInput::default()); + commands.spawn(Camera3dBundle::default()); } -fn inference( +#[derive(Debug, Component, Reflect)] +struct HandledFlameOutput; + +fn on_flame_output( mut commands: Commands, - modnet: Res, - onnx_assets: Res>, - mut images: ResMut>, - mut complete: Local, + flame_outputs: Query< + ( + Entity, + &FlameOutput, + ), + Without, + >, ) { - if *complete { - return; - } + for (entity, flame_output) in flame_outputs.iter() { + commands.entity(entity) + .insert(HandledFlameOutput); - let image = images.get(&modnet.input).expect("failed to get image asset"); - - let mask_image: Result = (|| { - let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?; - let session_lock = onnx.session.lock().map_err(|e| e.to_string())?; - let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?; - - Ok(modnet_inference(session, &[image], None).pop().unwrap()) - })(); - - match mask_image { - Ok(mask_image) => { - let mask_image = images.add(mask_image); - - commands.spawn(NodeBundle { - style: Style { - display: Display::Grid, - width: Val::Percent(100.0), - height: Val::Percent(100.0), - grid_template_columns: RepeatedGridTrack::flex(1, 1.0), - grid_template_rows: RepeatedGridTrack::flex(1, 1.0), - ..default() - }, - background_color: BackgroundColor(Color::DARK_GRAY), - ..default() - }) - .with_children(|builder| { - builder.spawn(ImageBundle { - style: Style { - ..default() - }, - image: UiImage::new(mask_image.clone()), - ..default() - }); - }); - - commands.spawn(Camera2dBundle::default()); - - *complete = true; - } - Err(e) => { - println!("inference failed: {}", e); - } + println!("{:?}", flame_output); } } - ``` diff --git a/src/models/flame.rs b/src/models/flame.rs index 6da6e92..a73c6e2 100644 --- a/src/models/flame.rs +++ b/src/models/flame.rs @@ -13,6 +13,11 @@ pub struct FlamePlugin; impl Plugin for FlamePlugin { fn build(&self, app: &mut App) { app.init_resource::(); + + app.register_type::(); + app.register_type::(); + + app.add_systems(PreUpdate, flame_inference_system); } } @@ -22,9 +27,48 @@ pub struct Flame { } +fn flame_inference_system( + mut commands: Commands, + flame: Res, + onnx_assets: Res>, + flame_inputs: Query< + ( + Entity, + &FlameInput, + ), + Without, + >, +) { + for (entity, flame_input) in flame_inputs.iter() { + let flame_output: Result = (|| { + let onnx = onnx_assets.get(&flame.onnx).ok_or("failed to get flame ONNX asset")?; + let session_lock = onnx.session.lock().map_err(|e| e.to_string())?; + let session = session_lock.as_ref().ok_or("failed to get flame session from flame ONNX asset")?; + + Ok(flame_inference( + session, + flame_input, + )) + })(); + + match flame_output { + Ok(flame_output) => { + commands.entity(entity) + .insert(flame_output); + } + Err(e) => { + warn!("{}", e); + } + } + } +} + + #[derive( Debug, Clone, + Component, + Reflect, )] pub struct FlameInput { pub shape: [[f32; 100]; 8], @@ -62,8 +106,10 @@ impl Default for FlameInput { Debug, Default, Clone, + Component, Deserialize, Serialize, + Reflect, )] pub struct FlameOutput { pub vertices: Vec<[f32; 3]>, // TODO: use Vec3 for binding diff --git a/tools/flame.rs b/tools/flame.rs index c75a058..0623c9f 100644 --- a/tools/flame.rs +++ b/tools/flame.rs @@ -5,11 +5,9 @@ use bevy_ort::{ models::flame::{ FlameInput, FlameOutput, - flame_inference, Flame, FlamePlugin, }, - Onnx, }; @@ -21,7 +19,8 @@ fn main() { FlamePlugin, )) .add_systems(Startup, load_flame) - .add_systems(Update, inference) + .add_systems(Startup, setup) + .add_systems(Update, on_flame_output) .run(); } @@ -30,41 +29,35 @@ fn load_flame( asset_server: Res, mut flame: ResMut, ) { - let flame_handle: Handle = asset_server.load("models/flame.onnx"); - flame.onnx = flame_handle; + flame.onnx = asset_server.load("models/flame.onnx"); } -fn inference( +fn setup( mut commands: Commands, - flame: Res, - onnx_assets: Res>, - mut complete: Local, ) { - if *complete { - return; - } + commands.spawn(FlameInput::default()); + commands.spawn(Camera3dBundle::default()); +} - let flame_output: Result = (|| { - let onnx = onnx_assets.get(&flame.onnx).ok_or("failed to get ONNX asset")?; - let session_lock = onnx.session.lock().map_err(|e| e.to_string())?; - let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?; - Ok(flame_inference( - session, - &FlameInput::default(), - )) - })(); +#[derive(Debug, Component, Reflect)] +struct HandledFlameOutput; + +fn on_flame_output( + mut commands: Commands, + flame_outputs: Query< + ( + Entity, + &FlameOutput, + ), + Without, + >, +) { + for (entity, flame_output) in flame_outputs.iter() { + commands.entity(entity) + .insert(HandledFlameOutput); - match flame_output { - Ok(_flame_output) => { - // TODO: insert mesh - // TODO: insert pan orbit camera - commands.spawn(Camera3dBundle::default()); - *complete = true; - } - Err(e) => { - eprintln!("inference failed: {}", e); - } + println!("{:?}", flame_output); } }