Skip to content

Commit

Permalink
Merge pull request #94 from tuna2134/refine
Browse files Browse the repository at this point in the history
コードのリファイン
  • Loading branch information
tuna2134 authored Oct 18, 2024
2 parents c312fb0 + c7d9112 commit d1cc8de
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 24 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion sbv2_api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ axum = "0.7.5"
dotenvy.workspace = true
env_logger.workspace = true
log = "0.4.22"
sbv2_core = { version = "0.2.0-alpha", path = "../sbv2_core" }
sbv2_core = { version = "0.2.0-alpha2", path = "../sbv2_core" }
serde = { version = "1.0.210", features = ["derive"] }
tokio = { version = "1.40.0", features = ["full"] }
utoipa = { version = "5.0.0", features = ["axum_extras"] }
Expand Down
4 changes: 2 additions & 2 deletions sbv2_bindings/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "sbv2_bindings"
version = "0.2.0-alpha1"
version = "0.2.0-alpha2"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand All @@ -12,4 +12,4 @@ crate-type = ["cdylib"]
anyhow.workspace = true
ndarray.workspace = true
pyo3 = { version = "0.22.0", features = ["anyhow"] }
sbv2_core = { version = "0.2.0-alpha", path = "../sbv2_core" }
sbv2_core = { version = "0.2.0-alpha2", path = "../sbv2_core" }
2 changes: 1 addition & 1 deletion sbv2_core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "sbv2_core"
description = "Style-Bert-VITSの推論ライブラリ"
version = "0.2.0-alpha1"
version = "0.2.0-alpha2"
edition = "2021"
license = "MIT"
readme = "../README.md"
Expand Down
12 changes: 6 additions & 6 deletions sbv2_core/src/bert.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::error::Result;
use ndarray::Array2;
use ndarray::{Array2, Ix2};
use ort::Session;

pub fn predict(
Expand All @@ -14,10 +14,10 @@ pub fn predict(
}?
)?;

let output = outputs.get("output").unwrap();
let output = outputs["output"]
.try_extract_tensor::<f32>()?
.into_dimensionality::<Ix2>()?
.to_owned();

let content = output.try_extract_tensor::<f32>()?.to_owned();
let (data, _) = content.clone().into_raw_vec_and_offset();

Ok(Array2::from_shape_vec((content.shape()[0], content.shape()[1]), data).unwrap())
Ok(output)
}
16 changes: 4 additions & 12 deletions sbv2_core/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::error::Result;
use ndarray::{array, Array1, Array2, Array3, Axis};
use ndarray::{array, Array1, Array2, Array3, Axis, Ix3};
use ort::{GraphOptimizationLevel, Session};

#[allow(clippy::vec_init_then_push, unused_variables)]
Expand Down Expand Up @@ -76,18 +76,10 @@ pub fn synthesize(
"length_scale" => array![length_scale],
}?)?;

let audio_array = outputs
.get("output")
.unwrap()
let audio_array = outputs["output"]
.try_extract_tensor::<f32>()?
.into_dimensionality::<Ix3>()?
.to_owned();

Ok(Array3::from_shape_vec(
(
audio_array.shape()[0],
audio_array.shape()[1],
audio_array.shape()[2],
),
audio_array.into_raw_vec_and_offset().0,
)?)
Ok(audio_array)
}

0 comments on commit d1cc8de

Please sign in to comment.