Skip to content

Commit

Permalink
Add FlatDirichlet distribution.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 613223674
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Mar 6, 2024
1 parent ae11c4c commit e3cf1f4
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 180 deletions.
3 changes: 2 additions & 1 deletion tensorflow_probability/python/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
# Contains ops for statistical distributions (with pdf, cdf, sample, etc...).
# APIs here are meant to evolve over time.

# Placeholder: py_binary
# Placeholder: py_library
# Placeholder: py_test
# Placeholder: py_binary
load(
"//tensorflow_probability/python:build_defs.bzl",
"multi_substrate_py_library",
Expand Down Expand Up @@ -507,6 +507,7 @@ multi_substrate_py_library(
"//tensorflow_probability/python/internal:parameter_properties",
"//tensorflow_probability/python/internal:prefer_static",
"//tensorflow_probability/python/internal:reparameterization",
"//tensorflow_probability/python/internal:samplers",
"//tensorflow_probability/python/internal:tensor_util",
"//tensorflow_probability/python/internal:tensorshape_util",
],
Expand Down
1 change: 1 addition & 0 deletions tensorflow_probability/python/distributions/TAXONOMY.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ ZeroInflatedNegativeBinomial,
### Distributions over Unit Simplex in R^n

[Dirichlet](https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Dirichlet)
[FlatDirichlet](https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/FlatDirichlet)
[RelaxedOneHotCategorical](https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/RelaxedOneHotCategorical)

### Distributions over Matrices
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_probability/python/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from tensorflow_probability.python.distributions.deterministic import Deterministic
from tensorflow_probability.python.distributions.deterministic import VectorDeterministic
from tensorflow_probability.python.distributions.dirichlet import Dirichlet
from tensorflow_probability.python.distributions.dirichlet import FlatDirichlet
from tensorflow_probability.python.distributions.dirichlet_multinomial import DirichletMultinomial
from tensorflow_probability.python.distributions.distribution import AutoCompositeTensorDistribution
from tensorflow_probability.python.distributions.distribution import Distribution
Expand Down Expand Up @@ -197,6 +198,7 @@
'ExponentiallyModifiedGaussian',
'ExpRelaxedOneHotCategorical',
'FiniteDiscrete',
'FlatDirichlet',
'FULLY_REPARAMETERIZED',
'Gamma',
'GammaGamma',
Expand Down
122 changes: 121 additions & 1 deletion tensorflow_probability/python/distributions/dirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# Dependency imports
import numpy as np
import tensorflow.compat.v2 as tf

from tensorflow_probability.python.bijectors import softmax_centered as softmax_centered_bijector
from tensorflow_probability.python.bijectors import softplus as softplus_bijector
from tensorflow_probability.python.distributions import distribution
Expand All @@ -29,12 +28,14 @@
from tensorflow_probability.python.internal import parameter_properties
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import tensorshape_util


__all__ = [
'Dirichlet',
'FlatDirichlet',
]


Expand Down Expand Up @@ -450,3 +451,122 @@ def _kl_dirichlet_dirichlet(d1, d2, name=None):
return (
tf.reduce_sum(concentration_diff * digamma_diff, axis=-1) -
tf.math.lbeta(concentration1) + tf.math.lbeta(concentration2))


class FlatDirichlet(Dirichlet):
"""Special case of Dirichlet for concentration = 1.
This case is both frequent and admits a more efficient sampling algorithm.
"""

def __init__(
self,
concentration_shape,
dtype=tf.float32,
validate_args=False,
allow_nan_stats=True,
force_probs_to_zero_outside_support=False,
name='FlatDirichlet',
):
"""Initialize a batch of FlatDirichlet distributions.
Args:
concentration_shape: Integer `Tensor` shape of the concentration
parameter.
dtype: The dtype of the distribution.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
result is undefined. When `False`, an exception is raised if one or more
of the statistic's batch members are undefined.
force_probs_to_zero_outside_support: If `True`, force `prob(x) == 0` and
`log_prob(x) == -inf` for values of x outside the distribution support.
name: Python `str` name prefixed to Ops created by this class.
"""
parameters = dict(locals())
self._concentration_shape = tensor_util.convert_nonref_to_tensor(
concentration_shape,
dtype=tf.int32,
name='concentration_shape',
as_shape_tensor=True,
)
self._concentration_shape_static = tensorshape_util.constant_value_as_shape(
self._concentration_shape
)
concentration = tf.ones(concentration_shape, dtype=dtype)
super(FlatDirichlet, self).__init__(
concentration=concentration,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
force_probs_to_zero_outside_support=force_probs_to_zero_outside_support,
name=name,
)
self._parameters = parameters

@classmethod
def _parameter_properties(cls, dtype, num_classes=None):
return dict(
concentration_shape=parameter_properties.ShapeParameterProperties()
)

@property
def concentration_shape(self):
return self._concentration_shape

def _batch_shape_tensor(self):
return tf.constant(self._concentration_shape[:-1], dtype=tf.int32)

def _batch_shape(self):
return tf.TensorShape(self._concentration_shape_static[:-1])

def _event_shape_tensor(self):
return tf.constant(self._concentration_shape[-1], dtype=tf.int32)

def _event_shape(self):
return tf.TensorShape([self._concentration_shape_static[-1]])

def _log_prob(self, x):
# The pdf of a flat dirichlet is just Gamma(n).
n = tf.cast(self._concentration_shape[-1], dtype=tf.float32)
lp = tf.math.lgamma(n)
if self._force_probs_to_zero_outside_support:
eps = np.finfo(dtype_util.as_numpy_dtype(x.dtype)).eps
in_support = (
tf.reduce_all(x >= 0, axis=-1) &
# Reusing the logic of tf.debugging.assert_near, 10 * np.finfo.eps
(tf.math.abs(tf.reduce_sum(x, axis=-1) - 1.) < 10 * eps))
return tf.where(in_support, lp, -float('inf'))
return lp

def _sample_n(self, n, seed=None):
# https://en.wikipedia.org/wiki/Dirichlet_distribution#When_each_alpha_is_1
tshape = self._concentration_shape
# rand_shape = [n] + tshape[:-1] + [tshape[-1] - 1]
rand_shape = ps.tensor_scatter_nd_sub(
ps.concat([[n], tshape], 0), indices=[-1], updates=[1]
)
rand_values = samplers.uniform(
rand_shape,
minval=dtype_util.as_numpy_dtype(self.dtype)(0.0),
maxval=dtype_util.as_numpy_dtype(self.dtype)(1.0),
dtype=self.dtype,
seed=seed,
)
# sentinel_shape = [n] + tshape[:-1] + [1]
sentinel_shape = ps.tensor_scatter_nd_update(
ps.concat([[n], tshape], 0), indices=[-1], updates=[1]
)
padded_values = tf.concat(
[
tf.zeros(sentinel_shape, dtype=self.dtype),
rand_values,
tf.ones(sentinel_shape, dtype=self.dtype),
],
axis=-1,
)
sorted_values = tf.sort(padded_values, axis=-1)
value_diffs = sorted_values[..., 1:] - sorted_values[..., :-1]
return value_diffs
74 changes: 73 additions & 1 deletion tensorflow_probability/python/distributions/dirichlet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# ============================================================================

# Dependency imports
from absl.testing import parameterized
import numpy as np
from scipy import special as sp_special
from scipy import stats as sp_stats

import tensorflow.compat.v2 as tf
from tensorflow_probability.python.bijectors import exp
from tensorflow_probability.python.distributions import dirichlet
Expand Down Expand Up @@ -339,5 +339,77 @@ def testAssertions(self):
self.evaluate(d.entropy())


@test_util.test_all_tf_execution_regimes
class FlatDirichletTest(test_util.TestCase):

@parameterized.parameters(
{'tshape': (3,)}, {'tshape': (2, 3)}, {'tshape': (5, 1, 10)}
)
def testSamplesHaveRightShape(self, tshape):
fd = dirichlet.FlatDirichlet(concentration_shape=tshape)
self.assertAllEqual(fd.batch_shape, tshape[:-1])
self.assertAllEqual(fd.event_shape, tshape[-1:])
sample = fd.sample(1, seed=test_util.test_seed())
self.assertAllEqual([1] + list(tshape), sample.shape)
sample2 = fd.sample([4, 5], seed=test_util.test_seed())
self.assertAllEqual([4, 5] + list(tshape), sample2.shape)

@parameterized.parameters(
{'tshape': (3,)}, {'tshape': (2, 3)}, {'tshape': (5, 1, 10)}
)
def testSamplesSumToOne(self, tshape):
fd = dirichlet.FlatDirichlet(concentration_shape=tshape)
sample = fd.sample(1, seed=test_util.test_seed())
self.assertAllClose(
tf.math.reduce_sum(sample, axis=-1),
tf.ones(shape=[1] + list(tshape)[:-1]),
)

@test_util.disable_test_for_backend(
disable_numpy=True, reason='Uses jit_compile'
)
def testSampleNJits(self):
@tf.function(jit_compile=True)
def f(x):
fd = dirichlet.FlatDirichlet(concentration_shape=(5,))
sample = fd.sample(1, seed=test_util.test_seed())
return sample + x

self.assertAllEqual([1, 5], f(0.1).shape)

def testSampleMoments(self):
fd = dirichlet.FlatDirichlet(concentration_shape=(3,))
samples = fd.sample(1000, seed=test_util.test_seed())
mean = tf.math.reduce_mean(samples, axis=0)
self.assertAllClose(mean, [1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0], atol=2e-2)
centered = samples - tf.ones(shape=(1, 3)) / 3.0
var = tf.math.reduce_mean(centered * centered, axis=0)
# https://en.wikipedia.org/wiki/Dirichlet_distribution#Properties says
# Var = alpha_i (alpha_0 - alpha_i) / ( alpha_0^2 (alpha_0 + 1))
# = (n - 1) / (n^2 (n+1))
self.assertAllClose(var, [1.0 / 18.0, 1.0 / 18.0, 1.0 / 18.0], atol=2e-2)

def testLogProb(self):
fd = dirichlet.FlatDirichlet(concentration_shape=(5,))
self.assertAllClose(
fd.log_prob(tf.constant([0.2, 0.2, 0.2, 0.2, 0.2])),
tf.math.log(24.0)
)

def testLogProbOutsideSupport(self):
fd = dirichlet.FlatDirichlet(concentration_shape=(5,),
force_probs_to_zero_outside_support=True)
self.assertAllEqual(fd.log_prob(tf.ones(shape=(5,))), -float('inf'))

@parameterized.parameters(
{'n': 2}, {'n': 3}, {'n': 4}, {'n': 5}, {'n': 6},
)
def testLogProbSameAsDirichlet(self, n):
fd = dirichlet.FlatDirichlet(concentration_shape=(n,))
d = dirichlet.Dirichlet(concentration=tf.ones(shape=(n,)))
p = tf.ones(shape=n) / float(n)
self.assertAllClose(d.log_prob(p), fd.log_prob(p))


if __name__ == '__main__':
test_util.main()
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,11 @@ def testCanConstructAndSampleDistribution(self, data):
'rtol',
'eigenvectors', # TODO(b/171872834): DeterminantalPointProcess
'total_count',
'concentration_shape',
'num_samples',
'df', # Can't represent constraint that Wishart df > dimension.
'mean_direction') # TODO(b/118492439): Add `UnitVector` bijector.
'mean_direction',
) # TODO(b/118492439): Add `UnitVector` bijector.
non_trainable_non_tensor_params = (
'batch_shape', # SphericalUniform, at least, has explicit batch shape
'dimension',
Expand Down
Loading

0 comments on commit e3cf1f4

Please sign in to comment.