Skip to content

Commit

Permalink
Handle local measures in TransformedDistribution.
Browse files Browse the repository at this point in the history
This change continues to set up the framework for tracking base measures and computing corrections on transformed densities. In `TransformedDistribution` we update `log_prob` to call a version of `experimental_local_measure` that keeps track of the base measure. We introduce a backwards-compatibility argument to control this rollout.

PiperOrigin-RevId: 385616650
  • Loading branch information
DistraxDev authored and DistraxDev committed Jan 20, 2022
1 parent a1c5d43 commit b7d457d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
29 changes: 29 additions & 0 deletions distrax/_src/bijectors/tfp_compatible_bijector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
from distrax._src.utils import math
import jax
import jax.numpy as jnp
from tensorflow_probability.python.experimental import tangent_spaces
from tensorflow_probability.substrates import jax as tfp

tfb = tfp.bijectors
tfd = tfp.distributions

Array = chex.Array
Bijector = bijector.Bijector
TangentSpace = tangent_spaces.TangentSpace


def tfp_compatible_bijector(
Expand Down Expand Up @@ -175,4 +177,31 @@ def _check_shape(
f"{event_shape} which has only {len(event_shape)} "
f"dimensions instead.")

def experimental_compute_density_correction(
self,
x: Array,
tangent_space: TangentSpace,
backward_compat: bool = True,
**kwargs):
"""Density correction for this transform wrt the tangent space, at x.
See `tfb.bijector.Bijector.experimental_compute_density_correction`, and
Radul and Alexeev, AISTATS 2021, “The Base Measure Problem and its
Solution”, https://arxiv.org/abs/2010.09647.
Args:
x: `float` or `double` `Array`.
tangent_space: `TangentSpace` or one of its subclasses. The tangent to
the support manifold at `x`.
backward_compat: unused
**kwargs: Optional keyword arguments forwarded to tangent space methods.
Returns:
density_correction: `Array` representing the density correction---in log
space---under the transformation that this Bijector denotes. Assumes
the Bijector is dimension-preserving.
"""
del backward_compat
return tangent_space.transform_dimension_preserving(x, self, **kwargs)

return TFPCompatibleBijector()
29 changes: 28 additions & 1 deletion distrax/_src/distributions/tfp_compatible_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# ==============================================================================
"""Wrapper to adapt a Distrax distribution for use in TFP."""

from typing import Dict, Optional, Sequence, Union
from typing import Dict, Optional, Sequence, Tuple, Union

import chex
from distrax._src.distributions import distribution
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.python.experimental import tangent_spaces
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
Expand All @@ -29,6 +30,7 @@
Distribution = distribution.Distribution
IntLike = distribution.IntLike
PRNGKey = chex.PRNGKey
TangentSpace = tangent_spaces.TangentSpace


def tfp_compatible_distribution(
Expand Down Expand Up @@ -136,4 +138,29 @@ def sample(self,
sample_shape = tuple(sample_shape)
return base_distribution.sample(sample_shape=sample_shape, seed=seed)

def experimental_local_measure(
self,
value: Array,
backward_compat: bool = True,
**unused_kwargs) -> Tuple[Array, TangentSpace]:
"""Returns a log probability density together with a `TangentSpace`.
See `tfd.distribution.Distribution.experimental_local_measure`, and
Radul and Alexeev, AISTATS 2021, “The Base Measure Problem and its
Solution”, https://arxiv.org/abs/2010.09647.
Args:
value: `float` or `double` `Array`.
backward_compat: unused
**unused_kwargs: unused
Returns:
log_prob: see `log_prob`.
tangent_space: `tangent_spaces.FullSpace()`, representing R^n with the
standard basis.
"""
del backward_compat
del unused_kwargs
return self.log_prob(value), tangent_spaces.FullSpace()

return TFPCompatibleDistribution()

0 comments on commit b7d457d

Please sign in to comment.