Skip to content

Commit

Permalink
Bug fix: MaternBNN should use student_t distribution in its weight prior
Browse files Browse the repository at this point in the history
and not just in its weight initialization.
Bug fix:  in bnn_tree.list_of_all, don't include WeightedSums where both leaves are the same kernel type.
Also, add ExponentialBNN as a base kernel.

PiperOrigin-RevId: 603051403
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Jan 31, 2024
1 parent eb67a7f commit 4a08cd2
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
LEAVES = [
kernels.ExponentiatedQuadraticBNN,
kernels.MaternBNN,
kernels.ExponentialBNN,
kernels.LinearBNN,
kernels.QuadraticBNN,
kernels.PeriodicBNN,
Expand All @@ -49,6 +50,7 @@
NON_PERIODIC_KERNELS = [
kernels.ExponentiatedQuadraticBNN,
kernels.MaternBNN,
kernels.ExponentialBNN,
kernels.LinearBNN,
kernels.QuadraticBNN,
kernels.OneLayerBNN,
Expand Down Expand Up @@ -86,7 +88,7 @@ def list_of_all(
# Abelian operators that aren't Multiply.
if include_sums:
for i, c1 in enumerate(non_multiply_children):
for j in range(i + 1):
for j in range(i):
c2 = non_multiply_children[j]
# Add is also abelian, but WeightedSum is more general.
all_bnns.append(
Expand Down
28 changes: 15 additions & 13 deletions tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,25 @@

class TreeTest(parameterized.TestCase):

def test_list_of_all(self):
def test_list_of_all_depth0(self):
l0 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 0)
# With no periods, there should be five kernels.
self.assertLen(l0, 5)
# With no periods, there should be six kernels.
self.assertLen(l0, 6)
for k in l0:
self.assertFalse(k.going_to_be_multiplied)

l0 = bnn_tree.list_of_all(100, 0, 50, [20.0, 40.0], parent_is_multiply=True)
self.assertLen(l0, 7)
self.assertLen(l0, 8)
for k in l0:
self.assertTrue(k.going_to_be_multiplied)

def test_list_of_all_depth1(self):
l1 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 1)
# With no periods, there should be
# 15 trees with a Multiply top node,
# 15 trees with a WeightedSum top node, and
# 25 trees with a LearnableChangePoint top node.
self.assertLen(l1, 55)
# choose(6+1, 2) = 21 trees with a Multiply top node,
# choose(6, 2) = 15 trees with a WeightedSum top node, and
# 6*6 = 36 trees with a LearnableChangePoint top node.
self.assertLen(l1, 72)

# Check that all of the BNNs in the tree can be trained.
for k in l1:
Expand All @@ -63,12 +64,13 @@ def test_list_of_all(self):
# nodes, with 7*8/2 = 28 trees.
self.assertLen(l1, 28)

def test_list_of_all_depth2(self):
l2 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 2)
# With no periods, there should be
# 15*16/2 = 120 trees with a Multiply top node,
# 55*56/2 = 1540 trees with a WeightedSum top node, and
# 55*55 = 3025 trees with a LearnableChangePoint top node.
self.assertLen(l2, 4685)
# There are 66 trees of depth 1, of which 15 are safe to multiply.
# choose(15+1, 2) = 120 trees with a Multiply top node,
# choose(66, 2) = 2145 trees with a WeightedSum top node, and
# 66*66 = 4356 trees with a LearnableChangePoint top node.
self.assertLen(l2, 7860)

@parameterized.parameters(0, 1) # depth=2 segfaults on my desktop :(
def test_weighted_sum_of_all(self, depth):
Expand Down
14 changes: 14 additions & 0 deletions tensorflow_probability/python/experimental/autobnn/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,25 @@ def kernel_init(seed, shape, unused_dtype):
self.kernel_init = kernel_init
super().setup()

def distributions(self):
d = super().distributions()
d['dense1']['kernel'] = student_t_lib.StudentT(
df=2.0 * self.degrees_of_freedom, loc=0.0, scale=1.0)
return d

def summarize(self, params=None, full: bool = False) -> str:
"""Return a string summarizing the structure of the BNN."""
return f'{self.shortname()}({self.degrees_of_freedom})'


class ExponentialBNN(MaternBNN):
"""Matern(0.5), also known as the absolute exponential kernel."""
degrees_of_freedom: float = 0.5

def summarize(self, params=None, full: bool = False) -> str:
return self.shortname()


class PolynomialBNN(OneLayerBNN):
"""A BNN where samples are polynomial functions."""
degree: int = 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
kernels.OneLayerBNN,
kernels.ExponentiatedQuadraticBNN,
kernels.MaternBNN,
kernels.ExponentialBNN,
kernels.PeriodicBNN,
kernels.PolynomialBNN,
kernels.LinearBNN,
Expand Down Expand Up @@ -139,6 +140,7 @@ def test_likelihood(self, kernel):
(kernels.OneLayerBNN(width=10), 'OneLayer'),
(kernels.ExponentiatedQuadraticBNN(width=5), 'RBF'),
(kernels.MaternBNN(width=5), 'Matern(2.5)'),
(kernels.ExponentialBNN(width=20), 'Exponential'),
(kernels.PeriodicBNN(period=10, width=10), 'Periodic(period=10.00)'),
(kernels.PolynomialBNN(degree=3, width=2), 'Polynomial(degree=3)'),
(kernels.LinearBNN(width=5), 'Linear'),
Expand Down

0 comments on commit 4a08cd2

Please sign in to comment.