-
Notifications
You must be signed in to change notification settings - Fork 7
/
train_rhythm_model.py
62 lines (48 loc) · 1.64 KB
/
train_rhythm_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
import logging
from pathlib import Path
import torch
import numpy as np
from tqdm import tqdm
from urhythmic.rhythm import RhythmModelFineGrained, RhythmModelGlobal
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
HOP_LENGTH = 320
SAMPLE_RATE = 16000
def train_rhythm_model(args):
logger.info(f"Training {args.model} rhythm model on {args.dataset_dir}")
model_type = RhythmModelFineGrained if args.model == "fine" else RhythmModelGlobal
rhythm_model = model_type(hop_length=HOP_LENGTH, sample_rate=SAMPLE_RATE)
utterances = []
for path in tqdm(list(args.dataset_dir.rglob("*.npz"))):
file = np.load(path, allow_pickle=True)
segments = list(file["segments"])
boundaries = list(file["boundaries"])
utterances.append((segments, boundaries))
dists = rhythm_model._fit(utterances)
logger.info(f"Saving checkpoint to {args.checkpoint_path}")
torch.save(dists, args.checkpoint_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train the FineGrained or Global rhythm model."
)
parser.add_argument(
"model",
help="type of rhythm model (fine-grained or global).",
type=str,
choices=["fine", "global"],
)
parser.add_argument(
"dataset_dir",
metavar="dataset-dir",
help="path to the directory of segmented speech.",
type=Path,
)
parser.add_argument(
"checkpoint_path",
metavar="checkpoint-path",
help="path to save checkpoint.",
type=Path,
)
args = parser.parse_args()
train_rhythm_model(args)