Skip to content

Commit

Permalink
Improve backprop performance through experimental kalman filter, by c…
Browse files Browse the repository at this point in the history
…hanging out MVN log_prob calculation.

PiperOrigin-RevId: 579184615
  • Loading branch information
srvasude authored and tensorflower-gardener committed Nov 3, 2023
1 parent 6ccdb1e commit 806394f
Showing 1 changed file with 14 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tensorflow.compat.v2 as tf

from tensorflow_probability.python.distributions import mvn_tril
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.math import linalg
Expand Down Expand Up @@ -625,11 +626,8 @@ def kalman_filter(transition_matrix,
axis=0),
added_cov=time_dep.observation_cov)

# TODO(srvasude): The JVP for this can be implemented more efficiently.
log_likelihoods = mvn_tril.MultivariateNormalTriL(
loc=observation_means,
scale_tril=tf.linalg.cholesky(observation_covs)).log_prob(
observation.y)
log_likelihoods = _mvn_log_prob(
observation_means, observation_covs, observation.y)
if observation.mask is not None:
log_likelihoods = tf.where(observation.mask,
tf.zeros([], dtype=log_likelihoods.dtype),
Expand All @@ -644,6 +642,17 @@ def kalman_filter(transition_matrix,
observation_covs)


def _mvn_log_prob(mean, covariance, y):
cholesky_matrix = tf.linalg.cholesky(covariance)
log_prob = -0.5 * linalg.hpsd_quadratic_form_solvevec(
covariance, y - mean, cholesky_matrix=cholesky_matrix)
log_prob = log_prob - 0.5 * linalg.hpsd_logdet(
covariance, cholesky_matrix=cholesky_matrix)
event_dims = ps.shape(mean)[-1]
return log_prob - 0.5 * event_dims * dtype_util.as_numpy_dtype(
mean.dtype)(np.log(2 * np.pi))


def _extract_batch_shape(x, sample_ndims, event_ndims):
"""Slice out the batch component of `x`'s shape."""
if x is None:
Expand Down

0 comments on commit 806394f

Please sign in to comment.