-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathevaluate_trained_model.py
96 lines (78 loc) · 3.46 KB
/
evaluate_trained_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import Dict, List
import torch
import csv
import argparse
from perceptual_advex.utilities import add_dataset_model_arguments, \
get_dataset_model
from perceptual_advex.attacks import *
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Adversarial training evaluation')
add_dataset_model_arguments(parser, include_checkpoint=True)
parser.add_argument('attacks', metavar='attack', type=str, nargs='+',
help='attack names')
parser.add_argument('--batch_size', type=int, default=100,
help='number of examples/minibatch')
parser.add_argument('--parallel', type=int, default=1,
help='number of GPUs to train on')
parser.add_argument('--num_batches', type=int, required=False,
help='number of batches (default entire dataset)')
parser.add_argument('--per_example', action='store_true', default=False,
help='output per-example accuracy')
parser.add_argument('--output', type=str, help='output CSV')
args = parser.parse_args()
dataset, model = get_dataset_model(args)
_, val_loader = dataset.make_loaders(1, args.batch_size, only_val=True)
model.eval()
if torch.cuda.is_available():
model.cuda()
attack_names: List[str] = args.attacks
attacks = [eval(attack_name) for attack_name in attack_names]
# Parallelize
if torch.cuda.is_available():
device_ids = list(range(args.parallel))
model = nn.DataParallel(model, device_ids)
attacks = [nn.DataParallel(attack, device_ids) for attack in attacks]
batches_correct: Dict[str, List[torch.Tensor]] = \
{attack_name: [] for attack_name in attack_names}
for batch_index, (inputs, labels) in enumerate(val_loader):
print(f'BATCH {batch_index:05d}')
if (
args.num_batches is not None and
batch_index >= args.num_batches
):
break
if torch.cuda.is_available():
inputs = inputs.cuda()
labels = labels.cuda()
for attack_name, attack in zip(attack_names, attacks):
adv_inputs = attack(inputs, labels)
with torch.no_grad():
adv_logits = model(adv_inputs)
batch_correct = (adv_logits.argmax(1) == labels).detach()
batch_accuracy = batch_correct.float().mean().item()
print(f'ATTACK {attack_name}',
f'accuracy = {batch_accuracy * 100:.1f}',
sep='\t')
batches_correct[attack_name].append(batch_correct)
print('OVERALL')
accuracies = []
attacks_correct: Dict[str, torch.Tensor] = {}
for attack_name in attack_names:
attacks_correct[attack_name] = torch.cat(batches_correct[attack_name])
accuracy = attacks_correct[attack_name].float().mean().item()
print(f'ATTACK {attack_name}',
f'accuracy = {accuracy * 100:.1f}',
sep='\t')
accuracies.append(accuracy)
with open(args.output, 'w') as out_file:
out_csv = csv.writer(out_file)
out_csv.writerow(attack_names)
if args.per_example:
for example_correct in zip(*[
attacks_correct[attack_name] for attack_name in attack_names
]):
out_csv.writerow(
[int(attack_correct.item()) for attack_correct
in example_correct])
out_csv.writerow(accuracies)