Skip to content

Commit

Permalink
fix errors and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maurapintor committed Oct 13, 2024
1 parent bf31901 commit 4c58f2f
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 8 deletions.
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pytest
pytest-cov
foolbox
git+https://github.com/jeromerony/adversarial-library
adv-lib
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
extras_require={
"foolbox": ["foolbox>=3.3.0"],
"tensorboard": ["tensorboard"],
"adv_lib": ["adv_lib"],
},
python_requires=">=3.7",
)
16 changes: 12 additions & 4 deletions src/secmlt/optimization/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import torch
from secmlt.adv.evasion.perturbation_models import LpPerturbationModels
from secmlt.models.data_processing.data_processing import DataProcessing
from secmlt.models.data_processing.identity_data_processing import (
IdentityDataProcessing,
)


class Constraint(ABC):
Expand Down Expand Up @@ -47,6 +50,8 @@ def __init__(self, preprocessing: DataProcessing) -> None:
preprocessing : DataProcessing
Preprocessing to invert to apply the constraint on the input space.
"""
if preprocessing is None:
preprocessing = IdentityDataProcessing()
self.preprocessing = preprocessing

def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
Expand Down Expand Up @@ -349,7 +354,7 @@ class QuantizationConstraint(InputSpaceConstraint):
def __init__(
self,
preprocessing: DataProcessing = None,
levels: Union[list, torch.Tensor, int] = 255,
levels: Union[list[float], torch.Tensor, int] = 255,
) -> None:
"""
Create the QuantizationConstraint.
Expand All @@ -358,15 +363,18 @@ def __init__(
----------
preprocessing: DataProcessing
Preprocessing to apply the constraint in the input space.
levels : int | torch.Tensor, int
levels : int, list[float] | torch.Tensor
Number of levels or specified levels.
"""
if isinstance(levels, int):
if isinstance(levels, int | float):
if levels < 2: # noqa: PLR2004
msg = "Number of levels must be at least 2."
raise ValueError(msg)
if int(levels) != levels:
msg = "Pass an integer number of levels."
raise ValueError(msg)
# create uniform levels if an integer is provided
self.levels = torch.linspace(0, 1, levels)
self.levels = torch.linspace(0, 1, int(levels))
elif isinstance(levels, list):
self.levels = torch.tensor(levels, dtype=torch.float32)
elif isinstance(levels, torch.Tensor):
Expand Down
15 changes: 13 additions & 2 deletions src/secmlt/tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,13 @@ def test_l0_constraint_invalid_radius():
),
(
torch.tensor([[0.1, 0.9], [0.3, 0.6]]),
3,
torch.tensor([[0.0, 1.0], [0.5, 0.5]]),
torch.Tensor([0.1, 0.2, 0.5]),
torch.tensor([[0.1, 0.5], [0.2, 0.5]]),
),
(
torch.tensor([[0.1, 0.9], [0.3, 0.6]]),
[0.1, 0.2, 0.5],
torch.tensor([[0.1, 0.5], [0.2, 0.5]]),
),
],
)
Expand All @@ -116,6 +121,12 @@ def test_quantization_constraint_invalid_levels():
QuantizationConstraint(levels=2.5)


def test_quantization_constraint_not_enough_levels():
# test that passing a number of levels < 2 values raises an error
with pytest.raises(ValueError): # noqa: PT011
QuantizationConstraint(levels=1)


@pytest.mark.parametrize(
"x, mask, expected",
[
Expand Down
2 changes: 1 addition & 1 deletion src/secmlt/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class MockConstraint(Constraint):
def __init__(self, mock_return):
self.mock_return = mock_return

def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
def _apply_constraint(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
return self.mock_return


Expand Down

0 comments on commit 4c58f2f

Please sign in to comment.