Skip to content

Commit

Permalink
fix(gpt): set attention mask and address other warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Oct 25, 2024
1 parent b66c782 commit 964b813
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 2 deletions.
38 changes: 38 additions & 0 deletions TTS/tts/layers/tortoise/autoregressive.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
# AGPL: a notification must be added stating that changes have been made to that file.
import functools
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from packaging.version import Version
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper

if Version(transformers.__version__) >= Version("4.45"):
isin = transformers.pytorch_utils.isin_mps_friendly
else:
isin = torch.isin


def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
Expand Down Expand Up @@ -596,6 +604,8 @@ def inference_speech(
max_length = (
trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
)
stop_token_tensor = torch.tensor(self.stop_mel_token, device=inputs.device, dtype=torch.long)
attention_mask = _prepare_attention_mask_for_generation(inputs, stop_token_tensor, stop_token_tensor)
gen = self.inference_model.generate(
inputs,
bos_token_id=self.start_mel_token,
Expand All @@ -604,11 +614,39 @@ def inference_speech(
max_length=max_length,
logits_processor=logits_processor,
num_return_sequences=num_return_sequences,
attention_mask=attention_mask,
**hf_generate_kwargs,
)
return gen[:, trunc_index:]


def _prepare_attention_mask_for_generation(
inputs: torch.Tensor,
pad_token_id: Optional[torch.Tensor],
eos_token_id: Optional[torch.Tensor],
) -> torch.LongTensor:
# No information for attention mask inference -> return default attention mask
default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
if pad_token_id is None:
return default_attention_mask

is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
if not is_input_ids:
return default_attention_mask

is_pad_token_in_inputs = (pad_token_id is not None) and (isin(elements=inputs, test_elements=pad_token_id).any())
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
isin(elements=eos_token_id, test_elements=pad_token_id).any()
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs.ne(pad_token_id).long()

attention_mask = (
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
)
return attention_mask


if __name__ == "__main__":
gpt = UnifiedVoice(
model_dim=256,
Expand Down
4 changes: 4 additions & 0 deletions TTS/tts/layers/xtts/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F
from transformers import GPT2Config

from TTS.tts.layers.tortoise.autoregressive import _prepare_attention_mask_for_generation
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
Expand Down Expand Up @@ -586,12 +587,15 @@ def generate(
**hf_generate_kwargs,
):
gpt_inputs = self.compute_embeddings(cond_latents, text_inputs)
stop_token_tensor = torch.tensor(self.stop_audio_token, device=gpt_inputs.device, dtype=torch.long)
attention_mask = _prepare_attention_mask_for_generation(gpt_inputs, stop_token_tensor, stop_token_tensor)
gen = self.gpt_inference.generate(
gpt_inputs,
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1],
attention_mask=attention_mask,
**hf_generate_kwargs,
)
if "return_dict_in_generate" in hf_generate_kwargs:
Expand Down
7 changes: 5 additions & 2 deletions TTS/tts/layers/xtts/gpt_inference.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
from torch import nn
from transformers import GPT2PreTrainedModel
from transformers import GenerationMixin, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

from TTS.tts.layers.xtts.stream_generator import StreamGenerationConfig

class GPT2InferenceModel(GPT2PreTrainedModel):

class GPT2InferenceModel(GPT2PreTrainedModel, GenerationMixin):
"""Override GPT2LMHeadModel to allow for prefix conditioning."""

def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
Expand All @@ -15,6 +17,7 @@ def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
self.final_norm = norm
self.lm_head = nn.Sequential(norm, linear)
self.kv_cache = kv_cache
self.generation_config = StreamGenerationConfig.from_model_config(config) if self.can_generate() else None

def store_prefix_emb(self, prefix_emb):
self.cached_prefix_emb = prefix_emb
Expand Down
1 change: 1 addition & 0 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ def inference_stream(
repetition_penalty=float(repetition_penalty),
output_attentions=False,
output_hidden_states=True,
return_dict_in_generate=True,
**hf_generate_kwargs,
)

Expand Down

0 comments on commit 964b813

Please sign in to comment.