forked from PeterWang512/FALdetector
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
119 lines (96 loc) · 3.92 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import glob
import argparse
import torch
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from networks.drn_seg import DRNSeg, DRNSub
from utils.tools import *
from utils.visualize import *
from sklearn.metrics import average_precision_score, accuracy_score
def load_global_classifier(model_path, gpu_id):
if torch.cuda.is_available() and gpu_id != -1:
device = 'cuda:{}'.format(gpu_id)
else:
device = 'cpu'
model = DRNSub(1)
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
model.to(device)
model.device = device
model.eval()
return model
def load_local_detector(model_path, gpu_id):
if torch.cuda.is_available():
device = 'cuda:{}'.format(gpu_id)
else:
device = 'cpu'
model = DRNSeg(2)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict['model'])
model.to(device)
model.device = device
model.eval()
return model
tf = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
def load_data(img_path, device):
face = Image.open(img_path).convert('RGB')
face = resize_shorter_side(face, 400)[0]
face_tens = tf(face).to(device)
return face_tens, face
def classify_fake(model, img_path):
img = load_data(img_path, model.device)[0].unsqueeze(0)
# Prediction
with torch.no_grad():
prob = model(img)[0].sigmoid().cpu().item()
return prob
def calc_psnr(img0, img1, mask=None):
return -10 * np.log10(np.mean((img0 - img1)**2) + 1e-6)
def detect_warp(model, img_path):
img, modified = load_data(img_path, model.device)
# Warping field prediction
with torch.no_grad():
flow = model(img.unsqueeze(0))[0].cpu().numpy()
flow = np.transpose(flow, (1, 2, 0))
# Undoing the warps
flow = flow_resize(flow, modified.size)
modified_np = np.asarray(modified)
reverse_np = warp(modified_np, flow)
original = Image.open(img_path.replace('modified', 'reference')).convert('RGB')
original_np = np.asarray(original.resize(modified.size, Image.BICUBIC))
psnr_before = calc_psnr(original_np / 255, modified_np / 255)
psnr_after = calc_psnr(original_np / 255, reverse_np / 255)
return psnr_before, psnr_after
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataroot", required=True, help='the root to the dataset')
parser.add_argument(
"--global_pth", required=True, help="path to the global model")
parser.add_argument(
"--local_pth", required=True, help="path to the local model")
parser.add_argument(
"--gpu_id", default='0', help="the id of the gpu to run model on")
args = parser.parse_args()
glb_model = load_global_classifier(args.global_pth, args.gpu_id)
lcl_model = load_local_detector(args.local_pth, args.gpu_id)
pred_prob, gt_prob, psnr_before, psnr_after = [], [], [], []
for img_path in glob.glob(args.dataroot + '/original/*'):
pred_prob.append(classify_fake(glb_model, img_path))
gt_prob.append(0)
for img_path in glob.glob(args.dataroot + '/modified/*'):
pred_prob.append(classify_fake(glb_model, img_path))
gt_prob.append(1)
psnrs = detect_warp(lcl_model, img_path)
psnr_before.append(psnrs[0])
psnr_after.append(psnrs[1])
pred_prob, gt_prob, psnr_before, psnr_after = \
np.array(pred_prob), np.array(gt_prob), np.array(psnr_before), np.array(psnr_after)
acc = accuracy_score(gt_prob, pred_prob > 0.5)
avg_precision = average_precision_score(gt_prob, pred_prob)
delta_psnr = psnr_after.mean() - psnr_before.mean()
print("Accuracy: ", acc)
print("Average precision: ", avg_precision)
print("PSNR increase: ", delta_psnr)