diff --git a/TTS/tts/configs/stylefast_pitch_config.py b/TTS/tts/configs/stylefast_pitch_config.py new file mode 100644 index 0000000000..c47344f7a8 --- /dev/null +++ b/TTS/tts/configs/stylefast_pitch_config.py @@ -0,0 +1,180 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.styleforward_tts import StyleForwardTTSArgs +from TTS.style_encoder.configs.style_encoder_config import StyleEncoderConfig + + +@dataclass +class StyleFastPitchConfig(BaseTTSConfig): + """Configure `ForwardTTS` as FastPitch model. + + Example: + + >>> from TTS.tts.configs.fast_pitch_config import FastPitchConfig + >>> config = FastPitchConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + binary_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_align_loss_start_step (int): + Start binary alignment loss after this many steps. Defaults to 20000. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "stylefast_pitch" + base_model: str = "styleforward_tts" + + # style encoder params + style_encoder_config: StyleEncoderConfig = None + + # model specific params + model_args: StyleForwardTTSArgs = StyleForwardTTSArgs() + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = True + ssim_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + binary_align_loss_alpha: float = 1.0 + binary_align_loss_start_step: int = 20000 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/models/styleforward_tts.py b/TTS/tts/models/styleforward_tts.py new file mode 100644 index 0000000000..ba908cfaa3 --- /dev/null +++ b/TTS/tts/models/styleforward_tts.py @@ -0,0 +1,735 @@ +from dataclasses import dataclass, field +from typing import Dict, Tuple + +import torch +from coqpit import Coqpit +from torch import nn +from torch.cuda.amp.autocast_mode import autocast + +from TTS.tts.layers.feed_forward.decoder import Decoder +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.aligner import AlignmentNetwork +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram + +# Import Style Encoder +from TTS.style_encoder.style_encoder import StyleEncoder + + +@dataclass +class StyleForwardTTSArgs(Coqpit): + """ForwardTTS Model arguments. + + Args: + + num_chars (int): + Number of characters in the vocabulary. Defaults to 100. + + out_channels (int): + Number of output channels. Defaults to 80. + + hidden_channels (int): + Number of base hidden channels of the model. Defaults to 512. + + use_aligner (bool): + Whether to use aligner network to learn the text to speech alignment or use pre-computed durations. + If set False, durations should be computed by `TTS/bin/compute_attention_masks.py` and path to the + pre-computed durations must be provided to `config.datasets[0].meta_file_attn_mask`. Defaults to True. + + use_pitch (bool): + Use pitch predictor to learn the pitch. Defaults to True. + + duration_predictor_hidden_channels (int): + Number of hidden channels in the duration predictor. Defaults to 256. + + duration_predictor_dropout_p (float): + Dropout rate for the duration predictor. Defaults to 0.1. + + duration_predictor_kernel_size (int): + Kernel size of conv layers in the duration predictor. Defaults to 3. + + pitch_predictor_hidden_channels (int): + Number of hidden channels in the pitch predictor. Defaults to 256. + + pitch_predictor_dropout_p (float): + Dropout rate for the pitch predictor. Defaults to 0.1. + + pitch_predictor_kernel_size (int): + Kernel size of conv layers in the pitch predictor. Defaults to 3. + + pitch_embedding_kernel_size (int): + Kernel size of the projection layer in the pitch predictor. Defaults to 3. + + positional_encoding (bool): + Whether to use positional encoding. Defaults to True. + + positional_encoding_use_scale (bool): + Whether to use a learnable scale coeff in the positional encoding. Defaults to True. + + length_scale (int): + Length scale that multiplies the predicted durations. Larger values result slower speech. Defaults to 1.0. + + encoder_type (str): + Type of the encoder module. One of the encoders available in :class:`TTS.tts.layers.feed_forward.encoder`. + Defaults to `fftransformer` as in the paper. + + encoder_params (dict): + Parameters of the encoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + decoder_type (str): + Type of the decoder module. One of the decoders available in :class:`TTS.tts.layers.feed_forward.decoder`. + Defaults to `fftransformer` as in the paper. + + decoder_params (str): + Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + detach_duration_predictor (bool): + Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss + does not pass to the earlier layers. Defaults to True. + + max_duration (int): + Maximum duration accepted by the model. Defaults to 75. + + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + + speaker_embedding_channels (int): + Number of speaker embedding channels. Defaults to 256. + + use_d_vector_file (bool): + Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False. + + d_vector_dim (int): + Number of d-vector channels. Defaults to 0. + + """ + + num_chars: int = None + out_channels: int = 80 + hidden_channels: int = 384 + use_aligner: bool = True + use_pitch: bool = True + pitch_predictor_hidden_channels: int = 256 + pitch_predictor_kernel_size: int = 3 + pitch_predictor_dropout_p: float = 0.1 + pitch_embedding_kernel_size: int = 3 + duration_predictor_hidden_channels: int = 256 + duration_predictor_kernel_size: int = 3 + duration_predictor_dropout_p: float = 0.1 + positional_encoding: bool = True + poisitonal_encoding_use_scale: bool = True + length_scale: int = 1 + encoder_type: str = "fftransformer" + encoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + decoder_type: str = "fftransformer" + decoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + detach_duration_predictor: bool = False + max_duration: int = 75 + num_speakers: int = 1 + use_speaker_embedding: bool = False + speakers_file: str = None + use_d_vector_file: bool = False + d_vector_dim: int = None + d_vector_file: str = None + + +class StyleForwardTTS(BaseTTS): + """General forward TTS model implementation that uses an encoder-decoder architecture with an optional alignment + network and a pitch predictor. + + If the alignment network is used, the model learns the text-to-speech alignment + from the data instead of using pre-computed durations. + + If the pitch predictor is used, the model trains a pitch predictor that predicts average pitch value for each + input character as in the FastPitch model. + + `ForwardTTS` can be configured to one of these architectures, + + - FastPitch + - SpeedySpeech + - FastSpeech + - TODO: FastSpeech2 (requires average speech energy predictor) + + Args: + config (Coqpit): Model coqpit class. + speaker_manager (SpeakerManager): Speaker manager for multi-speaker training. Only used for multi-speaker models. + Defaults to None. + + Examples: + >>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs + >>> config = ForwardTTSArgs() + >>> model = ForwardTTS(config) + """ + + # pylint: disable=dangerous-default-value + def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): + + super().__init__(config) + + self.speaker_manager = speaker_manager + self.init_multispeaker(config) + # # pass all config fields to `self` + # # for fewer code change + # for key in config: + # setattr(self, key, config[key]) + + self.max_duration = self.args.max_duration + self.use_aligner = self.args.use_aligner + self.use_pitch = self.args.use_pitch + self.use_binary_alignment_loss = False + + self.length_scale = ( + float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale + ) + + self.emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels) + + self.encoder = Encoder( + self.args.hidden_channels, + self.args.hidden_channels, + self.args.encoder_type, + self.args.encoder_params, + self.embedded_speaker_dim, + ) + + self.style_encoder_layer = StyleEncoder(self.config.style_encoder_config) + + if self.args.positional_encoding: + self.pos_encoder = PositionalEncoding(self.args.hidden_channels) + + self.decoder = Decoder( + self.args.out_channels, + self.args.hidden_channels, + self.args.decoder_type, + self.args.decoder_params, + ) + + self.duration_predictor = DurationPredictor( + self.args.hidden_channels + self.embedded_speaker_dim, + self.args.duration_predictor_hidden_channels, + self.args.duration_predictor_kernel_size, + self.args.duration_predictor_dropout_p, + ) + + if self.args.use_pitch: + self.pitch_predictor = DurationPredictor( + self.args.hidden_channels + self.embedded_speaker_dim, + self.args.pitch_predictor_hidden_channels, + self.args.pitch_predictor_kernel_size, + self.args.pitch_predictor_dropout_p, + ) + self.pitch_emb = nn.Conv1d( + 1, + self.args.hidden_channels, + kernel_size=self.args.pitch_embedding_kernel_size, + padding=int((self.args.pitch_embedding_kernel_size - 1) / 2), + ) + + if self.args.use_aligner: + self.aligner = AlignmentNetwork( + in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels + ) + + def init_multispeaker(self, config: Coqpit): + """Init for multi-speaker training. + + Args: + config (Coqpit): Model configuration. + """ + self.embedded_speaker_dim = 0 + # init speaker manager + if self.speaker_manager is None and (config.use_d_vector_file or config.use_speaker_embedding): + raise ValueError( + " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model." + ) + # set number of speakers + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + # init d-vector embedding + if config.use_d_vector_file: + self.embedded_speaker_dim = config.d_vector_dim + if self.args.d_vector_dim != self.args.hidden_channels: + self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + print(" > Init speaker_embedding layer.") + self.emb_g = nn.Embedding(self.args.num_speakers, self.args.hidden_channels) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + + @staticmethod + def generate_attn(dr, x_mask, y_mask=None): + """Generate an attention mask from the durations. + + Shapes + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + """ + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def expand_encoder_outputs(self, en, dr, x_mask, y_mask): + """Generate attention alignment map from durations and + expand encoder outputs + + Shapes: + - en: :math:`(B, D_{en}, T_{en})` + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + + Examples:: + + encoder output: [a,b,c,d] + durations: [1, 3, 2, 1] + + expanded: [a, b, b, b, c, c, d] + attention map: [[0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0]] + """ + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + """Format predicted durations. + 1. Convert to linear scale from log scale + 2. Apply the length scale for speed adjustment + 3. Apply masking. + 4. Cast 0 durations to 1. + 5. Round the duration values. + + Args: + o_dr_log: Log scale durations. + x_mask: Input text mask. + + Shapes: + - o_dr_log: :math:`(B, T_{de})` + - x_mask: :math:`(B, T_{en})` + """ + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + def _forward_encoder( + self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Encoding forward pass. + + 1. Embed speaker IDs if multi-speaker mode. + 2. Embed character sequences. + 3. Run the encoder network. + 4. Sum encoder outputs and speaker embeddings + + Args: + x (torch.LongTensor): Input sequence IDs. + x_mask (torch.FloatTensor): Input squence mask. + g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None. + + Returns: + Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]: + encoder output, encoder output for the duration predictor, input sequence mask, speaker embeddings, + character embeddings + + Shapes: + - x: :math:`(B, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - g: :math:`(B, C)` + """ + if hasattr(self, "emb_g"): + g = self.emb_g(g) # [B, C, 1] + if g is not None: + g = g.unsqueeze(-1) + # [B, T, C] + x_emb = self.emb(x) + # encoder pass + o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) + # speaker conditioning + # TODO: try different ways of conditioning + if g is not None: + o_en = o_en + g + return o_en, x_mask, g, x_emb + + def _forward_decoder( + self, + o_en: torch.FloatTensor, + dr: torch.IntTensor, + x_mask: torch.FloatTensor, + y_lengths: torch.IntTensor, + g: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Decoding forward pass. + + 1. Compute the decoder output mask + 2. Expand encoder output with the durations. + 3. Apply position encoding. + 4. Add speaker embeddings if multi-speaker mode. + 5. Run the decoder. + + Args: + o_en (torch.FloatTensor): Encoder output. + dr (torch.IntTensor): Ground truth durations or alignment network durations. + x_mask (torch.IntTensor): Input sequence mask. + y_lengths (torch.IntTensor): Output sequence lengths. + g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. + """ + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + # positional encoding + if hasattr(self, "pos_encoder"): + o_en_ex = self.pos_encoder(o_en_ex, y_mask) + # decoder pass + o_de = self.decoder(o_en_ex, y_mask, g=g) + return o_de.transpose(1, 2), attn.transpose(1, 2) + + def _forward_pitch_predictor( + self, + o_en: torch.FloatTensor, + x_mask: torch.IntTensor, + pitch: torch.FloatTensor = None, + dr: torch.IntTensor = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Pitch predictor forward pass. + + 1. Predict pitch from encoder outputs. + 2. In training - Compute average pitch values for each input character from the ground truth pitch values. + 3. Embed average pitch values. + + Args: + o_en (torch.FloatTensor): Encoder output. + x_mask (torch.IntTensor): Input sequence mask. + pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None. + dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction. + + Shapes: + - o_en: :math:`(B, C, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - pitch: :math:`(B, 1, T_{de})` + - dr: :math:`(B, T_{en})` + """ + o_pitch = self.pitch_predictor(o_en, x_mask) + if pitch is not None: + avg_pitch = average_over_durations(pitch, dr) + o_pitch_emb = self.pitch_emb(avg_pitch) + return o_pitch_emb, o_pitch, avg_pitch + o_pitch_emb = self.pitch_emb(o_pitch) + return o_pitch_emb, o_pitch + + def _forward_aligner( + self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor + ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Aligner forward pass. + + 1. Compute a mask to apply to the attention map. + 2. Run the alignment network. + 3. Apply MAS to compute the hard alignment map. + 4. Compute the durations from the hard alignment map. + + Args: + x (torch.FloatTensor): Input sequence. + y (torch.FloatTensor): Output sequence. + x_mask (torch.IntTensor): Input sequence mask. + y_mask (torch.IntTensor): Output sequence mask. + + Returns: + Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials, + hard alignment map. + + Shapes: + - x: :math:`[B, T_en, C_en]` + - y: :math:`[B, T_de, C_de]` + - x_mask: :math:`[B, 1, T_en]` + - y_mask: :math:`[B, 1, T_de]` + + - o_alignment_dur: :math:`[B, T_en]` + - alignment_soft: :math:`[B, T_en, T_de]` + - alignment_logprob: :math:`[B, 1, T_de, T_en]` + - alignment_mas: :math:`[B, T_en, T_de]` + """ + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None) + alignment_mas = maximum_path( + alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + ) + o_alignment_dur = torch.sum(alignment_mas, -1).int() + alignment_soft = alignment_soft.squeeze(1).transpose(1, 2) + return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas + + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def forward( + self, + x: torch.LongTensor, + x_lengths: torch.LongTensor, + y_lengths: torch.LongTensor, + y: torch.FloatTensor = None, + dr: torch.IntTensor = None, + pitch: torch.FloatTensor = None, + aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument + ) -> Dict: + """Model's forward pass. + + Args: + x (torch.LongTensor): Input character sequences. + x_lengths (torch.LongTensor): Input sequence lengths. + y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None. + y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. + dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None. + pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. + aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. + + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - y_lengths: :math:`[B]` + - y: :math:`[B, T_max2]` + - dr: :math:`[B, T_max]` + - g: :math:`[B, C]` + - pitch: :math:`[B, 1, T]` + """ + g = self._set_speaker_input(aux_input) + # compute sequence masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() + # encoder pass + o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g) + # duration predictor pass + if self.args.detach_duration_predictor: + o_dr_log = self.duration_predictor(o_en.detach(), x_mask) + else: + o_dr_log = self.duration_predictor(o_en, x_mask) + o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) + # generate attn mask from predicted durations + o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) + # aligner + o_alignment_dur = None + alignment_soft = None + alignment_logprob = None + alignment_mas = None + if self.use_aligner: + o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner( + x_emb, y, x_mask, y_mask + ) + alignment_soft = alignment_soft.transpose(1, 2) + alignment_mas = alignment_mas.transpose(1, 2) + dr = o_alignment_dur + # pitch predictor pass + o_pitch = None + avg_pitch = None + if self.args.use_pitch: + o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr) + o_en = o_en + o_pitch_emb + # decoder pass + o_de, attn = self._forward_decoder( + o_en, dr, x_mask, y_lengths, g=None + ) # TODO: maybe pass speaker embedding (g) too + outputs = { + "model_outputs": o_de, # [B, T, C] + "durations_log": o_dr_log.squeeze(1), # [B, T] + "durations": o_dr.squeeze(1), # [B, T] + "attn_durations": o_attn, # for visualization [B, T_en, T_de'] + "pitch_avg": o_pitch, + "pitch_avg_gt": avg_pitch, + "alignments": attn, # [B, T_de, T_en] + "alignment_soft": alignment_soft, + "alignment_mas": alignment_mas, + "o_alignment_dur": o_alignment_dur, + "alignment_logprob": alignment_logprob, + "x_mask": x_mask, + "y_mask": y_mask, + } + return outputs + + @torch.no_grad() + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument + """Model's inference pass. + + Args: + x (torch.LongTensor): Input character sequence. + aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`. + + Shapes: + - x: [B, T_max] + - x_lengths: [B] + - g: [B, C] + """ + g = self._set_speaker_input(aux_input) + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() + # encoder pass + o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g) + # duration predictor pass + o_dr_log = self.duration_predictor(o_en, x_mask) + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + y_lengths = o_dr.sum(1) + # pitch predictor pass + o_pitch = None + if self.args.use_pitch: + o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask) + o_en = o_en + o_pitch_emb + # decoder pass + o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None) + outputs = { + "model_outputs": o_de, + "alignments": attn, + "pitch": o_pitch, + "durations_log": o_dr_log, + } + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + pitch = batch["pitch"] if self.args.use_pitch else None + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + durations = batch["durations"] + aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + + # forward pass + outputs = self.forward( + text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input + ) + # use aligner's output as the duration target + if self.use_aligner: + durations = outputs["o_alignment_dur"] + # use float32 in AMP + with autocast(enabled=False): + # compute loss + loss_dict = criterion( + decoder_output=outputs["model_outputs"], + decoder_target=mel_input, + decoder_output_lens=mel_lengths, + dur_output=outputs["durations_log"], + dur_target=durations, + pitch_output=outputs["pitch_avg"] if self.use_pitch else None, + pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None, + input_lens=text_lengths, + alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, + alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None, + alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None, + ) + # compute duration error + durations_pred = outputs["durations"] + duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() + loss_dict["duration_error"] = duration_error + + return outputs, loss_dict + + def _create_logs(self, batch, outputs, ap): + """Create common logger outputs.""" + model_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # plot pitch figures + if self.args.use_pitch: + pitch = batch["pitch"] + pitch_avg_expanded, _ = self.expand_encoder_outputs( + outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"] + ) + pitch = pitch[0, 0].data.cpu().numpy() + # TODO: denormalize before plotting + pitch = abs(pitch) + pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy() + pitch_figures = { + "pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False), + "pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False), + } + figures.update(pitch_figures) + + # plot the attention mask computed from the predicted durations + if "attn_durations" in outputs: + alignments_hat = outputs["attn_durations"][0].data.cpu().numpy() + figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + ap = assets["audio_processor"] + figures, audios = self._create_logs(batch, outputs, ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + ap = assets["audio_processor"] + figures, audios = self._create_logs(batch, outputs, ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, ap.sample_rate) + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_criterion(self): + from TTS.tts.layers.losses import ForwardTTSLoss # pylint: disable=import-outside-toplevel + + return ForwardTTSLoss(self.config) + + def on_train_step_start(self, trainer): + """Enable binary alignment loss when needed""" + if trainer.total_steps_done > self.config.binary_align_loss_start_step: + self.use_binary_alignment_loss = True diff --git a/debug/Understanding Coqui-TTS pipeline to better customize.ipynb b/debug/Understanding Coqui-TTS pipeline to better customize.ipynb index 6a17f236a8..0e33c795a4 100644 --- a/debug/Understanding Coqui-TTS pipeline to better customize.ipynb +++ b/debug/Understanding Coqui-TTS pipeline to better customize.ipynb @@ -10,6 +10,452 @@ "sys.path.append('../')" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\lucas\\Anaconda3\\envs\\m_audio\\lib\\site-packages\\numpy\\_distributor_init.py:32: UserWarning: loaded more than 1 DLL from .libs:\n", + "C:\\Users\\lucas\\Anaconda3\\envs\\m_audio\\lib\\site-packages\\numpy\\.libs\\libopenblas.TXA6YQSD3GCQQC22GEQ54J2UDCXDXHWN.gfortran-win_amd64.dll\n", + "C:\\Users\\lucas\\Anaconda3\\envs\\m_audio\\lib\\site-packages\\numpy\\.libs\\libopenblas.WCDJNK7YVMPZQ2ME2ZZHJJRJ3JIKNDB7.gfortran-win_amd64.dll\n", + " stacklevel=1)\n", + "C:\\Users\\lucas\\Anaconda3\\envs\\m_audio\\lib\\site-packages\\torchaudio\\extension\\extension.py:14: UserWarning: torchaudio C++ extension is not available.\n", + " warnings.warn('torchaudio C++ extension is not available.')\n", + "C:\\Users\\lucas\\Anaconda3\\envs\\m_audio\\lib\\site-packages\\torchaudio\\backend\\utils.py:64: UserWarning: The interface of \"soundfile\" backend is planned to change in 0.8.0 to match that of \"sox_io\" backend and the current interface will be removed in 0.9.0. To use the new interface, do `torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False` before setting the backend to \"soundfile\". Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n", + " 'The interface of \"soundfile\" backend is planned to change in 0.8.0 to '\n" + ] + } + ], + "source": [ + "import os\n", + "import time\n", + "import torch\n", + "from TTS.config.shared_configs import BaseAudioConfig\n", + "from TTS.tts.configs.shared_configs import BaseDatasetConfig, GSTConfig\n", + "from TTS.tts.datasets import load_tts_samples\n", + "from TTS.utils.audio import AudioProcessor\n", + "from TTS.trainer_windows import Trainer, TrainingArgs\n", + "\n", + "# Old Tacotron Imports\n", + "from TTS.tts.models.tacotron2 import Tacotron2\n", + "from TTS.tts.configs.tacotron2_config import Tacotron2Config\n", + "from TTS.tts.configs.shared_configs import GSTConfig\n", + "\n", + "# Style Tacotron Imports\n", + "from TTS.tts.models.styletacotron2 import Styletacotron2\n", + "from TTS.tts.configs.styletacotron2_config import Styletacotron2Config\n", + "from TTS.style_encoder.configs.style_encoder_config import StyleEncoderConfig\n", + "\n", + "# Style forward TTS Imports\n", + "from TTS.tts.models.styleforward_tts import StyleForwardTTS\n", + "from TTS.tts.configs.stylefast_pitch_config import StyleFastPitchConfig\n", + "\n", + "def seed_everything(seed: int):\n", + " import random, os\n", + " import numpy as np\n", + " import torch\n", + " \n", + " random.seed(seed)\n", + " os.environ['PYTHONHASHSEED'] = str(seed)\n", + " np.random.seed(seed)\n", + " torch.manual_seed(seed)\n", + " torch.cuda.manual_seed(seed)\n", + " torch.backends.cudnn.deterministic = True\n", + " torch.backends.cudnn.benchmark = True" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " > Setting up Audio Processor...\n", + " | > sample_rate:22050\n", + " | > resample:False\n", + " | > num_mels:80\n", + " | > log_func:np.log\n", + " | > min_level_db:-100\n", + " | > frame_shift_ms:None\n", + " | > frame_length_ms:None\n", + " | > ref_level_db:20\n", + " | > fft_size:1024\n", + " | > power:1.5\n", + " | > preemphasis:0.0\n", + " | > griffin_lim_iters:60\n", + " | > signal_norm:False\n", + " | > symmetric_norm:True\n", + " | > mel_fmin:0\n", + " | > mel_fmax:8000\n", + " | > spec_gain:1.0\n", + " | > stft_pad_mode:reflect\n", + " | > max_norm:4.0\n", + " | > clip_norm:True\n", + " | > do_trim_silence:True\n", + " | > trim_db:60.0\n", + " | > do_sound_norm:False\n", + " | > do_amp_to_db_linear:True\n", + " | > do_amp_to_db_mel:True\n", + " | > do_rms_norm:False\n", + " | > db_level:None\n", + " | > stats_path:None\n", + " | > base:2.718281828459045\n", + " | > hop_length:256\n", + " | > win_length:1024\n", + " | > Found 13100 files in D:\\Mestrado\\Emotion Audio Synthesis (TTS)\\repo_final\\pt_etts\\data\\LJSpeech\\LJSpeech-1.1\n" + ] + } + ], + "source": [ + "seed_everything(42)\n", + "output_path = './'\n", + "\n", + "# init configs\n", + "dataset_config = BaseDatasetConfig(\n", + " name=\"ljspeech\", meta_file_train=\"metadata.csv\", path=os.path.join(output_path, \"D:/Mestrado/Emotion Audio Synthesis (TTS)/repo_final/pt_etts/data/LJSpeech\\LJSpeech-1.1\")\n", + ")\n", + "\n", + "audio_config = BaseAudioConfig(\n", + " sample_rate=22050,\n", + " do_trim_silence=True,\n", + " trim_db=60.0,\n", + " signal_norm=False,\n", + " mel_fmin=0.0,\n", + " mel_fmax=8000,\n", + " spec_gain=1.0,\n", + " log_func=\"np.log\",\n", + " ref_level_db=20,\n", + " preemphasis=0.0,\n", + ")\n", + "\n", + "style_config = StyleEncoderConfig(se_type=\"vaeflow\")\n", + "\n", + "config = StyleFastPitchConfig( # This is the config that is saved for the future use\n", + " style_encoder_config = style_config,\n", + " audio=audio_config,\n", + " batch_size=64,\n", + " eval_batch_size=16,\n", + " num_loader_workers=4,\n", + " num_eval_loader_workers=4,\n", + " run_eval=True,\n", + " test_delay_epochs=-1,\n", + " r=6,\n", + "# gradual_training=[[0, 6, 64], [10000, 4, 32], [50000, 3, 32], [100000, 2, 32]],\n", + "# double_decoder_consistency=True,\n", + " epochs=1000,\n", + " text_cleaner=\"phoneme_cleaners\",\n", + " use_phonemes=True,\n", + " phoneme_language=\"en-us\",\n", + " phoneme_cache_path=os.path.join(output_path, \"phoneme_cache\"),\n", + " print_step=25,\n", + " print_eval=True,\n", + " mixed_precision=False,\n", + " output_path=output_path,\n", + " datasets=[dataset_config],\n", + ")\n", + "# init audio processor\n", + "ap = AudioProcessor(**config.audio.to_dict())\n", + "\n", + "# load training samples\n", + "train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)\n", + "\n", + "# init model\n", + "model = StyleForwardTTS(config)\n", + "\n", + "# # init the trainer and 🚀\n", + "# trainer = Trainer(\n", + "# TrainingArgs(),\n", + "# config,\n", + "# output_path,\n", + "# model=model,\n", + "# train_samples=train_samples,\n", + "# eval_samples=eval_samples,\n", + "# training_assets={\"audio_processor\": ap},\n", + "# )\n", + "# # Data loader\n", + "# trainer.train_loader = trainer.get_train_dataloader(trainer.training_assets,trainer.train_samples, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "StyleForwardTTS(\n", + " (emb): Embedding(130, 384)\n", + " (encoder): Encoder(\n", + " (encoder): FFTransformerBlock(\n", + " (fft_layers): ModuleList(\n", + " (0): FFTransformer(\n", + " (self_attn): MultiheadAttention(\n", + " (out_proj): _LinearWithBias(in_features=384, out_features=384, bias=True)\n", + " )\n", + " (conv1): Conv1d(384, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 384, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (1): FFTransformer(\n", + " (self_attn): MultiheadAttention(\n", + " (out_proj): _LinearWithBias(in_features=384, out_features=384, bias=True)\n", + " )\n", + " (conv1): Conv1d(384, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 384, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (2): FFTransformer(\n", + " (self_attn): MultiheadAttention(\n", + " (out_proj): _LinearWithBias(in_features=384, out_features=384, bias=True)\n", + " )\n", + " (conv1): Conv1d(384, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 384, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (3): FFTransformer(\n", + " (self_attn): MultiheadAttention(\n", + " (out_proj): _LinearWithBias(in_features=384, out_features=384, bias=True)\n", + " )\n", + " (conv1): Conv1d(384, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 384, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (4): FFTransformer(\n", + " (self_attn): MultiheadAttention(\n", + " (out_proj): _LinearWithBias(in_features=384, out_features=384, bias=True)\n", + " )\n", + " (conv1): Conv1d(384, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 384, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (5): FFTransformer(\n", + " (self_attn): MultiheadAttention(\n", + " (out_proj): _LinearWithBias(in_features=384, out_features=384, bias=True)\n", + " )\n", + " (conv1): Conv1d(384, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 384, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (style_encoder_layer): StyleEncoder(\n", + " (layer): VAEFlowStyleEncoder(\n", + " (ref_encoder): ReferenceEncoder(\n", + " (convs): ModuleList(\n", + " (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (1): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " )\n", + " (bns): ModuleList(\n", + " (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (recurrence): GRU(256, 128, batch_first=True)\n", + " )\n", + " (q_z_layers_pre): ModuleList(\n", + " (0): Linear(in_features=128, out_features=300, bias=True)\n", + " (1): Linear(in_features=300, out_features=300, bias=True)\n", + " )\n", + " (q_z_layers_gate): ModuleList(\n", + " (0): Linear(in_features=128, out_features=300, bias=True)\n", + " (1): Linear(in_features=300, out_features=300, bias=True)\n", + " )\n", + " (q_z_mean): Linear(in_features=300, out_features=128, bias=True)\n", + " (q_z_logvar): Linear(in_features=300, out_features=128, bias=True)\n", + " (v_layers): ModuleList(\n", + " (0): Linear(in_features=300, out_features=128, bias=True)\n", + " (1): Linear(in_features=128, out_features=128, bias=True)\n", + " (2): Linear(in_features=128, out_features=128, bias=True)\n", + " (3): Linear(in_features=128, out_features=128, bias=True)\n", + " (4): Linear(in_features=128, out_features=128, bias=True)\n", + " (5): Linear(in_features=128, out_features=128, bias=True)\n", + " (6): Linear(in_features=128, out_features=128, bias=True)\n", + " (7): Linear(in_features=128, out_features=128, bias=True)\n", + " (8): Linear(in_features=128, out_features=128, bias=True)\n", + " (9): Linear(in_features=128, out_features=128, bias=True)\n", + " (10): Linear(in_features=128, out_features=128, bias=True)\n", + " (11): Linear(in_features=128, out_features=128, bias=True)\n", + " (12): Linear(in_features=128, out_features=128, bias=True)\n", + " (13): Linear(in_features=128, out_features=128, bias=True)\n", + " (14): Linear(in_features=128, out_features=128, bias=True)\n", + " (15): Linear(in_features=128, out_features=128, bias=True)\n", + " )\n", + " (sigmoid): Sigmoid()\n", + " (Gate): Gate()\n", + " (HF): HF()\n", + " )\n", + " )\n", + " (pos_encoder): PositionalEncoding()\n", + " (decoder): Decoder(\n", + " (decoder): FFTransformerDecoder(\n", + " (transformer_block): FFTransformerBlock(\n", + " (fft_layers): ModuleList(\n", + " (0): FFTransformer(\n", + " (self_attn): MultiheadAttention(\n", + " (out_proj): _LinearWithBias(in_features=384, out_features=384, bias=True)\n", + " )\n", + " (conv1): Conv1d(384, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 384, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (1): FFTransformer(\n", + " (self_attn): MultiheadAttention(\n", + " (out_proj): _LinearWithBias(in_features=384, out_features=384, bias=True)\n", + " )\n", + " (conv1): Conv1d(384, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 384, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (2): FFTransformer(\n", + " (self_attn): MultiheadAttention(\n", + " (out_proj): _LinearWithBias(in_features=384, out_features=384, bias=True)\n", + " )\n", + " (conv1): Conv1d(384, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 384, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (3): FFTransformer(\n", + " (self_attn): MultiheadAttention(\n", + " (out_proj): _LinearWithBias(in_features=384, out_features=384, bias=True)\n", + " )\n", + " (conv1): Conv1d(384, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 384, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (4): FFTransformer(\n", + " (self_attn): MultiheadAttention(\n", + " (out_proj): _LinearWithBias(in_features=384, out_features=384, bias=True)\n", + " )\n", + " (conv1): Conv1d(384, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 384, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (5): FFTransformer(\n", + " (self_attn): MultiheadAttention(\n", + " (out_proj): _LinearWithBias(in_features=384, out_features=384, bias=True)\n", + " )\n", + " (conv1): Conv1d(384, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (conv2): Conv1d(1024, 384, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " (postnet): Conv1d(384, 80, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " (duration_predictor): DurationPredictor(\n", + " (drop): Dropout(p=0.1, inplace=False)\n", + " (conv_1): Conv1d(384, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm_1): LayerNorm()\n", + " (conv_2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm_2): LayerNorm()\n", + " (proj): Conv1d(256, 1, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (pitch_predictor): DurationPredictor(\n", + " (drop): Dropout(p=0.1, inplace=False)\n", + " (conv_1): Conv1d(384, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm_1): LayerNorm()\n", + " (conv_2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (norm_2): LayerNorm()\n", + " (proj): Conv1d(256, 1, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (pitch_emb): Conv1d(1, 384, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (aligner): AlignmentNetwork(\n", + " (softmax): Softmax(dim=3)\n", + " (log_softmax): LogSoftmax(dim=3)\n", + " (key_layer): Sequential(\n", + " (0): Conv1d(384, 768, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): ReLU()\n", + " (2): Conv1d(768, 80, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (query_layer): Sequential(\n", + " (0): Conv1d(80, 160, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): ReLU()\n", + " (2): Conv1d(160, 80, kernel_size=(1,), stride=(1,))\n", + " (3): ReLU()\n", + " (4): Conv1d(80, 80, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/debug/config.json b/debug/config.json index c0f1415f8c..3cde797fa5 100644 --- a/debug/config.json +++ b/debug/config.json @@ -75,7 +75,8 @@ "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", "punctuations": "!'(),-.:;? ", "phonemes": "iy\u0268\u0289\u026fu\u026a\u028f\u028ae\u00f8\u0258\u0259\u0275\u0264o\u025b\u0153\u025c\u025e\u028c\u0254\u00e6\u0250a\u0276\u0251\u0252\u1d7b\u0298\u0253\u01c0\u0257\u01c3\u0284\u01c2\u0260\u01c1\u029bpbtd\u0288\u0256c\u025fk\u0261q\u0262\u0294\u0274\u014b\u0272\u0273n\u0271m\u0299r\u0280\u2c71\u027e\u027d\u0278\u03b2fv\u03b8\u00f0sz\u0283\u0292\u0282\u0290\u00e7\u029dx\u0263\u03c7\u0281\u0127\u0295h\u0266\u026c\u026e\u028b\u0279\u027bj\u0270l\u026d\u028e\u029f\u02c8\u02cc\u02d0\u02d1\u028dw\u0265\u029c\u02a2\u02a1\u0255\u0291\u027a\u0267\u02b2\u025a\u02de\u026b", - "unique": true + "unique": true, + "split_by_space": false }, "batch_group_size": 0, "loss_masking": true, @@ -116,15 +117,31 @@ "Prior to November 22, 1963." ], "style_encoder_config": { - "se_type": "gst", - "input_wav": null, + "se_type": "vaeflow", "num_mel": 80, - "style_embedding_dim": 256, + "style_embedding_dim": 128, "use_speaker_embedding": false, "gst_style_input_weights": null, "gst_num_heads": 4, "gst_num_style_tokens": 10, - "vae_latent_dim": 256 + "vae_latent_dim": 128, + "use_cyclical_annealing": true, + "vae_loss_alpha": 1.0, + "vae_cycle_period": 5000, + "use_nonlinear_proj": false, + "vaeflow_intern_dim": 300, + "vaeflow_number_of_flows": 16, + "diff_num_timesteps": 25, + "diff_schedule_type": "cosine", + "diff_loss_type": "l1", + "diff_ref_online": true, + "diff_step_dim": 128, + "diff_in_out_ch": 1, + "diff_num_heads": 1, + "diff_hidden_channels": 128, + "diff_num_blocks": 5, + "diff_dropout": 0.1, + "diff_loss_alpha": 0.75 }, "num_speakers": 1, "num_chars": 130, diff --git a/debug/config_forward.json b/debug/config_forward.json new file mode 100644 index 0000000000..bb203f4193 --- /dev/null +++ b/debug/config_forward.json @@ -0,0 +1,193 @@ +{ + "model": "stylefastpitch", + "run_name": "coqui_tts", + "run_description": "", + "epochs": 1000, + "batch_size": 64, + "eval_batch_size": 16, + "mixed_precision": false, + "scheduler_after_epoch": false, + "run_eval": true, + "test_delay_epochs": -1, + "print_eval": true, + "dashboard_logger": "tensorboard", + "print_step": 25, + "plot_step": 100, + "model_param_stats": false, + "project_name": null, + "log_model_step": null, + "wandb_entity": null, + "save_step": 10000, + "checkpoint": true, + "keep_all_best": false, + "keep_after": 10000, + "num_loader_workers": 4, + "num_eval_loader_workers": 4, + "use_noise_augment": false, + "use_language_weighted_sampler": false, + "output_path": "./", + "distributed_backend": "nccl", + "distributed_url": "tcp://localhost:54321", + "audio": { + "fft_size": 1024, + "win_length": 1024, + "hop_length": 256, + "frame_shift_ms": null, + "frame_length_ms": null, + "stft_pad_mode": "reflect", + "sample_rate": 22050, + "resample": false, + "preemphasis": 0.0, + "ref_level_db": 20, + "do_sound_norm": false, + "log_func": "np.log", + "do_trim_silence": true, + "trim_db": 60.0, + "do_rms_norm": false, + "db_level": null, + "power": 1.5, + "griffin_lim_iters": 60, + "num_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": 8000, + "spec_gain": 1.0, + "do_amp_to_db_linear": true, + "do_amp_to_db_mel": true, + "signal_norm": false, + "min_level_db": -100, + "symmetric_norm": true, + "max_norm": 4.0, + "clip_norm": true, + "stats_path": null + }, + "use_phonemes": true, + "use_espeak_phonemes": true, + "phoneme_language": "en-us", + "compute_input_seq_cache": false, + "text_cleaner": "phoneme_cleaners", + "enable_eos_bos_chars": false, + "test_sentences_file": "", + "phoneme_cache_path": "./phoneme_cache", + "characters": { + "pad": "_", + "eos": "~", + "bos": "^", + "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", + "punctuations": "!'(),-.:;? ", + "phonemes": "iy\u0268\u0289\u026fu\u026a\u028f\u028ae\u00f8\u0258\u0259\u0275\u0264o\u025b\u0153\u025c\u025e\u028c\u0254\u00e6\u0250a\u0276\u0251\u0252\u1d7b\u0298\u0253\u01c0\u0257\u01c3\u0284\u01c2\u0260\u01c1\u029bpbtd\u0288\u0256c\u025fk\u0261q\u0262\u0294\u0274\u014b\u0272\u0273n\u0271m\u0299r\u0280\u2c71\u027e\u027d\u0278\u03b2fv\u03b8\u00f0sz\u0283\u0292\u0282\u0290\u00e7\u029dx\u0263\u03c7\u0281\u0127\u0295h\u0266\u026c\u026e\u028b\u0279\u027bj\u0270l\u026d\u028e\u029f\u02c8\u02cc\u02d0\u02d1\u028dw\u0265\u029c\u02a2\u02a1\u0255\u0291\u027a\u0267\u02b2\u025a\u02de\u026b", + "unique": true + }, + "batch_group_size": 0, + "loss_masking": true, + "sort_by_audio_len": false, + "min_seq_len": 1, + "max_seq_len": Infinity, + "compute_f0": false, + "compute_linear_spec": false, + "add_blank": false, + "datasets": [ + { + "name": "ljspeech", + "path": "D:/Mestrado/Emotion Audio Synthesis (TTS)/repo_final/pt_etts/data/LJSpeech\\LJSpeech-1.1", + "meta_file_train": "metadata.csv", + "ignored_speakers": null, + "language": "", + "meta_file_val": "", + "meta_file_attn_mask": "" + } + ], + "optimizer": "RAdam", + "optimizer_params": { + "betas": [ + 0.9, + 0.998 + ], + "weight_decay": 1e-06 + }, + "lr_scheduler": "NoamLR", + "lr_scheduler_params": { + "warmup_steps": 4000 + }, + "test_sentences": [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963." + ], + "style_encoder_config": { + "se_type": "gst", + "input_wav": null, + "num_mel": 80, + "style_embedding_dim": 256, + "use_speaker_embedding": false, + "gst_style_input_weights": null, + "gst_num_heads": 4, + "gst_num_style_tokens": 10, + "vae_latent_dim": 256 + }, + "num_speakers": 1, + "num_chars": 130, + "r": 6, + "gradual_training": [ + [ + 0, + 6, + 64 + ], + [ + 10000, + 4, + 32 + ], + [ + 50000, + 3, + 32 + ], + [ + 100000, + 2, + 32 + ] + ], + "memory_size": -1, + "prenet_type": "original", + "prenet_dropout": true, + "prenet_dropout_at_inference": false, + "stopnet": true, + "separate_stopnet": true, + "stopnet_pos_weight": 10.0, + "max_decoder_steps": 500, + "encoder_in_features": 512, + "decoder_in_features": 512, + "decoder_output_dim": 80, + "out_channels": 80, + "attention_type": "original", + "attention_heads": null, + "attention_norm": "sigmoid", + "attention_win": false, + "windowing": false, + "use_forward_attn": false, + "forward_attn_mask": false, + "transition_agent": false, + "location_attn": true, + "bidirectional_decoder": false, + "double_decoder_consistency": true, + "ddc_r": 6, + "use_speaker_embedding": false, + "speaker_embedding_dim": 512, + "use_d_vector_file": false, + "d_vector_file": false, + "d_vector_dim": null, + "lr": 0.0001, + "grad_clip": 5.0, + "seq_len_norm": false, + "decoder_loss_alpha": 0.25, + "postnet_loss_alpha": 0.25, + "postnet_diff_spec_alpha": 0.25, + "decoder_diff_spec_alpha": 0.25, + "decoder_ssim_alpha": 0.25, + "postnet_ssim_alpha": 0.25, + "ga_alpha": 5.0 +} \ No newline at end of file diff --git a/debug/coqui_tts-May-09-2022_03+01PM-0000000/config.json b/debug/coqui_tts-May-09-2022_03+01PM-0000000/config.json new file mode 100644 index 0000000000..3cde797fa5 --- /dev/null +++ b/debug/coqui_tts-May-09-2022_03+01PM-0000000/config.json @@ -0,0 +1,210 @@ +{ + "model": "tacotron2", + "run_name": "coqui_tts", + "run_description": "", + "epochs": 1000, + "batch_size": 64, + "eval_batch_size": 16, + "mixed_precision": false, + "scheduler_after_epoch": false, + "run_eval": true, + "test_delay_epochs": -1, + "print_eval": true, + "dashboard_logger": "tensorboard", + "print_step": 25, + "plot_step": 100, + "model_param_stats": false, + "project_name": null, + "log_model_step": null, + "wandb_entity": null, + "save_step": 10000, + "checkpoint": true, + "keep_all_best": false, + "keep_after": 10000, + "num_loader_workers": 4, + "num_eval_loader_workers": 4, + "use_noise_augment": false, + "use_language_weighted_sampler": false, + "output_path": "./", + "distributed_backend": "nccl", + "distributed_url": "tcp://localhost:54321", + "audio": { + "fft_size": 1024, + "win_length": 1024, + "hop_length": 256, + "frame_shift_ms": null, + "frame_length_ms": null, + "stft_pad_mode": "reflect", + "sample_rate": 22050, + "resample": false, + "preemphasis": 0.0, + "ref_level_db": 20, + "do_sound_norm": false, + "log_func": "np.log", + "do_trim_silence": true, + "trim_db": 60.0, + "do_rms_norm": false, + "db_level": null, + "power": 1.5, + "griffin_lim_iters": 60, + "num_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": 8000, + "spec_gain": 1.0, + "do_amp_to_db_linear": true, + "do_amp_to_db_mel": true, + "signal_norm": false, + "min_level_db": -100, + "symmetric_norm": true, + "max_norm": 4.0, + "clip_norm": true, + "stats_path": null + }, + "use_phonemes": true, + "use_espeak_phonemes": true, + "phoneme_language": "en-us", + "compute_input_seq_cache": false, + "text_cleaner": "phoneme_cleaners", + "enable_eos_bos_chars": false, + "test_sentences_file": "", + "phoneme_cache_path": "./phoneme_cache", + "characters": { + "pad": "_", + "eos": "~", + "bos": "^", + "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", + "punctuations": "!'(),-.:;? ", + "phonemes": "iy\u0268\u0289\u026fu\u026a\u028f\u028ae\u00f8\u0258\u0259\u0275\u0264o\u025b\u0153\u025c\u025e\u028c\u0254\u00e6\u0250a\u0276\u0251\u0252\u1d7b\u0298\u0253\u01c0\u0257\u01c3\u0284\u01c2\u0260\u01c1\u029bpbtd\u0288\u0256c\u025fk\u0261q\u0262\u0294\u0274\u014b\u0272\u0273n\u0271m\u0299r\u0280\u2c71\u027e\u027d\u0278\u03b2fv\u03b8\u00f0sz\u0283\u0292\u0282\u0290\u00e7\u029dx\u0263\u03c7\u0281\u0127\u0295h\u0266\u026c\u026e\u028b\u0279\u027bj\u0270l\u026d\u028e\u029f\u02c8\u02cc\u02d0\u02d1\u028dw\u0265\u029c\u02a2\u02a1\u0255\u0291\u027a\u0267\u02b2\u025a\u02de\u026b", + "unique": true, + "split_by_space": false + }, + "batch_group_size": 0, + "loss_masking": true, + "sort_by_audio_len": false, + "min_seq_len": 1, + "max_seq_len": Infinity, + "compute_f0": false, + "compute_linear_spec": false, + "add_blank": false, + "datasets": [ + { + "name": "ljspeech", + "path": "D:/Mestrado/Emotion Audio Synthesis (TTS)/repo_final/pt_etts/data/LJSpeech\\LJSpeech-1.1", + "meta_file_train": "metadata.csv", + "ignored_speakers": null, + "language": "", + "meta_file_val": "", + "meta_file_attn_mask": "" + } + ], + "optimizer": "RAdam", + "optimizer_params": { + "betas": [ + 0.9, + 0.998 + ], + "weight_decay": 1e-06 + }, + "lr_scheduler": "NoamLR", + "lr_scheduler_params": { + "warmup_steps": 4000 + }, + "test_sentences": [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963." + ], + "style_encoder_config": { + "se_type": "vaeflow", + "num_mel": 80, + "style_embedding_dim": 128, + "use_speaker_embedding": false, + "gst_style_input_weights": null, + "gst_num_heads": 4, + "gst_num_style_tokens": 10, + "vae_latent_dim": 128, + "use_cyclical_annealing": true, + "vae_loss_alpha": 1.0, + "vae_cycle_period": 5000, + "use_nonlinear_proj": false, + "vaeflow_intern_dim": 300, + "vaeflow_number_of_flows": 16, + "diff_num_timesteps": 25, + "diff_schedule_type": "cosine", + "diff_loss_type": "l1", + "diff_ref_online": true, + "diff_step_dim": 128, + "diff_in_out_ch": 1, + "diff_num_heads": 1, + "diff_hidden_channels": 128, + "diff_num_blocks": 5, + "diff_dropout": 0.1, + "diff_loss_alpha": 0.75 + }, + "num_speakers": 1, + "num_chars": 130, + "r": 6, + "gradual_training": [ + [ + 0, + 6, + 64 + ], + [ + 10000, + 4, + 32 + ], + [ + 50000, + 3, + 32 + ], + [ + 100000, + 2, + 32 + ] + ], + "memory_size": -1, + "prenet_type": "original", + "prenet_dropout": true, + "prenet_dropout_at_inference": false, + "stopnet": true, + "separate_stopnet": true, + "stopnet_pos_weight": 10.0, + "max_decoder_steps": 500, + "encoder_in_features": 512, + "decoder_in_features": 512, + "decoder_output_dim": 80, + "out_channels": 80, + "attention_type": "original", + "attention_heads": null, + "attention_norm": "sigmoid", + "attention_win": false, + "windowing": false, + "use_forward_attn": false, + "forward_attn_mask": false, + "transition_agent": false, + "location_attn": true, + "bidirectional_decoder": false, + "double_decoder_consistency": true, + "ddc_r": 6, + "use_speaker_embedding": false, + "speaker_embedding_dim": 512, + "use_d_vector_file": false, + "d_vector_file": false, + "d_vector_dim": null, + "lr": 0.0001, + "grad_clip": 5.0, + "seq_len_norm": false, + "decoder_loss_alpha": 0.25, + "postnet_loss_alpha": 0.25, + "postnet_diff_spec_alpha": 0.25, + "decoder_diff_spec_alpha": 0.25, + "decoder_ssim_alpha": 0.25, + "postnet_ssim_alpha": 0.25, + "ga_alpha": 5.0 +} \ No newline at end of file