Skip to content

Commit

Permalink
Added KF with sparse sites and related test (#18)
Browse files Browse the repository at this point in the history
* Added KF with sparse sites and related test

* Update tests/integration/test_kalman_filter_with_sparse_sites.py

Co-authored-by: Vincent Adam <[email protected]>

* Incorporated suggested changes

* Update markovflow/kalman_filter.py

Co-authored-by: Vincent Adam <[email protected]>

* Update markovflow/kalman_filter.py

Co-authored-by: Vincent Adam <[email protected]>

* Incorporated PR changes

* sparse observations saved

* passing sparse sites rather than dense sites

Co-authored-by: Vincent Adam <[email protected]>
  • Loading branch information
prakharverma and vincentadam87 authored Sep 3, 2022
1 parent 08dac0b commit 06f21e0
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 2 deletions.
121 changes: 119 additions & 2 deletions markovflow/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ def _k_inv_post(self):
# The emission matrix is tiled across the time_points, so for a time invariant matrix
# this is equivalent to Gᵀ Σ⁻¹ G = (I_N ⊗ HᵀR⁻¹H),
likelihood_precision = SymmetricBlockTriDiagonal(h_t_r_h)
_k_inv_prior = self.prior_ssm.precision
# K⁻¹ + GᵀΣ⁻¹G
return _k_inv_prior + likelihood_precision
return self._k_inv_prior + likelihood_precision

@property
def _log_det_observation_precision(self):
Expand Down Expand Up @@ -495,3 +494,121 @@ def _log_det_observation_precision(self):
def observations(self):
""" Observation vector """
return self.sites.means


@tf_scope_class_decorator
class KalmanFilterWithSparseSites(BaseKalmanFilter):
r"""
Performs a Kalman filter on a :class:`~markovflow.state_space_model.StateSpaceModel`
and :class:`~markovflow.emission_model.EmissionModel`, with Gaussian sites, over a time grid.
"""

def __init__(self, state_space_model: StateSpaceModel, emission_model: EmissionModel, sites: GaussianSites,
num_grid_points: int, observations_index: tf.Tensor, observations: tf.Tensor):
"""
:param state_space_model: Parameterises the latent chain.
:param emission_model: Maps the latent chain to the observations.
:param sites: Gaussian sites over the observations.
:param num_grid_points: number of grid points.
:param observations_index: Index of the observations in the time grid with shape (N,).
:param observations: Sparse observations with shape (N, output_dim).
"""
self.sites = sites
self.observations_index = observations_index
self.sparse_observations = observations
self.grid_shape = tf.TensorShape((num_grid_points, 1))
super().__init__(state_space_model, emission_model)

@property
def _r_inv(self):
"""
Precisions of the observation model over the time grid.
"""
data_sites_precision = self.sites.precisions
return self.sparse_to_dense(data_sites_precision, output_shape=self.grid_shape + (1,))

@property
def _log_det_observation_precision(self):
"""
Sum of log determinant of the precisions of the observation model. It only calculates for the data_sites as
other sites precision is anyways zero.
"""
return tf.reduce_sum(tf.linalg.logdet(self._r_inv_data), axis=-1)

@property
def observations(self):
""" Sparse observation vector """
return self.sparse_observations

@property
def _r_inv_data(self):
"""
Precisions of the observation model for only the data sites.
"""
return self.sites.precisions

def sparse_to_dense(self, tensor: tf.Tensor, output_shape: tf.TensorShape) -> tf.Tensor:
"""
Convert a sparse tensor to a dense one on the basis of observations index, output tensor is of the output_shape.
"""
return tf.scatter_nd(self.observations_index, tensor, output_shape)

def dense_to_sparse(self, tensor: tf.Tensor) -> tf.Tensor:
"""
Convert a dense tensor to a sparse one on the basis of observations index.
"""
tensor_shape = tensor.shape
expand_dims = len(tensor_shape) == 3

tensor = tf.gather_nd(tf.reshape(tensor, (-1, 1)), self.observations_index)
if expand_dims:
tensor = tf.expand_dims(tensor, axis=-1)
return tensor

def log_likelihood(self) -> tf.Tensor:
r"""
Construct a TensorFlow function to compute the likelihood.
For more mathematical details, look at the log_likelihood function of the parent class.
The main difference from the parent class are that the vector of observations is now sparse.
:return: The likelihood as a scalar tensor (we sum over the `batch_shape`).
"""
# K⁻¹ + GᵀΣ⁻¹G = LLᵀ.
l_post = self._k_inv_post.cholesky
num_data = self.observations_index.shape[0]

# Hμ [..., num_transitions + 1, output_dim]
marginal = self.emission.project_state_to_f(self.prior_ssm.marginal_means)

# y = obs - Hμ [..., num_transitions + 1, output_dim]
disp = self.sparse_to_dense(self.observations, marginal.shape) - marginal
disp_data = self.sparse_observations - self.dense_to_sparse(marginal)

# cst is the constant term for a gaussian log likelihood
cst = (
-0.5 * np.log(2 * np.pi) * tf.cast(self.emission.output_dim * num_data, default_float())
)

term1 = -0.5 * tf.reduce_sum(
input_tensor=tf.einsum("...op,...p,...o->...o", self._r_inv_data, disp_data, disp_data), axis=[-1, -2]
)

# term 2 is: ½|L⁻¹(GᵀΣ⁻¹)y|²
# (GᵀΣ⁻¹)y [..., num_transitions + 1, state_dim]
obs_proj = self._back_project_y_to_state(disp)

# ½|L⁻¹(GᵀΣ⁻¹)y|² [...]
term2 = 0.5 * tf.reduce_sum(
input_tensor=tf.square(l_post.solve(obs_proj, transpose_left=False)), axis=[-1, -2]
)

## term 3 is: ½log |K⁻¹| - log |L| + ½ log |Σ⁻¹|
# where log |Σ⁻¹| = num_data * log|R⁻¹|
term3 = (
0.5 * self.prior_ssm.log_det_precision()
- l_post.abs_log_det()
+ 0.5 * self._log_det_observation_precision
)

return tf.reduce_sum(cst + term1 + term2 + term3)
69 changes: 69 additions & 0 deletions tests/integration/test_kalman_filter_with_sparse_sites.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import numpy as np
import tensorflow as tf
import pytest
from gpflow.config import default_float

from markovflow.kernels.matern import Matern12
from markovflow.mean_function import LinearMeanFunction
from markovflow.models.gaussian_process_regression import GaussianProcessRegression
from markovflow.kalman_filter import KalmanFilterWithSparseSites, UnivariateGaussianSitesNat
from markovflow.likelihoods import MultivariateGaussian

@pytest.fixture(
name="time_step_homogeneous", params=[(0.01, True), (0.01, False), (0.001, True), (0.001, False)],
)
def _time_step_homogeneous_fixture(request):
return request.param


@pytest.fixture(name="kalman_gpr_setup")
def _setup(batch_shape, time_step_homogeneous):
"""
Create a Gaussian Process model and an equivalent kalman filter model
with more latent states than observations.
FIXME: Currently batch_shape isn't used.
"""
dt, homogeneous = time_step_homogeneous

time_grid = np.arange(0.0, 1.0, dt)
if not homogeneous:
time_grid = np.sort(np.random.choice(time_grid, 50, replace=False))

time_points = time_grid[::10]
observations = np.sin(12 * time_points[..., None]) + np.random.randn(len(time_points), 1) * 0.1

input_data = (
tf.constant(time_points, dtype=default_float()),
tf.constant(observations, dtype=default_float()),
)

observation_covariance = 1.0 # Same as GPFlow default
kernel = Matern12(lengthscale=1.0, variance=1.0, output_dim=observations.shape[-1])
kernel.set_state_mean(tf.random.normal((1,), dtype=default_float()))
gpr_model = GaussianProcessRegression(
input_data=input_data,
kernel=kernel,
mean_function=LinearMeanFunction(1.1),
chol_obs_covariance=tf.constant([[np.sqrt(observation_covariance)]], dtype=default_float()),
)

prior_ssm = kernel.state_space_model(time_grid)
emission_model = kernel.generate_emission_model(time_grid)
observations_index = tf.where(tf.equal(time_grid[..., None], time_points))[:, 0][..., None]

observations -= gpr_model.mean_function(time_points)

nat1 = observations / observation_covariance
nat2 = (-0.5 / observation_covariance) * tf.ones_like(nat1)[..., None]
lognorm = tf.zeros_like(nat1)
sites = UnivariateGaussianSitesNat(nat1=nat1, nat2=nat2, log_norm=lognorm)

kf_sparse_sites = KalmanFilterWithSparseSites(prior_ssm, emission_model, sites, time_grid.shape[0],
observations_index, observations)

return gpr_model, kf_sparse_sites

def test_kalman_loglikelihood(with_tf_random_seed, kalman_gpr_setup):
gpr_model, kf_sparse_sites = kalman_gpr_setup

np.testing.assert_allclose(gpr_model.log_likelihood(), kf_sparse_sites.log_likelihood())

0 comments on commit 06f21e0

Please sign in to comment.