From e25470b06f9cadc9293bd8b99089a8b0ee259bf8 Mon Sep 17 00:00:00 2001 From: takana-v <44311840+takana-v@users.noreply.github.com> Date: Thu, 12 May 2022 21:46:28 +0900 Subject: [PATCH] =?UTF-8?q?=E3=83=A2=E3=83=87=E3=83=AB=E3=81=AE=E8=AA=AD?= =?UTF-8?q?=E3=81=BF=E8=BE=BC=E3=81=BF=E3=82=92=E9=81=85=E5=BB=B6=E3=81=95?= =?UTF-8?q?=E3=81=9B=E3=82=8B=E3=82=88=E3=81=86=E3=81=AB=E5=A4=89=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- run.py | 2 ++ .../synthesis_engine/core_wrapper.py | 31 ++++++++++++++++--- .../make_synthesis_engines.py | 5 ++- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/run.py b/run.py index d4434e58d..662022c7c 100644 --- a/run.py +++ b/run.py @@ -861,6 +861,7 @@ def engine_manifest(): parser.add_argument("--enable_cancellable_synthesis", action="store_true") parser.add_argument("--enable_guided_synthesis", action="store_true") parser.add_argument("--init_processes", type=int, default=2) + parser.add_argument("--load_all_models", action="store_true") # 引数へcpu_num_threadsの指定がなければ、環境変数をロールします。 # 環境変数にもない場合は、Noneのままとします。 @@ -880,6 +881,7 @@ def engine_manifest(): runtime_dirs=args.runtime_dir, cpu_num_threads=cpu_num_threads, enable_mock=args.enable_mock, + load_all_models=args.load_all_models, ) assert len(synthesis_engines) != 0, "音声合成エンジンがありません。" latest_core_version = str(max([LooseVersion(ver) for ver in synthesis_engines])) diff --git a/voicevox_engine/synthesis_engine/core_wrapper.py b/voicevox_engine/synthesis_engine/core_wrapper.py index 7521724ab..e8c13f030 100644 --- a/voicevox_engine/synthesis_engine/core_wrapper.py +++ b/voicevox_engine/synthesis_engine/core_wrapper.py @@ -339,7 +339,7 @@ def __init__( use_gpu: bool, core_dir: Path, cpu_num_threads: int = 0, - load_all_models: bool = True, + load_all_models: bool = False, ) -> None: self.core = load_core(core_dir, use_gpu) @@ -354,12 +354,17 @@ def __init__( self.exist_suppoted_devices = False self.exist_finalize = False exist_cpu_num_threads = False - # TODO: version 0.12 から追加された load_model, is_model_loaded 関数に対応するため、 - # self.exist_load_model, self.exist_is_model_loaded 変数を定義する + self.exist_load_model = False + self.exist_is_model_loaded = False if is_version_0_12_core_or_later(core_dir): model_type = "onnxruntime" - # TODO: self.exist_load_model, self.exist_is_model_loaded 両方を True にする + self.exist_load_model = True + self.exist_is_model_loaded = True + self.core.load_model.argtypes = (c_int,) + self.core.load_model.restype = c_bool + self.core.is_model_loaded.argtypes = (c_int,) + self.core.is_model_loaded.restype = c_bool else: model_type = check_core_type(core_dir) assert model_type is not None @@ -411,6 +416,11 @@ def __init__( raise Exception(self.core.last_error_message().decode("utf-8")) finally: os.chdir(cwd) + + def _lazy_init(self, speaker_id: int): + if self.exist_load_model and self.exist_load_model: + if not self.is_model_loaded(speaker_id): + self.load_model(speaker_id) def metas(self) -> str: return self.core.metas().decode("utf-8") @@ -421,6 +431,7 @@ def yukarin_s_forward( phoneme_list: np.ndarray, speaker_id: np.ndarray, ) -> np.ndarray: + self._lazy_init(speaker_id[0]) output = np.zeros((length,), dtype=np.float32) success = self.core.yukarin_s_forward( c_int(length), @@ -443,6 +454,7 @@ def yukarin_sa_forward( end_accent_phrase_list: np.ndarray, speaker_id: np.ndarray, ) -> np.ndarray: + self._lazy_init(speaker_id[0]) output = np.empty( ( len(speaker_id), @@ -473,6 +485,7 @@ def decode_forward( phoneme: np.ndarray, speaker_id: np.ndarray, ) -> np.ndarray: + self._lazy_init(speaker_id[0]) output = np.empty((length * 256,), dtype=np.float32) success = self.core.decode_forward( c_int(length), @@ -496,3 +509,13 @@ def finalize(self) -> None: self.core.finalize() return raise NameError + + def load_model(self, speaker_id: int) -> bool: + if self.exist_load_model: + return self.core.load_model(c_int(speaker_id)) + raise NameError + + def is_model_loaded(self, speaker_id: int) -> bool: + if self.exist_is_model_loaded: + return self.core.is_model_loaded(c_int(speaker_id)) + raise NameError diff --git a/voicevox_engine/synthesis_engine/make_synthesis_engines.py b/voicevox_engine/synthesis_engine/make_synthesis_engines.py index b34801779..b31ecebbf 100644 --- a/voicevox_engine/synthesis_engine/make_synthesis_engines.py +++ b/voicevox_engine/synthesis_engine/make_synthesis_engines.py @@ -16,6 +16,7 @@ def make_synthesis_engines( runtime_dirs: Optional[List[Path]] = None, cpu_num_threads: Optional[int] = None, enable_mock: bool = True, + load_all_models: bool = False, ) -> Dict[str, SynthesisEngineBase]: """ 音声ライブラリをロードして、音声合成エンジンを生成 @@ -36,6 +37,8 @@ def make_synthesis_engines( Noneのとき、ライブラリ側の挙動により論理コア数の半分か、物理コア数が指定される enable_mock: bool, optional, default=True コア読み込みに失敗したとき、代わりにmockを使用するかどうか + load_all_models: bool, optional, default=False + 起動時に全てのモデルを読み込むかどうか """ if cpu_num_threads == 0 or cpu_num_threads is None: print( @@ -69,7 +72,7 @@ def make_synthesis_engines( synthesis_engines = {} for core_dir in voicelib_dirs: try: - core = CoreWrapper(use_gpu, core_dir, cpu_num_threads) + core = CoreWrapper(use_gpu, core_dir, cpu_num_threads, load_all_models) metas = json.loads(core.metas()) core_version = metas[0]["version"] if core_version in synthesis_engines: