Skip to content

Commit

Permalink
🐛 Fix Stochastic model
Browse files Browse the repository at this point in the history
✅ Improve tests
  • Loading branch information
o-laurent committed Aug 25, 2023
1 parent 2a85366 commit d3d499c
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 12 deletions.
24 changes: 21 additions & 3 deletions tests/layers/test_bayesian_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
BayesConv3d,
BayesLinear,
)
from torch_uncertainty.layers.bayesian_layers.sampler import (
TrainableDistribution,
)


# fmt:on
Expand Down Expand Up @@ -83,7 +86,9 @@ def test_conv1(self, feat_input_odd: torch.Tensor) -> None:
assert out.shape == torch.Size([2, 10])

def test_conv1_even(self, feat_input_even: torch.Tensor) -> None:
layer = BayesConv1d(8, 2, kernel_size=1, sigma_init=0)
layer = BayesConv1d(
8, 2, kernel_size=1, sigma_init=0, padding_mode="reflect"
)
out = layer(feat_input_even)
assert out.shape == torch.Size([2, 10])

Expand All @@ -105,7 +110,9 @@ def test_conv2(self, img_input_odd: torch.Tensor) -> None:
assert out.shape == torch.Size([5, 2, 3, 3])

def test_conv2_even(self, img_input_even: torch.Tensor) -> None:
layer = BayesConv2d(10, 2, kernel_size=1, sigma_init=0)
layer = BayesConv2d(
10, 2, kernel_size=1, sigma_init=0, padding_mode="reflect"
)
out = layer(img_input_even)
assert out.shape == torch.Size([8, 2, 3, 3])

Expand All @@ -127,9 +134,20 @@ def test_conv3(self, cube_input_odd: torch.Tensor) -> None:
assert out.shape == torch.Size([1, 2, 3, 3, 3])

def test_conv3_even(self, cube_input_even: torch.Tensor) -> None:
layer = BayesConv3d(10, 2, kernel_size=1, sigma_init=0)
layer = BayesConv3d(
10, 2, kernel_size=1, sigma_init=0, padding_mode="reflect"
)
out = layer(cube_input_even)
assert out.shape == torch.Size([2, 2, 3, 3, 3])

layer.freeze()
out = layer(cube_input_even)


class TestTrainableDistribution:
"""Testing the TrainableDistribution class."""

def test_error(self) -> None:
sampler = TrainableDistribution(torch.ones(1), torch.ones(1))
with pytest.raises(ValueError):
sampler.log_posterior()
13 changes: 13 additions & 0 deletions tests/models/test_mlps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# fmt: off
from torch_uncertainty.models.mlp import bayesian_mlp, packed_mlp


# fmt: on
class TestMLPModel:
"""Testing the mlp models."""

def test_packed(self):
packed_mlp(1, 1, hidden_dims=[])

def test_bayesian(self):
bayesian_mlp(1, 1, hidden_dims=[1, 1, 1])
26 changes: 22 additions & 4 deletions tests/models/test_stochastic_model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,38 @@
import torch

from torch_uncertainty.layers import BayesLinear
from torch_uncertainty.layers import BayesConv2d, BayesLinear
from torch_uncertainty.models.utils import StochasticModel


@StochasticModel
class DummyModel(torch.nn.Module):
class DummyModelLinear(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.layer = BayesLinear(1, 10, 1)

def forward(self, x):
return self.layer(x)


@StochasticModel
class DummyModelConv(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.layer = BayesConv2d(1, 10, 1)

def forward(self, x):
return self.layer(x)


class TestStochasticModel:
"""Testing the ResNet std class."""

def test_main(self):
model = DummyModel()
def test_main_linear(self):
model = DummyModelLinear()
model.freeze()
model.unfreeze()

def test_main_conv(self):
model = DummyModelConv()
model.freeze()
model.unfreeze()
14 changes: 9 additions & 5 deletions torch_uncertainty/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
# fmt: off
from torch import nn

from ..layers.bayesian_layers import bayesian_modules


def StochasticModel(Model):
# fmt: on
def StochasticModel(Model: nn.Module):
"""Decorator for stochastic models. When applied to a model, it adds the
freeze and unfreeze methods to the model. Use freeze to obtain
deterministic outputs. Use unfreeze to obtain stochastic outputs.
"""

def freeze(self):
def freeze(self) -> None:
for module in self.modules():
if isinstance(module, bayesian_modules):
module.freeze = True
module.freeze()

setattr(Model, "freeze", freeze)

def unfreeze(self):
def unfreeze(self) -> None:
for module in self.modules():
if isinstance(module, bayesian_modules):
module.freeze = False
module.unfreeze()

setattr(Model, "unfreeze", unfreeze)

Expand Down

0 comments on commit d3d499c

Please sign in to comment.