Skip to content

Commit

Permalink
✅ Improve tests
Browse files Browse the repository at this point in the history
🔨 Add safeties
  • Loading branch information
o-laurent committed Aug 24, 2023
1 parent f6051c8 commit 41773ee
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 7 deletions.
25 changes: 25 additions & 0 deletions tests/models/test_vggs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch

from torch_uncertainty.models.vgg.packed import (
packed_vgg11,
packed_vgg16,
packed_vgg19,
)
from torch_uncertainty.models.vgg.std import vgg13, vgg19


class TestStdVGG:
"""Testing the VGG std class."""

def test_main(self):
vgg13(1, 10, style="cifar")
vgg19(1, 10, norm=torch.nn.BatchNorm2d)


class TestPackedVGG:
"""Testing the VGG packed class."""

def test_main(self):
packed_vgg11(2, 10, 2, 2, 1)
packed_vgg16(2, 10, 2, 2, 1)
packed_vgg19(2, 10, 2, 2, 1)
24 changes: 24 additions & 0 deletions tests/models/test_wideresnets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from torch_uncertainty.models.wideresnet.batched import batched_wideresnet28x10
from torch_uncertainty.models.wideresnet.masked import masked_wideresnet28x10
from torch_uncertainty.models.wideresnet.packed import packed_wideresnet28x10


class TestPackedResnet:
"""Testing the WideResNet packed class."""

def test_main(self):
packed_wideresnet28x10(1, 2, 2, 1, 1, 10, style="cifar")


class TestMaskedWide:
"""Testing the WideResNet masked class."""

def test_main(self):
masked_wideresnet28x10(1, 2, 2, 1, 10, style="cifar")


class TestBatchedWide:
"""Testing the WideResNet batched class."""

def test_main(self):
batched_wideresnet28x10(1, 2, 1, 10, style="cifar")
4 changes: 4 additions & 0 deletions torch_uncertainty/baselines/classification/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ def __new__(
use_variation_ratio=use_variation_ratio,
**kwargs,
)
else:
raise ValueError(
f"{version} is not in {cls.single} nor {cls.ensemble}."
)

@classmethod
def load_from_checkpoint(
Expand Down
4 changes: 4 additions & 0 deletions torch_uncertainty/baselines/classification/vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ def __new__(
use_variation_ratio=use_variation_ratio,
**kwargs,
)
else:
raise ValueError(
f"{version} is not in {cls.single} nor {cls.ensemble}."
)

@classmethod
def load_from_checkpoint(
Expand Down
4 changes: 4 additions & 0 deletions torch_uncertainty/baselines/classification/wideresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def __new__(
use_variation_ratio=use_variation_ratio,
**kwargs,
)
else:
raise ValueError(
f"{version} is not in {cls.single} nor {cls.ensemble}."
)

@classmethod
def load_from_checkpoint(
Expand Down
5 changes: 1 addition & 4 deletions torch_uncertainty/models/resnet/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,7 @@ def __init__(
self.in_channels = in_channels
self.num_estimators = num_estimators
self.in_planes = 64
if self.in_planes % self.num_estimators:
self.in_planes += (
self.num_estimators - self.in_planes % self.num_estimators
)

block_planes = self.in_planes

if style == "imagenet":
Expand Down
1 change: 0 additions & 1 deletion torch_uncertainty/models/resnet/std.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(

# As in timm
self.dropout = nn.Dropout2d(p=dropout_rate)

self.conv2 = nn.Conv2d(
planes,
planes,
Expand Down
6 changes: 4 additions & 2 deletions torch_uncertainty/models/wideresnet/batched.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# fmt: off
from typing import Type

import torch.nn as nn
import torch.nn.functional as F

Expand Down Expand Up @@ -81,7 +83,7 @@ def __init__(
self.in_planes = 16

assert (depth - 4) % 6 == 0, "Wide-resnet depth should be 6n+4."
n = (depth - 4) / 6
n = (depth - 4) // 6
k = widen_factor

nStages = [16, 16 * k, 32 * k, 64 * k]
Expand Down Expand Up @@ -156,7 +158,7 @@ def __init__(

def _wide_layer(
self,
block: nn.Module,
block: Type[nn.Module],
planes: int,
num_blocks: int,
dropout_rate: float,
Expand Down

0 comments on commit 41773ee

Please sign in to comment.