diff --git a/spinoffs/fun_mc/fun_mc/BUILD b/spinoffs/fun_mc/fun_mc/BUILD index ff870d2726..f20d2bda7c 100644 --- a/spinoffs/fun_mc/fun_mc/BUILD +++ b/spinoffs/fun_mc/fun_mc/BUILD @@ -46,6 +46,8 @@ py_library( deps = [ ":fun_mc_lib", ":prefab", + ":smc", + ":types", ":util_tfp", ], ) @@ -85,8 +87,10 @@ py_library( srcs = ["using_jax.py"], deps = [ ":api", + # jax dep, "//fun_mc/backends:rewrite", "//fun_mc/dynamic/backend_jax:backend", + # tensorflow_probability/substrates:jax dep, ], ) @@ -96,6 +100,8 @@ py_library( srcs = ["using_tensorflow.py"], deps = [ ":api", + # tensorflow dep, + # tensorflow_probability dep, "//fun_mc/backends:rewrite", "//fun_mc/dynamic/backend_tensorflow:backend", ], @@ -144,6 +150,33 @@ py_test( ], ) +# pytype +py_library( + name = "prefab", + srcs = ["prefab.py"], + deps = [ + ":backend", + ":fun_mc_lib", + ":malt", + ":sga_hmc", + ], +) + +py_test( + name = "prefab_test", + srcs = ["prefab_test.py"], + shard_count = 2, + deps = [ + ":fun_mc", + ":prefab", + ":test_util", + # jax dep, + # tensorflow dep, + # tensorflow_probability/python/internal:test_util dep, + # tensorflow/compiler/jit dep, + ], +) + # pytype py_library( name = "sga_hmc", @@ -172,31 +205,46 @@ py_test( # pytype py_library( - name = "prefab", - srcs = ["prefab.py"], + name = "smc", + srcs = ["smc.py"], deps = [ ":backend", ":fun_mc_lib", - ":malt", - ":sga_hmc", + ":types", ], ) -py_test( - name = "prefab_test", - srcs = ["prefab_test.py"], - shard_count = 2, +pytype_strict_contrib_test( + name = "smc_test", + srcs = ["smc_test.py"], + shard_count = 4, deps = [ - ":fun_mc", - ":prefab", + ":backend", + ":fun_mc_lib", + ":smc", ":test_util", + ":types", + # absl/testing:parameterized dep, # jax dep, + # jaxtyping dep, + # mock dep, # tensorflow dep, # tensorflow_probability/python/internal:test_util dep, # tensorflow/compiler/jit dep, ], ) +# pytype +py_library( + name = "types", + srcs = ["types.py"], + deps = [ + ":backend", + # jaxtyping dep, + # typeguard dep, + ], +) + # pytype py_library( name = "util_tfp", diff --git a/spinoffs/fun_mc/fun_mc/api.py b/spinoffs/fun_mc/fun_mc/api.py index 81d866a34d..9f9bc89615 100644 --- a/spinoffs/fun_mc/fun_mc/api.py +++ b/spinoffs/fun_mc/fun_mc/api.py @@ -17,10 +17,14 @@ from fun_mc import fun_mc_lib from fun_mc import prefab +from fun_mc import smc +from fun_mc import types from fun_mc import util_tfp from fun_mc.fun_mc_lib import * +from fun_mc.smc import * +from fun_mc.types import * __all__ = [ 'prefab', 'util_tfp', -] + fun_mc_lib.__all__ +] + fun_mc_lib.__all__ + smc.__all__ + types.__all__ diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/BUILD b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/BUILD index 337bd2abf6..019998f434 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/BUILD +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/BUILD @@ -32,8 +32,7 @@ py_library( name = "util", srcs = ["util.py"], deps = [ - # jax dep, - # jax:stax dep, + # jaxtyping dep, ], ) @@ -42,6 +41,5 @@ py_library( srcs = ["backend.py"], deps = [ ":util", - # tensorflow_probability/substrates:jax dep, ], ) diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py index cd9ea3e493..8ed20c1399 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py @@ -21,12 +21,15 @@ from jax import random from jax import tree_util import jax.numpy as jnp +import jaxtyping __all__ = [ + 'Array', 'assert_same_shallow_tree', 'block_until_ready', 'convert_to_tensor', 'diff', + 'DType', 'flatten_tree', 'get_shallow_tree', 'inverse_fn', @@ -39,8 +42,10 @@ 'random_categorical', 'random_integer', 'random_normal', + 'random_permutation', 'random_uniform', 'repeat', + 'Seed', 'split_seed', 'stack_dynamic_array', 'trace', @@ -50,6 +55,11 @@ ] +Array = jaxtyping.Array +DType = jax.typing.DTypeLike +Seed = jaxtyping.PRNGKeyArray + + def map_tree(fn, tree, *args): """Maps `fn` over the leaves of a nested structure.""" return tree_util.tree_map(fn, tree, *args) @@ -102,8 +112,9 @@ def make_tensor_seed(seed): if hasattr(seed, 'dtype') and jax.dtypes.issubdtype( seed.dtype, jax.dtypes.prng_key ): - return seed - return jnp.asarray(seed, jnp.uint32) + return jnp.asarray(seed) + else: + return jnp.asarray(seed, jnp.uint32) def split_seed(seed, count): @@ -114,7 +125,8 @@ def split_seed(seed, count): def random_uniform(shape, dtype, seed): """Generates a sample from uniform distribution over [0., 1).""" return random.uniform( - shape=tuple(shape), dtype=dtype, key=make_tensor_seed(seed)) + shape=tuple(shape), dtype=dtype, key=make_tensor_seed(seed) + ) def random_integer(shape, dtype, minval, maxval, seed): @@ -124,13 +136,15 @@ def random_integer(shape, dtype, minval, maxval, seed): dtype=dtype, minval=minval, maxval=maxval, - key=make_tensor_seed(seed)) + key=make_tensor_seed(seed), + ) def random_normal(shape, dtype, seed): """Generates a sample from a standard normal distribution.""" return random.normal( - shape=tuple(shape), dtype=dtype, key=make_tensor_seed(seed)) + shape=tuple(shape), dtype=dtype, key=make_tensor_seed(seed) + ) def _searchsorted(a, v): @@ -158,7 +172,8 @@ def random_categorical(logits, num_samples, seed): cum_sum = jnp.cumsum(probs, axis=-1) eta = random.uniform( - make_tensor_seed(seed), (num_samples,) + cum_sum.shape[:-1]) + make_tensor_seed(seed), (num_samples,) + cum_sum.shape[:-1] + ) cum_sum = jnp.broadcast_to(cum_sum, (num_samples,) + cum_sum.shape) flat_cum_sum = cum_sum.reshape([-1, cum_sum.shape[-1]]) @@ -166,11 +181,17 @@ def random_categorical(logits, num_samples, seed): return jax.vmap(_searchsorted)(flat_cum_sum, flat_eta).reshape(eta.shape).T +def random_permutation(value, seed): + """Randomly permutes the array.""" + return random.permutation(seed, value) + + def trace(state, fn, num_steps, unroll, max_steps, **_): """Implementation of `trace` operator, without the calling convention.""" # We need the shapes and dtypes of the outputs of `fn`. _, untraced_spec, traced_spec, stop_spec = jax.eval_shape( - fn, map_tree(lambda s: jax.ShapeDtypeStruct(s.shape, s.dtype), state)) + fn, map_tree(lambda s: jax.ShapeDtypeStruct(s.shape, s.dtype), state) + ) if isinstance(stop_spec, tuple): stop = () else: @@ -211,12 +232,16 @@ def trace(state, fn, num_steps, unroll, max_steps, **_): state, untraced, traced_element, stop = fn(state) else: traced_element = traced_init - map_tree_up_to(traced_spec, lambda l, e: l.append(e), traced_lists, - traced_element) + map_tree_up_to( + traced_spec, lambda l, e: l.append(e), traced_lists, traced_element + ) # Using asarray instead of stack to handle empty arrays correctly. - traced = map_tree_up_to(traced_spec, - lambda l, s: jnp.asarray(l, dtype=s.dtype), - traced_lists, traced_spec) + traced = map_tree_up_to( + traced_spec, + lambda l, s: jnp.asarray(l, dtype=s.dtype), + traced_lists, + traced_spec, + ) elif use_scan: def wrapper(state_untraced, _): diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/BUILD b/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/BUILD index c3bced3350..d95d4124d7 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/BUILD +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/BUILD @@ -28,7 +28,6 @@ py_library( srcs = ["util.py"], deps = [ # numpy dep, - # tensorflow dep, ], ) @@ -37,7 +36,5 @@ py_library( srcs = ["backend.py"], deps = [ ":util", - # tensorflow dep, - # tensorflow_probability dep, ], ) diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py index 92d0a5c275..41d9e0c43f 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py @@ -19,15 +19,18 @@ import numpy as np import six import tensorflow.compat.v2 as tf + from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import tnp = tf.experimental.numpy __all__ = [ + 'Array', 'assert_same_shallow_tree', 'block_until_ready', 'convert_to_tensor', 'diff', + 'DType', 'flatten_tree', 'get_shallow_tree', 'inverse_fn', @@ -40,16 +43,24 @@ 'random_categorical', 'random_integer', 'random_normal', + 'random_permutation', 'random_uniform', 'repeat', + 'Seed', 'split_seed', 'stack_dynamic_array', 'trace', + 'value_and_grad', 'value_and_ldj', 'write_to_dynamic_array', ] +Array = tf.Tensor +DType = tf.DType +Seed = tf.Tensor | int + + def map_tree(fn, tree, *args): """Maps `fn` over the leaves of a nested structure.""" return tf.nest.map_structure(fn, tree, *args) @@ -98,11 +109,9 @@ def make_tensor_seed(seed): """Converts a seed to a `Tensor` seed.""" if _is_stateful_seed(seed): iinfo = np.iinfo(np.int32) - return tf.random.uniform([2], - minval=iinfo.min, - maxval=iinfo.max, - dtype=tf.int32, - name='seed') + return tf.random.uniform( + [2], minval=iinfo.min, maxval=iinfo.max, dtype=tf.int32, name='seed' + ) else: return tf.convert_to_tensor(seed, dtype=tf.int32, name='seed') @@ -139,10 +148,12 @@ def random_integer(shape, dtype, minval, maxval, seed): """Generates a sample from uniform distribution over [minval, maxval).""" if _is_stateful_seed(seed): return tf.random.uniform( - shape=shape, dtype=dtype, minval=minval, maxval=maxval, seed=seed) + shape=shape, dtype=dtype, minval=minval, maxval=maxval, seed=seed + ) else: return tf.random.stateless_uniform( - shape=shape, dtype=dtype, minval=minval, maxval=maxval, seed=seed) + shape=shape, dtype=dtype, minval=minval, maxval=maxval, seed=seed + ) def random_normal(shape, dtype, seed): @@ -157,23 +168,36 @@ def random_categorical(logits, num_samples, seed): """Returns a sample from a categorical distribution. `logits` must be 2D.""" if _is_stateful_seed(seed): return tf.random.categorical( - logits=logits, num_samples=num_samples, seed=seed) + logits=logits, num_samples=num_samples, seed=seed + ) else: return tf.random.stateless_categorical( - logits=logits, num_samples=num_samples, seed=seed) + logits=logits, num_samples=num_samples, seed=seed + ) + + +def random_permutation(value, seed): + """Randomly permutes the array.""" + if _is_stateful_seed(seed): + return tf.random.shuffle(value, seed) + else: + return tf.random.experimental.stateless_shuffle(value, seed) def _eval_shape(fn, input_spec): """Gets output `TensorSpec`s from `fn` given input `TensorSpec`.""" - raw_compiled_fn = tf.function( - fn, autograph=False).get_concrete_function(input_spec) + raw_compiled_fn = tf.function(fn, autograph=False).get_concrete_function( + input_spec + ) def compiled_fn(x): return raw_compiled_fn(*tf.nest.flatten(x)) - output_spec = tf.nest.map_structure(tf.TensorSpec, - raw_compiled_fn.output_shapes, - raw_compiled_fn.output_dtypes) + output_spec = tf.nest.map_structure( + tf.TensorSpec, + raw_compiled_fn.output_shapes, + raw_compiled_fn.output_dtypes, + ) return compiled_fn, output_spec @@ -185,10 +209,10 @@ def trace(state, fn, num_steps, unroll, max_steps, parallel_iterations=10): state, first_untraced, first_traced, stop = fn(state) arrays = tf.nest.map_structure( lambda v: tf.TensorArray( # pylint: disable=g-long-lambda - v.dtype, - size=num_outputs, - element_shape=v.shape).write(0, v), - first_traced) + v.dtype, size=num_outputs, element_shape=v.shape + ).write(0, v), + first_traced, + ) start_idx = 1 else: # We need the shapes and dtypes of the outputs of `fn` function to create @@ -199,12 +223,13 @@ def trace(state, fn, num_steps, unroll, max_steps, parallel_iterations=10): arrays = tf.nest.map_structure( lambda spec: tf.TensorArray( # pylint: disable=g-long-lambda - spec.dtype, - size=num_outputs, - element_shape=spec.shape), - traced_spec) + spec.dtype, size=num_outputs, element_shape=spec.shape + ), + traced_spec, + ) first_untraced = tf.nest.map_structure( - lambda spec: tf.zeros(spec.shape, spec.dtype), untraced_spec) + lambda spec: tf.zeros(spec.shape, spec.dtype), untraced_spec + ) start_idx = 0 if isinstance(stop_spec, tuple): stop = () @@ -386,12 +411,13 @@ def wrapped(*args, **kwargs): def diff(x, prepend=None): """Like jnp.diff.""" if prepend is not None: - x = tf.concat([tf.convert_to_tensor(prepend, dtype=x.dtype)[tf.newaxis], x], - 0) + x = tf.concat( + [tf.convert_to_tensor(prepend, dtype=x.dtype)[tf.newaxis], x], 0 + ) return x[1:] - x[:-1] -def repeat(x, repeats, total_repeat_length=None): +def repeat(x, repeats, total_repeat_length): """Like jnp.repeat.""" res = tf.repeat(x, repeats) if total_repeat_length is not None: diff --git a/spinoffs/fun_mc/fun_mc/fun_mc_lib.py b/spinoffs/fun_mc/fun_mc/fun_mc_lib.py index afc87faa38..8d406ea43c 100644 --- a/spinoffs/fun_mc/fun_mc/fun_mc_lib.py +++ b/spinoffs/fun_mc/fun_mc/fun_mc_lib.py @@ -670,7 +670,7 @@ def call_transition_operator( new_args, extra = ret try: util.assert_same_shallow_tree(args, new_args) - except: + except Exception as e: args_s = _tree_repr(args) new_args_s = _tree_repr(new_args) raise TypeError( @@ -681,7 +681,7 @@ def call_transition_operator( args_s=args_s, new_args_s=new_args_s, ) - ) + ) from e return new_args, extra diff --git a/spinoffs/fun_mc/fun_mc/smc.py b/spinoffs/fun_mc/fun_mc/smc.py new file mode 100644 index 0000000000..ed5c32aae7 --- /dev/null +++ b/spinoffs/fun_mc/fun_mc/smc.py @@ -0,0 +1,167 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Implementation of Sequential Monte Carlo.""" + +from typing import Protocol, TypeVar, runtime_checkable + +from fun_mc import backend +from fun_mc import types + +jax = backend.jax +jnp = backend.jnp +tfp = backend.tfp +util = backend.util +distribute_lib = backend.distribute_lib + +Array = types.Array +Seed = types.Seed +Float = types.Float +Int = types.Int +Bool = types.Bool +DType = types.DType +BoolScalar = types.BoolScalar +IntScalar = types.IntScalar +FloatScalar = types.FloatScalar +State = TypeVar('State') +Extra = TypeVar('Extra') +T = TypeVar('T') + +__all__ = [ + 'conditional_systematic_resampling', + 'SampleAncestorsFn', + 'systematic_resampling', +] + + +@runtime_checkable +class SampleAncestorsFn(Protocol): + """Function that generates ancestor indices for resampling.""" + + def __call__( + self, + log_weights: Float[Array, 'num_particles'], + seed: Seed, + ) -> Int[Array, 'num_particles']: + """Generate a set of ancestor indices from particle weights.""" + + +@types.runtime_typed +def systematic_resampling( + log_weights: Float[Array, 'num_particles'], + seed: Seed, + permute: bool = False, +) -> Int[Array, 'num_particles']: + """Generate parent indices via systematic resampling. + + Args: + log_weights: Unnormalized log-scale weights. + seed: PRNG seed. + permute: Whether to permute the parent indices. Otherwise, they are sorted + in an ascending order. + + Returns: + parent_idxs: parent indices such that the marginal probability that a + randomly chosen element will be `i` is equal to `softmax(log_weights)[i]`. + """ + shift_seed, permute_seed = util.split_seed(seed, 2) + log_weights = jnp.where( + jnp.isnan(log_weights), + jnp.array(-float('inf'), log_weights.dtype), + log_weights, + ) + probs = jax.nn.softmax(log_weights) + # A common situation is all -inf log_weights that creats a NaN vector. + probs = jnp.where( + jnp.all(jnp.isfinite(probs)), probs, jnp.ones_like(probs) / probs.shape[0] + ) + num_particles = probs.shape[0] + + shift = util.random_uniform([], log_weights.dtype, shift_seed) + pie = jnp.cumsum(probs) * num_particles + shift + repeats = jnp.array(util.diff(jnp.floor(pie), prepend=0), jnp.int32) + parent_idxs = util.repeat( + jnp.arange(num_particles), repeats, total_repeat_length=num_particles + ) + if permute: + parent_idxs = util.random_permutation(parent_idxs, permute_seed) + return parent_idxs + + +@types.runtime_typed +def conditional_systematic_resampling( + log_weights: Float[Array, 'num_particles'], + seed: Seed, +) -> Int[Array, 'num_particles']: + """Apply conditional systematic resampling to `softmax(log_weights)`. + + Equivalent to (but typically much more efficient than) the following + rejection sampler: + + ```python + for i in count(): + parents = systematic_resampling(log_weights, seed=i, permute=True) + if parents[0] == 0: + break + return parents + ``` + + The algorithm used is from [1]. + + Args: + log_weights: Unnormalized log-scale weights. + seed: PRNG seed. + + Returns: + parent_idxs: A sample from the posterior over the output of + `systematic_resampling`, conditioned on parent_idxs[0] == 0. + + #### References + + [1]: Chopin and Singh, 'On Particle Gibbs Sampling,' Bernoulli, 2015 + https://www.jstor.org/stable/43590414 + """ + mixture_seed, shift_seed, permute_seed = util.split_seed(seed, 3) + log_weights = jnp.where( + jnp.isnan(log_weights), + jnp.array(-float('inf'), log_weights.dtype), + log_weights, + ) + probs = jax.nn.softmax(log_weights) + num_particles = log_weights.shape[0] + + # Sample from the posterior over shift given that parents[0] == 0. This turns + # out to be a mixture of non-overlapping uniforms. + scaled_w1 = num_particles * probs[0] + r = scaled_w1 % 1.0 + prob_shift_less_than_r = r * jnp.ceil(scaled_w1) / scaled_w1 + shift = util.random_uniform( + shape=[], dtype=log_weights.dtype, seed=shift_seed + ) + shift = jnp.where( + util.random_uniform(shape=[], dtype=log_weights.dtype, seed=mixture_seed) + < prob_shift_less_than_r, + shift * r, + r + shift * (1 - r), + ) + # Proceed as usual once we've figured out the shift. + pie = jnp.cumsum(probs) * num_particles + shift + repeats = jnp.array(util.diff(jnp.floor(pie), prepend=0), jnp.int32) + parent_idxs = util.repeat( + jnp.arange(num_particles), repeats, total_repeat_length=num_particles + ) + # Permute parents[1:]. + permuted_parents = util.random_permutation(parent_idxs[1:], permute_seed) + parents = jnp.concatenate([parent_idxs[:1], permuted_parents]) + return parents diff --git a/spinoffs/fun_mc/fun_mc/smc_test.py b/spinoffs/fun_mc/fun_mc/smc_test.py new file mode 100644 index 0000000000..8acf2535b3 --- /dev/null +++ b/spinoffs/fun_mc/fun_mc/smc_test.py @@ -0,0 +1,170 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# Dependency imports + +from absl.testing import parameterized +import jax as real_jax +import tensorflow.compat.v2 as real_tf +from tensorflow_probability.python.internal import test_util as tfp_test_util +from fun_mc import backend +from fun_mc import fun_mc_lib as fun_mc +from fun_mc import smc +from fun_mc import test_util +from fun_mc import types + +jax = backend.jax +jnp = backend.jnp +tfp = backend.tfp +util = backend.util +tfd = tfp.distributions +distribute_lib = backend.distribute_lib +Root = tfd.JointDistributionCoroutine.Root +Array = types.Array +Seed = types.Seed +Float = types.Float +Int = types.Int +Bool = types.Bool +BoolScalar = types.BoolScalar +IntScalar = types.IntScalar +FloatScalar = types.FloatScalar + + +real_tf.enable_v2_behavior() +real_tf.experimental.numpy.experimental_enable_numpy_behavior() +real_jax.config.update('jax_enable_x64', True) + +BACKEND = None # Rewritten by backends/rewrite.py. + + +def _test_seed() -> Seed: + seed = tfp_test_util.test_seed() % (2**32 - 1) + if BACKEND == 'backend_jax': + return jax.random.PRNGKey(seed) + else: + return util.make_tensor_seed([seed, 0]) + + +class SMCTest(tfp_test_util.TestCase): + + @property + def _dtype(self): + raise NotImplementedError() + + def _constant(self, value): + return jnp.array(value, self._dtype) + + @parameterized.parameters(True, False) + def test_systematic_resampling(self, permute): + seed = _test_seed() + + num_replications = 10000 + weights = self._constant([0.0, 0.5, 0.2, 0.3, 0.0]) + log_weights = jnp.log(weights) + + def kernel(seed): + seed, sample_seed = util.split_seed(seed, 2) + parents = smc.systematic_resampling( + log_weights, seed=sample_seed, permute=permute + ) + return seed, parents + + _, parents = jax.jit( + lambda seed: fun_mc.trace(seed, kernel, num_replications) + )(seed) + + # [num_samples, parents, parents] + freqs = jnp.mean( + jnp.array( + parents[..., jnp.newaxis] == jnp.arange(len(weights)), jnp.float32 + ), + (0, 1), + ) + + self.assertAllClose(freqs, weights, atol=0.05) + + if permute: + mean_index = jnp.sum(weights * jnp.arange(len(weights))) + self.assertAllClose( + jnp.mean(parents, 0), [mean_index] * len(weights), atol=0.05 + ) + + def test_conditional_systematic_resampling(self): + seed = _test_seed() + + num_replications = 10000 + weights = self._constant([0.2, 0.5, 0.2, 0.1, 0.0]) + log_weights = jnp.log(weights) + + def kernel(seed): + seed, systematic_seed, cond_seed = util.split_seed(seed, 3) + systematic_parents = smc.systematic_resampling( + log_weights, seed=systematic_seed, permute=True + ) + conditional_parents = smc.systematic_resampling( + log_weights, + seed=cond_seed, + ) + return seed, (systematic_parents, conditional_parents) + + _, (systematic_parents, conditional_parents) = jax.jit( + lambda seed: fun_mc.trace(seed, kernel, num_replications) + )(seed) + + self.assertFalse(jnp.all(systematic_parents[:, 0] == 0)) + self.assertTrue(jnp.all(conditional_parents[:, 0] == 0)) + + accepted_samples = jnp.array(systematic_parents[:, 0] == 0, jnp.float32) + rejection_freqs = jnp.sum( + jnp.mean( + accepted_samples[:, jnp.newaxis, jnp.newaxis] + * jnp.array( + systematic_parents[..., jnp.newaxis] + == jnp.arange(len(weights)), + jnp.float32, + ), + 1, + ), + 0, + ) / jnp.sum(accepted_samples) + conditional_freqs = jnp.mean( + jnp.array( + conditional_parents[..., jnp.newaxis] == jnp.arange(len(weights)), + jnp.float32, + ), + (0, 1), + ) + self.assertAllClose(rejection_freqs, conditional_freqs, atol=0.05) + + +@test_util.multi_backend_test(globals(), 'smc_test') +class SMCTest32(SMCTest): + + @property + def _dtype(self): + return jnp.float32 + + +@test_util.multi_backend_test(globals(), 'smc_test') +class SMCTest64(SMCTest): + + @property + def _dtype(self): + return jnp.float64 + + +del SMCTest + +if __name__ == '__main__': + tfp_test_util.main() diff --git a/spinoffs/fun_mc/fun_mc/types.py b/spinoffs/fun_mc/fun_mc/types.py new file mode 100644 index 0000000000..e40eeb8ba5 --- /dev/null +++ b/spinoffs/fun_mc/fun_mc/types.py @@ -0,0 +1,51 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Various types used in FunMC.""" + +from typing import Callable, TypeAlias, TypeVar + +import jaxtyping +from fun_mc import backend +import typeguard + +__all__ = [ + 'Array', + 'Bool', + 'BoolScalar', + 'DType', + 'Float', + 'FloatScalar', + 'Int', + 'IntScalar', + 'runtime_typed', + 'Seed', +] + +Array = backend.util.Array +Seed = backend.util.Seed +DType = backend.util.DType +Float = jaxtyping.Float +Int = jaxtyping.Int +Bool = jaxtyping.Bool +BoolScalar: TypeAlias = bool | Bool[Array, ''] +IntScalar: TypeAlias = int | Int[Array, ''] +FloatScalar: TypeAlias = float | Float[Array, ''] + +F = TypeVar('F', bound=Callable) + + +def runtime_typed(f: F) -> F: + """Adds runtime type checking.""" + return jaxtyping.jaxtyped(f, typechecker=typeguard.typechecked) diff --git a/spinoffs/fun_mc/pyproject.toml b/spinoffs/fun_mc/pyproject.toml index 713e0abb6e..cc7103d4b4 100644 --- a/spinoffs/fun_mc/pyproject.toml +++ b/spinoffs/fun_mc/pyproject.toml @@ -22,6 +22,8 @@ version = "0.1.0" dependencies = [ "immutabledict", "numpy>=1.13.3", + "jaxtyping", + "typeguard", ] requires-python = ">= 3.10" authors = [