Skip to content

Commit

Permalink
mocks and fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
maurapintor committed Mar 17, 2024
1 parent 5df55e8 commit b7993d5
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 29 deletions.
4 changes: 4 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Configuration for tests."""
pytest_plugins = [
"secmlt.tests.fixtures",
]
2 changes: 1 addition & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ ignore = [
"ANN", # annotations for tests
"PT006", # mark parametrize
]
"*/tests/*" = ["D104"]
"*/tests/*.py" = ["D104"]
"setup.py" = ["D"]
"examples/*" = [
"D", # docstrings
Expand Down
File renamed without changes.
41 changes: 41 additions & 0 deletions src/secmlt/tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Fixtures used for testing."""
import pytest
import torch
from torch.utils.data import DataLoader, TensorDataset


@pytest.fixture(autouse=True, scope="session")
def data_loader() -> DataLoader[tuple[torch.Tensor]]:
"""
Create fake data loader.
Returns
-------
DataLoader[tuple[torch.Tensor]]
A loader with random samples and labels.
"""
# Create a dummy dataset loader for testing
data = torch.randn(100, 3, 32, 32)
labels = torch.randint(0, 10, (100,))
dataset = TensorDataset(data, labels)
return DataLoader(dataset, batch_size=10)


@pytest.fixture(autouse=True, scope="session")
def adv_loaders() -> list[DataLoader[tuple[torch.Tensor, ...]]]:
"""
Create fake adversarial loaders.
Returns
-------
list[DataLoader[Tuple[torch.Tensor, ...]]]
A list of multiple loaders (with same ordered labels).
"""
# Create a list of dummy adversarial example loaders for testing
loaders = []
adv_labels = torch.randint(0, 10, (100,))
for _ in range(3):
adv_data = torch.randn(100, 3, 32, 32)
adv_dataset = TensorDataset(adv_data, adv_labels)
loaders.append(DataLoader(adv_dataset, batch_size=10))
return loaders
12 changes: 12 additions & 0 deletions src/secmlt/tests/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Mock classes for testing."""
import torch


class MockModel(torch.nn.Module):
"""Mock class for torch model."""

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return random outputs for classification and add fake gradients to x."""
# Mock output shape (batch_size, 10)
x.grad = torch.rand_like(x)
return torch.randn(x.size(0), 10)
Original file line number Diff line number Diff line change
@@ -1,37 +1,12 @@
import pytest
import torch
from secmlt.adv.evasion.aggregators.ensemble import (
FixedEpsilonEnsemble,
MinDistanceEnsemble,
)
from torch.utils.data import DataLoader, TensorDataset
from secmlt.tests.mocks import MockModel


class MockModel(torch.nn.Module):
def forward(self, x):
# Mock output shape (batch_size, 10)
return torch.randn(x.size(0), 10)


@pytest.fixture()
def data_loader():
# Create a dummy dataset loader for testing
data = torch.randn(100, 3, 32, 32)
labels = torch.randint(0, 10, (100,))
dataset = TensorDataset(data, labels)
return DataLoader(dataset, batch_size=10)


@pytest.fixture()
def adv_loaders():
# Create a list of dummy adversarial example loaders for testing
adv_data = torch.randn(100, 3, 32, 32)
adv_labels = torch.randint(0, 10, (100,))
adv_dataset = TensorDataset(adv_data, adv_labels)
return [DataLoader(adv_dataset, batch_size=10) for _ in range(3)]


def test_min_distance_ensemble(data_loader, adv_loaders):
def test_min_distance_ensemble(data_loader, adv_loaders) -> None:
model = MockModel()
ensemble = MinDistanceEnsemble("l2")
result_loader = ensemble(model, data_loader, adv_loaders)
Expand All @@ -45,7 +20,7 @@ def test_min_distance_ensemble(data_loader, adv_loaders):
assert batch[1].shape == (10,) # Expected shape of original labels


def test_fixed_epsilon_ensemble(data_loader, adv_loaders):
def test_fixed_epsilon_ensemble(data_loader, adv_loaders) -> None:
model = MockModel()
loss_fn = torch.nn.CrossEntropyLoss()
ensemble = FixedEpsilonEnsemble(loss_fn)
Expand Down
File renamed without changes.
File renamed without changes.

0 comments on commit b7993d5

Please sign in to comment.