diff --git a/docs/source/recipes/TTS/index.rst b/docs/source/recipes/TTS/index.rst
index aa891c072e..80d67a2f3d 100644
--- a/docs/source/recipes/TTS/index.rst
+++ b/docs/source/recipes/TTS/index.rst
@@ -5,3 +5,4 @@ TTS
:maxdepth: 2
ljspeech/vits
+ vctk/vits
\ No newline at end of file
diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst
index 385fd3c705..d08aa0f470 100644
--- a/docs/source/recipes/TTS/ljspeech/vits.rst
+++ b/docs/source/recipes/TTS/ljspeech/vits.rst
@@ -4,6 +4,10 @@ VITS
This tutorial shows you how to train an VITS model
with the `LJSpeech `_ dataset.
+.. note::
+
+ TTS related recipes require packages in ``requirements-tts.txt``.
+
.. note::
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_
@@ -27,6 +31,12 @@ To run stage 1 to stage 5, use
Build Monotonic Alignment Search
--------------------------------
+.. code-block:: bash
+
+ $ ./prepare.sh --stage -1 --stop_stage -1
+
+or
+
.. code-block:: bash
$ cd vits/monotonic_align
@@ -74,7 +84,7 @@ training part first. It will save the ground-truth and generated wavs to the dir
$ ./vits/infer.py \
--epoch 1000 \
--exp-dir vits/exp \
- --tokens data/tokens.txt
+ --tokens data/tokens.txt \
--max-duration 500
.. note::
diff --git a/docs/source/recipes/TTS/vctk/vits.rst b/docs/source/recipes/TTS/vctk/vits.rst
new file mode 100644
index 0000000000..34024a5ea5
--- /dev/null
+++ b/docs/source/recipes/TTS/vctk/vits.rst
@@ -0,0 +1,125 @@
+VITS
+===============
+
+This tutorial shows you how to train an VITS model
+with the `VCTK `_ dataset.
+
+.. note::
+
+ TTS related recipes require packages in ``requirements-tts.txt``.
+
+.. note::
+
+ The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_
+
+
+Data preparation
+----------------
+
+.. code-block:: bash
+
+ $ cd egs/vctk/TTS
+ $ ./prepare.sh
+
+To run stage 1 to stage 6, use
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage 1 --stop_stage 6
+
+
+Build Monotonic Alignment Search
+--------------------------------
+
+To build the monotonic alignment search, use the following commands:
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage -1 --stop_stage -1
+
+or
+
+.. code-block:: bash
+
+ $ cd vits/monotonic_align
+ $ python setup.py build_ext --inplace
+ $ cd ../../
+
+
+Training
+--------
+
+.. code-block:: bash
+
+ $ export CUDA_VISIBLE_DEVICES="0,1,2,3"
+ $ ./vits/train.py \
+ --world-size 4 \
+ --num-epochs 1000 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+ --max-duration 350
+
+.. note::
+
+ You can adjust the hyper-parameters to control the size of the VITS model and
+ the training configurations. For more details, please run ``./vits/train.py --help``.
+
+.. note::
+
+ The training can take a long time (usually a couple of days).
+
+Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``.
+
+
+Inference
+---------
+
+The inference part uses checkpoints saved by the training part, so you have to run the
+training part first. It will save the ground-truth and generated wavs to the directory
+``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``.
+
+.. code-block:: bash
+
+ $ export CUDA_VISIBLE_DEVICES="0"
+ $ ./vits/infer.py \
+ --epoch 1000 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt \
+ --max-duration 500
+
+.. note::
+
+ For more details, please run ``./vits/infer.py --help``.
+
+
+Export models
+-------------
+
+Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
+``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
+
+.. code-block:: bash
+
+ $ ./vits/export-onnx.py \
+ --epoch 1000 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+
+You can test the exported ONNX model with:
+
+.. code-block:: bash
+
+ $ ./vits/test_onnx.py \
+ --model-filename vits/exp/vits-epoch-1000.onnx \
+ --tokens data/tokens.txt
+
+
+Download pretrained models
+--------------------------
+
+If you don't want to train from scratch, you can download the pretrained models
+by visiting the following link:
+
+ - ``_
diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh
index 8ee40896e1..ed0a07f5e2 100755
--- a/egs/ljspeech/TTS/prepare.sh
+++ b/egs/ljspeech/TTS/prepare.sh
@@ -5,8 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
-nj=1
-stage=-1
+stage=0
stop_stage=100
dl_dir=$PWD/download
@@ -25,6 +24,17 @@ log() {
log "dl_dir: $dl_dir"
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+ log "Stage -1: build monotonic_align lib"
+ if [ ! -d vits/monotonic_align/build ]; then
+ cd vits/monotonic_align
+ python setup.py build_ext --inplace
+ cd ../../
+ else
+ log "monotonic_align lib already built"
+ fi
+fi
+
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
@@ -113,5 +123,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
--tokens data/tokens.txt
fi
fi
-
-
diff --git a/egs/ljspeech/TTS/vits/duration_predictor.py b/egs/ljspeech/TTS/vits/duration_predictor.py
index c29a28479a..1a8190014a 100644
--- a/egs/ljspeech/TTS/vits/duration_predictor.py
+++ b/egs/ljspeech/TTS/vits/duration_predictor.py
@@ -14,7 +14,6 @@
import torch
import torch.nn.functional as F
-
from flow import (
ConvFlow,
DilatedDepthSeparableConv,
diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py
index 154de4bf42..2068adeea0 100755
--- a/egs/ljspeech/TTS/vits/export-onnx.py
+++ b/egs/ljspeech/TTS/vits/export-onnx.py
@@ -180,7 +180,13 @@ def export_model_onnx(
model_filename,
verbose=False,
opset_version=opset_version,
- input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"],
+ input_names=[
+ "tokens",
+ "tokens_lens",
+ "noise_scale",
+ "noise_scale_dur",
+ "alpha",
+ ],
output_names=["audio"],
dynamic_axes={
"tokens": {0: "N", 1: "T"},
diff --git a/egs/ljspeech/TTS/vits/flow.py b/egs/ljspeech/TTS/vits/flow.py
index 206bd5e3e5..2b84f64340 100644
--- a/egs/ljspeech/TTS/vits/flow.py
+++ b/egs/ljspeech/TTS/vits/flow.py
@@ -13,7 +13,6 @@
from typing import Optional, Tuple, Union
import torch
-
from transform import piecewise_rational_quadratic_transform
diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py
index efb0e254cf..66c8cedb19 100644
--- a/egs/ljspeech/TTS/vits/generator.py
+++ b/egs/ljspeech/TTS/vits/generator.py
@@ -16,9 +16,6 @@
import numpy as np
import torch
import torch.nn.functional as F
-
-from icefall.utils import make_pad_mask
-
from duration_predictor import StochasticDurationPredictor
from hifigan import HiFiGANGenerator
from posterior_encoder import PosteriorEncoder
@@ -26,6 +23,8 @@
from text_encoder import TextEncoder
from utils import get_random_segments
+from icefall.utils import make_pad_mask
+
class VITSGenerator(torch.nn.Module):
"""Generator module in VITS, `Conditional Variational Autoencoder
diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py
index 91a35e3602..cf0d20ae23 100755
--- a/egs/ljspeech/TTS/vits/infer.py
+++ b/egs/ljspeech/TTS/vits/infer.py
@@ -36,13 +36,12 @@
import torch
import torch.nn as nn
import torchaudio
-
-from train import get_model, get_params
from tokenizer import Tokenizer
+from train import get_model, get_params
+from tts_datamodule import LJSpeechTtsDataModule
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
-from tts_datamodule import LJSpeechTtsDataModule
def get_parser():
@@ -107,12 +106,12 @@ def _save_worker(
for i in range(batch_size):
torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
- audio[i:i + 1, :audio_lens[i]],
+ audio[i : i + 1, : audio_lens[i]],
sample_rate=params.sampling_rate,
)
torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
- audio_pred[i:i + 1, :audio_lens_pred[i]],
+ audio_pred[i : i + 1, : audio_lens_pred[i]],
sample_rate=params.sampling_rate,
)
@@ -144,14 +143,24 @@ def _save_worker(
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]
- audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens)
+ audio_pred, _, durations = model.inference_batch(
+ text=tokens, text_lengths=tokens_lens
+ )
audio_pred = audio_pred.detach().cpu()
# convert to samples
- audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
+ audio_lens_pred = (
+ (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
+ )
futures.append(
executor.submit(
- _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
+ _save_worker,
+ batch_size,
+ cut_ids,
+ audio,
+ audio_pred,
+ audio_lens,
+ audio_lens_pred,
)
)
@@ -160,7 +169,9 @@ def _save_worker(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
- logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
# return results
for f in futures:
f.result()
diff --git a/egs/ljspeech/TTS/vits/loss.py b/egs/ljspeech/TTS/vits/loss.py
index 21aaad6e75..2f4dc9bc05 100644
--- a/egs/ljspeech/TTS/vits/loss.py
+++ b/egs/ljspeech/TTS/vits/loss.py
@@ -14,7 +14,6 @@
import torch
import torch.distributions as D
import torch.nn.functional as F
-
from lhotse.features.kaldi import Wav2LogFilterBank
diff --git a/egs/ljspeech/TTS/vits/posterior_encoder.py b/egs/ljspeech/TTS/vits/posterior_encoder.py
index 6b8a5be52f..1104fb864c 100644
--- a/egs/ljspeech/TTS/vits/posterior_encoder.py
+++ b/egs/ljspeech/TTS/vits/posterior_encoder.py
@@ -12,9 +12,9 @@
from typing import Optional, Tuple
import torch
+from wavenet import Conv1d, WaveNet
from icefall.utils import make_pad_mask
-from wavenet import WaveNet, Conv1d
class PosteriorEncoder(torch.nn.Module):
diff --git a/egs/ljspeech/TTS/vits/residual_coupling.py b/egs/ljspeech/TTS/vits/residual_coupling.py
index 2d6807cb7c..f9a2a3786e 100644
--- a/egs/ljspeech/TTS/vits/residual_coupling.py
+++ b/egs/ljspeech/TTS/vits/residual_coupling.py
@@ -12,7 +12,6 @@
from typing import Optional, Tuple, Union
import torch
-
from flow import FlipFlow
from wavenet import WaveNet
diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py
index 8acca7c026..686fee2a03 100755
--- a/egs/ljspeech/TTS/vits/test_onnx.py
+++ b/egs/ljspeech/TTS/vits/test_onnx.py
@@ -28,10 +28,10 @@
import argparse
import logging
+
import onnxruntime as ort
import torch
import torchaudio
-
from tokenizer import Tokenizer
diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py
index 9f337e45bb..fcbae7103f 100644
--- a/egs/ljspeech/TTS/vits/text_encoder.py
+++ b/egs/ljspeech/TTS/vits/text_encoder.py
@@ -169,9 +169,7 @@ def forward(
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
- x = self.encoder(
- x, pos_emb, key_padding_mask=key_padding_mask
- ) # (T, N, C)
+ x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
x = self.after_norm(x)
@@ -207,7 +205,9 @@ def __init__(
nn.Linear(dim_feedforward, d_model),
)
- self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout)
+ self.self_attn = RelPositionMultiheadAttention(
+ d_model, num_heads, dropout=dropout
+ )
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
@@ -242,7 +242,9 @@ def forward(
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
"""
# macaron style feed-forward module
- src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src)))
+ src = src + self.ff_scale * self.dropout(
+ self.feed_forward_macaron(self.norm_ff_macaron(src))
+ )
# multi-head self-attention module
src_attn = self.self_attn(
@@ -490,11 +492,17 @@ def forward(
q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
- v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
+ v = (
+ v.contiguous()
+ .view(seq_len, batch_size * self.num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
- p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim)
+ p = self.linear_pos(pos_emb).view(
+ pos_emb.size(0), -1, self.num_heads, self.head_dim
+ )
# (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
p = p.permute(0, 2, 3, 1)
@@ -506,15 +514,23 @@ def forward(
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
- matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len)
+ matrix_ac = torch.matmul(
+ q_with_bias_u, k
+ ) # (batch_size, num_head, seq_len, seq_len)
# compute matrix b and matrix d
- matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1)
- matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len)
+ matrix_bd = torch.matmul(
+ q_with_bias_v, p
+ ) # (batch_size, num_head, seq_len, 2*seq_len-1)
+ matrix_bd = self.rel_shift(
+ matrix_bd
+ ) # (batch_size, num_head, seq_len, seq_len)
# (batch_size, num_head, seq_len, seq_len)
attn_output_weights = (matrix_ac + matrix_bd) * scaling
- attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len)
+ attn_output_weights = attn_output_weights.view(
+ batch_size * self.num_heads, seq_len, seq_len
+ )
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, seq_len)
@@ -536,10 +552,16 @@ def forward(
# (batch_size * num_head, seq_len, head_dim)
attn_output = torch.bmm(attn_output_weights, v)
- assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim)
+ assert attn_output.shape == (
+ batch_size * self.num_heads,
+ seq_len,
+ self.head_dim,
+ )
attn_output = (
- attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim)
+ attn_output.transpose(0, 1)
+ .contiguous()
+ .view(seq_len, batch_size, self.embed_dim)
)
# (seq_len, batch_size, embed_dim)
attn_output = self.out_proj(attn_output)
diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py
index 0678b26fe0..70f1240b4a 100644
--- a/egs/ljspeech/TTS/vits/tokenizer.py
+++ b/egs/ljspeech/TTS/vits/tokenizer.py
@@ -78,7 +78,9 @@ def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
return token_ids_list
- def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True):
+ def tokens_to_token_ids(
+ self, tokens_list: List[str], intersperse_blank: bool = True
+ ):
"""
Args:
tokens_list:
diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py
index eb43a4cc93..71c4224fa7 100755
--- a/egs/ljspeech/TTS/vits/train.py
+++ b/egs/ljspeech/TTS/vits/train.py
@@ -18,21 +18,25 @@
import argparse
import logging
-import numpy as np
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
+import numpy as np
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
-from torch.optim import Optimizer
+from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
+from tts_datamodule import LJSpeechTtsDataModule
+from utils import MetricsTracker, plot_feature, save_checkpoint
+from vits import VITS
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint
@@ -41,11 +45,6 @@
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, setup_logger, str2bool
-from tokenizer import Tokenizer
-from tts_datamodule import LJSpeechTtsDataModule
-from utils import MetricsTracker, plot_feature, save_checkpoint
-from vits import VITS
-
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
@@ -385,11 +384,12 @@ def save_bad_model(suffix: str = ""):
params.batch_idx_train += 1
batch_size = len(batch["tokens"])
- audio, audio_lens, features, features_lens, tokens, tokens_lens = \
- prepare_input(batch, tokenizer, device)
+ audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
+ batch, tokenizer, device
+ )
loss_info = MetricsTracker()
- loss_info['samples'] = batch_size
+ loss_info["samples"] = batch_size
try:
with autocast(enabled=params.use_fp16):
@@ -446,7 +446,9 @@ def save_bad_model(suffix: str = ""):
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
- if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
+ if cur_grad_scale < 8.0 or (
+ cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
+ ):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
@@ -482,9 +484,7 @@ def save_bad_model(suffix: str = ""):
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
- tot_loss.write_summary(
- tb_writer, "train/tot_", params.batch_idx_train
- )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
@@ -492,19 +492,34 @@ def save_bad_model(suffix: str = ""):
if "returned_sample" in stats_g:
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
tb_writer.add_audio(
- "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
+ "train/speech_hat_",
+ speech_hat_,
+ params.batch_idx_train,
+ params.sampling_rate,
)
tb_writer.add_audio(
- "train/speech_", speech_, params.batch_idx_train, params.sampling_rate
+ "train/speech_",
+ speech_,
+ params.batch_idx_train,
+ params.sampling_rate,
)
tb_writer.add_image(
- "train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC'
+ "train/mel_hat_",
+ plot_feature(mel_hat_),
+ params.batch_idx_train,
+ dataformats="HWC",
)
tb_writer.add_image(
- "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
+ "train/mel_",
+ plot_feature(mel_),
+ params.batch_idx_train,
+ dataformats="HWC",
)
- if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
+ if (
+ params.batch_idx_train % params.valid_interval == 0
+ and not params.print_diagnostics
+ ):
logging.info("Computing validation loss")
valid_info, (speech_hat, speech) = compute_validation_loss(
params=params,
@@ -523,10 +538,16 @@ def save_bad_model(suffix: str = ""):
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_audio(
- "train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate
+ "train/valdi_speech_hat",
+ speech_hat,
+ params.batch_idx_train,
+ params.sampling_rate,
)
tb_writer.add_audio(
- "train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate
+ "train/valdi_speech",
+ speech,
+ params.batch_idx_train,
+ params.sampling_rate,
)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
@@ -555,11 +576,17 @@ def compute_validation_loss(
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
batch_size = len(batch["tokens"])
- audio, audio_lens, features, features_lens, tokens, tokens_lens = \
- prepare_input(batch, tokenizer, device)
+ (
+ audio,
+ audio_lens,
+ features,
+ features_lens,
+ tokens,
+ tokens_lens,
+ ) = prepare_input(batch, tokenizer, device)
loss_info = MetricsTracker()
- loss_info['samples'] = batch_size
+ loss_info["samples"] = batch_size
# forward discriminator
loss_d, stats_d = model(
@@ -596,12 +623,17 @@ def compute_validation_loss(
if batch_idx == 0 and rank == 0:
inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference(
- text=tokens[0, :tokens_lens[0].item()]
+ text=tokens[0, : tokens_lens[0].item()]
)
audio_pred = audio_pred.data.cpu().numpy()
- audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
- assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
- audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy()
+ audio_len_pred = (
+ (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
+ )
+ assert audio_len_pred == len(audio_pred), (
+ audio_len_pred,
+ len(audio_pred),
+ )
+ audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
returned_sample = (audio_pred, audio_gt)
if world_size > 1:
@@ -632,8 +664,9 @@ def scan_pessimistic_batches_for_oom(
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
- audio, audio_lens, features, features_lens, tokens, tokens_lens = \
- prepare_input(batch, tokenizer, device)
+ audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
+ batch, tokenizer, device
+ )
try:
# for discriminator
with autocast(enabled=params.use_fp16):
diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py
index 0fcbb92c16..81bb9ed130 100644
--- a/egs/ljspeech/TTS/vits/tts_datamodule.py
+++ b/egs/ljspeech/TTS/vits/tts_datamodule.py
@@ -29,10 +29,10 @@
CutConcatenate,
CutMix,
DynamicBucketingSampler,
- SpeechSynthesisDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
+ SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
diff --git a/egs/ljspeech/TTS/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py
index 2a3dae9007..6a067f5961 100644
--- a/egs/ljspeech/TTS/vits/utils.py
+++ b/egs/ljspeech/TTS/vits/utils.py
@@ -14,15 +14,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple, Union
import collections
import logging
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, Union
import torch
-import torch.nn as nn
import torch.distributed as dist
+import torch.nn as nn
from lhotse.dataset.sampling.base import CutSampler
-from pathlib import Path
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
@@ -97,23 +97,23 @@ def plot_feature(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
+
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
- mpl_logger = logging.getLogger('matplotlib')
+ mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
- im = ax.imshow(spectrogram, aspect="auto", origin="lower",
- interpolation='none')
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py
index d5e20a5787..b4f0c21e6d 100644
--- a/egs/ljspeech/TTS/vits/vits.py
+++ b/egs/ljspeech/TTS/vits/vits.py
@@ -9,8 +9,7 @@
import torch
import torch.nn as nn
-from torch.cuda.amp import autocast
-
+from generator import VITSGenerator
from hifigan import (
HiFiGANMultiPeriodDiscriminator,
HiFiGANMultiScaleDiscriminator,
@@ -25,9 +24,8 @@
KLDivergenceLoss,
MelSpectrogramLoss,
)
+from torch.cuda.amp import autocast
from utils import get_segments
-from generator import VITSGenerator
-
AVAILABLE_GENERATERS = {
"vits_generator": VITSGenerator,
@@ -42,8 +40,7 @@
class VITS(nn.Module):
- """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`
- """
+ """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
def __init__(
self,
diff --git a/egs/ljspeech/TTS/vits/wavenet.py b/egs/ljspeech/TTS/vits/wavenet.py
index fbe1be52b0..5db461d5ce 100644
--- a/egs/ljspeech/TTS/vits/wavenet.py
+++ b/egs/ljspeech/TTS/vits/wavenet.py
@@ -9,9 +9,8 @@
"""
-import math
import logging
-
+import math
from typing import Optional, Tuple
import torch
diff --git a/egs/vctk/TTS/local/compute_spectrogram_vctk.py b/egs/vctk/TTS/local/compute_spectrogram_vctk.py
new file mode 100755
index 0000000000..440ac12451
--- /dev/null
+++ b/egs/vctk/TTS/local/compute_spectrogram_vctk.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the VCTK dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/spectrogram.
+"""
+
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import (
+ CutSet,
+ LilcomChunkyWriter,
+ Spectrogram,
+ SpectrogramConfig,
+ load_manifest,
+)
+from lhotse.audio import RecordingSet
+from lhotse.supervision import SupervisionSet
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_spectrogram_vctk():
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/spectrogram")
+ num_jobs = min(32, os.cpu_count())
+
+ sampling_rate = 22050
+ frame_length = 1024 / sampling_rate # (in second)
+ frame_shift = 256 / sampling_rate # (in second)
+ use_fft_mag = True
+
+ prefix = "vctk"
+ suffix = "jsonl.gz"
+ partition = "all"
+
+ recordings = load_manifest(
+ src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet
+ ).resample(sampling_rate=sampling_rate)
+ supervisions = load_manifest(
+ src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet
+ )
+
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ use_fft_mag=use_fft_mag,
+ )
+ extractor = Spectrogram(config)
+
+ with get_executor() as ex: # Initialize the executor only once.
+ cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ return
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=recordings, supervisions=supervisions
+ )
+
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 80,
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+ cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ compute_spectrogram_vctk()
diff --git a/egs/vctk/TTS/local/display_manifest_statistics.py b/egs/vctk/TTS/local/display_manifest_statistics.py
new file mode 100755
index 0000000000..0472e2cea3
--- /dev/null
+++ b/egs/vctk/TTS/local/display_manifest_statistics.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+
+See the function `remove_short_and_long_utt()` in vits/train.py
+for usage.
+"""
+
+
+from lhotse import load_manifest_lazy
+
+
+def main():
+ path = "./data/spectrogram/vctk_cuts_all.jsonl.gz"
+ cuts = load_manifest_lazy(path)
+ cuts.describe()
+
+
+if __name__ == "__main__":
+ main()
+
+"""
+Cut statistics:
+╒═══════════════════════════╤══════════╕
+│ Cuts count: │ 43873 │
+├───────────────────────────┼──────────┤
+│ Total duration (hh:mm:ss) │ 41:02:18 │
+├───────────────────────────┼──────────┤
+│ mean │ 3.4 │
+├───────────────────────────┼──────────┤
+│ std │ 1.2 │
+├───────────────────────────┼──────────┤
+│ min │ 1.2 │
+├───────────────────────────┼──────────┤
+│ 25% │ 2.6 │
+├───────────────────────────┼──────────┤
+│ 50% │ 3.1 │
+├───────────────────────────┼──────────┤
+│ 75% │ 3.8 │
+├───────────────────────────┼──────────┤
+│ 99% │ 8.0 │
+├───────────────────────────┼──────────┤
+│ 99.5% │ 9.1 │
+├───────────────────────────┼──────────┤
+│ 99.9% │ 12.1 │
+├───────────────────────────┼──────────┤
+│ max │ 16.6 │
+├───────────────────────────┼──────────┤
+│ Recordings available: │ 43873 │
+├───────────────────────────┼──────────┤
+│ Features available: │ 43873 │
+├───────────────────────────┼──────────┤
+│ Supervisions available: │ 43873 │
+╘═══════════════════════════╧══════════╛
+SUPERVISION custom fields:
+Speech duration statistics:
+╒══════════════════════════════╤══════════╤══════════════════════╕
+│ Total speech duration │ 41:02:18 │ 100.00% of recording │
+├──────────────────────────────┼──────────┼──────────────────────┤
+│ Total speaking time duration │ 41:02:18 │ 100.00% of recording │
+├──────────────────────────────┼──────────┼──────────────────────┤
+│ Total silence duration │ 00:00:01 │ 0.00% of recording │
+╘══════════════════════════════╧══════════╧══════════════════════╛
+"""
diff --git a/egs/vctk/TTS/local/prepare_token_file.py b/egs/vctk/TTS/local/prepare_token_file.py
new file mode 100755
index 0000000000..c6636c3ad6
--- /dev/null
+++ b/egs/vctk/TTS/local/prepare_token_file.py
@@ -0,0 +1,104 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file reads the texts in given manifest and generates the file that maps tokens to IDs.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Dict
+
+from lhotse import load_manifest
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--manifest-file",
+ type=Path,
+ default=Path("data/spectrogram/vctk_cuts_all.jsonl.gz"),
+ help="Path to the manifest file",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=Path,
+ default=Path("data/tokens.txt"),
+ help="Path to the tokens",
+ )
+
+ return parser.parse_args()
+
+
+def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
+ """Write a symbol to ID mapping to a file.
+
+ Note:
+ No need to implement `read_mapping` as it can be done
+ through :func:`k2.SymbolTable.from_file`.
+
+ Args:
+ filename:
+ Filename to save the mapping.
+ sym2id:
+ A dict mapping symbols to IDs.
+ Returns:
+ Return None.
+ """
+ with open(filename, "w", encoding="utf-8") as f:
+ for sym, i in sym2id.items():
+ f.write(f"{sym} {i}\n")
+
+
+def get_token2id(manifest_file: Path) -> Dict[str, int]:
+ """Return a dict that maps token to IDs."""
+ extra_tokens = [
+ "", # 0 for blank
+ "", # 1 for sos and eos symbols.
+ "", # 2 for OOV
+ ]
+ all_tokens = set()
+
+ cut_set = load_manifest(manifest_file)
+
+ for cut in cut_set:
+ # Each cut only contain one supervision
+ assert len(cut.supervisions) == 1, len(cut.supervisions)
+ for t in cut.tokens:
+ all_tokens.add(t)
+
+ all_tokens = extra_tokens + list(all_tokens)
+
+ token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
+ return token2id
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+ manifest_file = Path(args.manifest_file)
+ out_file = Path(args.tokens)
+
+ token2id = get_token2id(manifest_file)
+ write_mapping(out_file, token2id)
diff --git a/egs/vctk/TTS/local/prepare_tokens_vctk.py b/egs/vctk/TTS/local/prepare_tokens_vctk.py
new file mode 100755
index 0000000000..32e1c7dfad
--- /dev/null
+++ b/egs/vctk/TTS/local/prepare_tokens_vctk.py
@@ -0,0 +1,61 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file reads the texts in given manifest and save the new cuts with phoneme tokens.
+"""
+
+import logging
+from pathlib import Path
+
+import g2p_en
+import tacotron_cleaner.cleaners
+from lhotse import CutSet, load_manifest
+from tqdm.auto import tqdm
+
+
+def prepare_tokens_vctk():
+ output_dir = Path("data/spectrogram")
+ prefix = "vctk"
+ suffix = "jsonl.gz"
+ partition = "all"
+
+ cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
+ g2p = g2p_en.G2p()
+
+ new_cuts = []
+ for cut in tqdm(cut_set):
+ # Each cut only contains one supervision
+ assert len(cut.supervisions) == 1, len(cut.supervisions)
+ text = cut.supervisions[0].text
+ # Text normalization
+ text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
+ # Convert to phonemes
+ cut.tokens = g2p(text)
+ new_cuts.append(cut)
+
+ new_cut_set = CutSet.from_cuts(new_cuts)
+ new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ prepare_tokens_vctk()
diff --git a/egs/vctk/TTS/local/validate_manifest.py b/egs/vctk/TTS/local/validate_manifest.py
new file mode 100755
index 0000000000..cd466303ed
--- /dev/null
+++ b/egs/vctk/TTS/local/validate_manifest.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script checks the following assumptions of the generated manifest:
+
+- Single supervision per cut
+
+We will add more checks later if needed.
+
+Usage example:
+
+ python3 ./local/validate_manifest.py \
+ ./data/spectrogram/ljspeech_cuts_all.jsonl.gz
+
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+from lhotse import CutSet, load_manifest_lazy
+from lhotse.dataset.speech_synthesis import validate_for_tts
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "manifest",
+ type=Path,
+ help="Path to the manifest file",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ manifest = args.manifest
+ logging.info(f"Validating {manifest}")
+
+ assert manifest.is_file(), f"{manifest} does not exist"
+ cut_set = load_manifest_lazy(manifest)
+ assert isinstance(cut_set, CutSet)
+
+ validate_for_tts(cut_set)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/vctk/TTS/prepare.sh b/egs/vctk/TTS/prepare.sh
new file mode 100755
index 0000000000..87150ad315
--- /dev/null
+++ b/egs/vctk/TTS/prepare.sh
@@ -0,0 +1,131 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+stage=0
+stop_stage=100
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+ log "Stage -1: build monotonic_align lib"
+ if [ ! -d vits/monotonic_align/build ]; then
+ cd vits/monotonic_align
+ python setup.py build_ext --inplace
+ cd ../../
+ else
+ log "monotonic_align lib already built"
+ fi
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ # If you have pre-downloaded it to /path/to/VCTK,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/VCTK $dl_dir/VCTK
+ #
+ if [ ! -d $dl_dir/VCTK ]; then
+ lhotse download vctk $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare VCTK manifest"
+ # We assume that you have downloaded the VCTK corpus
+ # to $dl_dir/VCTK
+ mkdir -p data/manifests
+ if [ ! -e data/manifests/.vctk.done ]; then
+ lhotse prepare vctk --use-edinburgh-vctk-url true $dl_dir/VCTK data/manifests
+ touch data/manifests/.vctk.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Compute spectrogram for VCTK"
+ mkdir -p data/spectrogram
+ if [ ! -e data/spectrogram/.vctk.done ]; then
+ ./local/compute_spectrogram_vctk.py
+ touch data/spectrogram/.vctk.done
+ fi
+
+ if [ ! -e data/spectrogram/.vctk-validated.done ]; then
+ log "Validating data/fbank for VCTK"
+ ./local/validate_manifest.py \
+ data/spectrogram/vctk_cuts_all.jsonl.gz
+ touch data/spectrogram/.vctk-validated.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare phoneme tokens for VCTK"
+ if [ ! -e data/spectrogram/.vctk_with_token.done ]; then
+ ./local/prepare_tokens_vctk.py
+ mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \
+ data/spectrogram/vctk_cuts_all.jsonl.gz
+ touch data/spectrogram/.vctk_with_token.done
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Split the VCTK cuts into train, valid and test sets"
+ if [ ! -e data/spectrogram/.vctk_split.done ]; then
+ lhotse subset --last 600 \
+ data/spectrogram/vctk_cuts_all.jsonl.gz \
+ data/spectrogram/vctk_cuts_validtest.jsonl.gz
+ lhotse subset --first 100 \
+ data/spectrogram/vctk_cuts_validtest.jsonl.gz \
+ data/spectrogram/vctk_cuts_valid.jsonl.gz
+ lhotse subset --last 500 \
+ data/spectrogram/vctk_cuts_validtest.jsonl.gz \
+ data/spectrogram/vctk_cuts_test.jsonl.gz
+
+ rm data/spectrogram/vctk_cuts_validtest.jsonl.gz
+
+ n=$(( $(gunzip -c data/spectrogram/vctk_cuts_all.jsonl.gz | wc -l) - 600 ))
+ lhotse subset --first $n \
+ data/spectrogram/vctk_cuts_all.jsonl.gz \
+ data/spectrogram/vctk_cuts_train.jsonl.gz
+ touch data/spectrogram/.vctk_split.done
+ fi
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Generate token file"
+ # We assume you have installed g2p_en and espnet_tts_frontend.
+ # If not, please install them with:
+ # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
+ # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
+ if [ ! -e data/tokens.txt ]; then
+ ./local/prepare_token_file.py \
+ --manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \
+ --tokens data/tokens.txt
+ fi
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Generate speakers file"
+ if [ ! -e data/speakers.txt ]; then
+ gunzip -c data/manifests/vctk_supervisions_all.jsonl.gz \
+ | jq '.speaker' | sed 's/"//g' \
+ | sort | uniq > data/speakers.txt
+ fi
+fi
diff --git a/egs/vctk/TTS/shared b/egs/vctk/TTS/shared
new file mode 120000
index 0000000000..4c5e91438c
--- /dev/null
+++ b/egs/vctk/TTS/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/duration_predictor.py b/egs/vctk/TTS/vits/duration_predictor.py
new file mode 120000
index 0000000000..9972b476f9
--- /dev/null
+++ b/egs/vctk/TTS/vits/duration_predictor.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/duration_predictor.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py
new file mode 100755
index 0000000000..7c9664cc14
--- /dev/null
+++ b/egs/vctk/TTS/vits/export-onnx.py
@@ -0,0 +1,284 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script exports a VITS model from PyTorch to ONNX.
+
+Export the model to ONNX:
+./vits/export-onnx.py \
+ --epoch 1000 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+
+It will generate two files inside vits/exp:
+ - vits-epoch-1000.onnx
+ - vits-epoch-1000.int8.onnx (quantizated model)
+
+See ./test_onnx.py for how to use the exported ONNX models.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Dict, Tuple
+
+import onnx
+import torch
+import torch.nn as nn
+from onnxruntime.quantization import QuantType, quantize_dynamic
+from tokenizer import Tokenizer
+from train import get_model, get_params
+
+from icefall.checkpoint import load_checkpoint
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=1000,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="vits/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--speakers",
+ type=Path,
+ default=Path("data/speakers.txt"),
+ help="Path to speakers.txt file.",
+ )
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ return parser
+
+
+def add_meta_data(filename: str, meta_data: Dict[str, str]):
+ """Add meta data to an ONNX model. It is changed in-place.
+
+ Args:
+ filename:
+ Filename of the ONNX model to be changed.
+ meta_data:
+ Key-value pairs.
+ """
+ model = onnx.load(filename)
+ for key, value in meta_data.items():
+ meta = model.metadata_props.add()
+ meta.key = key
+ meta.value = value
+
+ onnx.save(model, filename)
+
+
+class OnnxModel(nn.Module):
+ """A wrapper for VITS generator."""
+
+ def __init__(self, model: nn.Module):
+ """
+ Args:
+ model:
+ A VITS generator.
+ frame_shift:
+ The frame shift in samples.
+ """
+ super().__init__()
+ self.model = model
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ tokens_lens: torch.Tensor,
+ noise_scale: float = 0.667,
+ noise_scale_dur: float = 0.8,
+ speaker: int = 20,
+ alpha: float = 1.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Please see the help information of VITS.inference_batch
+
+ Args:
+ tokens:
+ Input text token indexes (1, T_text)
+ tokens_lens:
+ Number of tokens of shape (1,)
+ noise_scale (float):
+ Noise scale parameter for flow.
+ noise_scale_dur (float):
+ Noise scale parameter for duration predictor.
+ speaker (int):
+ Speaker ID.
+ alpha (float):
+ Alpha parameter to control the speed of generated speech.
+
+ Returns:
+ Return a tuple containing:
+ - audio, generated wavform tensor, (B, T_wav)
+ """
+ audio, _, _ = self.model.inference(
+ text=tokens,
+ text_lengths=tokens_lens,
+ noise_scale=noise_scale,
+ noise_scale_dur=noise_scale_dur,
+ sids=speaker,
+ alpha=alpha,
+ )
+ return audio
+
+
+def export_model_onnx(
+ model: nn.Module,
+ model_filename: str,
+ opset_version: int = 11,
+) -> None:
+ """Export the given generator model to ONNX format.
+ The exported model has one input:
+
+ - tokens, a tensor of shape (1, T_text); dtype is torch.int64
+
+ and it has one output:
+
+ - audio, a tensor of shape (1, T'); dtype is torch.float32
+
+ Args:
+ model:
+ The VITS generator.
+ model_filename:
+ The filename to save the exported ONNX model.
+ opset_version:
+ The opset version to use.
+ """
+ tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
+ tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
+ noise_scale = torch.tensor([1], dtype=torch.float32)
+ noise_scale_dur = torch.tensor([1], dtype=torch.float32)
+ alpha = torch.tensor([1], dtype=torch.float32)
+ speaker = torch.tensor([1], dtype=torch.int64)
+
+ torch.onnx.export(
+ model,
+ (tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha),
+ model_filename,
+ verbose=False,
+ opset_version=opset_version,
+ input_names=[
+ "tokens",
+ "tokens_lens",
+ "noise_scale",
+ "noise_scale_dur",
+ "speaker",
+ "alpha",
+ ],
+ output_names=["audio"],
+ dynamic_axes={
+ "tokens": {0: "N", 1: "T"},
+ "tokens_lens": {0: "N"},
+ "audio": {0: "N", 1: "T"},
+ "speaker": {0: "N"},
+ },
+ )
+
+ meta_data = {
+ "model_type": "VITS",
+ "version": "1",
+ "model_author": "k2-fsa",
+ "comment": "VITS generator",
+ }
+ logging.info(f"meta_data: {meta_data}")
+
+ add_meta_data(filename=model_filename, meta_data=meta_data)
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ tokenizer = Tokenizer(params.tokens)
+ params.blank_id = tokenizer.blank_id
+ params.oov_id = tokenizer.oov_id
+ params.vocab_size = tokenizer.vocab_size
+
+ with open(args.speakers) as f:
+ speaker_map = {line.strip(): i for i, line in enumerate(f)}
+ params.num_spks = len(speaker_map)
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+
+ model = model.generator
+ model.to("cpu")
+ model.eval()
+
+ model = OnnxModel(model=model)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"generator parameters: {num_param}")
+
+ suffix = f"epoch-{params.epoch}"
+
+ opset_version = 13
+
+ logging.info("Exporting encoder")
+ model_filename = params.exp_dir / f"vits-{suffix}.onnx"
+ export_model_onnx(
+ model,
+ model_filename,
+ opset_version=opset_version,
+ )
+ logging.info(f"Exported generator to {model_filename}")
+
+ # Generate int8 quantization models
+ # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
+
+ logging.info("Generate int8 quantization models")
+
+ model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
+ quantize_dynamic(
+ model_input=model_filename,
+ model_output=model_filename_int8,
+ weight_type=QuantType.QUInt8,
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/vctk/TTS/vits/flow.py b/egs/vctk/TTS/vits/flow.py
new file mode 120000
index 0000000000..e65d91ea75
--- /dev/null
+++ b/egs/vctk/TTS/vits/flow.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/flow.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/generator.py b/egs/vctk/TTS/vits/generator.py
new file mode 120000
index 0000000000..611679bfa8
--- /dev/null
+++ b/egs/vctk/TTS/vits/generator.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/generator.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/hifigan.py b/egs/vctk/TTS/vits/hifigan.py
new file mode 120000
index 0000000000..5ac025de72
--- /dev/null
+++ b/egs/vctk/TTS/vits/hifigan.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/hifigan.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/infer.py b/egs/vctk/TTS/vits/infer.py
new file mode 100755
index 0000000000..06c25f02eb
--- /dev/null
+++ b/egs/vctk/TTS/vits/infer.py
@@ -0,0 +1,272 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script performs model inference on test set.
+
+Usage:
+./vits/infer.py \
+ --epoch 1000 \
+ --exp-dir ./vits/exp \
+ --max-duration 500
+"""
+
+
+import argparse
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+from typing import Dict, List
+
+import k2
+import torch
+import torch.nn as nn
+import torchaudio
+from tokenizer import Tokenizer
+from train import get_model, get_params
+from tts_datamodule import VctkTtsDataModule
+
+from icefall.checkpoint import load_checkpoint
+from icefall.utils import AttributeDict, setup_logger
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=1000,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="vits/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ return parser
+
+
+def infer_dataset(
+ dl: torch.utils.data.DataLoader,
+ subset: str,
+ params: AttributeDict,
+ model: nn.Module,
+ tokenizer: Tokenizer,
+ speaker_map: Dict[str, int],
+) -> None:
+ """Decode dataset.
+ The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ tokenizer:
+ Used to convert text to phonemes.
+ """
+
+ # Background worker save audios to disk.
+ def _save_worker(
+ subset: str,
+ batch_size: int,
+ cut_ids: List[str],
+ audio: torch.Tensor,
+ audio_pred: torch.Tensor,
+ audio_lens: List[int],
+ audio_lens_pred: List[int],
+ ):
+ for i in range(batch_size):
+ torchaudio.save(
+ str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"),
+ audio[i : i + 1, : audio_lens[i]],
+ sample_rate=params.sampling_rate,
+ )
+ torchaudio.save(
+ str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"),
+ audio_pred[i : i + 1, : audio_lens_pred[i]],
+ sample_rate=params.sampling_rate,
+ )
+
+ device = next(model.parameters()).device
+ num_cuts = 0
+ log_interval = 5
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ futures = []
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ for batch_idx, batch in enumerate(dl):
+ batch_size = len(batch["tokens"])
+
+ tokens = batch["tokens"]
+ tokens = tokenizer.tokens_to_token_ids(tokens)
+ tokens = k2.RaggedTensor(tokens)
+ row_splits = tokens.shape.row_splits(1)
+ tokens_lens = row_splits[1:] - row_splits[:-1]
+ tokens = tokens.to(device)
+ tokens_lens = tokens_lens.to(device)
+ # tensor of shape (B, T)
+ tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
+ speakers = (
+ torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
+ .int()
+ .to(device)
+ )
+
+ audio = batch["audio"]
+ audio_lens = batch["audio_lens"].tolist()
+ cut_ids = [cut.id for cut in batch["cut"]]
+
+ audio_pred, _, durations = model.inference_batch(
+ text=tokens,
+ text_lengths=tokens_lens,
+ sids=speakers,
+ )
+ audio_pred = audio_pred.detach().cpu()
+ # convert to samples
+ audio_lens_pred = (
+ (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
+ )
+
+ futures.append(
+ executor.submit(
+ _save_worker,
+ subset,
+ batch_size,
+ cut_ids,
+ audio,
+ audio_pred,
+ audio_lens,
+ audio_lens_pred,
+ )
+ )
+
+ num_cuts += batch_size
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
+ # return results
+ for f in futures:
+ f.result()
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ VctkTtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.suffix = f"epoch-{params.epoch}"
+
+ params.res_dir = params.exp_dir / "infer" / params.suffix
+ params.save_wav_dir = params.res_dir / "wav"
+ params.save_wav_dir.mkdir(parents=True, exist_ok=True)
+
+ setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
+ logging.info("Infer started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ tokenizer = Tokenizer(params.tokens)
+ params.blank_id = tokenizer.blank_id
+ params.oov_id = tokenizer.oov_id
+ params.vocab_size = tokenizer.vocab_size
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ vctk = VctkTtsDataModule(args)
+ speaker_map = vctk.speakers()
+ params.num_spks = len(speaker_map)
+
+ logging.info(f"Device: {device}")
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+
+ model.to(device)
+ model.eval()
+
+ num_param_g = sum([p.numel() for p in model.generator.parameters()])
+ logging.info(f"Number of parameters in generator: {num_param_g}")
+ num_param_d = sum([p.numel() for p in model.discriminator.parameters()])
+ logging.info(f"Number of parameters in discriminator: {num_param_d}")
+ logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
+
+ test_cuts = vctk.test_cuts()
+ test_dl = vctk.test_dataloaders(test_cuts)
+
+ valid_cuts = vctk.valid_cuts()
+ valid_dl = vctk.valid_dataloaders(valid_cuts)
+
+ infer_sets = {"test": test_dl, "valid": valid_dl}
+
+ for subset, dl in infer_sets.items():
+ save_wav_dir = params.res_dir / "wav" / subset
+ save_wav_dir.mkdir(parents=True, exist_ok=True)
+
+ logging.info(f"Processing {subset} set, saving to {save_wav_dir}")
+
+ infer_dataset(
+ dl=dl,
+ subset=subset,
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ speaker_map=speaker_map,
+ )
+
+ logging.info(f"Wav files are saved to {params.save_wav_dir}")
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/vctk/TTS/vits/loss.py b/egs/vctk/TTS/vits/loss.py
new file mode 120000
index 0000000000..672e5ff68d
--- /dev/null
+++ b/egs/vctk/TTS/vits/loss.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/loss.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/monotonic_align b/egs/vctk/TTS/vits/monotonic_align
new file mode 120000
index 0000000000..71934e7cca
--- /dev/null
+++ b/egs/vctk/TTS/vits/monotonic_align
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/monotonic_align
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/posterior_encoder.py b/egs/vctk/TTS/vits/posterior_encoder.py
new file mode 120000
index 0000000000..41d64a3a66
--- /dev/null
+++ b/egs/vctk/TTS/vits/posterior_encoder.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/posterior_encoder.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/residual_coupling.py b/egs/vctk/TTS/vits/residual_coupling.py
new file mode 120000
index 0000000000..f979adbf00
--- /dev/null
+++ b/egs/vctk/TTS/vits/residual_coupling.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/residual_coupling.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py
new file mode 100755
index 0000000000..757e67fc1c
--- /dev/null
+++ b/egs/vctk/TTS/vits/test_onnx.py
@@ -0,0 +1,138 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script is used to test the exported onnx model by vits/export-onnx.py
+
+Use the onnx model to generate a wav:
+./vits/test_onnx.py \
+ --model-filename vits/exp/vits-epoch-1000.onnx \
+ --tokens data/tokens.txt
+"""
+
+
+import argparse
+import logging
+from pathlib import Path
+
+import onnxruntime as ort
+import torch
+import torchaudio
+from tokenizer import Tokenizer
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--model-filename",
+ type=str,
+ required=True,
+ help="Path to the onnx model.",
+ )
+
+ parser.add_argument(
+ "--speakers",
+ type=Path,
+ default=Path("data/speakers.txt"),
+ help="Path to speakers.txt file.",
+ )
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ return parser
+
+
+class OnnxModel:
+ def __init__(self, model_filename: str):
+ session_opts = ort.SessionOptions()
+ session_opts.inter_op_num_threads = 1
+ session_opts.intra_op_num_threads = 4
+
+ self.session_opts = session_opts
+
+ self.model = ort.InferenceSession(
+ model_filename,
+ sess_options=self.session_opts,
+ providers=["CPUExecutionProvider"],
+ )
+ logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
+
+ def __call__(
+ self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Args:
+ tokens:
+ A 1-D tensor of shape (1, T)
+ Returns:
+ A tensor of shape (1, T')
+ """
+ noise_scale = torch.tensor([0.667], dtype=torch.float32)
+ noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
+ alpha = torch.tensor([1.0], dtype=torch.float32)
+
+ out = self.model.run(
+ [
+ self.model.get_outputs()[0].name,
+ ],
+ {
+ self.model.get_inputs()[0].name: tokens.numpy(),
+ self.model.get_inputs()[1].name: tokens_lens.numpy(),
+ self.model.get_inputs()[2].name: noise_scale.numpy(),
+ self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
+ self.model.get_inputs()[4].name: speaker.numpy(),
+ self.model.get_inputs()[5].name: alpha.numpy(),
+ },
+ )[0]
+ return torch.from_numpy(out)
+
+
+def main():
+ args = get_parser().parse_args()
+
+ tokenizer = Tokenizer(args.tokens)
+
+ with open(args.speakers) as f:
+ speaker_map = {line.strip(): i for i, line in enumerate(f)}
+ args.num_spks = len(speaker_map)
+
+ logging.info("About to create onnx model")
+ model = OnnxModel(args.model_filename)
+
+ text = "I went there to see the land, the people and how their system works, end quote."
+ tokens = tokenizer.texts_to_token_ids([text])
+ tokens = torch.tensor(tokens) # (1, T)
+ tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
+ speaker = torch.tensor([1], dtype=torch.int64) # (1, )
+ audio = model(tokens, tokens_lens, speaker) # (1, T')
+
+ torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
+ logging.info("Saved to test_onnx.wav")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/vctk/TTS/vits/text_encoder.py b/egs/vctk/TTS/vits/text_encoder.py
new file mode 120000
index 0000000000..0efba277e1
--- /dev/null
+++ b/egs/vctk/TTS/vits/text_encoder.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/text_encoder.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/tokenizer.py b/egs/vctk/TTS/vits/tokenizer.py
new file mode 120000
index 0000000000..057b0dc4b1
--- /dev/null
+++ b/egs/vctk/TTS/vits/tokenizer.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/tokenizer.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py
new file mode 100755
index 0000000000..56f167a178
--- /dev/null
+++ b/egs/vctk/TTS/vits/train.py
@@ -0,0 +1,1000 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import numpy as np
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from lhotse.cut import Cut
+from lhotse.utils import fix_random_seed
+from tokenizer import Tokenizer
+from torch.cuda.amp import GradScaler, autocast
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torch.utils.tensorboard import SummaryWriter
+from tts_datamodule import VctkTtsDataModule
+from utils import MetricsTracker, plot_feature, save_checkpoint
+from vits import VITS
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import AttributeDict, setup_logger, str2bool
+
+LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=1000,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="vits/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ parser.add_argument(
+ "--lr", type=float, default=2.0e-4, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=20,
+ help="""Save checkpoint after processing this number of epochs"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.cur_epoch % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
+ Since it will take around 1000 epochs, we suggest using a large
+ save_every_n to save disk space.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ # training params
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": -1, # 0
+ "log_interval": 50,
+ "valid_interval": 200,
+ "env_info": get_env_info(),
+ "sampling_rate": 22050,
+ "frame_shift": 256,
+ "frame_length": 1024,
+ "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
+ "n_mels": 80,
+ "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
+ "lambda_mel": 45.0, # loss scaling coefficient for Mel loss
+ "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
+ "lambda_dur": 1.0, # loss scaling coefficient for duration loss
+ "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss
+ }
+ )
+
+ return params
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict, model: nn.Module
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(filename, model=model)
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ return saved_params
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ mel_loss_params = {
+ "n_mels": params.n_mels,
+ "frame_length": params.frame_length,
+ "frame_shift": params.frame_shift,
+ }
+ generator_params = {
+ "hidden_channels": 192,
+ "spks": params.num_spks,
+ "langs": None,
+ "spk_embed_dim": None,
+ "global_channels": 256,
+ "segment_size": 32,
+ "text_encoder_attention_heads": 2,
+ "text_encoder_ffn_expand": 4,
+ "text_encoder_cnn_module_kernel": 5,
+ "text_encoder_blocks": 6,
+ "text_encoder_dropout_rate": 0.1,
+ "decoder_kernel_size": 7,
+ "decoder_channels": 512,
+ "decoder_upsample_scales": [8, 8, 2, 2],
+ "decoder_upsample_kernel_sizes": [16, 16, 4, 4],
+ "decoder_resblock_kernel_sizes": [3, 7, 11],
+ "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ "use_weight_norm_in_decoder": True,
+ "posterior_encoder_kernel_size": 5,
+ "posterior_encoder_layers": 16,
+ "posterior_encoder_stacks": 1,
+ "posterior_encoder_base_dilation": 1,
+ "posterior_encoder_dropout_rate": 0.0,
+ "use_weight_norm_in_posterior_encoder": True,
+ "flow_flows": 4,
+ "flow_kernel_size": 5,
+ "flow_base_dilation": 1,
+ "flow_layers": 4,
+ "flow_dropout_rate": 0.0,
+ "use_weight_norm_in_flow": True,
+ "use_only_mean_in_flow": True,
+ "stochastic_duration_predictor_kernel_size": 3,
+ "stochastic_duration_predictor_dropout_rate": 0.5,
+ "stochastic_duration_predictor_flows": 4,
+ "stochastic_duration_predictor_dds_conv_layers": 3,
+ }
+ model = VITS(
+ vocab_size=params.vocab_size,
+ feature_dim=params.feature_dim,
+ sampling_rate=params.sampling_rate,
+ generator_params=generator_params,
+ mel_loss_params=mel_loss_params,
+ lambda_adv=params.lambda_adv,
+ lambda_mel=params.lambda_mel,
+ lambda_feat_match=params.lambda_feat_match,
+ lambda_dur=params.lambda_dur,
+ lambda_kl=params.lambda_kl,
+ )
+ return model
+
+
+def prepare_input(
+ batch: dict,
+ tokenizer: Tokenizer,
+ device: torch.device,
+ speaker_map: Dict[str, int],
+):
+ """Parse batch data"""
+ audio = batch["audio"].to(device)
+ features = batch["features"].to(device)
+ audio_lens = batch["audio_lens"].to(device)
+ features_lens = batch["features_lens"].to(device)
+ tokens = batch["tokens"]
+ speakers = (
+ torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device)
+ )
+
+ tokens = tokenizer.tokens_to_token_ids(tokens)
+ tokens = k2.RaggedTensor(tokens)
+ row_splits = tokens.shape.row_splits(1)
+ tokens_lens = row_splits[1:] - row_splits[:-1]
+ tokens = tokens.to(device)
+ tokens_lens = tokens_lens.to(device)
+ # a tensor of shape (B, T)
+ tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
+
+ return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ tokenizer: Tokenizer,
+ optimizer_g: Optimizer,
+ optimizer_d: Optimizer,
+ scheduler_g: LRSchedulerType,
+ scheduler_d: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ speaker_map: Dict[str, int],
+ scaler: GradScaler,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ tokenizer:
+ Used to convert text to phonemes.
+ optimizer_g:
+ The optimizer for generator.
+ optimizer_d:
+ The optimizer for discriminator.
+ scheduler_g:
+ The learning rate scheduler for generator, we call step() every epoch.
+ scheduler_d:
+ The learning rate scheduler for discriminator, we call step() every epoch.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations in one epoch
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ params=params,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ scheduler_g=scheduler_g,
+ scheduler_d=scheduler_d,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+
+ batch_size = len(batch["tokens"])
+ (
+ audio,
+ audio_lens,
+ features,
+ features_lens,
+ tokens,
+ tokens_lens,
+ speakers,
+ ) = prepare_input(batch, tokenizer, device, speaker_map)
+
+ loss_info = MetricsTracker()
+ loss_info["samples"] = batch_size
+
+ try:
+ with autocast(enabled=params.use_fp16):
+ # forward discriminator
+ loss_d, stats_d = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ sids=speakers,
+ forward_generator=False,
+ )
+ for k, v in stats_d.items():
+ loss_info[k] = v * batch_size
+ # update discriminator
+ optimizer_d.zero_grad()
+ scaler.scale(loss_d).backward()
+ scaler.step(optimizer_d)
+
+ with autocast(enabled=params.use_fp16):
+ # forward generator
+ loss_g, stats_g = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ sids=speakers,
+ forward_generator=True,
+ return_sample=params.batch_idx_train % params.log_interval == 0,
+ )
+ for k, v in stats_g.items():
+ if "returned_sample" not in k:
+ loss_info[k] = v * batch_size
+ # update generator
+ optimizer_g.zero_grad()
+ scaler.scale(loss_g).backward()
+ scaler.step(optimizer_g)
+ scaler.update()
+
+ # summary stats
+ tot_loss = tot_loss + loss_info
+ except: # noqa
+ save_bad_model()
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (
+ cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
+ ):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if params.batch_idx_train % params.log_interval == 0:
+ cur_lr_g = max(scheduler_g.get_last_lr())
+ cur_lr_d = max(scheduler_d.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
+ f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate_g", cur_lr_g, params.batch_idx_train
+ )
+ tb_writer.add_scalar(
+ "train/learning_rate_d", cur_lr_d, params.batch_idx_train
+ )
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+ if "returned_sample" in stats_g:
+ speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
+ tb_writer.add_audio(
+ "train/speech_hat_",
+ speech_hat_,
+ params.batch_idx_train,
+ params.sampling_rate,
+ )
+ tb_writer.add_audio(
+ "train/speech_",
+ speech_,
+ params.batch_idx_train,
+ params.sampling_rate,
+ )
+ tb_writer.add_image(
+ "train/mel_hat_",
+ plot_feature(mel_hat_),
+ params.batch_idx_train,
+ dataformats="HWC",
+ )
+ tb_writer.add_image(
+ "train/mel_",
+ plot_feature(mel_),
+ params.batch_idx_train,
+ dataformats="HWC",
+ )
+
+ if (
+ params.batch_idx_train % params.valid_interval == 0
+ and not params.print_diagnostics
+ ):
+ logging.info("Computing validation loss")
+ valid_info, (speech_hat, speech) = compute_validation_loss(
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ valid_dl=valid_dl,
+ speaker_map=speaker_map,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+ tb_writer.add_audio(
+ "train/valdi_speech_hat",
+ speech_hat,
+ params.batch_idx_train,
+ params.sampling_rate,
+ )
+ tb_writer.add_audio(
+ "train/valdi_speech",
+ speech,
+ params.batch_idx_train,
+ params.sampling_rate,
+ )
+
+ loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ tokenizer: Tokenizer,
+ valid_dl: torch.utils.data.DataLoader,
+ speaker_map: Dict[str, int],
+ world_size: int = 1,
+ rank: int = 0,
+) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
+ """Run the validation process."""
+ model.eval()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations
+ tot_loss = MetricsTracker()
+ returned_sample = None
+
+ with torch.no_grad():
+ for batch_idx, batch in enumerate(valid_dl):
+ batch_size = len(batch["tokens"])
+ (
+ audio,
+ audio_lens,
+ features,
+ features_lens,
+ tokens,
+ tokens_lens,
+ speakers,
+ ) = prepare_input(batch, tokenizer, device, speaker_map)
+
+ loss_info = MetricsTracker()
+ loss_info["samples"] = batch_size
+
+ # forward discriminator
+ loss_d, stats_d = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ sids=speakers,
+ forward_generator=False,
+ )
+ assert loss_d.requires_grad is False
+ for k, v in stats_d.items():
+ loss_info[k] = v * batch_size
+
+ # forward generator
+ loss_g, stats_g = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ sids=speakers,
+ forward_generator=True,
+ )
+ assert loss_g.requires_grad is False
+ for k, v in stats_g.items():
+ loss_info[k] = v * batch_size
+
+ # summary stats
+ tot_loss = tot_loss + loss_info
+
+ # infer for first batch:
+ if batch_idx == 0 and rank == 0:
+ inner_model = model.module if isinstance(model, DDP) else model
+ audio_pred, _, duration = inner_model.inference(
+ text=tokens[0, : tokens_lens[0].item()],
+ sids=speakers[0],
+ )
+ audio_pred = audio_pred.data.cpu().numpy()
+ audio_len_pred = (
+ (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
+ )
+ assert audio_len_pred == len(audio_pred), (
+ audio_len_pred,
+ len(audio_pred),
+ )
+ audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
+ returned_sample = (audio_pred, audio_gt)
+
+ if world_size > 1:
+ tot_loss.reduce(device)
+
+ loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss, returned_sample
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ tokenizer: Tokenizer,
+ optimizer_g: torch.optim.Optimizer,
+ optimizer_d: torch.optim.Optimizer,
+ speaker_map: Dict[str, int],
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ (
+ audio,
+ audio_lens,
+ features,
+ features_lens,
+ tokens,
+ tokens_lens,
+ speakers,
+ ) = prepare_input(batch, tokenizer, device, speaker_map)
+ try:
+ # for discriminator
+ with autocast(enabled=params.use_fp16):
+ loss_d, stats_d = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ sids=speakers,
+ forward_generator=False,
+ )
+ optimizer_d.zero_grad()
+ loss_d.backward()
+ # for generator
+ with autocast(enabled=params.use_fp16):
+ loss_g, stats_g = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ sids=speakers,
+ forward_generator=True,
+ )
+ optimizer_g.zero_grad()
+ loss_g.backward()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ tokenizer = Tokenizer(params.tokens)
+ params.blank_id = tokenizer.blank_id
+ params.oov_id = tokenizer.oov_id
+ params.vocab_size = tokenizer.vocab_size
+
+ vctk = VctkTtsDataModule(args)
+
+ train_cuts = vctk.train_cuts()
+ speaker_map = vctk.speakers()
+ params.num_spks = len(speaker_map)
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+ generator = model.generator
+ discriminator = model.discriminator
+
+ num_param_g = sum([p.numel() for p in generator.parameters()])
+ logging.info(f"Number of parameters in generator: {num_param_g}")
+ num_param_d = sum([p.numel() for p in discriminator.parameters()])
+ logging.info(f"Number of parameters in discriminator: {num_param_d}")
+ logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer_g = torch.optim.AdamW(
+ generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
+ )
+ optimizer_d = torch.optim.AdamW(
+ discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
+ )
+
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
+
+ if checkpoints is not None:
+ # load state_dict for optimizers
+ if "optimizer_g" in checkpoints:
+ logging.info("Loading optimizer_g state dict")
+ optimizer_g.load_state_dict(checkpoints["optimizer_g"])
+ if "optimizer_d" in checkpoints:
+ logging.info("Loading optimizer_d state dict")
+ optimizer_d.load_state_dict(checkpoints["optimizer_d"])
+
+ # load state_dict for schedulers
+ if "scheduler_g" in checkpoints:
+ logging.info("Loading scheduler_g state dict")
+ scheduler_g.load_state_dict(checkpoints["scheduler_g"])
+ if "scheduler_d" in checkpoints:
+ logging.info("Loading scheduler_d state dict")
+ scheduler_d.load_state_dict(checkpoints["scheduler_d"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+ train_dl = vctk.train_dataloaders(train_cuts)
+
+ valid_cuts = vctk.valid_cuts()
+ valid_dl = vctk.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ tokenizer=tokenizer,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ speaker_map=speaker_map,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ logging.info(f"Start epoch {epoch}")
+
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ params.cur_epoch = epoch
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ scheduler_g=scheduler_g,
+ scheduler_d=scheduler_d,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ speaker_map=speaker_map,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint(
+ filename=filename,
+ params=params,
+ model=model,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ scheduler_g=scheduler_g,
+ scheduler_d=scheduler_d,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ if rank == 0:
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+ # step per epoch
+ scheduler_g.step()
+ scheduler_d.step()
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ VctkTtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/vctk/TTS/vits/transform.py b/egs/vctk/TTS/vits/transform.py
new file mode 120000
index 0000000000..962647408b
--- /dev/null
+++ b/egs/vctk/TTS/vits/transform.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/transform.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/tts_datamodule.py b/egs/vctk/TTS/vits/tts_datamodule.py
new file mode 100644
index 0000000000..8b2a96b099
--- /dev/null
+++ b/egs/vctk/TTS/vits/tts_datamodule.py
@@ -0,0 +1,338 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+ SpecAugment,
+ SpeechSynthesisDataset,
+)
+from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
+ AudioSamples,
+ OnTheFlyFeatures,
+)
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class VctkTtsDataModule:
+ """
+ DataModule for tts experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="TTS data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/spectrogram"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--speakers",
+ type=Path,
+ default=Path("data/speakers.txt"),
+ help="Path to speakers.txt file.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, each batch will have the "
+ "field: batch['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=8,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to create train dataset")
+ train = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ sampling_rate = 22050
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=1024 / sampling_rate, # (in second),
+ frame_shift=256 / sampling_rate, # (in second)
+ use_fft_mag=True,
+ )
+ train = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ sampling_rate = 22050
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=1024 / sampling_rate, # (in second),
+ frame_shift=256 / sampling_rate, # (in second)
+ use_fft_mag=True,
+ )
+ validate = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create valid dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.info("About to create test dataset")
+ if self.args.on_the_fly_feats:
+ sampling_rate = 22050
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=1024 / sampling_rate, # (in second),
+ frame_shift=256 / sampling_rate, # (in second)
+ use_fft_mag=True,
+ )
+ test = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ test = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ test_sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=test_sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get train cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
+
+ @lru_cache()
+ def valid_cuts(self) -> CutSet:
+ logging.info("About to get validation cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
+
+ @lru_cache()
+ def test_cuts(self) -> CutSet:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")
+
+ @lru_cache()
+ def speakers(self) -> Dict[str, int]:
+ logging.info("About to get speakers")
+ with open(self.args.speakers) as f:
+ speakers = {line.strip(): i for i, line in enumerate(f)}
+ return speakers
diff --git a/egs/vctk/TTS/vits/utils.py b/egs/vctk/TTS/vits/utils.py
new file mode 120000
index 0000000000..085e764b43
--- /dev/null
+++ b/egs/vctk/TTS/vits/utils.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/utils.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/vits.py b/egs/vctk/TTS/vits/vits.py
new file mode 120000
index 0000000000..1f58cf6fea
--- /dev/null
+++ b/egs/vctk/TTS/vits/vits.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/vits.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/wavenet.py b/egs/vctk/TTS/vits/wavenet.py
new file mode 120000
index 0000000000..28f0a78eeb
--- /dev/null
+++ b/egs/vctk/TTS/vits/wavenet.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/wavenet.py
\ No newline at end of file
diff --git a/requirements-tts.txt b/requirements-tts.txt
new file mode 100644
index 0000000000..c30e23d549
--- /dev/null
+++ b/requirements-tts.txt
@@ -0,0 +1,6 @@
+# for TTS recipes
+matplotlib==3.8.2
+cython==3.0.6
+numba==0.58.1
+g2p_en==2.1.0
+espnet_tts_frontend==0.0.3
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 9502fcbd26..a1a46ae647 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,3 +8,5 @@ tensorboard
typeguard
dill
black==22.3.0
+onnx==1.15.0
+onnxruntime==1.16.3
\ No newline at end of file