diff --git a/examples/api/main.py b/examples/api/main.py index 1542b1740..90f2f138b 100644 --- a/examples/api/main.py +++ b/examples/api/main.py @@ -23,7 +23,10 @@ from pydantic import BaseModel - +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from tools.normalizer.en import normalizer_en_nemo_text +from tools.normalizer.zh import normalizer_zh_tn logger = get_logger("Command") @@ -35,14 +38,23 @@ async def startup_event(): global chat chat = ChatTTS.Chat(get_logger("ChatTTS")) + chat.normalizer.register("en", normalizer_en_nemo_text()) + chat.normalizer.register("zh", normalizer_zh_tn()) + logger.info("Initializing ChatTTS...") - if chat.load(): + if chat.load(source="huggingface"): logger.info("Models loaded successfully.") else: logger.error("Models load failed.") sys.exit(1) +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc: RequestValidationError): + logger.error(f"Validation error: {exc.errors()}") + return JSONResponse(status_code=422, content={"detail": exc.errors()}) + + class ChatTTSParams(BaseModel): text: list[str] stream: bool = False @@ -52,7 +64,7 @@ class ChatTTSParams(BaseModel): use_decoder: bool = True do_text_normalization: bool = True do_homophone_replacement: bool = False - params_refine_text: ChatTTS.Chat.RefineTextParams + params_refine_text: ChatTTS.Chat.RefineTextParams = None params_infer_code: ChatTTS.Chat.InferCodeParams