Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

整理: 名前付き引数削除 #968

Merged
merged 1 commit into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 13 additions & 30 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def audio_query(
style_id = get_style_id_from_deprecated(style_id=style_id, speaker_id=speaker)
engine = get_engine(core_version)
core = get_core(core_version)
accent_phrases = engine.create_accent_phrases(text, style_id=style_id)
accent_phrases = engine.create_accent_phrases(text, style_id)
return AudioQuery(
accent_phrases=accent_phrases,
speedScale=1,
Expand Down Expand Up @@ -323,7 +323,7 @@ def audio_query_from_preset(
raise HTTPException(status_code=422, detail="該当するプリセットIDが見つかりません")

accent_phrases = engine.create_accent_phrases(
text, style_id=StyleId(selected_preset.style_id)
text, StyleId(selected_preset.style_id)
)
return AudioQuery(
accent_phrases=accent_phrases,
Expand Down Expand Up @@ -376,13 +376,11 @@ def accent_phrases(
status_code=400,
detail=ParseKanaBadRequest(err).dict(),
)
accent_phrases = engine.replace_mora_data(
accent_phrases=accent_phrases, style_id=style_id
)
accent_phrases = engine.replace_mora_data(accent_phrases, style_id)

return accent_phrases
else:
return engine.create_accent_phrases(text, style_id=style_id)
return engine.create_accent_phrases(text, style_id)

@app.post(
"/mora_data",
Expand All @@ -398,7 +396,7 @@ def mora_data(
) -> list[AccentPhrase]:
style_id = get_style_id_from_deprecated(style_id=style_id, speaker_id=speaker)
engine = get_engine(core_version)
return engine.replace_mora_data(accent_phrases, style_id=style_id)
return engine.replace_mora_data(accent_phrases, style_id)

@app.post(
"/mora_length",
Expand All @@ -414,9 +412,7 @@ def mora_length(
) -> list[AccentPhrase]:
style_id = get_style_id_from_deprecated(style_id=style_id, speaker_id=speaker)
engine = get_engine(core_version)
return engine.replace_phoneme_length(
accent_phrases=accent_phrases, style_id=style_id
)
return engine.replace_phoneme_length(accent_phrases, style_id)

@app.post(
"/mora_pitch",
Expand All @@ -432,9 +428,7 @@ def mora_pitch(
) -> list[AccentPhrase]:
style_id = get_style_id_from_deprecated(style_id=style_id, speaker_id=speaker)
engine = get_engine(core_version)
return engine.replace_mora_pitch(
accent_phrases=accent_phrases, style_id=style_id
)
return engine.replace_mora_pitch(accent_phrases, style_id)

@app.post(
"/synthesis",
Expand Down Expand Up @@ -462,9 +456,7 @@ def synthesis(
style_id = get_style_id_from_deprecated(style_id=style_id, speaker_id=speaker)
engine = get_engine(core_version)
wave = engine.synthesis(
query=query,
style_id=style_id,
enable_interrogative_upspeak=enable_interrogative_upspeak,
query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak
)

with NamedTemporaryFile(delete=False) as f:
Expand Down Expand Up @@ -505,10 +497,7 @@ def cancellable_synthesis(
detail="実験的機能はデフォルトで無効になっています。使用するには引数を指定してください。",
)
f_name = cancellable_engine._synthesis_impl(
query=query,
style_id=style_id,
request=request,
core_version=core_version,
query, style_id, request, core_version=core_version
)
if f_name == "":
raise HTTPException(status_code=422, detail="不明なバージョンです")
Expand Down Expand Up @@ -553,7 +542,7 @@ def multi_synthesis(
)

with TemporaryFile() as wav_file:
wave = engine.synthesis(query=queries[i], style_id=style_id)
wave = engine.synthesis(queries[i], style_id)
soundfile.write(
file=wav_file,
data=wave,
Expand Down Expand Up @@ -1012,9 +1001,7 @@ def initialize_style_id(
実行しなくても他のAPIは使用できますが、初回実行時に時間がかかることがあります。
"""
core = get_core(core_version)
core.initialize_style_id_synthesis(
style_id=StyleId(style_id), skip_reinit=skip_reinit
)
core.initialize_style_id_synthesis(StyleId(style_id), skip_reinit=skip_reinit)
return Response(status_code=204)

@app.get("/is_initialized_style_id", response_model=bool, tags=["その他"])
Expand Down Expand Up @@ -1046,9 +1033,7 @@ def initialize_speaker(
stacklevel=1,
)
return initialize_style_id(
style_id=StyleId(speaker),
skip_reinit=skip_reinit,
core_version=core_version,
StyleId(speaker), skip_reinit=skip_reinit, core_version=core_version
)

@app.get(
Expand All @@ -1066,9 +1051,7 @@ def is_initialized_speaker(
"使用しているAPI(/is_initialize_speaker)は非推奨です。/is_initialized_style_idを利用してください。",
stacklevel=1,
)
return is_initialized_style_id(
style_id=StyleId(speaker), core_version=core_version
)
return is_initialized_style_id(StyleId(speaker), core_version=core_version)

@app.get("/user_dict", response_model=dict[str, UserDictWord], tags=["ユーザー辞書"])
def get_user_dict_words() -> dict[str, UserDictWord]:
Expand Down
12 changes: 3 additions & 9 deletions test/test_mock_tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,11 @@ def setUp(self):

def test_replace_phoneme_length(self):
"""`.replace_phoneme_length()` がエラー無く生成をおこなう"""
self.engine.replace_phoneme_length(
accent_phrases=self.accent_phrases_hello_hiho,
style_id=StyleId(0),
)
self.engine.replace_phoneme_length(self.accent_phrases_hello_hiho, StyleId(0))

def test_replace_mora_pitch(self):
"""`.replace_mora_pitch()` がエラー無く生成をおこなう"""
self.engine.replace_mora_pitch(
accent_phrases=self.accent_phrases_hello_hiho,
style_id=StyleId(0),
)
self.engine.replace_mora_pitch(self.accent_phrases_hello_hiho, StyleId(0))

def test_synthesis(self):
"""`.synthesis()` がエラー無く生成をおこなう"""
Expand All @@ -75,5 +69,5 @@ def test_synthesis(self):
outputStereo=False,
kana=create_kana(self.accent_phrases_hello_hiho),
),
style_id=StyleId(0),
StyleId(0),
)
6 changes: 3 additions & 3 deletions test/test_tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def test_replace_phoneme_length(self):
# Inputs
hello_hiho = deepcopy(self.accent_phrases_hello_hiho)
# Outputs & Indirect Outputs(yukarin_sに渡される値)
result = self.tts_engine.replace_phoneme_length(hello_hiho, style_id=StyleId(1))
result = self.tts_engine.replace_phoneme_length(hello_hiho, StyleId(1))
yukarin_s_args = self.yukarin_s_mock.call_args[1]
list_length = yukarin_s_args["length"]
phoneme_list = yukarin_s_args["phoneme_list"]
Expand Down Expand Up @@ -628,7 +628,7 @@ def test_replace_mora_pitch(self):
# Inputs
phrases: list = []
# Outputs
result = self.tts_engine.replace_mora_pitch(phrases, style_id=StyleId(1))
result = self.tts_engine.replace_mora_pitch(phrases, StyleId(1))
# Expects
true_result: list = []
# Tests
Expand All @@ -637,7 +637,7 @@ def test_replace_mora_pitch(self):
# Inputs
hello_hiho = deepcopy(self.accent_phrases_hello_hiho)
# Outputs & Indirect Outputs(yukarin_saに渡される値)
result = self.tts_engine.replace_mora_pitch(hello_hiho, style_id=StyleId(1))
result = self.tts_engine.replace_mora_pitch(hello_hiho, StyleId(1))
yukarin_sa_args = self.yukarin_sa_mock.call_args[1]
list_length = yukarin_sa_args["length"]
vowel_phoneme_list = yukarin_sa_args["vowel_phoneme_list"][0]
Expand Down
4 changes: 2 additions & 2 deletions voicevox_engine/morphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def synthesis_morphing_parameter(
# WORLDに掛けるため合成はモノラルで行う
query.outputStereo = False

base_wave = engine.synthesis(query=query, style_id=base_speaker).astype("float")
target_wave = engine.synthesis(query=query, style_id=target_speaker).astype("float")
base_wave = engine.synthesis(query, style_id=base_speaker).astype("float")
target_wave = engine.synthesis(query, style_id=target_speaker).astype("float")

return create_morphing_parameter(
base_wave=base_wave,
Expand Down
9 changes: 2 additions & 7 deletions voicevox_engine/tts_pipeline/tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,7 @@ def replace_mora_data(
) -> list[AccentPhrase]:
"""アクセント句系列の音素長・モーラ音高をスタイルIDに基づいて更新する"""
return self.replace_mora_pitch(
accent_phrases=self.replace_phoneme_length(
accent_phrases=accent_phrases, style_id=style_id
),
accent_phrases=self.replace_phoneme_length(accent_phrases, style_id),
style_id=style_id,
)

Expand All @@ -436,10 +434,7 @@ def create_accent_phrases(self, text: str, style_id: StyleId) -> list[AccentPhra
accent_phrases = text_to_accent_phrases(text)

# 音素長・モーラ音高の推定と更新
accent_phrases = self.replace_mora_data(
accent_phrases=accent_phrases,
style_id=style_id,
)
accent_phrases = self.replace_mora_data(accent_phrases, style_id)
return accent_phrases

def synthesis(
Expand Down