-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain-test.py
executable file
·127 lines (101 loc) · 4.14 KB
/
main-test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#/usr/bin/env python
import torch
import torchaudio
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from scipy.stats import spearmanr
from ptpt.log import info
import tqdm
import argparse
import toml
import datetime
from pathlib import Path
from types import SimpleNamespace
from data import ConcatDataset, FirstChannelDataset, FeatureScoreDataset
from stoi import STOIPredictor
from utils import set_seed, get_device
def main(args):
torchaudio.set_audio_backend('soundfile')
torch.autograd.set_detect_anomaly(args.detect_anomaly)
cfg = SimpleNamespace(**toml.load(args.cfg_path))
device = get_device(args.no_cuda)
args.data_root = Path(args.data_root)
data_mode = cfg.data['mode']
if data_mode in ['vqcpc']:
test_dataset = FeatureScoreDataset(args.data_root, load_z=False, return_file=True)
elif data_mode in ['concat']:
test_dataset = ConcatDataset(args.data_root, cfg.data['sample_rate'], return_file=True)
elif data_mode in ['single', 'rossbach']:
test_dataset = FirstChannelDataset(args.data_root, cfg.data['sample_rate'], return_file=True)
def loss_fn(net, batch):
x, stoi, end = batch
batch_size = x.shape[0]
stoi_pred = net(x)
if not cfg.model['pool']:
stoi_pred = masked_mean(stoi_pred, end)
loss = F.mse_loss(stoi_pred, stoi)
stoi, stoi_pred = stoi.detach().cpu(), stoi_pred.detach().cpu()
return loss, np.corrcoef(stoi, stoi_pred)[0, 1], spearmanr(stoi.T, stoi_pred.T)[0], stoi_pred
def collate_fn(data):
x, stoi, name = zip(*data)
lengths = [v.shape[0] for v in x]
batch_size = len(lengths)
stoi = torch.FloatTensor(stoi)
lengths = torch.LongTensor(lengths)
X = pad_sequence(x, batch_first=True)
return X, stoi, lengths, name
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
num_workers=args.nb_workers,
pin_memory=True,
collate_fn=collate_fn
)
chk = torch.load(args.checkpoint_path)
net = STOIPredictor(**cfg.model).to(device)
net.load_state_dict(chk['net'])
net.eval()
if not args.no_save:
ctime = str(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
result_f = open(f"results_{cfg.trainer['exp_name']}_{args.data_root.name}_{ctime}.csv", mode='w')
pb = tqdm.tqdm(test_loader)
total_loss, total_lcc, total_srcc = 0.0, 0.0, 0.0
with torch.inference_mode(), torch.cuda.amp.autocast(enabled = not args.no_amp and not args.no_cuda):
for batch in pb:
name = batch[-1]
loss, lcc, srcc, score = loss_fn(net, batch[:-1])
total_loss += loss.item()
total_lcc += lcc
total_srcc += srcc
if not args.no_save:
for n, s in zip(name, score):
row = (
f"{n},"
f"{s.item()}\n"
)
result_f.write(row)
if not args.no_save:
result_f.close()
avg_loss = total_loss / len(test_loader)
avg_lcc = total_lcc / len(test_loader)
avg_srcc = total_srcc / len(test_loader)
info(f"average MSE loss: {avg_loss}")
info(f"average LCC: {avg_lcc}")
info(f"average SRCC: {avg_srcc}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint_path', type=str)
parser.add_argument('data_root', type=str)
parser.add_argument('--cfg-path', type=str, default='config/vqcpc/stoi-gru128-pool-kmean.toml')
parser.add_argument('--no-save', action='store_true')
parser.add_argument('--no-cuda', action='store_true')
parser.add_argument('--no-amp', action='store_true')
parser.add_argument('--no-tqdm', action='store_true')
parser.add_argument('--nb-workers', type=int, default=8)
parser.add_argument('--detect-anomaly', action='store_true') # menacing aura!
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--seed', type=int, default=12345)
args = parser.parse_args()
set_seed(args.seed)
main(args)