-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathvalidate.py
63 lines (53 loc) · 2.43 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
import torch
import numpy as np
from data import AVLip
import torch.utils.data
from models import build_model
from sklearn.metrics import average_precision_score, confusion_matrix, accuracy_score
def validate(model, loader, gpu_id):
print("validating...")
device = torch.device(f"cuda:{gpu_id[0]}" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
y_true, y_pred = [], []
for img, crops, label in loader:
img_tens = img.to(device)
crops_tens = [[t.to(device) for t in sublist] for sublist in crops]
features = model.get_features(img_tens).to(device)
y_pred.extend(model(crops_tens, features)[0].sigmoid().flatten().tolist())
y_true.extend(label.flatten().tolist())
y_true = np.array(y_true)
y_pred = np.where(np.array(y_pred) >= 0.5, 1, 0)
# Get AP
ap = average_precision_score(y_true, y_pred)
cm = confusion_matrix(y_true, y_pred)
tp, fn, fp, tn = cm.ravel()
fnr = fn / (fn + tp)
fpr = fp / (fp + tn)
acc = accuracy_score(y_true, y_pred)
return ap, fpr, fnr, acc
if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--real_list_path", type=str, default="./datasets/val/0_real")
parser.add_argument("--fake_list_path", type=str, default="./datasets/val/1_fake")
parser.add_argument("--max_sample", type=int, default=1000, help="max number of validate samples")
parser.add_argument("--batch_size", type=int, default=10)
parser.add_argument("--data_label", type=str, default="val")
parser.add_argument("--arch", type=str, default="CLIP:ViT-L/14")
parser.add_argument("--ckpt", type=str, default="./checkpoints/ckpt.pth")
parser.add_argument("--gpu", type=int, default=0)
opt = parser.parse_args()
device = torch.device(f"cuda:{opt.gpu}" if torch.cuda.is_available() else "cpu")
print(f"Using cuda {opt.gpu} for inference.")
model = build_model(opt.arch)
state_dict = torch.load(opt.ckpt, map_location="cpu")
model.load_state_dict(state_dict["model"])
print("Model loaded.")
model.eval()
model.to(device)
dataset = AVLip(opt)
loader = data_loader = torch.utils.data.DataLoader(
dataset, batch_size=opt.batch_size, shuffle=True
)
ap, fpr, fnr, acc = validate(model, loader, gpu_id=[opt.gpu])
print(f"acc: {acc} ap: {ap} fpr: {fpr} fnr: {fnr}")