From 4c712533673a23a168406efd2d517a72a64d2f4f Mon Sep 17 00:00:00 2001 From: takana-v <44311840+takana-v@users.noreply.github.com> Date: Fri, 4 Feb 2022 08:02:49 +0900 Subject: [PATCH] =?UTF-8?q?=E5=AF=BE=E5=BF=9C=E3=83=87=E3=83=90=E3=82=A4?= =?UTF-8?q?=E3=82=B9=E3=81=8C=E5=88=86=E3=81=8B=E3=82=8BAPI=E3=82=A8?= =?UTF-8?q?=E3=83=B3=E3=83=89=E3=83=9D=E3=82=A4=E3=83=B3=E3=83=88=E3=82=92?= =?UTF-8?q?=E8=BF=BD=E5=8A=A0=20(#299)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * /supported_devicesを追加 * model修正 * 抽象プロパティを定義 * コメント追加 --- run.py | 13 +++++++++++++ test/test_mock_synthesis_engine.py | 2 +- voicevox_engine/dev/core/__init__.py | 2 ++ voicevox_engine/dev/core/mock.py | 9 +++++++++ voicevox_engine/dev/synthesis_engine/mock.py | 13 +++++++++++-- voicevox_engine/model.py | 9 +++++++++ .../synthesis_engine/make_synthesis_engines.py | 12 +++++++++++- .../synthesis_engine/synthesis_engine.py | 16 +++++++++++++++- .../synthesis_engine/synthesis_engine_base.py | 12 +++++++++++- 9 files changed, 82 insertions(+), 6 deletions(-) diff --git a/run.py b/run.py index 8b3a59b3b..de7ec39a7 100644 --- a/run.py +++ b/run.py @@ -31,6 +31,7 @@ ParseKanaError, Speaker, SpeakerInfo, + SupportedDevicesInfo, ) from voicevox_engine.morphing import synthesis_morphing from voicevox_engine.morphing import ( @@ -532,6 +533,18 @@ def speaker_info(speaker_uuid: str, core_version: Optional[str] = None): ret_data = {"policy": policy, "portrait": portrait, "style_infos": style_infos} return ret_data + @app.get("/supported_devices", response_model=SupportedDevicesInfo, tags=["その他"]) + def supported_devices( + core_version: Optional[str] = None, + ): + supported_devices = get_engine(core_version).supported_devices + if supported_devices is None: + raise HTTPException(status_code=422, detail="非対応の機能です。") + return Response( + content=supported_devices, + media_type="application/json", + ) + return app diff --git a/test/test_mock_synthesis_engine.py b/test/test_mock_synthesis_engine.py index 27bf20bf5..c06a0504a 100644 --- a/test/test_mock_synthesis_engine.py +++ b/test/test_mock_synthesis_engine.py @@ -102,7 +102,7 @@ def setUp(self): pause_mora=None, ), ] - self.engine = MockSynthesisEngine(speakers="") + self.engine = MockSynthesisEngine(speakers="", supported_devices="") def test_replace_phoneme_length(self): self.assertEqual( diff --git a/voicevox_engine/dev/core/__init__.py b/voicevox_engine/dev/core/__init__.py index cbc1cc1e1..432b00b93 100644 --- a/voicevox_engine/dev/core/__init__.py +++ b/voicevox_engine/dev/core/__init__.py @@ -2,6 +2,7 @@ decode_forward, initialize, metas, + supported_devices, yukarin_s_forward, yukarin_sa_forward, ) @@ -12,4 +13,5 @@ "yukarin_s_forward", "yukarin_sa_forward", "metas", + "supported_devices", ] diff --git a/voicevox_engine/dev/core/mock.py b/voicevox_engine/dev/core/mock.py index d49964701..59eb63d70 100644 --- a/voicevox_engine/dev/core/mock.py +++ b/voicevox_engine/dev/core/mock.py @@ -110,3 +110,12 @@ def metas() -> str: }, ] ) + + +def supported_devices() -> str: + return json.dumps( + { + "cpu": True, + "cuda": False, + } + ) diff --git a/voicevox_engine/dev/synthesis_engine/mock.py b/voicevox_engine/dev/synthesis_engine/mock.py index cd265c6e3..ed710274f 100644 --- a/voicevox_engine/dev/synthesis_engine/mock.py +++ b/voicevox_engine/dev/synthesis_engine/mock.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import numpy as np from pyopenjtalk import tts @@ -21,9 +21,18 @@ def __init__(self, **kwargs): """ super().__init__() - self.speakers = kwargs["speakers"] + self._speakers = kwargs["speakers"] + self._supported_devices = kwargs["supported_devices"] self.default_sampling_rate = 24000 + @property + def speakers(self) -> str: + return self._speakers + + @property + def supported_devices(self) -> Optional[str]: + return self._supported_devices + def replace_phoneme_length( self, accent_phrases: List[AccentPhrase], speaker_id: int ) -> List[AccentPhrase]: diff --git a/voicevox_engine/model.py b/voicevox_engine/model.py index 3ef6a88ee..5e19d8a43 100644 --- a/voicevox_engine/model.py +++ b/voicevox_engine/model.py @@ -140,3 +140,12 @@ class SpeakerInfo(BaseModel): policy: str = Field(title="policy.md") portrait: str = Field(title="portrait.pngをbase64エンコードしたもの") style_infos: List[StyleInfo] = Field(title="スタイルの追加情報") + + +class SupportedDevicesInfo(BaseModel): + """ + 対応しているデバイスの情報 + """ + + cpu: bool = Field(title="CPUに対応しているか") + cuda: bool = Field(title="CUDA(GPU)に対応しているか") diff --git a/voicevox_engine/synthesis_engine/make_synthesis_engines.py b/voicevox_engine/synthesis_engine/make_synthesis_engines.py index 5d140ad23..95996ba81 100644 --- a/voicevox_engine/synthesis_engine/make_synthesis_engines.py +++ b/voicevox_engine/synthesis_engine/make_synthesis_engines.py @@ -81,11 +81,18 @@ def make_synthesis_engines( file=sys.stderr, ) continue + try: + supported_devices = core.supported_devices() + except NameError: + # libtorch版コアは対応していないのでNameErrorになる + # 対応デバイスが不明であることを示すNoneを代入する + supported_devices = None synthesis_engines[core_version] = SynthesisEngine( yukarin_s_forwarder=core.yukarin_s_forward, yukarin_sa_forwarder=core.yukarin_sa_forward, decode_forwarder=core.decode_forward, speakers=core.metas(), + supported_devices=supported_devices, ) except Exception: if not enable_mock: @@ -96,9 +103,12 @@ def make_synthesis_engines( file=sys.stderr, ) from ..dev.core import metas as mock_metas + from ..dev.core import supported_devices as mock_supported_devices from ..dev.synthesis_engine import MockSynthesisEngine if "0.0.0" not in synthesis_engines: - synthesis_engines["0.0.0"] = MockSynthesisEngine(speakers=mock_metas()) + synthesis_engines["0.0.0"] = MockSynthesisEngine( + speakers=mock_metas(), supported_devices=mock_supported_devices() + ) return synthesis_engines diff --git a/voicevox_engine/synthesis_engine/synthesis_engine.py b/voicevox_engine/synthesis_engine/synthesis_engine.py index e5fa0ae03..add5cf6ef 100644 --- a/voicevox_engine/synthesis_engine/synthesis_engine.py +++ b/voicevox_engine/synthesis_engine/synthesis_engine.py @@ -132,6 +132,7 @@ def __init__( yukarin_sa_forwarder, decode_forwarder, speakers: str, + supported_devices: Optional[str] = None, ): """ yukarin_s_forwarder: 音素列から、音素ごとの長さを求める関数 @@ -160,15 +161,28 @@ def __init__( return: 音声波形 speakers: coreから取得したspeakersに関するjsonデータの文字列 + + supported_devices: + coreから取得した対応デバイスに関するjsonデータの文字列 + Noneの場合はコアが情報の取得に対応していないため、対応デバイスは不明 """ super().__init__() self.yukarin_s_forwarder = yukarin_s_forwarder self.yukarin_sa_forwarder = yukarin_sa_forwarder self.decode_forwarder = decode_forwarder - self.speakers = speakers + self._speakers = speakers + self._supported_devices = supported_devices self.default_sampling_rate = 24000 + @property + def speakers(self) -> str: + return self._speakers + + @property + def supported_devices(self) -> Optional[str]: + return self._supported_devices + def replace_phoneme_length( self, accent_phrases: List[AccentPhrase], speaker_id: int ) -> List[AccentPhrase]: diff --git a/voicevox_engine/synthesis_engine/synthesis_engine_base.py b/voicevox_engine/synthesis_engine/synthesis_engine_base.py index 272550b78..b2f4277c4 100644 --- a/voicevox_engine/synthesis_engine/synthesis_engine_base.py +++ b/voicevox_engine/synthesis_engine/synthesis_engine_base.py @@ -1,6 +1,6 @@ import copy from abc import ABCMeta, abstractmethod -from typing import List +from typing import List, Optional from .. import full_context_label from ..full_context_label import extract_full_context_label @@ -78,6 +78,16 @@ def full_context_label_moras_to_moras( class SynthesisEngineBase(metaclass=ABCMeta): + @property + @abstractmethod + def speakers(self) -> str: + raise NotImplementedError + + @property + @abstractmethod + def supported_devices(self) -> Optional[str]: + raise NotImplementedError + @abstractmethod def replace_phoneme_length( self, accent_phrases: List[AccentPhrase], speaker_id: int