Skip to content

Commit

Permalink
Update the documentation of fast_gp:GaussianProcess class and its log…
Browse files Browse the repository at this point in the history
…_prob

method.

PiperOrigin-RevId: 610437286
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Feb 26, 2024
1 parent aff7da4 commit d291dca
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions tensorflow_probability/python/experimental/fastgp/fast_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,11 @@ class GaussianProcess(distribution.Distribution):
"""Fast, JAX-only implementation of a GP distribution class.
See tfd.distributions.GaussianProcess for a description and parameter
documentation. Currently only supports log_prob and posterior_predictive
(the only two methods used by smc.py).
documentation, but note that not all of that class's methods are supported.
The default parameters are tuned to give a good time / error trade-off
in the n > 15,000 regime where this class gives a substantial speed-up
over tfd.distributions.GaussianProcess. In particular, it is tuned to
give a trade-off in the case where you care about the accuracy of both
log_prob and its derivative. If you care only about log_prob, it is
recommended to use log_det_algorithm='slq' with preconditioner_num_iters=1.
over tfd.distributions.GaussianProcess.
"""

def __init__(
Expand Down Expand Up @@ -311,7 +307,26 @@ def get_preconditioner(cov):

@jax.named_call
def log_prob(self, value, key, is_missing=None) -> Array:
"""log P(value | GP)."""
"""log P(value | GP).
Args:
value: `float` or `double` jax.Array.
key: A jax KeyArray. This method uses stochastic methods to quickly
estimate the log probability of `value`, and `key` is needed to
generate the stochasticity. `key` is also used when computing the
derivative of this function. In some circumstances it is acceptable
and in fact even necessary to pass the same value of `key` to multiple
invocations of log_prob; for example if the log_prob is being
optimized by an algorithm that assumes it is deterministic.
is_missing: Optional `bool` jax.Array of shape `[..., e]` where `e` is
the number of index points in each batch. Represents a batch of
Boolean masks. When not `None`, the returned log_prob is for the
*marginal* distribution in which all dimensions with `is_missing==True`
have been marginalized out.
Returns:
A stochastic approximation to log P(value | GP).
"""
empty_sample_batch_shape = value.ndim == 1
if empty_sample_batch_shape:
value = value[jnp.newaxis]
Expand Down

0 comments on commit d291dca

Please sign in to comment.