Skip to content

Commit

Permalink
FEAT: Fish speech stream (#2562)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Nov 21, 2024
1 parent a49c1a6 commit c456ef9
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 35 deletions.
101 changes: 66 additions & 35 deletions xinference/model/audio/fish_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,12 @@ def load(self):
raise ValueError(f"Device {self._device} is not available!")

enable_compile = self._kwargs.get("compile", False)
precision = self._kwargs.get("precision", torch.bfloat16)
logger.info("Loading Llama model, compile=%s...", enable_compile)
self._llama_queue = launch_thread_safe_queue(
checkpoint_path=self._model_path,
device=self._device,
precision=torch.bfloat16,
precision=precision,
compile=enable_compile,
)
logger.info("Llama model loaded, loading VQ-GAN model...")
Expand All @@ -113,16 +114,22 @@ def _inference(
top_p,
repetition_penalty,
temperature,
seed="0",
streaming=False,
):
from fish_speech.utils import autocast_exclude_mps
from fish_speech.utils import autocast_exclude_mps, set_seed
from tools.api import decode_vq_tokens, encode_reference
from tools.llama.generate import (
GenerateRequest,
GenerateResponse,
WrappedGenerateResponse,
)

seed = int(seed)
if seed != 0:
set_seed(seed)
logger.warning(f"set seed: {seed}")

# Parse reference audio aka prompt
prompt_tokens = encode_reference(
decoder_model=self._model,
Expand All @@ -138,7 +145,7 @@ def _inference(
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
compile=False,
compile=self._kwargs.get("compile", False),
iterative_prompt=chunk_length > 0,
chunk_length=chunk_length,
max_length=2048,
Expand All @@ -154,22 +161,20 @@ def _inference(
)
)

if streaming:
yield wav_chunk_header(), None, None

segments = []

while True:
result: WrappedGenerateResponse = response_queue.get() # type: ignore
result: WrappedGenerateResponse = response_queue.get()
if result.status == "error":
raise Exception(str(result.response))
raise result.response

result: GenerateResponse = result.response # type: ignore
result: GenerateResponse = result.response
if result.action == "next":
break

with autocast_exclude_mps(
device_type=self._model.device.type, dtype=torch.bfloat16
device_type=self._model.device.type,
dtype=self._kwargs.get("precision", torch.bfloat16),
):
fake_audios = decode_vq_tokens(
decoder_model=self._model,
Expand All @@ -180,7 +185,7 @@ def _inference(
segments.append(fake_audios)

if streaming:
yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
yield fake_audios, None, None

if len(segments) == 0:
raise Exception("No audio generated, please check the input text.")
Expand All @@ -205,32 +210,58 @@ def speech(
logger.warning("Fish speech does not support setting voice: %s.", voice)
if speed != 1.0:
logger.warning("Fish speech does not support setting speed: %s.", speed)
if stream is True:
logger.warning("stream mode is not implemented.")
import torchaudio

prompt_speech = kwargs.get("prompt_speech")
result = list(
self._inference(
text=input,
enable_reference_audio=kwargs.get(
"enable_reference_audio", prompt_speech is not None
),
reference_audio=prompt_speech,
reference_text=kwargs.get("reference_text", ""),
max_new_tokens=kwargs.get("max_new_tokens", 1024),
chunk_length=kwargs.get("chunk_length", 200),
top_p=kwargs.get("top_p", 0.7),
repetition_penalty=kwargs.get("repetition_penalty", 1.2),
temperature=kwargs.get("temperature", 0.7),
)
result = self._inference(
text=input,
enable_reference_audio=kwargs.get(
"enable_reference_audio", prompt_speech is not None
),
reference_audio=prompt_speech,
reference_text=kwargs.get("reference_text", ""),
max_new_tokens=kwargs.get("max_new_tokens", 1024),
chunk_length=kwargs.get("chunk_length", 200),
top_p=kwargs.get("top_p", 0.7),
repetition_penalty=kwargs.get("repetition_penalty", 1.2),
temperature=kwargs.get("temperature", 0.7),
streaming=stream,
)
sample_rate, audio = result[0][1]
audio = np.array([audio])

# Save the generated audio
with BytesIO() as out:
torchaudio.save(
out, torch.from_numpy(audio), sample_rate, format=response_format
)
return out.getvalue()
if stream:

def _stream_generator():
with BytesIO() as out:
writer = torchaudio.io.StreamWriter(out, format=response_format)
writer.add_audio_stream(
sample_rate=self._model.spec_transform.sample_rate,
num_channels=1,
)
i = 0
last_pos = 0
with writer.open():
for chunk in result:
chunk = chunk[0]
if chunk is not None:
chunk = chunk.reshape((chunk.shape[0], 1))
trans_chunk = torch.from_numpy(chunk)
writer.write_audio_chunk(i, trans_chunk)
new_last_pos = out.tell()
if new_last_pos != last_pos:
out.seek(last_pos)
encoded_bytes = out.read()
yield encoded_bytes
last_pos = new_last_pos

return _stream_generator()
else:
result = list(result)
sample_rate, audio = result[0][1]
audio = np.array([audio])

# Save the generated audio
with BytesIO() as out:
torchaudio.save(
out, torch.from_numpy(audio), sample_rate, format=response_format
)
return out.getvalue()
15 changes: 15 additions & 0 deletions xinference/model/audio/tests/test_fish_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import tempfile

Expand All @@ -25,11 +26,25 @@ def test_fish_speech(setup):
model_name="FishSpeech-1.4", model_type="audio", compile=False
)
model = client.get_model(model_uid)

input_string = "你好,你是谁?"
response = model.speech(input_string)
assert type(response) is bytes
assert len(response) > 0

# Test stream
input_string = "瑞典王国,通称瑞典,是一个位于斯堪的纳维亚半岛的北欧国家,首都及最大城市为斯德哥尔摩。"
response = model.speech(input_string, chunk_length=20, stream=True)
assert inspect.isgenerator(response)
i = 0
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as f:
for chunk in response:
f.write(chunk)
i += 1
assert type(chunk) is bytes
assert len(chunk) > 0
assert i > 5

# Test openai API
import openai

Expand Down

0 comments on commit c456ef9

Please sign in to comment.