Skip to content

Commit

Permalink
Add missing tests dir
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Dec 5, 2023
1 parent f75e73d commit 53a8630
Show file tree
Hide file tree
Showing 9 changed files with 379 additions and 0 deletions.
Empty file added tests/__init__.py
Empty file.
Empty file added tests/layers/__init__.py
Empty file.
Empty file.
72 changes: 72 additions & 0 deletions tests/layers/test_einsum_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from unittest import TestCase

import torch
from parameterized import parameterized

from simple_einet.abstract_layers import logits_to_log_weights
from simple_einet.layers.einsum import EinsumLayer
from simple_einet.sampling_utils import index_one_hot
from tests.layers.test_utils import get_sampling_context


class TestEinsumLayer(TestCase):
def setUp(self) -> None:
self.layer = EinsumLayer(num_features=4, num_sums_in=3, num_sums_out=2, num_repetitions=5)

def test_logits_to_log_weights(self):
for dim in range(self.layer.logits.dim()):
log_weights = logits_to_log_weights(self.layer.logits, dim=dim)
sums = log_weights.logsumexp(dim=dim)
target = torch.zeros_like(sums)
self.assertTrue(torch.allclose(sums, target, atol=1e-5))

def test_forward_shape(self):
bs = 2
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions)
out = self.layer(x)
self.assertEqual(
out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out, self.layer.num_repetitions)
)

@parameterized.expand([(False,), (True,)])
def test__sample_from_weights(self, differentiable: bool):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable)
log_weights = self.layer._select_weights(ctx, self.layer.logits)
indices = self.layer._sample_from_weights(ctx, log_weights)
if differentiable:
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in))
else:
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features))

@parameterized.expand([(False,), (True,)])
def test__select_weights(self, differentiable: bool):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable)
weights = self.layer._select_weights(ctx, self.layer.logits)
self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in**2))

@parameterized.expand([(False,), (True,)])
def test__condition_weights_on_evidence(self, differentiable: bool):
bs = 2
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions)
self.layer._enable_input_cache()
self.layer(x)

ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable)
log_weights = self.layer._select_weights(ctx, self.layer.logits)
log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights)
sums = log_weights.logsumexp(dim=2)
target = torch.zeros_like(sums)
self.assertTrue(torch.allclose(sums, target, atol=1e-5))

def test__differentiable_sampling_has_grads(self):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True)
ctx = self.layer.sample(ctx)

sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions)
sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1)
sample = index_one_hot(sample, index=ctx.indices_out, dim=-1)
sample.mean().backward()
self.assertTrue(self.layer.logits.grad is not None)
72 changes: 72 additions & 0 deletions tests/layers/test_linsum_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from unittest import TestCase

import torch
from parameterized import parameterized

from simple_einet.abstract_layers import logits_to_log_weights
from simple_einet.layers.linsum import LinsumLayer
from simple_einet.sampling_utils import index_one_hot
from tests.layers.test_utils import get_sampling_context


class TestLinsumLayer(TestCase):
def setUp(self) -> None:
self.layer = LinsumLayer(num_features=4, num_sums_in=3, num_sums_out=2, num_repetitions=5)

def test_logits_to_log_weights(self):
for dim in range(self.layer.logits.dim()):
log_weights = logits_to_log_weights(self.layer.logits, dim=dim)
sums = log_weights.logsumexp(dim=dim)
target = torch.zeros_like(sums)
self.assertTrue(torch.allclose(sums, target, atol=1e-5))

def test_forward_shape(self):
bs = 2
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions)
out = self.layer(x)
self.assertEqual(
out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out, self.layer.num_repetitions)
)

@parameterized.expand([(False,), (True,)])
def test__condition_weights_on_evidence(self, differentiable: bool):
bs = 2
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions)
self.layer._enable_input_cache()
self.layer(x)

ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable)
log_weights = self.layer._select_weights(ctx, self.layer.logits)
log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights)
sums = log_weights.logsumexp(dim=2)
target = torch.zeros_like(sums)
self.assertTrue(torch.allclose(sums, target, atol=1e-5))

@parameterized.expand([(False,), (True,)])
def test__sample_from_weights(self, differentiable: bool):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable)
log_weights = self.layer._select_weights(ctx, self.layer.logits)
indices = self.layer._sample_from_weights(ctx, log_weights)
if differentiable:
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in))
else:
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features))

@parameterized.expand([(False,), (True,)])
def test__select_weights(self, differentiable: bool):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable)
weights = self.layer._select_weights(ctx, self.layer.logits)
self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in))

def test__differentiable_sampling_has_grads(self):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True)
ctx = self.layer.sample(ctx)

sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions)
sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1)
sample = index_one_hot(sample, index=ctx.indices_out, dim=-1)
sample.mean().backward()
self.assertTrue(self.layer.logits.grad is not None)
70 changes: 70 additions & 0 deletions tests/layers/test_mixing_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from unittest import TestCase

import torch
from parameterized import parameterized

from simple_einet.abstract_layers import logits_to_log_weights
from simple_einet.layers.mixing import MixingLayer
from simple_einet.sampling_utils import index_one_hot
from tests.layers.test_utils import get_sampling_context


class TestMixingLayer(TestCase):
def setUp(self) -> None:
self.layer = MixingLayer(num_features=1, num_sums_in=3, num_sums_out=2)

def test_logits_to_log_weights(self):
for dim in range(self.layer.logits.dim()):
log_weights = logits_to_log_weights(self.layer.logits, dim=dim)
sums = log_weights.logsumexp(dim=dim)
target = torch.zeros_like(sums)
self.assertTrue(torch.allclose(sums, target, atol=1e-5))

def test_forward_shape(self):
bs = 2
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_out, self.layer.num_sums_in)
out = self.layer(x)
self.assertEqual(out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out))

@parameterized.expand([(False,), (True,)])
def test__condition_weights_on_evidence(self, differentiable: bool):
bs = 2
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_out, self.layer.num_sums_in)
self.layer._enable_input_cache()
self.layer(x)

ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable)
log_weights = self.layer._select_weights(ctx, self.layer.logits)
log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights)
sums = log_weights.logsumexp(dim=2)
target = torch.zeros_like(sums)
self.assertTrue(torch.allclose(sums, target, atol=1e-5))

@parameterized.expand([(False,), (True,)])
def test__sample_from_weights(self, differentiable: bool):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable)
log_weights = self.layer._select_weights(ctx, self.layer.logits)
indices = self.layer._sample_from_weights(ctx, log_weights)
if differentiable:
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in))
else:
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features))

@parameterized.expand([(False,), (True,)])
def test__select_weights(self, differentiable: bool):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable)
weights = self.layer._select_weights(ctx, self.layer.logits)
self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in))

def test__differentiable_sampling_has_grads(self):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True)
ctx = self.layer.sample(ctx)

sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions)
sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1)
sample = index_one_hot(sample, index=ctx.indices_out, dim=-1)
sample.mean().backward()
self.assertTrue(self.layer.logits.grad is not None)
72 changes: 72 additions & 0 deletions tests/layers/test_sum_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from unittest import TestCase

import torch
from parameterized import parameterized

from simple_einet.abstract_layers import logits_to_log_weights
from simple_einet.layers.sum import SumLayer
from simple_einet.sampling_utils import index_one_hot
from tests.layers.test_utils import get_sampling_context


class TestSumLayer(TestCase):
def setUp(self) -> None:
self.layer = SumLayer(num_features=4, num_sums_in=3, num_sums_out=2, num_repetitions=5)

def test_logits_to_log_weights(self):
for dim in range(self.layer.logits.dim()):
log_weights = logits_to_log_weights(self.layer.logits, dim=dim)
sums = log_weights.logsumexp(dim=dim)
target = torch.zeros_like(sums)
self.assertTrue(torch.allclose(sums, target, atol=1e-5))

def test_forward_shape(self):
bs = 2
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions)
out = self.layer(x)
self.assertEqual(
out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out, self.layer.num_repetitions)
)

@parameterized.expand([(False,), (True,)])
def test__sample_from_weights(self, differentiable: bool):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable)
log_weights = self.layer._select_weights(ctx, self.layer.logits)
indices = self.layer._sample_from_weights(ctx, log_weights)
if differentiable:
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in))
else:
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features))

@parameterized.expand([(False,), (True,)])
def test__select_weights(self, differentiable: bool):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable)
weights = self.layer._select_weights(ctx, self.layer.logits)
self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in))

@parameterized.expand([(False,), (True,)])
def test__condition_weights_on_evidence(self, differentiable: bool):
bs = 2
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions)
self.layer._enable_input_cache()
self.layer(x)

ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable)
log_weights = self.layer._select_weights(ctx, self.layer.logits)
log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights)
sums = log_weights.logsumexp(dim=2)
target = torch.zeros_like(sums)
self.assertTrue(torch.allclose(sums, target, atol=1e-5))

def test__differentiable_sampling_has_grads(self):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True)
ctx = self.layer.sample(ctx)

sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions)
sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1)
sample = index_one_hot(sample, index=ctx.indices_out, dim=-1)
sample.mean().backward()
self.assertTrue(self.layer.logits.grad is not None)
27 changes: 27 additions & 0 deletions tests/layers/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from torch.nn import functional as F

from simple_einet.sampling_utils import SamplingContext


def get_sampling_context(layer, num_samples: int, is_differentiable: bool = False):
if is_differentiable:
indices_out = torch.randint(low=0, high=layer.num_sums_out, size=(num_samples, layer.num_features_out))
one_hot_indices_out = F.one_hot(indices_out, num_classes=layer.num_sums_out).float()
indices_repetition = torch.randint(low=0, high=layer.num_repetitions, size=(num_samples,))
one_hot_indices_repetition = F.one_hot(indices_repetition, num_classes=layer.num_repetitions).float()
one_hot_indices_out.requires_grad_(True)
one_hot_indices_repetition.requires_grad_(True)
return SamplingContext(
num_samples=num_samples,
indices_out=one_hot_indices_out,
indices_repetition=one_hot_indices_repetition,
is_differentiable=True,
)
else:
return SamplingContext(
num_samples=num_samples,
indices_out=torch.randint(low=0, high=layer.num_sums_out, size=(num_samples, layer.num_features_out)),
indices_repetition=torch.randint(low=0, high=layer.num_repetitions, size=(num_samples,)),
is_differentiable=False,
)
66 changes: 66 additions & 0 deletions tests/test_einet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from unittest import TestCase

from itertools import product
from simple_einet.einet import Einet, EinetConfig
import torch
from parameterized import parameterized

from simple_einet.abstract_layers import logits_to_log_weights
from simple_einet.layers.distributions.binomial import Binomial
from simple_einet.layers.linsum import LinsumLayer
from simple_einet.sampling_utils import index_one_hot


class TestEinet(TestCase):
def make_einet(self, num_classes, num_repetitions):
config = EinetConfig(
num_features=self.num_features,
num_channels=self.num_channels,
depth=self.depth,
num_sums=self.num_sums,
num_leaves=self.num_leaves,
num_repetitions=num_repetitions,
num_classes=num_classes,
leaf_type=self.leaf_type,
leaf_kwargs=self.leaf_kwargs,
layer_type="linsum",
dropout=0.0,
)
return Einet(config)

def setUp(self) -> None:
self.num_features = 8
self.num_channels = 3
self.num_sums = 5
self.num_leaves = 2
self.depth = 3
self.leaf_type = Binomial
self.leaf_kwargs = {"total_count": 255}

@parameterized.expand(product([False, True], [1, 3], [1, 4]))
def test_sampling_shapes(self, differentiable: bool, num_classes: int, num_repetitions: int):
model = self.make_einet(num_classes=num_classes, num_repetitions=num_repetitions)
N = 2

# Sample without evidence
samples = model.sample(num_samples=N, is_differentiable=differentiable)
self.assertEqual(samples.shape, (N, self.num_channels, self.num_features))

# Sample with evidence
evidence = torch.randint(0, 2, size=(N, self.num_channels, self.num_features))
samples = model.sample(evidence=evidence, is_differentiable=differentiable)
self.assertEqual(samples.shape, (N, self.num_channels, self.num_features))

@parameterized.expand(product([False, True], [1, 3], [1, 4]))
def test_mpe_shapes(self, differentiable: bool, num_classes: int, num_repetitions: int):
model = self.make_einet(num_classes=num_classes, num_repetitions=num_repetitions)
N = 2

# MPE without evidence
mpe = model.mpe(is_differentiable=differentiable)
self.assertEqual(mpe.shape, (1, self.num_channels, self.num_features))

# MPE with evidence
evidence = torch.randint(0, 2, size=(N, self.num_channels, self.num_features))
mpe = model.mpe(evidence=evidence, is_differentiable=differentiable)
self.assertEqual(mpe.shape, (N, self.num_channels, self.num_features))

0 comments on commit 53a8630

Please sign in to comment.