diff --git a/torchattacks/attacks/fab.py b/torchattacks/attacks/fab.py
index 315af190..4ce83f8f 100644
--- a/torchattacks/attacks/fab.py
+++ b/torchattacks/attacks/fab.py
@@ -202,7 +202,7 @@ def projection_linf(self, t2, w2, b2):
         lmbd_opt = torch.unsqueeze(m, -1)
         d[c2] = torch.min(lmbd_opt, d[c2]) * c5[c2] + torch.max(-lmbd_opt, d[c2]) * (1 - c5[c2])  # nopep8
 
-        return d * (w != 0).type(torch.FloatTensor)
+        return d * (w != 0).type(torch.FloatTensor).to(self.device)
 
     def linear_approximation_search(self, clean_images, clean_labels, adv_images, niter):
         a1 = clean_images.clone()
diff --git a/torchattacks/attacks/fabl2.py b/torchattacks/attacks/fabl2.py
index db3068c5..2e3414f2 100644
--- a/torchattacks/attacks/fabl2.py
+++ b/torchattacks/attacks/fabl2.py
@@ -198,7 +198,7 @@ def projection_l2(self, t2, w2, b2):
             c5 = (alpha.unsqueeze(-1) > r[c2]).float()
             d[c2] = d[c2] * c5 - alpha.unsqueeze(-1) * w[c2] * (1 - c5)
 
-        return d * (w != 0).type(torch.FloatTensor)
+        return d * (w != 0).type(torch.FloatTensor).to(self.device)
 
     def linear_approximation_search(self, clean_images, clean_labels, adv_images, niter):
         a1 = clean_images.clone()