-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
140 lines (116 loc) · 5.26 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# -*- coding: utf-8 -*-
# @Description:
import argparse
import warnings
from tqdm import tqdm
import torch
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms, datasets
from attack import *
from models import IndentifyModel
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--method', required=True,
choices=['L-BFGS', "FGSM", 'I-FGSM', 'JSMA', 'ONE-PIXEL', 'C&W', 'DEEPFOOL', 'MI-FGSM', 'UPSET'],
help="Test method: L-BFGS, FGSM, I-FGSM, JSMA, ONE-PIXEL, C&W, DEEPFOOL, MI-FGSM, UPSET")
parser.add_argument('-c', '--count', default=1000, type=int,
help="Number of tests (default is 500), but if the number of test datasets is less than this "
"number, the number of test datasets prevails")
args = parser.parse_args()
def main():
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_datasets = datasets.CIFAR10("./datasets", train=False, transform=transform_test)
# 有一些方法是支持batch_size不为1的,按方法设置就行,如果不知道,那就保持1
dataloader = DataLoader(test_datasets, batch_size=1, shuffle=False, num_workers=4, drop_last=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss().to(device)
model = IndentifyModel().to(device)
# -------------------------------------------
# 在这里,您可以加载已训练的模型参数文件
# warnings.warn(f"You Must Load The Parameter of Model: {model.__class__.__name__}")
# 加载了就可以把警告删了
model.load_state_dict(torch.load(f"./parameter/{model.__class__.__name__}/100.pth"))
print("预训练模型加载完成")
# -------------------------------------------
method = args.method.upper()
if method == "L-BFGS":
# L-BFGS
attacker = L_BFGS(model=model, criterion=criterion)
# attacker = L_BFGS(parameter=parameter, criterion=criterion, iters=2, epsilon=0.2)
elif method == "FGSM":
# FGSM
attacker = FGSM(model=model, criterion=criterion)
# attacker = FGSM(parameter=parameter, criterion=criterion, epsilon=0.2)
elif method == "I-FGSM":
# I-FGSM
attacker = I_FGSM(model=model, criterion=criterion)
# attacker = I_FGSM(parameter=parameter, criterion=criterion)
elif method == "JSMA":
# JSMA
attacker = JSMA(model=model)
# attacker = JSMA(parameter=parameter, alpha=6, gamma=6, iters=50)
elif method == "ONE-PIXEL":
# ONE-PIXEL
attacker = ONE_PIXEL(model=model)
# attacker = ONE_PIXEL(parameter=parameter)
elif method == "C&W":
# C&W
attacker = CW(model=model, criterion=criterion)
# attacker = CW(parameter=parameter, criterion=criterion, iters=1000)
elif method == "DEEPFOOL":
# DEEPFOOL
attacker = DeepFool(model=model)
# attacker = DeepFool(parameter=parameter, overshoot=2, iters=100)
elif method == "MI-FGSM":
# MI-FGSM
attacker = MI_FGSM(model=model, criterion=criterion)
# attacker = MI_FGSM(parameter=parameter, criterion=criterion)
elif method == "UPSET":
# UPSET
residual_model = ResidualModel().to(device)
warnings.warn(f"You Must Load The Parameter of Model: {residual_model.__class__.__name__}")
# residual_model.load_state_dict(torch.load("./parameter/UPSET/target_0/1.pth"))
attacker = UPSET(model=residual_model)
else:
print(f"Unknown Method: {method}")
return
# -------------------------------------------
# begin to test
# 计数器
counter = 0
max_counter = min(args.count, len(dataloader))
print(f"Total Test Num: {max_counter}")
batch_size = dataloader.batch_size
# 整体正确率
total_num = 0
total_origin_accuracy = 0
total_attack_accuracy = 0
model.eval()
# 这里按照你设置的max_count和数据集数量的最小值
tqdm_dataloader = tqdm(dataloader, desc="Test", total=max_counter)
for image, target in tqdm_dataloader:
# 更新进度条
image, target = image.to(device), target.to(device)
# 初始结果(未攻击)
orinal_output = attacker.forward(image)
# 生成攻击图像 # 攻击后结果
pert_image = attacker.attack(image, target)
attack_output = attacker.forward(pert_image)
counter += 1
total_num += batch_size
attack_accuracy = (attack_output.argmax(1) == target).sum()
origin_accuracy = (orinal_output.argmax(1) == target).sum()
total_origin_accuracy += origin_accuracy
total_attack_accuracy += attack_accuracy
tqdm_dataloader.set_postfix(AttackAcc=f"{attack_accuracy / batch_size}",
OriginAcc=f"{origin_accuracy / batch_size}")
if tqdm_dataloader.n >= max_counter:
break
print(f"{attacker.__class__.__name__} "
f"初始正确率: {total_origin_accuracy / (max_counter * batch_size)} "
f"攻击后正确率: {total_attack_accuracy / (max_counter * batch_size)} ")
if __name__ == "__main__":
main()