Skip to content

Commit

Permalink
optimize(model): more torch funcalls
Browse files Browse the repository at this point in the history
and fix wenui infer without refine text
  • Loading branch information
fumiama committed Jun 25, 2024
1 parent 40cda89 commit 37f7663
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 54 deletions.
27 changes: 11 additions & 16 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,21 +198,15 @@ def _prepare_generation_inputs(
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = (
cache_position[0]
int(cache_position[0])
if cache_position is not None
else past_key_values.get_seq_length()
)
max_cache_length = (
torch.tensor(
past_key_values.get_max_length(), device=input_ids.device
)
if past_key_values.get_max_length() is not None
else None
)
max_cache_length = past_key_values.get_max_length()
cache_length = (
past_length
if max_cache_length is None
else torch.min(max_cache_length, past_length)
else min(max_cache_length, past_length)
)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
Expand All @@ -227,11 +221,12 @@ def _prepare_generation_inputs(
attention_mask is not None
and attention_mask.shape[1] > input_ids.shape[1]
):
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
start = -(attention_mask.shape[1] - past_length)
input_ids = input_ids.narrow(1, start, -start)
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
input_ids = input_ids.narrow(1, past_length, input_ids.size(1)-past_length)
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
Expand All @@ -240,14 +235,14 @@ def _prepare_generation_inputs(
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
attention_mask = attention_mask.narrow(1, -max_cache_length, max_cache_length)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
position_ids = position_ids.narrow(1, -input_ids.shape[1], input_ids.shape[1])

input_length = (
position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
Expand All @@ -257,7 +252,7 @@ def _prepare_generation_inputs(
past_length, past_length + input_length, device=input_ids.device
)
else:
cache_position = cache_position[-input_length:]
cache_position = cache_position.narrow(0, -input_length, input_length)

if has_static_cache:
past_key_values = None
Expand Down Expand Up @@ -365,7 +360,7 @@ def generate(
device=inputs_ids.device,
)
if attention_mask is not None:
attention_mask_cache[:, : attention_mask.shape[1]] = attention_mask
attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_(attention_mask)

with tqdm(
total=max_new_token,
Expand All @@ -379,7 +374,7 @@ def generate(
model_input = self._prepare_generation_inputs(
inputs_ids,
past_key_values,
attention_mask_cache[:, : inputs_ids.shape[1]],
attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]),
use_cache=True,
)

Expand Down
18 changes: 11 additions & 7 deletions ChatTTS/model/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@ def __init__(self, penalty: float, max_input_ids: int, past_window: int):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:

input_ids = input_ids[:, -self.past_window :]
if input_ids.size(1) > self.past_window:
input_ids = input_ids.narrow(1, -self.past_window, self.past_window)
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
freq[self.max_input_ids :] = 0
alpha = self.penalty**freq
if freq.size(0) > self.max_input_ids:
freq.narrow(0, self.max_input_ids, freq.size(0)-self.max_input_ids).zero_()
alpha = torch.pow(self.penalty, freq)
scores = scores.contiguous()
scores = torch.where(scores < 0, scores * alpha, scores / alpha)

return scores
inp = scores.multiply(alpha)
oth = scores.divide(alpha)
con = scores < 0
out = torch.where(con, inp, oth)
del inp, oth, scores, con, alpha
return out


"""class CustomRepetitionPenaltyLogitsProcessor():
Expand Down
29 changes: 29 additions & 0 deletions examples/web/ex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
ex=[
[
"四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。",
0.3,
0.7,
20,
2,
42,
True,
],
[
"What is your favorite english food?",
0.5,
0.5,
10,
245,
531,
True,
],
[
"chat T T S is a text to speech model designed for dialogue applications. [uv_break]it supports mixed language input [uv_break]and offers multi speaker capabilities with precise control over prosodic elements like [uv_break]laughter[uv_break][laugh], [uv_break]pauses, [uv_break]and intonation. [uv_break]it delivers natural and expressive speech,[uv_break]so please[uv_break] use the project responsibly at your own risk.[uv_break]",
0.8,
0.4,
7,
70,
165,
False,
],
]
2 changes: 2 additions & 0 deletions examples/web/funcs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import random
from typing import Optional
from time import sleep

import gradio as gr
import numpy as np
Expand Down Expand Up @@ -107,6 +108,7 @@ def refine_text(
has_interrupted = False

if not refine_text_flag:
sleep(1) # to skip fast answer of loading mark
return text, *set_generate_buttons(
generate_button, interrupt_button, is_reset=True
)
Expand Down
36 changes: 5 additions & 31 deletions examples/web/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import gradio as gr

from examples.web.funcs import *
from examples.web.ex import ex


def main():
Expand All @@ -35,7 +36,7 @@ def main():
maximum=1.0,
step=0.00001,
value=0.3,
label="Audio temperature",
label="Audio Temperature",
interactive=True,
)
top_p_slider = gr.Slider(
Expand All @@ -54,7 +55,7 @@ def main():
voice_selection = gr.Dropdown(
label="Timbre", choices=voices.keys(), value="Default"
)
audio_seed_input = gr.Number(value=2, label="Audio Seed")
audio_seed_input = gr.Number(value=2, label="Audio Seed", interactive=True)
generate_audio_seed = gr.Button("\U0001F3B2")
text_seed_input = gr.Number(value=42, label="Text Seed")
generate_text_seed = gr.Button("\U0001F3B2")
Expand Down Expand Up @@ -144,35 +145,7 @@ def make_audio(autoplay, stream):
)

gr.Examples(
examples=[
[
"四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。",
0.3,
0.7,
20,
2,
42,
True,
],
[
"What is [uv_break]your favorite english food?[laugh][lbreak]",
0.5,
0.5,
10,
245,
531,
False,
],
[
"chat T T S is a text to speech model designed for dialogue applications. [uv_break]it supports mixed language input [uv_break]and offers multi speaker capabilities with precise control over prosodic elements [laugh]like like [uv_break]laughter[laugh], [uv_break]pauses, [uv_break]and intonation. [uv_break]it delivers natural and expressive speech,[uv_break]so please[uv_break] use the project responsibly at your own risk.[uv_break]",
0.2,
0.6,
15,
67,
165,
False,
],
],
examples=ex,
inputs=[
text_input,
temperature_slider,
Expand Down Expand Up @@ -213,6 +186,7 @@ def make_audio(autoplay, stream):
server_port=args.server_port,
root_path=args.root_path,
inbrowser=True,
show_api=False,
)


Expand Down

0 comments on commit 37f7663

Please sign in to comment.