Skip to content

Commit

Permalink
整理: 名前付き引数削除 (#968)
Browse files Browse the repository at this point in the history
refactor: 名前付き引数削除
  • Loading branch information
tarepan authored Jan 3, 2024
1 parent dc597c0 commit 6452c89
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 51 deletions.
43 changes: 13 additions & 30 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def audio_query(
style_id = get_style_id_from_deprecated(style_id=style_id, speaker_id=speaker)
engine = get_engine(core_version)
core = get_core(core_version)
accent_phrases = engine.create_accent_phrases(text, style_id=style_id)
accent_phrases = engine.create_accent_phrases(text, style_id)
return AudioQuery(
accent_phrases=accent_phrases,
speedScale=1,
Expand Down Expand Up @@ -323,7 +323,7 @@ def audio_query_from_preset(
raise HTTPException(status_code=422, detail="該当するプリセットIDが見つかりません")

accent_phrases = engine.create_accent_phrases(
text, style_id=StyleId(selected_preset.style_id)
text, StyleId(selected_preset.style_id)
)
return AudioQuery(
accent_phrases=accent_phrases,
Expand Down Expand Up @@ -376,13 +376,11 @@ def accent_phrases(
status_code=400,
detail=ParseKanaBadRequest(err).dict(),
)
accent_phrases = engine.replace_mora_data(
accent_phrases=accent_phrases, style_id=style_id
)
accent_phrases = engine.replace_mora_data(accent_phrases, style_id)

return accent_phrases
else:
return engine.create_accent_phrases(text, style_id=style_id)
return engine.create_accent_phrases(text, style_id)

@app.post(
"/mora_data",
Expand All @@ -398,7 +396,7 @@ def mora_data(
) -> list[AccentPhrase]:
style_id = get_style_id_from_deprecated(style_id=style_id, speaker_id=speaker)
engine = get_engine(core_version)
return engine.replace_mora_data(accent_phrases, style_id=style_id)
return engine.replace_mora_data(accent_phrases, style_id)

@app.post(
"/mora_length",
Expand All @@ -414,9 +412,7 @@ def mora_length(
) -> list[AccentPhrase]:
style_id = get_style_id_from_deprecated(style_id=style_id, speaker_id=speaker)
engine = get_engine(core_version)
return engine.replace_phoneme_length(
accent_phrases=accent_phrases, style_id=style_id
)
return engine.replace_phoneme_length(accent_phrases, style_id)

@app.post(
"/mora_pitch",
Expand All @@ -432,9 +428,7 @@ def mora_pitch(
) -> list[AccentPhrase]:
style_id = get_style_id_from_deprecated(style_id=style_id, speaker_id=speaker)
engine = get_engine(core_version)
return engine.replace_mora_pitch(
accent_phrases=accent_phrases, style_id=style_id
)
return engine.replace_mora_pitch(accent_phrases, style_id)

@app.post(
"/synthesis",
Expand Down Expand Up @@ -462,9 +456,7 @@ def synthesis(
style_id = get_style_id_from_deprecated(style_id=style_id, speaker_id=speaker)
engine = get_engine(core_version)
wave = engine.synthesis(
query=query,
style_id=style_id,
enable_interrogative_upspeak=enable_interrogative_upspeak,
query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak
)

with NamedTemporaryFile(delete=False) as f:
Expand Down Expand Up @@ -505,10 +497,7 @@ def cancellable_synthesis(
detail="実験的機能はデフォルトで無効になっています。使用するには引数を指定してください。",
)
f_name = cancellable_engine._synthesis_impl(
query=query,
style_id=style_id,
request=request,
core_version=core_version,
query, style_id, request, core_version=core_version
)
if f_name == "":
raise HTTPException(status_code=422, detail="不明なバージョンです")
Expand Down Expand Up @@ -553,7 +542,7 @@ def multi_synthesis(
)

with TemporaryFile() as wav_file:
wave = engine.synthesis(query=queries[i], style_id=style_id)
wave = engine.synthesis(queries[i], style_id)
soundfile.write(
file=wav_file,
data=wave,
Expand Down Expand Up @@ -1012,9 +1001,7 @@ def initialize_style_id(
実行しなくても他のAPIは使用できますが、初回実行時に時間がかかることがあります。
"""
core = get_core(core_version)
core.initialize_style_id_synthesis(
style_id=StyleId(style_id), skip_reinit=skip_reinit
)
core.initialize_style_id_synthesis(StyleId(style_id), skip_reinit=skip_reinit)
return Response(status_code=204)

@app.get("/is_initialized_style_id", response_model=bool, tags=["その他"])
Expand Down Expand Up @@ -1046,9 +1033,7 @@ def initialize_speaker(
stacklevel=1,
)
return initialize_style_id(
style_id=StyleId(speaker),
skip_reinit=skip_reinit,
core_version=core_version,
StyleId(speaker), skip_reinit=skip_reinit, core_version=core_version
)

@app.get(
Expand All @@ -1066,9 +1051,7 @@ def is_initialized_speaker(
"使用しているAPI(/is_initialize_speaker)は非推奨です。/is_initialized_style_idを利用してください。",
stacklevel=1,
)
return is_initialized_style_id(
style_id=StyleId(speaker), core_version=core_version
)
return is_initialized_style_id(StyleId(speaker), core_version=core_version)

@app.get("/user_dict", response_model=dict[str, UserDictWord], tags=["ユーザー辞書"])
def get_user_dict_words() -> dict[str, UserDictWord]:
Expand Down
12 changes: 3 additions & 9 deletions test/test_mock_tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,11 @@ def setUp(self):

def test_replace_phoneme_length(self):
"""`.replace_phoneme_length()` がエラー無く生成をおこなう"""
self.engine.replace_phoneme_length(
accent_phrases=self.accent_phrases_hello_hiho,
style_id=StyleId(0),
)
self.engine.replace_phoneme_length(self.accent_phrases_hello_hiho, StyleId(0))

def test_replace_mora_pitch(self):
"""`.replace_mora_pitch()` がエラー無く生成をおこなう"""
self.engine.replace_mora_pitch(
accent_phrases=self.accent_phrases_hello_hiho,
style_id=StyleId(0),
)
self.engine.replace_mora_pitch(self.accent_phrases_hello_hiho, StyleId(0))

def test_synthesis(self):
"""`.synthesis()` がエラー無く生成をおこなう"""
Expand All @@ -75,5 +69,5 @@ def test_synthesis(self):
outputStereo=False,
kana=create_kana(self.accent_phrases_hello_hiho),
),
style_id=StyleId(0),
StyleId(0),
)
6 changes: 3 additions & 3 deletions test/test_tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def test_replace_phoneme_length(self):
# Inputs
hello_hiho = deepcopy(self.accent_phrases_hello_hiho)
# Outputs & Indirect Outputs(yukarin_sに渡される値)
result = self.tts_engine.replace_phoneme_length(hello_hiho, style_id=StyleId(1))
result = self.tts_engine.replace_phoneme_length(hello_hiho, StyleId(1))
yukarin_s_args = self.yukarin_s_mock.call_args[1]
list_length = yukarin_s_args["length"]
phoneme_list = yukarin_s_args["phoneme_list"]
Expand Down Expand Up @@ -628,7 +628,7 @@ def test_replace_mora_pitch(self):
# Inputs
phrases: list = []
# Outputs
result = self.tts_engine.replace_mora_pitch(phrases, style_id=StyleId(1))
result = self.tts_engine.replace_mora_pitch(phrases, StyleId(1))
# Expects
true_result: list = []
# Tests
Expand All @@ -637,7 +637,7 @@ def test_replace_mora_pitch(self):
# Inputs
hello_hiho = deepcopy(self.accent_phrases_hello_hiho)
# Outputs & Indirect Outputs(yukarin_saに渡される値)
result = self.tts_engine.replace_mora_pitch(hello_hiho, style_id=StyleId(1))
result = self.tts_engine.replace_mora_pitch(hello_hiho, StyleId(1))
yukarin_sa_args = self.yukarin_sa_mock.call_args[1]
list_length = yukarin_sa_args["length"]
vowel_phoneme_list = yukarin_sa_args["vowel_phoneme_list"][0]
Expand Down
4 changes: 2 additions & 2 deletions voicevox_engine/morphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def synthesis_morphing_parameter(
# WORLDに掛けるため合成はモノラルで行う
query.outputStereo = False

base_wave = engine.synthesis(query=query, style_id=base_speaker).astype("float")
target_wave = engine.synthesis(query=query, style_id=target_speaker).astype("float")
base_wave = engine.synthesis(query, style_id=base_speaker).astype("float")
target_wave = engine.synthesis(query, style_id=target_speaker).astype("float")

return create_morphing_parameter(
base_wave=base_wave,
Expand Down
9 changes: 2 additions & 7 deletions voicevox_engine/tts_pipeline/tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,7 @@ def replace_mora_data(
) -> list[AccentPhrase]:
"""アクセント句系列の音素長・モーラ音高をスタイルIDに基づいて更新する"""
return self.replace_mora_pitch(
accent_phrases=self.replace_phoneme_length(
accent_phrases=accent_phrases, style_id=style_id
),
accent_phrases=self.replace_phoneme_length(accent_phrases, style_id),
style_id=style_id,
)

Expand All @@ -436,10 +434,7 @@ def create_accent_phrases(self, text: str, style_id: StyleId) -> list[AccentPhra
accent_phrases = text_to_accent_phrases(text)

# 音素長・モーラ音高の推定と更新
accent_phrases = self.replace_mora_data(
accent_phrases=accent_phrases,
style_id=style_id,
)
accent_phrases = self.replace_mora_data(accent_phrases, style_id)
return accent_phrases

def synthesis(
Expand Down

0 comments on commit 6452c89

Please sign in to comment.