From 0cdfb43fb27b1bf3667525f45d3fff00f1b15874 Mon Sep 17 00:00:00 2001 From: codingl2k1 <138426806+codingl2k1@users.noreply.github.com> Date: Tue, 19 Nov 2024 14:24:22 +0100 Subject: [PATCH] ENH: Support fish speech reference audio (#2542) --- doc/source/models/model_abilities/audio.rst | 33 ++++++++++++++++++- xinference/client/restful/restful_client.py | 2 ++ xinference/model/audio/fish_speech.py | 12 ++++--- .../model/audio/tests/test_fish_speech.py | 3 +- 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/doc/source/models/model_abilities/audio.rst b/doc/source/models/model_abilities/audio.rst index 4ca3af1db8..45f3a540cf 100644 --- a/doc/source/models/model_abilities/audio.rst +++ b/doc/source/models/model_abilities/audio.rst @@ -331,7 +331,7 @@ Clone voice, launch model ``CosyVoice-300M``. zero_shot_prompt_text = "" # The zero shot prompt file is the voice file - # the words said in the file shoule be identical to zero_shot_prompt_text + # the words said in the file should be identical to zero_shot_prompt_text with open(zero_shot_prompt_file, "rb") as f: zero_shot_prompt = f.read() @@ -379,3 +379,34 @@ Instruction based, launch model ``CosyVoice-300M-Instruct``. ) More instructions and examples, could be found at https://fun-audio-llm.github.io/ . + + +FishSpeech Usage +~~~~~~~~~~~~~~~~ + +Basic usage, refer to :ref:`audio speech usage `. + +Clone voice, launch model ``FishSpeech-1.4``. Please use `prompt_speech` instead of `reference_audio` +to provide the reference audio to the FishSpeech model. + +.. code-block:: + + from xinference.client import Client + + client = Client("http://:") + + model = client.get_model("") + + reference_text = "" + # The reference audio file is the voice file + # the words said in the file should be identical to reference_text + with open(reference_audio_file, "rb") as f: + reference_audio = f.read() + + speech_bytes = model.speech( + "", + reference_text=reference_text, + prompt_speech=reference_audio, + enable_reference_audio=True, + ) +- \ No newline at end of file diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index ab03c566c1..e145d963de 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -704,6 +704,8 @@ def speech( The speed of the generated audio. stream: bool Use stream or not. + prompt_speech: bytes + The audio bytes to be provided to the model. Returns ------- diff --git a/xinference/model/audio/fish_speech.py b/xinference/model/audio/fish_speech.py index 4a6412f04a..d3f019e875 100644 --- a/xinference/model/audio/fish_speech.py +++ b/xinference/model/audio/fish_speech.py @@ -81,12 +81,13 @@ def load(self): if not is_device_available(self._device): raise ValueError(f"Device {self._device} is not available!") - logger.info("Loading Llama model...") + enable_compile = self._kwargs.get("compile", False) + logger.info("Loading Llama model, compile=%s...", enable_compile) self._llama_queue = launch_thread_safe_queue( checkpoint_path=self._model_path, device=self._device, precision=torch.bfloat16, - compile=False, + compile=enable_compile, ) logger.info("Llama model loaded, loading VQ-GAN model...") @@ -208,11 +209,14 @@ def speech( logger.warning("stream mode is not implemented.") import torchaudio + prompt_speech = kwargs.get("prompt_speech") result = list( self._inference( text=input, - enable_reference_audio=False, - reference_audio=None, + enable_reference_audio=kwargs.get( + "enable_reference_audio", prompt_speech is not None + ), + reference_audio=prompt_speech, reference_text=kwargs.get("reference_text", ""), max_new_tokens=kwargs.get("max_new_tokens", 1024), chunk_length=kwargs.get("chunk_length", 200), diff --git a/xinference/model/audio/tests/test_fish_speech.py b/xinference/model/audio/tests/test_fish_speech.py index ce57566b19..b0d1e60382 100644 --- a/xinference/model/audio/tests/test_fish_speech.py +++ b/xinference/model/audio/tests/test_fish_speech.py @@ -22,8 +22,7 @@ def test_fish_speech(setup): client = Client(endpoint) model_uid = client.launch_model( - model_name="FishSpeech-1.4", - model_type="audio", + model_name="FishSpeech-1.4", model_type="audio", compile=False ) model = client.get_model(model_uid) input_string = "你好,你是谁?"