Skip to content

Commit

Permalink
✅ Improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Aug 24, 2023
1 parent 2386247 commit f6051c8
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 26 deletions.
17 changes: 5 additions & 12 deletions tests/layers/test_packed_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@ def feat_input_one_rearrange() -> torch.Tensor:
return feat


@pytest.fixture
def feat_input_two_rearrange() -> torch.Tensor:
feat = torch.rand((2 * 3, 5))
return feat


@pytest.fixture
def img_input() -> torch.Tensor:
img = torch.rand((5, 6, 3, 3))
Expand All @@ -50,12 +44,11 @@ def test_linear_one_estimator_rearrange(
out = layer(feat_input_one_rearrange)
assert out.shape == torch.Size([3, 2])

def test_linear_two_estimator_rearrange(
self, feat_input_two_rearrange: torch.Tensor
):
layer = PackedLinear(5, 2, alpha=1, num_estimators=1, rearrange=True)
out = layer(feat_input_two_rearrange)
assert out.shape == torch.Size([6, 2])
def test_linear_two_estimator_rearrange_not_divisible(self):
feat = torch.rand((2 * 3, 3))
layer = PackedLinear(5, 1, alpha=1, num_estimators=2, rearrange=True)
out = layer(feat)
assert out.shape == torch.Size([6, 1])

def test_linear_extend(self):
_ = PackedConv2d(
Expand Down
20 changes: 20 additions & 0 deletions tests/models/test_stochastic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch

from torch_uncertainty.layers import BayesLinear
from torch_uncertainty.models.utils import StochasticModel


@StochasticModel
class DummyModel(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.layer = BayesLinear(1, 10, 1)


class TestStochasticModel:
"""Testing the ResNet std class."""

def test_main(self):
model = DummyModel()
model.freeze()
model.unfreeze()
8 changes: 4 additions & 4 deletions tests/routines/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TestClassificationSingle:

def test_cli_main_dummy_binary(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("file.py", "--logits"):
with ArgvContext("file.py --logits"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)
Expand All @@ -46,7 +46,7 @@ def test_cli_main_dummy_binary(self):

def test_cli_main_dummy_ood(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("file.py", "--evaluate_ood", "--entropy"):
with ArgvContext("file.py --evaluate_ood --entropy"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)
Expand Down Expand Up @@ -78,7 +78,7 @@ class TestClassificationEnsemble:

def test_cli_main_dummy_binary(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("file.py", "--mutual_information"):
with ArgvContext("file.py --mutual_information"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)
Expand All @@ -100,7 +100,7 @@ def test_cli_main_dummy_binary(self):

def test_cli_main_dummy_ood(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("file.py", "--evaluate_ood", "--variation_ratio"):
with ArgvContext("file.py --evaluate_ood --variation_ratio"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)
Expand Down
6 changes: 3 additions & 3 deletions tests/routines/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TestRegressionSingle:

def test_cli_main_dummy_dist(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext(""):
with ArgvContext("file.py"):
args = init_args(DummyRegressionBaseline, DummyRegressionDataModule)

# datamodule
Expand All @@ -37,7 +37,7 @@ def test_cli_main_dummy_dist(self):

def test_cli_main_dummy(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext(""):
with ArgvContext("file.py"):
args = init_args(DummyRegressionBaseline, DummyRegressionDataModule)

# datamodule
Expand All @@ -61,7 +61,7 @@ class TestRegressionEnsemble:

def test_cli_main_dummy(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext(""):
with ArgvContext("file.py"):
args = init_args(DummyRegressionBaseline, DummyRegressionDataModule)

# datamodule
Expand Down
44 changes: 37 additions & 7 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# fmt: off
from pathlib import Path

import pytest
import torch.nn as nn
from cli_test_helpers import ArgvContext

Expand All @@ -22,7 +23,7 @@ class TestCLI:

def test_cli_main_resnet(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("--style cifar"):
with ArgvContext("file.py"):
args = init_args(ResNet, CIFAR10DataModule)

# datamodule
Expand All @@ -35,6 +36,7 @@ def test_cli_main_resnet(self):
model = ResNet(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
style="cifar",
loss=nn.CrossEntropyLoss,
optimization_procedure=optim_cifar10_resnet18,
**vars(args),
Expand All @@ -45,7 +47,7 @@ def test_cli_main_resnet(self):
def test_cli_main_other_arguments(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext(
"--seed 42 --max_epochs 1 --channels_last --style cifar"
"file.py --seed 42 --max_epochs 1 --channels_last",
):
args = init_args(ResNet, CIFAR10DataModule)

Expand All @@ -59,6 +61,7 @@ def test_cli_main_other_arguments(self):
model = ResNet(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
style="cifar",
loss=nn.CrossEntropyLoss,
optimization_procedure=optim_cifar10_resnet18,
**vars(args),
Expand All @@ -68,7 +71,7 @@ def test_cli_main_other_arguments(self):

def test_cli_main_wideresnet(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("--style cifar"):
with ArgvContext("file.py"):
args = init_args(WideResNet, CIFAR10DataModule)

# datamodule
Expand All @@ -89,7 +92,7 @@ def test_cli_main_wideresnet(self):

def test_cli_main_vgg(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("--style cifar"):
with ArgvContext("file.py"):
args = init_args(VGG, CIFAR10DataModule)

# datamodule
Expand All @@ -109,12 +112,12 @@ def test_cli_main_vgg(self):
cli_main(model, dm, root, "std", args)

def test_cli_main_mlp(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext(""):
root = str(Path(__file__).parent.absolute().parents[0])
with ArgvContext("file.py"):
args = init_args(MLP, UCIDataModule)

# datamodule
args.root = root / "data"
args.root = root + "/data"
dm = UCIDataModule(
dataset_name="kin8nm", input_shape=(1, 5), **vars(args)
)
Expand All @@ -132,3 +135,30 @@ def test_cli_main_mlp(self):
)

cli_main(model, dm, root, "std", args)

def test_cli_other_training_task(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("file.py"):
args = init_args(MLP, UCIDataModule)

# datamodule
args.root = root / "/data"
dm = UCIDataModule(
dataset_name="kin8nm", input_shape=(1, 5), **vars(args)
)

dm.training_task = "time-series-regression"

args.summary = True

model = MLP(
num_outputs=1,
in_features=5,
hidden_dims=[],
dist_estimation=False,
loss=nn.MSELoss,
optimization_procedure=optim_regression,
**vars(args),
)
with pytest.raises(ValueError):
cli_main(model, dm, root, "std", args)

0 comments on commit f6051c8

Please sign in to comment.