|
17 | 17 | # Dependency imports
|
18 | 18 | import numpy as np
|
19 | 19 | import tensorflow.compat.v2 as tf
|
20 |
| - |
21 | 20 | from tensorflow_probability.python.bijectors import softmax_centered as softmax_centered_bijector
|
22 | 21 | from tensorflow_probability.python.bijectors import softplus as softplus_bijector
|
23 | 22 | from tensorflow_probability.python.distributions import distribution
|
|
29 | 28 | from tensorflow_probability.python.internal import parameter_properties
|
30 | 29 | from tensorflow_probability.python.internal import prefer_static as ps
|
31 | 30 | from tensorflow_probability.python.internal import reparameterization
|
| 31 | +from tensorflow_probability.python.internal import samplers |
32 | 32 | from tensorflow_probability.python.internal import tensor_util
|
33 | 33 | from tensorflow_probability.python.internal import tensorshape_util
|
34 | 34 |
|
35 | 35 |
|
36 | 36 | __all__ = [
|
37 | 37 | 'Dirichlet',
|
| 38 | + 'FlatDirichlet', |
38 | 39 | ]
|
39 | 40 |
|
40 | 41 |
|
@@ -450,3 +451,122 @@ def _kl_dirichlet_dirichlet(d1, d2, name=None):
|
450 | 451 | return (
|
451 | 452 | tf.reduce_sum(concentration_diff * digamma_diff, axis=-1) -
|
452 | 453 | tf.math.lbeta(concentration1) + tf.math.lbeta(concentration2))
|
| 454 | + |
| 455 | + |
| 456 | +class FlatDirichlet(Dirichlet): |
| 457 | + """Special case of Dirichlet for concentration = 1. |
| 458 | +
|
| 459 | + This case is both frequent and admits a more efficient sampling algorithm. |
| 460 | + """ |
| 461 | + |
| 462 | + def __init__( |
| 463 | + self, |
| 464 | + concentration_shape, |
| 465 | + dtype=tf.float32, |
| 466 | + validate_args=False, |
| 467 | + allow_nan_stats=True, |
| 468 | + force_probs_to_zero_outside_support=False, |
| 469 | + name='FlatDirichlet', |
| 470 | + ): |
| 471 | + """Initialize a batch of FlatDirichlet distributions. |
| 472 | +
|
| 473 | + Args: |
| 474 | + concentration_shape: Integer `Tensor` shape of the concentration |
| 475 | + parameter. |
| 476 | + dtype: The dtype of the distribution. |
| 477 | + validate_args: Python `bool`, default `False`. When `True` distribution |
| 478 | + parameters are checked for validity despite possibly degrading runtime |
| 479 | + performance. When `False` invalid inputs may silently render incorrect |
| 480 | + outputs. |
| 481 | + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics |
| 482 | + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the |
| 483 | + result is undefined. When `False`, an exception is raised if one or more |
| 484 | + of the statistic's batch members are undefined. |
| 485 | + force_probs_to_zero_outside_support: If `True`, force `prob(x) == 0` and |
| 486 | + `log_prob(x) == -inf` for values of x outside the distribution support. |
| 487 | + name: Python `str` name prefixed to Ops created by this class. |
| 488 | + """ |
| 489 | + parameters = dict(locals()) |
| 490 | + self._concentration_shape = tensor_util.convert_nonref_to_tensor( |
| 491 | + concentration_shape, |
| 492 | + dtype=tf.int32, |
| 493 | + name='concentration_shape', |
| 494 | + as_shape_tensor=True, |
| 495 | + ) |
| 496 | + self._concentration_shape_static = tensorshape_util.constant_value_as_shape( |
| 497 | + self._concentration_shape |
| 498 | + ) |
| 499 | + concentration = tf.ones(concentration_shape, dtype=dtype) |
| 500 | + super(FlatDirichlet, self).__init__( |
| 501 | + concentration=concentration, |
| 502 | + validate_args=validate_args, |
| 503 | + allow_nan_stats=allow_nan_stats, |
| 504 | + force_probs_to_zero_outside_support=force_probs_to_zero_outside_support, |
| 505 | + name=name, |
| 506 | + ) |
| 507 | + self._parameters = parameters |
| 508 | + |
| 509 | + @classmethod |
| 510 | + def _parameter_properties(cls, dtype, num_classes=None): |
| 511 | + return dict( |
| 512 | + concentration_shape=parameter_properties.ShapeParameterProperties() |
| 513 | + ) |
| 514 | + |
| 515 | + @property |
| 516 | + def concentration_shape(self): |
| 517 | + return self._concentration_shape |
| 518 | + |
| 519 | + def _batch_shape_tensor(self): |
| 520 | + return tf.constant(self._concentration_shape[:-1], dtype=tf.int32) |
| 521 | + |
| 522 | + def _batch_shape(self): |
| 523 | + return tf.TensorShape(self._concentration_shape_static[:-1]) |
| 524 | + |
| 525 | + def _event_shape_tensor(self): |
| 526 | + return tf.constant(self._concentration_shape[-1], dtype=tf.int32) |
| 527 | + |
| 528 | + def _event_shape(self): |
| 529 | + return tf.TensorShape([self._concentration_shape_static[-1]]) |
| 530 | + |
| 531 | + def _log_prob(self, x): |
| 532 | + # The pdf of a flat dirichlet is just Gamma(n). |
| 533 | + n = tf.cast(self._concentration_shape[-1], dtype=tf.float32) |
| 534 | + lp = tf.math.lgamma(n) |
| 535 | + if self._force_probs_to_zero_outside_support: |
| 536 | + eps = np.finfo(dtype_util.as_numpy_dtype(x.dtype)).eps |
| 537 | + in_support = ( |
| 538 | + tf.reduce_all(x >= 0, axis=-1) & |
| 539 | + # Reusing the logic of tf.debugging.assert_near, 10 * np.finfo.eps |
| 540 | + (tf.math.abs(tf.reduce_sum(x, axis=-1) - 1.) < 10 * eps)) |
| 541 | + return tf.where(in_support, lp, -float('inf')) |
| 542 | + return lp |
| 543 | + |
| 544 | + def _sample_n(self, n, seed=None): |
| 545 | + # https://en.wikipedia.org/wiki/Dirichlet_distribution#When_each_alpha_is_1 |
| 546 | + tshape = self._concentration_shape |
| 547 | + # rand_shape = [n] + tshape[:-1] + [tshape[-1] - 1] |
| 548 | + rand_shape = ps.tensor_scatter_nd_sub( |
| 549 | + ps.concat([[n], tshape], 0), indices=[-1], updates=[1] |
| 550 | + ) |
| 551 | + rand_values = samplers.uniform( |
| 552 | + rand_shape, |
| 553 | + minval=dtype_util.as_numpy_dtype(self.dtype)(0.0), |
| 554 | + maxval=dtype_util.as_numpy_dtype(self.dtype)(1.0), |
| 555 | + dtype=self.dtype, |
| 556 | + seed=seed, |
| 557 | + ) |
| 558 | + # sentinel_shape = [n] + tshape[:-1] + [1] |
| 559 | + sentinel_shape = ps.tensor_scatter_nd_update( |
| 560 | + ps.concat([[n], tshape], 0), indices=[-1], updates=[1] |
| 561 | + ) |
| 562 | + padded_values = tf.concat( |
| 563 | + [ |
| 564 | + tf.zeros(sentinel_shape, dtype=self.dtype), |
| 565 | + rand_values, |
| 566 | + tf.ones(sentinel_shape, dtype=self.dtype), |
| 567 | + ], |
| 568 | + axis=-1, |
| 569 | + ) |
| 570 | + sorted_values = tf.sort(padded_values, axis=-1) |
| 571 | + value_diffs = sorted_values[..., 1:] - sorted_values[..., :-1] |
| 572 | + return value_diffs |
0 commit comments