Skip to content

Commit

Permalink
feat!: ユーザー辞書データに改行やnull文字が入っていた場合にエラーとする (#1522)
Browse files Browse the repository at this point in the history
Co-authored-by: Hiroshiba Kazuyuki <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Hiroshiba <[email protected]>
  • Loading branch information
4 people authored Feb 7, 2025
1 parent ab1df7e commit ab6f180
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 48 deletions.
57 changes: 56 additions & 1 deletion test/unit/user_dict/test_user_dict_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""UserDictWord のテスト"""

from typing import TypedDict
from typing import Literal, TypedDict, get_args

import pytest
from pydantic import ValidationError
Expand Down Expand Up @@ -55,6 +55,46 @@ def test_valid_word() -> None:
UserDictWord(**args)


CsvSafeStrFieldName = Literal[
"part_of_speech",
"part_of_speech_detail_1",
"part_of_speech_detail_2",
"part_of_speech_detail_3",
"inflectional_type",
"inflectional_form",
"stem",
"yomi",
"accent_associative_rule",
]


@pytest.mark.parametrize(
"field",
get_args(CsvSafeStrFieldName),
)
def test_invalid_csv_safe_str(field: CsvSafeStrFieldName) -> None:
"""UserDictWord の文字列 CSV で許可されない文字をエラーとする。"""
# Inputs
test_value_newlines = generate_model()
test_value_newlines[field] = "te\r\nst"
test_value_null = generate_model()
test_value_null[field] = "te\x00st"
test_value_comma = generate_model()
test_value_comma[field] = "te,st"
test_value_double_quote = generate_model()
test_value_double_quote[field] = 'te"st'

# Test
with pytest.raises(ValidationError):
UserDictWord(**test_value_newlines)
with pytest.raises(ValidationError):
UserDictWord(**test_value_null)
with pytest.raises(ValidationError):
UserDictWord(**test_value_comma)
with pytest.raises(ValidationError):
UserDictWord(**test_value_double_quote)


def test_convert_to_zenkaku() -> None:
"""UserDictWord は surface を全角にする。"""
# Inputs
Expand Down Expand Up @@ -126,6 +166,21 @@ def test_invalid_pronunciation_not_katakana() -> None:
UserDictWord(**test_value)


def test_invalid_pronunciation_newlines_and_null() -> None:
"""UserDictWord は pronunciation 内の改行や null 文字をエラーとする。"""
# Inputs
test_value_newlines = generate_model()
test_value_newlines["pronunciation"] = "ボイ\r\nボ"
test_value_null = generate_model()
test_value_null["pronunciation"] = "ボイ\x00ボ"

# Test
with pytest.raises(ValidationError):
UserDictWord(**test_value_newlines)
with pytest.raises(ValidationError):
UserDictWord(**test_value_null)


def test_invalid_pronunciation_invalid_sutegana() -> None:
"""UserDictWord は無効な pronunciation をエラーとする。"""
# Inputs
Expand Down
1 change: 1 addition & 0 deletions voicevox_engine/app/routers/user_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get_user_dict_words() -> dict[str, UserDictWord]:
status_code=500, detail="辞書の読み込みに失敗しました。"
)

# TODO: CsvSafeStrを使う
@router.post("/user_dict_word", dependencies=[Depends(verify_mutability)])
def add_user_dict_word(
surface: Annotated[str, Query(description="言葉の表層形")],
Expand Down
120 changes: 73 additions & 47 deletions voicevox_engine/user_dict/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from re import findall, fullmatch
from typing import Self

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import AfterValidator, BaseModel, ConfigDict, Field, model_validator
from pydantic.json_schema import SkipJsonSchema
from typing_extensions import Annotated


class WordTypes(str, Enum):
Expand All @@ -26,65 +27,90 @@ class WordTypes(str, Enum):
USER_DICT_MAX_PRIORITY = 10


def _check_newlines_and_null(text: str) -> str:
if "\n" in text or "\r" in text:
raise ValueError("ユーザー辞書データ内に改行が含まれています。")
if "\x00" in text:
raise ValueError("ユーザー辞書データ内にnull文字が含まれています。")
return text


def _check_comma_and_double_quote(text: str) -> str:
if "," in text:
raise ValueError("ユーザー辞書データ内にカンマが含まれています。")
if '"' in text:
raise ValueError("ユーザー辞書データ内にダブルクォートが含まれています。")
return text


def _convert_to_zenkaku(surface: str) -> str:
return surface.translate(
str.maketrans(
"".join(chr(0x21 + i) for i in range(94)),
"".join(chr(0xFF01 + i) for i in range(94)),
)
)


def _check_is_katakana(pronunciation: str) -> str:
if not fullmatch(r"[ァ-ヴー]+", pronunciation):
raise ValueError("発音は有効なカタカナでなくてはいけません。")
sutegana = ["ァ", "ィ", "ゥ", "ェ", "ォ", "ャ", "ュ", "ョ", "ヮ", "ッ"]
for i in range(len(pronunciation)):
if pronunciation[i] in sutegana:
# 「キャット」のように、捨て仮名が連続する可能性が考えられるので、
# 「ッ」に関しては「ッ」そのものが連続している場合と、「ッ」の後にほかの捨て仮名が連続する場合のみ無効とする
if i < len(pronunciation) - 1 and (
pronunciation[i + 1] in sutegana[:-1]
or (
pronunciation[i] == sutegana[-1]
and pronunciation[i + 1] == sutegana[-1]
)
):
raise ValueError("無効な発音です。(捨て仮名の連続)")
if pronunciation[i] == "ヮ":
if i != 0 and pronunciation[i - 1] not in ["ク", "グ"]:
raise ValueError("無効な発音です。(「くゎ」「ぐゎ」以外の「ゎ」の使用)")
return pronunciation


CsvSafeStr = Annotated[
str,
AfterValidator(_check_newlines_and_null),
AfterValidator(_check_comma_and_double_quote),
]


class UserDictWord(BaseModel):
"""
辞書のコンパイルに使われる情報
"""

model_config = ConfigDict(validate_assignment=True)

surface: str = Field(description="表層形")
surface: Annotated[
str,
AfterValidator(_convert_to_zenkaku),
AfterValidator(_check_newlines_and_null),
] = Field(description="表層形")
priority: int = Field(
description="優先度", ge=USER_DICT_MIN_PRIORITY, le=USER_DICT_MAX_PRIORITY
)
context_id: int = Field(description="文脈ID", default=1348)
part_of_speech: str = Field(description="品詞")
part_of_speech_detail_1: str = Field(description="品詞細分類1")
part_of_speech_detail_2: str = Field(description="品詞細分類2")
part_of_speech_detail_3: str = Field(description="品詞細分類3")
inflectional_type: str = Field(description="活用型")
inflectional_form: str = Field(description="活用形")
stem: str = Field(description="原形")
yomi: str = Field(description="読み")
pronunciation: str = Field(description="発音")
part_of_speech: CsvSafeStr = Field(description="品詞")
part_of_speech_detail_1: CsvSafeStr = Field(description="品詞細分類1")
part_of_speech_detail_2: CsvSafeStr = Field(description="品詞細分類2")
part_of_speech_detail_3: CsvSafeStr = Field(description="品詞細分類3")
inflectional_type: CsvSafeStr = Field(description="活用型")
inflectional_form: CsvSafeStr = Field(description="活用形")
stem: CsvSafeStr = Field(description="原形")
yomi: CsvSafeStr = Field(description="読み")
pronunciation: Annotated[CsvSafeStr, AfterValidator(_check_is_katakana)] = Field(
description="発音"
)
accent_type: int = Field(description="アクセント型")
mora_count: int | SkipJsonSchema[None] = Field(default=None, description="モーラ数")
accent_associative_rule: str = Field(description="アクセント結合規則")

@field_validator("surface")
@classmethod
def convert_to_zenkaku(cls, surface: str) -> str:
return surface.translate(
str.maketrans(
"".join(chr(0x21 + i) for i in range(94)),
"".join(chr(0xFF01 + i) for i in range(94)),
)
)

@field_validator("pronunciation", mode="before")
@classmethod
def check_is_katakana(cls, pronunciation: str) -> str:
if not fullmatch(r"[ァ-ヴー]+", pronunciation):
raise ValueError("発音は有効なカタカナでなくてはいけません。")
sutegana = ["ァ", "ィ", "ゥ", "ェ", "ォ", "ャ", "ュ", "ョ", "ヮ", "ッ"]
for i in range(len(pronunciation)):
if pronunciation[i] in sutegana:
# 「キャット」のように、捨て仮名が連続する可能性が考えられるので、
# 「ッ」に関しては「ッ」そのものが連続している場合と、「ッ」の後にほかの捨て仮名が連続する場合のみ無効とする
if i < len(pronunciation) - 1 and (
pronunciation[i + 1] in sutegana[:-1]
or (
pronunciation[i] == sutegana[-1]
and pronunciation[i + 1] == sutegana[-1]
)
):
raise ValueError("無効な発音です。(捨て仮名の連続)")
if pronunciation[i] == "ヮ":
if i != 0 and pronunciation[i - 1] not in ["ク", "グ"]:
raise ValueError(
"無効な発音です。(「くゎ」「ぐゎ」以外の「ゎ」の使用)"
)
return pronunciation
accent_associative_rule: CsvSafeStr = Field(description="アクセント結合規則")

@model_validator(mode="after")
def check_mora_count_and_accent_type(self) -> Self:
Expand Down

0 comments on commit ab6f180

Please sign in to comment.