Skip to content

Commit

Permalink
✅ Improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Aug 25, 2023
1 parent d3d499c commit 98aa65b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tests/layers/test_bayesian_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_conv3_even(self, cube_input_even: torch.Tensor) -> None:
class TestTrainableDistribution:
"""Testing the TrainableDistribution class."""

def test_error(self) -> None:
def test_log_posterior(self) -> None:
sampler = TrainableDistribution(torch.ones(1), torch.ones(1))
with pytest.raises(ValueError):
sampler.log_posterior()
5 changes: 5 additions & 0 deletions tests/models/test_deep_ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def test_list_and_num_estimators(self):
with pytest.raises(ValueError):
deep_ensembles([model_1, model_2], num_estimators=2)

def test_list_singleton(self):
model_1 = dummy_model(1, 10, 1)
with pytest.raises(ValueError):
deep_ensembles([model_1], num_estimators=2)

def test_model_and_no_num_estimator(self):
model_1 = dummy_model(1, 10, 1)
with pytest.raises(ValueError):
Expand Down
8 changes: 5 additions & 3 deletions torch_uncertainty/models/vgg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,16 @@ def _init_weights(self):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
if m.bias is not None:
if m.bias is not None: # coverage: ignore
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if m.bias is not None: # coverage: ignore
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear) or isinstance(m, PackedLinear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
if m.bias is not None: # coverage: ignore
nn.init.constant_(m.bias, 0)


def _vgg(
Expand Down

0 comments on commit 98aa65b

Please sign in to comment.