diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 66dcfc1ba..fb50033bb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -105,7 +105,7 @@ jobs: # cargoのキャッシュが原因でテストが失敗する場合はバージョン部分をカウントアップすること key: "v2-cargo-test-cache-${{ matrix.features }}-${{ matrix.os }}" - name: Run cargo test - run: cargo test -vv --features ,${{ matrix.features }} -- --include-ignored + run: RUST_BACKTRACE=full cargo test -vv --features ,${{ matrix.features }} -- --include-ignored c-header: runs-on: ubuntu-latest diff --git a/Cargo.lock b/Cargo.lock index ba7ac4171..a783bb6d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -489,7 +489,7 @@ checksum = "a6358dedf60f4d9b8db43ad187391afe959746101346fe51bb978126bec61dfb" dependencies = [ "clap 3.2.22", "heck", - "indexmap", + "indexmap 1.9.1", "log", "proc-macro2", "quote", @@ -581,7 +581,7 @@ dependencies = [ "atty", "bitflags", "clap_lex 0.2.4", - "indexmap", + "indexmap 1.9.1", "strsim", "termcolor", "textwrap", @@ -1110,6 +1110,12 @@ dependencies = [ "termcolor", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "erased-serde" version = "0.3.25" @@ -1130,6 +1136,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "errno" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" +dependencies = [ + "errno-dragonfly", + "libc", + "windows-sys 0.48.0", +] + [[package]] name = "errno-dragonfly" version = "0.1.2" @@ -1173,7 +1190,7 @@ checksum = "e94a7bbaa59354bc20dd75b67f23e2797b4490e9d6928203fb105c79e448c86c" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.2.16", "windows-sys 0.36.1", ] @@ -1422,7 +1439,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap", + "indexmap 1.9.1", "slab", "tokio", "tokio-util", @@ -1435,6 +1452,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" + [[package]] name = "heck" version = "0.4.0" @@ -1459,6 +1482,12 @@ dependencies = [ "libc", ] +[[package]] +name = "hermit-abi" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" + [[package]] name = "hkdf" version = "0.10.0" @@ -1658,7 +1687,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10a35a97730320ffe8e2d410b5d3b69279b98d2c14bdb8b70ea89ecf7888d41e" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.12.3", +] + +[[package]] +name = "indexmap" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +dependencies = [ + "equivalent", + "hashbrown 0.14.0", + "serde", ] [[package]] @@ -1706,12 +1746,13 @@ dependencies = [ [[package]] name = "io-lifetimes" -version = "1.0.4" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7d6c6f8c91b4b9ed43484ad1a938e393caf35960fce7f82a040497207bd8e9e" +checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ + "hermit-abi 0.3.2", "libc", - "windows-sys 0.42.0", + "windows-sys 0.48.0", ] [[package]] @@ -1728,7 +1769,7 @@ checksum = "28dfb6c8100ccc63462345b67d1bbc3679177c75ee4bf59bf29c8b1d110b8189" dependencies = [ "hermit-abi 0.2.6", "io-lifetimes", - "rustix", + "rustix 0.36.7", "windows-sys 0.42.0", ] @@ -1893,6 +1934,12 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" +[[package]] +name = "linux-raw-sys" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" + [[package]] name = "lock_api" version = "0.4.9" @@ -2052,7 +2099,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af5a8477ac96877b5bd1fd67e0c28736c12943aba24eda92b127e036b0c8f400" dependencies = [ - "indexmap", + "indexmap 1.9.1", "itertools", "ndarray", "noisy_float", @@ -2238,7 +2285,7 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" [[package]] name = "open_jtalk" version = "0.1.25" -source = "git+https://github.com/VOICEVOX/open_jtalk-rs.git?rev=d766a52bad4ccafe18597e57bd6842f59dca881e#d766a52bad4ccafe18597e57bd6842f59dca881e" +source = "git+https://github.com/VOICEVOX/open_jtalk-rs.git?rev=a16714ce16dec76fd0e3041a7acfa484921db3b5#a16714ce16dec76fd0e3041a7acfa484921db3b5" dependencies = [ "open_jtalk-sys", "thiserror", @@ -2247,7 +2294,7 @@ dependencies = [ [[package]] name = "open_jtalk-sys" version = "0.16.111" -source = "git+https://github.com/VOICEVOX/open_jtalk-rs.git?rev=d766a52bad4ccafe18597e57bd6842f59dca881e#d766a52bad4ccafe18597e57bd6842f59dca881e" +source = "git+https://github.com/VOICEVOX/open_jtalk-rs.git?rev=a16714ce16dec76fd0e3041a7acfa484921db3b5#a16714ce16dec76fd0e3041a7acfa484921db3b5" dependencies = [ "bindgen", "cmake", @@ -2334,7 +2381,7 @@ checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.2.16", "smallvec", "windows-sys 0.36.1", ] @@ -2764,6 +2811,15 @@ dependencies = [ "bitflags", ] +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" version = "1.6.0" @@ -2907,13 +2963,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4fdebc4b395b7fbb9ab11e462e20ed9051e7b16e42d24042c776eca0ac81b03" dependencies = [ "bitflags", - "errno", + "errno 0.2.8", "io-lifetimes", "libc", - "linux-raw-sys", + "linux-raw-sys 0.1.4", "windows-sys 0.42.0", ] +[[package]] +name = "rustix" +version = "0.37.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acf8729d8542766f1b2cf77eb034d52f40d375bb8b615d0b147089946e16613d" +dependencies = [ + "bitflags", + "errno 0.3.1", + "io-lifetimes", + "libc", + "linux-raw-sys 0.3.8", + "windows-sys 0.48.0", +] + [[package]] name = "rustls" version = "0.20.6" @@ -3011,22 +3081,22 @@ checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" [[package]] name = "serde" -version = "1.0.145" +version = "1.0.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "728eb6351430bccb993660dfffc5a72f91ccc1295abaa8ce19b27ebe4f75568b" +checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.145" +version = "1.0.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fa1584d3d1bcacd84c277a0dfe21f5b0f6accf4a23d04d4c6d61f1af522b4c" +checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" dependencies = [ "proc-macro2", "quote", - "syn 1.0.102", + "syn 2.0.15", ] [[package]] @@ -3035,7 +3105,7 @@ version = "1.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e55a28e3aaef9d5ce0506d0a14dbba8054ddc7e499ef522dd8b26859ec9d4a44" dependencies = [ - "indexmap", + "indexmap 1.9.1", "itoa", "ryu", "serde", @@ -3399,15 +3469,16 @@ checksum = "c02424087780c9b71cc96799eaeddff35af2bc513278cda5c99fc1f5d026d3c1" [[package]] name = "tempfile" -version = "3.4.0" +version = "3.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af18f7ae1acd354b992402e9ec5864359d693cd8a79dcbef59f76891701c1e95" +checksum = "31c0432476357e58790aaa47a8efb0c5138f137343f3b5f23bd36a27e3b0a6d6" dependencies = [ + "autocfg", "cfg-if", "fastrand", - "redox_syscall", - "rustix", - "windows-sys 0.42.0", + "redox_syscall 0.3.5", + "rustix 0.37.19", + "windows-sys 0.48.0", ] [[package]] @@ -3646,7 +3717,7 @@ version = "0.19.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e6a7712b49e1775fb9a7b998de6635b299237f48b404dde71704f2e0e7f37e5" dependencies = [ - "indexmap", + "indexmap 1.9.1", "nom8", "serde", "serde_spanned", @@ -3865,6 +3936,16 @@ dependencies = [ "serde", ] +[[package]] +name = "uuid" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d023da39d1fde5a8a3fe1f3e01ca9632ada0a63e9797de55a879d6e2236277be" +dependencies = [ + "getrandom 0.2.7", + "serde", +] + [[package]] name = "valuable" version = "0.1.0" @@ -3910,6 +3991,8 @@ dependencies = [ "futures", "heck", "humansize", + "indexmap 2.0.0", + "itertools", "nanoid", "once_cell", "onnxruntime", @@ -3922,10 +4005,12 @@ dependencies = [ "serde_json", "strum", "tar", + "tempfile", "test_util", "thiserror", "tokio", "tracing", + "uuid", "windows", ] @@ -3955,12 +4040,14 @@ dependencies = [ "serde", "serde_json", "strum", + "tempfile", "test_util", "thiserror", "tokio", "toml 0.7.2", "tracing-subscriber", "typetag", + "uuid", "voicevox_core", ] @@ -3979,6 +4066,7 @@ dependencies = [ "test_util", "tokio", "tracing", + "uuid", "voicevox_core", ] diff --git a/Cargo.toml b/Cargo.toml index e5d45dea1..fc5c52efc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,15 +15,18 @@ clap = { version = "4.0.10", features = ["derive"] } const-default = { version = "1.0.0", features = ["derive"] } easy-ext = "1.0.1" fs-err = { version = "2.9.0", features = ["tokio"] } +itertools = "0.10.5" once_cell = "1.15.0" regex = "1.6.0" serde = { version = "1.0.145", features = ["derive"] } serde_json = { version = "1.0.85", features = ["preserve_order"] } strum = { version = "0.24.1", features = ["derive"] } +tempfile = "3.6.0" test_util = { path = "crates/test_util" } thiserror = "1.0.37" tracing = { version = "0.1.37", features = ["log"] } tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } +uuid = { version = "1.4.0", features = ["v4", "serde"] } voicevox_core = { path = "crates/voicevox_core" } tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "macros", "sync"] } derive-getters = "0.2.0" diff --git a/crates/download/Cargo.toml b/crates/download/Cargo.toml index 16a81978f..98a59d60b 100644 --- a/crates/download/Cargo.toml +++ b/crates/download/Cargo.toml @@ -14,7 +14,7 @@ fs-err.workspace = true futures-core = "0.3.25" futures-util = "0.3.25" indicatif = "0.17.3" -itertools = "0.10.5" +itertools.workspace = true octocrab = { version = "0.19.0", default-features = false, features = ["rustls-tls", "stream"] } once_cell.workspace = true platforms = "3.0.2" diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index 1a399ff75..9e5febe22 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -19,6 +19,8 @@ duplicate = "1.0.0" easy-ext.workspace = true fs-err.workspace = true futures = "0.3.26" +indexmap = { version = "2.0.0", features = ["serde"] } +itertools.workspace = true nanoid = "0.4.0" once_cell.workspace = true process_path.workspace = true @@ -26,9 +28,11 @@ regex.workspace = true serde.workspace = true serde_json.workspace = true strum.workspace = true +tempfile.workspace = true thiserror.workspace = true tokio.workspace = true tracing.workspace = true +uuid.workspace = true [dependencies.onnxruntime] git = "https://github.com/VOICEVOX/onnxruntime-rs.git" @@ -36,7 +40,7 @@ rev = "ebb9dcb9b26ee681889b52b6db3b4f642b04a250" [dependencies.open_jtalk] git = "https://github.com/VOICEVOX/open_jtalk-rs.git" -rev = "d766a52bad4ccafe18597e57bd6842f59dca881e" +rev = "a16714ce16dec76fd0e3041a7acfa484921db3b5" [dev-dependencies] flate2 = "1.0.24" diff --git a/crates/voicevox_core/src/engine/open_jtalk.rs b/crates/voicevox_core/src/engine/open_jtalk.rs index f07fe89cb..79bfcc3de 100644 --- a/crates/voicevox_core/src/engine/open_jtalk.rs +++ b/crates/voicevox_core/src/engine/open_jtalk.rs @@ -1,11 +1,13 @@ +use std::io::Write; use std::{ path::{Path, PathBuf}, sync::Mutex, }; +use tempfile::NamedTempFile; use ::open_jtalk::*; -use crate::Error; +use crate::{Error, UserDict}; #[derive(thiserror::Error, Debug)] pub enum OpenJtalkError { @@ -23,7 +25,7 @@ pub type Result = std::result::Result; pub struct OpenJtalk { resources: Mutex, - dict_loaded: bool, + dict_dir: Option, } struct Resources { @@ -43,7 +45,7 @@ impl OpenJtalk { njd: ManagedResource::initialize(), jpcommon: ManagedResource::initialize(), }), - dict_loaded: false, + dict_dir: None, } } pub fn new_with_initialize( @@ -55,6 +57,54 @@ impl OpenJtalk { Ok(s) } + /// ユーザー辞書を設定する。 + /// 先に [`Self::load`] を呼ぶ必要がある。 + /// この関数を読んだ後にユーザー辞書を変更した場合は、再度この関数を呼ぶ必要がある。 + pub fn use_user_dict(&self, user_dict: &UserDict) -> crate::result::Result<()> { + let dict_dir = self + .dict_dir + .as_ref() + .and_then(|dict_dir| dict_dir.to_str()) + .ok_or(Error::NotLoadedOpenjtalkDict)?; + + // ユーザー辞書用のcsvを作成 + let mut temp_csv = NamedTempFile::new().map_err(|e| Error::UseUserDict(e.to_string()))?; + temp_csv + .write_all(user_dict.to_mecab_format().as_bytes()) + .map_err(|e| Error::UseUserDict(e.to_string()))?; + let temp_csv_path = temp_csv.into_temp_path(); + let temp_dict = NamedTempFile::new().map_err(|e| Error::UseUserDict(e.to_string()))?; + let temp_dict_path = temp_dict.into_temp_path(); + + // Mecabでユーザー辞書をコンパイル + // TODO: エラー(SEGV)が出るパターンを把握し、それをRust側で防ぐ。 + mecab_dict_index(&[ + "mecab-dict-index", + "-d", + dict_dir, + "-u", + temp_dict_path.to_str().unwrap(), + "-f", + "utf-8", + "-t", + "utf-8", + temp_csv_path.to_str().unwrap(), + "-q", + ]); + + let Resources { mecab, .. } = &mut *self.resources.lock().unwrap(); + + let result = mecab.load_with_userdic(Path::new(dict_dir), Some(Path::new(&temp_dict_path))); + + if !result { + return Err(Error::UseUserDict( + "辞書のコンパイルに失敗しました".to_string(), + )); + } + + Ok(()) + } + pub fn extract_fullcontext(&self, text: impl AsRef) -> Result> { let Resources { mecab, @@ -112,10 +162,10 @@ impl OpenJtalk { .mecab .load(open_jtalk_dict_dir.as_ref()); if result { - self.dict_loaded = true; + self.dict_dir = Some(open_jtalk_dict_dir.as_ref().into()); Ok(()) } else { - self.dict_loaded = false; + self.dict_dir = None; Err(OpenJtalkError::Load { mecab_dict_dir: open_jtalk_dict_dir.as_ref().into(), }) @@ -123,7 +173,7 @@ impl OpenJtalk { } pub fn dict_loaded(&self) -> bool { - self.dict_loaded + self.dict_dir.is_some() } } diff --git a/crates/voicevox_core/src/error.rs b/crates/voicevox_core/src/error.rs index 6c4916412..9068d92e2 100644 --- a/crates/voicevox_core/src/error.rs +++ b/crates/voicevox_core/src/error.rs @@ -4,6 +4,7 @@ use super::*; //use engine:: use std::path::PathBuf; use thiserror::Error; +use uuid::Uuid; /* * 新しいエラーを定義したら、必ずresult_code.rsにあるVoicevoxResultCodeに対応するコードを定義し、 @@ -82,6 +83,21 @@ pub enum Error { #[error("{},{0}", base_error_message(VOICEVOX_RESULT_PARSE_KANA_ERROR))] ParseKana(#[from] KanaParseError), + + #[error("{}: {0}", base_error_message(VOICEVOX_LOAD_USER_DICT_ERROR))] + LoadUserDict(String), + + #[error("{}: {0}", base_error_message(VOICEVOX_SAVE_USER_DICT_ERROR))] + SaveUserDict(String), + + #[error("{}: {0}", base_error_message(VOICEVOX_UNKNOWN_USER_DICT_WORD_ERROR))] + UnknownWord(Uuid), + + #[error("{}: {0}", base_error_message(VOICEVOX_USE_USER_DICT_ERROR))] + UseUserDict(String), + + #[error("{}: {0}", base_error_message(VOICEVOX_INVALID_USER_DICT_WORD_ERROR))] + InvalidWord(InvalidWordError), } fn base_error_message(result_code: VoicevoxResultCode) -> &'static str { diff --git a/crates/voicevox_core/src/lib.rs b/crates/voicevox_core/src/lib.rs index ec8928d79..0175ca2a2 100644 --- a/crates/voicevox_core/src/lib.rs +++ b/crates/voicevox_core/src/lib.rs @@ -12,6 +12,7 @@ mod numerics; mod result; pub mod result_code; mod status; +mod user_dict; mod version; mod voice_model; mod voice_synthesizer; @@ -31,6 +32,7 @@ pub use self::result::*; pub use self::voice_model::*; pub use devices::*; pub use manifest::*; +pub use user_dict::*; pub use version::*; pub use voice_synthesizer::*; diff --git a/crates/voicevox_core/src/result_code.rs b/crates/voicevox_core/src/result_code.rs index 55f59ab2f..011df479d 100644 --- a/crates/voicevox_core/src/result_code.rs +++ b/crates/voicevox_core/src/result_code.rs @@ -43,6 +43,18 @@ pub enum VoicevoxResultCode { VOICEVOX_ALREADY_LOADED_MODEL_ERROR = 18, /// Modelが読み込まれていない VOICEVOX_UNLOADED_MODEL_ERROR = 19, + /// ユーザー辞書を読み込めなかった + VOICEVOX_LOAD_USER_DICT_ERROR = 20, + /// ユーザー辞書を書き込めなかった + VOICEVOX_SAVE_USER_DICT_ERROR = 21, + /// ユーザー辞書に単語が見つからなかった + VOICEVOX_UNKNOWN_USER_DICT_WORD_ERROR = 22, + /// OpenJTalkのユーザー辞書の設定に失敗した + VOICEVOX_USE_USER_DICT_ERROR = 23, + /// ユーザー辞書の単語のバリデーションに失敗した + VOICEVOX_INVALID_USER_DICT_WORD_ERROR = 24, + /// UUIDの変換に失敗した + VOICEVOX_RESULT_INVALID_UUID_ERROR = 25, } pub const fn error_result_to_message(result_code: VoicevoxResultCode) -> &'static str { @@ -79,5 +91,13 @@ pub const fn error_result_to_message(result_code: VoicevoxResultCode) -> &'stati "すでに読み込まれているModelを読み込もうとしました\0" } VOICEVOX_UNLOADED_MODEL_ERROR => "Modelが読み込まれていません\0", + VOICEVOX_LOAD_USER_DICT_ERROR => "ユーザー辞書を読み込めませんでした\0", + VOICEVOX_SAVE_USER_DICT_ERROR => "ユーザー辞書を書き込めませんでした\0", + VOICEVOX_UNKNOWN_USER_DICT_WORD_ERROR => "ユーザー辞書に単語が見つかりませんでした\0", + VOICEVOX_USE_USER_DICT_ERROR => "OpenJTalkのユーザー辞書の設定に失敗しました\0", + VOICEVOX_INVALID_USER_DICT_WORD_ERROR => { + "ユーザー辞書の単語のバリデーションに失敗しました\0" + } + VOICEVOX_RESULT_INVALID_UUID_ERROR => "UUIDの変換に失敗しました\0", } } diff --git a/crates/voicevox_core/src/user_dict/dict.rs b/crates/voicevox_core/src/user_dict/dict.rs new file mode 100644 index 000000000..380a8ac24 --- /dev/null +++ b/crates/voicevox_core/src/user_dict/dict.rs @@ -0,0 +1,82 @@ +use derive_getters::Getters; +use fs_err::File; +use indexmap::IndexMap; +use itertools::join; +use uuid::Uuid; + +use super::word::*; +use crate::{Error, Result}; + +/// ユーザー辞書。 +/// 単語はJSONとの相互変換のために挿入された順序を保つ。 +#[derive(Clone, Debug, Default, Getters)] +pub struct UserDict { + words: IndexMap, +} + +impl UserDict { + /// ユーザー辞書を作成する。 + pub fn new() -> Self { + Default::default() + } + + /// ユーザー辞書をファイルから読み込む。 + /// + /// ファイルが読めなかった、または内容が不正だった場合はエラーを返す。 + pub fn load(&mut self, store_path: &str) -> Result<()> { + let store_path = std::path::Path::new(store_path); + + let store_file = File::open(store_path).map_err(|e| Error::LoadUserDict(e.to_string()))?; + + let words: IndexMap = + serde_json::from_reader(store_file).map_err(|e| Error::LoadUserDict(e.to_string()))?; + + self.words.extend(words); + Ok(()) + } + + /// ユーザー辞書に単語を追加する。 + pub fn add_word(&mut self, word: UserDictWord) -> Result { + let word_uuid = Uuid::new_v4(); + self.words.insert(word_uuid, word); + Ok(word_uuid) + } + + /// ユーザー辞書の単語を変更する。 + pub fn update_word(&mut self, word_uuid: Uuid, new_word: UserDictWord) -> Result<()> { + if !self.words.contains_key(&word_uuid) { + return Err(Error::UnknownWord(word_uuid)); + } + self.words.insert(word_uuid, new_word); + Ok(()) + } + + /// ユーザー辞書から単語を削除する。 + pub fn remove_word(&mut self, word_uuid: Uuid) -> Result { + let Some(word) = self.words.remove(&word_uuid) else { + return Err(Error::UnknownWord(word_uuid)); + }; + Ok(word) + } + + /// 他のユーザー辞書をインポートする。 + pub fn import(&mut self, other: &Self) -> Result<()> { + for (word_uuid, word) in &other.words { + self.words.insert(*word_uuid, word.clone()); + } + Ok(()) + } + + /// ユーザー辞書を保存する。 + pub fn save(&self, store_path: &str) -> Result<()> { + let mut file = File::create(store_path).map_err(|e| Error::SaveUserDict(e.to_string()))?; + serde_json::to_writer(&mut file, &self.words) + .map_err(|e| Error::SaveUserDict(e.to_string()))?; + Ok(()) + } + + /// MeCabで使用する形式に変換する。 + pub(crate) fn to_mecab_format(&self) -> String { + join(self.words.values().map(UserDictWord::to_mecab_format), "\n") + } +} diff --git a/crates/voicevox_core/src/user_dict/mod.rs b/crates/voicevox_core/src/user_dict/mod.rs new file mode 100644 index 000000000..58def046f --- /dev/null +++ b/crates/voicevox_core/src/user_dict/mod.rs @@ -0,0 +1,6 @@ +mod dict; +mod part_of_speech_data; +mod word; + +pub use dict::*; +pub use word::*; diff --git a/crates/voicevox_core/src/user_dict/part_of_speech_data.rs b/crates/voicevox_core/src/user_dict/part_of_speech_data.rs new file mode 100644 index 000000000..712cef885 --- /dev/null +++ b/crates/voicevox_core/src/user_dict/part_of_speech_data.rs @@ -0,0 +1,121 @@ +use derive_getters::Getters; +use once_cell::sync::Lazy; +use std::collections::HashMap; + +use crate::UserDictWordType; + +/// 最小の優先度 +pub static MIN_PRIORITY: u32 = 0; +/// 最大の優先度 +pub static MAX_PRIORITY: u32 = 10; + +/// 品詞ごとの情報 +#[derive(Debug, Getters)] +pub struct PartOfSpeechDetail { + /// 品詞 + pub part_of_speech: &'static str, + /// 品詞細分類1 + pub part_of_speech_detail_1: &'static str, + /// 品詞細分類2 + pub part_of_speech_detail_2: &'static str, + /// 品詞細分類3 + pub part_of_speech_detail_3: &'static str, + /// 文脈IDは辞書の左・右文脈IDのこと + /// + /// 参考: + pub context_id: i32, + /// コストのパーセンタイル + pub cost_candidates: Vec, + /// アクセント結合規則の一覧 + pub accent_associative_rules: Vec<&'static str>, +} + +// 元データ: https://github.com/VOICEVOX/voicevox_engine/blob/master/voicevox_engine/part_of_speech_data.py +pub static PART_OF_SPEECH_DETAIL: Lazy> = + Lazy::new(|| { + HashMap::from_iter([ + ( + UserDictWordType::ProperNoun, + PartOfSpeechDetail { + part_of_speech: "名詞", + part_of_speech_detail_1: "固有名詞", + part_of_speech_detail_2: "一般", + part_of_speech_detail_3: "*", + context_id: 1348, + cost_candidates: vec![ + -988, 3488, 4768, 6048, 7328, 8609, 8734, 8859, 8984, 9110, 14176, + ], + accent_associative_rules: vec!["*", "C1", "C2", "C3", "C4", "C5"], + }, + ), + ( + UserDictWordType::CommonNoun, + PartOfSpeechDetail { + part_of_speech: "名詞", + part_of_speech_detail_1: "一般", + part_of_speech_detail_2: "*", + part_of_speech_detail_3: "*", + context_id: 1345, + cost_candidates: vec![ + -4445, 49, 1473, 2897, 4321, 5746, 6554, 7362, 8170, 8979, 15001, + ], + accent_associative_rules: vec!["*", "C1", "C2", "C3", "C4", "C5"], + }, + ), + ( + UserDictWordType::Verb, + PartOfSpeechDetail { + part_of_speech: "動詞", + part_of_speech_detail_1: "自立", + part_of_speech_detail_2: "*", + part_of_speech_detail_3: "*", + context_id: 642, + cost_candidates: vec![ + 3100, 6160, 6360, 6561, 6761, 6962, 7414, 7866, 8318, 8771, 13433, + ], + accent_associative_rules: vec!["*"], + }, + ), + ( + UserDictWordType::Adjective, + PartOfSpeechDetail { + part_of_speech: "形容詞", + part_of_speech_detail_1: "自立", + part_of_speech_detail_2: "*", + part_of_speech_detail_3: "*", + context_id: 20, + cost_candidates: vec![ + 1527, 3266, 3561, 3857, 4153, 4449, 5149, 5849, 6549, 7250, 10001, + ], + accent_associative_rules: vec!["*"], + }, + ), + ( + UserDictWordType::Suffix, + PartOfSpeechDetail { + part_of_speech: "名詞", + part_of_speech_detail_1: "接尾", + part_of_speech_detail_2: "一般", + part_of_speech_detail_3: "*", + context_id: 1358, + cost_candidates: vec![ + 4399, 5373, 6041, 6710, 7378, 8047, 9440, 10834, 12228, 13622, 15847, + ], + accent_associative_rules: vec!["*", "C1", "C2", "C3", "C4", "C5"], + }, + ), + ]) + }); + +fn search_cost_candidates(context_id: i32) -> &'static [i32] { + &PART_OF_SPEECH_DETAIL + .values() + .find(|x| x.context_id == context_id) + .expect("品詞IDが不正です") + .cost_candidates +} + +pub fn priority2cost(context_id: i32, priority: u32) -> i32 { + let cost_candidates = search_cost_candidates(context_id); + cost_candidates[(MAX_PRIORITY - priority) as usize] +} diff --git a/crates/voicevox_core/src/user_dict/word.rs b/crates/voicevox_core/src/user_dict/word.rs new file mode 100644 index 000000000..2d1ae8ed9 --- /dev/null +++ b/crates/voicevox_core/src/user_dict/word.rs @@ -0,0 +1,294 @@ +use crate::{ + error::Error, + result::Result, + user_dict::part_of_speech_data::{ + priority2cost, MAX_PRIORITY, MIN_PRIORITY, PART_OF_SPEECH_DETAIL, + }, +}; +use derive_getters::Getters; +use once_cell::sync::Lazy; +use regex::Regex; +use serde::{de::Error as _, Deserialize, Serialize}; +use std::ops::RangeToInclusive; + +/// ユーザー辞書の単語。 +#[derive(Clone, Debug, Getters, Serialize)] +pub struct UserDictWord { + /// 単語の表記。 + pub surface: String, + /// 単語の読み。 + pub pronunciation: String, + /// アクセント型。 + pub accent_type: usize, + /// 単語の種類。 + pub word_type: UserDictWordType, + /// 単語の優先度。 + pub priority: u32, + + /// モーラ数。 + mora_count: usize, +} + +impl<'de> Deserialize<'de> for UserDictWord { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + let raw = UserDictWord::deserialize(deserializer)?; + return Self::new( + &raw.surface, + raw.pronunciation, + raw.accent_type, + raw.word_type, + raw.priority, + ) + .map_err(D::Error::custom); + + #[derive(Deserialize)] + struct UserDictWord { + surface: String, + pronunciation: String, + accent_type: usize, + word_type: UserDictWordType, + priority: u32, + } + } +} + +#[derive(thiserror::Error, Debug, PartialEq)] +pub enum InvalidWordError { + #[error("無効な発音です({1}): {0:?}")] + InvalidPronunciation(String, &'static str), + #[error("優先度は{MIN_PRIORITY}以上{MAX_PRIORITY}以下である必要があります: {0}")] + InvalidPriority(u32), + #[error("誤ったアクセント型です({1:?}の範囲から外れています): {0}")] + InvalidAccentType(usize, RangeToInclusive), +} +type InvalidWordResult = std::result::Result; + +static PRONUNCIATION_REGEX: Lazy = Lazy::new(|| Regex::new(r"^[ァ-ヴー]+$").unwrap()); +static MORA_REGEX: Lazy = Lazy::new(|| { + Regex::new(concat!( + "(?:", + "[イ][ェ]|[ヴ][ャュョ]|[トド][ゥ]|[テデ][ィャュョ]|[デ][ェ]|[クグ][ヮ]|", // rule_others + "[キシチニヒミリギジビピ][ェャュョ]|", // rule_line_i + "[ツフヴ][ァ]|[ウスツフヴズ][ィ]|[ウツフヴ][ェォ]|", // rule_line_u + "[ァ-ヴー]", // rule_one_mora + ")", + )) + .unwrap() +}); +static SPACE_REGEX: Lazy = Lazy::new(|| Regex::new(r"\p{Z}").unwrap()); + +impl Default for UserDictWord { + fn default() -> Self { + Self { + surface: "".to_string(), + pronunciation: "".to_string(), + accent_type: 0, + word_type: UserDictWordType::CommonNoun, + priority: 0, + mora_count: 0, + } + } +} + +impl UserDictWord { + pub fn new( + surface: &str, + pronunciation: String, + accent_type: usize, + word_type: UserDictWordType, + priority: u32, + ) -> Result { + if MIN_PRIORITY > priority || priority > MAX_PRIORITY { + return Err(Error::InvalidWord(InvalidWordError::InvalidPriority( + priority, + ))); + } + validate_pronunciation(&pronunciation).map_err(Error::InvalidWord)?; + let mora_count = + calculate_mora_count(&pronunciation, accent_type).map_err(Error::InvalidWord)?; + Ok(Self { + surface: to_zenkaku(surface), + pronunciation, + accent_type, + word_type, + priority, + mora_count, + }) + } +} + +/// カタカナの文字列が発音として有効かどうかを判定する。 +fn validate_pronunciation(pronunciation: &str) -> InvalidWordResult<()> { + // 元実装:https://github.com/VOICEVOX/voicevox_engine/blob/39747666aa0895699e188f3fd03a0f448c9cf746/voicevox_engine/model.py#L190-L210 + if !PRONUNCIATION_REGEX.is_match(pronunciation) { + return Err(InvalidWordError::InvalidPronunciation( + pronunciation.to_string(), + "カタカナ以外の文字", + )); + } + let sutegana = ['ァ', 'ィ', 'ゥ', 'ェ', 'ォ', 'ャ', 'ュ', 'ョ', 'ヮ', 'ッ']; + + let pronunciation_chars = pronunciation.chars().collect::>(); + + for i in 0..pronunciation_chars.len() { + // 「キャット」のように、捨て仮名が連続する可能性が考えられるので、 + // 「ッ」に関しては「ッ」そのものが連続している場合と、「ッ」の後にほかの捨て仮名が連続する場合のみ無効とする + if sutegana.contains(&pronunciation_chars[i]) + && i < pronunciation_chars.len() - 1 + && (sutegana[..sutegana.len() - 1].contains(pronunciation_chars.get(i + 1).unwrap()) + || (pronunciation_chars.get(i).unwrap() == &'ッ' + && sutegana.contains(pronunciation_chars.get(i + 1).unwrap()))) + { + return Err(InvalidWordError::InvalidPronunciation( + pronunciation.to_string(), + "捨て仮名の連続", + )); + } + + if pronunciation_chars.get(i).unwrap() == &'ヮ' + && i != 0 + && !['ク', 'グ'].contains(&pronunciation_chars[i - 1]) + { + return Err(InvalidWordError::InvalidPronunciation( + pronunciation.to_string(), + "「くゎ」「ぐゎ」以外の「ゎ」の使用", + )); + } + } + Ok(()) +} + +/// カタカナの発音からモーラ数を計算する。 +fn calculate_mora_count(pronunciation: &str, accent_type: usize) -> InvalidWordResult { + // 元実装:https://github.com/VOICEVOX/voicevox_engine/blob/39747666aa0895699e188f3fd03a0f448c9cf746/voicevox_engine/model.py#L212-L236 + let mora_count = MORA_REGEX.find_iter(pronunciation).count(); + + if accent_type > mora_count { + return Err(InvalidWordError::InvalidAccentType( + accent_type, + ..=mora_count, + )); + } + + Ok(mora_count) +} + +/// 一部の種類の文字を、全角文字に置き換える。 +/// +/// 具体的には +/// - "!"から"~"までの範囲の文字(数字やアルファベット)は、対応する全角文字に +/// - " "などの目に見えない文字は、まとめて全角スペース(0x3000)に +/// 変換する。 +fn to_zenkaku(surface: &str) -> String { + // 元実装:https://github.com/VOICEVOX/voicevox/blob/69898f5dd001d28d4de355a25766acb0e0833ec2/src/components/DictionaryManageDialog.vue#L379-L387 + SPACE_REGEX + .replace_all(surface, "\u{3000}") + .chars() + .map(|c| match u32::from(c) { + i @ 0x21..=0x7e => char::from_u32(0xfee0 + i).unwrap_or(c), + _ => c, + }) + .collect() +} +/// ユーザー辞書の単語の種類。 +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum UserDictWordType { + /// 固有名詞。 + ProperNoun, + /// 一般名詞。 + CommonNoun, + /// 動詞。 + Verb, + /// 形容詞。 + Adjective, + /// 接尾辞。 + Suffix, +} + +impl UserDictWord { + pub fn to_mecab_format(&self) -> String { + let pos = PART_OF_SPEECH_DETAIL.get(&self.word_type).unwrap(); + format!( + "{},{},{},{},{},{},{},{},{},{},{},{},{},{}/{},{}\n", + self.surface, + pos.context_id, + pos.context_id, + priority2cost(pos.context_id, self.priority), + pos.part_of_speech, + pos.part_of_speech_detail_1, + pos.part_of_speech_detail_2, + pos.part_of_speech_detail_3, + "*", // inflectional_type + "*", // inflectional_form + "*", // stem + self.pronunciation, // yomi + self.pronunciation, + self.accent_type, + self.mora_count, + "*" // accent_associative_rule + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + + #[rstest] + #[case("abcdefg", "abcdefg")] + #[case("あいうえお", "あいうえお")] + #[case("a_b_c_d_e_f_g", "a_b_c_d_e_f_g")] + #[case("a b c d e f g", "a b c d e f g")] + fn to_zenkaku_works(#[case] before: &str, #[case] after: &str) { + assert_eq!(to_zenkaku(before), after); + } + + #[rstest] + fn to_mecab_format_works() { + // テストの期待値は、VOICEVOX Engineが一時的に出力するcsvの内容を使用した。 + let word = UserDictWord::new( + "単語", + "ヨミ".to_string(), + 0, + UserDictWordType::ProperNoun, + 5, + ) + .unwrap(); + assert_eq!( + word.to_mecab_format(), + "単語,1348,1348,8609,名詞,固有名詞,一般,*,*,*,*,ヨミ,ヨミ,0/2,*\n" + ); + } + + #[rstest] + #[case("ヨミ", None)] + #[case("漢字", Some("カタカナ以外の文字"))] + #[case("ひらがな", Some("カタカナ以外の文字"))] + #[case("ッッッ", Some("捨て仮名の連続"))] + #[case("ァァァァ", Some("捨て仮名の連続"))] + #[case("ヌヮ", Some("「くゎ」「ぐゎ」以外の「ゎ」の使用"))] + fn pronunciation_validation_works( + #[case] pronunciation: &str, + #[case] expected_error_message: Option<&str>, + ) { + let result = validate_pronunciation(pronunciation); + + if let Some(expected_error_message) = expected_error_message { + match result { + Ok(_) => unreachable!(), + Err(InvalidWordError::InvalidPronunciation(err_pronunciation, err_message)) => { + assert_eq!(err_pronunciation, pronunciation); + assert_eq!(err_message, expected_error_message); + } + Err(_) => unreachable!(), + } + } else { + assert!(result.is_ok()); + } + } +} diff --git a/crates/voicevox_core_c_api/Cargo.toml b/crates/voicevox_core_c_api/Cargo.toml index 81b65ff58..5ba235311 100644 --- a/crates/voicevox_core_c_api/Cargo.toml +++ b/crates/voicevox_core_c_api/Cargo.toml @@ -24,6 +24,7 @@ serde_json.workspace = true thiserror.workspace = true tokio.workspace = true tracing-subscriber.workspace = true +uuid.workspace = true voicevox_core.workspace = true [dependencies.chrono] @@ -48,6 +49,7 @@ regex.workspace = true rstest = "0.15.0" serde.workspace = true strum.workspace = true +tempfile.workspace = true test_util.workspace = true toml = "0.7.2" typetag = "0.2.5" diff --git a/crates/voicevox_core_c_api/include/voicevox_core.h b/crates/voicevox_core_c_api/include/voicevox_core.h index d3149ffd5..ecb358846 100644 --- a/crates/voicevox_core_c_api/include/voicevox_core.h +++ b/crates/voicevox_core_c_api/include/voicevox_core.h @@ -117,11 +117,68 @@ enum VoicevoxResultCode * Modelが読み込まれていない */ VOICEVOX_UNLOADED_MODEL_ERROR = 19, + /** + * ユーザー辞書を読み込めなかった + */ + VOICEVOX_LOAD_USER_DICT_ERROR = 20, + /** + * ユーザー辞書を書き込めなかった + */ + VOICEVOX_SAVE_USER_DICT_ERROR = 21, + /** + * ユーザー辞書に単語が見つからなかった + */ + VOICEVOX_UNKNOWN_USER_DICT_WORD_ERROR = 22, + /** + * OpenJTalkのユーザー辞書の設定に失敗した + */ + VOICEVOX_USE_USER_DICT_ERROR = 23, + /** + * ユーザー辞書の単語のバリデーションに失敗した + */ + VOICEVOX_INVALID_USER_DICT_WORD_ERROR = 24, + /** + * UUIDの変換に失敗した + */ + VOICEVOX_RESULT_INVALID_UUID_ERROR = 25, }; #ifndef __cplusplus typedef int32_t VoicevoxResultCode; #endif // __cplusplus +/** + * ユーザー辞書の単語の種類 + */ +enum VoicevoxUserDictWordType +#ifdef __cplusplus + : int32_t +#endif // __cplusplus + { + /** + * 固有名詞。 + */ + VOICEVOX_USER_DICT_WORD_TYPE_PROPER_NOUN = 0, + /** + * 一般名詞。 + */ + VOICEVOX_USER_DICT_WORD_TYPE_COMMON_NOUN = 1, + /** + * 動詞。 + */ + VOICEVOX_USER_DICT_WORD_TYPE_VERB = 2, + /** + * 形容詞。 + */ + VOICEVOX_USER_DICT_WORD_TYPE_ADJECTIVE = 3, + /** + * 接尾辞。 + */ + VOICEVOX_USER_DICT_WORD_TYPE_SUFFIX = 4, +}; +#ifndef __cplusplus +typedef int32_t VoicevoxUserDictWordType; +#endif // __cplusplus + /** * 参照カウントで管理されたOpenJtalk */ @@ -129,6 +186,11 @@ typedef struct OpenJtalkRc OpenJtalkRc; typedef struct VoicevoxSynthesizer VoicevoxSynthesizer; +/** + * ユーザー辞書 + */ +typedef struct VoicevoxUserDict VoicevoxUserDict; + /** * 音声モデル */ @@ -207,6 +269,32 @@ typedef struct VoicevoxTtsOptions { bool enable_interrogative_upspeak; } VoicevoxTtsOptions; +/** + * ユーザー辞書の単語 + */ +typedef struct VoicevoxUserDictWord { + /** + * 表記 + */ + const char *surface; + /** + * 読み + */ + const char *pronunciation; + /** + * アクセント型 + */ + uintptr_t accent_type; + /** + * 単語の種類 + */ + VoicevoxUserDictWordType word_type; + /** + * 優先度 + */ + uint32_t priority; +} VoicevoxUserDictWord; + #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -227,7 +315,7 @@ extern const struct VoicevoxTtsOptions voicevox_default_tts_options; * 参照カウントで管理されたOpenJtalkを生成する * * # Safety - * @out_open_jtalk 自動でheap領域が割り当てられるため :voicevox_open_jtalk_rc_delete で開放する必要がある + * @out_open_jtalk 自動でheap領域が割り当てられるため :voicevox_open_jtalk_rc_delete で解放する必要がある */ #ifdef _WIN32 __declspec(dllimport) @@ -235,6 +323,22 @@ __declspec(dllimport) VoicevoxResultCode voicevox_open_jtalk_rc_new(const char *open_jtalk_dic_dir, struct OpenJtalkRc **out_open_jtalk); +/** + * OpenJtalkの使うユーザー辞書を設定する + * この関数を呼び出した後にユーザー辞書を変更した場合、再度この関数を呼び出す必要がある。 + * @param [in] open_jtalk 参照カウントで管理されたOpenJtalk + * @param [in] user_dict ユーザー辞書 + * + * # Safety + * @open_jtalk 有効な :OpenJtalkRc のポインタであること + * @user_dict 有効な :VoicevoxUserDict のポインタであること + */ +#ifdef _WIN32 +__declspec(dllimport) +#endif +VoicevoxResultCode voicevox_open_jtalk_rc_use_user_dict(const struct OpenJtalkRc *open_jtalk, + const struct VoicevoxUserDict *user_dict); + /** * 参照カウントで管理されたOpenJtalkを削除する * @param [in] open_jtalk 参照カウントで管理されたOpenJtalk @@ -598,6 +702,151 @@ __declspec(dllimport) #endif const char *voicevox_error_result_to_message(VoicevoxResultCode result_code); +/** + * VoicevoxUserDictWordを最低限のパラメータで作成する。 + * @param [in] surface 表記 + * @param [in] pronunciation 読み + * @return VoicevoxUserDictWord + * + * # Safety + * @param surface, pronunciation は有効な文字列へのポインタであること + */ +#ifdef _WIN32 +__declspec(dllimport) +#endif +struct VoicevoxUserDictWord voicevox_user_dict_word_make(const char *surface, + const char *pronunciation); + +/** + * ユーザー辞書を作成する + * @return VoicevoxUserDict + * + * # Safety + * @return 自動で解放されることはないので、呼び出し側で :voicevox_user_dict_delete で解放する必要がある + */ +#ifdef _WIN32 +__declspec(dllimport) +#endif +struct VoicevoxUserDict *voicevox_user_dict_new(void); + +/** + * ユーザー辞書にファイルを読み込ませる + * @param [in] user_dict VoicevoxUserDictのポインタ + * @param [in] dict_path 読み込む辞書ファイルのパス + * @return 結果コード #VoicevoxResultCode + * + * # Safety + * @param user_dict は有効な :VoicevoxUserDict のポインタであること + * @param dict_path パスが有効な文字列を指していること + */ +#ifdef _WIN32 +__declspec(dllimport) +#endif +VoicevoxResultCode voicevox_user_dict_load(const struct VoicevoxUserDict *user_dict, + const char *dict_path); + +/** + * ユーザー辞書に単語を追加する + * @param [in] user_dict VoicevoxUserDictのポインタ + * @param [in] word 追加する単語 + * @param [out] output_word_uuid 追加した単語のUUID + * @return 結果コード #VoicevoxResultCode + * + * # Safety + * @param user_dict は有効な :VoicevoxUserDict のポインタであること + * + */ +#ifdef _WIN32 +__declspec(dllimport) +#endif +VoicevoxResultCode voicevox_user_dict_add_word(const struct VoicevoxUserDict *user_dict, + const struct VoicevoxUserDictWord *word, + uint8_t (*output_word_uuid)[16]); + +/** + * ユーザー辞書の単語を更新する + * @param [in] user_dict VoicevoxUserDictのポインタ + * @param [in] word_uuid 更新する単語のUUID + * @param [in] word 新しい単語のデータ + * @return 結果コード #VoicevoxResultCode + * + * # Safety + * @param user_dict は有効な :VoicevoxUserDict のポインタであること + */ +#ifdef _WIN32 +__declspec(dllimport) +#endif +VoicevoxResultCode voicevox_user_dict_update_word(const struct VoicevoxUserDict *user_dict, + const uint8_t (*word_uuid)[16], + const struct VoicevoxUserDictWord *word); + +/** + * ユーザー辞書から単語を削除する + * @param [in] user_dict VoicevoxUserDictのポインタ + * @param [in] word_uuid 削除する単語のUUID + * @return 結果コード #VoicevoxResultCode + */ +#ifdef _WIN32 +__declspec(dllimport) +#endif +VoicevoxResultCode voicevox_user_dict_remove_word(const struct VoicevoxUserDict *user_dict, + const uint8_t (*word_uuid)[16]); + +/** + * ユーザー辞書の単語をJSON形式で出力する + * @param [in] user_dict VoicevoxUserDictのポインタ + * @param [out] output_json JSON形式の文字列 + * @return 結果コード #VoicevoxResultCode + * + * # Safety + * @param user_dict は有効な :VoicevoxUserDict のポインタであること + * @param output_json 自動でheapメモリが割り当てられるので ::voicevox_json_free で解放する必要がある + */ +#ifdef _WIN32 +__declspec(dllimport) +#endif +VoicevoxResultCode voicevox_user_dict_to_json(const struct VoicevoxUserDict *user_dict, + char **output_json); + +/** + * 他のユーザー辞書をインポートする + * @param [in] user_dict VoicevoxUserDictのポインタ + * @param [in] other_dict インポートするユーザー辞書 + * @return 結果コード #VoicevoxResultCode + */ +#ifdef _WIN32 +__declspec(dllimport) +#endif +VoicevoxResultCode voicevox_user_dict_import(const struct VoicevoxUserDict *user_dict, + const struct VoicevoxUserDict *other_dict); + +/** + * ユーザー辞書をファイルに保存する + * @param [in] user_dict VoicevoxUserDictのポインタ + * @param [in] path 保存先のファイルパス + * + * # Safety + * @param user_dict は有効な :VoicevoxUserDict のポインタであること + * @param path は有効なUTF-8文字列であること + */ +#ifdef _WIN32 +__declspec(dllimport) +#endif +VoicevoxResultCode voicevox_user_dict_save(const struct VoicevoxUserDict *user_dict, + const char *path); + +/** + * ユーザー辞書を廃棄する。 + * @param [in] user_dict VoicevoxUserDictのポインタ + * + * # Safety + * @param user_dict は有効な :VoicevoxUserDict のポインタであること + */ +#ifdef _WIN32 +__declspec(dllimport) +#endif +void voicevox_user_dict_delete(struct VoicevoxUserDict *user_dict); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/crates/voicevox_core_c_api/src/helpers.rs b/crates/voicevox_core_c_api/src/helpers.rs index c62fdeb83..08d1116de 100644 --- a/crates/voicevox_core_c_api/src/helpers.rs +++ b/crates/voicevox_core_c_api/src/helpers.rs @@ -1,4 +1,5 @@ use std::fmt::Debug; +use voicevox_core::UserDictWord; use const_default::ConstDefault; use thiserror::Error; @@ -39,9 +40,15 @@ pub(crate) fn into_result_code_with_error(result: CApiResult<()>) -> VoicevoxRes Err(RustApi(OpenFile { .. })) => VOICEVOX_OPEN_FILE_ERROR, Err(RustApi(VvmRead { .. })) => VOICEVOX_VVM_MODEL_READ_ERROR, Err(RustApi(ParseKana(_))) => VOICEVOX_RESULT_PARSE_KANA_ERROR, + Err(RustApi(LoadUserDict(_))) => VOICEVOX_LOAD_USER_DICT_ERROR, + Err(RustApi(SaveUserDict(_))) => VOICEVOX_SAVE_USER_DICT_ERROR, + Err(RustApi(UnknownWord(_))) => VOICEVOX_UNKNOWN_USER_DICT_WORD_ERROR, + Err(RustApi(UseUserDict(_))) => VOICEVOX_USE_USER_DICT_ERROR, + Err(RustApi(InvalidWord(_))) => VOICEVOX_INVALID_USER_DICT_WORD_ERROR, Err(InvalidUtf8Input) => VOICEVOX_RESULT_INVALID_UTF8_INPUT_ERROR, Err(InvalidAudioQuery(_)) => VOICEVOX_RESULT_INVALID_AUDIO_QUERY_ERROR, Err(InvalidAccentPhrase(_)) => VOICEVOX_RESULT_INVALID_ACCENT_PHRASE_ERROR, + Err(InvalidUuid(_)) => VOICEVOX_RESULT_INVALID_UUID_ERROR, } } } @@ -58,6 +65,8 @@ pub(crate) enum CApiError { InvalidAudioQuery(serde_json::Error), #[error("無効なAccentPhraseです: {0}")] InvalidAccentPhrase(serde_json::Error), + #[error("無効なUUIDです: {0}")] + InvalidUuid(uuid::Error), } pub(crate) fn audio_query_model_to_json(audio_query_model: &AudioQueryModel) -> String { @@ -175,3 +184,45 @@ impl ConstDefault for VoicevoxSynthesisOptions { } }; } + +impl VoicevoxUserDictWord { + pub(crate) unsafe fn try_into_word(&self) -> CApiResult { + Ok(UserDictWord::new( + ensure_utf8(CStr::from_ptr(self.surface))?, + ensure_utf8(CStr::from_ptr(self.pronunciation))?.to_string(), + self.accent_type, + self.word_type.into(), + self.priority, + )?) + } +} + +impl From for voicevox_core::UserDictWordType { + fn from(value: VoicevoxUserDictWordType) -> Self { + match value { + VoicevoxUserDictWordType::VOICEVOX_USER_DICT_WORD_TYPE_PROPER_NOUN => Self::ProperNoun, + VoicevoxUserDictWordType::VOICEVOX_USER_DICT_WORD_TYPE_COMMON_NOUN => Self::CommonNoun, + VoicevoxUserDictWordType::VOICEVOX_USER_DICT_WORD_TYPE_VERB => Self::Verb, + VoicevoxUserDictWordType::VOICEVOX_USER_DICT_WORD_TYPE_ADJECTIVE => Self::Adjective, + VoicevoxUserDictWordType::VOICEVOX_USER_DICT_WORD_TYPE_SUFFIX => Self::Suffix, + } + } +} + +impl From for VoicevoxUserDictWordType { + fn from(value: voicevox_core::UserDictWordType) -> Self { + match value { + voicevox_core::UserDictWordType::ProperNoun => { + Self::VOICEVOX_USER_DICT_WORD_TYPE_PROPER_NOUN + } + voicevox_core::UserDictWordType::CommonNoun => { + Self::VOICEVOX_USER_DICT_WORD_TYPE_COMMON_NOUN + } + voicevox_core::UserDictWordType::Verb => Self::VOICEVOX_USER_DICT_WORD_TYPE_VERB, + voicevox_core::UserDictWordType::Adjective => { + Self::VOICEVOX_USER_DICT_WORD_TYPE_ADJECTIVE + } + voicevox_core::UserDictWordType::Suffix => Self::VOICEVOX_USER_DICT_WORD_TYPE_SUFFIX, + } + } +} diff --git a/crates/voicevox_core_c_api/src/lib.rs b/crates/voicevox_core_c_api/src/lib.rs index dfee95707..ef141b4ae 100644 --- a/crates/voicevox_core_c_api/src/lib.rs +++ b/crates/voicevox_core_c_api/src/lib.rs @@ -21,9 +21,10 @@ use std::sync::{Arc, Mutex, MutexGuard}; use tokio::runtime::Runtime; use tracing_subscriber::fmt::format::Writer; use tracing_subscriber::EnvFilter; +use uuid::Uuid; use voicevox_core::{ - AccentPhraseModel, AudioQueryModel, AudioQueryOptions, OpenJtalk, TtsOptions, VoiceModel, - VoiceModelId, + AccentPhraseModel, AudioQueryModel, AudioQueryOptions, OpenJtalk, TtsOptions, UserDictWord, + VoiceModel, VoiceModelId, }; use voicevox_core::{StyleId, SupportedDevices, SynthesisOptions, Synthesizer}; @@ -84,7 +85,7 @@ pub struct OpenJtalkRc { /// 参照カウントで管理されたOpenJtalkを生成する /// /// # Safety -/// @out_open_jtalk 自動でheap領域が割り当てられるため :voicevox_open_jtalk_rc_delete で開放する必要がある +/// @out_open_jtalk 自動でheap領域が割り当てられるため :voicevox_open_jtalk_rc_delete で解放する必要がある #[no_mangle] pub unsafe extern "C" fn voicevox_open_jtalk_rc_new( open_jtalk_dic_dir: *const c_char, @@ -98,6 +99,29 @@ pub unsafe extern "C" fn voicevox_open_jtalk_rc_new( })()) } +/// OpenJtalkの使うユーザー辞書を設定する +/// この関数を呼び出した後にユーザー辞書を変更した場合、再度この関数を呼び出す必要がある。 +/// @param [in] open_jtalk 参照カウントで管理されたOpenJtalk +/// @param [in] user_dict ユーザー辞書 +/// +/// # Safety +/// @open_jtalk 有効な :OpenJtalkRc のポインタであること +/// @user_dict 有効な :VoicevoxUserDict のポインタであること +#[no_mangle] +pub extern "C" fn voicevox_open_jtalk_rc_use_user_dict( + open_jtalk: &OpenJtalkRc, + user_dict: &VoicevoxUserDict, +) -> VoicevoxResultCode { + into_result_code_with_error((|| { + let user_dict = user_dict.to_owned(); + { + let dict = user_dict.dict.as_ref().lock().expect("lock failed"); + open_jtalk.open_jtalk.use_user_dict(&dict)?; + } + Ok(()) + })()) +} + /// 参照カウントで管理されたOpenJtalkを削除する /// @param [in] open_jtalk 参照カウントで管理されたOpenJtalk /// @@ -692,6 +716,245 @@ pub extern "C" fn voicevox_error_result_to_message( C_STRING_DROP_CHECKER.blacklist(message).as_ptr() } +/// ユーザー辞書 +#[derive(Default)] +pub struct VoicevoxUserDict { + dict: Arc>, +} + +/// ユーザー辞書の単語 +#[repr(C)] +pub struct VoicevoxUserDictWord { + /// 表記 + surface: *const c_char, + /// 読み + pronunciation: *const c_char, + /// アクセント型 + accent_type: usize, + /// 単語の種類 + word_type: VoicevoxUserDictWordType, + /// 優先度 + priority: u32, +} + +/// ユーザー辞書の単語の種類 +#[repr(i32)] +#[allow(non_camel_case_types)] +#[derive(Copy, Clone)] +pub enum VoicevoxUserDictWordType { + /// 固有名詞。 + VOICEVOX_USER_DICT_WORD_TYPE_PROPER_NOUN = 0, + /// 一般名詞。 + VOICEVOX_USER_DICT_WORD_TYPE_COMMON_NOUN = 1, + /// 動詞。 + VOICEVOX_USER_DICT_WORD_TYPE_VERB = 2, + /// 形容詞。 + VOICEVOX_USER_DICT_WORD_TYPE_ADJECTIVE = 3, + /// 接尾辞。 + VOICEVOX_USER_DICT_WORD_TYPE_SUFFIX = 4, +} + +/// VoicevoxUserDictWordを最低限のパラメータで作成する。 +/// @param [in] surface 表記 +/// @param [in] pronunciation 読み +/// @return VoicevoxUserDictWord +/// +/// # Safety +/// @param surface, pronunciation は有効な文字列へのポインタであること +#[no_mangle] +pub extern "C" fn voicevox_user_dict_word_make( + surface: *const c_char, + pronunciation: *const c_char, +) -> VoicevoxUserDictWord { + VoicevoxUserDictWord { + surface, + pronunciation, + accent_type: UserDictWord::default().accent_type, + word_type: UserDictWord::default().word_type.into(), + priority: UserDictWord::default().priority, + } +} + +/// ユーザー辞書を作成する +/// @return VoicevoxUserDict +/// +/// # Safety +/// @return 自動で解放されることはないので、呼び出し側で :voicevox_user_dict_delete で解放する必要がある +#[no_mangle] +pub extern "C" fn voicevox_user_dict_new() -> Box { + Default::default() +} + +/// ユーザー辞書にファイルを読み込ませる +/// @param [in] user_dict VoicevoxUserDictのポインタ +/// @param [in] dict_path 読み込む辞書ファイルのパス +/// @return 結果コード #VoicevoxResultCode +/// +/// # Safety +/// @param user_dict は有効な :VoicevoxUserDict のポインタであること +/// @param dict_path パスが有効な文字列を指していること +#[no_mangle] +pub unsafe extern "C" fn voicevox_user_dict_load( + user_dict: &VoicevoxUserDict, + dict_path: *const c_char, +) -> VoicevoxResultCode { + into_result_code_with_error((|| { + let dict_path = ensure_utf8(unsafe { CStr::from_ptr(dict_path) })?; + let mut dict = user_dict.dict.lock().unwrap(); + dict.load(dict_path)?; + + Ok(()) + })()) +} + +/// ユーザー辞書に単語を追加する +/// @param [in] user_dict VoicevoxUserDictのポインタ +/// @param [in] word 追加する単語 +/// @param [out] output_word_uuid 追加した単語のUUID +/// @return 結果コード #VoicevoxResultCode +/// +/// # Safety +/// @param user_dict は有効な :VoicevoxUserDict のポインタであること +/// +#[no_mangle] +pub unsafe extern "C" fn voicevox_user_dict_add_word( + user_dict: &VoicevoxUserDict, + word: &VoicevoxUserDictWord, + output_word_uuid: NonNull<[u8; 16]>, +) -> VoicevoxResultCode { + into_result_code_with_error((|| { + let word = word.try_into_word()?; + let uuid = { + let mut dict = user_dict.dict.lock().expect("lock failed"); + dict.add_word(word)? + }; + output_word_uuid.as_ptr().copy_from(uuid.as_bytes(), 16); + + Ok(()) + })()) +} + +/// ユーザー辞書の単語を更新する +/// @param [in] user_dict VoicevoxUserDictのポインタ +/// @param [in] word_uuid 更新する単語のUUID +/// @param [in] word 新しい単語のデータ +/// @return 結果コード #VoicevoxResultCode +/// +/// # Safety +/// @param user_dict は有効な :VoicevoxUserDict のポインタであること +#[no_mangle] +pub unsafe extern "C" fn voicevox_user_dict_update_word( + user_dict: &VoicevoxUserDict, + word_uuid: &[u8; 16], + word: &VoicevoxUserDictWord, +) -> VoicevoxResultCode { + into_result_code_with_error((|| { + let word_uuid = Uuid::from_slice(word_uuid).map_err(CApiError::InvalidUuid)?; + let word = word.try_into_word()?; + { + let mut dict = user_dict.dict.lock().expect("lock failed"); + dict.update_word(word_uuid, word)?; + }; + + Ok(()) + })()) +} + +/// ユーザー辞書から単語を削除する +/// @param [in] user_dict VoicevoxUserDictのポインタ +/// @param [in] word_uuid 削除する単語のUUID +/// @return 結果コード #VoicevoxResultCode +#[no_mangle] +pub extern "C" fn voicevox_user_dict_remove_word( + user_dict: &VoicevoxUserDict, + word_uuid: &[u8; 16], +) -> VoicevoxResultCode { + into_result_code_with_error((|| { + let word_uuid = Uuid::from_slice(word_uuid).map_err(CApiError::InvalidUuid)?; + { + let mut dict = user_dict.dict.lock().expect("lock failed"); + dict.remove_word(word_uuid)?; + }; + + Ok(()) + })()) +} + +/// ユーザー辞書の単語をJSON形式で出力する +/// @param [in] user_dict VoicevoxUserDictのポインタ +/// @param [out] output_json JSON形式の文字列 +/// @return 結果コード #VoicevoxResultCode +/// +/// # Safety +/// @param user_dict は有効な :VoicevoxUserDict のポインタであること +/// @param output_json 自動でheapメモリが割り当てられるので ::voicevox_json_free で解放する必要がある +#[no_mangle] +pub unsafe extern "C" fn voicevox_user_dict_to_json( + user_dict: &VoicevoxUserDict, + output_json: NonNull<*mut c_char>, +) -> VoicevoxResultCode { + let dict = user_dict.dict.lock().expect("lock failed"); + let json = serde_json::to_string(&dict.words()).expect("should be always valid"); + let json = CString::new(json).expect("\\0を含まない文字列であることが保証されている"); + output_json + .as_ptr() + .write_unaligned(C_STRING_DROP_CHECKER.whitelist(json).into_raw()); + VoicevoxResultCode::VOICEVOX_RESULT_OK +} + +/// 他のユーザー辞書をインポートする +/// @param [in] user_dict VoicevoxUserDictのポインタ +/// @param [in] other_dict インポートするユーザー辞書 +/// @return 結果コード #VoicevoxResultCode +#[no_mangle] +pub extern "C" fn voicevox_user_dict_import( + user_dict: &VoicevoxUserDict, + other_dict: &VoicevoxUserDict, +) -> VoicevoxResultCode { + into_result_code_with_error((|| { + { + let mut dict = user_dict.dict.lock().expect("lock failed"); + let other_dict = other_dict.dict.lock().expect("lock failed"); + dict.import(&other_dict)?; + }; + + Ok(()) + })()) +} + +/// ユーザー辞書をファイルに保存する +/// @param [in] user_dict VoicevoxUserDictのポインタ +/// @param [in] path 保存先のファイルパス +/// +/// # Safety +/// @param user_dict は有効な :VoicevoxUserDict のポインタであること +/// @param path は有効なUTF-8文字列であること +#[no_mangle] +pub unsafe extern "C" fn voicevox_user_dict_save( + user_dict: &VoicevoxUserDict, + path: *const c_char, +) -> VoicevoxResultCode { + into_result_code_with_error((|| { + let path = ensure_utf8(CStr::from_ptr(path))?; + { + let dict = user_dict.dict.lock().expect("lock failed"); + dict.save(path)?; + }; + + Ok(()) + })()) +} + +/// ユーザー辞書を廃棄する。 +/// @param [in] user_dict VoicevoxUserDictのポインタ +/// +/// # Safety +/// @param user_dict は有効な :VoicevoxUserDict のポインタであること +#[no_mangle] +pub unsafe extern "C" fn voicevox_user_dict_delete(user_dict: Box) { + drop(user_dict); +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/voicevox_core_c_api/tests/e2e/snapshots.toml b/crates/voicevox_core_c_api/tests/e2e/snapshots.toml index 6250c6809..bf420e2b0 100644 --- a/crates/voicevox_core_c_api/tests/e2e/snapshots.toml +++ b/crates/voicevox_core_c_api/tests/e2e/snapshots.toml @@ -64,3 +64,12 @@ stderr.windows = ''' {windows-video-cards} ''' stderr.unix = "" + +[user_dict] +stderr.windows = ''' +{windows-video-cards} +''' +stderr.unix = "" + +[user_dict_manipulate] +stderr = "" diff --git a/crates/voicevox_core_c_api/tests/e2e/symbols.rs b/crates/voicevox_core_c_api/tests/e2e/symbols.rs index 7f0f1e079..810fbf91b 100644 --- a/crates/voicevox_core_c_api/tests/e2e/symbols.rs +++ b/crates/voicevox_core_c_api/tests/e2e/symbols.rs @@ -15,6 +15,10 @@ pub(crate) struct Symbols<'lib> { 'lib, unsafe extern "C" fn(*const c_char, *mut *mut OpenJtalkRc) -> VoicevoxResultCode, >, + pub(crate) voicevox_open_jtalk_rc_use_user_dict: Symbol< + 'lib, + unsafe extern "C" fn(*mut OpenJtalkRc, *const VoicevoxUserDict) -> VoicevoxResultCode, + >, pub(crate) voicevox_open_jtalk_rc_delete: Symbol<'lib, unsafe extern "C" fn(*mut OpenJtalkRc)>, pub(crate) voicevox_voice_model_new_from_path: Symbol< 'lib, @@ -121,6 +125,52 @@ pub(crate) struct Symbols<'lib> { 'lib, unsafe extern "C" fn(i64, i64, *mut f32, *mut f32, *mut i64, *mut f32) -> bool, >, + + pub(crate) voicevox_user_dict_word_make: + Symbol<'lib, unsafe extern "C" fn(*const c_char, *const c_char) -> VoicevoxUserDictWord>, + pub(crate) voicevox_user_dict_new: + Symbol<'lib, unsafe extern "C" fn() -> *mut VoicevoxUserDict>, + pub(crate) voicevox_user_dict_load: Symbol< + 'lib, + unsafe extern "C" fn(*const VoicevoxUserDict, *const c_char) -> VoicevoxResultCode, + >, + pub(crate) voicevox_user_dict_add_word: Symbol< + 'lib, + unsafe extern "C" fn( + *const VoicevoxUserDict, + *const VoicevoxUserDictWord, + *mut [u8; 16], + ) -> VoicevoxResultCode, + >, + pub(crate) voicevox_user_dict_update_word: Symbol< + 'lib, + unsafe extern "C" fn( + *const VoicevoxUserDict, + *const [u8; 16], + *const VoicevoxUserDictWord, + ) -> VoicevoxResultCode, + >, + pub(crate) voicevox_user_dict_remove_word: Symbol< + 'lib, + unsafe extern "C" fn(*const VoicevoxUserDict, *const [u8; 16]) -> VoicevoxResultCode, + >, + pub(crate) voicevox_user_dict_to_json: Symbol< + 'lib, + unsafe extern "C" fn(*const VoicevoxUserDict, *mut *mut c_char) -> VoicevoxResultCode, + >, + pub(crate) voicevox_user_dict_import: Symbol< + 'lib, + unsafe extern "C" fn( + *const VoicevoxUserDict, + *const VoicevoxUserDict, + ) -> VoicevoxResultCode, + >, + pub(crate) voicevox_user_dict_save: Symbol< + 'lib, + unsafe extern "C" fn(*const VoicevoxUserDict, *const c_char) -> VoicevoxResultCode, + >, + pub(crate) voicevox_user_dict_delete: + Symbol<'lib, unsafe extern "C" fn(*mut VoicevoxUserDict) -> VoicevoxResultCode>, } impl<'lib> Symbols<'lib> { @@ -140,6 +190,7 @@ impl<'lib> Symbols<'lib> { voicevox_default_synthesis_options, voicevox_default_tts_options, voicevox_open_jtalk_rc_new, + voicevox_open_jtalk_rc_use_user_dict, voicevox_open_jtalk_rc_delete, voicevox_voice_model_new_from_path, voicevox_voice_model_id, @@ -169,6 +220,16 @@ impl<'lib> Symbols<'lib> { yukarin_s_forward, yukarin_sa_forward, decode_forward, + voicevox_user_dict_word_make, + voicevox_user_dict_new, + voicevox_user_dict_load, + voicevox_user_dict_add_word, + voicevox_user_dict_update_word, + voicevox_user_dict_remove_word, + voicevox_user_dict_to_json, + voicevox_user_dict_import, + voicevox_user_dict_save, + voicevox_user_dict_delete, )) } } @@ -210,3 +271,23 @@ pub(crate) struct VoicevoxTtsOptions { _kana: bool, _enable_interrogative_upspeak: bool, } + +#[repr(C)] +pub(crate) struct VoicevoxUserDict { + _private: [u8; 0], +} + +#[repr(C)] +pub(crate) struct VoicevoxUserDictWord { + pub(crate) surface: *const c_char, + pub(crate) pronunciation: *const c_char, + pub(crate) accent_type: usize, + pub(crate) word_type: VoicevoxUserDictWordType, + pub(crate) priority: u32, +} + +#[repr(i32)] +#[allow(non_camel_case_types)] +pub(crate) enum VoicevoxUserDictWordType { + VOICEVOX_USER_DICT_WORD_TYPE_PROPER_NOUN = 0, +} diff --git a/crates/voicevox_core_c_api/tests/e2e/testcases.rs b/crates/voicevox_core_c_api/tests/e2e/testcases.rs index 8a02bdcad..f4d2920c7 100644 --- a/crates/voicevox_core_c_api/tests/e2e/testcases.rs +++ b/crates/voicevox_core_c_api/tests/e2e/testcases.rs @@ -3,3 +3,5 @@ mod compatible_engine_load_model_before_initialize; mod global_info; mod simple_tts; mod tts_via_audio_query; +mod user_dict_load; +mod user_dict_manipulate; diff --git a/crates/voicevox_core_c_api/tests/e2e/testcases/user_dict_load.rs b/crates/voicevox_core_c_api/tests/e2e/testcases/user_dict_load.rs new file mode 100644 index 000000000..b9284c153 --- /dev/null +++ b/crates/voicevox_core_c_api/tests/e2e/testcases/user_dict_load.rs @@ -0,0 +1,167 @@ +// ユーザー辞書の登録によって読みが変化することを確認するテスト。 +// 辞書ロード前後でAudioQueryのkanaが変化するかどうかで確認する。 + +use crate::symbols::VoicevoxInitializeOptions; +use assert_cmd::assert::AssertResult; +use once_cell::sync::Lazy; +use std::ffi::{CStr, CString}; +use std::mem::MaybeUninit; +use test_util::OPEN_JTALK_DIC_DIR; +use voicevox_core::result_code::VoicevoxResultCode; + +use libloading::Library; +use serde::{Deserialize, Serialize}; + +use crate::{ + assert_cdylib::{self, case, Utf8Output}, + snapshots, + symbols::{Symbols, VoicevoxAccelerationMode, VoicevoxUserDictWordType}, +}; + +macro_rules! cstr { + ($s:literal $(,)?) => { + CStr::from_bytes_with_nul(concat!($s, '\0').as_ref()).unwrap() + }; +} + +case!(TestCase); + +#[derive(Serialize, Deserialize)] +struct TestCase; + +#[typetag::serde(name = "user_dict_load")] +impl assert_cdylib::TestCase for TestCase { + unsafe fn exec(&self, lib: &Library) -> anyhow::Result<()> { + let Symbols { + voicevox_user_dict_word_make, + voicevox_user_dict_new, + voicevox_user_dict_add_word, + voicevox_user_dict_delete, + voicevox_default_initialize_options, + voicevox_default_audio_query_options, + voicevox_open_jtalk_rc_new, + voicevox_open_jtalk_rc_use_user_dict, + voicevox_open_jtalk_rc_delete, + voicevox_voice_model_new_from_path, + voicevox_voice_model_delete, + voicevox_synthesizer_new_with_initialize, + voicevox_synthesizer_delete, + voicevox_synthesizer_load_voice_model, + voicevox_synthesizer_audio_query, + .. + } = Symbols::new(lib)?; + + let dict = voicevox_user_dict_new(); + + let mut word_uuid = [0u8; 16]; + + let word = { + let mut word = voicevox_user_dict_word_make( + cstr!("this_word_should_not_exist_in_default_dictionary").as_ptr(), + cstr!("アイウエオ").as_ptr(), + ); + word.word_type = VoicevoxUserDictWordType::VOICEVOX_USER_DICT_WORD_TYPE_PROPER_NOUN; + word.priority = 10; + + word + }; + + assert_ok(voicevox_user_dict_add_word(dict, &word, &mut word_uuid)); + + let model = { + let mut model = MaybeUninit::uninit(); + assert_ok(voicevox_voice_model_new_from_path( + cstr!("../../model/sample.vvm").as_ptr(), + model.as_mut_ptr(), + )); + model.assume_init() + }; + + let openjtalk = { + let mut openjtalk = MaybeUninit::uninit(); + let open_jtalk_dic_dir = CString::new(OPEN_JTALK_DIC_DIR).unwrap(); + assert_ok(voicevox_open_jtalk_rc_new( + open_jtalk_dic_dir.as_ptr(), + openjtalk.as_mut_ptr(), + )); + openjtalk.assume_init() + }; + + let synthesizer = { + let mut synthesizer = MaybeUninit::uninit(); + assert_ok(voicevox_synthesizer_new_with_initialize( + openjtalk, + VoicevoxInitializeOptions { + acceleration_mode: VoicevoxAccelerationMode::VOICEVOX_ACCELERATION_MODE_CPU, + ..**voicevox_default_initialize_options + }, + synthesizer.as_mut_ptr(), + )); + synthesizer.assume_init() + }; + + assert_ok(voicevox_synthesizer_load_voice_model(synthesizer, model)); + + let mut audio_query_without_dict = std::ptr::null_mut(); + assert_ok(voicevox_synthesizer_audio_query( + synthesizer, + cstr!("this_word_should_not_exist_in_default_dictionary").as_ptr(), + STYLE_ID, + **voicevox_default_audio_query_options, + &mut audio_query_without_dict, + )); + let audio_query_without_dict = serde_json::from_str::( + CStr::from_ptr(audio_query_without_dict).to_str()?, + )?; + + assert_ok(voicevox_open_jtalk_rc_use_user_dict(openjtalk, dict)); + + let mut audio_query_with_dict = std::ptr::null_mut(); + assert_ok(voicevox_synthesizer_audio_query( + synthesizer, + cstr!("this_word_should_not_exist_in_default_dictionary").as_ptr(), + STYLE_ID, + **voicevox_default_audio_query_options, + &mut audio_query_with_dict, + )); + + let audio_query_with_dict = serde_json::from_str::( + CStr::from_ptr(audio_query_with_dict).to_str()?, + )?; + + assert_ne!( + audio_query_without_dict.get("kana"), + audio_query_with_dict.get("kana") + ); + + voicevox_voice_model_delete(model); + voicevox_open_jtalk_rc_delete(openjtalk); + voicevox_synthesizer_delete(synthesizer); + voicevox_user_dict_delete(dict); + + return Ok(()); + + fn assert_ok(result_code: VoicevoxResultCode) { + std::assert_eq!(VoicevoxResultCode::VOICEVOX_RESULT_OK, result_code); + } + const STYLE_ID: u32 = 0; + } + + fn assert_output(&self, output: Utf8Output) -> AssertResult { + output + .mask_timestamps() + .mask_windows_video_cards() + .assert() + .try_success()? + .try_stdout("")? + .try_stderr(&*SNAPSHOTS.stderr) + } +} + +static SNAPSHOTS: Lazy = snapshots::section!(user_dict); + +#[derive(Deserialize)] +struct Snapshots { + #[serde(deserialize_with = "snapshots::deserialize_platform_specific_snapshot")] + stderr: String, +} diff --git a/crates/voicevox_core_c_api/tests/e2e/testcases/user_dict_manipulate.rs b/crates/voicevox_core_c_api/tests/e2e/testcases/user_dict_manipulate.rs new file mode 100644 index 000000000..2468e8f20 --- /dev/null +++ b/crates/voicevox_core_c_api/tests/e2e/testcases/user_dict_manipulate.rs @@ -0,0 +1,189 @@ +// ユーザー辞書の操作をテストする。 + +use assert_cmd::assert::AssertResult; +use once_cell::sync::Lazy; +use std::{ + ffi::{CStr, CString}, + mem::MaybeUninit, +}; +use tempfile::NamedTempFile; +use uuid::Uuid; +use voicevox_core::result_code::VoicevoxResultCode; + +use libloading::Library; +use serde::{Deserialize, Serialize}; + +use crate::{ + assert_cdylib::{self, case, Utf8Output}, + snapshots, + symbols::{Symbols, VoicevoxUserDict, VoicevoxUserDictWord}, +}; + +case!(TestCase); + +#[derive(Serialize, Deserialize)] +struct TestCase; + +macro_rules! cstr { + ($s:literal $(,)?) => { + CStr::from_bytes_with_nul(concat!($s, '\0').as_ref()).unwrap() + }; +} + +#[typetag::serde(name = "user_dict_manipulate")] +impl assert_cdylib::TestCase for TestCase { + unsafe fn exec(&self, lib: &Library) -> anyhow::Result<()> { + let Symbols { + voicevox_user_dict_word_make, + voicevox_user_dict_new, + voicevox_user_dict_add_word, + voicevox_user_dict_update_word, + voicevox_user_dict_remove_word, + voicevox_user_dict_to_json, + voicevox_user_dict_import, + voicevox_user_dict_load, + voicevox_user_dict_save, + voicevox_user_dict_delete, + voicevox_json_free, + .. + } = Symbols::new(lib)?; + + let get_json = |dict: &*mut VoicevoxUserDict| -> String { + let mut json = MaybeUninit::uninit(); + assert_ok(voicevox_user_dict_to_json( + (*dict) as *const _, + json.as_mut_ptr(), + )); + + let ret = CStr::from_ptr(json.assume_init()) + .to_str() + .unwrap() + .to_string(); + + voicevox_json_free(json.assume_init()); + + serde_json::from_str::(&ret).expect("invalid json"); + + ret + }; + + let add_word = |dict: *const VoicevoxUserDict, word: &VoicevoxUserDictWord| -> Uuid { + let mut word_uuid = [0u8; 16]; + + assert_ok(voicevox_user_dict_add_word( + dict, + word as *const _, + &mut word_uuid, + )); + + Uuid::from_slice(&word_uuid).expect("invalid uuid") + }; + + // テスト用の辞書ファイルを作成 + let dict = voicevox_user_dict_new(); + + // 単語の追加のテスト + let word = voicevox_user_dict_word_make(cstr!("hoge").as_ptr(), cstr!("ホゲ").as_ptr()); + + let word_uuid = add_word(dict, &word); + + let json = get_json(&dict); + + assert!(json.contains("hoge")); + assert!(json.contains("ホゲ")); + assert_contains_uuid(&json, &word_uuid); + + // 単語の変更のテスト + let word = voicevox_user_dict_word_make(cstr!("fuga").as_ptr(), cstr!("フガ").as_ptr()); + + assert_ok(voicevox_user_dict_update_word( + dict, + &word_uuid.into_bytes(), + &word, + )); + + let json = get_json(&dict); + + assert!(!json.contains("hoge")); + assert!(!json.contains("ホゲ")); + assert!(json.contains("fuga")); + assert!(json.contains("フガ")); + assert_contains_uuid(&json, &word_uuid); + + // 辞書のインポートのテスト。 + let other_dict = voicevox_user_dict_new(); + + let other_word = + voicevox_user_dict_word_make(cstr!("piyo").as_ptr(), cstr!("ピヨ").as_ptr()); + + let other_word_uuid = add_word(other_dict, &other_word); + + assert_ok(voicevox_user_dict_import(dict, other_dict)); + + let json = get_json(&dict); + assert!(json.contains("fuga")); + assert!(json.contains("フガ")); + assert_contains_uuid(&json, &word_uuid); + assert!(json.contains("piyo")); + assert!(json.contains("ピヨ")); + assert_contains_uuid(&json, &other_word_uuid); + + // 単語の削除のテスト + assert_ok(voicevox_user_dict_remove_word( + dict, + &word_uuid.into_bytes(), + )); + + let json = get_json(&dict); + assert_not_contains_uuid(&json, &word_uuid); + // 他の単語は残っている + assert_contains_uuid(&json, &other_word_uuid); + + // 辞書のセーブ・ロードのテスト + let temp_path = NamedTempFile::new().unwrap().into_temp_path(); + let temp_path = CString::new(temp_path.to_str().unwrap()).unwrap(); + let word = voicevox_user_dict_word_make(cstr!("hoge").as_ptr(), cstr!("ホゲ").as_ptr()); + let word_uuid = add_word(dict, &word); + + assert_ok(voicevox_user_dict_save(dict, temp_path.as_ptr())); + assert_ok(voicevox_user_dict_load(other_dict, temp_path.as_ptr())); + + let json = get_json(&other_dict); + assert_contains_uuid(&json, &word_uuid); + assert_contains_uuid(&json, &other_word_uuid); + + voicevox_user_dict_delete(dict); + voicevox_user_dict_delete(other_dict); + + return Ok(()); + + fn assert_ok(result_code: VoicevoxResultCode) { + std::assert_eq!(VoicevoxResultCode::VOICEVOX_RESULT_OK, result_code); + } + + fn assert_contains_uuid(text: &str, pattern: &Uuid) { + assert!(text.contains(pattern.to_string().as_str())); + } + + fn assert_not_contains_uuid(text: &str, pattern: &Uuid) { + assert!(!text.contains(pattern.to_string().as_str())); + } + } + + fn assert_output(&self, output: Utf8Output) -> AssertResult { + output + .mask_timestamps() + .mask_windows_video_cards() + .assert() + .try_success()? + .try_stdout("")? + .try_stderr(&*SNAPSHOTS.stderr) + } +} + +static SNAPSHOTS: Lazy = snapshots::section!(user_dict_manipulate); + +#[derive(Deserialize)] +struct Snapshots { + stderr: String, +} diff --git a/crates/voicevox_core_python_api/.gitignore b/crates/voicevox_core_python_api/.gitignore index 2caff6041..f08864f3d 100644 --- a/crates/voicevox_core_python_api/.gitignore +++ b/crates/voicevox_core_python_api/.gitignore @@ -2,6 +2,11 @@ /python/voicevox_core/model/ # Maturin -*.abi3.dll +*.pyd *.abi3.dylib *.abi3.so + +# onnxruntime +onnxruntime*.dll +libonnxruntime.so* +libonnxruntime.dylib* diff --git a/crates/voicevox_core_python_api/Cargo.toml b/crates/voicevox_core_python_api/Cargo.toml index bad9b7fb5..a5b09bd6e 100644 --- a/crates/voicevox_core_python_api/Cargo.toml +++ b/crates/voicevox_core_python_api/Cargo.toml @@ -26,4 +26,5 @@ serde_json.workspace = true test_util.workspace = true tokio.workspace = true tracing.workspace = true +uuid.workspace = true voicevox_core.workspace = true diff --git a/crates/voicevox_core_python_api/python/test/conftest.py b/crates/voicevox_core_python_api/python/test/conftest.py index ca947f850..54449324c 100644 --- a/crates/voicevox_core_python_api/python/test/conftest.py +++ b/crates/voicevox_core_python_api/python/test/conftest.py @@ -9,6 +9,11 @@ root_dir = Path(os.path.dirname(os.path.abspath(__file__))) +open_jtalk_dic_dir = ( + root_dir.parent.parent.parent / "test_util" / "data" / "open_jtalk_dic_utf_8-1.11" +) +model_dir = root_dir.parent.parent.parent.parent / "model" / "sample.vvm" + class DurationExampleData(TypedDict): length: int diff --git a/crates/voicevox_core_python_api/python/test/test_user_dict_load.py b/crates/voicevox_core_python_api/python/test/test_user_dict_load.py new file mode 100644 index 000000000..8c506a50f --- /dev/null +++ b/crates/voicevox_core_python_api/python/test/test_user_dict_load.py @@ -0,0 +1,38 @@ +# ユーザー辞書の単語が反映されるかをテストする。 +# AudioQueryのkanaを比較して変化するかどうかで判断する。 + +from uuid import UUID +import pytest +import conftest # noqa: F401 +import voicevox_core # noqa: F401 + + +@pytest.mark.asyncio +async def test_user_dict_load() -> None: + open_jtalk = voicevox_core.OpenJtalk(conftest.open_jtalk_dic_dir) + model = await voicevox_core.VoiceModel.from_path(conftest.model_dir) + synthesizer = await voicevox_core.Synthesizer.new_with_initialize( + open_jtalk=open_jtalk, + ) + + await synthesizer.load_voice_model(model) + + audio_query_without_dict = await synthesizer.audio_query( + "this_word_should_not_exist_in_default_dictionary", style_id=0, kana=False + ) + + temp_dict = voicevox_core.UserDict() + uuid = temp_dict.add_word( + voicevox_core.UserDictWord( + surface="this_word_should_not_exist_in_default_dictionary", + pronunciation="アイウエオ", + ) + ) + assert isinstance(uuid, UUID) + + open_jtalk.use_user_dict(temp_dict) + + audio_query_with_dict = await synthesizer.audio_query( + "this_word_should_not_exist_in_default_dictionary", style_id=0, kana=False + ) + assert audio_query_without_dict != audio_query_with_dict diff --git a/crates/voicevox_core_python_api/python/test/test_user_dict_manipulate.py b/crates/voicevox_core_python_api/python/test/test_user_dict_manipulate.py new file mode 100644 index 000000000..8ed860cf9 --- /dev/null +++ b/crates/voicevox_core_python_api/python/test/test_user_dict_manipulate.py @@ -0,0 +1,68 @@ +# ユーザー辞書の操作をテストする。 +# どのコードがどの操作を行っているかはコメントを参照。 + +import os +from uuid import UUID +import tempfile +import pytest +import voicevox_core # noqa: F401 + + +@pytest.mark.asyncio +async def test_user_dict_load() -> None: + dict_a = voicevox_core.UserDict() + + # 単語の追加 + uuid_a = dict_a.add_word( + voicevox_core.UserDictWord( + surface="hoge", + pronunciation="ホゲ", + ) + ) + assert isinstance(uuid_a, UUID) + assert dict_a.words[uuid_a].surface == "hoge" + assert dict_a.words[uuid_a].pronunciation == "ホゲ" + + # 単語の更新 + dict_a.update_word( + uuid_a, + voicevox_core.UserDictWord( + surface="fuga", + pronunciation="フガ", + ), + ) + + assert dict_a.words[uuid_a].surface == "fuga" + assert dict_a.words[uuid_a].pronunciation == "フガ" + + # ユーザー辞書のインポート + dict_b = voicevox_core.UserDict() + uuid_b = dict_b.add_word( + voicevox_core.UserDictWord( + surface="foo", + pronunciation="フー", + ) + ) + + dict_a.import_dict(dict_b) + assert uuid_b in dict_a.words + + # ユーザー辞書のエクスポート + dict_c = voicevox_core.UserDict() + uuid_c=dict_c.add_word( + voicevox_core.UserDictWord( + surface="bar", + pronunciation="バー", + ) + ) + temp_path_fd, temp_path = tempfile.mkstemp() + os.close(temp_path_fd) + dict_c.save(temp_path) + dict_a.load(temp_path) + assert uuid_a in dict_a.words + assert uuid_c in dict_a.words + + # 単語の削除 + dict_a.remove_word(uuid_a) + assert uuid_a not in dict_a.words + assert uuid_c in dict_a.words diff --git a/crates/voicevox_core_python_api/python/voicevox_core/__init__.py b/crates/voicevox_core_python_api/python/voicevox_core/__init__.py index 7b1dce0c8..da1866584 100644 --- a/crates/voicevox_core_python_api/python/voicevox_core/__init__.py +++ b/crates/voicevox_core_python_api/python/voicevox_core/__init__.py @@ -6,8 +6,16 @@ Mora, SpeakerMeta, SupportedDevices, + UserDictWord, + UserDictWordType, ) -from ._rust import OpenJtalk, Synthesizer, VoiceModel, supported_devices # noqa: F401 +from ._rust import ( + OpenJtalk, + Synthesizer, + VoiceModel, + UserDict, + supported_devices, +) # noqa: F401 __all__ = [ "AccelerationMode", @@ -20,4 +28,7 @@ "Synthesizer", "VoiceModel", "supported_devices", + "UserDict", + "UserDictWord", + "UserDictWordType", ] diff --git a/crates/voicevox_core_python_api/python/voicevox_core/_models.py b/crates/voicevox_core_python_api/python/voicevox_core/_models.py index 213e9bfee..94d7a8a16 100644 --- a/crates/voicevox_core_python_api/python/voicevox_core/_models.py +++ b/crates/voicevox_core_python_api/python/voicevox_core/_models.py @@ -1,3 +1,4 @@ +import dataclasses from enum import Enum from typing import List, Optional @@ -69,3 +70,22 @@ class AudioQuery: output_sampling_rate: int output_stereo: bool kana: Optional[str] + + +class UserDictWordType(str, Enum): + PROPER_NOUN = "PROPER_NOUN" + COMMON_NOUN = "COMMON_NOUN" + VERB = "VERB" + ADJECTIVE = "ADJECTIVE" + SUFFIX = "SUFFIX" + + +@pydantic.dataclasses.dataclass +class UserDictWord: + surface: str + pronunciation: str + accent_type: int = dataclasses.field(default=0) + word_type: UserDictWordType = dataclasses.field( + default=UserDictWordType.COMMON_NOUN + ) + priority: int = dataclasses.field(default=5) diff --git a/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi b/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi index 0ae709e3f..020545528 100644 --- a/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi +++ b/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi @@ -1,5 +1,6 @@ from pathlib import Path -from typing import Final, List, Literal, Union +from typing import Dict, Final, List, Literal, Union +from uuid import UUID import numpy as np from numpy.typing import NDArray @@ -9,6 +10,8 @@ from voicevox_core import ( AudioQuery, SpeakerMeta, SupportedDevices, + UserDict, + UserDictWord, ) __version__: str @@ -39,6 +42,17 @@ class OpenJtalk: open_jtalkの辞書ディレクトリ。 """ ... + def use_user_dict(self, user_dict: UserDict) -> None: + """ユーザー辞書を設定する。 + + この関数を読んだ後にユーザー辞書を変更した場合は、再度この関数を呼ぶ必要がある。 + + Parameters + ---------- + user_dict + ユーザー辞書。 + """ + ... class Synthesizer: @staticmethod @@ -102,8 +116,6 @@ class Synthesizer: モデルが読み込まれているのであればtrue、そうでないならfalse """ ... - def unload_voice_model(self, voice_model_id: str) -> None: - """指定したvoice_model_idのモデルがを破棄する""" async def audio_query( self, text: str, @@ -245,3 +257,77 @@ class Synthesizer: 疑問文の調整を有効にする。 """ ... + +class UserDict: + """ユーザー辞書。 + + Attributes + ---------- + words + エントリーのリスト。 + """ + + words: Dict[UUID, UserDictWord] + def __init__(self) -> None: + """ユーザー辞書をまたは新規作成する。""" + ... + def load(self, path: str) -> None: + """ファイルに保存されたユーザー辞書を読み込む。 + + Parameters + ---------- + path + ユーザー辞書のパス。 + """ + ... + def save(self, path: str) -> None: + """ユーザー辞書をファイルに保存する。 + + Parameters + ---------- + path + ユーザー辞書のパス。 + """ + ... + def add_word(self, word: UserDictWord) -> UUID: + """単語を追加する。 + + Parameters + ---------- + word + 追加する単語。 + + Returns + ------- + 単語のUUID。 + """ + ... + def update_word(self, word_uuid: UUID, word: UserDictWord) -> None: + """単語を更新する。 + + Parameters + ---------- + word_uuid + 更新する単語のUUID。 + word + 新しい単語のデータ。 + """ + ... + def remove_word(self, word_uuid: UUID) -> None: + """単語を削除する。 + + Parameters + ---------- + word_uuid + 削除する単語のUUID。 + """ + ... + def import_dict(self, other: UserDict) -> None: + """ユーザー辞書をインポートする。 + + Parameters + ---------- + other + インポートするユーザー辞書。 + """ + ... diff --git a/crates/voicevox_core_python_api/requirements-test.txt b/crates/voicevox_core_python_api/requirements-test.txt index 9e24dd984..c310a65bc 100644 --- a/crates/voicevox_core_python_api/requirements-test.txt +++ b/crates/voicevox_core_python_api/requirements-test.txt @@ -1,2 +1,3 @@ pytest==7.3.1 +pytest-asyncio==0.21.0 numpy==1.24.3 diff --git a/crates/voicevox_core_python_api/src/lib.rs b/crates/voicevox_core_python_api/src/lib.rs index e40294609..21346680f 100644 --- a/crates/voicevox_core_python_api/src/lib.rs +++ b/crates/voicevox_core_python_api/src/lib.rs @@ -7,14 +7,17 @@ use pyo3::{ create_exception, exceptions::PyException, pyclass, pyfunction, pymethods, pymodule, - types::{PyBytes, PyList, PyModule}, - wrap_pyfunction, FromPyObject as _, PyAny, PyResult, Python, ToPyObject, + types::{IntoPyDict as _, PyBytes, PyDict, PyList, PyModule}, + wrap_pyfunction, FromPyObject as _, PyAny, PyObject, PyResult, Python, ToPyObject, }; use serde::{de::DeserializeOwned, Serialize}; +use serde_json::json; use tokio::{runtime::Runtime, sync::Mutex}; +use uuid::Uuid; use voicevox_core::{ AccelerationMode, AccentPhraseModel, AccentPhrasesOptions, AudioQueryModel, AudioQueryOptions, - InitializeOptions, StyleId, SynthesisOptions, TtsOptions, VoiceModelId, VoiceModelMeta, + InitializeOptions, StyleId, SynthesisOptions, TtsOptions, UserDictWord, UserDictWordType, + VoiceModelId, VoiceModelMeta, }; static RUNTIME: Lazy = Lazy::new(|| Runtime::new().unwrap()); @@ -29,7 +32,9 @@ fn rust(_py: Python<'_>, module: &PyModule) -> PyResult<()> { module.add_class::()?; module.add_class::()?; - module.add_class::() + module.add_class::()?; + module.add_class::()?; + Ok(()) } create_exception!( @@ -98,6 +103,12 @@ impl OpenJtalk { ), }) } + + fn use_user_dict(&self, user_dict: UserDict) -> PyResult<()> { + self.open_jtalk + .use_user_dict(&user_dict.dict) + .into_py_result() + } } #[pyclass] @@ -361,6 +372,75 @@ impl Synthesizer { } } +#[pyclass] +#[derive(Default, Debug, Clone)] +struct UserDict { + dict: voicevox_core::UserDict, +} + +#[pymethods] +impl UserDict { + #[new] + fn new() -> Self { + Self::default() + } + + fn load(&mut self, path: &str) -> PyResult<()> { + self.dict.load(path).into_py_result() + } + + fn save(&self, path: &str) -> PyResult<()> { + self.dict.save(path).into_py_result() + } + + fn add_word( + &mut self, + #[pyo3(from_py_with = "to_rust_user_dict_word")] word: UserDictWord, + py: Python, + ) -> PyResult { + let uuid = self.dict.add_word(word).into_py_result()?; + + to_py_uuid(py, uuid) + } + + fn update_word( + &mut self, + #[pyo3(from_py_with = "to_rust_uuid")] word_uuid: Uuid, + #[pyo3(from_py_with = "to_rust_user_dict_word")] word: UserDictWord, + ) -> PyResult<()> { + self.dict.update_word(word_uuid, word).into_py_result()?; + Ok(()) + } + + fn remove_word( + &mut self, + #[pyo3(from_py_with = "to_rust_uuid")] word_uuid: Uuid, + ) -> PyResult<()> { + self.dict.remove_word(word_uuid).into_py_result()?; + Ok(()) + } + + fn import_dict(&mut self, other: &UserDict) -> PyResult<()> { + self.dict.import(&other.dict).into_py_result()?; + Ok(()) + } + + #[getter] + fn words<'py>(&self, py: Python<'py>) -> PyResult<&'py PyDict> { + let words = self + .dict + .words() + .iter() + .map(|(&uuid, word)| { + let uuid = to_py_uuid(py, uuid)?; + let word = to_py_user_dict_word(py, word)?; + Ok((uuid, word)) + }) + .collect::>>()?; + Ok(words.into_py_dict(py)) + } +} + fn from_acceleration_mode(ob: &PyAny) -> PyResult { let py = ob.py(); @@ -457,6 +537,41 @@ where ) } +fn to_rust_uuid(ob: &PyAny) -> PyResult { + let uuid = ob.getattr("hex")?.extract::()?; + uuid.parse().into_py_result() +} +fn to_py_uuid(py: Python, uuid: Uuid) -> PyResult { + let uuid = uuid.hyphenated().to_string(); + let uuid = py.import("uuid")?.call_method1("UUID", (uuid,))?; + Ok(uuid.to_object(py)) +} +fn to_rust_user_dict_word(ob: &PyAny) -> PyResult { + voicevox_core::UserDictWord::new( + ob.getattr("surface")?.extract()?, + ob.getattr("pronunciation")?.extract()?, + ob.getattr("accent_type")?.extract()?, + to_rust_word_type(ob.getattr("word_type")?.extract()?)?, + ob.getattr("priority")?.extract()?, + ) + .into_py_result() +} +fn to_py_user_dict_word<'py>( + py: Python<'py>, + word: &voicevox_core::UserDictWord, +) -> PyResult<&'py PyAny> { + let class = py + .import("voicevox_core")? + .getattr("UserDictWord")? + .downcast()?; + to_pydantic_dataclass(word, class) +} +fn to_rust_word_type(word_type: &PyAny) -> PyResult { + let name = word_type.getattr("name")?.extract::()?; + + serde_json::from_value::(json!(name)).into_py_result() +} + impl Drop for Synthesizer { fn drop(&mut self) { debug!("Destructing a VoicevoxCore");