From 218af0c23172f282715b250415f6064326f78e65 Mon Sep 17 00:00:00 2001 From: Chawin Sitawarin Date: Fri, 17 Feb 2023 10:46:59 -0800 Subject: [PATCH] Fix bug with square attack --- autoattack_modified/autoattack.py | 27 +++++++++----- autoattack_modified/other_utils.py | 56 ++++++++++++++++++------------ autoattack_modified/square.py | 26 +++++--------- main.py | 4 +-- 4 files changed, 62 insertions(+), 51 deletions(-) diff --git a/autoattack_modified/autoattack.py b/autoattack_modified/autoattack.py index bdb952d..383a82a 100644 --- a/autoattack_modified/autoattack.py +++ b/autoattack_modified/autoattack.py @@ -255,34 +255,46 @@ def run_standard_evaluation(self, x_orig, y_orig, bs=250, **kwargs_orig): # apgd on cross-entropy loss self.apgd.loss = "ce" self.apgd.seed = self.get_seed() - adv_curr = self.apgd.perturb(x, y, **kwargs) # cheap=True + adv_curr = self.apgd.perturb( + x, y, **kwargs + ) # cheap=True elif attack == "apgd-dlr": # apgd on dlr loss self.apgd.loss = "dlr" self.apgd.seed = self.get_seed() - adv_curr = self.apgd.perturb(x, y, **kwargs) # cheap=True + adv_curr = self.apgd.perturb( + x, y, **kwargs + ) # cheap=True elif attack == "fab": # fab self.fab.targeted = False self.fab.seed = self.get_seed() - adv_curr = self.fab.perturb(x, y, **kwargs) # cheap=True + adv_curr = self.fab.perturb( + x, y, **kwargs + ) # cheap=True elif attack == "square": # square self.square.seed = self.get_seed() - adv_curr = self.square.perturb(x, y, **kwargs) # cheap=True + adv_curr = self.square.perturb( + x, y, **kwargs + ) # cheap=True elif attack == "apgd-t": # targeted apgd self.apgd_targeted.seed = self.get_seed() - adv_curr = self.apgd_targeted.perturb(x, y, **kwargs) # cheap=True + adv_curr = self.apgd_targeted.perturb( + x, y, **kwargs + ) # cheap=True elif attack == "fab-t": # fab targeted self.fab.targeted = True self.fab.n_restarts = 1 self.fab.seed = self.get_seed() - adv_curr = self.fab.perturb(x, y, **kwargs) # cheap=True + adv_curr = self.fab.perturb( + x, y, **kwargs + ) # cheap=True else: raise ValueError("Attack not supported") @@ -432,7 +444,6 @@ def set_version(self, version="standard"): self.attacks_to_run = ["apgd-ce", "apgd-t", "fab-t", "square"] else: self.attacks_to_run = ["apgd-ce", "fab", "square"] - # self.attacks_to_run = ['fab'] if self.norm in ["Linf", "L2"]: self.apgd.n_restarts = 1 self.apgd_targeted.n_target_classes = 9 @@ -465,7 +476,7 @@ def set_version(self, version="standard"): self.fab.n_target_classes = 9 self.apgd_targeted.n_target_classes = 9 self.square.n_queries = 5000 - if not self.norm in ["Linf", "L2"]: + if self.norm not in ["Linf", "L2"]: print( '"{}" version is used with {} norm: please check'.format( version, self.norm diff --git a/autoattack_modified/other_utils.py b/autoattack_modified/other_utils.py index 138bbdd..9164246 100644 --- a/autoattack_modified/other_utils.py +++ b/autoattack_modified/other_utils.py @@ -1,67 +1,77 @@ -import os -import torch import copy +import os -class Logger(): + +class Logger: def __init__(self, log_path): self.log_path = log_path - + def log(self, str_to_log): print(str_to_log) if not self.log_path is None: - with open(self.log_path, 'a') as f: - f.write(str_to_log + '\n') + with open(self.log_path, "a") as f: + f.write(str_to_log + "\n") f.flush() - + + def check_imgs(adv, x, norm): delta = (adv - x).view(adv.shape[0], -1) - if norm == 'Linf': + if norm == "Linf": res = delta.abs().max(dim=1)[0] - elif norm == 'L2': - res = (delta ** 2).sum(dim=1).sqrt() - elif norm == 'L1': + elif norm == "L2": + res = (delta**2).sum(dim=1).sqrt() + elif norm == "L1": res = delta.abs().sum(dim=1) - str_det = 'max {} pert: {:.5f}, nan in imgs: {}, max in imgs: {:.5f}, min in imgs: {:.5f}'.format( - norm, res.max(), (adv != adv).sum(), adv.max(), adv.min()) + str_det = "max {} pert: {:.5f}, nan in imgs: {}, max in imgs: {:.5f}, min in imgs: {:.5f}".format( + norm, res.max(), (adv != adv).sum(), adv.max(), adv.min() + ) print(str_det) - + return str_det + def L1_norm(x, keepdim=False): z = x.abs().view(x.shape[0], -1).sum(-1) if keepdim: - z = z.view(-1, *[1]*(len(x.shape) - 1)) + z = z.view(-1, *[1] * (len(x.shape) - 1)) return z + def L2_norm(x, keepdim=False): - z = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() + z = (x**2).view(x.shape[0], -1).sum(-1).sqrt() if keepdim: - z = z.view(-1, *[1]*(len(x.shape) - 1)) + z = z.view(-1, *[1] * (len(x.shape) - 1)) return z + def L0_norm(x): - return (x != 0.).view(x.shape[0], -1).sum(-1) + return (x != 0.0).view(x.shape[0], -1).sum(-1) + def makedir(path): if not os.path.exists(path): os.makedirs(path) + def get_pred(output): if output.size(-1) == 1: return (output >= 0).squeeze().float() return output.argmax(1) + def mask_kwargs(kwargs_orig, batch_datapoint_idcs): kwargs = {} - if 'dino_targets' in kwargs_orig: + if "dino_targets" in kwargs_orig: kwargs = copy.deepcopy(kwargs_orig) - kwargs['masks'] = kwargs_orig['masks'][batch_datapoint_idcs] - kwargs['dino_targets'] = [] + kwargs["masks"] = kwargs_orig["masks"][batch_datapoint_idcs] + kwargs["dino_targets"] = [] if batch_datapoint_idcs.ndim == 0: - kwargs['dino_targets'].append(kwargs_orig['dino_targets'][batch_datapoint_idcs]) + kwargs["dino_targets"].append( + kwargs_orig["dino_targets"][batch_datapoint_idcs] + ) else: for i in batch_datapoint_idcs: - kwargs['dino_targets'].append(kwargs_orig['dino_targets'][i]) + kwargs["dino_targets"].append(kwargs_orig["dino_targets"][i]) return kwargs diff --git a/autoattack_modified/square.py b/autoattack_modified/square.py index 1abad31..92faa61 100644 --- a/autoattack_modified/square.py +++ b/autoattack_modified/square.py @@ -72,11 +72,11 @@ def __init__( self.device = device self.return_all = False + @torch.no_grad() def margin_and_loss(self, x, y, **kwargs): """ :param y: correct labels if untargeted else target labels """ - logits = self.predict(x, **kwargs) # EDIT: binary classification if logits.size(-1) == 1: @@ -88,10 +88,8 @@ def margin_and_loss(self, x, y, **kwargs): if not self.targeted: if self.loss == "ce": return y_corr, -1.0 * xent - elif self.loss == "margin": - return y_corr, y_corr - else: - return -y_corr, xent + return y_corr, y_corr + return -y_corr, xent xent = F.cross_entropy(logits, y, reduction="none") u = torch.arange(x.shape[0]) @@ -102,10 +100,8 @@ def margin_and_loss(self, x, y, **kwargs): if not self.targeted: if self.loss == "ce": return y_corr - y_others, -1.0 * xent - elif self.loss == "margin": - return y_corr - y_others, y_corr - y_others - else: - return y_others - y_corr, xent + return y_corr - y_others, y_corr - y_others + return y_others - y_corr, xent def init_hyperparam(self, x): assert self.norm in ["Linf", "L2", "L1"] @@ -254,7 +250,6 @@ def p_selection(self, it): def attack_single_run(self, x, y, **kwargs_orig): with torch.no_grad(): - adv = x.clone() c, h, w = x.shape[1:] n_features = c * h * w n_ex_total = x.shape[0] @@ -273,6 +268,7 @@ def attack_single_run(self, x, y, **kwargs_orig): for i_iter in range(self.n_queries): idx_to_fool = (margin_min > 0.0).nonzero().squeeze() + idx_to_fool = idx_to_fool.view(-1) kwargs = mask_kwargs(kwargs_orig, idx_to_fool) @@ -531,7 +527,6 @@ def attack_single_run(self, x, y, **kwargs_orig): x_best = x + delta_init + r_best margin_min, loss_min = self.margin_and_loss(x_best, y, **kwargs) n_queries = torch.ones(x.shape[0]).to(self.device) - s_init = int(math.sqrt(self.p_init * n_features / c)) for i_iter in range(self.n_queries): idx_to_fool = (margin_min > 0.0).nonzero().squeeze() @@ -664,12 +659,8 @@ def attack_single_run(self, x, y, **kwargs_orig): "- max pert={:.3f}".format( norms_image.max().item() ), - #'- old pert={:.3f}'.format(norms_image_old.max().item()) ) - assert (x_new != x_new).sum() == 0 - assert (x_best != x_best).sum() == 0 - if ind_succ.numel() == n_ex_total: break @@ -766,6 +757,5 @@ def perturb(self, x, y=None, **kwargs_orig): if not self.return_all: return adv - else: - print("returning final points") - return adv_all + print("returning final points") + return adv_all diff --git a/main.py b/main.py index 80089a4..863b11a 100644 --- a/main.py +++ b/main.py @@ -202,8 +202,8 @@ def main() -> None: # Running evaluation for attack in eval_attack: # import pdb; pdb.set_trace() - # TODO: remove next line; only for debugging - if attack[0] == "no_attack": continue + # DEBUG: remove next line; only for debugging + # if attack[0] == "no_attack": continue # Use DataParallel (not distributed) model for AutoAttack. # Otherwise, DDP model can get timeout or c10d failure.