Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: flame inference system #16

Merged
merged 1 commit into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "bevy_ort"
description = "bevy ort (onnxruntime) plugin"
version = "0.10.0"
version = "0.11.0"
edition = "2021"
authors = ["mosure <[email protected]>"]
license = "MIT"
Expand Down
105 changes: 34 additions & 71 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ use bevy::prelude::*;

use bevy_ort::{
BevyOrtPlugin,
inputs,
models::modnet::{
images_to_modnet_input,
modnet_output_to_luma_images,
models::flame::{
FlameInput,
FlameOutput,
Flame,
FlamePlugin,
},
Onnx,
};


Expand All @@ -48,88 +48,51 @@ fn main() {
.add_plugins((
DefaultPlugins,
BevyOrtPlugin,
FlamePlugin,
))
.init_resource::<Modnet>()
.add_systems(Startup, load_modnet)
.add_systems(Update, inference)
.add_systems(Startup, load_flame)
.add_systems(Startup, setup)
.add_systems(Update, on_flame_output)
.run();
}

#[derive(Resource, Default)]
pub struct Modnet {
pub onnx: Handle<Onnx>,
pub input: Handle<Image>,
}

fn load_modnet(
fn load_flame(
asset_server: Res<AssetServer>,
mut modnet: ResMut<Modnet>,
mut flame: ResMut<Flame>,
) {
let modnet_handle: Handle<Onnx> = asset_server.load("modnet_photographic_portrait_matting.onnx");
modnet.onnx = modnet_handle;
flame.onnx = asset_server.load("models/flame.onnx");
}

let input_handle: Handle<Image> = asset_server.load("images/person.png");
modnet.input = input_handle;

fn setup(
mut commands: Commands,
) {
commands.spawn(FlameInput::default());
commands.spawn(Camera3dBundle::default());
}


fn inference(
#[derive(Debug, Component, Reflect)]
struct HandledFlameOutput;

fn on_flame_output(
mut commands: Commands,
modnet: Res<Modnet>,
onnx_assets: Res<Assets<Onnx>>,
mut images: ResMut<Assets<Image>>,
mut complete: Local<bool>,
flame_outputs: Query<
(
Entity,
&FlameOutput,
),
Without<HandledFlameOutput>,
>,
) {
if *complete {
return;
}
for (entity, flame_output) in flame_outputs.iter() {
commands.entity(entity)
.insert(HandledFlameOutput);

let image = images.get(&modnet.input).expect("failed to get image asset");

let mask_image: Result<Image, String> = (|| {
let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?;
let session_lock = onnx.session.lock().map_err(|e| e.to_string())?;
let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?;

Ok(modnet_inference(session, &[image], None).pop().unwrap())
})();

match mask_image {
Ok(mask_image) => {
let mask_image = images.add(mask_image);

commands.spawn(NodeBundle {
style: Style {
display: Display::Grid,
width: Val::Percent(100.0),
height: Val::Percent(100.0),
grid_template_columns: RepeatedGridTrack::flex(1, 1.0),
grid_template_rows: RepeatedGridTrack::flex(1, 1.0),
..default()
},
background_color: BackgroundColor(Color::DARK_GRAY),
..default()
})
.with_children(|builder| {
builder.spawn(ImageBundle {
style: Style {
..default()
},
image: UiImage::new(mask_image.clone()),
..default()
});
});

commands.spawn(Camera2dBundle::default());

*complete = true;
}
Err(e) => {
println!("inference failed: {}", e);
}
println!("{:?}", flame_output);
}
}

```


Expand Down
46 changes: 46 additions & 0 deletions src/models/flame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ pub struct FlamePlugin;
impl Plugin for FlamePlugin {
fn build(&self, app: &mut App) {
app.init_resource::<Flame>();

app.register_type::<FlameInput>();
app.register_type::<FlameOutput>();

app.add_systems(PreUpdate, flame_inference_system);
}
}

Expand All @@ -22,9 +27,48 @@ pub struct Flame {
}


fn flame_inference_system(
mut commands: Commands,
flame: Res<Flame>,
onnx_assets: Res<Assets<Onnx>>,
flame_inputs: Query<
(
Entity,
&FlameInput,
),
Without<FlameOutput>,
>,
) {
for (entity, flame_input) in flame_inputs.iter() {
let flame_output: Result<FlameOutput, String> = (|| {
let onnx = onnx_assets.get(&flame.onnx).ok_or("failed to get flame ONNX asset")?;
let session_lock = onnx.session.lock().map_err(|e| e.to_string())?;
let session = session_lock.as_ref().ok_or("failed to get flame session from flame ONNX asset")?;

Ok(flame_inference(
session,
flame_input,
))
})();

match flame_output {
Ok(flame_output) => {
commands.entity(entity)
.insert(flame_output);
}
Err(e) => {
warn!("{}", e);
}
}
}
}


#[derive(
Debug,
Clone,
Component,
Reflect,
)]
pub struct FlameInput {
pub shape: [[f32; 100]; 8],
Expand Down Expand Up @@ -62,8 +106,10 @@ impl Default for FlameInput {
Debug,
Default,
Clone,
Component,
Deserialize,
Serialize,
Reflect,
)]
pub struct FlameOutput {
pub vertices: Vec<[f32; 3]>, // TODO: use Vec3 for binding
Expand Down
55 changes: 24 additions & 31 deletions tools/flame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@ use bevy_ort::{
models::flame::{
FlameInput,
FlameOutput,
flame_inference,
Flame,
FlamePlugin,
},
Onnx,
};


Expand All @@ -21,7 +19,8 @@ fn main() {
FlamePlugin,
))
.add_systems(Startup, load_flame)
.add_systems(Update, inference)
.add_systems(Startup, setup)
.add_systems(Update, on_flame_output)
.run();
}

Expand All @@ -30,41 +29,35 @@ fn load_flame(
asset_server: Res<AssetServer>,
mut flame: ResMut<Flame>,
) {
let flame_handle: Handle<Onnx> = asset_server.load("models/flame.onnx");
flame.onnx = flame_handle;
flame.onnx = asset_server.load("models/flame.onnx");
}


fn inference(
fn setup(
mut commands: Commands,
flame: Res<Flame>,
onnx_assets: Res<Assets<Onnx>>,
mut complete: Local<bool>,
) {
if *complete {
return;
}
commands.spawn(FlameInput::default());
commands.spawn(Camera3dBundle::default());
}

let flame_output: Result<FlameOutput, String> = (|| {
let onnx = onnx_assets.get(&flame.onnx).ok_or("failed to get ONNX asset")?;
let session_lock = onnx.session.lock().map_err(|e| e.to_string())?;
let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?;

Ok(flame_inference(
session,
&FlameInput::default(),
))
})();
#[derive(Debug, Component, Reflect)]
struct HandledFlameOutput;

fn on_flame_output(
mut commands: Commands,
flame_outputs: Query<
(
Entity,
&FlameOutput,
),
Without<HandledFlameOutput>,
>,
) {
for (entity, flame_output) in flame_outputs.iter() {
commands.entity(entity)
.insert(HandledFlameOutput);

match flame_output {
Ok(_flame_output) => {
// TODO: insert mesh
// TODO: insert pan orbit camera
commands.spawn(Camera3dBundle::default());
*complete = true;
}
Err(e) => {
eprintln!("inference failed: {}", e);
}
println!("{:?}", flame_output);
}
}
Loading