Skip to content

Commit

Permalink
Merge pull request #109 from fabiobrau/main
Browse files Browse the repository at this point in the history
Bug fix: Projection on the l2 ball
  • Loading branch information
maurapintor authored Nov 2, 2024
2 parents 1e7f856 + 49653dd commit 5eafb33
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/secmlt/optimization/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ def project(self, x: torch.Tensor) -> torch.Tensor:
"""
flat_x = x.flatten(start_dim=1)
diff_norm = flat_x.norm(p=2, dim=1, keepdim=True).clamp_(min=1e-12)
flat_x = torch.where(diff_norm <= 1, flat_x, flat_x / diff_norm) * self.radius
flat_x = torch.where(
diff_norm <= self.radius, flat_x, self.radius * flat_x / diff_norm
)
return flat_x.reshape(x.shape)


Expand Down

0 comments on commit 5eafb33

Please sign in to comment.