From 85ad11dbf2fb496b3b8a92fb4695154f2a8a0f05 Mon Sep 17 00:00:00 2001 From: st-- Date: Thu, 21 Apr 2022 10:38:36 +0200 Subject: [PATCH] Implement whitened representation of q(u) (#8) - new feature: add `whiten` keyword argument to `VBPP` constructor with which the whitened representation of q(u) can be turned on. - bugfix: add `full_output_cov` keyword argument to `predict_f` (though only `False` is implemented) so that the `predict_f_samples` method can be used - docs: add note on matrix inversion issues (solution: fix inducing point locations). - improved tests --- demo/demo_coal_1d.py | 5 ++- tests/test_model.py | 93 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_smoke.py | 21 +++++----- vbpp/model.py | 35 ++++++++++++++--- 4 files changed, 139 insertions(+), 15 deletions(-) create mode 100644 tests/test_model.py diff --git a/demo/demo_coal_1d.py b/demo/demo_coal_1d.py index 90a3097..c9e9a73 100644 --- a/demo/demo_coal_1d.py +++ b/demo/demo_coal_1d.py @@ -32,11 +32,14 @@ def build_model(events, domain, M=20): kernel = gpflow.kernels.SquaredExponential() Z = domain_grid(domain, M) feature = gpflow.inducing_variables.InducingPoints(Z) + gpflow.set_trainable(feature, False) q_mu = np.zeros(M) q_S = np.eye(M) num_events = len(events) beta0 = np.sqrt(num_events / domain_area(domain)) - model = VBPP(feature, kernel, domain, q_mu, q_S, beta0=beta0, num_events=num_events) + model = VBPP( + feature, kernel, domain, q_mu, q_S, beta0=beta0, num_events=num_events, whiten=True + ) return model diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..a2b1386 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,93 @@ +# Copyright (C) 2022 ST John +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import tensorflow as tf +from gpflow.inducing_variables import InducingPoints +from gpflow.kernels import SquaredExponential +import gpflow +from vbpp import VBPP + +rng = np.random.RandomState(0) + + +class Data: + domain = np.array([[0.0, 10.0]]) + events = rng.uniform(0, 10, size=20)[:, None] + Z = np.linspace(0, 10, 17)[:, None] + Xtest = np.linspace(-2, 12, 37)[:, None] + + +@pytest.mark.parametrize("whiten", [True, False]) +def test_elbo_terms_at_initialization(whiten): + kernel = SquaredExponential() + feature = InducingPoints(Data.Z) + M = feature.num_inducing + m_init = np.zeros(M) + S_init = np.eye(M) if whiten else kernel(Data.Z, full_cov=True) + m = VBPP(feature, kernel, Data.domain, m_init, S_init, whiten=whiten) + + Kuu = m.compute_Kuu() + assert np.allclose(m.prior_kl(Kuu).numpy(), 0.0) + assert np.allclose(m._elbo_integral_term(Kuu).numpy(), -m.total_area) + + +def test_equivalence_of_whitening(): + kernel = SquaredExponential() + feature = InducingPoints(Data.Z) + + M = feature.num_inducing + np.random.seed(42) + m_init = np.random.randn(M) + S_init = (lambda A: A @ A.T)(np.random.randn(M, M)) + + Kuu = kernel(Data.Z) + L = np.linalg.cholesky(Kuu.numpy()) + + beta0 = 1.234 + m_whitened = VBPP(feature, kernel, Data.domain, m_init, S_init, whiten=True, beta0=beta0) + m_unwhitened = VBPP( + feature, kernel, Data.domain, L @ m_init, L @ S_init @ L.T, whiten=False, beta0=beta0 + ) + + Xnew = np.linspace(-3, 13, 17)[:, None] + f_mean_whitened, f_var_whitened = m_whitened.predict_f(Xnew) + f_mean_unwhitened, f_var_unwhitened = m_unwhitened.predict_f(Xnew) + np.testing.assert_allclose(f_mean_whitened, f_mean_unwhitened, rtol=1e-3) + np.testing.assert_allclose(f_var_whitened, f_var_unwhitened, rtol=2e-3) + + np.testing.assert_allclose( + m_whitened.elbo(Data.events), m_unwhitened.elbo(Data.events), rtol=1e-6 + ) + + +@pytest.mark.parametrize("whiten", [True, False]) +def test_lambda_predictions(whiten): + kernel = SquaredExponential() + feature = InducingPoints(Data.Z) + + M = feature.num_inducing + np.random.seed(42) + m_init = np.random.randn(M) + S_init = (lambda A: A @ A.T)(np.random.randn(M, M)) + beta0 = 1.234 + + m = VBPP(feature, kernel, Data.domain, m_init, S_init, whiten=whiten, beta0=beta0) + + mean, lower, upper = m.predict_lambda_and_percentiles(Data.Xtest) + mean_again = m.predict_lambda(Data.Xtest) + np.testing.assert_allclose(mean, mean_again) + np.testing.assert_array_less(lower, mean) + np.testing.assert_array_less(mean, upper) diff --git a/tests/test_smoke.py b/tests/test_smoke.py index fc98584..dc24310 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -1,3 +1,4 @@ +# Copyright (C) 2022 ST John # Copyright (C) Secondmind Ltd 2017 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import numpy as np -import tensorflow as tf from gpflow.inducing_variables import InducingPoints from gpflow.kernels import SquaredExponential import gpflow @@ -22,20 +23,22 @@ rng = np.random.RandomState(0) -def test_smoke(): +@pytest.mark.parametrize("whiten", [True, False]) +def test_smoke_optimize_and_predict(whiten): domain = np.array([[0.0, 10.0]]) - kernel = SquaredExponential() events = rng.uniform(0, 10, size=20)[:, None] - feature = InducingPoints(np.linspace(0, 10, 20)[:, None]) - M = len(feature) - m = VBPP(feature, kernel, domain, np.zeros(M), np.eye(M)) - Kuu = m.compute_Kuu() - m.q_sqrt.assign(np.linalg.cholesky(Kuu)) - assert np.allclose(m.prior_kl(tf.identity(Kuu)).numpy(), 0.0) + kernel = SquaredExponential() + Z = np.linspace(0, 10, 17)[:, None] + feature = InducingPoints(Z) + M = feature.num_inducing + m = VBPP(feature, kernel, domain, np.zeros(M), np.eye(M), whiten=whiten) def objective_closure(): return -m.elbo(events) opt = gpflow.optimizers.Scipy() opt.minimize(objective_closure, m.trainable_variables, options=dict(maxiter=2)) + + X = np.linspace(-1, 11, 19)[:, None] + _ = m.predict_f_samples(X) diff --git a/vbpp/model.py b/vbpp/model.py index 54c63cf..5b07e78 100644 --- a/vbpp/model.py +++ b/vbpp/model.py @@ -1,3 +1,4 @@ +# Copyright (C) 2022 ST John # Copyright (C) Secondmind Ltd 2017-2020 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -61,6 +62,11 @@ class VBPP(gpflow.models.GPModel, gpflow.models.ExternalDataTrainingLossMixin): Implementation of the "Variational Bayes for Point Processes" model by Lloyd et al. (2015), with capability for multiple observations and the constant offset `beta0` from John and Hensman (2018). + + Note: If you encounter "Input matrix is not invertible." errors during + training, this may be due to inducing points moving too close to each + other, especially in 1D. You may want to consider fixing the inducing + points, e.g. on a grid. """ def __init__( @@ -74,6 +80,7 @@ def __init__( beta0: float = 1e-6, num_observations: int = 1, num_events: Optional[int] = None, + whiten: bool = False, ): """ D = number of dimensions @@ -100,6 +107,10 @@ def __init__( :param num_events: total number of events, defaults to events.shape[0] (relevant when feeding in minibatches) + + :param whiten: whether to use the whitened representation of q(u). + When whiten=True, we parametrise q(v) = N(q_mu, q_S) instead, and u = L v, + where L is the lower-triangular Cholesky factor of the kernel matrix Kuu. """ super().__init__( kernel, @@ -138,6 +149,8 @@ def __init__( self.psi_jitter = 0.0 + self.whiten = whiten + def _Psi_matrix(self): Ψ = tf_calc_Psi_matrix(self.kernel, self.inducing_variable, self.domain) psi_jitter_matrix = self.psi_jitter * tf.eye( @@ -149,11 +162,13 @@ def _Psi_matrix(self): def total_area(self): return np.prod(self.domain[:, 1] - self.domain[:, 0]) - def predict_f(self, Xnew, full_cov=False, *, Kuu=None): + def predict_f(self, Xnew, full_cov=False, *, full_output_cov=False, Kuu=None): """ VBPP-specific conditional on the approximate posterior q(u), including a constant mean function. """ + if full_output_cov: + raise NotImplementedError("only supports single-output models") mean, var = conditional( Xnew, self.inducing_variable, @@ -161,6 +176,7 @@ def predict_f(self, Xnew, full_cov=False, *, Kuu=None): self.q_mu[:, None], full_cov=full_cov, q_sqrt=self.q_sqrt[None, :, :], + white=self.whiten, ) # TODO make conditional() use Kuu if available @@ -201,7 +217,10 @@ def _elbo_integral_term(self, Kuu): # Kzz⁻¹ m = R^-T R⁻¹ m # Rinv_m = R⁻¹ m - Rinv_m = tf.linalg.triangular_solve(R, self.q_mu[:, None], lower=True) + if self.whiten: + Rinv_m = self.q_mu[:, None] + else: + Rinv_m = tf.linalg.triangular_solve(R, self.q_mu[:, None], lower=True) # R⁻¹ Ψ R^-T # = (R⁻¹ Ψ) R^-T @@ -211,8 +230,11 @@ def _elbo_integral_term(self, Kuu): int_mean_f_sqr = tf_vec_mat_vec_mul(Rinv_m, Rinv_Ψ_RinvT, Rinv_m) - Rinv_L = tf.linalg.triangular_solve(R, self.q_sqrt, lower=True) - Rinv_L_LT_RinvT = tf.matmul(Rinv_L, Rinv_L, transpose_b=True) + if self.whiten: + Rinv_L_LT_RinvT = tf.matmul(self.q_sqrt, self.q_sqrt, transpose_b=True) + else: + Rinv_L = tf.linalg.triangular_solve(R, self.q_sqrt, lower=True) + Rinv_L_LT_RinvT = tf.matmul(Rinv_L, Rinv_L, transpose_b=True) # int_var_fx = γ |T| + trace_terms # trace_terms = - Tr(Kzz⁻¹ Ψ) + Tr(Kzz⁻¹ S Kzz⁻¹ Ψ) @@ -244,7 +266,10 @@ def prior_kl(self, Kuu): """ KL divergence between p(u) = N(0, Kuu) and q(u) = N(μ, S) """ - return kullback_leiblers.gauss_kl(self.q_mu[:, None], self.q_sqrt[None, :, :], Kuu) + if self.whiten: + return kullback_leiblers.gauss_kl(self.q_mu[:, None], self.q_sqrt[None, :, :]) + else: + return kullback_leiblers.gauss_kl(self.q_mu[:, None], self.q_sqrt[None, :, :], Kuu) def maximum_log_likelihood_objective(self, events): return self.elbo(events)