-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f75e73d
commit 53a8630
Showing
9 changed files
with
379 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from unittest import TestCase | ||
|
||
import torch | ||
from parameterized import parameterized | ||
|
||
from simple_einet.abstract_layers import logits_to_log_weights | ||
from simple_einet.layers.einsum import EinsumLayer | ||
from simple_einet.sampling_utils import index_one_hot | ||
from tests.layers.test_utils import get_sampling_context | ||
|
||
|
||
class TestEinsumLayer(TestCase): | ||
def setUp(self) -> None: | ||
self.layer = EinsumLayer(num_features=4, num_sums_in=3, num_sums_out=2, num_repetitions=5) | ||
|
||
def test_logits_to_log_weights(self): | ||
for dim in range(self.layer.logits.dim()): | ||
log_weights = logits_to_log_weights(self.layer.logits, dim=dim) | ||
sums = log_weights.logsumexp(dim=dim) | ||
target = torch.zeros_like(sums) | ||
self.assertTrue(torch.allclose(sums, target, atol=1e-5)) | ||
|
||
def test_forward_shape(self): | ||
bs = 2 | ||
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) | ||
out = self.layer(x) | ||
self.assertEqual( | ||
out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out, self.layer.num_repetitions) | ||
) | ||
|
||
@parameterized.expand([(False,), (True,)]) | ||
def test__sample_from_weights(self, differentiable: bool): | ||
N = 2 | ||
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) | ||
log_weights = self.layer._select_weights(ctx, self.layer.logits) | ||
indices = self.layer._sample_from_weights(ctx, log_weights) | ||
if differentiable: | ||
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in)) | ||
else: | ||
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features)) | ||
|
||
@parameterized.expand([(False,), (True,)]) | ||
def test__select_weights(self, differentiable: bool): | ||
N = 2 | ||
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) | ||
weights = self.layer._select_weights(ctx, self.layer.logits) | ||
self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in**2)) | ||
|
||
@parameterized.expand([(False,), (True,)]) | ||
def test__condition_weights_on_evidence(self, differentiable: bool): | ||
bs = 2 | ||
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) | ||
self.layer._enable_input_cache() | ||
self.layer(x) | ||
|
||
ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable) | ||
log_weights = self.layer._select_weights(ctx, self.layer.logits) | ||
log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights) | ||
sums = log_weights.logsumexp(dim=2) | ||
target = torch.zeros_like(sums) | ||
self.assertTrue(torch.allclose(sums, target, atol=1e-5)) | ||
|
||
def test__differentiable_sampling_has_grads(self): | ||
N = 2 | ||
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) | ||
ctx = self.layer.sample(ctx) | ||
|
||
sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) | ||
sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) | ||
sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) | ||
sample.mean().backward() | ||
self.assertTrue(self.layer.logits.grad is not None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from unittest import TestCase | ||
|
||
import torch | ||
from parameterized import parameterized | ||
|
||
from simple_einet.abstract_layers import logits_to_log_weights | ||
from simple_einet.layers.linsum import LinsumLayer | ||
from simple_einet.sampling_utils import index_one_hot | ||
from tests.layers.test_utils import get_sampling_context | ||
|
||
|
||
class TestLinsumLayer(TestCase): | ||
def setUp(self) -> None: | ||
self.layer = LinsumLayer(num_features=4, num_sums_in=3, num_sums_out=2, num_repetitions=5) | ||
|
||
def test_logits_to_log_weights(self): | ||
for dim in range(self.layer.logits.dim()): | ||
log_weights = logits_to_log_weights(self.layer.logits, dim=dim) | ||
sums = log_weights.logsumexp(dim=dim) | ||
target = torch.zeros_like(sums) | ||
self.assertTrue(torch.allclose(sums, target, atol=1e-5)) | ||
|
||
def test_forward_shape(self): | ||
bs = 2 | ||
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) | ||
out = self.layer(x) | ||
self.assertEqual( | ||
out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out, self.layer.num_repetitions) | ||
) | ||
|
||
@parameterized.expand([(False,), (True,)]) | ||
def test__condition_weights_on_evidence(self, differentiable: bool): | ||
bs = 2 | ||
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) | ||
self.layer._enable_input_cache() | ||
self.layer(x) | ||
|
||
ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable) | ||
log_weights = self.layer._select_weights(ctx, self.layer.logits) | ||
log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights) | ||
sums = log_weights.logsumexp(dim=2) | ||
target = torch.zeros_like(sums) | ||
self.assertTrue(torch.allclose(sums, target, atol=1e-5)) | ||
|
||
@parameterized.expand([(False,), (True,)]) | ||
def test__sample_from_weights(self, differentiable: bool): | ||
N = 2 | ||
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) | ||
log_weights = self.layer._select_weights(ctx, self.layer.logits) | ||
indices = self.layer._sample_from_weights(ctx, log_weights) | ||
if differentiable: | ||
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in)) | ||
else: | ||
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features)) | ||
|
||
@parameterized.expand([(False,), (True,)]) | ||
def test__select_weights(self, differentiable: bool): | ||
N = 2 | ||
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) | ||
weights = self.layer._select_weights(ctx, self.layer.logits) | ||
self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in)) | ||
|
||
def test__differentiable_sampling_has_grads(self): | ||
N = 2 | ||
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) | ||
ctx = self.layer.sample(ctx) | ||
|
||
sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) | ||
sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) | ||
sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) | ||
sample.mean().backward() | ||
self.assertTrue(self.layer.logits.grad is not None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from unittest import TestCase | ||
|
||
import torch | ||
from parameterized import parameterized | ||
|
||
from simple_einet.abstract_layers import logits_to_log_weights | ||
from simple_einet.layers.mixing import MixingLayer | ||
from simple_einet.sampling_utils import index_one_hot | ||
from tests.layers.test_utils import get_sampling_context | ||
|
||
|
||
class TestMixingLayer(TestCase): | ||
def setUp(self) -> None: | ||
self.layer = MixingLayer(num_features=1, num_sums_in=3, num_sums_out=2) | ||
|
||
def test_logits_to_log_weights(self): | ||
for dim in range(self.layer.logits.dim()): | ||
log_weights = logits_to_log_weights(self.layer.logits, dim=dim) | ||
sums = log_weights.logsumexp(dim=dim) | ||
target = torch.zeros_like(sums) | ||
self.assertTrue(torch.allclose(sums, target, atol=1e-5)) | ||
|
||
def test_forward_shape(self): | ||
bs = 2 | ||
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_out, self.layer.num_sums_in) | ||
out = self.layer(x) | ||
self.assertEqual(out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out)) | ||
|
||
@parameterized.expand([(False,), (True,)]) | ||
def test__condition_weights_on_evidence(self, differentiable: bool): | ||
bs = 2 | ||
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_out, self.layer.num_sums_in) | ||
self.layer._enable_input_cache() | ||
self.layer(x) | ||
|
||
ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable) | ||
log_weights = self.layer._select_weights(ctx, self.layer.logits) | ||
log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights) | ||
sums = log_weights.logsumexp(dim=2) | ||
target = torch.zeros_like(sums) | ||
self.assertTrue(torch.allclose(sums, target, atol=1e-5)) | ||
|
||
@parameterized.expand([(False,), (True,)]) | ||
def test__sample_from_weights(self, differentiable: bool): | ||
N = 2 | ||
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) | ||
log_weights = self.layer._select_weights(ctx, self.layer.logits) | ||
indices = self.layer._sample_from_weights(ctx, log_weights) | ||
if differentiable: | ||
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in)) | ||
else: | ||
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features)) | ||
|
||
@parameterized.expand([(False,), (True,)]) | ||
def test__select_weights(self, differentiable: bool): | ||
N = 2 | ||
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) | ||
weights = self.layer._select_weights(ctx, self.layer.logits) | ||
self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in)) | ||
|
||
def test__differentiable_sampling_has_grads(self): | ||
N = 2 | ||
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) | ||
ctx = self.layer.sample(ctx) | ||
|
||
sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) | ||
sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) | ||
sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) | ||
sample.mean().backward() | ||
self.assertTrue(self.layer.logits.grad is not None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from unittest import TestCase | ||
|
||
import torch | ||
from parameterized import parameterized | ||
|
||
from simple_einet.abstract_layers import logits_to_log_weights | ||
from simple_einet.layers.sum import SumLayer | ||
from simple_einet.sampling_utils import index_one_hot | ||
from tests.layers.test_utils import get_sampling_context | ||
|
||
|
||
class TestSumLayer(TestCase): | ||
def setUp(self) -> None: | ||
self.layer = SumLayer(num_features=4, num_sums_in=3, num_sums_out=2, num_repetitions=5) | ||
|
||
def test_logits_to_log_weights(self): | ||
for dim in range(self.layer.logits.dim()): | ||
log_weights = logits_to_log_weights(self.layer.logits, dim=dim) | ||
sums = log_weights.logsumexp(dim=dim) | ||
target = torch.zeros_like(sums) | ||
self.assertTrue(torch.allclose(sums, target, atol=1e-5)) | ||
|
||
def test_forward_shape(self): | ||
bs = 2 | ||
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) | ||
out = self.layer(x) | ||
self.assertEqual( | ||
out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out, self.layer.num_repetitions) | ||
) | ||
|
||
@parameterized.expand([(False,), (True,)]) | ||
def test__sample_from_weights(self, differentiable: bool): | ||
N = 2 | ||
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) | ||
log_weights = self.layer._select_weights(ctx, self.layer.logits) | ||
indices = self.layer._sample_from_weights(ctx, log_weights) | ||
if differentiable: | ||
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in)) | ||
else: | ||
self.assertEqual(tuple(indices.shape), (N, self.layer.num_features)) | ||
|
||
@parameterized.expand([(False,), (True,)]) | ||
def test__select_weights(self, differentiable: bool): | ||
N = 2 | ||
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) | ||
weights = self.layer._select_weights(ctx, self.layer.logits) | ||
self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in)) | ||
|
||
@parameterized.expand([(False,), (True,)]) | ||
def test__condition_weights_on_evidence(self, differentiable: bool): | ||
bs = 2 | ||
x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) | ||
self.layer._enable_input_cache() | ||
self.layer(x) | ||
|
||
ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable) | ||
log_weights = self.layer._select_weights(ctx, self.layer.logits) | ||
log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights) | ||
sums = log_weights.logsumexp(dim=2) | ||
target = torch.zeros_like(sums) | ||
self.assertTrue(torch.allclose(sums, target, atol=1e-5)) | ||
|
||
def test__differentiable_sampling_has_grads(self): | ||
N = 2 | ||
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) | ||
ctx = self.layer.sample(ctx) | ||
|
||
sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) | ||
sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) | ||
sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) | ||
sample.mean().backward() | ||
self.assertTrue(self.layer.logits.grad is not None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import torch | ||
from torch.nn import functional as F | ||
|
||
from simple_einet.sampling_utils import SamplingContext | ||
|
||
|
||
def get_sampling_context(layer, num_samples: int, is_differentiable: bool = False): | ||
if is_differentiable: | ||
indices_out = torch.randint(low=0, high=layer.num_sums_out, size=(num_samples, layer.num_features_out)) | ||
one_hot_indices_out = F.one_hot(indices_out, num_classes=layer.num_sums_out).float() | ||
indices_repetition = torch.randint(low=0, high=layer.num_repetitions, size=(num_samples,)) | ||
one_hot_indices_repetition = F.one_hot(indices_repetition, num_classes=layer.num_repetitions).float() | ||
one_hot_indices_out.requires_grad_(True) | ||
one_hot_indices_repetition.requires_grad_(True) | ||
return SamplingContext( | ||
num_samples=num_samples, | ||
indices_out=one_hot_indices_out, | ||
indices_repetition=one_hot_indices_repetition, | ||
is_differentiable=True, | ||
) | ||
else: | ||
return SamplingContext( | ||
num_samples=num_samples, | ||
indices_out=torch.randint(low=0, high=layer.num_sums_out, size=(num_samples, layer.num_features_out)), | ||
indices_repetition=torch.randint(low=0, high=layer.num_repetitions, size=(num_samples,)), | ||
is_differentiable=False, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from unittest import TestCase | ||
|
||
from itertools import product | ||
from simple_einet.einet import Einet, EinetConfig | ||
import torch | ||
from parameterized import parameterized | ||
|
||
from simple_einet.abstract_layers import logits_to_log_weights | ||
from simple_einet.layers.distributions.binomial import Binomial | ||
from simple_einet.layers.linsum import LinsumLayer | ||
from simple_einet.sampling_utils import index_one_hot | ||
|
||
|
||
class TestEinet(TestCase): | ||
def make_einet(self, num_classes, num_repetitions): | ||
config = EinetConfig( | ||
num_features=self.num_features, | ||
num_channels=self.num_channels, | ||
depth=self.depth, | ||
num_sums=self.num_sums, | ||
num_leaves=self.num_leaves, | ||
num_repetitions=num_repetitions, | ||
num_classes=num_classes, | ||
leaf_type=self.leaf_type, | ||
leaf_kwargs=self.leaf_kwargs, | ||
layer_type="linsum", | ||
dropout=0.0, | ||
) | ||
return Einet(config) | ||
|
||
def setUp(self) -> None: | ||
self.num_features = 8 | ||
self.num_channels = 3 | ||
self.num_sums = 5 | ||
self.num_leaves = 2 | ||
self.depth = 3 | ||
self.leaf_type = Binomial | ||
self.leaf_kwargs = {"total_count": 255} | ||
|
||
@parameterized.expand(product([False, True], [1, 3], [1, 4])) | ||
def test_sampling_shapes(self, differentiable: bool, num_classes: int, num_repetitions: int): | ||
model = self.make_einet(num_classes=num_classes, num_repetitions=num_repetitions) | ||
N = 2 | ||
|
||
# Sample without evidence | ||
samples = model.sample(num_samples=N, is_differentiable=differentiable) | ||
self.assertEqual(samples.shape, (N, self.num_channels, self.num_features)) | ||
|
||
# Sample with evidence | ||
evidence = torch.randint(0, 2, size=(N, self.num_channels, self.num_features)) | ||
samples = model.sample(evidence=evidence, is_differentiable=differentiable) | ||
self.assertEqual(samples.shape, (N, self.num_channels, self.num_features)) | ||
|
||
@parameterized.expand(product([False, True], [1, 3], [1, 4])) | ||
def test_mpe_shapes(self, differentiable: bool, num_classes: int, num_repetitions: int): | ||
model = self.make_einet(num_classes=num_classes, num_repetitions=num_repetitions) | ||
N = 2 | ||
|
||
# MPE without evidence | ||
mpe = model.mpe(is_differentiable=differentiable) | ||
self.assertEqual(mpe.shape, (1, self.num_channels, self.num_features)) | ||
|
||
# MPE with evidence | ||
evidence = torch.randint(0, 2, size=(N, self.num_channels, self.num_features)) | ||
mpe = model.mpe(evidence=evidence, is_differentiable=differentiable) | ||
self.assertEqual(mpe.shape, (N, self.num_channels, self.num_features)) |