diff --git a/Cargo.lock b/Cargo.lock index 6372f6c..096231d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1901,7 +1901,7 @@ dependencies = [ [[package]] name = "sbv2_bindings" -version = "0.2.0-alpha1" +version = "0.2.0-alpha2" dependencies = [ "anyhow", "ndarray", @@ -1911,7 +1911,7 @@ dependencies = [ [[package]] name = "sbv2_core" -version = "0.2.0-alpha1" +version = "0.2.0-alpha2" dependencies = [ "anyhow", "dotenvy", diff --git a/sbv2_api/Cargo.toml b/sbv2_api/Cargo.toml index 45b6838..b13e631 100644 --- a/sbv2_api/Cargo.toml +++ b/sbv2_api/Cargo.toml @@ -9,7 +9,7 @@ axum = "0.7.5" dotenvy.workspace = true env_logger.workspace = true log = "0.4.22" -sbv2_core = { version = "0.2.0-alpha", path = "../sbv2_core" } +sbv2_core = { version = "0.2.0-alpha2", path = "../sbv2_core" } serde = { version = "1.0.210", features = ["derive"] } tokio = { version = "1.40.0", features = ["full"] } utoipa = { version = "5.0.0", features = ["axum_extras"] } diff --git a/sbv2_bindings/Cargo.toml b/sbv2_bindings/Cargo.toml index 21bb694..91c6e20 100644 --- a/sbv2_bindings/Cargo.toml +++ b/sbv2_bindings/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sbv2_bindings" -version = "0.2.0-alpha1" +version = "0.2.0-alpha2" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -12,4 +12,4 @@ crate-type = ["cdylib"] anyhow.workspace = true ndarray.workspace = true pyo3 = { version = "0.22.0", features = ["anyhow"] } -sbv2_core = { version = "0.2.0-alpha", path = "../sbv2_core" } +sbv2_core = { version = "0.2.0-alpha2", path = "../sbv2_core" } diff --git a/sbv2_core/Cargo.toml b/sbv2_core/Cargo.toml index 209470d..0c5259b 100644 --- a/sbv2_core/Cargo.toml +++ b/sbv2_core/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "sbv2_core" description = "Style-Bert-VITSの推論ライブラリ" -version = "0.2.0-alpha1" +version = "0.2.0-alpha2" edition = "2021" license = "MIT" readme = "../README.md" diff --git a/sbv2_core/src/bert.rs b/sbv2_core/src/bert.rs index 3bd8396..0710bf6 100644 --- a/sbv2_core/src/bert.rs +++ b/sbv2_core/src/bert.rs @@ -1,5 +1,5 @@ use crate::error::Result; -use ndarray::Array2; +use ndarray::{Array2, Ix2}; use ort::Session; pub fn predict( @@ -14,10 +14,10 @@ pub fn predict( }? )?; - let output = outputs.get("output").unwrap(); + let output = outputs["output"] + .try_extract_tensor::()? + .into_dimensionality::()? + .to_owned(); - let content = output.try_extract_tensor::()?.to_owned(); - let (data, _) = content.clone().into_raw_vec_and_offset(); - - Ok(Array2::from_shape_vec((content.shape()[0], content.shape()[1]), data).unwrap()) + Ok(output) } diff --git a/sbv2_core/src/model.rs b/sbv2_core/src/model.rs index 485c3f7..9f2a221 100644 --- a/sbv2_core/src/model.rs +++ b/sbv2_core/src/model.rs @@ -1,5 +1,5 @@ use crate::error::Result; -use ndarray::{array, Array1, Array2, Array3, Axis}; +use ndarray::{array, Array1, Array2, Array3, Axis, Ix3}; use ort::{GraphOptimizationLevel, Session}; #[allow(clippy::vec_init_then_push, unused_variables)] @@ -76,18 +76,10 @@ pub fn synthesize( "length_scale" => array![length_scale], }?)?; - let audio_array = outputs - .get("output") - .unwrap() + let audio_array = outputs["output"] .try_extract_tensor::()? + .into_dimensionality::()? .to_owned(); - Ok(Array3::from_shape_vec( - ( - audio_array.shape()[0], - audio_array.shape()[1], - audio_array.shape()[2], - ), - audio_array.into_raw_vec_and_offset().0, - )?) + Ok(audio_array) }