-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathextract_ssl_s2.py
69 lines (58 loc) · 2.09 KB
/
extract_ssl_s2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import math
import multiprocessing
import argparse
from random import shuffle
import torch.multiprocessing as mp
import torch
from glob import glob
from tqdm import tqdm
import utils
from data_conf import data_root
from feature_extractor import content_module_map
import logging
logging.getLogger("numba").setLevel(logging.WARNING)
import librosa
def process_one(file_path, model, device, content_module):
ssl_path = file_path.replace(".wav", ".ssl.pt")
try:
wav16k, sr = librosa.load(file_path, sr=16000)
wav16k = torch.from_numpy(wav16k).to(device)
ssl_content = content_module.get_content(model, wav_16k_tensor=wav16k)
torch.save(ssl_content.cpu().half(), ssl_path)
del ssl_content
del wav16k
except:
print("skip", file_path)
def process_batch(filenames, content_module):
content_module = content_module_map[content_module]
print("Loading content model...")
rank = mp.current_process()._identity
rank = rank[0] if len(rank) > 0 else 0
gpu_id = rank % torch.cuda.device_count()
device = torch.device(f"cuda:{gpu_id}")
print(device)
ssl_model = content_module.get_model().to(device)
print("Loaded content model.")
for filename in tqdm(filenames):
process_one(filename, ssl_model, device, content_module)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", type=str, default="configs/s2.json", help="path to config"
)
args = parser.parse_args()
filenames = glob(f"{data_root}/**/*.wav", recursive=True) # [:10]
hps = utils.get_hparams_from_file(args.config)
shuffle(filenames)
multiprocessing.set_start_method("spawn", force=True)
num_processes = 8
chunk_size = int(math.ceil(len(filenames) / num_processes))
chunks = [
filenames[i : i + chunk_size] for i in range(0, len(filenames), chunk_size)
]
print([len(c) for c in chunks])
processes = [
multiprocessing.Process(target=process_batch, args=(chunk,hps.content_module)) for chunk in chunks
]
for p in processes:
p.start()