Skip to content

Commit

Permalink
モデルの読み込みを起動時から変更 (#400)
Browse files Browse the repository at this point in the history
* モデルの読み込みを遅延させるように変更

* fix format

* c_int to c_long

Co-authored-by: Gray Suitcase <[email protected]>

* c_int to c_long

Co-authored-by: Gray Suitcase <[email protected]>

* c_int to c_long

Co-authored-by: Gray Suitcase <[email protected]>

* SynthesisEngineの引数をCoreWrapperに変更

* テストを一旦無効化

* CoreWrapperのAttributeを参照しないように変更

* テスト修正

Co-authored-by: Gray Suitcase <[email protected]>
  • Loading branch information
takana-v and PickledChair authored May 15, 2022
1 parent cf1c95e commit dc918b2
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 51 deletions.
2 changes: 2 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,7 @@ def engine_manifest():
parser.add_argument("--enable_cancellable_synthesis", action="store_true")
parser.add_argument("--enable_guided_synthesis", action="store_true")
parser.add_argument("--init_processes", type=int, default=2)
parser.add_argument("--load_all_models", action="store_true")

# 引数へcpu_num_threadsの指定がなければ、環境変数をロールします。
# 環境変数にもない場合は、Noneのままとします。
Expand All @@ -880,6 +881,7 @@ def engine_manifest():
runtime_dirs=args.runtime_dir,
cpu_num_threads=cpu_num_threads,
enable_mock=args.enable_mock,
load_all_models=args.load_all_models,
)
assert len(synthesis_engines) != 0, "音声合成エンジンがありません。"
latest_core_version = str(max([LooseVersion(ver) for ver in synthesis_engines]))
Expand Down
27 changes: 20 additions & 7 deletions test/test_synthesis_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,21 @@ def decode_mock(
return numpy.array(result)


class MockCore:
yukarin_s_forward = Mock(side_effect=yukarin_s_mock)
yukarin_sa_forward = Mock(side_effect=yukarin_sa_mock)
decode_forward = Mock(side_effect=decode_mock)

def metas(self):
return ""

def supported_devices(self):
return ""

def is_model_loaded(self, speaker_id):
raise NameError


class TestSynthesisEngine(TestCase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -186,14 +201,12 @@ def setUp(self):
pause_mora=None,
),
]
self.yukarin_s_mock = Mock(side_effect=yukarin_s_mock)
self.yukarin_sa_mock = Mock(side_effect=yukarin_sa_mock)
self.decode_mock = Mock(side_effect=decode_mock)
core = MockCore()
self.yukarin_s_mock = core.yukarin_s_forward
self.yukarin_sa_mock = core.yukarin_sa_forward
self.decode_mock = core.decode_forward
self.synthesis_engine = SynthesisEngine(
yukarin_s_forwarder=self.yukarin_s_mock,
yukarin_sa_forwarder=self.yukarin_sa_mock,
decode_forwarder=self.decode_mock,
speakers="",
core=core,
)

def test_to_flatten_moras(self):
Expand Down
20 changes: 16 additions & 4 deletions test/test_synthesis_engine_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,26 @@ def create_mock_query(accent_phrases):
)


class MockCore:
yukarin_s_forward = Mock(side_effect=yukarin_s_mock)
yukarin_sa_forward = Mock(side_effect=yukarin_sa_mock)
decode_forward = Mock(side_effect=decode_mock)

def metas(self):
return ""

def supported_devices(self):
return ""

def is_model_loaded(self, speaker_id):
raise NameError


class TestSynthesisEngineBase(TestCase):
def setUp(self):
super().setUp()
self.synthesis_engine = SynthesisEngine(
yukarin_s_forwarder=Mock(side_effect=yukarin_s_mock),
yukarin_sa_forwarder=Mock(side_effect=yukarin_sa_mock),
decode_forwarder=Mock(side_effect=decode_mock),
speakers="",
core=MockCore(),
)
self.synthesis_engine._synthesis_impl = Mock()

Expand Down
23 changes: 19 additions & 4 deletions voicevox_engine/synthesis_engine/core_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def __init__(
use_gpu: bool,
core_dir: Path,
cpu_num_threads: int = 0,
load_all_models: bool = True,
load_all_models: bool = False,
) -> None:

self.core = load_core(core_dir, use_gpu)
Expand All @@ -354,12 +354,17 @@ def __init__(
self.exist_suppoted_devices = False
self.exist_finalize = False
exist_cpu_num_threads = False
# TODO: version 0.12 から追加された load_model, is_model_loaded 関数に対応するため、
# self.exist_load_model, self.exist_is_model_loaded 変数を定義する
self.exist_load_model = False
self.exist_is_model_loaded = False

if is_version_0_12_core_or_later(core_dir):
model_type = "onnxruntime"
# TODO: self.exist_load_model, self.exist_is_model_loaded 両方を True にする
self.exist_load_model = True
self.exist_is_model_loaded = True
self.core.load_model.argtypes = (c_long,)
self.core.load_model.restype = c_bool
self.core.is_model_loaded.argtypes = (c_long,)
self.core.is_model_loaded.restype = c_bool
else:
model_type = check_core_type(core_dir)
assert model_type is not None
Expand Down Expand Up @@ -496,3 +501,13 @@ def finalize(self) -> None:
self.core.finalize()
return
raise NameError

def load_model(self, speaker_id: int) -> bool:
if self.exist_load_model:
return self.core.load_model(c_long(speaker_id))
raise NameError

def is_model_loaded(self, speaker_id: int) -> bool:
if self.exist_is_model_loaded:
return self.core.is_model_loaded(c_long(speaker_id))
raise NameError
19 changes: 5 additions & 14 deletions voicevox_engine/synthesis_engine/make_synthesis_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def make_synthesis_engines(
runtime_dirs: Optional[List[Path]] = None,
cpu_num_threads: Optional[int] = None,
enable_mock: bool = True,
load_all_models: bool = False,
) -> Dict[str, SynthesisEngineBase]:
"""
音声ライブラリをロードして、音声合成エンジンを生成
Expand All @@ -36,6 +37,8 @@ def make_synthesis_engines(
Noneのとき、ライブラリ側の挙動により論理コア数の半分か、物理コア数が指定される
enable_mock: bool, optional, default=True
コア読み込みに失敗したとき、代わりにmockを使用するかどうか
load_all_models: bool, optional, default=False
起動時に全てのモデルを読み込むかどうか
"""
if cpu_num_threads == 0 or cpu_num_threads is None:
print(
Expand Down Expand Up @@ -69,7 +72,7 @@ def make_synthesis_engines(
synthesis_engines = {}
for core_dir in voicelib_dirs:
try:
core = CoreWrapper(use_gpu, core_dir, cpu_num_threads)
core = CoreWrapper(use_gpu, core_dir, cpu_num_threads, load_all_models)
metas = json.loads(core.metas())
core_version = metas[0]["version"]
if core_version in synthesis_engines:
Expand All @@ -78,19 +81,7 @@ def make_synthesis_engines(
file=sys.stderr,
)
continue
try:
supported_devices = core.supported_devices()
except NameError:
# libtorch版コアは対応していないのでNameErrorになる
# 対応デバイスが不明であることを示すNoneを代入する
supported_devices = None
synthesis_engines[core_version] = SynthesisEngine(
yukarin_s_forwarder=core.yukarin_s_forward,
yukarin_sa_forwarder=core.yukarin_sa_forward,
decode_forwarder=core.decode_forward,
speakers=core.metas(),
supported_devices=supported_devices,
)
synthesis_engines[core_version] = SynthesisEngine(core=core)
except Exception:
if not enable_mock:
raise
Expand Down
55 changes: 33 additions & 22 deletions voicevox_engine/synthesis_engine/synthesis_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..acoustic_feature_extractor import OjtPhoneme
from ..kana_parser import create_kana
from ..model import AccentPhrase, AudioQuery, Mora
from .core_wrapper import CoreWrapper
from .synthesis_engine_base import SynthesisEngineBase

unvoiced_mora_phoneme_list = ["A", "I", "U", "E", "O", "cl", "pau"]
Expand Down Expand Up @@ -138,20 +139,16 @@ def pre_process(
class SynthesisEngine(SynthesisEngineBase):
def __init__(
self,
yukarin_s_forwarder,
yukarin_sa_forwarder,
decode_forwarder,
speakers: str,
supported_devices: Optional[str] = None,
core: CoreWrapper,
):
"""
yukarin_s_forwarder: 音素列から、音素ごとの長さを求める関数
core.yukarin_s_forward: 音素列から、音素ごとの長さを求める関数
length: 音素列の長さ
phoneme_list: 音素列
speaker_id: 話者番号
return: 音素ごとの長さ
yukarin_sa_forwarder: モーラごとの音素列とアクセント情報から、モーラごとの音高を求める関数
core.yukarin_sa_forward: モーラごとの音素列とアクセント情報から、モーラごとの音高を求める関数
length: モーラ列の長さ
vowel_phoneme_list: 母音の音素列
consonant_phoneme_list: 子音の音素列
Expand All @@ -162,7 +159,7 @@ def __init__(
speaker_id: 話者番号
return: モーラごとの音高
decode_forwarder: フレームごとの音素と音高から波形を求める関数
core.decode_forward: フレームごとの音素と音高から波形を求める関数
length: フレームの長さ
phoneme_size: 音素の種類数
f0: フレームごとの音高
Expand All @@ -177,12 +174,12 @@ def __init__(
Noneの場合はコアが情報の取得に対応していないため、対応デバイスは不明
"""
super().__init__()
self.yukarin_s_forwarder = yukarin_s_forwarder
self.yukarin_sa_forwarder = yukarin_sa_forwarder
self.decode_forwarder = decode_forwarder

self._speakers = speakers
self._supported_devices = supported_devices
self.core = core
self._speakers = self.core.metas()
try:
self._supported_devices = self.core.supported_devices()
except NameError:
self._supported_devices = None
self.default_sampling_rate = 24000

@property
Expand All @@ -193,6 +190,14 @@ def speakers(self) -> str:
def supported_devices(self) -> Optional[str]:
return self._supported_devices

def _lazy_init(self, speaker_id: int):
try:
is_model_loaded = self.core.is_model_loaded(speaker_id)
except NameError:
return
if not is_model_loaded:
self.core.load_model(speaker_id)

def replace_phoneme_length(
self, accent_phrases: List[AccentPhrase], speaker_id: int
) -> List[AccentPhrase]:
Expand All @@ -209,6 +214,8 @@ def replace_phoneme_length(
accent_phrases : List[AccentPhrase]
母音・子音の長さが設定されたアクセント句モデルのリスト
"""
# モデルがロードされていない場合はロードする
self._lazy_init(speaker_id)
# phoneme
# AccentPhraseをすべてMoraおよびOjtPhonemeの形に分解し、処理可能な形にする
flatten_moras, phoneme_data_list = pre_process(accent_phrases)
Expand All @@ -220,8 +227,8 @@ def replace_phoneme_length(
phoneme_list_s = numpy.array(
[p.phoneme_id for p in phoneme_data_list], dtype=numpy.int64
)
# Phoneme IDのリスト(phoneme_list_s)をyukarin_s_forwarderにかけ、推論器によって適切な音素の長さを割り当てる
phoneme_length = self.yukarin_s_forwarder(
# Phoneme IDのリスト(phoneme_list_s)をyukarin_s_forwardにかけ、推論器によって適切な音素の長さを割り当てる
phoneme_length = self.core.yukarin_s_forward(
length=len(phoneme_list_s),
phoneme_list=phoneme_list_s,
speaker_id=numpy.array(speaker_id, dtype=numpy.int64).reshape(-1),
Expand Down Expand Up @@ -255,6 +262,8 @@ def replace_mora_pitch(
accent_phrases : List[AccentPhrase]
音高(ピッチ)が設定されたアクセント句モデルのリスト
"""
# モデルがロードされていない場合はロードする
self._lazy_init(speaker_id)
# numpy.concatenateが空リストだとエラーを返すのでチェック
if len(accent_phrases) == 0:
return []
Expand Down Expand Up @@ -354,8 +363,8 @@ def _create_one_hot(accent_phrase: AccentPhrase, position: int):
dtype=numpy.int64,
)

# 今までに生成された情報をyukarin_sa_forwarderにかけ、推論器によってモーラごとに適切な音高(ピッチ)を割り当てる
f0_list = self.yukarin_sa_forwarder(
# 今までに生成された情報をyukarin_sa_forwardにかけ、推論器によってモーラごとに適切な音高(ピッチ)を割り当てる
f0_list = self.core.yukarin_sa_forward(
length=vowel_phoneme_list.shape[0],
vowel_phoneme_list=vowel_phoneme_list[numpy.newaxis],
consonant_phoneme_list=consonant_phoneme_list[numpy.newaxis],
Expand Down Expand Up @@ -392,7 +401,8 @@ def _synthesis_impl(self, query: AudioQuery, speaker_id: int):
wave : numpy.ndarray
音声合成結果
"""

# モデルがロードされていない場合はロードする
self._lazy_init(speaker_id)
# phoneme
# AccentPhraseをすべてMoraおよびOjtPhonemeの形に分解し、処理可能な形にする
flatten_moras, phoneme_data_list = pre_process(query.accent_phrases)
Expand Down Expand Up @@ -461,8 +471,8 @@ def _synthesis_impl(self, query: AudioQuery, speaker_id: int):
array[numpy.arange(len(phoneme)), phoneme] = 1
phoneme = array

# 今まで生成された情報をdecode_forwarderにかけ、推論器によって音声波形を生成する
wave = self.decode_forwarder(
# 今まで生成された情報をdecode_forwardにかけ、推論器によって音声波形を生成する
wave = self.core.decode_forward(
length=phoneme.shape[0],
phoneme_size=phoneme.shape[1],
f0=f0[:, numpy.newaxis],
Expand Down Expand Up @@ -495,6 +505,7 @@ def guided_synthesis(
normalize: bool,
core_version: Optional[str] = None,
):
self._lazy_init(speaker)
kana = create_kana(query.accent_phrases)
f0, phonemes = extract_guided_feature(audio_path, kana)

Expand Down Expand Up @@ -524,7 +535,7 @@ def guided_synthesis(
f0 = resample(f0, int(len(f0) / query.speedScale))
phone_list = resample(phone_list, int(len(phone_list) / query.speedScale))

wave = self.decode_forwarder(
wave = self.core.decode_forward(
length=phone_list.shape[0],
phoneme_size=phone_list.shape[1],
f0=f0[:, numpy.newaxis].astype(numpy.float32),
Expand Down

0 comments on commit dc918b2

Please sign in to comment.