Skip to content

Commit

Permalink
Fix bug with square attack
Browse files Browse the repository at this point in the history
  • Loading branch information
chawins committed Feb 17, 2023
1 parent 0cdf991 commit 218af0c
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 51 deletions.
27 changes: 19 additions & 8 deletions autoattack_modified/autoattack.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,34 +255,46 @@ def run_standard_evaluation(self, x_orig, y_orig, bs=250, **kwargs_orig):
# apgd on cross-entropy loss
self.apgd.loss = "ce"
self.apgd.seed = self.get_seed()
adv_curr = self.apgd.perturb(x, y, **kwargs) # cheap=True
adv_curr = self.apgd.perturb(
x, y, **kwargs
) # cheap=True
elif attack == "apgd-dlr":
# apgd on dlr loss
self.apgd.loss = "dlr"
self.apgd.seed = self.get_seed()
adv_curr = self.apgd.perturb(x, y, **kwargs) # cheap=True
adv_curr = self.apgd.perturb(
x, y, **kwargs
) # cheap=True

elif attack == "fab":
# fab
self.fab.targeted = False
self.fab.seed = self.get_seed()
adv_curr = self.fab.perturb(x, y, **kwargs) # cheap=True
adv_curr = self.fab.perturb(
x, y, **kwargs
) # cheap=True

elif attack == "square":
# square
self.square.seed = self.get_seed()
adv_curr = self.square.perturb(x, y, **kwargs) # cheap=True
adv_curr = self.square.perturb(
x, y, **kwargs
) # cheap=True

elif attack == "apgd-t":
# targeted apgd
self.apgd_targeted.seed = self.get_seed()
adv_curr = self.apgd_targeted.perturb(x, y, **kwargs) # cheap=True
adv_curr = self.apgd_targeted.perturb(
x, y, **kwargs
) # cheap=True
elif attack == "fab-t":
# fab targeted
self.fab.targeted = True
self.fab.n_restarts = 1
self.fab.seed = self.get_seed()
adv_curr = self.fab.perturb(x, y, **kwargs) # cheap=True
adv_curr = self.fab.perturb(
x, y, **kwargs
) # cheap=True

else:
raise ValueError("Attack not supported")
Expand Down Expand Up @@ -432,7 +444,6 @@ def set_version(self, version="standard"):
self.attacks_to_run = ["apgd-ce", "apgd-t", "fab-t", "square"]
else:
self.attacks_to_run = ["apgd-ce", "fab", "square"]
# self.attacks_to_run = ['fab']
if self.norm in ["Linf", "L2"]:
self.apgd.n_restarts = 1
self.apgd_targeted.n_target_classes = 9
Expand Down Expand Up @@ -465,7 +476,7 @@ def set_version(self, version="standard"):
self.fab.n_target_classes = 9
self.apgd_targeted.n_target_classes = 9
self.square.n_queries = 5000
if not self.norm in ["Linf", "L2"]:
if self.norm not in ["Linf", "L2"]:
print(
'"{}" version is used with {} norm: please check'.format(
version, self.norm
Expand Down
56 changes: 33 additions & 23 deletions autoattack_modified/other_utils.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,77 @@
import os
import torch
import copy
import os

class Logger():

class Logger:
def __init__(self, log_path):
self.log_path = log_path

def log(self, str_to_log):
print(str_to_log)
if not self.log_path is None:
with open(self.log_path, 'a') as f:
f.write(str_to_log + '\n')
with open(self.log_path, "a") as f:
f.write(str_to_log + "\n")
f.flush()



def check_imgs(adv, x, norm):
delta = (adv - x).view(adv.shape[0], -1)
if norm == 'Linf':
if norm == "Linf":
res = delta.abs().max(dim=1)[0]
elif norm == 'L2':
res = (delta ** 2).sum(dim=1).sqrt()
elif norm == 'L1':
elif norm == "L2":
res = (delta**2).sum(dim=1).sqrt()
elif norm == "L1":
res = delta.abs().sum(dim=1)

str_det = 'max {} pert: {:.5f}, nan in imgs: {}, max in imgs: {:.5f}, min in imgs: {:.5f}'.format(
norm, res.max(), (adv != adv).sum(), adv.max(), adv.min())
str_det = "max {} pert: {:.5f}, nan in imgs: {}, max in imgs: {:.5f}, min in imgs: {:.5f}".format(
norm, res.max(), (adv != adv).sum(), adv.max(), adv.min()
)
print(str_det)

return str_det


def L1_norm(x, keepdim=False):
z = x.abs().view(x.shape[0], -1).sum(-1)
if keepdim:
z = z.view(-1, *[1]*(len(x.shape) - 1))
z = z.view(-1, *[1] * (len(x.shape) - 1))
return z


def L2_norm(x, keepdim=False):
z = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
z = (x**2).view(x.shape[0], -1).sum(-1).sqrt()
if keepdim:
z = z.view(-1, *[1]*(len(x.shape) - 1))
z = z.view(-1, *[1] * (len(x.shape) - 1))
return z


def L0_norm(x):
return (x != 0.).view(x.shape[0], -1).sum(-1)
return (x != 0.0).view(x.shape[0], -1).sum(-1)


def makedir(path):
if not os.path.exists(path):
os.makedirs(path)


def get_pred(output):
if output.size(-1) == 1:
return (output >= 0).squeeze().float()
return output.argmax(1)


def mask_kwargs(kwargs_orig, batch_datapoint_idcs):
kwargs = {}
if 'dino_targets' in kwargs_orig:
if "dino_targets" in kwargs_orig:
kwargs = copy.deepcopy(kwargs_orig)
kwargs['masks'] = kwargs_orig['masks'][batch_datapoint_idcs]
kwargs['dino_targets'] = []
kwargs["masks"] = kwargs_orig["masks"][batch_datapoint_idcs]
kwargs["dino_targets"] = []

if batch_datapoint_idcs.ndim == 0:
kwargs['dino_targets'].append(kwargs_orig['dino_targets'][batch_datapoint_idcs])
kwargs["dino_targets"].append(
kwargs_orig["dino_targets"][batch_datapoint_idcs]
)
else:
for i in batch_datapoint_idcs:
kwargs['dino_targets'].append(kwargs_orig['dino_targets'][i])
kwargs["dino_targets"].append(kwargs_orig["dino_targets"][i])
return kwargs
26 changes: 8 additions & 18 deletions autoattack_modified/square.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ def __init__(
self.device = device
self.return_all = False

@torch.no_grad()
def margin_and_loss(self, x, y, **kwargs):
"""
:param y: correct labels if untargeted else target labels
"""

logits = self.predict(x, **kwargs)
# EDIT: binary classification
if logits.size(-1) == 1:
Expand All @@ -88,10 +88,8 @@ def margin_and_loss(self, x, y, **kwargs):
if not self.targeted:
if self.loss == "ce":
return y_corr, -1.0 * xent
elif self.loss == "margin":
return y_corr, y_corr
else:
return -y_corr, xent
return y_corr, y_corr
return -y_corr, xent

xent = F.cross_entropy(logits, y, reduction="none")
u = torch.arange(x.shape[0])
Expand All @@ -102,10 +100,8 @@ def margin_and_loss(self, x, y, **kwargs):
if not self.targeted:
if self.loss == "ce":
return y_corr - y_others, -1.0 * xent
elif self.loss == "margin":
return y_corr - y_others, y_corr - y_others
else:
return y_others - y_corr, xent
return y_corr - y_others, y_corr - y_others
return y_others - y_corr, xent

def init_hyperparam(self, x):
assert self.norm in ["Linf", "L2", "L1"]
Expand Down Expand Up @@ -254,7 +250,6 @@ def p_selection(self, it):

def attack_single_run(self, x, y, **kwargs_orig):
with torch.no_grad():
adv = x.clone()
c, h, w = x.shape[1:]
n_features = c * h * w
n_ex_total = x.shape[0]
Expand All @@ -273,6 +268,7 @@ def attack_single_run(self, x, y, **kwargs_orig):

for i_iter in range(self.n_queries):
idx_to_fool = (margin_min > 0.0).nonzero().squeeze()
idx_to_fool = idx_to_fool.view(-1)

kwargs = mask_kwargs(kwargs_orig, idx_to_fool)

Expand Down Expand Up @@ -531,7 +527,6 @@ def attack_single_run(self, x, y, **kwargs_orig):
x_best = x + delta_init + r_best
margin_min, loss_min = self.margin_and_loss(x_best, y, **kwargs)
n_queries = torch.ones(x.shape[0]).to(self.device)
s_init = int(math.sqrt(self.p_init * n_features / c))

for i_iter in range(self.n_queries):
idx_to_fool = (margin_min > 0.0).nonzero().squeeze()
Expand Down Expand Up @@ -664,12 +659,8 @@ def attack_single_run(self, x, y, **kwargs_orig):
"- max pert={:.3f}".format(
norms_image.max().item()
),
#'- old pert={:.3f}'.format(norms_image_old.max().item())
)

assert (x_new != x_new).sum() == 0
assert (x_best != x_best).sum() == 0

if ind_succ.numel() == n_ex_total:
break

Expand Down Expand Up @@ -766,6 +757,5 @@ def perturb(self, x, y=None, **kwargs_orig):

if not self.return_all:
return adv
else:
print("returning final points")
return adv_all
print("returning final points")
return adv_all
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def main() -> None:
# Running evaluation
for attack in eval_attack:
# import pdb; pdb.set_trace()
# TODO: remove next line; only for debugging
if attack[0] == "no_attack": continue
# DEBUG: remove next line; only for debugging
# if attack[0] == "no_attack": continue

# Use DataParallel (not distributed) model for AutoAttack.
# Otherwise, DDP model can get timeout or c10d failure.
Expand Down

0 comments on commit 218af0c

Please sign in to comment.