From 49653dd1bc1e1c8b66275e707f5ec227b7624ae7 Mon Sep 17 00:00:00 2001 From: Fabio Brau Date: Thu, 31 Oct 2024 10:37:55 +0100 Subject: [PATCH] Bug fix: Projection on the l2 ball --- src/secmlt/optimization/constraints.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/secmlt/optimization/constraints.py b/src/secmlt/optimization/constraints.py index df77dd5..acbc54e 100644 --- a/src/secmlt/optimization/constraints.py +++ b/src/secmlt/optimization/constraints.py @@ -205,7 +205,9 @@ def project(self, x: torch.Tensor) -> torch.Tensor: """ flat_x = x.flatten(start_dim=1) diff_norm = flat_x.norm(p=2, dim=1, keepdim=True).clamp_(min=1e-12) - flat_x = torch.where(diff_norm <= 1, flat_x, flat_x / diff_norm) * self.radius + flat_x = torch.where( + diff_norm <= self.radius, flat_x, self.radius * flat_x / diff_norm + ) return flat_x.reshape(x.shape)