-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmetrics.py
104 lines (87 loc) · 4.1 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import torch
import yaml
import hyperpyyaml
import torch.nn.functional as F
import h5py
from scipy.signal import medfilt
from argparse import ArgumentParser
from pyannote.core import Segment, Timeline, Annotation
from pyannote.metrics.diarization import DiarizationErrorRate
from datasets.diarization_dataset import KaldiDiarizationDataset
def gen_ref(configs, hyp_dir, metric, threshold=0.5, median=11, subsampling=10):
test_set = KaldiDiarizationDataset(
data_dir=configs["data"]["val_data_dir"],
chunk_size=configs["data"]["chunk_size"],
context_size=configs["data"]["context_recp"],
input_transform=configs["data"]["feat_type"],
frame_size=configs["data"]["feat"]["win_length"],
frame_shift=configs["data"]["feat"]["hop_length"],
subsampling=configs["data"]["subsampling"],
rate=configs["data"]["feat"]["sample_rate"],
label_delay=configs["data"]["label_delay"],
n_speakers=configs["data"]["num_speakers"],
use_last_samples=configs["data"]["use_last_samples"],
shuffle=configs["data"]["shuffle"])
der = 0.0
diaerr = 0.0
spkcon = 0.0
falarm = 0.0
miss = 0.0
spkscore = 0.0
for i in range(len(test_set)):
# ref: (5000, C)
label, rec = test_set.__getfulllabel__(i)
# perm label
T, n_spk = label.shape
frame_idx = torch.arange(1, T + 1).unsqueeze(-1)
label_idx = (frame_idx * label)
label_idx = label_idx.masked_fill_(label_idx == 0, torch.inf)
sort_idx = torch.argsort(torch.min(label_idx, dim=0)[0])
label = label[:, sort_idx]
reference = Annotation(uri=f'file{i+1}')
for spkid, frames in enumerate(label.T):
frames = F.pad(frames, (1, 1), 'constant')
changes, = torch.where(torch.diff(frames, dim=0) != 0)
for s, e in zip(changes[::2], changes[1::2]):
reference[Segment(s, e)] = str(spkid)
# read hypothesis h5 file
filepath = os.path.join(hyp_dir, rec+".h5")
data = h5py.File(filepath, 'r')
pred = torch.where(torch.from_numpy(data['T_hat'][:]).float() > threshold, 1, 0)
if median > 1:
pred = medfilt(pred, (median, 1))
pred = torch.from_numpy(pred).float()
hypothesis = Annotation(uri=f'file{i+1}')
for spkid, frames in enumerate(pred.T):
frames = F.pad(frames, (1, 1), 'constant')
changes, = torch.where(torch.diff(frames, dim=0) != 0)
for s, e in zip(changes[::2], changes[1::2]):
hypothesis[Segment(s*subsampling, e*subsampling)] = str(spkid)
res = metric(reference, hypothesis, detailed=True)
spkscore += res['total']
spkcon += res['confusion']
falarm += res['false alarm']
miss += res['missed detection']
diaerr += res['confusion'] + res['false alarm'] + res['missed detection']
print(rec)
print("der: ", (res['confusion'] + res['false alarm'] + res['missed detection']) / res['total'])
der = diaerr / spkscore
print("speaker score: ", spkscore)
print('mean der: ', der)
print('mean speaker confusion rate: ', spkcon / spkscore)
print('mean speaker false alarm rate: ', falarm / spkscore)
print('mean speaker miss rate: ', miss / spkscore)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--configs', help='Configuration file path', required=True)
parser.add_argument('--preds_dir', default=None, help='Hypothesis results dir')
parser.add_argument("--thredshold", default=0.5, help="Threshold of decision")
parser.add_argument("--median", default=11, help="Median filter parameter")
setup = parser.parse_args()
with open(setup.configs, "r") as f:
configs = hyperpyyaml.load_hyperpyyaml(f)
f.close()
preds_dir = "/mnt/home/liangdi/projects/pl_version/pl_eend/tsne_visual/data/onl_2spk_version_tfm_10w_ver_37/preds_h5"
metric = DiarizationErrorRate(collar=50)
gen_ref(configs, preds_dir, metric)