Skip to content

Commit

Permalink
Use piper_phonemize as text tokenizer in ljspeech recipe (#1511)
Browse files Browse the repository at this point in the history
* use piper_phonemize as text tokenizer in ljspeech recipe

* modify usage of tokenizer in vits/train.py

* update docs
  • Loading branch information
yaozengwei authored Feb 29, 2024
1 parent 291d060 commit d89f4ea
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 101 deletions.
6 changes: 3 additions & 3 deletions docs/source/recipes/TTS/ljspeech/vits.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
VITS
VITS-LJSpeech
===============

This tutorial shows you how to train an VITS model
with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.

.. note::

TTS related recipes require packages in ``requirements-tts.txt``.

.. note::
Expand Down Expand Up @@ -120,4 +120,4 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models
by visiting the following link:

- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29>`_
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
4 changes: 2 additions & 2 deletions docs/source/recipes/TTS/vctk/vits.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
VITS
VITS-VCTK
===============

This tutorial shows you how to train an VITS model
with the `VCTK <https://datashare.ed.ac.uk/handle/10283/3443>`_ dataset.

.. note::

TTS related recipes require packages in ``requirements-tts.txt``.

.. note::
Expand Down
66 changes: 12 additions & 54 deletions egs/ljspeech/TTS/local/prepare_token_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,88 +17,46 @@


"""
This file reads the texts in given manifest and generates the file that maps tokens to IDs.
This file generates the file that maps tokens to IDs.
"""

import argparse
import logging
from pathlib import Path
from typing import Dict

from lhotse import load_manifest
from piper_phonemize import get_espeak_map


def get_args():
parser = argparse.ArgumentParser()

parser.add_argument(
"--manifest-file",
type=Path,
default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"),
help="Path to the manifest file",
)

parser.add_argument(
"--tokens",
type=Path,
default=Path("data/tokens.txt"),
help="Path to the tokens",
help="Path to the dict that maps the text tokens to IDs",
)

return parser.parse_args()


def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
def get_token2id(filename: Path) -> Dict[str, int]:
"""Get a dict that maps token to IDs, and save it to the given filename."""
all_tokens = get_espeak_map() # token: [token_id]
all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()}
# sort by token_id
all_tokens = sorted(all_tokens.items(), key=lambda x: x[1])

Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")


def get_token2id(manifest_file: Path) -> Dict[str, int]:
"""Return a dict that maps token to IDs."""
extra_tokens = [
"<blk>", # 0 for blank
"<sos/eos>", # 1 for sos and eos symbols.
"<unk>", # 2 for OOV
]
all_tokens = set()

cut_set = load_manifest(manifest_file)

for cut in cut_set:
# Each cut only contain one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
for t in cut.tokens:
all_tokens.add(t)

all_tokens = extra_tokens + list(all_tokens)

token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
return token2id
for token, token_id in all_tokens:
f.write(f"{token} {token_id}\n")


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)

args = get_args()
manifest_file = Path(args.manifest_file)
out_file = Path(args.tokens)

token2id = get_token2id(manifest_file)
write_mapping(out_file, token2id)
get_token2id(out_file)
11 changes: 7 additions & 4 deletions egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import logging
from pathlib import Path

import g2p_en
import tacotron_cleaner.cleaners
from lhotse import CutSet, load_manifest
from piper_phonemize import phonemize_espeak


def prepare_tokens_ljspeech():
Expand All @@ -35,17 +35,20 @@ def prepare_tokens_ljspeech():
partition = "all"

cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
g2p = g2p_en.G2p()

new_cuts = []
for cut in cut_set:
# Each cut only contains one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
text = cut.supervisions[0].normalized_text
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
cut.tokens = g2p(text)
tokens_list = phonemize_espeak(text, "en-us")
tokens = []
for t in tokens_list:
tokens.extend(t)
cut.tokens = tokens
new_cuts.append(cut)

new_cut_set = CutSet.from_cuts(new_cuts)
Expand Down
16 changes: 10 additions & 6 deletions egs/ljspeech/TTS/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
cd vits/monotonic_align
python setup.py build_ext --inplace
cd ../../
else
else
log "monotonic_align lib already built"
fi
fi
Expand Down Expand Up @@ -80,6 +80,11 @@ fi

if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for LJSpeech"
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
./local/prepare_tokens_ljspeech.py
mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \
Expand Down Expand Up @@ -113,13 +118,12 @@ fi

if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate token file"
# We assume you have installed g2p_en and espnet_tts_frontend.
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
# - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/tokens.txt ]; then
./local/prepare_token_file.py \
--manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \
--tokens data/tokens.txt
./local/prepare_token_file.py --tokens data/tokens.txt
fi
fi
3 changes: 1 addition & 2 deletions egs/ljspeech/TTS/vits/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ def main():
params.update(vars(args))

tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size

logging.info(params)
Expand Down
9 changes: 5 additions & 4 deletions egs/ljspeech/TTS/vits/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,16 @@ def _save_worker(
batch_size = len(batch["tokens"])

tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)

audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
Expand Down Expand Up @@ -201,8 +203,7 @@ def main():
device = torch.device("cuda", 0)

tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size

logging.info(f"Device: {device}")
Expand Down
4 changes: 3 additions & 1 deletion egs/ljspeech/TTS/vits/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def main():
model = OnnxModel(args.model_filename)

text = "I went there to see the land, the people and how their system works, end quote."
tokens = tokenizer.texts_to_token_ids([text])
tokens = tokenizer.texts_to_token_ids(
[text], intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
audio = model(tokens, tokens_lens) # (1, T')
Expand Down
77 changes: 57 additions & 20 deletions egs/ljspeech/TTS/vits/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
Expand All @@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, List

import g2p_en
import tacotron_cleaner.cleaners
from piper_phonemize import phonemize_espeak
from utils import intersperse


Expand All @@ -38,21 +39,37 @@ def __init__(self, tokens: str):
id = int(info[0])
else:
token, id = info[0], int(info[1])
assert token not in self.token2id, token
self.token2id[token] = id

self.blank_id = self.token2id["<blk>"]
self.oov_id = self.token2id["<unk>"]
self.vocab_size = len(self.token2id)
# Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md
self.pad_id = self.token2id["_"] # padding
self.sos_id = self.token2id["^"] # beginning of an utterance (bos)
self.eos_id = self.token2id["$"] # end of an utterance (eos)
self.space_id = self.token2id[" "] # word separator (whitespace)

self.g2p = g2p_en.G2p()
self.vocab_size = len(self.token2id)

def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
def texts_to_token_ids(
self,
texts: List[str],
intersperse_blank: bool = True,
add_sos: bool = False,
add_eos: bool = False,
lang: str = "en-us",
) -> List[List[int]]:
"""
Args:
texts:
A list of transcripts.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
add_sos:
Whether to add sos token at the start.
add_eos:
Whether to add eos token at the end.
lang:
Language argument passed to phonemize_espeak().
Returns:
Return a list of token id list [utterance][token_id]
Expand All @@ -63,30 +80,46 @@ def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
tokens = self.g2p(text)
tokens_list = phonemize_espeak(text, lang)
tokens = []
for t in tokens_list:
tokens.extend(t)

token_ids = []
for t in tokens:
if t in self.token2id:
token_ids.append(self.token2id[t])
else:
token_ids.append(self.oov_id)
if t not in self.token2id:
logging.warning(f"Skip OOV {t}")
continue
token_ids.append(self.token2id[t])

if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
token_ids = intersperse(token_ids, self.pad_id)
if add_sos:
token_ids = [self.sos_id] + token_ids
if add_eos:
token_ids = token_ids + [self.eos_id]

token_ids_list.append(token_ids)

return token_ids_list

def tokens_to_token_ids(
self, tokens_list: List[str], intersperse_blank: bool = True
):
self,
tokens_list: List[str],
intersperse_blank: bool = True,
add_sos: bool = False,
add_eos: bool = False,
) -> List[List[int]]:
"""
Args:
tokens_list:
A list of token list, each corresponding to one utterance.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
add_sos:
Whether to add sos token at the start.
add_eos:
Whether to add eos token at the end.
Returns:
Return a list of token id list [utterance][token_id]
Expand All @@ -96,13 +129,17 @@ def tokens_to_token_ids(
for tokens in tokens_list:
token_ids = []
for t in tokens:
if t in self.token2id:
token_ids.append(self.token2id[t])
else:
token_ids.append(self.oov_id)
if t not in self.token2id:
logging.warning(f"Skip OOV {t}")
continue
token_ids.append(self.token2id[t])

if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
token_ids = intersperse(token_ids, self.pad_id)
if add_sos:
token_ids = [self.sos_id] + token_ids
if add_eos:
token_ids = token_ids + [self.eos_id]

token_ids_list.append(token_ids)

Expand Down
Loading

0 comments on commit d89f4ea

Please sign in to comment.