Skip to content

Commit

Permalink
✅ Add tests for new transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
alafage committed Aug 26, 2023
1 parent 3f49f11 commit 13a6439
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 60 deletions.
56 changes: 56 additions & 0 deletions tests/transforms/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# fmt:off
from typing import Tuple

import pytest
import torch
from PIL import Image
Expand All @@ -10,7 +12,9 @@
Color,
Contrast,
Equalize,
MIMOBatchFormat,
Posterize,
RepeatTarget,
Rotation,
Sharpness,
Shear,
Expand All @@ -27,6 +31,12 @@ def img_input() -> torch.Tensor:
return im


@pytest.fixture
def batch_input() -> Tuple[torch.Tensor, torch.Tensor]:
imgs = torch.rand(2, 3, 28, 28)
return imgs, torch.tensor([0, 1])


class TestAutoContrast:
"""Testing the AutoContrast transform."""

Expand Down Expand Up @@ -159,3 +169,49 @@ def test_failures(self, img_input):
aug = Color()
with pytest.raises(ValueError):
_ = aug(img_input, -1)


class TestRepeatTarget:
"""Testing the RepeatTarget transform."""

def test_batch(self, batch_input):
fn = RepeatTarget(3)
_, target = fn(batch_input)
assert target.shape == (6,)

def test_failures(self):
with pytest.raises(ValueError):
_ = RepeatTarget(1.2)

with pytest.raises(ValueError):
_ = RepeatTarget(0)


class TestMIMOBatchFormat:
"""Testing the MIMOBatchFormat transform."""

def test_batch(self, batch_input):
b, c, h, w = batch_input[0].shape

fn = MIMOBatchFormat(1, 0, 1)
imgs, target = fn(batch_input)
assert imgs.shape == (b, c, h, w)
assert target.shape == (b,)

fn = MIMOBatchFormat(4, 0, 2)
imgs, target = fn(batch_input)
assert imgs.shape == (b * 4 * 2, 3, 28, 28)
assert target.shape == (b * 4 * 2,)

def test_failures(self):
with pytest.raises(ValueError):
_ = MIMOBatchFormat(0, 0, 1)

with pytest.raises(ValueError):
_ = MIMOBatchFormat(1, -1, 1)

with pytest.raises(ValueError):
_ = MIMOBatchFormat(1, 1.2, 1)

with pytest.raises(ValueError):
_ = MIMOBatchFormat(1, 0, 0)
2 changes: 1 addition & 1 deletion torch_uncertainty/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# flake8: noqa
from .cutout import Cutout
from .mimo_format import MIMOBatchFormat
from .transforms import (
AutoContrast,
Brightness,
Color,
Contrast,
Equalize,
MIMOBatchFormat,
Posterize,
RepeatTarget,
Rotation,
Expand Down
59 changes: 0 additions & 59 deletions torch_uncertainty/transforms/mimo_format.py

This file was deleted.

67 changes: 67 additions & 0 deletions torch_uncertainty/transforms/transforms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# fmt: off
from typing import List, Optional, Tuple, Union

import torch
import torchvision.transforms.functional as F
from einops import rearrange
from PIL import Image, ImageEnhance
from torch import Tensor, nn

Expand Down Expand Up @@ -255,8 +257,73 @@ def forward(
class RepeatTarget(nn.Module):
def __init__(self, num_repeats: int) -> None:
super().__init__()

if not isinstance(num_repeats, int):
raise ValueError("num_repeats must be an integer.")
if num_repeats <= 0:
raise ValueError("num_repeats must be greater than 0.")

self.num_repeats = num_repeats

def forward(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
inputs, targets = batch
return inputs, targets.repeat(self.num_repeats)


class MIMOBatchFormat(nn.Module):
def __init__(
self, num_estimators: int, rho: float = 0.0, batch_repeat: int = 1
) -> None:
super().__init__()

if num_estimators <= 0:
raise ValueError("num_estimators must be greater than 0.")
if not (0.0 <= rho <= 1.0):
raise ValueError("rho must be between 0 and 1.")
if batch_repeat <= 0:
raise ValueError("batch_repeat must be greater than 0.")

self.num_estimators = num_estimators
self.rho = rho
self.batch_repeat = batch_repeat

def shuffle(self, inputs: Tensor):
idx = torch.randperm(inputs.nelement(), device=inputs.device)
return inputs.view(-1)[idx].view(inputs.size())

def forward(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
inputs, targets = batch
indexes = torch.arange(
0, inputs.shape[0], device=inputs.device, dtype=torch.int64
).repeat(self.batch_repeat)
main_shuffle = self.shuffle(indexes)
threshold_shuffle = int(main_shuffle.shape[0] * (1.0 - self.rho))
shuffle_indices = [
torch.concat(
[
self.shuffle(main_shuffle[:threshold_shuffle]),
main_shuffle[threshold_shuffle:],
],
axis=0,
)
for _ in range(self.num_estimators)
]
inputs = torch.stack(
[
torch.index_select(inputs, dim=0, index=indices)
for indices in shuffle_indices
],
axis=0,
)
targets = torch.stack(
[
torch.index_select(targets, dim=0, index=indices)
for indices in shuffle_indices
],
axis=0,
)
inputs = rearrange(
inputs, "m b c h w -> (m b) c h w", m=self.num_estimators
)
targets = rearrange(targets, "m b -> (m b)", m=self.num_estimators)
return inputs, targets

0 comments on commit 13a6439

Please sign in to comment.