diff --git a/torchattacks/attacks/fab.py b/torchattacks/attacks/fab.py index 315af190..4ce83f8f 100644 --- a/torchattacks/attacks/fab.py +++ b/torchattacks/attacks/fab.py @@ -202,7 +202,7 @@ def projection_linf(self, t2, w2, b2): lmbd_opt = torch.unsqueeze(m, -1) d[c2] = torch.min(lmbd_opt, d[c2]) * c5[c2] + torch.max(-lmbd_opt, d[c2]) * (1 - c5[c2]) # nopep8 - return d * (w != 0).type(torch.FloatTensor) + return d * (w != 0).type(torch.FloatTensor).to(self.device) def linear_approximation_search(self, clean_images, clean_labels, adv_images, niter): a1 = clean_images.clone() diff --git a/torchattacks/attacks/fabl2.py b/torchattacks/attacks/fabl2.py index db3068c5..2e3414f2 100644 --- a/torchattacks/attacks/fabl2.py +++ b/torchattacks/attacks/fabl2.py @@ -198,7 +198,7 @@ def projection_l2(self, t2, w2, b2): c5 = (alpha.unsqueeze(-1) > r[c2]).float() d[c2] = d[c2] * c5 - alpha.unsqueeze(-1) * w[c2] * (1 - c5) - return d * (w != 0).type(torch.FloatTensor) + return d * (w != 0).type(torch.FloatTensor).to(self.device) def linear_approximation_search(self, clean_images, clean_labels, adv_images, niter): a1 = clean_images.clone()