Skip to content

Commit e3cf1f4

Browse files
Add FlatDirichlet distribution.
PiperOrigin-RevId: 613223674
1 parent ae11c4c commit e3cf1f4

File tree

7 files changed

+314
-180
lines changed

7 files changed

+314
-180
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
# Contains ops for statistical distributions (with pdf, cdf, sample, etc...).
1717
# APIs here are meant to evolve over time.
1818

19+
# Placeholder: py_binary
1920
# Placeholder: py_library
2021
# Placeholder: py_test
21-
# Placeholder: py_binary
2222
load(
2323
"//tensorflow_probability/python:build_defs.bzl",
2424
"multi_substrate_py_library",
@@ -507,6 +507,7 @@ multi_substrate_py_library(
507507
"//tensorflow_probability/python/internal:parameter_properties",
508508
"//tensorflow_probability/python/internal:prefer_static",
509509
"//tensorflow_probability/python/internal:reparameterization",
510+
"//tensorflow_probability/python/internal:samplers",
510511
"//tensorflow_probability/python/internal:tensor_util",
511512
"//tensorflow_probability/python/internal:tensorshape_util",
512513
],

tensorflow_probability/python/distributions/TAXONOMY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ ZeroInflatedNegativeBinomial,
180180
### Distributions over Unit Simplex in R^n
181181

182182
[Dirichlet](https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Dirichlet)
183+
[FlatDirichlet](https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/FlatDirichlet)
183184
[RelaxedOneHotCategorical](https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/RelaxedOneHotCategorical)
184185

185186
### Distributions over Matrices

tensorflow_probability/python/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from tensorflow_probability.python.distributions.deterministic import Deterministic
3737
from tensorflow_probability.python.distributions.deterministic import VectorDeterministic
3838
from tensorflow_probability.python.distributions.dirichlet import Dirichlet
39+
from tensorflow_probability.python.distributions.dirichlet import FlatDirichlet
3940
from tensorflow_probability.python.distributions.dirichlet_multinomial import DirichletMultinomial
4041
from tensorflow_probability.python.distributions.distribution import AutoCompositeTensorDistribution
4142
from tensorflow_probability.python.distributions.distribution import Distribution
@@ -197,6 +198,7 @@
197198
'ExponentiallyModifiedGaussian',
198199
'ExpRelaxedOneHotCategorical',
199200
'FiniteDiscrete',
201+
'FlatDirichlet',
200202
'FULLY_REPARAMETERIZED',
201203
'Gamma',
202204
'GammaGamma',

tensorflow_probability/python/distributions/dirichlet.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
# Dependency imports
1818
import numpy as np
1919
import tensorflow.compat.v2 as tf
20-
2120
from tensorflow_probability.python.bijectors import softmax_centered as softmax_centered_bijector
2221
from tensorflow_probability.python.bijectors import softplus as softplus_bijector
2322
from tensorflow_probability.python.distributions import distribution
@@ -29,12 +28,14 @@
2928
from tensorflow_probability.python.internal import parameter_properties
3029
from tensorflow_probability.python.internal import prefer_static as ps
3130
from tensorflow_probability.python.internal import reparameterization
31+
from tensorflow_probability.python.internal import samplers
3232
from tensorflow_probability.python.internal import tensor_util
3333
from tensorflow_probability.python.internal import tensorshape_util
3434

3535

3636
__all__ = [
3737
'Dirichlet',
38+
'FlatDirichlet',
3839
]
3940

4041

@@ -450,3 +451,122 @@ def _kl_dirichlet_dirichlet(d1, d2, name=None):
450451
return (
451452
tf.reduce_sum(concentration_diff * digamma_diff, axis=-1) -
452453
tf.math.lbeta(concentration1) + tf.math.lbeta(concentration2))
454+
455+
456+
class FlatDirichlet(Dirichlet):
457+
"""Special case of Dirichlet for concentration = 1.
458+
459+
This case is both frequent and admits a more efficient sampling algorithm.
460+
"""
461+
462+
def __init__(
463+
self,
464+
concentration_shape,
465+
dtype=tf.float32,
466+
validate_args=False,
467+
allow_nan_stats=True,
468+
force_probs_to_zero_outside_support=False,
469+
name='FlatDirichlet',
470+
):
471+
"""Initialize a batch of FlatDirichlet distributions.
472+
473+
Args:
474+
concentration_shape: Integer `Tensor` shape of the concentration
475+
parameter.
476+
dtype: The dtype of the distribution.
477+
validate_args: Python `bool`, default `False`. When `True` distribution
478+
parameters are checked for validity despite possibly degrading runtime
479+
performance. When `False` invalid inputs may silently render incorrect
480+
outputs.
481+
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
482+
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
483+
result is undefined. When `False`, an exception is raised if one or more
484+
of the statistic's batch members are undefined.
485+
force_probs_to_zero_outside_support: If `True`, force `prob(x) == 0` and
486+
`log_prob(x) == -inf` for values of x outside the distribution support.
487+
name: Python `str` name prefixed to Ops created by this class.
488+
"""
489+
parameters = dict(locals())
490+
self._concentration_shape = tensor_util.convert_nonref_to_tensor(
491+
concentration_shape,
492+
dtype=tf.int32,
493+
name='concentration_shape',
494+
as_shape_tensor=True,
495+
)
496+
self._concentration_shape_static = tensorshape_util.constant_value_as_shape(
497+
self._concentration_shape
498+
)
499+
concentration = tf.ones(concentration_shape, dtype=dtype)
500+
super(FlatDirichlet, self).__init__(
501+
concentration=concentration,
502+
validate_args=validate_args,
503+
allow_nan_stats=allow_nan_stats,
504+
force_probs_to_zero_outside_support=force_probs_to_zero_outside_support,
505+
name=name,
506+
)
507+
self._parameters = parameters
508+
509+
@classmethod
510+
def _parameter_properties(cls, dtype, num_classes=None):
511+
return dict(
512+
concentration_shape=parameter_properties.ShapeParameterProperties()
513+
)
514+
515+
@property
516+
def concentration_shape(self):
517+
return self._concentration_shape
518+
519+
def _batch_shape_tensor(self):
520+
return tf.constant(self._concentration_shape[:-1], dtype=tf.int32)
521+
522+
def _batch_shape(self):
523+
return tf.TensorShape(self._concentration_shape_static[:-1])
524+
525+
def _event_shape_tensor(self):
526+
return tf.constant(self._concentration_shape[-1], dtype=tf.int32)
527+
528+
def _event_shape(self):
529+
return tf.TensorShape([self._concentration_shape_static[-1]])
530+
531+
def _log_prob(self, x):
532+
# The pdf of a flat dirichlet is just Gamma(n).
533+
n = tf.cast(self._concentration_shape[-1], dtype=tf.float32)
534+
lp = tf.math.lgamma(n)
535+
if self._force_probs_to_zero_outside_support:
536+
eps = np.finfo(dtype_util.as_numpy_dtype(x.dtype)).eps
537+
in_support = (
538+
tf.reduce_all(x >= 0, axis=-1) &
539+
# Reusing the logic of tf.debugging.assert_near, 10 * np.finfo.eps
540+
(tf.math.abs(tf.reduce_sum(x, axis=-1) - 1.) < 10 * eps))
541+
return tf.where(in_support, lp, -float('inf'))
542+
return lp
543+
544+
def _sample_n(self, n, seed=None):
545+
# https://en.wikipedia.org/wiki/Dirichlet_distribution#When_each_alpha_is_1
546+
tshape = self._concentration_shape
547+
# rand_shape = [n] + tshape[:-1] + [tshape[-1] - 1]
548+
rand_shape = ps.tensor_scatter_nd_sub(
549+
ps.concat([[n], tshape], 0), indices=[-1], updates=[1]
550+
)
551+
rand_values = samplers.uniform(
552+
rand_shape,
553+
minval=dtype_util.as_numpy_dtype(self.dtype)(0.0),
554+
maxval=dtype_util.as_numpy_dtype(self.dtype)(1.0),
555+
dtype=self.dtype,
556+
seed=seed,
557+
)
558+
# sentinel_shape = [n] + tshape[:-1] + [1]
559+
sentinel_shape = ps.tensor_scatter_nd_update(
560+
ps.concat([[n], tshape], 0), indices=[-1], updates=[1]
561+
)
562+
padded_values = tf.concat(
563+
[
564+
tf.zeros(sentinel_shape, dtype=self.dtype),
565+
rand_values,
566+
tf.ones(sentinel_shape, dtype=self.dtype),
567+
],
568+
axis=-1,
569+
)
570+
sorted_values = tf.sort(padded_values, axis=-1)
571+
value_diffs = sorted_values[..., 1:] - sorted_values[..., :-1]
572+
return value_diffs

tensorflow_probability/python/distributions/dirichlet_test.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
# ============================================================================
1515

1616
# Dependency imports
17+
from absl.testing import parameterized
1718
import numpy as np
1819
from scipy import special as sp_special
1920
from scipy import stats as sp_stats
20-
2121
import tensorflow.compat.v2 as tf
2222
from tensorflow_probability.python.bijectors import exp
2323
from tensorflow_probability.python.distributions import dirichlet
@@ -339,5 +339,77 @@ def testAssertions(self):
339339
self.evaluate(d.entropy())
340340

341341

342+
@test_util.test_all_tf_execution_regimes
343+
class FlatDirichletTest(test_util.TestCase):
344+
345+
@parameterized.parameters(
346+
{'tshape': (3,)}, {'tshape': (2, 3)}, {'tshape': (5, 1, 10)}
347+
)
348+
def testSamplesHaveRightShape(self, tshape):
349+
fd = dirichlet.FlatDirichlet(concentration_shape=tshape)
350+
self.assertAllEqual(fd.batch_shape, tshape[:-1])
351+
self.assertAllEqual(fd.event_shape, tshape[-1:])
352+
sample = fd.sample(1, seed=test_util.test_seed())
353+
self.assertAllEqual([1] + list(tshape), sample.shape)
354+
sample2 = fd.sample([4, 5], seed=test_util.test_seed())
355+
self.assertAllEqual([4, 5] + list(tshape), sample2.shape)
356+
357+
@parameterized.parameters(
358+
{'tshape': (3,)}, {'tshape': (2, 3)}, {'tshape': (5, 1, 10)}
359+
)
360+
def testSamplesSumToOne(self, tshape):
361+
fd = dirichlet.FlatDirichlet(concentration_shape=tshape)
362+
sample = fd.sample(1, seed=test_util.test_seed())
363+
self.assertAllClose(
364+
tf.math.reduce_sum(sample, axis=-1),
365+
tf.ones(shape=[1] + list(tshape)[:-1]),
366+
)
367+
368+
@test_util.disable_test_for_backend(
369+
disable_numpy=True, reason='Uses jit_compile'
370+
)
371+
def testSampleNJits(self):
372+
@tf.function(jit_compile=True)
373+
def f(x):
374+
fd = dirichlet.FlatDirichlet(concentration_shape=(5,))
375+
sample = fd.sample(1, seed=test_util.test_seed())
376+
return sample + x
377+
378+
self.assertAllEqual([1, 5], f(0.1).shape)
379+
380+
def testSampleMoments(self):
381+
fd = dirichlet.FlatDirichlet(concentration_shape=(3,))
382+
samples = fd.sample(1000, seed=test_util.test_seed())
383+
mean = tf.math.reduce_mean(samples, axis=0)
384+
self.assertAllClose(mean, [1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0], atol=2e-2)
385+
centered = samples - tf.ones(shape=(1, 3)) / 3.0
386+
var = tf.math.reduce_mean(centered * centered, axis=0)
387+
# https://en.wikipedia.org/wiki/Dirichlet_distribution#Properties says
388+
# Var = alpha_i (alpha_0 - alpha_i) / ( alpha_0^2 (alpha_0 + 1))
389+
# = (n - 1) / (n^2 (n+1))
390+
self.assertAllClose(var, [1.0 / 18.0, 1.0 / 18.0, 1.0 / 18.0], atol=2e-2)
391+
392+
def testLogProb(self):
393+
fd = dirichlet.FlatDirichlet(concentration_shape=(5,))
394+
self.assertAllClose(
395+
fd.log_prob(tf.constant([0.2, 0.2, 0.2, 0.2, 0.2])),
396+
tf.math.log(24.0)
397+
)
398+
399+
def testLogProbOutsideSupport(self):
400+
fd = dirichlet.FlatDirichlet(concentration_shape=(5,),
401+
force_probs_to_zero_outside_support=True)
402+
self.assertAllEqual(fd.log_prob(tf.ones(shape=(5,))), -float('inf'))
403+
404+
@parameterized.parameters(
405+
{'n': 2}, {'n': 3}, {'n': 4}, {'n': 5}, {'n': 6},
406+
)
407+
def testLogProbSameAsDirichlet(self, n):
408+
fd = dirichlet.FlatDirichlet(concentration_shape=(n,))
409+
d = dirichlet.Dirichlet(concentration=tf.ones(shape=(n,)))
410+
p = tf.ones(shape=n) / float(n)
411+
self.assertAllClose(d.log_prob(p), fd.log_prob(p))
412+
413+
342414
if __name__ == '__main__':
343415
test_util.main()

tensorflow_probability/python/distributions/distribution_properties_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,11 @@ def testCanConstructAndSampleDistribution(self, data):
431431
'rtol',
432432
'eigenvectors', # TODO(b/171872834): DeterminantalPointProcess
433433
'total_count',
434+
'concentration_shape',
434435
'num_samples',
435436
'df', # Can't represent constraint that Wishart df > dimension.
436-
'mean_direction') # TODO(b/118492439): Add `UnitVector` bijector.
437+
'mean_direction',
438+
) # TODO(b/118492439): Add `UnitVector` bijector.
437439
non_trainable_non_tensor_params = (
438440
'batch_shape', # SphericalUniform, at least, has explicit batch shape
439441
'dimension',

0 commit comments

Comments
 (0)