diff --git a/tensorflow_probability/python/distributions/BUILD b/tensorflow_probability/python/distributions/BUILD index 228f0a6532..26f80786dd 100644 --- a/tensorflow_probability/python/distributions/BUILD +++ b/tensorflow_probability/python/distributions/BUILD @@ -16,9 +16,9 @@ # Contains ops for statistical distributions (with pdf, cdf, sample, etc...). # APIs here are meant to evolve over time. +# Placeholder: py_binary # Placeholder: py_library # Placeholder: py_test -# Placeholder: py_binary load( "//tensorflow_probability/python:build_defs.bzl", "multi_substrate_py_library", @@ -507,6 +507,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:reparameterization", + "//tensorflow_probability/python/internal:samplers", "//tensorflow_probability/python/internal:tensor_util", "//tensorflow_probability/python/internal:tensorshape_util", ], diff --git a/tensorflow_probability/python/distributions/TAXONOMY.md b/tensorflow_probability/python/distributions/TAXONOMY.md index bc79313d3e..070a9cad7f 100644 --- a/tensorflow_probability/python/distributions/TAXONOMY.md +++ b/tensorflow_probability/python/distributions/TAXONOMY.md @@ -180,6 +180,7 @@ ZeroInflatedNegativeBinomial, ### Distributions over Unit Simplex in R^n [Dirichlet](https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Dirichlet) +[FlatDirichlet](https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/FlatDirichlet) [RelaxedOneHotCategorical](https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/RelaxedOneHotCategorical) ### Distributions over Matrices diff --git a/tensorflow_probability/python/distributions/__init__.py b/tensorflow_probability/python/distributions/__init__.py index 3c36089cc8..6160c7d2f6 100644 --- a/tensorflow_probability/python/distributions/__init__.py +++ b/tensorflow_probability/python/distributions/__init__.py @@ -36,6 +36,7 @@ from tensorflow_probability.python.distributions.deterministic import Deterministic from tensorflow_probability.python.distributions.deterministic import VectorDeterministic from tensorflow_probability.python.distributions.dirichlet import Dirichlet +from tensorflow_probability.python.distributions.dirichlet import FlatDirichlet from tensorflow_probability.python.distributions.dirichlet_multinomial import DirichletMultinomial from tensorflow_probability.python.distributions.distribution import AutoCompositeTensorDistribution from tensorflow_probability.python.distributions.distribution import Distribution @@ -197,6 +198,7 @@ 'ExponentiallyModifiedGaussian', 'ExpRelaxedOneHotCategorical', 'FiniteDiscrete', + 'FlatDirichlet', 'FULLY_REPARAMETERIZED', 'Gamma', 'GammaGamma', diff --git a/tensorflow_probability/python/distributions/dirichlet.py b/tensorflow_probability/python/distributions/dirichlet.py index 8f2a39d8e6..3b6eb57ca6 100644 --- a/tensorflow_probability/python/distributions/dirichlet.py +++ b/tensorflow_probability/python/distributions/dirichlet.py @@ -17,7 +17,6 @@ # Dependency imports import numpy as np import tensorflow.compat.v2 as tf - from tensorflow_probability.python.bijectors import softmax_centered as softmax_centered_bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector from tensorflow_probability.python.distributions import distribution @@ -29,12 +28,14 @@ from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import reparameterization +from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.internal import tensor_util from tensorflow_probability.python.internal import tensorshape_util __all__ = [ 'Dirichlet', + 'FlatDirichlet', ] @@ -450,3 +451,122 @@ def _kl_dirichlet_dirichlet(d1, d2, name=None): return ( tf.reduce_sum(concentration_diff * digamma_diff, axis=-1) - tf.math.lbeta(concentration1) + tf.math.lbeta(concentration2)) + + +class FlatDirichlet(Dirichlet): + """Special case of Dirichlet for concentration = 1. + + This case is both frequent and admits a more efficient sampling algorithm. + """ + + def __init__( + self, + concentration_shape, + dtype=tf.float32, + validate_args=False, + allow_nan_stats=True, + force_probs_to_zero_outside_support=False, + name='FlatDirichlet', + ): + """Initialize a batch of FlatDirichlet distributions. + + Args: + concentration_shape: Integer `Tensor` shape of the concentration + parameter. + dtype: The dtype of the distribution. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or more + of the statistic's batch members are undefined. + force_probs_to_zero_outside_support: If `True`, force `prob(x) == 0` and + `log_prob(x) == -inf` for values of x outside the distribution support. + name: Python `str` name prefixed to Ops created by this class. + """ + parameters = dict(locals()) + self._concentration_shape = tensor_util.convert_nonref_to_tensor( + concentration_shape, + dtype=tf.int32, + name='concentration_shape', + as_shape_tensor=True, + ) + self._concentration_shape_static = tensorshape_util.constant_value_as_shape( + self._concentration_shape + ) + concentration = tf.ones(concentration_shape, dtype=dtype) + super(FlatDirichlet, self).__init__( + concentration=concentration, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + force_probs_to_zero_outside_support=force_probs_to_zero_outside_support, + name=name, + ) + self._parameters = parameters + + @classmethod + def _parameter_properties(cls, dtype, num_classes=None): + return dict( + concentration_shape=parameter_properties.ShapeParameterProperties() + ) + + @property + def concentration_shape(self): + return self._concentration_shape + + def _batch_shape_tensor(self): + return tf.constant(self._concentration_shape[:-1], dtype=tf.int32) + + def _batch_shape(self): + return tf.TensorShape(self._concentration_shape_static[:-1]) + + def _event_shape_tensor(self): + return tf.constant(self._concentration_shape[-1], dtype=tf.int32) + + def _event_shape(self): + return tf.TensorShape([self._concentration_shape_static[-1]]) + + def _log_prob(self, x): + # The pdf of a flat dirichlet is just Gamma(n). + n = tf.cast(self._concentration_shape[-1], dtype=tf.float32) + lp = tf.math.lgamma(n) + if self._force_probs_to_zero_outside_support: + eps = np.finfo(dtype_util.as_numpy_dtype(x.dtype)).eps + in_support = ( + tf.reduce_all(x >= 0, axis=-1) & + # Reusing the logic of tf.debugging.assert_near, 10 * np.finfo.eps + (tf.math.abs(tf.reduce_sum(x, axis=-1) - 1.) < 10 * eps)) + return tf.where(in_support, lp, -float('inf')) + return lp + + def _sample_n(self, n, seed=None): + # https://en.wikipedia.org/wiki/Dirichlet_distribution#When_each_alpha_is_1 + tshape = self._concentration_shape + # rand_shape = [n] + tshape[:-1] + [tshape[-1] - 1] + rand_shape = ps.tensor_scatter_nd_sub( + ps.concat([[n], tshape], 0), indices=[-1], updates=[1] + ) + rand_values = samplers.uniform( + rand_shape, + minval=dtype_util.as_numpy_dtype(self.dtype)(0.0), + maxval=dtype_util.as_numpy_dtype(self.dtype)(1.0), + dtype=self.dtype, + seed=seed, + ) + # sentinel_shape = [n] + tshape[:-1] + [1] + sentinel_shape = ps.tensor_scatter_nd_update( + ps.concat([[n], tshape], 0), indices=[-1], updates=[1] + ) + padded_values = tf.concat( + [ + tf.zeros(sentinel_shape, dtype=self.dtype), + rand_values, + tf.ones(sentinel_shape, dtype=self.dtype), + ], + axis=-1, + ) + sorted_values = tf.sort(padded_values, axis=-1) + value_diffs = sorted_values[..., 1:] - sorted_values[..., :-1] + return value_diffs diff --git a/tensorflow_probability/python/distributions/dirichlet_test.py b/tensorflow_probability/python/distributions/dirichlet_test.py index 31c655f65d..bb96c2776e 100644 --- a/tensorflow_probability/python/distributions/dirichlet_test.py +++ b/tensorflow_probability/python/distributions/dirichlet_test.py @@ -14,10 +14,10 @@ # ============================================================================ # Dependency imports +from absl.testing import parameterized import numpy as np from scipy import special as sp_special from scipy import stats as sp_stats - import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import exp from tensorflow_probability.python.distributions import dirichlet @@ -339,5 +339,77 @@ def testAssertions(self): self.evaluate(d.entropy()) +@test_util.test_all_tf_execution_regimes +class FlatDirichletTest(test_util.TestCase): + + @parameterized.parameters( + {'tshape': (3,)}, {'tshape': (2, 3)}, {'tshape': (5, 1, 10)} + ) + def testSamplesHaveRightShape(self, tshape): + fd = dirichlet.FlatDirichlet(concentration_shape=tshape) + self.assertAllEqual(fd.batch_shape, tshape[:-1]) + self.assertAllEqual(fd.event_shape, tshape[-1:]) + sample = fd.sample(1, seed=test_util.test_seed()) + self.assertAllEqual([1] + list(tshape), sample.shape) + sample2 = fd.sample([4, 5], seed=test_util.test_seed()) + self.assertAllEqual([4, 5] + list(tshape), sample2.shape) + + @parameterized.parameters( + {'tshape': (3,)}, {'tshape': (2, 3)}, {'tshape': (5, 1, 10)} + ) + def testSamplesSumToOne(self, tshape): + fd = dirichlet.FlatDirichlet(concentration_shape=tshape) + sample = fd.sample(1, seed=test_util.test_seed()) + self.assertAllClose( + tf.math.reduce_sum(sample, axis=-1), + tf.ones(shape=[1] + list(tshape)[:-1]), + ) + + @test_util.disable_test_for_backend( + disable_numpy=True, reason='Uses jit_compile' + ) + def testSampleNJits(self): + @tf.function(jit_compile=True) + def f(x): + fd = dirichlet.FlatDirichlet(concentration_shape=(5,)) + sample = fd.sample(1, seed=test_util.test_seed()) + return sample + x + + self.assertAllEqual([1, 5], f(0.1).shape) + + def testSampleMoments(self): + fd = dirichlet.FlatDirichlet(concentration_shape=(3,)) + samples = fd.sample(1000, seed=test_util.test_seed()) + mean = tf.math.reduce_mean(samples, axis=0) + self.assertAllClose(mean, [1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0], atol=2e-2) + centered = samples - tf.ones(shape=(1, 3)) / 3.0 + var = tf.math.reduce_mean(centered * centered, axis=0) + # https://en.wikipedia.org/wiki/Dirichlet_distribution#Properties says + # Var = alpha_i (alpha_0 - alpha_i) / ( alpha_0^2 (alpha_0 + 1)) + # = (n - 1) / (n^2 (n+1)) + self.assertAllClose(var, [1.0 / 18.0, 1.0 / 18.0, 1.0 / 18.0], atol=2e-2) + + def testLogProb(self): + fd = dirichlet.FlatDirichlet(concentration_shape=(5,)) + self.assertAllClose( + fd.log_prob(tf.constant([0.2, 0.2, 0.2, 0.2, 0.2])), + tf.math.log(24.0) + ) + + def testLogProbOutsideSupport(self): + fd = dirichlet.FlatDirichlet(concentration_shape=(5,), + force_probs_to_zero_outside_support=True) + self.assertAllEqual(fd.log_prob(tf.ones(shape=(5,))), -float('inf')) + + @parameterized.parameters( + {'n': 2}, {'n': 3}, {'n': 4}, {'n': 5}, {'n': 6}, + ) + def testLogProbSameAsDirichlet(self, n): + fd = dirichlet.FlatDirichlet(concentration_shape=(n,)) + d = dirichlet.Dirichlet(concentration=tf.ones(shape=(n,))) + p = tf.ones(shape=n) / float(n) + self.assertAllClose(d.log_prob(p), fd.log_prob(p)) + + if __name__ == '__main__': test_util.main() diff --git a/tensorflow_probability/python/distributions/distribution_properties_test.py b/tensorflow_probability/python/distributions/distribution_properties_test.py index 89f9974927..7fe221bbd8 100644 --- a/tensorflow_probability/python/distributions/distribution_properties_test.py +++ b/tensorflow_probability/python/distributions/distribution_properties_test.py @@ -431,9 +431,11 @@ def testCanConstructAndSampleDistribution(self, data): 'rtol', 'eigenvectors', # TODO(b/171872834): DeterminantalPointProcess 'total_count', + 'concentration_shape', 'num_samples', 'df', # Can't represent constraint that Wishart df > dimension. - 'mean_direction') # TODO(b/118492439): Add `UnitVector` bijector. + 'mean_direction', + ) # TODO(b/118492439): Add `UnitVector` bijector. non_trainable_non_tensor_params = ( 'batch_shape', # SphericalUniform, at least, has explicit batch shape 'dimension', diff --git a/tensorflow_probability/python/distributions/hypothesis_testlib.py b/tensorflow_probability/python/distributions/hypothesis_testlib.py index 16a5899602..52eeaaa253 100644 --- a/tensorflow_probability/python/distributions/hypothesis_testlib.py +++ b/tensorflow_probability/python/distributions/hypothesis_testlib.py @@ -195,195 +195,131 @@ def fix_bates(d): CONSTRAINTS = { - 'atol': - tf.math.softplus, - 'rtol': - tf.math.softplus, - 'Dirichlet.concentration': - tfp_hps.softplus_plus_eps(), - 'concentration': - tfp_hps.softplus_plus_eps(), - 'GeneralizedPareto.concentration': # Permits +ve and -ve concentrations. - lambda x: tf.math.tanh(x) * 0.24, - 'concentration0': - tfp_hps.softplus_plus_eps(), - 'concentration1': - tfp_hps.softplus_plus_eps(), - 'concentration0_numerator': - tfp_hps.softplus_plus_eps(), - 'concentration1_numerator': - tfp_hps.softplus_plus_eps(1.), - 'concentration0_denominator': - tfp_hps.softplus_plus_eps(), - 'concentration1_denominator': - tfp_hps.softplus_plus_eps(1.), - 'covariance_matrix': - tfp_hps.positive_definite, - 'df': - tfp_hps.softplus_plus_eps(), - 'DeterminantalPointProcess.eigenvalues': - tfp_hps.softplus_plus_eps(), - 'eigenvectors': - tfp_hps.orthonormal, - 'InverseGaussian.loc': - tfp_hps.softplus_plus_eps(), - 'JohnsonSU.tailweight': - tfp_hps.softplus_plus_eps(), - 'PowerSpherical.mean_direction': - lambda x: tf.math.l2_normalize(tf.math.sigmoid(x) + 1e-6, -1), - 'VonMisesFisher.mean_direction': # max ndims is 3 to avoid instability. - lambda x: tf.math.l2_normalize(tf.math.sigmoid(x[..., :3]) + 1e-6, -1), - 'Categorical.probs': - tf.math.softmax, - 'ExpRelaxedOneHotCategorical.probs': - tf.math.softmax, - 'RelaxedOneHotCategorical.probs': - tf.math.softmax, - 'FiniteDiscrete.probs': - tf.math.softmax, - 'Multinomial.probs': - tf.math.softmax, - 'OneHotCategorical.probs': - tf.math.softmax, - 'RelaxedCategorical.probs': - tf.math.softmax, + 'atol': tf.math.softplus, + 'rtol': tf.math.softplus, + 'Dirichlet.concentration': tfp_hps.softplus_plus_eps(), + 'concentration': tfp_hps.softplus_plus_eps(), + 'GeneralizedPareto.concentration': ( # Permits +ve and -ve concentrations. + lambda x: tf.math.tanh(x) * 0.24 + ), + 'concentration0': tfp_hps.softplus_plus_eps(), + 'concentration1': tfp_hps.softplus_plus_eps(), + 'concentration0_numerator': tfp_hps.softplus_plus_eps(), + 'concentration1_numerator': tfp_hps.softplus_plus_eps(1.0), + 'concentration0_denominator': tfp_hps.softplus_plus_eps(), + 'concentration1_denominator': tfp_hps.softplus_plus_eps(1.0), + 'covariance_matrix': tfp_hps.positive_definite, + 'df': tfp_hps.softplus_plus_eps(), + 'DeterminantalPointProcess.eigenvalues': tfp_hps.softplus_plus_eps(), + 'eigenvectors': tfp_hps.orthonormal, + 'InverseGaussian.loc': tfp_hps.softplus_plus_eps(), + 'JohnsonSU.tailweight': tfp_hps.softplus_plus_eps(), + 'PowerSpherical.mean_direction': lambda x: tf.math.l2_normalize( + tf.math.sigmoid(x) + 1e-6, -1 + ), + 'VonMisesFisher.mean_direction': ( # max ndims is 3 to avoid instability. + lambda x: tf.math.l2_normalize(tf.math.sigmoid(x[..., :3]) + 1e-6, -1) + ), + 'Categorical.probs': tf.math.softmax, + 'ExpRelaxedOneHotCategorical.probs': tf.math.softmax, + 'RelaxedOneHotCategorical.probs': tf.math.softmax, + 'FiniteDiscrete.probs': tf.math.softmax, + 'Multinomial.probs': tf.math.softmax, + 'OneHotCategorical.probs': tf.math.softmax, + 'RelaxedCategorical.probs': tf.math.softmax, 'Zipf.power': - # Strictly > 1. See also b/175929563 (rejection sampler - # iterates too much and emits `nan` for powers too close to 1). - tfp_hps.softplus_plus_eps(1 + 1e-4), - 'ContinuousBernoulli.probs': - tf.sigmoid, + # Strictly > 1. See also b/175929563 (rejection sampler + # iterates too much and emits `nan` for powers too close to 1). + tfp_hps.softplus_plus_eps(1 + 1e-4), + 'ContinuousBernoulli.probs': tf.sigmoid, 'Geometric.logits': # TODO(b/128410109): re-enable down to -50 - # Capping at 15. so that probability is less than 1, and entropy is - # defined. b/147394924 - lambda x: tf.minimum(tf.maximum(x, -16.), 15.), # works around the bug - 'Geometric.probs': - constrain_between_eps_and_one_minus_eps(), - 'Binomial.probs': - tf.sigmoid, + # Capping at 15. so that probability is less than 1, and entropy is + # defined. b/147394924 + lambda x: tf.minimum(tf.maximum(x, -16.0), 15.0), # works around the bug + 'Geometric.probs': constrain_between_eps_and_one_minus_eps(), + 'Binomial.probs': tf.sigmoid, # Constrain probs away from 0 to avoid immense samples. # See b/178842153. - 'NegativeBinomial.logits': - lambda x: tf.minimum(x, 15.), - 'NegativeBinomial.probs': - constrain_between_eps_and_one_minus_eps(eps0=0., eps1=1e-6), - 'Bernoulli.probs': - tf.sigmoid, - 'PlackettLuce.scores': - tfp_hps.softplus_plus_eps(), - 'ProbitBernoulli.probs': - tf.sigmoid, - 'RelaxedBernoulli.probs': - tf.sigmoid, + 'NegativeBinomial.logits': lambda x: tf.minimum(x, 15.0), + 'NegativeBinomial.probs': constrain_between_eps_and_one_minus_eps( + eps0=0.0, eps1=1e-6 + ), + 'Bernoulli.probs': tf.sigmoid, + 'PlackettLuce.scores': tfp_hps.softplus_plus_eps(), + 'ProbitBernoulli.probs': tf.sigmoid, + 'RelaxedBernoulli.probs': tf.sigmoid, 'cutpoints': - # Permit values that aren't too large - lambda x: ascending.Ascending().forward(10 * tf.math.tanh(x)), + # Permit values that aren't too large + lambda x: ascending.Ascending().forward(10 * tf.math.tanh(x)), # Capping log_rate because of weird semantics of Poisson with very # large rates (see b/178842153). - 'log_rate': - lambda x: tf.minimum(tf.maximum(x, -16.), 15.), + 'log_rate': lambda x: tf.minimum(tf.maximum(x, -16.0), 15.0), # Capping log_rate1 and log_rate2 to 15. This is because if both are large # (meaning the rates are `inf`), then the Skellam distribution is undefined. - 'log_rate1': - lambda x: tf.minimum(tf.maximum(x, -16.), 15.), - 'log_rate2': - lambda x: tf.minimum(tf.maximum(x, -16.), 15.), - 'log_scale': - lambda x: tf.maximum(x, -16.), - 'mixing_concentration': - tfp_hps.softplus_plus_eps(), - 'mixing_rate': - tfp_hps.softplus_plus_eps(), - 'rate': - tfp_hps.softplus_plus_eps(), - 'rate1': - tfp_hps.softplus_plus_eps(), - 'rate2': - tfp_hps.softplus_plus_eps(), - 'scale': - tfp_hps.softplus_plus_eps(), - 'GeneralizedPareto.scale': # Avoid underflow in support bijector. - tfp_hps.softplus_plus_eps(1e-2), - 'Wishart.scale': - tfp_hps.positive_definite, - 'scale_diag': - tfp_hps.softplus_plus_eps(), - 'scale_tril': - tfp_hps.lower_tril_positive_definite, - 'tailweight': - tfp_hps.softplus_plus_eps(), - 'temperature': - tfp_hps.softplus_plus_eps(), - 'total_count': - lambda x: tf.floor(tf.sigmoid(x / 100) * 100) + 1, - 'Bates': - fix_bates, - 'Bernoulli': - lambda d: dict(d, dtype=tf.float32), - 'CholeskyLKJ': - fix_lkj, - 'LKJ': - fix_lkj, + 'log_rate1': lambda x: tf.minimum(tf.maximum(x, -16.0), 15.0), + 'log_rate2': lambda x: tf.minimum(tf.maximum(x, -16.0), 15.0), + 'log_scale': lambda x: tf.maximum(x, -16.0), + 'mixing_concentration': tfp_hps.softplus_plus_eps(), + 'mixing_rate': tfp_hps.softplus_plus_eps(), + 'rate': tfp_hps.softplus_plus_eps(), + 'rate1': tfp_hps.softplus_plus_eps(), + 'rate2': tfp_hps.softplus_plus_eps(), + 'scale': tfp_hps.softplus_plus_eps(), + 'GeneralizedPareto.scale': ( # Avoid underflow in support bijector. + tfp_hps.softplus_plus_eps(1e-2) + ), + 'Wishart.scale': tfp_hps.positive_definite, + 'scale_diag': tfp_hps.softplus_plus_eps(), + 'scale_tril': tfp_hps.lower_tril_positive_definite, + 'tailweight': tfp_hps.softplus_plus_eps(), + 'temperature': tfp_hps.softplus_plus_eps(), + 'total_count': lambda x: tf.floor(tf.sigmoid(x / 100) * 100) + 1, + 'total_shape': tfp_hps.shapes(min_ndims=1, min_lastdimsize=2), + 'Bates': fix_bates, + 'Bernoulli': lambda d: dict(d, dtype=tf.float32), + 'CholeskyLKJ': fix_lkj, + 'LKJ': fix_lkj, 'MultivariateNormalDiagPlusLowRank.scale_diag': - # Ensure that the diagonal component is large enough to avoid being - # overwhelmed by the (singular) low-rank perturbation. - tfp_hps.softplus_plus_eps(1. + 1e-6), - 'MultivariateNormalDiagPlusLowRank.scale_perturb_diag': - tfp_hps.softplus_plus_eps(), + # Ensure that the diagonal component is large enough to avoid being + # overwhelmed by the (singular) low-rank perturbation. + tfp_hps.softplus_plus_eps(1.0 + 1e-6), + 'MultivariateNormalDiagPlusLowRank.scale_perturb_diag': ( + tfp_hps.softplus_plus_eps() + ), 'MultivariateNormalDiagPlusLowRank.scale_perturb_factor': - # Prevent large low-rank perturbations from creating numerically - # singular matrices. - tf.math.tanh, + # Prevent large low-rank perturbations from creating numerically + # singular matrices. + tf.math.tanh, 'MultivariateNormalDiagPlusLowRankCovariance.cov_diag_factor': - # Ensure that the diagonal component is large enough to avoid being - # overwhelmed by the (singular) low-rank perturbation. - tfp_hps.softplus_plus_eps(1. + 1e-6), + # Ensure that the diagonal component is large enough to avoid being + # overwhelmed by the (singular) low-rank perturbation. + tfp_hps.softplus_plus_eps(1.0 + 1e-6), 'MultivariateNormalDiagPlusLowRankCovariance.cov_perturb_factor': - # Prevent large low-rank perturbations from creating numerically - # singular matrices. - tf.math.tanh, - 'NormalInverseGaussian': - fix_normal_inverse_gaussian, - 'OrderedLogistic': - lambda d: dict(d, dtype=tf.float32), - 'OnehotCategorical': - lambda d: dict(d, dtype=tf.float32), - 'PERT': - fix_pert, - 'StoppingRatioLogistic': - lambda d: dict(d, dtype=tf.float32), - 'Triangular': - fix_triangular, - 'TruncatedCauchy': - lambda d: dict( # pylint:disable=g-long-lambda - d, - high=tfp_hps.ensure_high_gt_low( - d['low'], d['high'])), - 'TruncatedNormal': - fix_truncated_normal, - 'Uniform': - lambda d: dict( # pylint:disable=g-long-lambda - d, - high=tfp_hps.ensure_high_gt_low( - d['low'], d['high'])), - 'SphericalUniform': - fix_spherical_uniform, - 'Wishart': - fix_wishart, - 'WishartTriL': - fix_wishart, - 'Zipf': - lambda d: dict(d, dtype=tf.float32), - 'FiniteDiscrete': - fix_finite_discrete, - 'GeneralizedNormal.power': - tfp_hps.softplus_plus_eps(), - 'TwoPieceNormal.skewness': - tfp_hps.softplus_plus_eps(), - 'TwoPieceStudentT.skewness': - tfp_hps.softplus_plus_eps(), - 'NoncentralChi2.noncentrality': - tf.math.softplus, + # Prevent large low-rank perturbations from creating numerically + # singular matrices. + tf.math.tanh, + 'NormalInverseGaussian': fix_normal_inverse_gaussian, + 'OrderedLogistic': lambda d: dict(d, dtype=tf.float32), + 'OnehotCategorical': lambda d: dict(d, dtype=tf.float32), + 'PERT': fix_pert, + 'StoppingRatioLogistic': lambda d: dict(d, dtype=tf.float32), + 'Triangular': fix_triangular, + 'TruncatedCauchy': lambda d: dict( # pylint:disable=g-long-lambda + d, high=tfp_hps.ensure_high_gt_low(d['low'], d['high']) + ), + 'TruncatedNormal': fix_truncated_normal, + 'Uniform': lambda d: dict( # pylint:disable=g-long-lambda + d, high=tfp_hps.ensure_high_gt_low(d['low'], d['high']) + ), + 'SphericalUniform': fix_spherical_uniform, + 'Wishart': fix_wishart, + 'WishartTriL': fix_wishart, + 'Zipf': lambda d: dict(d, dtype=tf.float32), + 'FiniteDiscrete': fix_finite_discrete, + 'GeneralizedNormal.power': tfp_hps.softplus_plus_eps(), + 'TwoPieceNormal.skewness': tfp_hps.softplus_plus_eps(), + 'TwoPieceStudentT.skewness': tfp_hps.softplus_plus_eps(), + 'NoncentralChi2.noncentrality': tf.math.softplus, }