diff --git a/tensorflow_probability/python/experimental/mcmc/BUILD b/tensorflow_probability/python/experimental/mcmc/BUILD index 7fb140497b..aa2843bfb9 100644 --- a/tensorflow_probability/python/experimental/mcmc/BUILD +++ b/tensorflow_probability/python/experimental/mcmc/BUILD @@ -18,8 +18,6 @@ # //tensorflow_probability/python/internal/auto_batching # internally. -# Placeholder: py_library -# Placeholder: py_test load( "//tensorflow_probability/python:build_defs.bzl", "multi_substrate_py_library", @@ -548,6 +546,9 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:tensor_util", "//tensorflow_probability/python/internal:tensorshape_util", + "//tensorflow_probability/python/distributions:batch_reshape", + "//tensorflow_probability/python/distributions:batch_broadcast", + "//tensorflow_probability/python/distributions:independent" ], ) @@ -574,6 +575,8 @@ multi_substrate_py_test( "//tensorflow_probability/python/distributions:sample", "//tensorflow_probability/python/distributions:transformed_distribution", "//tensorflow_probability/python/distributions:uniform", + "//tensorflow_probability/python/distributions:categorical", + "//tensorflow_probability/python/distributions:hidden_markov_model", "//tensorflow_probability/python/internal:test_util", "//tensorflow_probability/python/math:gradient", # "//third_party/tensorflow/compiler/jit:xla_cpu_jit", # DisableOnExport @@ -652,6 +655,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/distributions:mvn_diag", "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/distributions:sample", + "//tensorflow_probability/python/experimental/mcmc:sequential_monte_carlo_kernel", "//tensorflow_probability/python/distributions:uniform", "//tensorflow_probability/python/distributions/internal:statistical_testing", "//tensorflow_probability/python/internal:test_util", diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index 1bcbc870f4..b920b0db85 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -25,6 +25,11 @@ from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.mcmc.internal import util as mcmc_util +from tensorflow_probability.python.distributions import batch_reshape +from tensorflow_probability.python.distributions import batch_broadcast +from tensorflow_probability.python.distributions import normal +from tensorflow_probability.python.distributions import uniform + __all__ = [ 'infer_trajectories', @@ -44,6 +49,39 @@ def _default_trace_fn(state, kernel_results): kernel_results.incremental_log_marginal_likelihood) +def _default_kernel(parameters): + mean, variance = tf.nn.moments(parameters, axes=[0]) + proposal_distribution = normal.Normal(loc=tf.fill(parameters.shape, mean), scale=tf.sqrt(variance)) + return proposal_distribution + + +def _default_extra_fn(step, + state, + seed + ): + return state.extra + + +def where_fn(accept, a, b, num_outer_particles, num_inner_particles): + is_scalar = tf.rank(a) == tf.constant(0) + is_nan = tf.math.is_nan(tf.cast(a, tf.float32)) + is_all_nan = tf.reduce_all(is_nan) + if is_scalar and is_all_nan: + return a + elif a.shape == 2 and b.shape == 2: + # extra + return a + elif a.shape == num_outer_particles and b.shape == num_outer_particles: + return mcmc_util.choose(accept, a, b) + elif a.shape == [num_outer_particles, num_inner_particles] and \ + b.shape == [num_outer_particles, num_inner_particles]: + return mcmc_util.choose(accept, a, b) + elif a.shape == () and b.shape == (): + return a + else: + raise ValueError("Unexpected tensor shapes") + + particle_filter_arg_str = """\ Each latent state is a `Tensor` or nested structure of `Tensor`s, as defined by the `initial_state_prior`. @@ -435,6 +473,344 @@ def seeded_one_step(seed_state_results, _): return traced_results +def smc_squared( + inner_observations, + initial_parameter_prior, + num_outer_particles, + inner_initial_state_prior, + inner_transition_fn, + inner_observation_fn, + num_inner_particles, + outer_trace_fn=_default_trace_fn, + outer_rejuvenation_criterion_fn=None, + outer_resample_criterion_fn=None, + outer_resample_fn=weighted_resampling.resample_systematic, + inner_resample_criterion_fn=smc_kernel.ess_below_threshold, + inner_resample_fn=weighted_resampling.resample_systematic, + extra_fn=_default_extra_fn, + parameter_proposal_kernel=_default_kernel, + inner_proposal_fn=None, + inner_initial_state_proposal=None, + outer_trace_criterion_fn=_always_trace, + parallel_iterations=1, + num_transitions_per_observation=1, + static_trace_allocation_size=None, + initial_parameter_proposal=None, + unbiased_gradients=True, + seed=None, +): + init_seed, loop_seed, step_seed = samplers.split_seed(seed, n=3, salt='smc_squared') + + num_observation_steps = ps.size0(tf.nest.flatten(inner_observations)[0]) + + # TODO: The following two lines compensates for having the first empty step in smc2 + num_timesteps = (1 + num_transitions_per_observation * + (num_observation_steps - 1)) + 1 + last_obs_expanded = tf.expand_dims(inner_observations[-1], axis=0) + inner_observations = tf.concat([inner_observations, last_obs_expanded], axis=0) + + if outer_rejuvenation_criterion_fn is None: + outer_rejuvenation_criterion_fn = lambda *_: tf.constant(False) + + if outer_resample_criterion_fn is None: + outer_resample_criterion_fn = lambda *_: tf.constant(False) + + # If trace criterion is `None`, we'll return only the final results. + never_trace = lambda *_: False + if outer_trace_criterion_fn is None: + static_trace_allocation_size = 0 + outer_trace_criterion_fn = never_trace + + if initial_parameter_proposal is None: + initial_state = initial_parameter_prior.sample(num_outer_particles, + seed=seed) + initial_log_weights = ps.zeros_like( + initial_parameter_prior.log_prob(initial_state)) + else: + initial_state = initial_parameter_proposal.sample(num_outer_particles, + seed=seed) + initial_log_weights = ( + initial_parameter_prior.log_prob(initial_state) - + initial_parameter_proposal.log_prob(initial_state) + ) + + # Normalize the initial weights. If we used a proposal, the weights are + # normalized in expectation, but actually normalizing them reduces variance. + initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=0) + + inner_weighted_particles = _particle_filter_initial_weighted_particles( + observations=inner_observations, + observation_fn=inner_observation_fn(initial_state), + initial_state_prior=inner_initial_state_prior(0, initial_state), + initial_state_proposal=(inner_initial_state_proposal(0, initial_state) + if inner_initial_state_proposal is not None else None), + num_particles=num_inner_particles, + particles_dim=1, + seed=seed + ) + + init_state = smc_kernel.WeightedParticles(*inner_weighted_particles) + + batch_zeros = tf.zeros(ps.shape(initial_state)) + + initial_filter_results = smc_kernel.SequentialMonteCarloResults( + steps=0, + parent_indices=smc_kernel._dummy_indices_like(init_state.log_weights), + incremental_log_marginal_likelihood=batch_zeros, + accumulated_log_marginal_likelihood=batch_zeros, + seed=samplers.zeros_seed()) + + initial_state = smc_kernel.WeightedParticles( + particles=(initial_state, + inner_weighted_particles, + initial_filter_results.parent_indices, + initial_filter_results.incremental_log_marginal_likelihood, + initial_filter_results.accumulated_log_marginal_likelihood), + log_weights=initial_log_weights, + extra=(tf.constant(0), + initial_filter_results.seed) + ) + + outer_propose_and_update_log_weights_fn = ( + _outer_particle_filter_propose_and_update_log_weights_fn( + outer_rejuvenation_criterion_fn=outer_rejuvenation_criterion_fn, + inner_observations=inner_observations, + inner_transition_fn=inner_transition_fn, + inner_proposal_fn=inner_proposal_fn, + inner_observation_fn=inner_observation_fn, + inner_resample_fn=inner_resample_fn, + inner_resample_criterion_fn=inner_resample_criterion_fn, + parameter_proposal_kernel=parameter_proposal_kernel, + initial_parameter_prior=initial_parameter_prior, + num_transitions_per_observation=num_transitions_per_observation, + unbiased_gradients=unbiased_gradients, + inner_initial_state_prior=inner_initial_state_prior, + inner_initial_state_proposal=inner_initial_state_proposal, + num_inner_particles=num_inner_particles, + num_outer_particles=num_outer_particles, + extra_fn=extra_fn + ) + ) + + traced_results = sequential_monte_carlo( + initial_weighted_particles=initial_state, + propose_and_update_log_weights_fn=outer_propose_and_update_log_weights_fn, + resample_fn=outer_resample_fn, + resample_criterion_fn=outer_resample_criterion_fn, + trace_criterion_fn=outer_trace_criterion_fn, + static_trace_allocation_size=static_trace_allocation_size, + parallel_iterations=parallel_iterations, + unbiased_gradients=unbiased_gradients, + num_steps=num_timesteps, + particles_dim=0, + trace_fn=outer_trace_fn, + seed=loop_seed + ) + + return traced_results + + +def _outer_particle_filter_propose_and_update_log_weights_fn( + inner_observations, + inner_transition_fn, + inner_proposal_fn, + inner_observation_fn, + initial_parameter_prior, + inner_initial_state_prior, + inner_initial_state_proposal, + num_transitions_per_observation, + inner_resample_fn, + inner_resample_criterion_fn, + outer_rejuvenation_criterion_fn, + unbiased_gradients, + parameter_proposal_kernel, + num_inner_particles, + num_outer_particles, + extra_fn +): + """Build a function specifying a particle filter update step.""" + def _outer_propose_and_update_log_weights_fn(step, state, seed=None): + outside_parameters = state.particles[0] + inner_weighted_particles, log_weights = state.particles[1], state.log_weights + + filter_results = smc_kernel.SequentialMonteCarloResults( + steps=step, + parent_indices=state.particles[2], + incremental_log_marginal_likelihood=state.particles[3], + accumulated_log_marginal_likelihood=state.particles[4], + seed=state.extra[1]) + + inner_propose_and_update_log_weights_fn = ( + _particle_filter_propose_and_update_log_weights_fn( + observations=inner_observations, + transition_fn=inner_transition_fn(outside_parameters), + proposal_fn=(inner_proposal_fn(outside_parameters) + if inner_proposal_fn is not None else None), + observation_fn=inner_observation_fn(outside_parameters), + particles_dim=1, + num_transitions_per_observation=num_transitions_per_observation, + extra_fn=extra_fn + ) + ) + + kernel = smc_kernel.SequentialMonteCarlo( + propose_and_update_log_weights_fn=inner_propose_and_update_log_weights_fn, + resample_fn=inner_resample_fn, + resample_criterion_fn=inner_resample_criterion_fn, + particles_dim=1, + unbiased_gradients=unbiased_gradients + ) + + inner_weighted_particles, filter_results = kernel.one_step(inner_weighted_particles, + filter_results, + seed=seed) + + updated_log_weights = log_weights + filter_results.incremental_log_marginal_likelihood + + do_rejuvenation = outer_rejuvenation_criterion_fn(step, state) + + def rejuvenate_particles(outside_parameters, updated_log_weights, inner_weighted_particles, filter_results): + proposed_parameters = parameter_proposal_kernel(outside_parameters).sample(seed=seed) + + rej_params_log_weights = ps.zeros_like( + initial_parameter_prior.log_prob(proposed_parameters) + ) + rej_params_log_weights = tf.nn.log_softmax(rej_params_log_weights, axis=0) + + rej_inner_weighted_particles = _particle_filter_initial_weighted_particles( + observations=inner_observations, + observation_fn=inner_observation_fn(proposed_parameters), + initial_state_prior=inner_initial_state_prior(0, proposed_parameters), + initial_state_proposal=(inner_initial_state_proposal(0, proposed_parameters) + if inner_initial_state_proposal is not None else None), + num_particles=num_inner_particles, + particles_dim=1, + seed=seed) + + batch_zeros = tf.zeros(ps.shape(log_weights)) + + rej_filter_results = smc_kernel.SequentialMonteCarloResults( + steps=tf.constant(0, dtype=tf.int32), + parent_indices=smc_kernel._dummy_indices_like( + rej_inner_weighted_particles.log_weights + ), + incremental_log_marginal_likelihood=batch_zeros, + accumulated_log_marginal_likelihood=batch_zeros, + seed=samplers.zeros_seed()) + + rej_inner_particles_weights = rej_inner_weighted_particles.log_weights + + rej_inner_propose_and_update_log_weights_fn = ( + _particle_filter_propose_and_update_log_weights_fn( + observations=inner_observations, + transition_fn=inner_transition_fn(proposed_parameters), + proposal_fn=(inner_proposal_fn(proposed_parameters) + if inner_proposal_fn is not None else None), + observation_fn=inner_observation_fn(proposed_parameters), + extra_fn=extra_fn, + particles_dim=1, + num_transitions_per_observation=num_transitions_per_observation) + ) + + rej_kernel = smc_kernel.SequentialMonteCarlo( + propose_and_update_log_weights_fn=rej_inner_propose_and_update_log_weights_fn, + resample_fn=inner_resample_fn, + resample_criterion_fn=inner_resample_criterion_fn, + particles_dim=1, + unbiased_gradients=unbiased_gradients) + + def condition(i, + rej_inner_weighted_particles, + rej_filter_results, + rej_parameters_weights, + rej_params_log_weights): + return tf.less_equal(i, step) + + def body(i, + rej_inner_weighted_particles, + rej_filter_results, + rej_parameters_weights, + rej_params_log_weights): + + rej_inner_weighted_particles, rej_filter_results = rej_kernel.one_step( + rej_inner_weighted_particles, rej_filter_results, seed=seed + ) + + rej_parameters_weights += rej_inner_weighted_particles.log_weights + + rej_params_log_weights = rej_params_log_weights + rej_filter_results.incremental_log_marginal_likelihood + return i + 1, rej_inner_weighted_particles, rej_filter_results, rej_parameters_weights, rej_params_log_weights + + i, rej_inner_weighted_particles, rej_filter_results, rej_inner_particles_weights, rej_params_log_weights = tf.while_loop( + condition, + body, + loop_vars=[0, + rej_inner_weighted_particles, + rej_filter_results, + rej_inner_particles_weights, + rej_params_log_weights + ] + ) + + log_a = rej_filter_results.accumulated_log_marginal_likelihood - \ + filter_results.accumulated_log_marginal_likelihood + \ + parameter_proposal_kernel(proposed_parameters).log_prob(outside_parameters) - \ + parameter_proposal_kernel(outside_parameters).log_prob(proposed_parameters) + + acceptance_probs = tf.minimum(1., tf.exp(log_a)) + + random_numbers = uniform.Uniform(0., 1.).sample(num_outer_particles, seed=seed) + + # Determine if the proposed particle should be accepted or reject + accept = random_numbers > acceptance_probs + + # Update the chosen particles and filter restults based on the acceptance step + outside_parameters = tf.where(accept, outside_parameters, proposed_parameters) + updated_log_weights = tf.where(accept, updated_log_weights, rej_params_log_weights) + + inner_weighted_particles_particles = mcmc_util.choose( + accept, + inner_weighted_particles.particles, + rej_inner_weighted_particles.particles + ) + inner_weighted_particles_log_weights = mcmc_util.choose( + accept, + inner_weighted_particles.log_weights, + rej_inner_weighted_particles.log_weights + ) + + inner_weighted_particles = smc_kernel.WeightedParticles( + particles=inner_weighted_particles_particles, + log_weights=inner_weighted_particles_log_weights, + extra=inner_weighted_particles.extra + ) + + filter_results = tf.nest.map_structure( + lambda a, b: where_fn(accept, a, b, num_outer_particles, num_inner_particles), + filter_results, + rej_filter_results + ) + + return outside_parameters, updated_log_weights, inner_weighted_particles, filter_results + + outside_parameters, updated_log_weights, inner_weighted_particles, filter_results = tf.cond( + do_rejuvenation, + lambda: (rejuvenate_particles(outside_parameters, updated_log_weights, inner_weighted_particles, filter_results)), + lambda: (outside_parameters, updated_log_weights, inner_weighted_particles, filter_results) + ) + + return smc_kernel.WeightedParticles( + particles=(outside_parameters, + inner_weighted_particles, + filter_results.parent_indices, + filter_results.incremental_log_marginal_likelihood, + filter_results.accumulated_log_marginal_likelihood), + log_weights=updated_log_weights, + extra=(step, + filter_results.seed)) + return _outer_propose_and_update_log_weights_fn + + @docstring_util.expand_docstring( particle_filter_arg_str=particle_filter_arg_str.format(scibor_ref_idx=1)) def particle_filter(observations, @@ -442,6 +818,7 @@ def particle_filter(observations, transition_fn, observation_fn, num_particles, + extra_fn=_default_extra_fn, initial_state_proposal=None, proposal_fn=None, resample_fn=weighted_resampling.resample_systematic, @@ -526,7 +903,9 @@ def particle_filter(observations, particles_dim=particles_dim, proposal_fn=proposal_fn, observation_fn=observation_fn, - num_transitions_per_observation=num_transitions_per_observation)) + num_transitions_per_observation=num_transitions_per_observation, + extra_fn=extra_fn + )) return sequential_monte_carlo( initial_weighted_particles=initial_weighted_particles, @@ -549,6 +928,7 @@ def _particle_filter_initial_weighted_particles(observations, initial_state_proposal, num_particles, particles_dim=0, + extra=np.nan, seed=None): """Initialize a set of weighted particles including the first observation.""" # Propose an initial state. @@ -574,6 +954,14 @@ def _particle_filter_initial_weighted_particles(observations, axis=particles_dim) # Return particles weighted by the initial observation. + if extra is np.nan: + if len(ps.shape(initial_log_weights)) == 1: + # initial extra for particle filter + extra = tf.constant(0) + else: + # initial extra for inner particles of smc_squared + extra = tf.constant(0, shape=ps.shape(initial_log_weights)) + return smc_kernel.WeightedParticles( particles=initial_state, log_weights=initial_log_weights + _compute_observation_log_weights( @@ -581,7 +969,8 @@ def _particle_filter_initial_weighted_particles(observations, particles=initial_state, observations=observations, observation_fn=observation_fn, - particles_dim=particles_dim)) + particles_dim=particles_dim), + extra=extra) def _particle_filter_propose_and_update_log_weights_fn( @@ -589,6 +978,7 @@ def _particle_filter_propose_and_update_log_weights_fn( transition_fn, proposal_fn, observation_fn, + extra_fn, num_transitions_per_observation=1, particles_dim=0): """Build a function specifying a particle filter update step.""" @@ -619,13 +1009,18 @@ def propose_and_update_log_weights_fn(step, state, seed=None): else: proposed_particles = transition_dist.sample(seed=seed) + updated_extra = extra_fn(step, + state, + seed) + with tf.control_dependencies(assertions): return smc_kernel.WeightedParticles( particles=proposed_particles, log_weights=log_weights + _compute_observation_log_weights( step + 1, proposed_particles, observations, observation_fn, num_transitions_per_observation=num_transitions_per_observation, - particles_dim=particles_dim)) + particles_dim=particles_dim), + extra=updated_extra) return propose_and_update_log_weights_fn @@ -670,6 +1065,8 @@ def _compute_observation_log_weights(step, observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) + if particles_dim == 1: + observation = tf.expand_dims(observation, axis=0) observation = tf.nest.map_structure( lambda x: tf.expand_dims(x, axis=particles_dim), observation) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py index 6508eb6231..e190c76bda 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py @@ -21,6 +21,7 @@ from tensorflow_probability.python.bijectors import shift from tensorflow_probability.python.distributions import bernoulli from tensorflow_probability.python.distributions import deterministic +from tensorflow_probability.python.distributions import independent from tensorflow_probability.python.distributions import joint_distribution_auto_batched as jdab from tensorflow_probability.python.distributions import joint_distribution_named as jdn from tensorflow_probability.python.distributions import linear_gaussian_ssm as lgssm @@ -177,128 +178,6 @@ def observation_fn(_, state): self.assertAllEqual(incremental_log_marginal_likelihoods.shape, [num_timesteps] + batch_shape) - def test_batch_of_filters_particles_dim_1(self): - - batch_shape = [3, 2] - num_particles = 1000 - num_timesteps = 40 - - # Batch of priors on object 1D positions and velocities. - initial_state_prior = jdn.JointDistributionNamed({ - 'position': normal.Normal(loc=0., scale=tf.ones(batch_shape)), - 'velocity': normal.Normal(loc=0., scale=tf.ones(batch_shape) * 0.1) - }) - - def transition_fn(_, previous_state): - return jdn.JointDistributionNamed({ - 'position': - normal.Normal( - loc=previous_state['position'] + previous_state['velocity'], - scale=0.1), - 'velocity': - normal.Normal(loc=previous_state['velocity'], scale=0.01) - }) - - def observation_fn(_, state): - return normal.Normal(loc=state['position'], scale=0.1) - - # Batch of synthetic observations, . - true_initial_positions = np.random.randn(*batch_shape).astype(self.dtype) - true_velocities = 0.1 * np.random.randn( - *batch_shape).astype(self.dtype) - observed_positions = ( - true_velocities * - np.arange(num_timesteps).astype( - self.dtype)[..., tf.newaxis, tf.newaxis] + - true_initial_positions) - - (particles, log_weights, parent_indices, - incremental_log_marginal_likelihoods) = self.evaluate( - particle_filter.particle_filter( - observations=observed_positions, - initial_state_prior=initial_state_prior, - transition_fn=transition_fn, - observation_fn=observation_fn, - num_particles=num_particles, - seed=test_util.test_seed(), - particles_dim=1)) - - self.assertAllEqual(particles['position'].shape, - [num_timesteps, - batch_shape[0], - num_particles, - batch_shape[1]]) - self.assertAllEqual(particles['velocity'].shape, - [num_timesteps, - batch_shape[0], - num_particles, - batch_shape[1]]) - self.assertAllEqual(parent_indices.shape, - [num_timesteps, - batch_shape[0], - num_particles, - batch_shape[1]]) - self.assertAllEqual(incremental_log_marginal_likelihoods.shape, - [num_timesteps] + batch_shape) - - self.assertAllClose( - self.evaluate( - tf.reduce_sum(tf.exp(log_weights) * - particles['position'], axis=2)), - observed_positions, - atol=0.3) - - velocity_means = tf.reduce_sum(tf.exp(log_weights) * - particles['velocity'], axis=2) - - self.assertAllClose( - self.evaluate(tf.reduce_mean(velocity_means, axis=0)), - true_velocities, atol=0.05) - - # Uncertainty in velocity should decrease over time. - velocity_stddev = self.evaluate( - tf.math.reduce_std(particles['velocity'], axis=2)) - self.assertAllLess((velocity_stddev[-1] - velocity_stddev[0]), 0.) - - trajectories = self.evaluate( - particle_filter.reconstruct_trajectories(particles, - parent_indices, - particles_dim=1)) - self.assertAllEqual([num_timesteps, - batch_shape[0], - num_particles, - batch_shape[1]], - trajectories['position'].shape) - self.assertAllEqual([num_timesteps, - batch_shape[0], - num_particles, - batch_shape[1]], - trajectories['velocity'].shape) - - # Verify that `infer_trajectories` also works on batches. - trajectories, incremental_log_marginal_likelihoods = self.evaluate( - particle_filter.infer_trajectories( - observations=observed_positions, - initial_state_prior=initial_state_prior, - transition_fn=transition_fn, - observation_fn=observation_fn, - num_particles=num_particles, - particles_dim=1, - seed=test_util.test_seed())) - - self.assertAllEqual([num_timesteps, - batch_shape[0], - num_particles, - batch_shape[1]], - trajectories['position'].shape) - self.assertAllEqual([num_timesteps, - batch_shape[0], - num_particles, - batch_shape[1]], - trajectories['velocity'].shape) - self.assertAllEqual(incremental_log_marginal_likelihoods.shape, - [num_timesteps] + batch_shape) - def test_reconstruct_trajectories_toy_example(self): particles = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6,], [7, 8, 9]]) # 1 -- 4 -- 7 @@ -734,6 +613,205 @@ def marginal_log_likelihood(level_scale, noise_scale): self.assertAllNotNone(grads) self.assertAllAssertsNested(self.assertNotAllZero, grads) + def test_smc_squared_rejuvenation_parameters(self): + def particle_dynamics(params, _, previous_state): + reshaped_params = tf.reshape(params, [params.shape[0]] + [1] * (previous_state.shape.rank - 1)) + broadcasted_params = tf.broadcast_to(reshaped_params, previous_state.shape) + return normal.Normal(previous_state + broadcasted_params + 1, 0.1) + + def rejuvenation_criterion(step, state): + # Rejuvenation every 2 steps + cond = tf.logical_and( + tf.equal(tf.math.mod(step, tf.constant(2)), tf.constant(0)), + tf.not_equal(state.extra[0], tf.constant(0)) + ) + return tf.cond(cond, lambda: tf.constant(True), lambda: tf.constant(False)) + + inner_observations = tf.range(30, dtype=tf.float32) + + num_outer_particles = 3 + num_inner_particles = 7 + + loc = tf.broadcast_to([0., 0.], [num_outer_particles, 2]) + scale_diag = tf.broadcast_to([0.05, 0.05], [num_outer_particles, 2]) + + params, inner_pt = self.evaluate(particle_filter.smc_squared( + inner_observations=inner_observations, + inner_initial_state_prior=lambda _, params: mvn_diag.MultivariateNormalDiag( + loc=loc, scale_diag=scale_diag + ), + initial_parameter_prior=normal.Normal(3., 1.), + num_outer_particles=num_outer_particles, + num_inner_particles=num_inner_particles, + outer_rejuvenation_criterion_fn=rejuvenation_criterion, + inner_transition_fn=lambda params: ( + lambda _, state: independent.Independent(particle_dynamics(params, _, state), 1)), + inner_observation_fn=lambda params: ( + lambda _, state: independent.Independent(normal.Normal(state, 2.), 1)), + outer_trace_fn=lambda s, r: ( + s.particles[0], + s.particles[1] + ), + parameter_proposal_kernel=lambda params: normal.Normal(params, 3), + seed=test_util.test_seed() + ) + ) + + abs_params = tf.abs(params) + differences = abs_params[1:] - abs_params[:-1] + mask_parameters = tf.reduce_all(tf.less_equal(differences, 0), axis=0) + + self.assertAllTrue(mask_parameters) + + def test_smc_squared_can_step_dynamics_faster_than_observations(self): + initial_state_prior = jdn.JointDistributionNamed({ + 'position': deterministic.Deterministic([1.]), + 'velocity': deterministic.Deterministic([0.]) + }) + + # Use 100 steps between observations to integrate a simple harmonic + # oscillator. + dt = 0.01 + def simple_harmonic_motion_transition_fn(_, state): + return jdn.JointDistributionNamed({ + 'position': + normal.Normal( + loc=state['position'] + dt * state['velocity'], + scale=dt * 0.01), + 'velocity': + normal.Normal( + loc=state['velocity'] - dt * state['position'], + scale=dt * 0.01) + }) + + def observe_position(_, state): + return normal.Normal(loc=state['position'], scale=0.01) + + particles, lps = self.evaluate(particle_filter.smc_squared( + inner_observations=tf.convert_to_tensor( + [tf.math.cos(0.), tf.math.cos(1.)]), + inner_initial_state_prior=lambda _, params: initial_state_prior, + initial_parameter_prior=deterministic.Deterministic(0.), + num_outer_particles=1, + inner_transition_fn=lambda params: simple_harmonic_motion_transition_fn, + inner_observation_fn=lambda params: observe_position, + num_inner_particles=1024, + outer_trace_fn=lambda s, r: ( + s.particles[1].particles, + s.particles[3] + ), + num_transitions_per_observation=100, + seed=test_util.test_seed()) + ) + + self.assertAllEqual(ps.shape(particles['position']), tf.constant([102, 1, 1024])) + + self.assertAllClose(tf.transpose(np.mean(particles['position'], axis=-1)), + tf.reshape(tf.math.cos(dt * np.arange(102)), [1, -1]), + atol=0.04) + + self.assertAllEqual(ps.shape(lps), [102, 1]) + self.assertGreater(lps[1][0], 1.) + self.assertGreater(lps[-1][0], 3.) + + def test_smc_squared_custom_outer_trace_fn(self): + def trace_fn(state, _): + # Traces the mean and stddev of the particle population at each step. + weights = tf.exp(state[0][1].log_weights[0]) + mean = tf.reduce_sum(weights * state[0][1].particles[0], axis=0) + variance = tf.reduce_sum( + weights * (state[0][1].particles[0] - mean[tf.newaxis, ...]) ** 2) + return {'mean': mean, + 'stddev': tf.sqrt(variance), + # In real usage we would likely not track the particles and + # weights. We keep them here just so we can double-check the + # stats, below. + 'particles': state[0][1].particles[0], + 'weights': weights} + + results = self.evaluate(particle_filter.smc_squared( + inner_observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]), + inner_initial_state_prior=lambda _, params: normal.Normal([0.], 1.), + initial_parameter_prior=deterministic.Deterministic(0.), + inner_transition_fn=lambda params: (lambda _, state: normal.Normal(state, 1.)), + inner_observation_fn=lambda params: (lambda _, state: normal.Normal(state, 1.)), + num_inner_particles=1024, + num_outer_particles=1, + outer_trace_fn=trace_fn, + seed=test_util.test_seed()) + ) + + # Verify that posterior means are increasing. + self.assertAllGreater(results['mean'][1:] - results['mean'][:-1], 0.) + + # Check that our traced means and scales match values computed + # by averaging over particles after the fact. + all_means = self.evaluate(tf.reduce_sum( + results['weights'] * results['particles'], axis=1)) + all_variances = self.evaluate( + tf.reduce_sum( + results['weights'] * + (results['particles'] - all_means[..., tf.newaxis])**2, + axis=1)) + self.assertAllClose(results['mean'], all_means) + self.assertAllClose(results['stddev'], np.sqrt(all_variances)) + + def test_smc_squared_indices_to_trace(self): + num_outer_particles = 7 + num_inner_particles = 13 + + def rejuvenation_criterion(step, state): + # Rejuvenation every 3 steps + cond = tf.logical_and( + tf.equal(tf.math.mod(step, tf.constant(3)), tf.constant(0)), + tf.not_equal(state.extra[0], tf.constant(0)) + ) + return tf.cond(cond, lambda: tf.constant(True), lambda: tf.constant(False)) + + (parameters, weight_parameters, inner_particles, inner_log_weights, lp) = self.evaluate( + particle_filter.smc_squared( + inner_observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]), + initial_parameter_prior=deterministic.Deterministic(0.), + inner_initial_state_prior=lambda _, params: normal.Normal([0.] * num_outer_particles, 1.), + inner_transition_fn=lambda params: (lambda _, state: normal.Normal(state, 10.)), + inner_observation_fn=lambda params: (lambda _, state: normal.Normal(state, 0.1)), + num_inner_particles=num_inner_particles, + num_outer_particles=num_outer_particles, + outer_rejuvenation_criterion_fn=rejuvenation_criterion, + outer_trace_fn=lambda s, r: ( # pylint: disable=g-long-lambda + s.particles[0], + s.log_weights, + s.particles[1].particles, + s.particles[1].log_weights, + r.accumulated_log_marginal_likelihood), + seed=test_util.test_seed()) + ) + + # TODO: smc_squared at the moment starts his run with an empty step + self.assertAllEqual(ps.shape(parameters), [6, 7]) + self.assertAllEqual(ps.shape(weight_parameters), [6, 7]) + self.assertAllEqual(ps.shape(inner_particles), [6, 7, 13]) + self.assertAllEqual(ps.shape(inner_log_weights), [6, 7, 13]) + self.assertAllEqual(ps.shape(lp), [6]) + + def test_extra(self): + def step_hundred(step, state, seed): + return step * 2 + + results = self.evaluate( + particle_filter.particle_filter( + observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]), + initial_state_prior=normal.Normal(0., 1.), + transition_fn=lambda _, state: normal.Normal(state, 1.), + observation_fn=lambda _, state: normal.Normal(state, 1.), + num_particles=1024, + extra_fn=step_hundred, + trace_fn=lambda s, r: s.extra, + seed=test_util.test_seed()) + ) + + self.assertAllEqual(results, [0, 0, 2, 4, 6]) + # TODO(b/186068104): add tests with dynamic shapes. class ParticleFilterTestFloat32(_ParticleFilterTest): diff --git a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py index 73cb0f8414..300418c87d 100644 --- a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py +++ b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py @@ -34,7 +34,7 @@ # SequentialMonteCarlo `state` structure. class WeightedParticles(collections.namedtuple( - 'WeightedParticles', ['particles', 'log_weights'])): + 'WeightedParticles', ['particles', 'log_weights', 'extra'])): """Particles with corresponding log weights. This structure serves as the `state` for the `SequentialMonteCarlo` transition @@ -50,6 +50,10 @@ class WeightedParticles(collections.namedtuple( `exp(reduce_logsumexp(log_weights, axis=0)) == 1.`. These must be used in conjunction with `particles` to compute expectations under the target distribution. + extra: a (structure of) Tensor(s) each of shape + `concat([[b1, ..., bN], event_shape])`, where `event_shape` + may differ across component `Tensor`s. This represents global state of the + sampling process that is not associated with individual particles. In some contexts, particles may be stacked across multiple inference steps, in which case all `Tensor` shapes will be prefixed by an additional dimension @@ -292,7 +296,7 @@ def one_step(self, state, kernel_results, seed=None): - tf.gather(normalized_log_weights, 0, axis=self.particles_dim)) do_resample = self.resample_criterion_fn( - state, particles_dim=self.particles_dim) + state, self.particles_dim) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to # use the resampled values for each batch element according to @@ -326,7 +330,8 @@ def one_step(self, state, kernel_results, seed=None): normalized_log_weights)) return (WeightedParticles(particles=resampled_particles, - log_weights=log_weights), + log_weights=log_weights, + extra=state.extra), SequentialMonteCarloResults( steps=kernel_results.steps + 1, parent_indices=resample_indices, diff --git a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py index 2a9302a420..2e29f6c4dd 100644 --- a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py +++ b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py @@ -42,7 +42,9 @@ def propose_and_update_log_weights_fn(_, weighted_particles, seed=None): return WeightedParticles( particles=proposed_particles, log_weights=weighted_particles.log_weights + - normal.Normal(loc=-2.6, scale=0.1).log_prob(proposed_particles)) + normal.Normal(loc=-2.6, scale=0.1).log_prob(proposed_particles), + extra=tf.constant(np.nan) + ) num_particles = 16 initial_state = self.evaluate( @@ -50,7 +52,9 @@ def propose_and_update_log_weights_fn(_, weighted_particles, seed=None): particles=tf.random.normal([num_particles], seed=test_util.test_seed()), log_weights=tf.fill([num_particles], - -tf.math.log(float(num_particles))))) + -tf.math.log(float(num_particles))), + extra=tf.constant(np.nan) + )) # Run a couple of steps. seeds = samplers.split_seed( @@ -96,7 +100,9 @@ def testMarginalLikelihoodGradientIsDefined(self): WeightedParticles( particles=samplers.normal([num_particles], seed=seeds[0]), log_weights=tf.fill([num_particles], - -tf.math.log(float(num_particles))))) + -tf.math.log(float(num_particles))), + extra=tf.constant(np.nan) + )) def propose_and_update_log_weights_fn(_, weighted_particles, @@ -110,7 +116,9 @@ def propose_and_update_log_weights_fn(_, particles=proposed_particles, log_weights=(weighted_particles.log_weights + transition_dist.log_prob(proposed_particles) - - proposal_dist.log_prob(proposed_particles))) + proposal_dist.log_prob(proposed_particles)), + extra=tf.constant(np.nan) + ) def marginal_logprob(transition_scale): kernel = SequentialMonteCarlo( diff --git a/tfp_nightly.egg-info/PKG-INFO b/tfp_nightly.egg-info/PKG-INFO new file mode 100644 index 0000000000..96ea3cdd82 --- /dev/null +++ b/tfp_nightly.egg-info/PKG-INFO @@ -0,0 +1,244 @@ +Metadata-Version: 2.1 +Name: tfp-nightly +Version: 0.24.0.dev0 +Summary: Probabilistic modeling and statistical inference in TensorFlow +Home-page: http://github.com/tensorflow/probability +Author: Google LLC +Author-email: no-reply@google.com +License: Apache 2.0 +Keywords: tensorflow probability statistics bayesian machine learning +Platform: UNKNOWN +Classifier: Development Status :: 4 - Beta +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: Education +Classifier: Intended Audience :: Science/Research +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Topic :: Scientific/Engineering +Classifier: Topic :: Scientific/Engineering :: Mathematics +Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence +Classifier: Topic :: Software Development +Classifier: Topic :: Software Development :: Libraries +Classifier: Topic :: Software Development :: Libraries :: Python Modules +Requires-Python: >=3.9 +Description-Content-Type: text/markdown +Provides-Extra: jax +Provides-Extra: tfds +License-File: LICENSE + +# TensorFlow Probability + +TensorFlow Probability is a library for probabilistic reasoning and statistical +analysis in TensorFlow. As part of the TensorFlow ecosystem, TensorFlow +Probability provides integration of probabilistic methods with deep networks, +gradient-based inference via automatic differentiation, and scalability to +large datasets and models via hardware acceleration (e.g., GPUs) and distributed +computation. + +__TFP also works as "Tensor-friendly Probability" in pure JAX!__: +`from tensorflow_probability.substrates import jax as tfp` -- +Learn more [here](https://www.tensorflow.org/probability/examples/TensorFlow_Probability_on_JAX). + +Our probabilistic machine learning tools are structured as follows. + +__Layer 0: TensorFlow.__ Numerical operations. In particular, the LinearOperator +class enables matrix-free implementations that can exploit special structure +(diagonal, low-rank, etc.) for efficient computation. It is built and maintained +by the TensorFlow Probability team and is now part of +[`tf.linalg`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/ops/linalg) +in core TF. + +__Layer 1: Statistical Building Blocks__ + +* Distributions ([`tfp.distributions`](https://github.com/tensorflow/probability/tree/main/tensorflow_probability/python/distributions)): + A large collection of probability + distributions and related statistics with batch and + [broadcasting](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + semantics. See the + [Distributions Tutorial](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/TensorFlow_Distributions_Tutorial.ipynb). +* Bijectors ([`tfp.bijectors`](https://github.com/tensorflow/probability/tree/main/tensorflow_probability/python/bijectors)): + Reversible and composable transformations of random variables. Bijectors + provide a rich class of transformed distributions, from classical examples + like the + [log-normal distribution](https://en.wikipedia.org/wiki/Log-normal_distribution) + to sophisticated deep learning models such as + [masked autoregressive flows](https://arxiv.org/abs/1705.07057). + +__Layer 2: Model Building__ + +* Joint Distributions (e.g., [`tfp.distributions.JointDistributionSequential`](https://github.com/tensorflow/probability/tree/main/tensorflow_probability/python/distributions/joint_distribution_sequential.py)): + Joint distributions over one or more possibly-interdependent distributions. + For an introduction to modeling with TFP's `JointDistribution`s, check out + [this colab](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/Modeling_with_JointDistribution.ipynb) +* Probabilistic Layers ([`tfp.layers`](https://github.com/tensorflow/probability/tree/main/tensorflow_probability/python/layers)): + Neural network layers with uncertainty over the functions they represent, + extending TensorFlow Layers. + +__Layer 3: Probabilistic Inference__ + +* Markov chain Monte Carlo ([`tfp.mcmc`](https://github.com/tensorflow/probability/tree/main/tensorflow_probability/python/mcmc)): + Algorithms for approximating integrals via sampling. Includes + [Hamiltonian Monte Carlo](https://en.wikipedia.org/wiki/Hamiltonian_Monte_Carlo), + random-walk Metropolis-Hastings, and the ability to build custom transition + kernels. +* Variational Inference ([`tfp.vi`](https://github.com/tensorflow/probability/tree/main/tensorflow_probability/python/vi)): + Algorithms for approximating integrals via optimization. +* Optimizers ([`tfp.optimizer`](https://github.com/tensorflow/probability/tree/main/tensorflow_probability/python/optimizer)): + Stochastic optimization methods, extending TensorFlow Optimizers. Includes + [Stochastic Gradient Langevin Dynamics](http://www.icml-2011.org/papers/398_icmlpaper.pdf). +* Monte Carlo ([`tfp.monte_carlo`](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/monte_carlo)): + Tools for computing Monte Carlo expectations. + +TensorFlow Probability is under active development. Interfaces may change at any +time. + +## Examples + +See [`tensorflow_probability/examples/`](https://github.com/tensorflow/probability/tree/main/tensorflow_probability/examples/) +for end-to-end examples. It includes tutorial notebooks such as: + +* [Linear Mixed Effects Models](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/Linear_Mixed_Effects_Models.ipynb). + A hierarchical linear model for sharing statistical strength across examples. +* [Eight Schools](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/Eight_Schools.ipynb). + A hierarchical normal model for exchangeable treatment effects. +* [Hierarchical Linear Models](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/HLM_TFP_R_Stan.ipynb). + Hierarchical linear models compared among TensorFlow Probability, R, and Stan. +* [Bayesian Gaussian Mixture Models](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb). + Clustering with a probabilistic generative model. +* [Probabilistic Principal Components Analysis](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_PCA.ipynb). + Dimensionality reduction with latent variables. +* [Gaussian Copulas](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/Gaussian_Copula.ipynb). + Probability distributions for capturing dependence across random variables. +* [TensorFlow Distributions: A Gentle Introduction](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/TensorFlow_Distributions_Tutorial.ipynb). + Introduction to TensorFlow Distributions. +* [Understanding TensorFlow Distributions Shapes](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb). + How to distinguish between samples, batches, and events for arbitrarily shaped + probabilistic computations. +* [TensorFlow Probability Case Study: Covariance Estimation](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/TensorFlow_Probability_Case_Study_Covariance_Estimation.ipynb). + A user's case study in applying TensorFlow Probability to estimate covariances. + +It also includes example scripts such as: + + Representation learning with a latent code and variational inference. +* [Vector-Quantized Autoencoder](https://github.com/tensorflow/probability/tree/main/tensorflow_probability/examples/vq_vae.py). + Discrete representation learning with vector quantization. +* [Disentangled Sequential Variational Autoencoder](https://github.com/tensorflow/probability/tree/main/tensorflow_probability/examples/disentangled_vae.py) + Disentangled representation learning over sequences with variational inference. +* [Bayesian Neural Networks](https://github.com/tensorflow/probability/tree/main/tensorflow_probability/examples/bayesian_neural_network.py). + Neural networks with uncertainty over their weights. +* [Bayesian Logistic Regression](https://github.com/tensorflow/probability/tree/main/tensorflow_probability/examples/logistic_regression.py). + Bayesian inference for binary classification. + +## Installation + +For additional details on installing TensorFlow, guidance installing +prerequisites, and (optionally) setting up virtual environments, see the +[TensorFlow installation guide](https://www.tensorflow.org/install). + +### Stable Builds + +To install the latest stable version, run the following: + +```shell +# Notes: + +# - The `--upgrade` flag ensures you'll get the latest version. +# - The `--user` flag ensures the packages are installed to your user directory +# rather than the system directory. +# - TensorFlow 2 packages require a pip >= 19.0 +python -m pip install --upgrade --user pip +python -m pip install --upgrade --user tensorflow tensorflow_probability +``` + +For CPU-only usage (and a smaller install), install with `tensorflow-cpu`. + +To use a pre-2.0 version of TensorFlow, run: + +```shell +python -m pip install --upgrade --user "tensorflow<2" "tensorflow_probability<0.9" +``` + +Note: Since [TensorFlow](https://www.tensorflow.org/install) is *not* included +as a dependency of the TensorFlow Probability package (in `setup.py`), you must +explicitly install the TensorFlow package (`tensorflow` or `tensorflow-cpu`). +This allows us to maintain one package instead of separate packages for CPU and +GPU-enabled TensorFlow. See the +[TFP release notes](https://github.com/tensorflow/probability/releases) for more +details about dependencies between TensorFlow and TensorFlow Probability. + + +### Nightly Builds + +There are also nightly builds of TensorFlow Probability under the pip package +`tfp-nightly`, which depends on one of `tf-nightly` or `tf-nightly-cpu`. +Nightly builds include newer features, but may be less stable than the +versioned releases. Both stable and nightly docs are available +[here](https://www.tensorflow.org/probability/api_docs/python/tfp?version=nightly). + +```shell +python -m pip install --upgrade --user tf-nightly tfp-nightly +``` + +### Installing from Source + +You can also install from source. This requires the [Bazel]( +https://bazel.build/) build system. It is highly recommended that you install +the nightly build of TensorFlow (`tf-nightly`) before trying to build +TensorFlow Probability from source. + +```shell +# sudo apt-get install bazel git python-pip # Ubuntu; others, see above links. +python -m pip install --upgrade --user tf-nightly +git clone https://github.com/tensorflow/probability.git +cd probability +bazel build --copt=-O3 --copt=-march=native :pip_pkg +PKGDIR=$(mktemp -d) +./bazel-bin/pip_pkg $PKGDIR +python -m pip install --upgrade --user $PKGDIR/*.whl +``` + +## Community + +As part of TensorFlow, we're committed to fostering an open and welcoming +environment. + +* [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow): Ask + or answer technical questions. +* [GitHub](https://github.com/tensorflow/probability/issues): Report bugs or + make feature requests. +* [TensorFlow Blog](https://blog.tensorflow.org/): Stay up to date on content + from the TensorFlow team and best articles from the community. +* [Youtube Channel](http://youtube.com/tensorflow/): Follow TensorFlow shows. +* [tfprobability@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/tfprobability): + Open mailing list for discussion and questions. + +See the [TensorFlow Community](https://www.tensorflow.org/community/) page for +more details. Check out our latest publicity here: + ++ [Coffee with a Googler: Probabilistic Machine Learning in TensorFlow]( + https://www.youtube.com/watch?v=BjUkL8DFH5Q) ++ [Introducing TensorFlow Probability]( + https://medium.com/tensorflow/introducing-tensorflow-probability-dca4c304e245) + +## Contributing + +We're eager to collaborate with you! See [`CONTRIBUTING.md`](CONTRIBUTING.md) +for a guide on how to contribute. This project adheres to TensorFlow's +[code of conduct](CODE_OF_CONDUCT.md). By participating, you are expected to +uphold this code. + +## References + +If you use TensorFlow Probability in a paper, please cite: + ++ _TensorFlow Distributions._ Joshua V. Dillon, Ian Langmore, Dustin Tran, +Eugene Brevdo, Srinivas Vasudevan, Dave Moore, Brian Patton, Alex Alemi, Matt +Hoffman, Rif A. Saurous. +[arXiv preprint arXiv:1711.10604, 2017](https://arxiv.org/abs/1711.10604). + +(We're aware there's a lot more to TensorFlow Probability than Distributions, but the Distributions paper lays out our vision and is a fine thing to cite for now.) + + diff --git a/tfp_nightly.egg-info/SOURCES.txt b/tfp_nightly.egg-info/SOURCES.txt new file mode 100644 index 0000000000..8645f5fa31 --- /dev/null +++ b/tfp_nightly.egg-info/SOURCES.txt @@ -0,0 +1,1068 @@ +LICENSE +README.md +setup.py +tensorflow_probability/__init__.py +tensorflow_probability/python/__init__.py +tensorflow_probability/python/version.py +tensorflow_probability/python/bijectors/__init__.py +tensorflow_probability/python/bijectors/absolute_value.py +tensorflow_probability/python/bijectors/absolute_value_test.py +tensorflow_probability/python/bijectors/ascending.py +tensorflow_probability/python/bijectors/ascending_test.py +tensorflow_probability/python/bijectors/batch_normalization.py +tensorflow_probability/python/bijectors/batch_normalization_test.py +tensorflow_probability/python/bijectors/bijector.py +tensorflow_probability/python/bijectors/bijector_composition_test.py +tensorflow_probability/python/bijectors/bijector_properties_test.py +tensorflow_probability/python/bijectors/bijector_test.py +tensorflow_probability/python/bijectors/bijector_test_util.py +tensorflow_probability/python/bijectors/blockwise.py +tensorflow_probability/python/bijectors/blockwise_test.py +tensorflow_probability/python/bijectors/categorical_to_discrete.py +tensorflow_probability/python/bijectors/categorical_to_discrete_test.py +tensorflow_probability/python/bijectors/chain.py +tensorflow_probability/python/bijectors/chain_test.py +tensorflow_probability/python/bijectors/cholesky_outer_product.py +tensorflow_probability/python/bijectors/cholesky_outer_product_test.py +tensorflow_probability/python/bijectors/cholesky_to_inv_cholesky.py +tensorflow_probability/python/bijectors/cholesky_to_inv_cholesky_test.py +tensorflow_probability/python/bijectors/composition.py +tensorflow_probability/python/bijectors/correlation_cholesky.py +tensorflow_probability/python/bijectors/correlation_cholesky_test.py +tensorflow_probability/python/bijectors/cumsum.py +tensorflow_probability/python/bijectors/cumsum_test.py +tensorflow_probability/python/bijectors/discrete_cosine_transform.py +tensorflow_probability/python/bijectors/discrete_cosine_transform_test.py +tensorflow_probability/python/bijectors/exp.py +tensorflow_probability/python/bijectors/exp_test.py +tensorflow_probability/python/bijectors/expm1.py +tensorflow_probability/python/bijectors/expm1_test.py +tensorflow_probability/python/bijectors/ffjord.py +tensorflow_probability/python/bijectors/ffjord_test.py +tensorflow_probability/python/bijectors/fill_scale_tril.py +tensorflow_probability/python/bijectors/fill_scale_tril_test.py +tensorflow_probability/python/bijectors/fill_triangular.py +tensorflow_probability/python/bijectors/fill_triangular_test.py +tensorflow_probability/python/bijectors/frechet_cdf.py +tensorflow_probability/python/bijectors/frechet_cdf_test.py +tensorflow_probability/python/bijectors/generalized_pareto.py +tensorflow_probability/python/bijectors/generalized_pareto_test.py +tensorflow_probability/python/bijectors/gev_cdf.py +tensorflow_probability/python/bijectors/gev_cdf_test.py +tensorflow_probability/python/bijectors/glow.py +tensorflow_probability/python/bijectors/glow_test.py +tensorflow_probability/python/bijectors/gompertz_cdf.py +tensorflow_probability/python/bijectors/gompertz_cdf_test.py +tensorflow_probability/python/bijectors/gumbel_cdf.py +tensorflow_probability/python/bijectors/gumbel_cdf_test.py +tensorflow_probability/python/bijectors/householder.py +tensorflow_probability/python/bijectors/householder_test.py +tensorflow_probability/python/bijectors/hypothesis_testlib.py +tensorflow_probability/python/bijectors/identity.py +tensorflow_probability/python/bijectors/identity_test.py +tensorflow_probability/python/bijectors/inline.py +tensorflow_probability/python/bijectors/inline_test.py +tensorflow_probability/python/bijectors/invert.py +tensorflow_probability/python/bijectors/invert_test.py +tensorflow_probability/python/bijectors/iterated_sigmoid_centered.py +tensorflow_probability/python/bijectors/iterated_sigmoid_centered_test.py +tensorflow_probability/python/bijectors/joint_map.py +tensorflow_probability/python/bijectors/joint_map_test.py +tensorflow_probability/python/bijectors/kumaraswamy_cdf.py +tensorflow_probability/python/bijectors/kumaraswamy_cdf_test.py +tensorflow_probability/python/bijectors/lambertw_transform.py +tensorflow_probability/python/bijectors/lambertw_transform_test.py +tensorflow_probability/python/bijectors/ldj_ratio.py +tensorflow_probability/python/bijectors/ldj_ratio_test.py +tensorflow_probability/python/bijectors/masked_autoregressive.py +tensorflow_probability/python/bijectors/masked_autoregressive_test.py +tensorflow_probability/python/bijectors/matrix_inverse_tril.py +tensorflow_probability/python/bijectors/matrix_inverse_tril_test.py +tensorflow_probability/python/bijectors/moyal_cdf.py +tensorflow_probability/python/bijectors/moyal_cdf_test.py +tensorflow_probability/python/bijectors/normal_cdf.py +tensorflow_probability/python/bijectors/normal_cdf_test.py +tensorflow_probability/python/bijectors/pad.py +tensorflow_probability/python/bijectors/pad_test.py +tensorflow_probability/python/bijectors/permute.py +tensorflow_probability/python/bijectors/permute_test.py +tensorflow_probability/python/bijectors/power.py +tensorflow_probability/python/bijectors/power_test.py +tensorflow_probability/python/bijectors/power_transform.py +tensorflow_probability/python/bijectors/power_transform_test.py +tensorflow_probability/python/bijectors/rational_quadratic_spline.py +tensorflow_probability/python/bijectors/rational_quadratic_spline_test.py +tensorflow_probability/python/bijectors/rayleigh_cdf.py +tensorflow_probability/python/bijectors/rayleigh_cdf_test.py +tensorflow_probability/python/bijectors/real_nvp.py +tensorflow_probability/python/bijectors/real_nvp_test.py +tensorflow_probability/python/bijectors/reciprocal.py +tensorflow_probability/python/bijectors/reciprocal_test.py +tensorflow_probability/python/bijectors/reshape.py +tensorflow_probability/python/bijectors/reshape_test.py +tensorflow_probability/python/bijectors/restructure.py +tensorflow_probability/python/bijectors/restructure_test.py +tensorflow_probability/python/bijectors/scale.py +tensorflow_probability/python/bijectors/scale_matvec_diag.py +tensorflow_probability/python/bijectors/scale_matvec_diag_test.py +tensorflow_probability/python/bijectors/scale_matvec_linear_operator.py +tensorflow_probability/python/bijectors/scale_matvec_linear_operator_test.py +tensorflow_probability/python/bijectors/scale_matvec_lu.py +tensorflow_probability/python/bijectors/scale_matvec_lu_test.py +tensorflow_probability/python/bijectors/scale_matvec_tril.py +tensorflow_probability/python/bijectors/scale_matvec_tril_test.py +tensorflow_probability/python/bijectors/scale_test.py +tensorflow_probability/python/bijectors/shift.py +tensorflow_probability/python/bijectors/shift_test.py +tensorflow_probability/python/bijectors/shifted_gompertz_cdf.py +tensorflow_probability/python/bijectors/shifted_gompertz_cdf_test.py +tensorflow_probability/python/bijectors/sigmoid.py +tensorflow_probability/python/bijectors/sigmoid_test.py +tensorflow_probability/python/bijectors/sinh.py +tensorflow_probability/python/bijectors/sinh_arcsinh.py +tensorflow_probability/python/bijectors/sinh_arcsinh_test.py +tensorflow_probability/python/bijectors/sinh_test.py +tensorflow_probability/python/bijectors/soft_clip.py +tensorflow_probability/python/bijectors/soft_clip_test.py +tensorflow_probability/python/bijectors/softfloor.py +tensorflow_probability/python/bijectors/softfloor_test.py +tensorflow_probability/python/bijectors/softmax_centered.py +tensorflow_probability/python/bijectors/softmax_centered_test.py +tensorflow_probability/python/bijectors/softplus.py +tensorflow_probability/python/bijectors/softplus_test.py +tensorflow_probability/python/bijectors/softsign.py +tensorflow_probability/python/bijectors/softsign_test.py +tensorflow_probability/python/bijectors/split.py +tensorflow_probability/python/bijectors/split_test.py +tensorflow_probability/python/bijectors/square.py +tensorflow_probability/python/bijectors/square_test.py +tensorflow_probability/python/bijectors/tanh.py +tensorflow_probability/python/bijectors/tanh_test.py +tensorflow_probability/python/bijectors/transform_diagonal.py +tensorflow_probability/python/bijectors/transform_diagonal_test.py +tensorflow_probability/python/bijectors/transpose.py +tensorflow_probability/python/bijectors/transpose_test.py +tensorflow_probability/python/bijectors/unit_vector.py +tensorflow_probability/python/bijectors/unit_vector_test.py +tensorflow_probability/python/bijectors/weibull_cdf.py +tensorflow_probability/python/bijectors/weibull_cdf_test.py +tensorflow_probability/python/debugging/__init__.py +tensorflow_probability/python/debugging/benchmarking/__init__.py +tensorflow_probability/python/debugging/benchmarking/benchmark_tf_function.py +tensorflow_probability/python/distributions/__init__.py +tensorflow_probability/python/distributions/autoregressive.py +tensorflow_probability/python/distributions/autoregressive_test.py +tensorflow_probability/python/distributions/batch_broadcast.py +tensorflow_probability/python/distributions/batch_broadcast_test.py +tensorflow_probability/python/distributions/batch_concat.py +tensorflow_probability/python/distributions/batch_concat_test.py +tensorflow_probability/python/distributions/batch_reshape.py +tensorflow_probability/python/distributions/batch_reshape_test.py +tensorflow_probability/python/distributions/bates.py +tensorflow_probability/python/distributions/bates_test.py +tensorflow_probability/python/distributions/bernoulli.py +tensorflow_probability/python/distributions/bernoulli_test.py +tensorflow_probability/python/distributions/beta.py +tensorflow_probability/python/distributions/beta_binomial.py +tensorflow_probability/python/distributions/beta_binomial_test.py +tensorflow_probability/python/distributions/beta_quotient.py +tensorflow_probability/python/distributions/beta_quotient_test.py +tensorflow_probability/python/distributions/beta_test.py +tensorflow_probability/python/distributions/binomial.py +tensorflow_probability/python/distributions/binomial_test.py +tensorflow_probability/python/distributions/blockwise.py +tensorflow_probability/python/distributions/blockwise_test.py +tensorflow_probability/python/distributions/categorical.py +tensorflow_probability/python/distributions/categorical_test.py +tensorflow_probability/python/distributions/cauchy.py +tensorflow_probability/python/distributions/cauchy_test.py +tensorflow_probability/python/distributions/chi.py +tensorflow_probability/python/distributions/chi2.py +tensorflow_probability/python/distributions/chi2_test.py +tensorflow_probability/python/distributions/chi_test.py +tensorflow_probability/python/distributions/cholesky_lkj.py +tensorflow_probability/python/distributions/cholesky_lkj_test.py +tensorflow_probability/python/distributions/cholesky_util.py +tensorflow_probability/python/distributions/cholesky_util_test.py +tensorflow_probability/python/distributions/continuous_bernoulli.py +tensorflow_probability/python/distributions/continuous_bernoulli_test.py +tensorflow_probability/python/distributions/deterministic.py +tensorflow_probability/python/distributions/deterministic_test.py +tensorflow_probability/python/distributions/dirichlet.py +tensorflow_probability/python/distributions/dirichlet_multinomial.py +tensorflow_probability/python/distributions/dirichlet_multinomial_test.py +tensorflow_probability/python/distributions/dirichlet_test.py +tensorflow_probability/python/distributions/discrete_rejection_sampling.py +tensorflow_probability/python/distributions/discrete_rejection_sampling_test.py +tensorflow_probability/python/distributions/distribution.py +tensorflow_probability/python/distributions/distribution_properties_test.py +tensorflow_probability/python/distributions/distribution_test.py +tensorflow_probability/python/distributions/doublesided_maxwell.py +tensorflow_probability/python/distributions/doublesided_maxwell_test.py +tensorflow_probability/python/distributions/dpp.py +tensorflow_probability/python/distributions/dpp_test.py +tensorflow_probability/python/distributions/empirical.py +tensorflow_probability/python/distributions/empirical_test.py +tensorflow_probability/python/distributions/exp_gamma.py +tensorflow_probability/python/distributions/exp_gamma_test.py +tensorflow_probability/python/distributions/exponential.py +tensorflow_probability/python/distributions/exponential_test.py +tensorflow_probability/python/distributions/exponentially_modified_gaussian.py +tensorflow_probability/python/distributions/exponentially_modified_gaussian_test.py +tensorflow_probability/python/distributions/finite_discrete.py +tensorflow_probability/python/distributions/finite_discrete_test.py +tensorflow_probability/python/distributions/gamma.py +tensorflow_probability/python/distributions/gamma_gamma.py +tensorflow_probability/python/distributions/gamma_gamma_test.py +tensorflow_probability/python/distributions/gamma_test.py +tensorflow_probability/python/distributions/gaussian_process.py +tensorflow_probability/python/distributions/gaussian_process_regression_model.py +tensorflow_probability/python/distributions/gaussian_process_regression_model_test.py +tensorflow_probability/python/distributions/gaussian_process_test.py +tensorflow_probability/python/distributions/generalized_normal.py +tensorflow_probability/python/distributions/generalized_normal_test.py +tensorflow_probability/python/distributions/generalized_pareto.py +tensorflow_probability/python/distributions/generalized_pareto_test.py +tensorflow_probability/python/distributions/geometric.py +tensorflow_probability/python/distributions/geometric_test.py +tensorflow_probability/python/distributions/gev.py +tensorflow_probability/python/distributions/gev_test.py +tensorflow_probability/python/distributions/gumbel.py +tensorflow_probability/python/distributions/gumbel_test.py +tensorflow_probability/python/distributions/half_cauchy.py +tensorflow_probability/python/distributions/half_cauchy_test.py +tensorflow_probability/python/distributions/half_normal.py +tensorflow_probability/python/distributions/half_normal_test.py +tensorflow_probability/python/distributions/half_student_t.py +tensorflow_probability/python/distributions/half_student_t_test.py +tensorflow_probability/python/distributions/hidden_markov_model.py +tensorflow_probability/python/distributions/hidden_markov_model_test.py +tensorflow_probability/python/distributions/horseshoe.py +tensorflow_probability/python/distributions/horseshoe_test.py +tensorflow_probability/python/distributions/hypothesis_testlib.py +tensorflow_probability/python/distributions/independent.py +tensorflow_probability/python/distributions/independent_test.py +tensorflow_probability/python/distributions/inflated.py +tensorflow_probability/python/distributions/inflated_test.py +tensorflow_probability/python/distributions/inverse_gamma.py +tensorflow_probability/python/distributions/inverse_gamma_test.py +tensorflow_probability/python/distributions/inverse_gaussian.py +tensorflow_probability/python/distributions/inverse_gaussian_test.py +tensorflow_probability/python/distributions/jax_transformation_test.py +tensorflow_probability/python/distributions/johnson_su.py +tensorflow_probability/python/distributions/johnson_su_test.py +tensorflow_probability/python/distributions/joint_distribution.py +tensorflow_probability/python/distributions/joint_distribution_auto_batched.py +tensorflow_probability/python/distributions/joint_distribution_auto_batched_test.py +tensorflow_probability/python/distributions/joint_distribution_coroutine.py +tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py +tensorflow_probability/python/distributions/joint_distribution_named.py +tensorflow_probability/python/distributions/joint_distribution_named_test.py +tensorflow_probability/python/distributions/joint_distribution_sequential.py +tensorflow_probability/python/distributions/joint_distribution_sequential_test.py +tensorflow_probability/python/distributions/joint_distribution_util.py +tensorflow_probability/python/distributions/joint_distribution_util_test.py +tensorflow_probability/python/distributions/kullback_leibler.py +tensorflow_probability/python/distributions/kullback_leibler_test.py +tensorflow_probability/python/distributions/kumaraswamy.py +tensorflow_probability/python/distributions/kumaraswamy_test.py +tensorflow_probability/python/distributions/lambertw_f.py +tensorflow_probability/python/distributions/lambertw_f_test.py +tensorflow_probability/python/distributions/laplace.py +tensorflow_probability/python/distributions/laplace_test.py +tensorflow_probability/python/distributions/linear_gaussian_ssm.py +tensorflow_probability/python/distributions/linear_gaussian_ssm_test.py +tensorflow_probability/python/distributions/lkj.py +tensorflow_probability/python/distributions/lkj_test.py +tensorflow_probability/python/distributions/log_prob_ratio.py +tensorflow_probability/python/distributions/logistic.py +tensorflow_probability/python/distributions/logistic_test.py +tensorflow_probability/python/distributions/logitnormal.py +tensorflow_probability/python/distributions/logitnormal_test.py +tensorflow_probability/python/distributions/loglogistic.py +tensorflow_probability/python/distributions/loglogistic_test.py +tensorflow_probability/python/distributions/lognormal.py +tensorflow_probability/python/distributions/lognormal_test.py +tensorflow_probability/python/distributions/markov_chain.py +tensorflow_probability/python/distributions/markov_chain_test.py +tensorflow_probability/python/distributions/masked.py +tensorflow_probability/python/distributions/masked_test.py +tensorflow_probability/python/distributions/matrix_normal_linear_operator.py +tensorflow_probability/python/distributions/matrix_normal_linear_operator_test.py +tensorflow_probability/python/distributions/matrix_t_linear_operator.py +tensorflow_probability/python/distributions/matrix_t_linear_operator_test.py +tensorflow_probability/python/distributions/mixture.py +tensorflow_probability/python/distributions/mixture_same_family.py +tensorflow_probability/python/distributions/mixture_same_family_test.py +tensorflow_probability/python/distributions/mixture_test.py +tensorflow_probability/python/distributions/moyal.py +tensorflow_probability/python/distributions/moyal_test.py +tensorflow_probability/python/distributions/multinomial.py +tensorflow_probability/python/distributions/multinomial_test.py +tensorflow_probability/python/distributions/multivariate_student_t.py +tensorflow_probability/python/distributions/multivariate_student_t_test.py +tensorflow_probability/python/distributions/mvn_diag.py +tensorflow_probability/python/distributions/mvn_diag_plus_low_rank.py +tensorflow_probability/python/distributions/mvn_diag_plus_low_rank_covariance.py +tensorflow_probability/python/distributions/mvn_diag_plus_low_rank_covariance_test.py +tensorflow_probability/python/distributions/mvn_diag_plus_low_rank_test.py +tensorflow_probability/python/distributions/mvn_diag_test.py +tensorflow_probability/python/distributions/mvn_full_covariance.py +tensorflow_probability/python/distributions/mvn_full_covariance_test.py +tensorflow_probability/python/distributions/mvn_linear_operator.py +tensorflow_probability/python/distributions/mvn_linear_operator_test.py +tensorflow_probability/python/distributions/mvn_low_rank_update_linear_operator_covariance.py +tensorflow_probability/python/distributions/mvn_low_rank_update_linear_operator_covariance_test.py +tensorflow_probability/python/distributions/mvn_tril.py +tensorflow_probability/python/distributions/mvn_tril_test.py +tensorflow_probability/python/distributions/negative_binomial.py +tensorflow_probability/python/distributions/negative_binomial_test.py +tensorflow_probability/python/distributions/noncentral_chi2.py +tensorflow_probability/python/distributions/noncentral_chi2_test.py +tensorflow_probability/python/distributions/normal.py +tensorflow_probability/python/distributions/normal_conjugate_posteriors.py +tensorflow_probability/python/distributions/normal_conjugate_posteriors_test.py +tensorflow_probability/python/distributions/normal_inverse_gaussian.py +tensorflow_probability/python/distributions/normal_inverse_gaussian_test.py +tensorflow_probability/python/distributions/normal_test.py +tensorflow_probability/python/distributions/numerical_properties_test.py +tensorflow_probability/python/distributions/onehot_categorical.py +tensorflow_probability/python/distributions/onehot_categorical_test.py +tensorflow_probability/python/distributions/ordered_logistic.py +tensorflow_probability/python/distributions/ordered_logistic_test.py +tensorflow_probability/python/distributions/pareto.py +tensorflow_probability/python/distributions/pareto_test.py +tensorflow_probability/python/distributions/pert.py +tensorflow_probability/python/distributions/pert_test.py +tensorflow_probability/python/distributions/pixel_cnn.py +tensorflow_probability/python/distributions/pixel_cnn_test.py +tensorflow_probability/python/distributions/plackett_luce.py +tensorflow_probability/python/distributions/plackett_luce_test.py +tensorflow_probability/python/distributions/platform_compatibility_test.py +tensorflow_probability/python/distributions/poisson.py +tensorflow_probability/python/distributions/poisson_lognormal.py +tensorflow_probability/python/distributions/poisson_lognormal_test.py +tensorflow_probability/python/distributions/poisson_test.py +tensorflow_probability/python/distributions/power_spherical.py +tensorflow_probability/python/distributions/power_spherical_test.py +tensorflow_probability/python/distributions/probit_bernoulli.py +tensorflow_probability/python/distributions/probit_bernoulli_test.py +tensorflow_probability/python/distributions/quantized_distribution.py +tensorflow_probability/python/distributions/quantized_distribution_test.py +tensorflow_probability/python/distributions/relaxed_bernoulli.py +tensorflow_probability/python/distributions/relaxed_bernoulli_test.py +tensorflow_probability/python/distributions/relaxed_onehot_categorical.py +tensorflow_probability/python/distributions/relaxed_onehot_categorical_test.py +tensorflow_probability/python/distributions/sample.py +tensorflow_probability/python/distributions/sample_test.py +tensorflow_probability/python/distributions/sigmoid_beta.py +tensorflow_probability/python/distributions/sigmoid_beta_test.py +tensorflow_probability/python/distributions/sinh_arcsinh.py +tensorflow_probability/python/distributions/sinh_arcsinh_test.py +tensorflow_probability/python/distributions/skellam.py +tensorflow_probability/python/distributions/skellam_test.py +tensorflow_probability/python/distributions/spherical_uniform.py +tensorflow_probability/python/distributions/spherical_uniform_test.py +tensorflow_probability/python/distributions/stochastic_process_properties_test.py +tensorflow_probability/python/distributions/stopping_ratio_logistic.py +tensorflow_probability/python/distributions/stopping_ratio_logistic_test.py +tensorflow_probability/python/distributions/student_t.py +tensorflow_probability/python/distributions/student_t_process.py +tensorflow_probability/python/distributions/student_t_process_regression_model.py +tensorflow_probability/python/distributions/student_t_process_regression_model_test.py +tensorflow_probability/python/distributions/student_t_process_test.py +tensorflow_probability/python/distributions/student_t_test.py +tensorflow_probability/python/distributions/transformed_distribution.py +tensorflow_probability/python/distributions/transformed_distribution_test.py +tensorflow_probability/python/distributions/triangular.py +tensorflow_probability/python/distributions/triangular_test.py +tensorflow_probability/python/distributions/truncated_cauchy.py +tensorflow_probability/python/distributions/truncated_cauchy_test.py +tensorflow_probability/python/distributions/truncated_normal.py +tensorflow_probability/python/distributions/truncated_normal_test.py +tensorflow_probability/python/distributions/two_piece_normal.py +tensorflow_probability/python/distributions/two_piece_normal_test.py +tensorflow_probability/python/distributions/two_piece_student_t.py +tensorflow_probability/python/distributions/two_piece_student_t_test.py +tensorflow_probability/python/distributions/uniform.py +tensorflow_probability/python/distributions/uniform_test.py +tensorflow_probability/python/distributions/untestable_distributions.py +tensorflow_probability/python/distributions/variational_gaussian_process.py +tensorflow_probability/python/distributions/variational_gaussian_process_test.py +tensorflow_probability/python/distributions/vector_exponential_linear_operator.py +tensorflow_probability/python/distributions/von_mises.py +tensorflow_probability/python/distributions/von_mises_fisher.py +tensorflow_probability/python/distributions/von_mises_fisher_test.py +tensorflow_probability/python/distributions/von_mises_test.py +tensorflow_probability/python/distributions/weibull.py +tensorflow_probability/python/distributions/weibull_test.py +tensorflow_probability/python/distributions/wishart.py +tensorflow_probability/python/distributions/wishart_test.py +tensorflow_probability/python/distributions/zipf.py +tensorflow_probability/python/distributions/zipf_test.py +tensorflow_probability/python/distributions/internal/__init__.py +tensorflow_probability/python/distributions/internal/correlation_matrix_volumes.py +tensorflow_probability/python/distributions/internal/correlation_matrix_volumes_lib.py +tensorflow_probability/python/distributions/internal/correlation_matrix_volumes_test.py +tensorflow_probability/python/distributions/internal/statistical_testing.py +tensorflow_probability/python/distributions/internal/statistical_testing_test.py +tensorflow_probability/python/distributions/internal/stochastic_process_util.py +tensorflow_probability/python/experimental/__init__.py +tensorflow_probability/python/experimental/auto_batching/__init__.py +tensorflow_probability/python/experimental/auto_batching/allocation_strategy.py +tensorflow_probability/python/experimental/auto_batching/allocation_strategy_test.py +tensorflow_probability/python/experimental/auto_batching/backend_test_lib.py +tensorflow_probability/python/experimental/auto_batching/dsl.py +tensorflow_probability/python/experimental/auto_batching/dsl_test.py +tensorflow_probability/python/experimental/auto_batching/frontend.py +tensorflow_probability/python/experimental/auto_batching/frontend_test.py +tensorflow_probability/python/experimental/auto_batching/gast_util.py +tensorflow_probability/python/experimental/auto_batching/instructions.py +tensorflow_probability/python/experimental/auto_batching/instructions_test.py +tensorflow_probability/python/experimental/auto_batching/liveness.py +tensorflow_probability/python/experimental/auto_batching/lowering.py +tensorflow_probability/python/experimental/auto_batching/lowering_test.py +tensorflow_probability/python/experimental/auto_batching/numpy_backend.py +tensorflow_probability/python/experimental/auto_batching/numpy_backend_test.py +tensorflow_probability/python/experimental/auto_batching/stack_optimization.py +tensorflow_probability/python/experimental/auto_batching/stack_optimization_test.py +tensorflow_probability/python/experimental/auto_batching/stackless.py +tensorflow_probability/python/experimental/auto_batching/stackless_test.py +tensorflow_probability/python/experimental/auto_batching/test_programs.py +tensorflow_probability/python/experimental/auto_batching/tf_backend.py +tensorflow_probability/python/experimental/auto_batching/tf_backend_test.py +tensorflow_probability/python/experimental/auto_batching/type_inference.py +tensorflow_probability/python/experimental/auto_batching/type_inference_test.py +tensorflow_probability/python/experimental/auto_batching/virtual_machine.py +tensorflow_probability/python/experimental/auto_batching/virtual_machine_test.py +tensorflow_probability/python/experimental/auto_batching/xla.py +tensorflow_probability/python/experimental/bayesopt/__init__.py +tensorflow_probability/python/experimental/bayesopt/acquisition/__init__.py +tensorflow_probability/python/experimental/bayesopt/acquisition/acquisition_function.py +tensorflow_probability/python/experimental/bayesopt/acquisition/acquisition_function_test.py +tensorflow_probability/python/experimental/bayesopt/acquisition/expected_improvement.py +tensorflow_probability/python/experimental/bayesopt/acquisition/expected_improvement_test.py +tensorflow_probability/python/experimental/bayesopt/acquisition/max_value_entropy_search.py +tensorflow_probability/python/experimental/bayesopt/acquisition/max_value_entropy_search_test.py +tensorflow_probability/python/experimental/bayesopt/acquisition/probability_of_improvement.py +tensorflow_probability/python/experimental/bayesopt/acquisition/probability_of_improvement_test.py +tensorflow_probability/python/experimental/bayesopt/acquisition/upper_confidence_bound.py +tensorflow_probability/python/experimental/bayesopt/acquisition/upper_confidence_bound_test.py +tensorflow_probability/python/experimental/bayesopt/acquisition/weighted_power_scalarization.py +tensorflow_probability/python/experimental/bayesopt/acquisition/weighted_power_scalarization_test.py +tensorflow_probability/python/experimental/bijectors/__init__.py +tensorflow_probability/python/experimental/bijectors/distribution_bijectors.py +tensorflow_probability/python/experimental/bijectors/distribution_bijectors_test.py +tensorflow_probability/python/experimental/bijectors/highway_flow.py +tensorflow_probability/python/experimental/bijectors/highway_flow_test.py +tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse.py +tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse_test.py +tensorflow_probability/python/experimental/bijectors/sharded.py +tensorflow_probability/python/experimental/bijectors/sharded_test.py +tensorflow_probability/python/experimental/distribute/__init__.py +tensorflow_probability/python/experimental/distribute/diagonal_mass_matrix_adaptation_test.py +tensorflow_probability/python/experimental/distribute/joint_distribution.py +tensorflow_probability/python/experimental/distribute/joint_distribution_test.py +tensorflow_probability/python/experimental/distribute/sharded.py +tensorflow_probability/python/experimental/distribute/sharded_test.py +tensorflow_probability/python/experimental/distributions/__init__.py +tensorflow_probability/python/experimental/distributions/importance_resample.py +tensorflow_probability/python/experimental/distributions/importance_resample_test.py +tensorflow_probability/python/experimental/distributions/increment_log_prob.py +tensorflow_probability/python/experimental/distributions/increment_log_prob_test.py +tensorflow_probability/python/experimental/distributions/joint_distribution_pinned.py +tensorflow_probability/python/experimental/distributions/joint_distribution_pinned_test.py +tensorflow_probability/python/experimental/distributions/marginal_fns.py +tensorflow_probability/python/experimental/distributions/marginal_fns_test.py +tensorflow_probability/python/experimental/distributions/multitask_gaussian_process.py +tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py +tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model_test.py +tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_test.py +tensorflow_probability/python/experimental/distributions/mvn_precision_factor_linop.py +tensorflow_probability/python/experimental/distributions/mvn_precision_factor_linop_test.py +tensorflow_probability/python/experimental/joint_distribution_layers/__init__.py +tensorflow_probability/python/experimental/joint_distribution_layers/layers.py +tensorflow_probability/python/experimental/joint_distribution_layers/layers_test.py +tensorflow_probability/python/experimental/linalg/__init__.py +tensorflow_probability/python/experimental/linalg/linear_operator_interpolated_psd_kernel.py +tensorflow_probability/python/experimental/linalg/linear_operator_interpolated_psd_kernel_test.py +tensorflow_probability/python/experimental/linalg/linear_operator_psd_kernel.py +tensorflow_probability/python/experimental/linalg/linear_operator_psd_kernel_test.py +tensorflow_probability/python/experimental/linalg/linear_operator_row_block.py +tensorflow_probability/python/experimental/linalg/linear_operator_row_block_test.py +tensorflow_probability/python/experimental/linalg/linear_operator_unitary.py +tensorflow_probability/python/experimental/linalg/linear_operator_unitary_test.py +tensorflow_probability/python/experimental/linalg/no_pivot_ldl.py +tensorflow_probability/python/experimental/linalg/no_pivot_ldl_test.py +tensorflow_probability/python/experimental/marginalize/__init__.py +tensorflow_probability/python/experimental/marginalize/logeinsumexp.py +tensorflow_probability/python/experimental/marginalize/logeinsumexp_test.py +tensorflow_probability/python/experimental/marginalize/marginalizable.py +tensorflow_probability/python/experimental/marginalize/marginalizable_test.py +tensorflow_probability/python/experimental/math/__init__.py +tensorflow_probability/python/experimental/math/manual_special_functions.py +tensorflow_probability/python/experimental/math/manual_special_functions_test.py +tensorflow_probability/python/experimental/mcmc/__init__.py +tensorflow_probability/python/experimental/mcmc/covariance_reducer.py +tensorflow_probability/python/experimental/mcmc/covariance_reducer_test.py +tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation.py +tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation_test.py +tensorflow_probability/python/experimental/mcmc/elliptical_slice_sampler.py +tensorflow_probability/python/experimental/mcmc/elliptical_slice_sampler_test.py +tensorflow_probability/python/experimental/mcmc/expectations_reducer.py +tensorflow_probability/python/experimental/mcmc/expectations_reducer_test.py +tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py +tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py +tensorflow_probability/python/experimental/mcmc/initialization.py +tensorflow_probability/python/experimental/mcmc/initialization_test.py +tensorflow_probability/python/experimental/mcmc/kernel_builder.py +tensorflow_probability/python/experimental/mcmc/kernel_builder_test.py +tensorflow_probability/python/experimental/mcmc/kernel_outputs.py +tensorflow_probability/python/experimental/mcmc/kernel_outputs_test.py +tensorflow_probability/python/experimental/mcmc/nuts_autobatching.py +tensorflow_probability/python/experimental/mcmc/nuts_autobatching_test.py +tensorflow_probability/python/experimental/mcmc/nuts_autobatching_xla_test.py +tensorflow_probability/python/experimental/mcmc/particle_filter.py +tensorflow_probability/python/experimental/mcmc/particle_filter_augmentation.py +tensorflow_probability/python/experimental/mcmc/particle_filter_augmentation_test.py +tensorflow_probability/python/experimental/mcmc/particle_filter_test.py +tensorflow_probability/python/experimental/mcmc/pnuts_test.py +tensorflow_probability/python/experimental/mcmc/potential_scale_reduction_reducer.py +tensorflow_probability/python/experimental/mcmc/potential_scale_reduction_reducer_test.py +tensorflow_probability/python/experimental/mcmc/preconditioned_hmc.py +tensorflow_probability/python/experimental/mcmc/preconditioned_hmc_test.py +tensorflow_probability/python/experimental/mcmc/preconditioned_nuts.py +tensorflow_probability/python/experimental/mcmc/preconditioning_utils.py +tensorflow_probability/python/experimental/mcmc/progress_bar_reducer.py +tensorflow_probability/python/experimental/mcmc/progress_bar_reducer_test.py +tensorflow_probability/python/experimental/mcmc/reducer.py +tensorflow_probability/python/experimental/mcmc/run.py +tensorflow_probability/python/experimental/mcmc/sample.py +tensorflow_probability/python/experimental/mcmc/sample_discarding_kernel.py +tensorflow_probability/python/experimental/mcmc/sample_discarding_kernel_test.py +tensorflow_probability/python/experimental/mcmc/sample_fold.py +tensorflow_probability/python/experimental/mcmc/sample_fold_test.py +tensorflow_probability/python/experimental/mcmc/sample_sequential_monte_carlo.py +tensorflow_probability/python/experimental/mcmc/sample_sequential_monte_carlo_test.py +tensorflow_probability/python/experimental/mcmc/sample_test.py +tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py +tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py +tensorflow_probability/python/experimental/mcmc/sharded.py +tensorflow_probability/python/experimental/mcmc/sharded_test.py +tensorflow_probability/python/experimental/mcmc/snaper_hmc.py +tensorflow_probability/python/experimental/mcmc/snaper_hmc_test.py +tensorflow_probability/python/experimental/mcmc/step.py +tensorflow_probability/python/experimental/mcmc/step_test.py +tensorflow_probability/python/experimental/mcmc/thermodynamic_integrals.py +tensorflow_probability/python/experimental/mcmc/thermodynamic_integrals_test.py +tensorflow_probability/python/experimental/mcmc/thinning_kernel.py +tensorflow_probability/python/experimental/mcmc/thinning_kernel_test.py +tensorflow_probability/python/experimental/mcmc/tracing_reducer.py +tensorflow_probability/python/experimental/mcmc/tracing_reducer_test.py +tensorflow_probability/python/experimental/mcmc/weighted_resampling.py +tensorflow_probability/python/experimental/mcmc/weighted_resampling_test.py +tensorflow_probability/python/experimental/mcmc/windowed_sampling.py +tensorflow_probability/python/experimental/mcmc/windowed_sampling_test.py +tensorflow_probability/python/experimental/mcmc/with_reductions.py +tensorflow_probability/python/experimental/mcmc/with_reductions_test.py +tensorflow_probability/python/experimental/mcmc/internal/__init__.py +tensorflow_probability/python/experimental/mcmc/internal/test_fixtures.py +tensorflow_probability/python/experimental/nn/__init__.py +tensorflow_probability/python/experimental/nn/affine_layers.py +tensorflow_probability/python/experimental/nn/affine_layers_test.py +tensorflow_probability/python/experimental/nn/convolutional_layers.py +tensorflow_probability/python/experimental/nn/convolutional_layers_test.py +tensorflow_probability/python/experimental/nn/convolutional_layers_v2.py +tensorflow_probability/python/experimental/nn/convolutional_layers_v2_test.py +tensorflow_probability/python/experimental/nn/convolutional_transpose_layers.py +tensorflow_probability/python/experimental/nn/convolutional_transpose_layers_test.py +tensorflow_probability/python/experimental/nn/layers.py +tensorflow_probability/python/experimental/nn/layers_test.py +tensorflow_probability/python/experimental/nn/variational_base.py +tensorflow_probability/python/experimental/nn/initializers/__init__.py +tensorflow_probability/python/experimental/nn/initializers/initializers.py +tensorflow_probability/python/experimental/nn/losses/__init__.py +tensorflow_probability/python/experimental/nn/losses/losses.py +tensorflow_probability/python/experimental/nn/util/__init__.py +tensorflow_probability/python/experimental/nn/util/convolution_util.py +tensorflow_probability/python/experimental/nn/util/convolution_util_test.py +tensorflow_probability/python/experimental/nn/util/kernel_bias.py +tensorflow_probability/python/experimental/nn/util/kernel_bias_test.py +tensorflow_probability/python/experimental/nn/util/random_variable.py +tensorflow_probability/python/experimental/nn/util/random_variable_test.py +tensorflow_probability/python/experimental/nn/util/utils.py +tensorflow_probability/python/experimental/parallel_filter/__init__.py +tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py +tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_test.py +tensorflow_probability/python/experimental/psd_kernels/__init__.py +tensorflow_probability/python/experimental/psd_kernels/additive_kernel.py +tensorflow_probability/python/experimental/psd_kernels/additive_kernel_test.py +tensorflow_probability/python/experimental/psd_kernels/feature_scaled_with_categorical.py +tensorflow_probability/python/experimental/psd_kernels/feature_scaled_with_categorical_test.py +tensorflow_probability/python/experimental/psd_kernels/feature_scaled_with_embedded_categorical.py +tensorflow_probability/python/experimental/psd_kernels/feature_scaled_with_embedded_categorical_test.py +tensorflow_probability/python/experimental/psd_kernels/multitask_kernel.py +tensorflow_probability/python/experimental/psd_kernels/multitask_kernel_test.py +tensorflow_probability/python/experimental/sequential/__init__.py +tensorflow_probability/python/experimental/sequential/ensemble_adjustment_kalman_filter.py +tensorflow_probability/python/experimental/sequential/ensemble_adjustment_kalman_filter_test.py +tensorflow_probability/python/experimental/sequential/ensemble_kalman_filter.py +tensorflow_probability/python/experimental/sequential/ensemble_kalman_filter_test.py +tensorflow_probability/python/experimental/sequential/extended_kalman_filter.py +tensorflow_probability/python/experimental/sequential/extended_kalman_filter_test.py +tensorflow_probability/python/experimental/sequential/iterated_filter.py +tensorflow_probability/python/experimental/sequential/iterated_filter_test.py +tensorflow_probability/python/experimental/stats/__init__.py +tensorflow_probability/python/experimental/stats/sample_stats.py +tensorflow_probability/python/experimental/stats/sample_stats_test.py +tensorflow_probability/python/experimental/sts_gibbs/__init__.py +tensorflow_probability/python/experimental/sts_gibbs/benchmarks_test.py +tensorflow_probability/python/experimental/sts_gibbs/dynamic_spike_and_slab.py +tensorflow_probability/python/experimental/sts_gibbs/dynamic_spike_and_slab_test.py +tensorflow_probability/python/experimental/sts_gibbs/gibbs_sampler.py +tensorflow_probability/python/experimental/sts_gibbs/gibbs_sampler_test.py +tensorflow_probability/python/experimental/sts_gibbs/sample_parameters.py +tensorflow_probability/python/experimental/sts_gibbs/sample_parameters_test.py +tensorflow_probability/python/experimental/sts_gibbs/spike_and_slab.py +tensorflow_probability/python/experimental/sts_gibbs/spike_and_slab_test.py +tensorflow_probability/python/experimental/substrates/__init__.py +tensorflow_probability/python/experimental/tangent_spaces/__init__.py +tensorflow_probability/python/experimental/tangent_spaces/simplex.py +tensorflow_probability/python/experimental/tangent_spaces/simplex_test.py +tensorflow_probability/python/experimental/tangent_spaces/spaces.py +tensorflow_probability/python/experimental/tangent_spaces/spaces_test.py +tensorflow_probability/python/experimental/tangent_spaces/spaces_test_util.py +tensorflow_probability/python/experimental/tangent_spaces/spherical.py +tensorflow_probability/python/experimental/tangent_spaces/spherical_test.py +tensorflow_probability/python/experimental/tangent_spaces/symmetric_matrix.py +tensorflow_probability/python/experimental/tangent_spaces/symmetric_matrix_test.py +tensorflow_probability/python/experimental/util/__init__.py +tensorflow_probability/python/experimental/util/composite_tensor.py +tensorflow_probability/python/experimental/util/deferred_module.py +tensorflow_probability/python/experimental/util/deferred_module_test.py +tensorflow_probability/python/experimental/util/jit_public_methods.py +tensorflow_probability/python/experimental/util/jit_public_methods_test.py +tensorflow_probability/python/experimental/util/special_methods.py +tensorflow_probability/python/experimental/util/trainable.py +tensorflow_probability/python/experimental/util/trainable_test.py +tensorflow_probability/python/experimental/vi/__init__.py +tensorflow_probability/python/experimental/vi/automatic_structured_vi.py +tensorflow_probability/python/experimental/vi/automatic_structured_vi_test.py +tensorflow_probability/python/experimental/vi/surrogate_posteriors.py +tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py +tensorflow_probability/python/experimental/vi/util/__init__.py +tensorflow_probability/python/experimental/vi/util/trainable_linear_operators.py +tensorflow_probability/python/experimental/vi/util/trainable_linear_operators_test.py +tensorflow_probability/python/glm/__init__.py +tensorflow_probability/python/glm/family.py +tensorflow_probability/python/glm/family_test.py +tensorflow_probability/python/glm/fisher_scoring.py +tensorflow_probability/python/glm/fisher_scoring_test.py +tensorflow_probability/python/glm/proximal_hessian.py +tensorflow_probability/python/glm/proximal_hessian_test.py +tensorflow_probability/python/internal/__init__.py +tensorflow_probability/python/internal/all_util.py +tensorflow_probability/python/internal/assert_util.py +tensorflow_probability/python/internal/auto_composite_tensor.py +tensorflow_probability/python/internal/auto_composite_tensor_test.py +tensorflow_probability/python/internal/batch_shape_lib.py +tensorflow_probability/python/internal/batch_shape_lib_test.py +tensorflow_probability/python/internal/batched_rejection_sampler.py +tensorflow_probability/python/internal/batched_rejection_sampler_test.py +tensorflow_probability/python/internal/broadcast_util.py +tensorflow_probability/python/internal/broadcast_util_test.py +tensorflow_probability/python/internal/cache_util.py +tensorflow_probability/python/internal/cache_util_test.py +tensorflow_probability/python/internal/callable_util.py +tensorflow_probability/python/internal/callable_util_test.py +tensorflow_probability/python/internal/custom_gradient.py +tensorflow_probability/python/internal/custom_gradient_test.py +tensorflow_probability/python/internal/distribute_lib.py +tensorflow_probability/python/internal/distribute_lib_test.py +tensorflow_probability/python/internal/distribute_test_lib.py +tensorflow_probability/python/internal/distribution_util.py +tensorflow_probability/python/internal/distribution_util_test.py +tensorflow_probability/python/internal/docstring_util.py +tensorflow_probability/python/internal/docstring_util_test.py +tensorflow_probability/python/internal/dtype_util.py +tensorflow_probability/python/internal/dtype_util_test.py +tensorflow_probability/python/internal/empirical_statistical_testing.py +tensorflow_probability/python/internal/empirical_statistical_testing_test.py +tensorflow_probability/python/internal/hypothesis_testlib.py +tensorflow_probability/python/internal/hypothesis_testlib_test.py +tensorflow_probability/python/internal/implementation_selection.py +tensorflow_probability/python/internal/implementation_selection_test.py +tensorflow_probability/python/internal/lazy_loader.py +tensorflow_probability/python/internal/loop_util.py +tensorflow_probability/python/internal/loop_util_test.py +tensorflow_probability/python/internal/monte_carlo.py +tensorflow_probability/python/internal/name_util.py +tensorflow_probability/python/internal/nest_util.py +tensorflow_probability/python/internal/nest_util_test.py +tensorflow_probability/python/internal/numerics_testing.py +tensorflow_probability/python/internal/numerics_testing_test.py +tensorflow_probability/python/internal/parameter_properties.py +tensorflow_probability/python/internal/prefer_static.py +tensorflow_probability/python/internal/prefer_static_shape64_test.py +tensorflow_probability/python/internal/prefer_static_test.py +tensorflow_probability/python/internal/reparameterization.py +tensorflow_probability/python/internal/samplers.py +tensorflow_probability/python/internal/samplers_test.py +tensorflow_probability/python/internal/slicing.py +tensorflow_probability/python/internal/slicing_test.py +tensorflow_probability/python/internal/special_math.py +tensorflow_probability/python/internal/special_math_test.py +tensorflow_probability/python/internal/structural_tuple.py +tensorflow_probability/python/internal/structural_tuple_test.py +tensorflow_probability/python/internal/tensor_util.py +tensorflow_probability/python/internal/tensor_util_test.py +tensorflow_probability/python/internal/tensorshape_util.py +tensorflow_probability/python/internal/tensorshape_util_test.py +tensorflow_probability/python/internal/test_combinations.py +tensorflow_probability/python/internal/test_combinations_test.py +tensorflow_probability/python/internal/test_util.py +tensorflow_probability/python/internal/test_util_test.py +tensorflow_probability/python/internal/tf_keras.py +tensorflow_probability/python/internal/trainable_state_util.py +tensorflow_probability/python/internal/trainable_state_util_test.py +tensorflow_probability/python/internal/unnest.py +tensorflow_probability/python/internal/unnest_test.py +tensorflow_probability/python/internal/variadic_reduce.py +tensorflow_probability/python/internal/vectorization_util.py +tensorflow_probability/python/internal/vectorization_util_test.py +tensorflow_probability/python/internal/backend/__init__.py +tensorflow_probability/python/internal/backend/numpy/__init__.py +tensorflow_probability/python/internal/backend/numpy/__internal__.py +tensorflow_probability/python/internal/backend/numpy/_utils.py +tensorflow_probability/python/internal/backend/numpy/bitwise.py +tensorflow_probability/python/internal/backend/numpy/compat.py +tensorflow_probability/python/internal/backend/numpy/composite_tensor.py +tensorflow_probability/python/internal/backend/numpy/composite_tensor_gradient.py +tensorflow_probability/python/internal/backend/numpy/config.py +tensorflow_probability/python/internal/backend/numpy/control_flow.py +tensorflow_probability/python/internal/backend/numpy/data_structures.py +tensorflow_probability/python/internal/backend/numpy/debugging.py +tensorflow_probability/python/internal/backend/numpy/deprecation.py +tensorflow_probability/python/internal/backend/numpy/dtype.py +tensorflow_probability/python/internal/backend/numpy/errors.py +tensorflow_probability/python/internal/backend/numpy/functional_ops.py +tensorflow_probability/python/internal/backend/numpy/initializers.py +tensorflow_probability/python/internal/backend/numpy/keras_layers.py +tensorflow_probability/python/internal/backend/numpy/linalg.py +tensorflow_probability/python/internal/backend/numpy/linalg_impl.py +tensorflow_probability/python/internal/backend/numpy/misc.py +tensorflow_probability/python/internal/backend/numpy/nest.py +tensorflow_probability/python/internal/backend/numpy/nn.py +tensorflow_probability/python/internal/backend/numpy/numpy_array.py +tensorflow_probability/python/internal/backend/numpy/numpy_keras.py +tensorflow_probability/python/internal/backend/numpy/numpy_logging.py +tensorflow_probability/python/internal/backend/numpy/numpy_math.py +tensorflow_probability/python/internal/backend/numpy/numpy_signal.py +tensorflow_probability/python/internal/backend/numpy/numpy_test.py +tensorflow_probability/python/internal/backend/numpy/ops.py +tensorflow_probability/python/internal/backend/numpy/private.py +tensorflow_probability/python/internal/backend/numpy/random_generators.py +tensorflow_probability/python/internal/backend/numpy/raw_ops.py +tensorflow_probability/python/internal/backend/numpy/resource_variable_ops.py +tensorflow_probability/python/internal/backend/numpy/rewrite_equivalence_test.py +tensorflow_probability/python/internal/backend/numpy/sets_lib.py +tensorflow_probability/python/internal/backend/numpy/sparse_lib.py +tensorflow_probability/python/internal/backend/numpy/tensor_array_ops.py +tensorflow_probability/python/internal/backend/numpy/tensor_array_ops_test.py +tensorflow_probability/python/internal/backend/numpy/tensor_spec.py +tensorflow_probability/python/internal/backend/numpy/test_lib.py +tensorflow_probability/python/internal/backend/numpy/tf_inspect.py +tensorflow_probability/python/internal/backend/numpy/type_spec.py +tensorflow_probability/python/internal/backend/numpy/type_spec_registry.py +tensorflow_probability/python/internal/backend/numpy/v1.py +tensorflow_probability/python/internal/backend/numpy/v2.py +tensorflow_probability/python/internal/backend/numpy/variable_utils.py +tensorflow_probability/python/internal/backend/numpy/variables.py +tensorflow_probability/python/internal/backend/numpy/gen/__init__.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_addition.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_adjoint.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_block_diag.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_block_lower_triangular.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_circulant.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_composition.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_diag.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_full_matrix.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_householder.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_identity.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_inversion.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_kronecker.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_low_rank_update.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_lower_triangular.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_permutation.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_toeplitz.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_util.py +tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_zeros.py +tensorflow_probability/python/internal/backend/numpy/gen/property_hint_util.py +tensorflow_probability/python/internal/backend/numpy/gen/slicing.py +tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py +tensorflow_probability/python/layers/__init__.py +tensorflow_probability/python/layers/conv_variational.py +tensorflow_probability/python/layers/conv_variational_test.py +tensorflow_probability/python/layers/dense_variational.py +tensorflow_probability/python/layers/dense_variational_test.py +tensorflow_probability/python/layers/dense_variational_v2.py +tensorflow_probability/python/layers/dense_variational_v2_test.py +tensorflow_probability/python/layers/distribution_layer.py +tensorflow_probability/python/layers/distribution_layer_test.py +tensorflow_probability/python/layers/initializers.py +tensorflow_probability/python/layers/initializers_test.py +tensorflow_probability/python/layers/masked_autoregressive.py +tensorflow_probability/python/layers/masked_autoregressive_test.py +tensorflow_probability/python/layers/util.py +tensorflow_probability/python/layers/variable_input.py +tensorflow_probability/python/layers/variable_input_test.py +tensorflow_probability/python/layers/weight_norm.py +tensorflow_probability/python/layers/weight_norm_test.py +tensorflow_probability/python/layers/internal/__init__.py +tensorflow_probability/python/layers/internal/distribution_tensor_coercible.py +tensorflow_probability/python/layers/internal/distribution_tensor_coercible_test.py +tensorflow_probability/python/layers/internal/tensor_tuple.py +tensorflow_probability/python/layers/internal/tensor_tuple_test.py +tensorflow_probability/python/math/__init__.py +tensorflow_probability/python/math/bessel.py +tensorflow_probability/python/math/bessel_test.py +tensorflow_probability/python/math/custom_gradient.py +tensorflow_probability/python/math/custom_gradient_test.py +tensorflow_probability/python/math/diag_jacobian.py +tensorflow_probability/python/math/diag_jacobian_test.py +tensorflow_probability/python/math/generic.py +tensorflow_probability/python/math/generic_test.py +tensorflow_probability/python/math/gradient.py +tensorflow_probability/python/math/gradient_test.py +tensorflow_probability/python/math/gram_schmidt.py +tensorflow_probability/python/math/gram_schmidt_test.py +tensorflow_probability/python/math/hypergeometric.py +tensorflow_probability/python/math/hypergeometric_test.py +tensorflow_probability/python/math/integration.py +tensorflow_probability/python/math/integration_test.py +tensorflow_probability/python/math/interpolation.py +tensorflow_probability/python/math/interpolation_test.py +tensorflow_probability/python/math/linalg.py +tensorflow_probability/python/math/linalg_test.py +tensorflow_probability/python/math/minimize.py +tensorflow_probability/python/math/minimize_test.py +tensorflow_probability/python/math/numeric.py +tensorflow_probability/python/math/numeric_test.py +tensorflow_probability/python/math/root_search.py +tensorflow_probability/python/math/root_search_test.py +tensorflow_probability/python/math/scan_associative.py +tensorflow_probability/python/math/scan_associative_test.py +tensorflow_probability/python/math/sparse.py +tensorflow_probability/python/math/sparse_test.py +tensorflow_probability/python/math/special.py +tensorflow_probability/python/math/special_test.py +tensorflow_probability/python/math/ode/__init__.py +tensorflow_probability/python/math/ode/base.py +tensorflow_probability/python/math/ode/bdf.py +tensorflow_probability/python/math/ode/bdf_util.py +tensorflow_probability/python/math/ode/bdf_util_test.py +tensorflow_probability/python/math/ode/dormand_prince.py +tensorflow_probability/python/math/ode/ode_test.py +tensorflow_probability/python/math/ode/runge_kutta_util.py +tensorflow_probability/python/math/ode/runge_kutta_util_test.py +tensorflow_probability/python/math/ode/util.py +tensorflow_probability/python/math/ode/util_test.py +tensorflow_probability/python/math/ode/xla_test.py +tensorflow_probability/python/math/psd_kernels/__init__.py +tensorflow_probability/python/math/psd_kernels/changepoint.py +tensorflow_probability/python/math/psd_kernels/changepoint_test.py +tensorflow_probability/python/math/psd_kernels/exp_sin_squared.py +tensorflow_probability/python/math/psd_kernels/exp_sin_squared_test.py +tensorflow_probability/python/math/psd_kernels/exponential_curve.py +tensorflow_probability/python/math/psd_kernels/exponential_curve_test.py +tensorflow_probability/python/math/psd_kernels/exponentiated_quadratic.py +tensorflow_probability/python/math/psd_kernels/exponentiated_quadratic_test.py +tensorflow_probability/python/math/psd_kernels/feature_scaled.py +tensorflow_probability/python/math/psd_kernels/feature_scaled_test.py +tensorflow_probability/python/math/psd_kernels/feature_transformed.py +tensorflow_probability/python/math/psd_kernels/feature_transformed_test.py +tensorflow_probability/python/math/psd_kernels/gamma_exponential.py +tensorflow_probability/python/math/psd_kernels/gamma_exponential_test.py +tensorflow_probability/python/math/psd_kernels/hypothesis_testlib.py +tensorflow_probability/python/math/psd_kernels/kumaraswamy_transformed.py +tensorflow_probability/python/math/psd_kernels/kumaraswamy_transformed_test.py +tensorflow_probability/python/math/psd_kernels/matern.py +tensorflow_probability/python/math/psd_kernels/matern_test.py +tensorflow_probability/python/math/psd_kernels/parabolic.py +tensorflow_probability/python/math/psd_kernels/parabolic_test.py +tensorflow_probability/python/math/psd_kernels/pointwise_exponential.py +tensorflow_probability/python/math/psd_kernels/pointwise_exponential_test.py +tensorflow_probability/python/math/psd_kernels/polynomial.py +tensorflow_probability/python/math/psd_kernels/polynomial_test.py +tensorflow_probability/python/math/psd_kernels/positive_semidefinite_kernel.py +tensorflow_probability/python/math/psd_kernels/positive_semidefinite_kernel_test.py +tensorflow_probability/python/math/psd_kernels/psd_kernel_properties_test.py +tensorflow_probability/python/math/psd_kernels/rational_quadratic.py +tensorflow_probability/python/math/psd_kernels/rational_quadratic_test.py +tensorflow_probability/python/math/psd_kernels/schur_complement.py +tensorflow_probability/python/math/psd_kernels/schur_complement_test.py +tensorflow_probability/python/math/psd_kernels/spectral_mixture.py +tensorflow_probability/python/math/psd_kernels/spectral_mixture_test.py +tensorflow_probability/python/math/psd_kernels/internal/__init__.py +tensorflow_probability/python/math/psd_kernels/internal/test_util.py +tensorflow_probability/python/math/psd_kernels/internal/test_util_test.py +tensorflow_probability/python/math/psd_kernels/internal/util.py +tensorflow_probability/python/math/psd_kernels/internal/util_test.py +tensorflow_probability/python/mcmc/__init__.py +tensorflow_probability/python/mcmc/diagnostic.py +tensorflow_probability/python/mcmc/diagnostic_test.py +tensorflow_probability/python/mcmc/dual_averaging_step_size_adaptation.py +tensorflow_probability/python/mcmc/dual_averaging_step_size_adaptation_test.py +tensorflow_probability/python/mcmc/eight_schools_hmc.py +tensorflow_probability/python/mcmc/eight_schools_hmc_eager_test.py +tensorflow_probability/python/mcmc/eight_schools_hmc_graph_test.py +tensorflow_probability/python/mcmc/hmc.py +tensorflow_probability/python/mcmc/hmc_test.py +tensorflow_probability/python/mcmc/kernel.py +tensorflow_probability/python/mcmc/langevin.py +tensorflow_probability/python/mcmc/langevin_test.py +tensorflow_probability/python/mcmc/metropolis_hastings.py +tensorflow_probability/python/mcmc/metropolis_hastings_test.py +tensorflow_probability/python/mcmc/nuts.py +tensorflow_probability/python/mcmc/nuts_test.py +tensorflow_probability/python/mcmc/random_walk_metropolis.py +tensorflow_probability/python/mcmc/random_walk_metropolis_test.py +tensorflow_probability/python/mcmc/replica_exchange_mc.py +tensorflow_probability/python/mcmc/replica_exchange_mc_test.py +tensorflow_probability/python/mcmc/sample.py +tensorflow_probability/python/mcmc/sample_annealed_importance.py +tensorflow_probability/python/mcmc/sample_annealed_importance_test.py +tensorflow_probability/python/mcmc/sample_halton_sequence_lib.py +tensorflow_probability/python/mcmc/sample_halton_sequence_test.py +tensorflow_probability/python/mcmc/sample_test.py +tensorflow_probability/python/mcmc/simple_step_size_adaptation.py +tensorflow_probability/python/mcmc/simple_step_size_adaptation_test.py +tensorflow_probability/python/mcmc/slice_sampler_kernel.py +tensorflow_probability/python/mcmc/slice_sampler_test.py +tensorflow_probability/python/mcmc/transformed_kernel.py +tensorflow_probability/python/mcmc/transformed_kernel_test.py +tensorflow_probability/python/mcmc/internal/__init__.py +tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py +tensorflow_probability/python/mcmc/internal/leapfrog_integrator_test.py +tensorflow_probability/python/mcmc/internal/slice_sampler_utils.py +tensorflow_probability/python/mcmc/internal/util.py +tensorflow_probability/python/mcmc/internal/util_test.py +tensorflow_probability/python/monte_carlo/__init__.py +tensorflow_probability/python/monte_carlo/expectation.py +tensorflow_probability/python/monte_carlo/expectation_test.py +tensorflow_probability/python/optimizer/__init__.py +tensorflow_probability/python/optimizer/bfgs.py +tensorflow_probability/python/optimizer/bfgs_test.py +tensorflow_probability/python/optimizer/bfgs_utils.py +tensorflow_probability/python/optimizer/differential_evolution.py +tensorflow_probability/python/optimizer/differential_evolution_test.py +tensorflow_probability/python/optimizer/lbfgs.py +tensorflow_probability/python/optimizer/lbfgs_test.py +tensorflow_probability/python/optimizer/nelder_mead.py +tensorflow_probability/python/optimizer/nelder_mead_test.py +tensorflow_probability/python/optimizer/proximal_hessian_sparse.py +tensorflow_probability/python/optimizer/proximal_hessian_sparse_test.py +tensorflow_probability/python/optimizer/sgld.py +tensorflow_probability/python/optimizer/sgld_test.py +tensorflow_probability/python/optimizer/variational_sgd.py +tensorflow_probability/python/optimizer/variational_sgd_test.py +tensorflow_probability/python/optimizer/convergence_criteria/__init__.py +tensorflow_probability/python/optimizer/convergence_criteria/convergence_criterion.py +tensorflow_probability/python/optimizer/convergence_criteria/loss_not_decreasing.py +tensorflow_probability/python/optimizer/convergence_criteria/loss_not_decreasing_test.py +tensorflow_probability/python/optimizer/convergence_criteria/successive_gradients_are_uncorrelated.py +tensorflow_probability/python/optimizer/convergence_criteria/successive_gradients_are_uncorrelated_test.py +tensorflow_probability/python/optimizer/linesearch/__init__.py +tensorflow_probability/python/optimizer/linesearch/hager_zhang.py +tensorflow_probability/python/optimizer/linesearch/hager_zhang_test.py +tensorflow_probability/python/optimizer/linesearch/internal/__init__.py +tensorflow_probability/python/optimizer/linesearch/internal/hager_zhang_lib.py +tensorflow_probability/python/optimizer/linesearch/internal/hager_zhang_lib_test.py +tensorflow_probability/python/random/__init__.py +tensorflow_probability/python/random/random_ops.py +tensorflow_probability/python/random/random_ops_test.py +tensorflow_probability/python/stats/__init__.py +tensorflow_probability/python/stats/calibration.py +tensorflow_probability/python/stats/calibration_test.py +tensorflow_probability/python/stats/kendalls_tau.py +tensorflow_probability/python/stats/kendalls_tau_test.py +tensorflow_probability/python/stats/leave_one_out.py +tensorflow_probability/python/stats/leave_one_out_test.py +tensorflow_probability/python/stats/moving_stats.py +tensorflow_probability/python/stats/moving_stats_test.py +tensorflow_probability/python/stats/quantiles.py +tensorflow_probability/python/stats/quantiles_test.py +tensorflow_probability/python/stats/ranking.py +tensorflow_probability/python/stats/ranking_test.py +tensorflow_probability/python/stats/sample_stats.py +tensorflow_probability/python/stats/sample_stats_test.py +tensorflow_probability/python/sts/__init__.py +tensorflow_probability/python/sts/decomposition.py +tensorflow_probability/python/sts/decomposition_test.py +tensorflow_probability/python/sts/default_model.py +tensorflow_probability/python/sts/default_model_test.py +tensorflow_probability/python/sts/fitting.py +tensorflow_probability/python/sts/fitting_test.py +tensorflow_probability/python/sts/forecast.py +tensorflow_probability/python/sts/forecast_test.py +tensorflow_probability/python/sts/holiday_effects.py +tensorflow_probability/python/sts/holiday_effects_test.py +tensorflow_probability/python/sts/regularization.py +tensorflow_probability/python/sts/regularization_test.py +tensorflow_probability/python/sts/structural_time_series.py +tensorflow_probability/python/sts/structural_time_series_test.py +tensorflow_probability/python/sts/anomaly_detection/__init__.py +tensorflow_probability/python/sts/anomaly_detection/anomaly_detection_lib.py +tensorflow_probability/python/sts/anomaly_detection/anomaly_detection_test.py +tensorflow_probability/python/sts/components/__init__.py +tensorflow_probability/python/sts/components/autoregressive.py +tensorflow_probability/python/sts/components/autoregressive_integrated_moving_average.py +tensorflow_probability/python/sts/components/autoregressive_integrated_moving_average_test.py +tensorflow_probability/python/sts/components/autoregressive_moving_average.py +tensorflow_probability/python/sts/components/autoregressive_moving_average_test.py +tensorflow_probability/python/sts/components/autoregressive_test.py +tensorflow_probability/python/sts/components/dynamic_regression.py +tensorflow_probability/python/sts/components/dynamic_regression_test.py +tensorflow_probability/python/sts/components/local_level.py +tensorflow_probability/python/sts/components/local_level_test.py +tensorflow_probability/python/sts/components/local_linear_trend.py +tensorflow_probability/python/sts/components/local_linear_trend_test.py +tensorflow_probability/python/sts/components/regression.py +tensorflow_probability/python/sts/components/regression_test.py +tensorflow_probability/python/sts/components/seasonal.py +tensorflow_probability/python/sts/components/seasonal_test.py +tensorflow_probability/python/sts/components/semilocal_linear_trend.py +tensorflow_probability/python/sts/components/semilocal_linear_trend_test.py +tensorflow_probability/python/sts/components/smooth_seasonal.py +tensorflow_probability/python/sts/components/smooth_seasonal_test.py +tensorflow_probability/python/sts/components/sum.py +tensorflow_probability/python/sts/components/sum_test.py +tensorflow_probability/python/sts/internal/__init__.py +tensorflow_probability/python/sts/internal/missing_values_util.py +tensorflow_probability/python/sts/internal/missing_values_util_test.py +tensorflow_probability/python/sts/internal/seasonality_util.py +tensorflow_probability/python/sts/internal/seasonality_util_test.py +tensorflow_probability/python/sts/internal/util.py +tensorflow_probability/python/sts/internal/util_test.py +tensorflow_probability/python/util/__init__.py +tensorflow_probability/python/util/deferred_tensor.py +tensorflow_probability/python/util/deferred_tensor_test.py +tensorflow_probability/python/util/seed_stream.py +tensorflow_probability/python/util/seed_stream_test.py +tensorflow_probability/python/vi/__init__.py +tensorflow_probability/python/vi/csiszar_divergence.py +tensorflow_probability/python/vi/csiszar_divergence_test.py +tensorflow_probability/python/vi/mutual_information.py +tensorflow_probability/python/vi/mutual_information_test.py +tensorflow_probability/python/vi/optimization.py +tensorflow_probability/python/vi/optimization_test.py +tensorflow_probability/substrates/__init__.py +tensorflow_probability/substrates/jax/__init__.py +tensorflow_probability/substrates/numpy/__init__.py +tfp_nightly.egg-info/PKG-INFO +tfp_nightly.egg-info/SOURCES.txt +tfp_nightly.egg-info/dependency_links.txt +tfp_nightly.egg-info/not-zip-safe +tfp_nightly.egg-info/requires.txt +tfp_nightly.egg-info/top_level.txt \ No newline at end of file diff --git a/tfp_nightly.egg-info/dependency_links.txt b/tfp_nightly.egg-info/dependency_links.txt new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/tfp_nightly.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/tfp_nightly.egg-info/not-zip-safe b/tfp_nightly.egg-info/not-zip-safe new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/tfp_nightly.egg-info/not-zip-safe @@ -0,0 +1 @@ + diff --git a/tfp_nightly.egg-info/requires.txt b/tfp_nightly.egg-info/requires.txt new file mode 100644 index 0000000000..2a08bbd673 --- /dev/null +++ b/tfp_nightly.egg-info/requires.txt @@ -0,0 +1,14 @@ +absl-py +six>=1.10.0 +numpy>=1.13.3 +decorator +cloudpickle>=1.3 +gast>=0.3.2 +dm-tree + +[jax] +jax +jaxlib + +[tfds] +tfds-nightly diff --git a/tfp_nightly.egg-info/top_level.txt b/tfp_nightly.egg-info/top_level.txt new file mode 100644 index 0000000000..ecabf3d7f4 --- /dev/null +++ b/tfp_nightly.egg-info/top_level.txt @@ -0,0 +1 @@ +tensorflow_probability