Skip to content

Commit

Permalink
Merge pull request #107 from pralab/105-step-size-is-not-passed-to-op…
Browse files Browse the repository at this point in the history
…timizer-in-modular-attack

approve after #104 - 105 step size is not passed to optimizer in modular attack
  • Loading branch information
zangobot authored Oct 14, 2024
2 parents bfad19e + 3abd378 commit 847942f
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 49 deletions.
10 changes: 3 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,12 @@ The library can be installed together with other plugins that enable further fun

* [Foolbox](https://github.com/bethgelab/foolbox), a Python toolbox to create adversarial examples.
* [Tensorboard](https://www.tensorflow.org/tensorboard), a visualization toolkit for machine learning experimentation.
* [Adversarial Library](https://github.com/jeromerony/adversarial-library), a powerful library of various adversarial attacks resources in PyTorch.

Install one or more extras with the command:
```bash
pip install secml-torch[foolbox,tensorboard]
```

To enable the `adv_lib` extra, you have to manually install the library from the original repository:

Install one or more extras with the command:
```bash
pip install git+https://github.com/jeromerony/adversarial-library
pip install secml-torch[foolbox,tensorboard, adv_lib]
```

## Key Features
Expand Down
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pytest
pytest-cov
foolbox
git+https://github.com/jeromerony/adversarial-library
adv-lib
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
extras_require={
"foolbox": ["foolbox>=3.3.0"],
"tensorboard": ["tensorboard"],
"adv_lib": ["adv_lib"],
},
python_requires=">=3.7",
)
6 changes: 4 additions & 2 deletions src/secmlt/adv/evasion/modular_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _init_perturbation_constraints(self) -> list[Constraint]:
raise NotImplementedError(msg)

def _create_optimizer(self, delta: torch.Tensor, **kwargs) -> Optimizer:
return self.optimizer_cls([delta], **kwargs)
return self.optimizer_cls([delta], lr=self.step_size, **kwargs)

def forward_loss(
self, model: BaseModel, x: torch.Tensor, target: torch.Tensor
Expand Down Expand Up @@ -181,8 +181,10 @@ def _run(
samples: torch.Tensor,
labels: torch.Tensor,
init_deltas: torch.Tensor = None,
**optim_kwargs,
optim_kwargs: dict | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if optim_kwargs is None:
optim_kwargs = {}
multiplier = 1 if self.y_target is not None else -1
target = (
torch.zeros_like(labels) + self.y_target
Expand Down
131 changes: 95 additions & 36 deletions src/secmlt/optimization/constraints.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
"""Constraints for tensors and the corresponding batch-wise projections."""

from abc import ABC, abstractmethod
from typing import Union

import torch
from secmlt.adv.evasion.perturbation_models import LpPerturbationModels
from secmlt.models.data_processing.data_processing import DataProcessing
from secmlt.models.data_processing.identity_data_processing import (
IdentityDataProcessing,
)


class Constraint(ABC):
"""Generic constraint."""

@abstractmethod
def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Project onto the constraint.
Expand All @@ -24,7 +28,50 @@ def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
torch.Tensor
Tensor projected onto the constraint.
"""
...
x_transformed = x.detach().clone()
return self._apply_constraint(x_transformed)

@abstractmethod
def _apply_constraint(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: ...


class InputSpaceConstraint(Constraint, ABC):
"""Input space constraint.
Reverts the preprocessing, applies the constraint, and re-applies the preprocessing.
"""

def __init__(self, preprocessing: DataProcessing) -> None:
"""
Create InputSpaceConstraint.
Parameters
----------
preprocessing : DataProcessing
Preprocessing to invert to apply the constraint on the input space.
"""
if preprocessing is None:
preprocessing = IdentityDataProcessing()
self.preprocessing = preprocessing

def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Project onto the constraint in the input space.
Parameters
----------
x : torch.Tensor
Input tensor.
Returns
-------
torch.Tensor
Tensor projected onto the constraint.
"""
x_transformed = x.detach().clone()
x_transformed = self.preprocessing.invert(x_transformed)
x_transformed = self._apply_constraint(x_transformed)
return self.preprocessing(x_transformed)


class ClipConstraint(Constraint):
Expand All @@ -44,7 +91,7 @@ def __init__(self, lb: float = 0.0, ub: float = 1.0) -> None:
self.lb = lb
self.ub = ub

def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
def _apply_constraint(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Call the projection function.
Expand Down Expand Up @@ -103,7 +150,7 @@ def project(self, x: torch.Tensor) -> torch.Tensor:
"""
...

def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
def _apply_constraint(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Project the samples onto the Lp constraint.
Expand Down Expand Up @@ -301,44 +348,56 @@ def project(self, x: torch.Tensor) -> torch.Tensor:
return proj.view_as(x)


class QuantizationConstraint(Constraint):
class QuantizationConstraint(InputSpaceConstraint):
"""Constraint for ensuring quantized outputs into specified levels."""

def __init__(self, levels: int) -> None:
def __init__(
self,
preprocessing: DataProcessing = None,
levels: Union[list[float], torch.Tensor, int] = 255,
) -> None:
"""
Create the QuantizationConstraint.
Parameters
----------
levels : int
Number of levels
"""
if int(levels) != levels:
msg = (
f"Pass either an integer or a float with no decimals for "
f"the number of levels (current value: {levels})."
)
raise ValueError(msg)
self.levels = levels
super().__init__()

def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""
Enforce the quantization constraint.
Parameters
----------
x : torch.Tensor
Non-quantized input tensor.
Returns
-------
torch.Tensor
Input with values quantized on the specified
number of levels.
"""
# the -1 there is to count for the 0
return (x * (self.levels - 1)).round() / (self.levels - 1)
preprocessing: DataProcessing
Preprocessing to apply the constraint in the input space.
levels : int, list[float] | torch.Tensor
Number of levels or specified levels.
"""
if isinstance(levels, int | float):
if levels < 2: # noqa: PLR2004
msg = "Number of levels must be at least 2."
raise ValueError(msg)
if int(levels) != levels:
msg = "Pass an integer number of levels."
raise ValueError(msg)
# create uniform levels if an integer is provided
self.levels = torch.linspace(0, 1, int(levels))
elif isinstance(levels, list):
self.levels = torch.tensor(levels, dtype=torch.float32)
elif isinstance(levels, torch.Tensor):
self.levels = levels.type(torch.float32)
if len(self.levels) < 2: # noqa: PLR2004
msg = "Number of custom levels must be at least 2."
raise ValueError(msg)
else:
msg = "Levels must be an integer, list, or torch.Tensor."
raise TypeError(msg)
# sort levels to ensure they are in ascending order
self.levels = self.levels.sort().values # noqa: PD011
super().__init__(preprocessing)

def _apply_constraint(self, x: torch.Tensor) -> torch.Tensor:
# reshape x to facilitate broadcasting with custom levels
x_expanded = x.unsqueeze(-1)
# calculate the absolute difference between x and each custom level
distances = torch.abs(x_expanded - self.levels)
# find the index of the closest custom level
nearest_indices = torch.argmin(distances, dim=-1)
# quantize x to the nearest custom level
return self.levels[nearest_indices]


class MaskConstraint(Constraint):
Expand All @@ -356,7 +415,7 @@ def __init__(self, mask: torch.Tensor) -> None:
self.mask = mask.type(torch.bool)
super().__init__()

def __call__(self, x: torch.Tensor) -> torch.Tensor:
def _apply_constraint(self, x: torch.Tensor) -> torch.Tensor:
"""
Enforce the mask constraint.
Expand Down
15 changes: 13 additions & 2 deletions src/secmlt/tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,13 @@ def test_l0_constraint_invalid_radius():
),
(
torch.tensor([[0.1, 0.9], [0.3, 0.6]]),
3,
torch.tensor([[0.0, 1.0], [0.5, 0.5]]),
torch.Tensor([0.1, 0.2, 0.5]),
torch.tensor([[0.1, 0.5], [0.2, 0.5]]),
),
(
torch.tensor([[0.1, 0.9], [0.3, 0.6]]),
[0.1, 0.2, 0.5],
torch.tensor([[0.1, 0.5], [0.2, 0.5]]),
),
],
)
Expand All @@ -116,6 +121,12 @@ def test_quantization_constraint_invalid_levels():
QuantizationConstraint(levels=2.5)


def test_quantization_constraint_not_enough_levels():
# test that passing a number of levels < 2 values raises an error
with pytest.raises(ValueError): # noqa: PT011
QuantizationConstraint(levels=1)


@pytest.mark.parametrize(
"x, mask, expected",
[
Expand Down
2 changes: 1 addition & 1 deletion src/secmlt/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class MockConstraint(Constraint):
def __init__(self, mock_return):
self.mock_return = mock_return

def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
def _apply_constraint(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
return self.mock_return


Expand Down

0 comments on commit 847942f

Please sign in to comment.