diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst
index d08aa0f470..323d0adfc8 100644
--- a/docs/source/recipes/TTS/ljspeech/vits.rst
+++ b/docs/source/recipes/TTS/ljspeech/vits.rst
@@ -1,11 +1,11 @@
-VITS
+VITS-LJSpeech
===============
This tutorial shows you how to train an VITS model
with the `LJSpeech `_ dataset.
.. note::
-
+
TTS related recipes require packages in ``requirements-tts.txt``.
.. note::
@@ -120,4 +120,4 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models
by visiting the following link:
- - ``_
+ - ``_
diff --git a/docs/source/recipes/TTS/vctk/vits.rst b/docs/source/recipes/TTS/vctk/vits.rst
index 34024a5ea5..45ae9d9d20 100644
--- a/docs/source/recipes/TTS/vctk/vits.rst
+++ b/docs/source/recipes/TTS/vctk/vits.rst
@@ -1,11 +1,11 @@
-VITS
+VITS-VCTK
===============
This tutorial shows you how to train an VITS model
with the `VCTK `_ dataset.
.. note::
-
+
TTS related recipes require packages in ``requirements-tts.txt``.
.. note::
diff --git a/egs/ljspeech/TTS/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py
index df976804ab..5b048b600b 100755
--- a/egs/ljspeech/TTS/local/prepare_token_file.py
+++ b/egs/ljspeech/TTS/local/prepare_token_file.py
@@ -17,7 +17,7 @@
"""
-This file reads the texts in given manifest and generates the file that maps tokens to IDs.
+This file generates the file that maps tokens to IDs.
"""
import argparse
@@ -25,80 +25,38 @@
from pathlib import Path
from typing import Dict
-from lhotse import load_manifest
+from piper_phonemize import get_espeak_map
def get_args():
parser = argparse.ArgumentParser()
- parser.add_argument(
- "--manifest-file",
- type=Path,
- default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"),
- help="Path to the manifest file",
- )
-
parser.add_argument(
"--tokens",
type=Path,
default=Path("data/tokens.txt"),
- help="Path to the tokens",
+ help="Path to the dict that maps the text tokens to IDs",
)
return parser.parse_args()
-def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
- """Write a symbol to ID mapping to a file.
+def get_token2id(filename: Path) -> Dict[str, int]:
+ """Get a dict that maps token to IDs, and save it to the given filename."""
+ all_tokens = get_espeak_map() # token: [token_id]
+ all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()}
+ # sort by token_id
+ all_tokens = sorted(all_tokens.items(), key=lambda x: x[1])
- Note:
- No need to implement `read_mapping` as it can be done
- through :func:`k2.SymbolTable.from_file`.
-
- Args:
- filename:
- Filename to save the mapping.
- sym2id:
- A dict mapping symbols to IDs.
- Returns:
- Return None.
- """
with open(filename, "w", encoding="utf-8") as f:
- for sym, i in sym2id.items():
- f.write(f"{sym} {i}\n")
-
-
-def get_token2id(manifest_file: Path) -> Dict[str, int]:
- """Return a dict that maps token to IDs."""
- extra_tokens = [
- "", # 0 for blank
- "", # 1 for sos and eos symbols.
- "", # 2 for OOV
- ]
- all_tokens = set()
-
- cut_set = load_manifest(manifest_file)
-
- for cut in cut_set:
- # Each cut only contain one supervision
- assert len(cut.supervisions) == 1, len(cut.supervisions)
- for t in cut.tokens:
- all_tokens.add(t)
-
- all_tokens = extra_tokens + list(all_tokens)
-
- token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
- return token2id
+ for token, token_id in all_tokens:
+ f.write(f"{token} {token_id}\n")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
- manifest_file = Path(args.manifest_file)
out_file = Path(args.tokens)
-
- token2id = get_token2id(manifest_file)
- write_mapping(out_file, token2id)
+ get_token2id(out_file)
diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
index fcd0137a08..08fe7430ef 100755
--- a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
+++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
@@ -23,9 +23,9 @@
import logging
from pathlib import Path
-import g2p_en
import tacotron_cleaner.cleaners
from lhotse import CutSet, load_manifest
+from piper_phonemize import phonemize_espeak
def prepare_tokens_ljspeech():
@@ -35,17 +35,20 @@ def prepare_tokens_ljspeech():
partition = "all"
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
- g2p = g2p_en.G2p()
new_cuts = []
for cut in cut_set:
# Each cut only contains one supervision
- assert len(cut.supervisions) == 1, len(cut.supervisions)
+ assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
text = cut.supervisions[0].normalized_text
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
- cut.tokens = g2p(text)
+ tokens_list = phonemize_espeak(text, "en-us")
+ tokens = []
+ for t in tokens_list:
+ tokens.extend(t)
+ cut.tokens = tokens
new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts)
diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh
index ed0a07f5e2..cbf27bd423 100755
--- a/egs/ljspeech/TTS/prepare.sh
+++ b/egs/ljspeech/TTS/prepare.sh
@@ -30,7 +30,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
cd vits/monotonic_align
python setup.py build_ext --inplace
cd ../../
- else
+ else
log "monotonic_align lib already built"
fi
fi
@@ -80,6 +80,11 @@ fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for LJSpeech"
+ # We assume you have installed piper_phonemize and espnet_tts_frontend.
+ # If not, please install them with:
+ # - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize,
+ # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
+ # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
./local/prepare_tokens_ljspeech.py
mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \
@@ -113,13 +118,12 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate token file"
- # We assume you have installed g2p_en and espnet_tts_frontend.
+ # We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
- # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
+ # - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize,
+ # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/tokens.txt ]; then
- ./local/prepare_token_file.py \
- --manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \
- --tokens data/tokens.txt
+ ./local/prepare_token_file.py --tokens data/tokens.txt
fi
fi
diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py
index f82f9dbe9b..c607f0114b 100755
--- a/egs/ljspeech/TTS/vits/export-onnx.py
+++ b/egs/ljspeech/TTS/vits/export-onnx.py
@@ -218,8 +218,7 @@ def main():
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
- params.blank_id = tokenizer.blank_id
- params.oov_id = tokenizer.oov_id
+ params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
logging.info(params)
diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py
index cf0d20ae23..9e7c71c6dc 100755
--- a/egs/ljspeech/TTS/vits/infer.py
+++ b/egs/ljspeech/TTS/vits/infer.py
@@ -130,14 +130,16 @@ def _save_worker(
batch_size = len(batch["tokens"])
tokens = batch["tokens"]
- tokens = tokenizer.tokens_to_token_ids(tokens)
+ tokens = tokenizer.tokens_to_token_ids(
+ tokens, intersperse_blank=True, add_sos=True, add_eos=True
+ )
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# tensor of shape (B, T)
- tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
+ tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
@@ -201,8 +203,7 @@ def main():
device = torch.device("cuda", 0)
tokenizer = Tokenizer(params.tokens)
- params.blank_id = tokenizer.blank_id
- params.oov_id = tokenizer.oov_id
+ params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
logging.info(f"Device: {device}")
diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py
index fcbc1d6632..4f46e8e6c5 100755
--- a/egs/ljspeech/TTS/vits/test_onnx.py
+++ b/egs/ljspeech/TTS/vits/test_onnx.py
@@ -108,7 +108,9 @@ def main():
model = OnnxModel(args.model_filename)
text = "I went there to see the land, the people and how their system works, end quote."
- tokens = tokenizer.texts_to_token_ids([text])
+ tokens = tokenizer.texts_to_token_ids(
+ [text], intersperse_blank=True, add_sos=True, add_eos=True
+ )
tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
audio = model(tokens, tokens_lens) # (1, T')
diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py
index b0afc6a044..9a5a9090ec 100644
--- a/egs/ljspeech/TTS/vits/tokenizer.py
+++ b/egs/ljspeech/TTS/vits/tokenizer.py
@@ -1,4 +1,4 @@
-# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
@@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
from typing import Dict, List
-import g2p_en
import tacotron_cleaner.cleaners
+from piper_phonemize import phonemize_espeak
from utils import intersperse
@@ -38,21 +39,37 @@ def __init__(self, tokens: str):
id = int(info[0])
else:
token, id = info[0], int(info[1])
+ assert token not in self.token2id, token
self.token2id[token] = id
- self.blank_id = self.token2id[""]
- self.oov_id = self.token2id[""]
- self.vocab_size = len(self.token2id)
+ # Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md
+ self.pad_id = self.token2id["_"] # padding
+ self.sos_id = self.token2id["^"] # beginning of an utterance (bos)
+ self.eos_id = self.token2id["$"] # end of an utterance (eos)
+ self.space_id = self.token2id[" "] # word separator (whitespace)
- self.g2p = g2p_en.G2p()
+ self.vocab_size = len(self.token2id)
- def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
+ def texts_to_token_ids(
+ self,
+ texts: List[str],
+ intersperse_blank: bool = True,
+ add_sos: bool = False,
+ add_eos: bool = False,
+ lang: str = "en-us",
+ ) -> List[List[int]]:
"""
Args:
texts:
A list of transcripts.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
+ add_sos:
+ Whether to add sos token at the start.
+ add_eos:
+ Whether to add eos token at the end.
+ lang:
+ Language argument passed to phonemize_espeak().
Returns:
Return a list of token id list [utterance][token_id]
@@ -63,30 +80,46 @@ def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
- tokens = self.g2p(text)
+ tokens_list = phonemize_espeak(text, lang)
+ tokens = []
+ for t in tokens_list:
+ tokens.extend(t)
+
token_ids = []
for t in tokens:
- if t in self.token2id:
- token_ids.append(self.token2id[t])
- else:
- token_ids.append(self.oov_id)
+ if t not in self.token2id:
+ logging.warning(f"Skip OOV {t}")
+ continue
+ token_ids.append(self.token2id[t])
if intersperse_blank:
- token_ids = intersperse(token_ids, self.blank_id)
+ token_ids = intersperse(token_ids, self.pad_id)
+ if add_sos:
+ token_ids = [self.sos_id] + token_ids
+ if add_eos:
+ token_ids = token_ids + [self.eos_id]
token_ids_list.append(token_ids)
return token_ids_list
def tokens_to_token_ids(
- self, tokens_list: List[str], intersperse_blank: bool = True
- ):
+ self,
+ tokens_list: List[str],
+ intersperse_blank: bool = True,
+ add_sos: bool = False,
+ add_eos: bool = False,
+ ) -> List[List[int]]:
"""
Args:
tokens_list:
A list of token list, each corresponding to one utterance.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
+ add_sos:
+ Whether to add sos token at the start.
+ add_eos:
+ Whether to add eos token at the end.
Returns:
Return a list of token id list [utterance][token_id]
@@ -96,13 +129,17 @@ def tokens_to_token_ids(
for tokens in tokens_list:
token_ids = []
for t in tokens:
- if t in self.token2id:
- token_ids.append(self.token2id[t])
- else:
- token_ids.append(self.oov_id)
+ if t not in self.token2id:
+ logging.warning(f"Skip OOV {t}")
+ continue
+ token_ids.append(self.token2id[t])
if intersperse_blank:
- token_ids = intersperse(token_ids, self.blank_id)
+ token_ids = intersperse(token_ids, self.pad_id)
+ if add_sos:
+ token_ids = [self.sos_id] + token_ids
+ if add_eos:
+ token_ids = token_ids + [self.eos_id]
token_ids_list.append(token_ids)
diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py
index 71c4224fa7..6589b75ff6 100755
--- a/egs/ljspeech/TTS/vits/train.py
+++ b/egs/ljspeech/TTS/vits/train.py
@@ -296,14 +296,16 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
- tokens = tokenizer.tokens_to_token_ids(tokens)
+ tokens = tokenizer.tokens_to_token_ids(
+ tokens, intersperse_blank=True, add_sos=True, add_eos=True
+ )
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
- tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
+ tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
return audio, audio_lens, features, features_lens, tokens, tokens_lens
@@ -742,8 +744,7 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
- params.blank_id = tokenizer.blank_id
- params.oov_id = tokenizer.oov_id
+ params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
logging.info(params)
diff --git a/requirements-tts.txt b/requirements-tts.txt
index c30e23d549..eae50ba7b5 100644
--- a/requirements-tts.txt
+++ b/requirements-tts.txt
@@ -3,4 +3,5 @@ matplotlib==3.8.2
cython==3.0.6
numba==0.58.1
g2p_en==2.1.0
-espnet_tts_frontend==0.0.3
\ No newline at end of file
+espnet_tts_frontend==0.0.3
+# piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize, could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5