Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different tts model types. #1541

Merged
merged 7 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions egs/ljspeech/TTS/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Introduction

This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books.
A transcription is provided for each clip.
This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books.
A transcription is provided for each clip.
Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours.

The texts were published between 1884 and 1964, and are in the public domain.
The texts were published between 1884 and 1964, and are in the public domain.
The audio was recorded in 2016-17 by the [LibriVox](https://librivox.org/) project and is also in the public domain.

The above information is from the [LJSpeech website](https://keithito.com/LJ-Speech-Dataset/).
Expand Down Expand Up @@ -35,4 +35,4 @@ To inference, use:
--exp-dir vits/exp \
--epoch 1000 \
--tokens data/tokens.txt
```
```
2 changes: 1 addition & 1 deletion egs/ljspeech/TTS/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare LJSpeech manifest"
# We assume that you have downloaded the LJSpeech corpus
# to $dl_dir/LJSpeech
# to $dl_dir/LJSpeech-1.1
mkdir -p data/manifests
if [ ! -e data/manifests/.ljspeech.done ]; then
lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests
Expand Down
27 changes: 11 additions & 16 deletions egs/ljspeech/TTS/vits/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
--exp-dir vits/exp \
--tokens data/tokens.txt

It will generate two files inside vits/exp:
It will generate one file inside vits/exp:
- vits-epoch-1000.onnx
- vits-epoch-1000.int8.onnx (quantizated model)

See ./test_onnx.py for how to use the exported ONNX models.
"""
Expand All @@ -40,7 +39,6 @@
import onnx
import torch
import torch.nn as nn
from onnxruntime.quantization import QuantType, quantize_dynamic
from tokenizer import Tokenizer
from train import get_model, get_params

Expand Down Expand Up @@ -75,6 +73,15 @@ def get_parser():
help="""Path to vocabulary.""",
)

parser.add_argument(
"--model-type",
type=str,
default="",
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)

return parser


Expand Down Expand Up @@ -240,7 +247,7 @@ def main():
model = OnnxModel(model=model)

num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"generator parameters: {num_param}")
logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M")

suffix = f"epoch-{params.epoch}"

Expand All @@ -256,18 +263,6 @@ def main():
)
logging.info(f"Exported generator to {model_filename}")

# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection

logging.info("Generate int8 quantization models")

model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
quantize_dynamic(
model_input=model_filename,
model_output=model_filename_int8,
weight_type=QuantType.QUInt8,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quantizing using quint8 is very slow at run time, so we removed it.

)


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
Expand Down
2 changes: 1 addition & 1 deletion egs/ljspeech/TTS/vits/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def __init__(
self.upsample_factor = int(np.prod(decoder_upsample_scales))
self.spks = None
if spks is not None and spks > 1:
assert global_channels > 0
assert global_channels > 0, global_channels
self.spks = spks
self.global_emb = torch.nn.Embedding(spks, global_channels)
self.spk_embed_dim = None
Expand Down
10 changes: 10 additions & 0 deletions egs/ljspeech/TTS/vits/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ def get_parser():
help="""Path to vocabulary.""",
)

parser.add_argument(
"--model-type",
type=str,
default="",
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)

return parser


Expand All @@ -94,6 +103,7 @@ def infer_dataset(
tokenizer:
Used to convert text to phonemes.
"""

# Background worker save audios to disk.
def _save_worker(
batch_size: int,
Expand Down
51 changes: 51 additions & 0 deletions egs/ljspeech/TTS/vits/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from tokenizer import Tokenizer
from train import get_model, get_params
from vits import VITS


def test_model_type(model_type):
tokens = "./data/tokens.txt"

params = get_params()

tokenizer = Tokenizer(tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_type = model_type

model = get_model(params)
generator = model.generator

num_param = sum([p.numel() for p in generator.parameters()])
print(
f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M"
)


def main():
test_model_type("high") # 35.63 M
test_model_type("low") # 7.55 M
test_model_type("medium") # 23.61 M
test_model_type("") # 35.63 M


if __name__ == "__main__":
main()
7 changes: 4 additions & 3 deletions egs/ljspeech/TTS/vits/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def forward(
x_lengths (Tensor): Length tensor (B,).

Returns:
Tensor: Encoded hidden representation (B, attention_dim, T_text).
Tensor: Projected mean tensor (B, attention_dim, T_text).
Tensor: Projected scale tensor (B, attention_dim, T_text).
Tensor: Encoded hidden representation (B, embed_dim, T_text).
Tensor: Projected mean tensor (B, embed_dim, T_text).
Tensor: Projected scale tensor (B, embed_dim, T_text).
Tensor: Mask tensor for input tensor (B, 1, T_text).

"""
Expand All @@ -108,6 +108,7 @@ def forward(

# encoder assume the channel last (B, T_text, embed_dim)
x = self.encoder(x, key_padding_mask=pad_mask)
# Note: attention_dim == embed_dim

# convert the channel first (B, embed_dim, T_text)
x = x.transpose(1, 2)
Expand Down
21 changes: 11 additions & 10 deletions egs/ljspeech/TTS/vits/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--model-type",
type=str,
default="",
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)

return parser


Expand Down Expand Up @@ -189,15 +198,6 @@ def get_params() -> AttributeDict:

- feature_dim: The model input dim. It has to match the one used
in computing features.

- subsampling_factor: The subsampling factor for the model.

- encoder_dim: Hidden dim for multi-head attention model.

- num_decoder_layers: Number of decoder layer of transformer decoder.

- warm_step: The warmup period that dictates the decay of the
scale on "simple" (un-pruned) loss.
"""
params = AttributeDict(
{
Expand Down Expand Up @@ -278,6 +278,7 @@ def get_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size,
feature_dim=params.feature_dim,
sampling_rate=params.sampling_rate,
model_type=params.model_type,
mel_loss_params=mel_loss_params,
lambda_adv=params.lambda_adv,
lambda_mel=params.lambda_mel,
Expand Down Expand Up @@ -363,7 +364,7 @@ def train_one_epoch(
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device

# used to summary the stats over iterations in one epoch
# used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()

saved_bad_model = False
Expand Down
53 changes: 52 additions & 1 deletion egs/ljspeech/TTS/vits/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

"""VITS module for GAN-TTS task."""

import copy
from typing import Any, Dict, Optional, Tuple

import torch
Expand Down Expand Up @@ -38,6 +39,36 @@
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
}

LOW_CONFIG = {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"hidden_channels": 96,
"decoder_upsample_scales": (8, 8, 4),
"decoder_channels": 256,
"decoder_upsample_kernel_sizes": (16, 16, 8),
"decoder_resblock_kernel_sizes": (3, 5, 7),
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
"text_encoder_cnn_module_kernel": 3,
}

MEDIUM_CONFIG = {
"hidden_channels": 192,
"decoder_upsample_scales": (8, 8, 4),
"decoder_channels": 256,
"decoder_upsample_kernel_sizes": (16, 16, 8),
"decoder_resblock_kernel_sizes": (3, 5, 7),
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
"text_encoder_cnn_module_kernel": 3,
}

HIGH_CONFIG = {
"hidden_channels": 192,
"decoder_upsample_scales": (8, 8, 2, 2),
"decoder_channels": 512,
"decoder_upsample_kernel_sizes": (16, 16, 4, 4),
"decoder_resblock_kernel_sizes": (3, 7, 11),
"decoder_resblock_dilations": ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
"text_encoder_cnn_module_kernel": 3,
}


class VITS(nn.Module):
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
Expand All @@ -49,6 +80,7 @@ def __init__(
feature_dim: int = 513,
sampling_rate: int = 22050,
generator_type: str = "vits_generator",
model_type: str = "",
generator_params: Dict[str, Any] = {
"hidden_channels": 192,
"spks": None,
Expand Down Expand Up @@ -155,12 +187,13 @@ def __init__(
"""Initialize VITS module.

Args:
idim (int): Input vocabrary size.
idim (int): Input vocabulary size.
odim (int): Acoustic feature dimension. The actual output channels will
be 1 since VITS is the end-to-end text-to-wave model but for the
compatibility odim is used to indicate the acoustic feature dimension.
sampling_rate (int): Sampling rate, not used for the training but it will
be referred in saving waveform during the inference.
model_type (str): If not empty, must be one of: low, medium, high
generator_type (str): Generator type.
generator_params (Dict[str, Any]): Parameter dict for generator.
discriminator_type (str): Discriminator type.
Expand All @@ -181,6 +214,24 @@ def __init__(
"""
super().__init__()

generator_params = copy.deepcopy(generator_params)
discriminator_params = copy.deepcopy(discriminator_params)
generator_adv_loss_params = copy.deepcopy(generator_adv_loss_params)
discriminator_adv_loss_params = copy.deepcopy(discriminator_adv_loss_params)
feat_match_loss_params = copy.deepcopy(feat_match_loss_params)
mel_loss_params = copy.deepcopy(mel_loss_params)

if model_type != "":
assert model_type in ("low", "medium", "high"), model_type
if model_type == "low":
generator_params.update(LOW_CONFIG)
elif model_type == "medium":
generator_params.update(MEDIUM_CONFIG)
elif model_type == "high":
generator_params.update(HIGH_CONFIG)
else:
raise ValueError(f"Unknown model_type: ${model_type}")

# define modules
generator_class = AVAILABLE_GENERATERS[generator_type]
if generator_type == "vits_generator":
Expand Down
Loading