Skip to content

Commit

Permalink
Fix cuda error
Browse files Browse the repository at this point in the history
  • Loading branch information
rikonaka committed Mar 31, 2024
1 parent 8dedd74 commit 1ca9fe3
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion torchattacks/attacks/fab.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def projection_linf(self, t2, w2, b2):
lmbd_opt = torch.unsqueeze(m, -1)
d[c2] = torch.min(lmbd_opt, d[c2]) * c5[c2] + torch.max(-lmbd_opt, d[c2]) * (1 - c5[c2]) # nopep8

return d * (w != 0).type(torch.FloatTensor)
return d * (w != 0).type(torch.FloatTensor).to(self.device)

def linear_approximation_search(self, clean_images, clean_labels, adv_images, niter):
a1 = clean_images.clone()
Expand Down
2 changes: 1 addition & 1 deletion torchattacks/attacks/fabl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def projection_l2(self, t2, w2, b2):
c5 = (alpha.unsqueeze(-1) > r[c2]).float()
d[c2] = d[c2] * c5 - alpha.unsqueeze(-1) * w[c2] * (1 - c5)

return d * (w != 0).type(torch.FloatTensor)
return d * (w != 0).type(torch.FloatTensor).to(self.device)

def linear_approximation_search(self, clean_images, clean_labels, adv_images, niter):
a1 = clean_images.clone()
Expand Down

0 comments on commit 1ca9fe3

Please sign in to comment.