Skip to content

Commit

Permalink
対応デバイスが分かるAPIエンドポイントを追加 (#299)
Browse files Browse the repository at this point in the history
* /supported_devicesを追加

* model修正

* 抽象プロパティを定義

* コメント追加
  • Loading branch information
takana-v authored Feb 3, 2022
1 parent ad1484c commit 4c71253
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 6 deletions.
13 changes: 13 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ParseKanaError,
Speaker,
SpeakerInfo,
SupportedDevicesInfo,
)
from voicevox_engine.morphing import synthesis_morphing
from voicevox_engine.morphing import (
Expand Down Expand Up @@ -532,6 +533,18 @@ def speaker_info(speaker_uuid: str, core_version: Optional[str] = None):
ret_data = {"policy": policy, "portrait": portrait, "style_infos": style_infos}
return ret_data

@app.get("/supported_devices", response_model=SupportedDevicesInfo, tags=["その他"])
def supported_devices(
core_version: Optional[str] = None,
):
supported_devices = get_engine(core_version).supported_devices
if supported_devices is None:
raise HTTPException(status_code=422, detail="非対応の機能です。")
return Response(
content=supported_devices,
media_type="application/json",
)

return app


Expand Down
2 changes: 1 addition & 1 deletion test/test_mock_synthesis_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def setUp(self):
pause_mora=None,
),
]
self.engine = MockSynthesisEngine(speakers="")
self.engine = MockSynthesisEngine(speakers="", supported_devices="")

def test_replace_phoneme_length(self):
self.assertEqual(
Expand Down
2 changes: 2 additions & 0 deletions voicevox_engine/dev/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
decode_forward,
initialize,
metas,
supported_devices,
yukarin_s_forward,
yukarin_sa_forward,
)
Expand All @@ -12,4 +13,5 @@
"yukarin_s_forward",
"yukarin_sa_forward",
"metas",
"supported_devices",
]
9 changes: 9 additions & 0 deletions voicevox_engine/dev/core/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,12 @@ def metas() -> str:
},
]
)


def supported_devices() -> str:
return json.dumps(
{
"cpu": True,
"cuda": False,
}
)
13 changes: 11 additions & 2 deletions voicevox_engine/dev/synthesis_engine/mock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from logging import getLogger
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import numpy as np
from pyopenjtalk import tts
Expand All @@ -21,9 +21,18 @@ def __init__(self, **kwargs):
"""
super().__init__()

self.speakers = kwargs["speakers"]
self._speakers = kwargs["speakers"]
self._supported_devices = kwargs["supported_devices"]
self.default_sampling_rate = 24000

@property
def speakers(self) -> str:
return self._speakers

@property
def supported_devices(self) -> Optional[str]:
return self._supported_devices

def replace_phoneme_length(
self, accent_phrases: List[AccentPhrase], speaker_id: int
) -> List[AccentPhrase]:
Expand Down
9 changes: 9 additions & 0 deletions voicevox_engine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,12 @@ class SpeakerInfo(BaseModel):
policy: str = Field(title="policy.md")
portrait: str = Field(title="portrait.pngをbase64エンコードしたもの")
style_infos: List[StyleInfo] = Field(title="スタイルの追加情報")


class SupportedDevicesInfo(BaseModel):
"""
対応しているデバイスの情報
"""

cpu: bool = Field(title="CPUに対応しているか")
cuda: bool = Field(title="CUDA(GPU)に対応しているか")
12 changes: 11 additions & 1 deletion voicevox_engine/synthesis_engine/make_synthesis_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,18 @@ def make_synthesis_engines(
file=sys.stderr,
)
continue
try:
supported_devices = core.supported_devices()
except NameError:
# libtorch版コアは対応していないのでNameErrorになる
# 対応デバイスが不明であることを示すNoneを代入する
supported_devices = None
synthesis_engines[core_version] = SynthesisEngine(
yukarin_s_forwarder=core.yukarin_s_forward,
yukarin_sa_forwarder=core.yukarin_sa_forward,
decode_forwarder=core.decode_forward,
speakers=core.metas(),
supported_devices=supported_devices,
)
except Exception:
if not enable_mock:
Expand All @@ -96,9 +103,12 @@ def make_synthesis_engines(
file=sys.stderr,
)
from ..dev.core import metas as mock_metas
from ..dev.core import supported_devices as mock_supported_devices
from ..dev.synthesis_engine import MockSynthesisEngine

if "0.0.0" not in synthesis_engines:
synthesis_engines["0.0.0"] = MockSynthesisEngine(speakers=mock_metas())
synthesis_engines["0.0.0"] = MockSynthesisEngine(
speakers=mock_metas(), supported_devices=mock_supported_devices()
)

return synthesis_engines
16 changes: 15 additions & 1 deletion voicevox_engine/synthesis_engine/synthesis_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
yukarin_sa_forwarder,
decode_forwarder,
speakers: str,
supported_devices: Optional[str] = None,
):
"""
yukarin_s_forwarder: 音素列から、音素ごとの長さを求める関数
Expand Down Expand Up @@ -160,15 +161,28 @@ def __init__(
return: 音声波形
speakers: coreから取得したspeakersに関するjsonデータの文字列
supported_devices:
coreから取得した対応デバイスに関するjsonデータの文字列
Noneの場合はコアが情報の取得に対応していないため、対応デバイスは不明
"""
super().__init__()
self.yukarin_s_forwarder = yukarin_s_forwarder
self.yukarin_sa_forwarder = yukarin_sa_forwarder
self.decode_forwarder = decode_forwarder

self.speakers = speakers
self._speakers = speakers
self._supported_devices = supported_devices
self.default_sampling_rate = 24000

@property
def speakers(self) -> str:
return self._speakers

@property
def supported_devices(self) -> Optional[str]:
return self._supported_devices

def replace_phoneme_length(
self, accent_phrases: List[AccentPhrase], speaker_id: int
) -> List[AccentPhrase]:
Expand Down
12 changes: 11 additions & 1 deletion voicevox_engine/synthesis_engine/synthesis_engine_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
from abc import ABCMeta, abstractmethod
from typing import List
from typing import List, Optional

from .. import full_context_label
from ..full_context_label import extract_full_context_label
Expand Down Expand Up @@ -78,6 +78,16 @@ def full_context_label_moras_to_moras(


class SynthesisEngineBase(metaclass=ABCMeta):
@property
@abstractmethod
def speakers(self) -> str:
raise NotImplementedError

@property
@abstractmethod
def supported_devices(self) -> Optional[str]:
raise NotImplementedError

@abstractmethod
def replace_phoneme_length(
self, accent_phrases: List[AccentPhrase], speaker_id: int
Expand Down

0 comments on commit 4c71253

Please sign in to comment.