From 6912a9c015d74e11f1b122494c708c83257a3781 Mon Sep 17 00:00:00 2001 From: Steven Braun Date: Tue, 5 Dec 2023 09:29:33 +0100 Subject: [PATCH] Add/update/fix tests --- simple_einet/sampling_utils.py | 1 + simple_einet/tests/__init__.py | 0 simple_einet/tests/layers/__init__.py | 0 .../tests/layers/distributions/__init__.py | 0 .../tests/layers/test_einsum_layer.py | 72 ------------------- .../tests/layers/test_linsum_layer.py | 72 ------------------- .../tests/layers/test_mixing_layer.py | 70 ------------------ simple_einet/tests/layers/test_sum_layer.py | 72 ------------------- simple_einet/tests/layers/test_utils.py | 27 ------- 9 files changed, 1 insertion(+), 313 deletions(-) delete mode 100644 simple_einet/tests/__init__.py delete mode 100644 simple_einet/tests/layers/__init__.py delete mode 100644 simple_einet/tests/layers/distributions/__init__.py delete mode 100644 simple_einet/tests/layers/test_einsum_layer.py delete mode 100644 simple_einet/tests/layers/test_linsum_layer.py delete mode 100644 simple_einet/tests/layers/test_mixing_layer.py delete mode 100644 simple_einet/tests/layers/test_sum_layer.py delete mode 100644 simple_einet/tests/layers/test_utils.py diff --git a/simple_einet/sampling_utils.py b/simple_einet/sampling_utils.py index 1f299c8..845302b 100644 --- a/simple_einet/sampling_utils.py +++ b/simple_einet/sampling_utils.py @@ -303,6 +303,7 @@ def init_einet_stats(einet: "Einet", dataloader: torch.utils.data.DataLoader): # Compute mean and std from tqdm import tqdm + for batch in tqdm(dataloader, desc="Leaf Parameter Initialization"): data, label = batch if stats_mean == None: diff --git a/simple_einet/tests/__init__.py b/simple_einet/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/simple_einet/tests/layers/__init__.py b/simple_einet/tests/layers/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/simple_einet/tests/layers/distributions/__init__.py b/simple_einet/tests/layers/distributions/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/simple_einet/tests/layers/test_einsum_layer.py b/simple_einet/tests/layers/test_einsum_layer.py deleted file mode 100644 index 01b3766..0000000 --- a/simple_einet/tests/layers/test_einsum_layer.py +++ /dev/null @@ -1,72 +0,0 @@ -from unittest import TestCase - -import torch -from parameterized import parameterized - -from simple_einet.abstract_layers import logits_to_log_weights -from simple_einet.layers.einsum import EinsumLayer -from simple_einet.sampling_utils import index_one_hot -from simple_einet.tests.layers.test_utils import get_sampling_context - - -class TestEinsumLayer(TestCase): - def setUp(self) -> None: - self.layer = EinsumLayer(num_features=4, num_sums_in=3, num_sums_out=2, num_repetitions=5) - - def test_logits_to_log_weights(self): - for dim in range(self.layer.logits.dim()): - log_weights = logits_to_log_weights(self.layer.logits, dim=dim) - sums = log_weights.logsumexp(dim=dim) - target = torch.zeros_like(sums) - self.assertTrue(torch.allclose(sums, target, atol=1e-5)) - - def test_forward_shape(self): - bs = 2 - x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) - out = self.layer(x) - self.assertEqual( - out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out, self.layer.num_repetitions) - ) - - @parameterized.expand([(False,), (True,)]) - def test__sample_from_weights(self, differentiable: bool): - N = 2 - ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) - log_weights = self.layer._select_weights(ctx, self.layer.logits) - indices = self.layer._sample_from_weights(ctx, log_weights) - if differentiable: - self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in)) - else: - self.assertEqual(tuple(indices.shape), (N, self.layer.num_features)) - - @parameterized.expand([(False,), (True,)]) - def test__select_weights(self, differentiable: bool): - N = 2 - ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) - weights = self.layer._select_weights(ctx, self.layer.logits) - self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in**2)) - - @parameterized.expand([(False,), (True,)]) - def test__condition_weights_on_evidence(self, differentiable: bool): - bs = 2 - x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) - self.layer._enable_input_cache() - self.layer(x) - - ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable) - log_weights = self.layer._select_weights(ctx, self.layer.logits) - log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights) - sums = log_weights.logsumexp(dim=2) - target = torch.zeros_like(sums) - self.assertTrue(torch.allclose(sums, target, atol=1e-5)) - - def test__differentiable_sampling_has_grads(self): - N = 2 - ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) - ctx = self.layer.sample(ctx) - - sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) - sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) - sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) - sample.mean().backward() - self.assertTrue(self.layer.logits.grad is not None) diff --git a/simple_einet/tests/layers/test_linsum_layer.py b/simple_einet/tests/layers/test_linsum_layer.py deleted file mode 100644 index 9ea0de8..0000000 --- a/simple_einet/tests/layers/test_linsum_layer.py +++ /dev/null @@ -1,72 +0,0 @@ -from unittest import TestCase - -import torch -from parameterized import parameterized - -from simple_einet.abstract_layers import logits_to_log_weights -from simple_einet.layers.linsum import LinsumLayer -from simple_einet.sampling_utils import index_one_hot -from simple_einet.tests.layers.test_utils import get_sampling_context - - -class TestLinsumLayer(TestCase): - def setUp(self) -> None: - self.layer = LinsumLayer(num_features=4, num_sums_in=3, num_sums_out=2, num_repetitions=5) - - def test_logits_to_log_weights(self): - for dim in range(self.layer.logits.dim()): - log_weights = logits_to_log_weights(self.layer.logits, dim=dim) - sums = log_weights.logsumexp(dim=dim) - target = torch.zeros_like(sums) - self.assertTrue(torch.allclose(sums, target, atol=1e-5)) - - def test_forward_shape(self): - bs = 2 - x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) - out = self.layer(x) - self.assertEqual( - out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out, self.layer.num_repetitions) - ) - - @parameterized.expand([(False,), (True,)]) - def test__condition_weights_on_evidence(self, differentiable: bool): - bs = 2 - x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) - self.layer._enable_input_cache() - self.layer(x) - - ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable) - log_weights = self.layer._select_weights(ctx, self.layer.logits) - log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights) - sums = log_weights.logsumexp(dim=2) - target = torch.zeros_like(sums) - self.assertTrue(torch.allclose(sums, target, atol=1e-5)) - - @parameterized.expand([(False,), (True,)]) - def test__sample_from_weights(self, differentiable: bool): - N = 2 - ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) - log_weights = self.layer._select_weights(ctx, self.layer.logits) - indices = self.layer._sample_from_weights(ctx, log_weights) - if differentiable: - self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in)) - else: - self.assertEqual(tuple(indices.shape), (N, self.layer.num_features)) - - @parameterized.expand([(False,), (True,)]) - def test__select_weights(self, differentiable: bool): - N = 2 - ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) - weights = self.layer._select_weights(ctx, self.layer.logits) - self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in)) - - def test__differentiable_sampling_has_grads(self): - N = 2 - ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) - ctx = self.layer.sample(ctx) - - sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) - sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) - sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) - sample.mean().backward() - self.assertTrue(self.layer.logits.grad is not None) diff --git a/simple_einet/tests/layers/test_mixing_layer.py b/simple_einet/tests/layers/test_mixing_layer.py deleted file mode 100644 index 7bb59a0..0000000 --- a/simple_einet/tests/layers/test_mixing_layer.py +++ /dev/null @@ -1,70 +0,0 @@ -from unittest import TestCase - -import torch -from parameterized import parameterized - -from simple_einet.abstract_layers import logits_to_log_weights -from simple_einet.layers.mixing import MixingLayer -from simple_einet.sampling_utils import index_one_hot -from simple_einet.tests.layers.test_utils import get_sampling_context - - -class TestMixingLayer(TestCase): - def setUp(self) -> None: - self.layer = MixingLayer(num_features=1, num_sums_in=3, num_sums_out=2) - - def test_logits_to_log_weights(self): - for dim in range(self.layer.logits.dim()): - log_weights = logits_to_log_weights(self.layer.logits, dim=dim) - sums = log_weights.logsumexp(dim=dim) - target = torch.zeros_like(sums) - self.assertTrue(torch.allclose(sums, target, atol=1e-5)) - - def test_forward_shape(self): - bs = 2 - x = torch.randn(bs, self.layer.num_features, 1, self.layer.num_sums_in) - out = self.layer(x) - self.assertEqual(out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out)) - - @parameterized.expand([(False,), (True,)]) - def test__condition_weights_on_evidence(self, differentiable: bool): - bs = 2 - x = torch.randn(bs, self.layer.num_features, 1, self.layer.num_sums_in) - self.layer._enable_input_cache() - self.layer(x) - - ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable) - log_weights = self.layer._select_weights(ctx, self.layer.logits) - log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights) - sums = log_weights.logsumexp(dim=2) - target = torch.zeros_like(sums) - self.assertTrue(torch.allclose(sums, target, atol=1e-5)) - - @parameterized.expand([(False,), (True,)]) - def test__sample_from_weights(self, differentiable: bool): - N = 2 - ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) - log_weights = self.layer._select_weights(ctx, self.layer.logits) - indices = self.layer._sample_from_weights(ctx, log_weights) - if differentiable: - self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in)) - else: - self.assertEqual(tuple(indices.shape), (N, self.layer.num_features)) - - @parameterized.expand([(False,), (True,)]) - def test__select_weights(self, differentiable: bool): - N = 2 - ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) - weights = self.layer._select_weights(ctx, self.layer.logits) - self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in)) - - def test__differentiable_sampling_has_grads(self): - N = 2 - ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) - ctx = self.layer.sample(ctx) - - sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) - sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) - sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) - sample.mean().backward() - self.assertTrue(self.layer.logits.grad is not None) diff --git a/simple_einet/tests/layers/test_sum_layer.py b/simple_einet/tests/layers/test_sum_layer.py deleted file mode 100644 index b9bd10b..0000000 --- a/simple_einet/tests/layers/test_sum_layer.py +++ /dev/null @@ -1,72 +0,0 @@ -from unittest import TestCase - -import torch -from parameterized import parameterized - -from simple_einet.abstract_layers import logits_to_log_weights -from simple_einet.layers.sum import SumLayer -from simple_einet.sampling_utils import index_one_hot -from simple_einet.tests.layers.test_utils import get_sampling_context - - -class TestSumLayer(TestCase): - def setUp(self) -> None: - self.layer = SumLayer(num_features=4, num_sums_in=3, num_sums_out=2, num_repetitions=5) - - def test_logits_to_log_weights(self): - for dim in range(self.layer.logits.dim()): - log_weights = logits_to_log_weights(self.layer.logits, dim=dim) - sums = log_weights.logsumexp(dim=dim) - target = torch.zeros_like(sums) - self.assertTrue(torch.allclose(sums, target, atol=1e-5)) - - def test_forward_shape(self): - bs = 2 - x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) - out = self.layer(x) - self.assertEqual( - out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out, self.layer.num_repetitions) - ) - - @parameterized.expand([(False,), (True,)]) - def test__sample_from_weights(self, differentiable: bool): - N = 2 - ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) - log_weights = self.layer._select_weights(ctx, self.layer.logits) - indices = self.layer._sample_from_weights(ctx, log_weights) - if differentiable: - self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in)) - else: - self.assertEqual(tuple(indices.shape), (N, self.layer.num_features)) - - @parameterized.expand([(False,), (True,)]) - def test__select_weights(self, differentiable: bool): - N = 2 - ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) - weights = self.layer._select_weights(ctx, self.layer.logits) - self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in)) - - @parameterized.expand([(False,), (True,)]) - def test__condition_weights_on_evidence(self, differentiable: bool): - bs = 2 - x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) - self.layer._enable_input_cache() - self.layer(x) - - ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable) - log_weights = self.layer._select_weights(ctx, self.layer.logits) - log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights) - sums = log_weights.logsumexp(dim=2) - target = torch.zeros_like(sums) - self.assertTrue(torch.allclose(sums, target, atol=1e-5)) - - def test__differentiable_sampling_has_grads(self): - N = 2 - ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) - ctx = self.layer.sample(ctx) - - sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) - sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) - sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) - sample.mean().backward() - self.assertTrue(self.layer.logits.grad is not None) diff --git a/simple_einet/tests/layers/test_utils.py b/simple_einet/tests/layers/test_utils.py deleted file mode 100644 index 3feecd0..0000000 --- a/simple_einet/tests/layers/test_utils.py +++ /dev/null @@ -1,27 +0,0 @@ -import torch -from torch.nn import functional as F - -from simple_einet.sampling_utils import SamplingContext - - -def get_sampling_context(layer, num_samples: int, is_differentiable: bool = False): - if is_differentiable: - indices_out = torch.randint(low=0, high=layer.num_sums_out, size=(num_samples, layer.num_features_out)) - one_hot_indices_out = F.one_hot(indices_out, num_classes=layer.num_sums_out).float() - indices_repetition = torch.randint(low=0, high=layer.num_repetitions, size=(num_samples,)) - one_hot_indices_repetition = F.one_hot(indices_repetition, num_classes=layer.num_repetitions).float() - one_hot_indices_out.requires_grad_(True) - one_hot_indices_repetition.requires_grad_(True) - return SamplingContext( - num_samples=num_samples, - indices_out=one_hot_indices_out, - indices_repetition=one_hot_indices_repetition, - is_differentiable=True, - ) - else: - return SamplingContext( - num_samples=num_samples, - indices_out=torch.randint(low=0, high=layer.num_sums_out, size=(num_samples, layer.num_features_out)), - indices_repetition=torch.randint(low=0, high=layer.num_repetitions, size=(num_samples,)), - is_differentiable=False, - )