-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtest.py
155 lines (143 loc) · 5.86 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import tensorflow as tf
try:
# Disable all GPUS
tf.config.set_visible_devices([], 'GPU')
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != 'GPU'
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
import math
import os
import tqdm
import glob
import datetime
import copy
import pickle
from threading import main_thread
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utilities.data.utils import _collate_fn_raw, _collate_fn_raw_multiclass
from utilities.data.raw_transforms import get_raw_transforms_v2, simple_supervised_transforms
from utilities.config_parser import parse_config, get_data_info, get_config
from models.classifier import Classifier
from utilities.training_utils import setup_dataloaders, optimization_helper
import argparse
from torch.utils.data import DataLoader
from utilities.data.raw_dataset import RawWaveformDataset as SpectrogramDataset
import wandb
from utilities.data.mixup import do_mixup, mixup_criterion
from utilities.metrics_helper import calculate_mAP
from utilities.data.raw_transforms import Compose, PeakNormalization, PadToSize
from sklearn.metrics import accuracy_score
from utilities.metrics_helper import calculate_stats, d_prime
torch.backends.cudnn.enabled = False
def get_val_acc(x):
x = x.split("/")[-1]
x = x.replace(".pth", "")
x = x.split("val_acc=")[-1]
return float(x)
parser = argparse.ArgumentParser()
parser.add_argument("--test_csv_name", type=str)
parser.add_argument("--exp_dir", type=str)
parser.add_argument("--meta_dir", type=str)
parser.add_argument("--gpu_id", type=int, default=0)
parser.add_argument("--metrics", type=str, default="multiclass")
parser.add_argument("--separator", type=str, default=",")
def pad_input(signal, sr):
signal = signal[0]
# print("input signal shape:", signal.shape)
size = int(math.ceil(signal.shape[1] / sr) * sr)
# print(size)
padding = size - signal.shape[1]
offset = padding // 2
pad_width = ((0, 0), (offset, padding - offset))
signal = torch.nn.functional.pad(signal, pad_width[1], "replicate")
signal = signal.unsqueeze(0)
# print("padded input shape:", signal.shape)
signal = signal.reshape(-1, 1, int(sr*1))
# print("batched input shape:", signal.shape)
return signal
if __name__ == '__main__':
args = parser.parse_args()
hparams_path = os.path.join(args.exp_dir, "hparams.pickle")
ckpts = sorted(glob.glob(os.path.join(args.exp_dir, "ckpts", "*")), key=get_val_acc)
print(ckpts)
if len(ckpts) == 0:
print(f"Well, no checkpoints found in {args.exp_dir}. Exiting...")
exit()
ckpt_path = ckpts[-1]
print(ckpt_path)
checkpoint = torch.load(ckpt_path)
fname = ckpt_path.split("/")[-3]
ckpt_ext = "/".join(ckpt_path.split("/")[-3:])
res = os.path.join(args.exp_dir, "results.txt")
if os.path.exists(res):
print(f"{res} files exists.. exiting")
exit()
with open(hparams_path, "rb") as fp:
hparams = pickle.load(fp)
model = Classifier(hparams.cfg)
device = torch.device(f"cuda:{args.gpu_id}")
print(model.load_state_dict(checkpoint['model_state_dict']))
model = model.to(device).eval()
# print(model)
ac = hparams.cfg['audio_config']
print(ac)
# val_clip_size = int(ac['val_clip_size'] * ac['sample_rate'])
val_tfs = Compose([
# PadToSize(val_clip_size, 'wrap'),
PeakNormalization(sr=ac['sample_rate'])
])
sr = ac['sample_rate']
# padder = PadToSize(int(sr * 1.0), "wrap")
# val_tfs = simple_supervised_transforms(False, val_clip_size,
# sample_rate=ac['sample_rate'])
val_set = SpectrogramDataset(os.path.join(args.meta_dir, args.test_csv_name),
os.path.join(args.meta_dir, "lbl_map.json"),
hparams.cfg['audio_config'], mode=args.metrics,
transform=val_tfs, is_val=True, delimiter=args.separator
)
collate_fn = _collate_fn_raw_multiclass if args.metrics == "multiclass" else _collate_fn_raw
loader = DataLoader(val_set, batch_size=1, num_workers=2, collate_fn=collate_fn)
all_preds = []
all_gts = []
for batch in tqdm.tqdm(loader):
x, _, y = batch
# print(x.shape, y.shape)
o = pad_input(x, sr)
# print(o.shape)
o = o.to(device)
with torch.no_grad():
preds = model(o)
# print(preds.shape)
preds = torch.mean(preds, dim=0, keepdim=True)
# print(preds.shape)
if args.metrics == "multiclass":
preds = torch.argmax(preds, 1).detach().item()
all_preds.append(preds)
all_gts.append(y.detach().cpu().float().item())
else:
y_pred_sigmoid = torch.sigmoid(preds)
all_preds.append(y_pred_sigmoid.detach().cpu().float())
all_gts.append(y.detach().cpu().float())
if args.metrics == "multiclass":
acc = accuracy_score(np.asarray(all_gts), np.asarray(all_preds))
print("Accuracy: {:.4f}".format(acc))
with open(res, "w") as fd:
fd.writelines("model,acc,ckpt_ext\n")
fd.writelines("{},{},{}\n".format(fname, acc, ckpt_ext))
else:
macro_mAP = calculate_mAP(all_preds, all_gts, mode='macro')
all_preds = torch.cat(all_preds).detach().cpu().numpy()
all_gts = torch.cat(all_gts).detach().cpu().numpy()
stats = calculate_stats(all_preds, all_gts)
# mAP = np.mean([stat['AP'] for stat in stats])
mAUC = np.mean([stat['auc'] for stat in stats])
dp = d_prime(mAUC)
print("mAP: {:.5f}".format(macro_mAP))
print("mAUC: {:.5f}".format(mAUC))
print("dprime: {:.5f}".format(dp))