diff --git a/TTS/tts/layers/tortoise/autoregressive.py b/TTS/tts/layers/tortoise/autoregressive.py index 14d881bc10..aaae695516 100644 --- a/TTS/tts/layers/tortoise/autoregressive.py +++ b/TTS/tts/layers/tortoise/autoregressive.py @@ -1,14 +1,22 @@ # AGPL: a notification must be added stating that changes have been made to that file. import functools +from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F +import transformers +from packaging.version import Version from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper +if Version(transformers.__version__) >= Version("4.45"): + isin = transformers.pytorch_utils.isin_mps_friendly +else: + isin = torch.isin + def null_position_embeddings(range, dim): return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) @@ -596,6 +604,8 @@ def inference_speech( max_length = ( trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length ) + stop_token_tensor = torch.tensor(self.stop_mel_token, device=inputs.device, dtype=torch.long) + attention_mask = _prepare_attention_mask_for_generation(inputs, stop_token_tensor, stop_token_tensor) gen = self.inference_model.generate( inputs, bos_token_id=self.start_mel_token, @@ -604,11 +614,39 @@ def inference_speech( max_length=max_length, logits_processor=logits_processor, num_return_sequences=num_return_sequences, + attention_mask=attention_mask, **hf_generate_kwargs, ) return gen[:, trunc_index:] +def _prepare_attention_mask_for_generation( + inputs: torch.Tensor, + pad_token_id: Optional[torch.Tensor], + eos_token_id: Optional[torch.Tensor], +) -> torch.LongTensor: + # No information for attention mask inference -> return default attention mask + default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) + if pad_token_id is None: + return default_attention_mask + + is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] + if not is_input_ids: + return default_attention_mask + + is_pad_token_in_inputs = (pad_token_id is not None) and (isin(elements=inputs, test_elements=pad_token_id).any()) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~( + isin(elements=eos_token_id, test_elements=pad_token_id).any() + ) + can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id + attention_mask_from_padding = inputs.ne(pad_token_id).long() + + attention_mask = ( + attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask + ) + return attention_mask + + if __name__ == "__main__": gpt = UnifiedVoice( model_dim=256, diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index b55b84d90e..b3c3b31b47 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from transformers import GPT2Config +from TTS.tts.layers.tortoise.autoregressive import _prepare_attention_mask_for_generation from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler @@ -586,12 +587,15 @@ def generate( **hf_generate_kwargs, ): gpt_inputs = self.compute_embeddings(cond_latents, text_inputs) + stop_token_tensor = torch.tensor(self.stop_audio_token, device=gpt_inputs.device, dtype=torch.long) + attention_mask = _prepare_attention_mask_for_generation(gpt_inputs, stop_token_tensor, stop_token_tensor) gen = self.gpt_inference.generate( gpt_inputs, bos_token_id=self.start_audio_token, pad_token_id=self.stop_audio_token, eos_token_id=self.stop_audio_token, max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1], + attention_mask=attention_mask, **hf_generate_kwargs, ) if "return_dict_in_generate" in hf_generate_kwargs: diff --git a/TTS/tts/layers/xtts/gpt_inference.py b/TTS/tts/layers/xtts/gpt_inference.py index 4625ae1ba9..e94683524a 100644 --- a/TTS/tts/layers/xtts/gpt_inference.py +++ b/TTS/tts/layers/xtts/gpt_inference.py @@ -1,10 +1,12 @@ import torch from torch import nn -from transformers import GPT2PreTrainedModel +from transformers import GenerationMixin, GPT2PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions +from TTS.tts.layers.xtts.stream_generator import StreamGenerationConfig -class GPT2InferenceModel(GPT2PreTrainedModel): + +class GPT2InferenceModel(GPT2PreTrainedModel, GenerationMixin): """Override GPT2LMHeadModel to allow for prefix conditioning.""" def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache): @@ -15,6 +17,7 @@ def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache): self.final_norm = norm self.lm_head = nn.Sequential(norm, linear) self.kv_cache = kv_cache + self.generation_config = StreamGenerationConfig.from_model_config(config) if self.can_generate() else None def store_prefix_emb(self, prefix_emb): self.cached_prefix_emb = prefix_emb diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 0b7652e450..c92db9c1d0 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -667,6 +667,7 @@ def inference_stream( repetition_penalty=float(repetition_penalty), output_attentions=False, output_hidden_states=True, + return_dict_in_generate=True, **hf_generate_kwargs, )