-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathECAPAModel.py
121 lines (108 loc) · 4.52 KB
/
ECAPAModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
'''
This part is used to train the speaker model and evaluate the performances
'''
import soundfile as sf
import torch, sys, os, tqdm, numpy, soundfile, time, pickle
import torch.nn as nn
from tools import *
from loss import AAMsoftmax
from model import ECAPA_TDNN
class ECAPAModel(nn.Module):
def __init__(self, lr, lr_decay, C , n_class, m, s, test_step, **kwargs):
super(ECAPAModel, self).__init__()
## ECAPA-TDNN
self.speaker_encoder = ECAPA_TDNN(C = C).cuda()
## Classifier
self.speaker_loss = AAMsoftmax(n_class = n_class, m = m, s = s).cuda()
self.optim = torch.optim.Adam(self.parameters(), lr = lr, weight_decay = 2e-5)
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optim, step_size = test_step, gamma=lr_decay)
print(time.strftime("%m-%d %H:%M:%S") + " Model para number = %.2f"%(sum(param.numel() for param in self.speaker_encoder.parameters()) / 1024 / 1024))
def train_network(self, epoch, loader):
self.train()
## Update the learning rate based on the current epcoh
index, top1, loss = 0, 0, 0
lr = self.optim.param_groups[0]['lr']
for num, (data, labels) in enumerate(loader, start = 1):
self.zero_grad()
labels = labels.cuda()
speaker_embedding = self.speaker_encoder.forward(data.cuda())
output = self.speaker_loss.forward(speaker_embedding, labels)
nloss, prec = self.speaker_loss.evaluation(output, labels)
nloss.backward()
self.optim.step()
index += len(labels)
top1 += prec
loss += nloss.detach().cpu().numpy()
sys.stderr.write("Train: " + time.strftime("%m-%d %H:%M:%S") + \
" [%2d] Lr: %5f, Training: %.2f%%, " %(epoch, lr, 100 * (num / loader.__len__())) + \
" Loss: %.5f, ACC: %2.2f%% \r" %(loss/(num), top1/index*len(labels)))
sys.stderr.flush()
sys.stdout.write("\n")
self.scheduler.step()
return loss/num, lr, top1/index*len(labels)
def eval_network(self, eval_path, valid_pair):
self.eval()
files = []
embeddings = {}
for idx, pairs in enumerate(valid_pair):
for pair in pairs:
files += [pair[0], pair[1]]
setfiles = list(set(files))
setfiles.sort()
for idx, file in tqdm.tqdm(enumerate(setfiles), total = len(setfiles)):
label = int(file.split('_')[0][3:])
filename = os.path.join(eval_path, "spk{:03}".format(label), file)
audio, _ = sf.read(filename)
# Full utterance
data_1 = torch.FloatTensor(numpy.stack([audio],axis=0)).cuda()
# Spliited utterance matrix
max_audio = 300 * 160 + 240
if audio.shape[0] <= max_audio:
shortage = max_audio - audio.shape[0]
audio = numpy.pad(audio, (0, shortage), 'wrap')
feats = []
startframe = numpy.linspace(0, audio.shape[0]-max_audio, num=5)
for asf in startframe:
feats.append(audio[int(asf):int(asf)+max_audio])
feats = numpy.stack(feats, axis = 0).astype(numpy.float)
data_2 = torch.FloatTensor(feats).cuda()
# Speaker embeddings
with torch.no_grad():
embedding_1 = self.speaker_encoder.forward(data_1)
embedding_1 = F.normalize(embedding_1, p=2, dim=1)
embedding_2 = self.speaker_encoder.forward(data_2)
embedding_2 = F.normalize(embedding_2, p=2, dim=1)
embeddings[file] = [embedding_1, embedding_2]
scores, labels = [], []
for idx, pairs in enumerate(valid_pair):
for pair in pairs:
embedding_11, embedding_12 = embeddings[pair[0]]
embedding_21, embedding_22 = embeddings[pair[1]]
# Compute the scores
score_1 = torch.mean(torch.matmul(embedding_11, embedding_21.T)) # higher is positive
score_2 = torch.mean(torch.matmul(embedding_12, embedding_22.T))
score = (score_1 + score_2) / 2
score = score.detach().cpu().numpy()
scores.append(score)
labels.append(int(pair[0].split('_')[0][3:]))
# Coumpute EER and minDCF
EER = tuneThresholdfromScore(scores, labels, [1, 0.1])[1]
fnrs, fprs, thresholds = ComputeErrorRates(scores, labels)
minDCF, _ = ComputeMinDcf(fnrs, fprs, thresholds, 0.05, 1, 1)
return EER, minDCF
def save_parameters(self, path):
torch.save(self.state_dict(), path)
def load_parameters(self, path):
self_state = self.state_dict()
loaded_state = torch.load(path)
for name, param in loaded_state.items():
origname = name
if name not in self_state:
name = name.replace("module.", "")
if name not in self_state:
print("%s is not in the model."%origname)
continue
if self_state[name].size() != loaded_state[origname].size():
print("Wrong parameter length: %s, model: %s, loaded: %s"%(origname, self_state[name].size(), loaded_state[origname].size()))
continue
self_state[name].copy_(param)