-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy patheval_gen_lp_norms.py
109 lines (82 loc) · 4.14 KB
/
eval_gen_lp_norms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import foolbox as fb
import torch
import torchvision
from foolbox import PyTorchModel
from torchvision import transforms
from tqdm import tqdm
from sub_sources.DBAT.models.wideresnet import WideResNet
from sub_sources.DBAT.train_utils import eval_test
from normalize_utils import NormalizeByChannelMeanStd
parser = argparse.ArgumentParser(description='Evaluation code for General Lp norms')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='Batch size for testing (default: 128)')
parser.add_argument('--model-dir', type=str, default='model_cifar10',
help='path to model we wish to load')
parser.add_argument('--checkpoint', default='checkpoint.pt', type=str,
help='path to pretrained model')
parser.add_argument('--norm', default='l2', type=str, help='lp norm to use for attack',
choices=['l2', 'l1', 'linf', 'L2DeepFoolAttack', 'LinfDeepFoolAttack'])
args = parser.parse_args([])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
# setup data loader
transform_test = transforms.Compose([
transforms.ToTensor(),
])
testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, **kwargs)
def main():
dataset_normalization = NormalizeByChannelMeanStd(
mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
model = WideResNet(num_classes=args.num_classes ,use_normalize=True, normalize_layer=dataset_normalization, eval_mode=True)
model = model.to(device)
assert args.checkpoint != ''
checkpoint = torch.load(f'{args.model_dir}/{args.checkpoint}', map_location=device)
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
print(f'test_accuracy={checkpoint["test_accuracy"]}')
else:
state_dict = checkpoint
model.load_state_dict(state_dict)
print('read checkpoint {}'.format(args.checkpoint))
robust_acc, natural_acc = eval_test(model, test_loader, norm=args.norm)
print("For attack norm {}, Robust Accuracy: {}, Natural Accuracy: {}".format(args.norm, robust_acc, natural_acc))
def eval_test(model, test_loader, norm):
model.eval()
fb_model = PyTorchModel(model, bounds=(0, 1), device=device)
if norm == 'l2':
epsilon, num_steps = 0.5, 20
adversary = fb.attacks.L2PGD(steps=num_steps) #, abs_stepsize=step_size)
elif norm == 'l1':
epsilon, num_steps = 12, 20
adversary = fb.attacks.L1PGD(steps=num_steps) #, abs_stepsize=step_size)
elif norm == 'LinfDeepFoolAttack':
epsilon, num_steps = 0.02, 50
adversary = fb.attacks.LinfDeepFoolAttack(steps=num_steps)
elif norm == 'L2DeepFoolAttack':
epsilon, num_steps = 0.02, 50
adversary = fb.attacks.L2DeepFoolAttack(steps=num_steps)
else:
raise NotImplementedError(f'Requested norm is not implemented: {norm}')
print(f'Using adversary with norm={norm}')
robust_err_total = 0
natural_err_total = 0
num_test_samples = len(test_loader.dataset)
for data, target in tqdm(test_loader):
data, target = data.to(device), target.to(device)
raw_advs, clipped_advs, success = adversary(fb_model, data, target, epsilons=epsilon)
out = model(data)
err_clean = (out.data.max(1)[1] != target.data).float().sum()
# err_adv = (model(clipped_advs).data.max(1)[1] != target.data).float().sum()
err_adv = success.float().sum()
robust_err_total += err_adv
natural_err_total += err_clean
print('natural_err_total: ', natural_err_total)
print('robust_err_total: ', robust_err_total)
natural_acc = round((num_test_samples - natural_err_total.item()) / num_test_samples * 100, 3)
print(f'Natural acc total: {natural_acc}')
robust_acc = round((num_test_samples - robust_err_total.item()) / num_test_samples * 100, 3)
print(f'Robust acc total: {robust_acc}')
return robust_acc, natural_acc
if __name__ == '__main__':
main()