forked from SVDDChallenge/CtrSVDD2024_Baseline
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
52 lines (43 loc) · 1.94 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import random
import numpy as np
import torch
def seed_worker(worker_id):
"""
Used in generating seed for the worker of torch.utils.data.Dataloader
"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def set_seed(seed):
"""
set initial seed for reproduction
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def compute_det_curve(target_scores, nontarget_scores):
n_scores = target_scores.size + nontarget_scores.size
all_scores = np.concatenate((target_scores, nontarget_scores))
labels = np.concatenate((np.ones(target_scores.size), np.zeros(nontarget_scores.size)))
# Sort labels based on scores
indices = np.argsort(all_scores, kind='mergesort')
labels = labels[indices]
# Compute false rejection and false acceptance rates
tar_trial_sums = np.cumsum(labels)
nontarget_trial_sums = nontarget_scores.size - (np.arange(1, n_scores + 1) - tar_trial_sums)
frr = np.concatenate((np.atleast_1d(0), tar_trial_sums / target_scores.size)) # false rejection rates
far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / nontarget_scores.size)) # false acceptance rates
thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices])) # Thresholds are the sorted scores
return frr, far, thresholds
def compute_eer(target_scores, nontarget_scores):
target_scores = np.array(target_scores).flatten()
nontarget_scores = np.array(nontarget_scores).flatten()
frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores)
abs_diffs = np.abs(frr - far)
min_index = np.argmin(abs_diffs)
eer = np.mean((frr[min_index], far[min_index]))
return eer, thresholds[min_index]