From 4c2e69318dd577a8e553022716f665eb8ef82158 Mon Sep 17 00:00:00 2001 From: siege Date: Tue, 16 Jan 2024 17:34:34 -0800 Subject: [PATCH] BoringMC: Change the SOT to be JAX rather than TensorFlow. This relies heavily on TensorFlow's experimental numpy API, so `tf.experimental.numpy.experimental_enable_numpy_behavior()` must be called before using FunMC from TF. PiperOrigin-RevId: 599007999 --- spinoffs/fun_mc/fun_mc/BUILD | 9 +- spinoffs/fun_mc/fun_mc/backend.py | 2 +- .../backends/tensorflow_integration_test.py | 2 + .../fun_mc/fun_mc/dynamic/backend_jax/BUILD | 10 - .../fun_mc/dynamic/backend_jax/backend.py | 9 +- .../fun_mc/dynamic/backend_jax/tf_on_jax.py | 208 --- .../dynamic/backend_tensorflow/backend.py | 42 +- spinoffs/fun_mc/fun_mc/fun_mc_lib.py | 1381 +++++++++------- spinoffs/fun_mc/fun_mc/fun_mc_test.py | 1408 +++++++++-------- spinoffs/fun_mc/fun_mc/malt.py | 176 ++- spinoffs/fun_mc/fun_mc/malt_test.py | 113 +- spinoffs/fun_mc/fun_mc/prefab.py | 412 ++--- spinoffs/fun_mc/fun_mc/prefab_test.py | 236 +-- spinoffs/fun_mc/fun_mc/sga_hmc.py | 370 +++-- spinoffs/fun_mc/fun_mc/sga_hmc_test.py | 192 ++- spinoffs/fun_mc/fun_mc/test_util.py | 27 +- spinoffs/fun_mc/fun_mc/using_jax.py | 1 + spinoffs/fun_mc/fun_mc/using_tensorflow.py | 1 + spinoffs/fun_mc/fun_mc/util_tfp.py | 59 +- spinoffs/fun_mc/fun_mc/util_tfp_test.py | 110 +- 20 files changed, 2647 insertions(+), 2121 deletions(-) delete mode 100644 spinoffs/fun_mc/fun_mc/dynamic/backend_jax/tf_on_jax.py diff --git a/spinoffs/fun_mc/fun_mc/BUILD b/spinoffs/fun_mc/fun_mc/BUILD index 14100447c5..ff870d2726 100644 --- a/spinoffs/fun_mc/fun_mc/BUILD +++ b/spinoffs/fun_mc/fun_mc/BUILD @@ -15,8 +15,8 @@ # Description: # Functional MC API. -# Placeholder: py_test # [internal] load pytype.bzl (pytype_library) +# Placeholder: py_test licenses(["notice"]) @@ -55,7 +55,7 @@ py_library( name = "backend", srcs = ["backend.py"], deps = [ - "//fun_mc/dynamic/backend_tensorflow:backend", + "//fun_mc/dynamic/backend_jax:backend", ], ) @@ -108,6 +108,7 @@ py_test( shard_count = 8, deps = [ ":fun_mc", + ":prefab", ":test_util", # absl/testing:parameterized dep, # scipy dep, @@ -134,6 +135,7 @@ py_test( shard_count = 2, deps = [ ":fun_mc", + ":malt", ":test_util", # jax dep, # tensorflow dep, @@ -159,6 +161,7 @@ py_test( shard_count = 2, deps = [ ":fun_mc", + ":sga_hmc", ":test_util", # jax dep, # tensorflow dep, @@ -185,6 +188,7 @@ py_test( shard_count = 2, deps = [ ":fun_mc", + ":prefab", ":test_util", # jax dep, # tensorflow dep, @@ -200,7 +204,6 @@ py_library( deps = [ ":backend", ":fun_mc_lib", - # tensorflow_probability dep, ], ) diff --git a/spinoffs/fun_mc/fun_mc/backend.py b/spinoffs/fun_mc/fun_mc/backend.py index 29e538b1a6..975a87f635 100644 --- a/spinoffs/fun_mc/fun_mc/backend.py +++ b/spinoffs/fun_mc/fun_mc/backend.py @@ -14,4 +14,4 @@ # ============================================================================ """Default backend implementation.""" -from fun_mc.dynamic.backend_tensorflow.backend import * # pylint: disable=wildcard-import +from fun_mc.dynamic.backend_jax.backend import * # pylint: disable=wildcard-import diff --git a/spinoffs/fun_mc/fun_mc/backends/tensorflow_integration_test.py b/spinoffs/fun_mc/fun_mc/backends/tensorflow_integration_test.py index 4aa27ecc9d..70e2cfca97 100644 --- a/spinoffs/fun_mc/fun_mc/backends/tensorflow_integration_test.py +++ b/spinoffs/fun_mc/fun_mc/backends/tensorflow_integration_test.py @@ -16,6 +16,8 @@ from fun_mc import using_tensorflow as fun_mc from absl.testing import absltest +tf.experimental.numpy.experimental_enable_numpy_behavior() + class TensorFlowIntegrationTest(absltest.TestCase): diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/BUILD b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/BUILD index b201cd39b3..337bd2abf6 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/BUILD +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/BUILD @@ -37,20 +37,10 @@ py_library( ], ) -py_library( - name = "tf_on_jax", - srcs = ["tf_on_jax.py"], - deps = [ - # jax dep, - # jax:stax dep, - ], -) - py_library( name = "backend", srcs = ["backend.py"], deps = [ - ":tf_on_jax", ":util", # tensorflow_probability/substrates:jax dep, ], diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/backend.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/backend.py index 666a080f3d..3fb141710d 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/backend.py +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/backend.py @@ -14,18 +14,19 @@ # ============================================================================ """JAX backend.""" -from fun_mc.dynamic.backend_jax import tf_on_jax +import jax +import jax.numpy as jnp + from fun_mc.dynamic.backend_jax import util from tensorflow_probability.substrates import jax as tfp from tensorflow_probability.substrates.jax.internal import distribute_lib from tensorflow_probability.substrates.jax.internal import prefer_static -tf = tf_on_jax.tf - __all__ = [ 'distribute_lib', 'prefer_static', - 'tf', + 'jax', + 'jnp', 'tfp', 'util', ] diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/tf_on_jax.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/tf_on_jax.py deleted file mode 100644 index c6baada131..0000000000 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/tf_on_jax.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2021 The TensorFlow Probability Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Rough TensorFlow API implemented via JAX.""" - -import contextlib -import functools -import types -from typing import Any - -import jax -from jax import lax -from jax import tree_util -from jax.example_libraries import stax -import jax.numpy as jnp -import numpy as np - -__all__ = [ - 'tf', -] - -tf = types.ModuleType('tensorflow', '') - - -def _impl(path_in_tf=(), private=True, name=None): - """Implements a TensorFlow function.""" - - def _decorator(fn): - """Implements a TensorFlow function.""" - cur_mod = tf - for path_element in path_in_tf: - if not hasattr(cur_mod, path_element): - new_mod = types.ModuleType(path_element, '') - setattr(cur_mod, path_element, new_mod) - cur_mod = getattr(cur_mod, path_element) - if name is None: - if private: - final_name = fn.__name__[1:] - else: - final_name = fn.__name__ - else: - final_name = name - setattr(cur_mod, final_name, fn) - return fn - - return _decorator - - -_impl_np = functools.partial(_impl, private=False) - - -@_impl() -def _cond(pred, true_fn, false_fn): - # TODO(siege): I'm not sure this is completely correct, does lax.cond - # correctly handle closures? - return lax.cond(pred, (), lambda _: true_fn(), (), lambda _: false_fn()) - - -@_impl() -def _convert_to_tensor(value, dtype=None, name=None): - del name - return jnp.asarray(value, dtype) - - -@_impl() -def _while_loop(cond, body, loop_vars, **kwargs): # pylint: disable=missing-docstring - del kwargs - - # JAX doesn't do the automatic unwrapping of variables. - def cond_wrapper(loop_vars): - return cond(*loop_vars) - - def body_wrapper(loop_vars): - return body(*loop_vars) - - return lax.while_loop(cond_wrapper, body_wrapper, loop_vars) - - -@_impl() -def _cast(v, dtype): - return jnp.asarray(v).astype(dtype) - - -@_impl() -@contextlib.contextmanager -def _name_scope(name): - yield name - - -@_impl() -def _rank(x): - # JAX doesn't have rank implemented. - return len(x.shape) - - -@_impl() -def _function(x): - return jax.jit(x) - - -@_impl() -def _one_hot(indices, depth): - indices = jnp.asarray(indices) - flat_indices = indices.reshape([-1]) - flat_ret = jnp.eye(depth)[flat_indices] - return flat_ret.reshape(indices.shape + (depth,)) - - -@_impl() -def _gather(params, indices): - params = jnp.asarray(params) - indices = jnp.asarray(indices) - return params[indices] - - -@_impl() -def _range(*args, **kwargs): - """Implements tf.range.""" - # TODO(siege): This is a hack, the correct solution is to fix reduce_sum etc - # to correctly handle jnp.array axes. - if any( - tree_util.tree_flatten( - tree_util.tree_map(lambda x: isinstance(x, jnp.ndarray), - (args, kwargs)))[0]): - return jnp.arange(*args, **kwargs) - else: - return np.arange(*args, **kwargs) - - -@_impl() -def _eye(num_rows, num_columns=None, batch_shape=None, dtype=jnp.float32): - """Implements tf.eye.""" - x = jnp.eye(num_rows, num_columns).astype(dtype) - if batch_shape is not None: - x = jnp.broadcast_to(x, tuple(batch_shape) + x.shape) - return x - - -@_impl() -def _get_static_value(value): - try: - return np.array(value) - except TypeError: - return None - - -tf.TensorSpec = Any -tf.DType = Any - - -_impl(name='add_n')(sum) -_impl(['nn'], name='softmax')(stax.softmax) -_impl(name='custom_gradient')(jax.custom_gradient) -_impl(name='stop_gradient')(jax.lax.stop_gradient) - -tf.newaxis = None - -_impl_np()(jnp.cumsum) -_impl_np()(jnp.exp) -_impl_np()(jnp.einsum) -_impl_np()(jnp.floor) -_impl_np()(jnp.float32) -_impl_np()(jnp.float64) -_impl_np()(jnp.int32) -_impl_np()(jnp.maximum) -_impl_np()(jnp.minimum) -_impl_np()(jnp.ones) -_impl_np()(jnp.ones_like) -_impl_np()(jnp.reshape) -_impl_np()(jnp.shape) -_impl_np()(jnp.size) -_impl_np()(jnp.sqrt) -_impl_np()(jnp.where) -_impl_np()(jnp.zeros) -_impl_np()(jnp.zeros_like) -_impl_np()(jnp.transpose) -_impl_np(name='fill')(jnp.full) -_impl_np(['nn'])(jax.nn.softmax) -_impl_np(['math'])(jnp.ceil) -_impl_np(['math'])(jnp.log) -_impl_np(['math'], name='mod')(jnp.mod) -_impl_np(['math'])(jnp.sqrt) -_impl_np(['math'], name='is_finite')(jnp.isfinite) -_impl_np(['math'], name='is_nan')(jnp.isnan) -_impl_np(['math'], name='pow')(jnp.power) -_impl_np(['math'], name='reduce_all')(jnp.all) -_impl_np(['math'], name='reduce_prod')(jnp.prod) -_impl_np(['math'], name='reduce_variance')(jnp.var) -_impl_np(name='abs')(jnp.abs) -_impl_np(name='Tensor')(jnp.ndarray) -_impl_np(name='concat')(jnp.concatenate) -_impl_np(name='constant')(jnp.array) -_impl_np(name='expand_dims')(jnp.expand_dims) -_impl_np(name='reduce_max')(jnp.max) -_impl_np(name='reduce_mean')(jnp.mean) -_impl_np(name='reduce_sum')(jnp.sum) -_impl_np(name='square')(jnp.square) diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/backend.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/backend.py index 962faa502f..a1461e16c4 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/backend.py +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/backend.py @@ -14,16 +14,56 @@ # ============================================================================ """TensorFlow backend.""" +import types import tensorflow.compat.v2 as tf import tensorflow_probability as tfp from tensorflow_probability.python.internal import distribute_lib from tensorflow_probability.python.internal import prefer_static from fun_mc.dynamic.backend_tensorflow import util +tnp = tf.experimental.numpy + +_lax = types.ModuleType('lax') +_lax.cond = tf.cond +_lax.stop_gradient = tf.stop_gradient + +_nn = types.ModuleType('nn') +_nn.softmax = tf.nn.softmax +_nn.one_hot = tf.one_hot + + +class _ShapeDtypeStruct: + pass + + +jax = types.ModuleType('jax') +jax.ShapeDtypeStruct = _ShapeDtypeStruct +jax.jit = tf.function +jax.lax = _lax +jax.custom_gradient = tf.custom_gradient +jax.nn = _nn + + +class _JNP(types.ModuleType): + + def __getattr__(self, name): + return getattr(tnp, name) + + +jnp = _JNP('numpy') +jnp.dtype = tf.DType +# These are technically provided by TensorFlow, but only after numpy mode is +# enabled. +jnp.ndarray = tf.Tensor +jnp.float32 = tf.float32 +jnp.float64 = tf.float64 + + __all__ = [ 'distribute_lib', 'prefer_static', - 'tf', + 'jnp', + 'jax', 'tfp', 'util', ] diff --git a/spinoffs/fun_mc/fun_mc/fun_mc_lib.py b/spinoffs/fun_mc/fun_mc/fun_mc_lib.py index 498a909a0b..a507e5d2ba 100644 --- a/spinoffs/fun_mc/fun_mc/fun_mc_lib.py +++ b/spinoffs/fun_mc/fun_mc/fun_mc_lib.py @@ -39,10 +39,10 @@ from typing import Any, Callable, NamedTuple, Optional, Union import numpy as np - from fun_mc import backend -tf = backend.tf +jax = backend.jax +jnp = backend.jnp tfp = backend.tfp util = backend.util ps = backend.prefer_static @@ -140,34 +140,44 @@ 'TransportMap', ] -# We quote tf types to avoid unconditionally loading the TF backend. -AnyTensor = Union['tf.Tensor', np.ndarray, np.generic] -BooleanTensor = Union[bool, 'tf.Tensor', np.ndarray, np.bool_] -IntTensor = Union[int, 'tf.Tensor', np.ndarray, np.integer] -FloatTensor = Union[float, 'tf.Tensor', np.ndarray, np.floating] +AnyArray = Union[jnp.ndarray, np.ndarray, np.generic] +BooleanArray = Union[jnp.ndarray, np.ndarray, np.bool_] +IntArray = Union[jnp.ndarray, np.ndarray, np.integer] +FloatArray = Union[jnp.ndarray, np.ndarray, np.floating] +Shape = tuple[int, ...] # TODO(b/109648354): Correctly represent the recursive nature of this type. -TensorNest = Union[AnyTensor, Sequence[AnyTensor], Mapping[Any, AnyTensor]] -TensorSpecNest = Union['tf.TensorSpec', Sequence['tf.TensorSpec'], - Mapping[Any, 'tf.TensorSpec']] -BijectorNest = Union[tfb.Bijector, Sequence[tfb.Bijector], - Mapping[Any, tfb.Bijector]] -BooleanNest = Union[BooleanTensor, Sequence[BooleanTensor], - Mapping[Any, BooleanTensor]] -FloatNest = Union[FloatTensor, Sequence[FloatTensor], Mapping[Any, FloatTensor]] -IntNest = Union[IntTensor, Sequence[IntTensor], Mapping[Any, IntTensor]] -StringNest = Union[str, Sequence[str], Mapping[Any, str]] -DTypeNest = Union['tf.DType', Sequence['tf.DType'], Mapping[Any, 'tf.DType']] -State = TensorNest # pylint: disable=invalid-name -StateExtra = TensorNest # pylint: disable=invalid-name -TransitionOperator = Callable[..., tuple[State, TensorNest]] -TransportMap = Callable[..., tuple[State, TensorNest]] -PotentialFn = Union[Callable[[TensorNest], tuple['tf.Tensor', TensorNest]], - Callable[..., tuple['tf.Tensor', TensorNest]]] -GradFn = Union[Callable[[TensorNest], tuple[TensorNest, TensorNest]], - Callable[..., tuple[TensorNest, TensorNest]]] - - -def _trace_extra(state: State, extra: TensorNest) -> TensorNest: +ArrayNest = Union[AnyArray, Sequence[AnyArray], Mapping[Any, AnyArray]] +ArraySpecNest = Union[ + jax.ShapeDtypeStruct, + Sequence[jax.ShapeDtypeStruct], + Mapping[Any, jax.ShapeDtypeStruct], +] +BijectorNest = Union[ + tfb.Bijector, Sequence[tfb.Bijector], Mapping[Any, tfb.Bijector] +] +BooleanNest = Union[ + BooleanArray, Sequence[BooleanArray], Mapping[Any, BooleanArray] +] +FloatNest = Union[FloatArray, Sequence[FloatArray], Mapping[Any, FloatArray]] +ShapeNest = Union[Shape, Sequence[Shape], Mapping[Any, Shape]] +IntNest = Union[IntArray, Sequence[IntArray], Mapping[Any, IntArray]] +StringNest = Union[Sequence[str], Mapping[Any, str]] +DTypeNest = Union[jnp.dtype, Sequence[jnp.dtype], Mapping[Any, jnp.dtype]] +State = ArrayNest # pylint: disable=invalid-name +StateExtra = ArrayNest # pylint: disable=invalid-name +TransitionOperator = Callable[..., tuple[State, ArrayNest]] +TransportMap = Callable[..., tuple[State, ArrayNest]] +PotentialFn = Union[ + Callable[[ArrayNest], tuple[jnp.ndarray, ArrayNest]], + Callable[..., tuple[jnp.ndarray, ArrayNest]], +] +GradFn = Union[ + Callable[[ArrayNest], tuple[ArrayNest, ArrayNest]], + Callable[..., tuple[ArrayNest, ArrayNest]], +] + + +def _trace_extra(state: State, extra: ArrayNest) -> ArrayNest: del state return extra @@ -206,20 +216,20 @@ def _select(trace_mask, traced, untraced): def trace( state: State, fn: TransitionOperator, - num_steps: IntTensor, - trace_fn: Callable[[State, TensorNest], TensorNest] = _trace_extra, - trace_mask: BooleanNest = True, + num_steps: IntArray, + trace_fn: Callable[[State, ArrayNest], ArrayNest] = _trace_extra, + trace_mask: bool | BooleanNest = True, unroll: bool = False, parallel_iterations: int = 10, -) -> tuple[State, TensorNest]: +) -> tuple[State, ArrayNest]: """`TransitionOperator` that runs `fn` repeatedly and traces its outputs. Args: - state: A nest of `Tensor`s or None. + state: A nest of `Array`s or None. fn: A `TransitionOperator`. num_steps: Number of steps to run the function for. Must be greater than 1. trace_fn: Callable that the unpacked outputs of `fn` and returns a nest of - `Tensor`s. These will potentially be stacked and returned as the second + `Array`s. These will potentially be stacked and returned as the second return value. By default, just the `extra` return from `fn` is returned. trace_mask: A potentially shallow nest with boolean leaves applied to the return value of `trace_fn`. This controls whether or not to actually trace @@ -269,8 +279,9 @@ def fn(x): """ def wrapper(state): - state, extra = util.map_tree(util.convert_to_tensor, - call_transition_operator(fn, state)) + state, extra = util.map_tree( + util.convert_to_tensor, call_transition_operator(fn, state) + ) trace_element = util.map_tree( util.convert_to_tensor, trace_fn(state, extra) ) @@ -320,7 +331,7 @@ def trace_mask(self) -> BooleanNest: return next(iter(self.trace_mask_holder.keys())).trace_mask @util.named_call - def trace(self, only_valid=True) -> TensorNest: + def trace(self, only_valid=True) -> ArrayNest: """Returns the stacked and unstacked values. Args: @@ -342,7 +353,7 @@ def interruptible_trace_init( state: State, fn: TransitionOperator, num_steps: int, - trace_mask: BooleanNest = True, + trace_mask: bool | BooleanNest = True, ) -> InterruptibleTraceState: """Initializes the state interruptible trace operator. @@ -366,21 +377,20 @@ def interruptible_trace_init( """ state = util.map_tree(util.convert_to_tensor, state) _, trace_element = util.eval_shape( - lambda s: call_transition_operator(fn, s), - state + lambda s: call_transition_operator(fn, s), state ) untraced, traced = _split_trace(trace_element, trace_mask) traced = util.map_tree( lambda x: util.new_dynamic_array(x.shape, x.dtype, num_steps), traced ) - untraced = util.map_tree(lambda x: tf.zeros(x.shape, x.dtype), untraced) + untraced = util.map_tree(lambda x: jnp.zeros(x.shape, x.dtype), untraced) return InterruptibleTraceState( state=state, traced=traced, untraced=untraced, - step=tf.zeros([], tf.int32), + step=jnp.zeros([], jnp.int32), trace_mask_holder={_TraceMaskHolder(trace_mask): ()}, ) @@ -425,12 +435,9 @@ def fun(x, y): assert([1.0, 2.0, 3.0, 4.0] == x_trace) assert([2.0, 4.0, 6.0, 8.0] == y_trace) ``` - """ - inner_state, trace_element = call_transition_operator( - fn, state.state - ) + inner_state, trace_element = call_transition_operator(fn, state.state) untraced, traced = util.map_tree( util.convert_to_tensor, _split_trace(trace_element, state.trace_mask) @@ -490,8 +497,9 @@ def call_fn( Returns: ret: Return value of `fn`. """ - if (isinstance(args, collections.abc.Sequence) and - not _is_namedtuple_like(args)): + if isinstance(args, collections.abc.Sequence) and not _is_namedtuple_like( + args + ): return fn(*args) elif isinstance(args, collections.abc.Mapping): return fn(**args) @@ -499,8 +507,9 @@ def call_fn( return fn(args) -def recover_state_from_args(args: Sequence[Any], kwargs: Mapping[str, Any], - state_structure: Any) -> Any: +def recover_state_from_args( + args: Sequence[Any], kwargs: Mapping[str, Any], state_structure: Any +) -> Any: """Attempts to recover the state that was transmitted via *args, **kwargs.""" orig_args = args if isinstance(state_structure, collections.abc.Mapping): @@ -518,20 +527,26 @@ def recover_state_from_args(args: Sequence[Any], kwargs: Mapping[str, Any], else: if k not in kwargs: raise ValueError( - ('Missing \'{}\' from kwargs.\nargs=\n{}\nkwargs=\n{}\n' - 'state_structure=\n{}').format(k, orig_args, kwargs, - _tree_repr(state_structure))) + ( + "Missing '{}' from kwargs.\nargs=\n{}\nkwargs=\n{}\n" + 'state_structure=\n{}' + ).format(k, orig_args, kwargs, _tree_repr(state_structure)) + ) state[k] = kwargs[k] return state - elif (isinstance(state_structure, collections.abc.Sequence) and - not _is_namedtuple_like(state_structure)): + elif isinstance( + state_structure, collections.abc.Sequence + ) and not _is_namedtuple_like(state_structure): # Sadly, we have no way of inferring the state index from kwargs, so we # disallow them. # TODO(siege): We could support length-1 sequences in principle. if kwargs: - raise ValueError('This wrapper does not accept keyword arguments for a ' - 'sequence-like state structure=\n{}'.format( - _tree_repr(state_structure))) + raise ValueError( + 'This wrapper does not accept keyword arguments for a ' + 'sequence-like state structure=\n{}'.format( + _tree_repr(state_structure) + ) + ) return type(state_structure)(args) elif args: return args[0] @@ -546,7 +561,7 @@ def recover_state_from_args(args: Sequence[Any], kwargs: Mapping[str, Any], def call_potential_fn( fn: PotentialFn, args: Union[tuple[Any], Mapping[str, Any], Any], -) -> tuple[tf.Tensor, Any]: +) -> tuple[jnp.ndarray, Any]: """Calls a transition operator with `args`. `fn` must fulfill the `PotentialFn` contract: @@ -566,27 +581,31 @@ def call_potential_fn( TypeError: If `fn` doesn't fulfill the contract. """ ret = call_fn(fn, args) - error_template = ('`{fn:}` must have a signature ' - '`fn(args) -> (tf.Tensor, extra)`' - ' but when called with `args=`\n{args:}\nreturned ' - '`ret=`\n{ret:}\ninstead. The structure of ' - '`args=`\n{args_s:}\nThe structure of `ret=`\n{ret_s:}\n' - 'A common solution is to adjust the `return`s in `fn` to ' - 'be `return args, ()`.') + error_template = ( + '`{fn:}` must have a signature ' + '`fn(args) -> (jnp.ndarray, extra)`' + ' but when called with `args=`\n{args:}\nreturned ' + '`ret=`\n{ret:}\ninstead. The structure of ' + '`args=`\n{args_s:}\nThe structure of `ret=`\n{ret_s:}\n' + 'A common solution is to adjust the `return`s in `fn` to ' + 'be `return args, ()`.' + ) if not isinstance(ret, collections.abc.Sequence) or len(ret) != 2: args_s = _tree_repr(args) ret_s = _tree_repr(ret) raise TypeError( error_template.format( - fn=fn, args=args, ret=ret, args_s=args_s, ret_s=ret_s)) + fn=fn, args=args, ret=ret, args_s=args_s, ret_s=ret_s + ) + ) return ret def call_transition_operator( fn: TransitionOperator, args: State, -) -> tuple[State, TensorNest]: +) -> tuple[State, ArrayNest]: """Calls a transition operator with `args`. `fn` must fulfill the `TransitionOperator` contract: @@ -607,27 +626,32 @@ def call_transition_operator( TypeError: If `fn` doesn't fulfill the contract. """ ret = call_fn(fn, args) - error_template = ('`{fn:}` must have a signature ' - '`fn(args) -> (new_args, extra)`' - ' but when called with `args=`\n{args:}\nreturned ' - '`ret=`\n{ret:}\ninstead. The structure of ' - '`args=`\n{args_s:}\nThe structure of `ret=`\n{ret_s:}\n' - 'A common solution is to adjust the `return`s in `fn` to ' - 'be `return args, ()`.') + error_template = ( + '`{fn:}` must have a signature ' + '`fn(args) -> (new_args, extra)`' + ' but when called with `args=`\n{args:}\nreturned ' + '`ret=`\n{ret:}\ninstead. The structure of ' + '`args=`\n{args_s:}\nThe structure of `ret=`\n{ret_s:}\n' + 'A common solution is to adjust the `return`s in `fn` to ' + 'be `return args, ()`.' + ) if not isinstance(ret, collections.abc.Sequence) or len(ret) != 2: args_s = _tree_repr(args) ret_s = _tree_repr(ret) raise TypeError( error_template.format( - fn=fn, args=args, ret=ret, args_s=args_s, ret_s=ret_s)) + fn=fn, args=args, ret=ret, args_s=args_s, ret_s=ret_s + ) + ) error_template = ( '`{fn:}` must have a signature ' '`fn(args) -> (new_args, extra)`' ' but when called with `args=`\n{args:}\nreturned ' '`new_args=`\n{new_args:}\ninstead. The structure of ' - '`args=`\n{args_s:}\nThe structure of `new_args=`\n{new_args_s:}\n') + '`args=`\n{args_s:}\nThe structure of `new_args=`\n{new_args_s:}\n' + ) new_args, extra = ret try: util.assert_same_shallow_tree(args, new_args) @@ -640,14 +664,16 @@ def call_transition_operator( args=args, new_args=new_args, args_s=args_s, - new_args_s=new_args_s)) + new_args_s=new_args_s, + ) + ) return new_args, extra def call_transport_map( fn: TransportMap, args: State, -) -> tuple[State, TensorNest]: +) -> tuple[State, ArrayNest]: """Calls a transport map with `args`. `fn` must fulfill the `TransportMap` contract: @@ -668,27 +694,31 @@ def call_transport_map( """ ret = call_fn(fn, args) - error_template = ('`{fn:}` must have a signature ' - '`fn(args) -> (out, extra)`' - ' but when called with `args=`\n{args:}\nreturned ' - '`ret=`\n{ret:}\ninstead. The structure of ' - '`args=`\n{args_s:}\nThe structure of `ret=`\n{ret_s:}\n' - 'A common solution is to adjust the `return`s in `fn` to ' - 'be `return args, ()`.') + error_template = ( + '`{fn:}` must have a signature ' + '`fn(args) -> (out, extra)`' + ' but when called with `args=`\n{args:}\nreturned ' + '`ret=`\n{ret:}\ninstead. The structure of ' + '`args=`\n{args_s:}\nThe structure of `ret=`\n{ret_s:}\n' + 'A common solution is to adjust the `return`s in `fn` to ' + 'be `return args, ()`.' + ) if not isinstance(ret, Sequence) or len(ret) != 2: args_s = _tree_repr(args) ret_s = _tree_repr(ret) raise TypeError( error_template.format( - fn=fn, args=args, ret=ret, args_s=args_s, ret_s=ret_s)) + fn=fn, args=args, ret=ret, args_s=args_s, ret_s=ret_s + ) + ) return ret def call_transport_map_with_ldj( fn: TransitionOperator, args: State, -) -> tuple[State, TensorNest, TensorNest]: +) -> tuple[State, ArrayNest, ArrayNest]: """Calls `fn` and returns the log-det jacobian to `fn`'s first output. Args: @@ -709,7 +739,7 @@ def wrapper(args): def call_potential_fn_with_grads( fn: PotentialFn, args: Union[tuple[Any], Mapping[str, Any], Any] -) -> tuple[tf.Tensor, TensorNest, TensorNest]: +) -> tuple[jnp.ndarray, ArrayNest, ArrayNest]: """Calls `fn` and returns the gradients with respect to `fn`'s first output. Args: @@ -728,8 +758,7 @@ def wrapper(args): return util.value_and_grad(wrapper, args) -def maybe_broadcast_structure(from_structure: Any, - to_structure: Any) -> Any: +def maybe_broadcast_structure(from_structure: Any, to_structure: Any) -> Any: """Maybe broadcasts `from_structure` to `to_structure`. This assumes that `from_structure` is a shallow version of `to_structure`. @@ -803,19 +832,22 @@ def reparameterize_potential_fn( if state_structure is None and init_state is None: raise ValueError( - 'At least one of `state_structure` or `init_state` must be ' - 'passed in.') + 'At least one of `state_structure` or `init_state` must be passed in.' + ) def wrapper(*args, **kwargs): """Transformed wrapper.""" real_state_structure = ( - state_structure if state_structure is not None else init_state) - transformed_state = recover_state_from_args(args, kwargs, - real_state_structure) + state_structure if state_structure is not None else init_state + ) + transformed_state = recover_state_from_args( + args, kwargs, real_state_structure + ) if track_volume: state, map_extra, ldj = call_transport_map_with_ldj( - transport_map_fn, transformed_state) + transport_map_fn, transformed_state + ) else: state, map_extra = call_transport_map(transport_map_fn, transformed_state) @@ -828,17 +860,20 @@ def wrapper(*args, **kwargs): if init_state is not None: inverse_transform_map_fn = util.inverse_fn(transport_map_fn) - transformed_state, _ = call_transport_map(inverse_transform_map_fn, - init_state) + transformed_state, _ = call_transport_map( + inverse_transform_map_fn, init_state + ) else: transformed_state = None return wrapper, transformed_state -def transform_log_prob_fn(log_prob_fn: PotentialFn, - bijector: BijectorNest, - init_state: Optional[State] = None) -> Any: +def transform_log_prob_fn( + log_prob_fn: PotentialFn, + bijector: BijectorNest, + init_state: Optional[State] = None, +) -> Any: """Transforms a log-prob function using a bijector. This takes a log-prob function and creates a new log-prob function that now @@ -870,40 +905,49 @@ def transform_log_prob_fn(log_prob_fn: PotentialFn, """ bijector_structure = util.get_shallow_tree( - lambda b: isinstance(b, tfb.Bijector), bijector) + lambda b: isinstance(b, tfb.Bijector), bijector + ) def wrapper(*args, **kwargs): """Transformed wrapper.""" bijector_ = bijector args = recover_state_from_args(args, kwargs, bijector_) - args = util.map_tree(lambda x: 0. + x, args) + args = util.map_tree(lambda x: 0.0 + x, args) - original_space_args = util.map_tree_up_to(bijector_structure, - lambda b, x: b.forward(x), - bijector_, args) - original_space_log_prob, extra = call_potential_fn(log_prob_fn, - original_space_args) + original_space_args = util.map_tree_up_to( + bijector_structure, lambda b, x: b.forward(x), bijector_, args + ) + original_space_log_prob, extra = call_potential_fn( + log_prob_fn, original_space_args + ) event_ndims = util.map_tree( - lambda x: len(x.shape) - len(original_space_log_prob.shape), args) + lambda x: len(x.shape) - len(original_space_log_prob.shape), args + ) return original_space_log_prob + sum( util.flatten_tree( util.map_tree_up_to( bijector_structure, lambda b, x, e: b.forward_log_det_jacobian(x, event_ndims=e), - bijector_, args, event_ndims))), [original_space_args, extra] + bijector_, + args, + event_ndims, + ) + ) + ), [original_space_args, extra] if init_state is None: return wrapper else: - return wrapper, util.map_tree_up_to(bijector_structure, - lambda b, s: b.inverse(s), bijector, - init_state) + return wrapper, util.map_tree_up_to( + bijector_structure, lambda b, s: b.inverse(s), bijector, init_state + ) class IntegratorStepState(NamedTuple): """Integrator step state.""" + state: State state_grads: State momentum: State @@ -911,24 +955,26 @@ class IntegratorStepState(NamedTuple): class IntegratorStepExtras(NamedTuple): """Integrator step extras.""" - target_log_prob: FloatTensor + + target_log_prob: FloatArray state_extra: StateExtra - kinetic_energy: FloatTensor + kinetic_energy: FloatArray kinetic_energy_extra: Any momentum_grads: State -IntegratorStep = Callable[[IntegratorStepState], tuple[IntegratorStepState, - IntegratorStepExtras]] +IntegratorStep = Callable[ + [IntegratorStepState], tuple[IntegratorStepState, IntegratorStepExtras] +] @util.named_call def splitting_integrator_step( integrator_step_state: IntegratorStepState, - step_size: FloatTensor, + step_size: FloatArray, target_log_prob_fn: PotentialFn, kinetic_energy_fn: PotentialFn, - coefficients: Sequence[FloatTensor], + coefficients: Sequence[FloatArray], forward: bool = True, ) -> tuple[IntegratorStepState, IntegratorStepExtras]: """Symmetric symplectic integrator `TransitionOperator`. @@ -966,9 +1012,9 @@ def splitting_integrator_step( momentum_grads = None step_size = maybe_broadcast_structure(step_size, state) - state = util.map_tree(tf.convert_to_tensor, state) - momentum = util.map_tree(tf.convert_to_tensor, momentum) - state = util.map_tree(tf.convert_to_tensor, state) + state = util.map_tree(jnp.asarray, state) + momentum = util.map_tree(jnp.asarray, momentum) + state = util.map_tree(jnp.asarray, state) idx_and_coefficients = enumerate(coefficients) if not forward: @@ -978,34 +1024,46 @@ def splitting_integrator_step( # pylint: disable=cell-var-from-loop if i % 2 == 0: # Update momentum. - state_grads = util.map_tree(tf.convert_to_tensor, state_grads) + state_grads = util.map_tree(jnp.asarray, state_grads) - momentum = util.map_tree(lambda m, sg, s: m + c * sg * s, momentum, - state_grads, step_size) + momentum = util.map_tree( + lambda m, sg, s: m + c * sg * s, momentum, state_grads, step_size + ) - kinetic_energy, kinetic_energy_extra, momentum_grads = call_potential_fn_with_grads( - kinetic_energy_fn, momentum) + kinetic_energy, kinetic_energy_extra, momentum_grads = ( + call_potential_fn_with_grads(kinetic_energy_fn, momentum) + ) else: # Update position. if momentum_grads is None: _, _, momentum_grads = call_potential_fn_with_grads( - kinetic_energy_fn, momentum) + kinetic_energy_fn, momentum + ) - state = util.map_tree(lambda x, mg, s: x + c * mg * s, state, - momentum_grads, step_size) + state = util.map_tree( + lambda x, mg, s: x + c * mg * s, state, momentum_grads, step_size + ) target_log_prob, state_extra, state_grads = call_potential_fn_with_grads( - target_log_prob_fn, state) + target_log_prob_fn, state + ) - return (IntegratorStepState(state, state_grads, momentum), - IntegratorStepExtras(target_log_prob, state_extra, kinetic_energy, - kinetic_energy_extra, momentum_grads)) + return ( + IntegratorStepState(state, state_grads, momentum), + IntegratorStepExtras( + target_log_prob, + state_extra, + kinetic_energy, + kinetic_energy_extra, + momentum_grads, + ), + ) @util.named_call def leapfrog_step( integrator_step_state: IntegratorStepState, - step_size: FloatTensor, + step_size: FloatArray, target_log_prob_fn: PotentialFn, kinetic_energy_fn: PotentialFn, ) -> tuple[IntegratorStepState, IntegratorStepExtras]: @@ -1022,19 +1080,20 @@ def leapfrog_step( integrator_step_state: IntegratorStepState. integrator_step_extras: IntegratorStepExtras. """ - coefficients = [0.5, 1., 0.5] + coefficients = [0.5, 1.0, 0.5] return splitting_integrator_step( integrator_step_state, step_size, target_log_prob_fn, kinetic_energy_fn, - coefficients=coefficients) + coefficients=coefficients, + ) @util.named_call def ruth4_step( integrator_step_state: IntegratorStepState, - step_size: FloatTensor, + step_size: FloatArray, target_log_prob_fn: PotentialFn, kinetic_energy_fn: PotentialFn, ) -> tuple[IntegratorStepState, IntegratorStepExtras]: @@ -1058,21 +1117,22 @@ def ruth4_step( [1]: Ruth, Ronald D. (August 1983). "A Canonical Integration Technique". Nuclear Science, IEEE Trans. on. NS-30 (4): 2669-2671 """ - c = 2**(1. / 3) - coefficients = (1. / (2 - c)) * np.array([0.5, 1., 0.5 - 0.5 * c, -c]) + c = 2 ** (1.0 / 3) + coefficients = (1.0 / (2 - c)) * np.array([0.5, 1.0, 0.5 - 0.5 * c, -c]) coefficients = list(coefficients) + list(reversed(coefficients))[1:] return splitting_integrator_step( integrator_step_state, step_size, target_log_prob_fn, kinetic_energy_fn, - coefficients=coefficients) + coefficients=coefficients, + ) @util.named_call def blanes_3_stage_step( integrator_step_state: IntegratorStepState, - step_size: FloatTensor, + step_size: FloatArray, target_log_prob_fn: PotentialFn, kinetic_energy_fn: PotentialFn, ) -> tuple[IntegratorStepState, IntegratorStepExtras]: @@ -1103,20 +1163,21 @@ def blanes_3_stage_step( """ a1 = 0.11888010966 b1 = 0.29619504261 - coefficients = [a1, b1, 0.5 - a1, 1. - 2. * b1] + coefficients = [a1, b1, 0.5 - a1, 1.0 - 2.0 * b1] coefficients = coefficients + list(reversed(coefficients))[1:] return splitting_integrator_step( integrator_step_state, step_size, target_log_prob_fn, kinetic_energy_fn, - coefficients=coefficients) + coefficients=coefficients, + ) @util.named_call def blanes_4_stage_step( integrator_step_state: IntegratorStepState, - step_size: FloatTensor, + step_size: FloatArray, target_log_prob_fn: PotentialFn, kinetic_energy_fn: PotentialFn, ) -> tuple[IntegratorStepState, IntegratorStepExtras]: @@ -1144,23 +1205,24 @@ def blanes_4_stage_step( a1 = 0.071353913 a2 = 0.268548791 b1 = 0.191667800 - coefficients = [a1, b1, a2, 0.5 - b1, 1. - 2. * (a1 + a2)] + coefficients = [a1, b1, a2, 0.5 - b1, 1.0 - 2.0 * (a1 + a2)] coefficients = coefficients + list(reversed(coefficients))[1:] return splitting_integrator_step( integrator_step_state, step_size, target_log_prob_fn, kinetic_energy_fn, - coefficients=coefficients) + coefficients=coefficients, + ) @util.named_call def mclachlan_optimal_4th_order_step( integrator_step_state: IntegratorStepState, - step_size: FloatTensor, + step_size: FloatArray, target_log_prob_fn: PotentialFn, kinetic_energy_fn: PotentialFn, - forward: BooleanTensor, + forward: BooleanArray, ) -> tuple[IntegratorStepState, IntegratorStepExtras]: """4th order integrator for Hamiltonians with a quadratic kinetic energy. @@ -1174,7 +1236,7 @@ def mclachlan_optimal_4th_order_step( state. target_log_prob_fn: Target log prob fn. kinetic_energy_fn: Kinetic energy fn. - forward: A scalar `bool` Tensor. Whether to run this integrator in the + forward: A scalar `bool` Array. Whether to run this integrator in the forward direction. Note that this is done for the entire state, not per-batch. @@ -1207,29 +1269,32 @@ def _step(direction): target_log_prob_fn, kinetic_energy_fn, coefficients=coefficients, - forward=direction) + forward=direction, + ) - # In principle we can avoid the cond, and use `tf.where` to select between the - # coefficients. This would require a superfluous momentum update, but in + # In principle we can avoid the cond, and use `jnp.where` to select between + # the coefficients. This would require a superfluous momentum update, but in # principle is feasible. We're not doing it because it would complicate the # code slightly, and there is limited motivation to do it since reversing the # directions for all the chains at once is typically valid as well. - return tf.cond(forward, lambda: _step(True), lambda: _step(False)) + return jax.lax.cond(forward, lambda: _step(True), lambda: _step(False)) class MetropolisHastingsExtra(NamedTuple): """Metropolis-hastings extra outputs.""" - is_accepted: FloatTensor - log_uniform: FloatTensor + + is_accepted: FloatArray + log_uniform: FloatArray @util.named_call def metropolis_hastings_step( current_state: State, proposed_state: State, - energy_change: FloatTensor, - log_uniform: Optional[FloatTensor] = None, - seed=None) -> tuple[State, MetropolisHastingsExtra]: + energy_change: FloatArray, + log_uniform: Optional[FloatArray] = None, + seed=None, +) -> tuple[State, MetropolisHastingsExtra]: """Metropolis-Hastings step. This probabilistically chooses between `current_state` and `proposed_state` @@ -1249,23 +1314,26 @@ def metropolis_hastings_step( new_state: The chosen state. mh_extra: MetropolisHastingsExtra. """ - current_state = util.map_tree(tf.convert_to_tensor, current_state) - proposed_state = util.map_tree(tf.convert_to_tensor, proposed_state) - energy_change = tf.convert_to_tensor(energy_change) + current_state = util.map_tree(jnp.asarray, current_state) + proposed_state = util.map_tree(jnp.asarray, proposed_state) + energy_change = jnp.asarray(energy_change) log_accept_ratio = -energy_change if log_uniform is None: - log_uniform = tf.math.log( + log_uniform = jnp.log( util.random_uniform( shape=log_accept_ratio.shape, dtype=log_accept_ratio.dtype, - seed=seed)) + seed=seed, + ) + ) is_accepted = log_uniform < log_accept_ratio next_state = choose(is_accepted, proposed_state, current_state) return next_state, MetropolisHastingsExtra( - is_accepted=is_accepted, log_uniform=log_uniform) + is_accepted=is_accepted, log_uniform=log_uniform + ) class PersistentMetropolistHastingsState(NamedTuple): @@ -1275,10 +1343,11 @@ class PersistentMetropolistHastingsState(NamedTuple): level: Value uniformly distributed on [-1, 1], absolute value of which is used as the slice variable for the acceptance test. """ + # We borrow the [-1, 1] encoding from the original paper; it has the effect of # flipping the drift direction automatically, which has the effect of # prolonging the persistent bouts of acceptance. - level: FloatTensor + level: FloatArray class PersistentMetropolistHastingsExtra(NamedTuple): @@ -1288,15 +1357,16 @@ class PersistentMetropolistHastingsExtra(NamedTuple): is_accepted: Whether the proposed state was accepted. accepted_state: The accepted state. """ - is_accepted: BooleanTensor + + is_accepted: BooleanArray accepted_state: State @util.named_call def persistent_metropolis_hastings_init( - shape: IntTensor, - dtype: tf.DType = tf.float32, - init_level: FloatTensor = 0., + shape: Shape, + dtype: jnp.dtype = jnp.float32, + init_level: float | FloatArray = 0.0, ) -> PersistentMetropolistHastingsState: """Initializes `PersistentMetropolistHastingsState`. @@ -1308,8 +1378,9 @@ def persistent_metropolis_hastings_init( Returns: pmh_state: `PersistentMetropolistHastingsState` """ - return PersistentMetropolistHastingsState(level=init_level + - tf.zeros(shape, dtype)) + return PersistentMetropolistHastingsState( + level=init_level + jnp.zeros(shape, dtype) + ) @util.named_call @@ -1317,8 +1388,8 @@ def persistent_metropolis_hastings_step( pmh_state: PersistentMetropolistHastingsState, current_state: State, proposed_state: State, - energy_change: FloatTensor, - drift: FloatTensor, + energy_change: FloatArray, + drift: FloatArray, ) -> tuple[ PersistentMetropolistHastingsState, PersistentMetropolistHastingsExtra ]: @@ -1347,11 +1418,11 @@ def persistent_metropolis_hastings_step( Metropolis accept/reject decisions. """ log_accept_ratio = -energy_change - is_accepted = tf.math.log(tf.abs(pmh_state.level)) < log_accept_ratio + is_accepted = jnp.log(jnp.abs(pmh_state.level)) < log_accept_ratio # N.B. we'll never accept when energy_change is NaN, so `level` should remain # non-NaN at all times. level = pmh_state.level - level = tf.where(is_accepted, level * tf.exp(energy_change), level) + level = jnp.where(is_accepted, level * jnp.exp(energy_change), level) level += drift level = (1 + level) % 2 - 1 return pmh_state._replace(level=level), PersistentMetropolistHastingsExtra( @@ -1364,11 +1435,13 @@ def persistent_metropolis_hastings_step( @util.named_call -def gaussian_momentum_sample(state: Optional[State] = None, - shape: Optional[IntTensor] = None, - dtype: Optional[DTypeNest] = None, - named_axis: Optional[StringNest] = None, - seed=None) -> State: +def gaussian_momentum_sample( + state: Optional[State] = None, + shape: Optional[ShapeNest] = None, + dtype: Optional[DTypeNest] = None, + named_axis: Optional[StringNest] = None, + seed=None, +) -> State: """Generates a sample from a Gaussian (Normal) momentum distribution. One of `state` or the pair of `shape`/`dtype` need to be specified to obtain @@ -1376,7 +1449,7 @@ def gaussian_momentum_sample(state: Optional[State] = None, structure. Args: - state: A nest of `Tensor`s with the shape and dtype being the same as the + state: A nest of `Array`s with the shape and dtype being the same as the output. shape: A nest of shapes, which matches the output shapes. dtype: A nest of dtypes, which matches the output dtypes. @@ -1384,7 +1457,7 @@ def gaussian_momentum_sample(state: Optional[State] = None, seed: For reproducibility. Returns: - sample: A nest of `Tensor`s with the same structure, shape and dtypes as one + sample: A nest of `Array`s with the same structure, shape and dtypes as one of the two sets inputs, distributed with Normal distribution. """ if dtype is None or shape is None: @@ -1405,9 +1478,9 @@ def _one_part(dtype, shape, seed, named_axis): def make_gaussian_kinetic_energy_fn( - chain_ndims: IntTensor, + chain_ndims: int, named_axis: Optional[StringNest] = None, -) -> Callable[..., tuple[tf.Tensor, TensorNest]]: +) -> Callable[..., tuple[jnp.ndarray, ArrayNest]]: """Returns a function that computes the kinetic energy of a state. Args: @@ -1435,39 +1508,50 @@ def kinetic_energy_fn(*args, **kwargs): # leaves). Instead, we go the other way, and decompose named_axis into # args, kwargs. These new objects are guaranteed to line up with the # decomposed state. - named_axis_args = call_fn(lambda *args, **kwargs: (args, kwargs), - named_axis) + named_axis_args = call_fn( + lambda *args, **kwargs: (args, kwargs), named_axis + ) def _one_part(x, named_axis): return backend.distribute_lib.reduce_sum( - tf.square(x), tuple(range(chain_ndims, len(x.shape))), named_axis) - - return 0.5 * sum( - util.flatten_tree( - util.map_tree_up_to(state_args, _one_part, state_args, - named_axis_args))), () + jnp.square(x), tuple(range(chain_ndims, len(x.shape))), named_axis + ) + + return ( + 0.5 + * sum( + util.flatten_tree( + util.map_tree_up_to( + state_args, _one_part, state_args, named_axis_args + ) + ) + ), + (), + ) return kinetic_energy_fn class IntegratorState(NamedTuple): """Integrator state.""" + state: State state_extra: StateExtra state_grads: State - target_log_prob: FloatTensor + target_log_prob: FloatArray momentum: State class IntegratorExtras(NamedTuple): """Integrator extra outputs.""" - initial_energy: FloatTensor - initial_kinetic_energy: FloatTensor + + initial_energy: FloatArray | tuple[()] + initial_kinetic_energy: FloatArray | tuple[()] initial_kinetic_energy_extra: Any - final_energy: FloatTensor - final_kinetic_energy: FloatTensor + final_energy: FloatArray | tuple[()] + final_kinetic_energy: FloatArray | tuple[()] final_kinetic_energy_extra: Any - energy_change: FloatTensor + energy_change: FloatArray integrator_trace: Any momentum_grads: State @@ -1475,14 +1559,14 @@ class IntegratorExtras(NamedTuple): @util.named_call def hamiltonian_integrator( int_state: IntegratorState, - num_steps: IntTensor, + num_steps: IntArray, integrator_step_fn: IntegratorStep, kinetic_energy_fn: PotentialFn, integrator_trace_fn: Callable[ - [IntegratorStepState, IntegratorStepExtras], TensorNest + [IntegratorStepState, IntegratorStepExtras], ArrayNest ] = lambda *args: (), unroll: bool = False, - max_num_steps: Optional[IntTensor] = None, + max_num_steps: Optional[IntArray] = None, ) -> tuple[IntegratorState, IntegratorExtras]: """Intergrates a discretized set of Hamiltonian equations. @@ -1491,10 +1575,10 @@ def hamiltonian_integrator( Args: int_state: Current `IntegratorState`. - num_steps: Integer scalar or N-D `Tensor`. Number of steps to take. If this + num_steps: Integer scalar or N-D `Array`. Number of steps to take. If this is not a scalar, then each corresponding independent system will be evaluated for that number of steps, followed by copying the final state to - avoid creating a ragged Tensor. Keep this in mind when interpreting the + avoid creating a ragged Array. Keep this in mind when interpreting the `integrator_trace` in the auxiliary output. integrator_step_fn: Instance of `IntegratorStep`. kinetic_energy_fn: Function to compute the kinetic energy from momentums. @@ -1518,7 +1602,8 @@ def hamiltonian_integrator( is_ragged = len(num_steps.shape) > 0 or max_num_steps is not None # pylint: disable=g-explicit-length-test initial_kinetic_energy, initial_kinetic_energy_extra = call_potential_fn( - kinetic_energy_fn, momentum) + kinetic_energy_fn, momentum + ) initial_energy = -target_log_prob + initial_kinetic_energy if is_ragged: @@ -1537,9 +1622,13 @@ def hamiltonian_integrator( integrator_wrapper_state = ( step, IntegratorStepState(state, state_grads, momentum), - IntegratorStepExtras(target_log_prob, state_extra, initial_kinetic_energy, - initial_kinetic_energy_extra, - util.map_tree(tf.zeros_like, momentum)), + IntegratorStepExtras( + target_log_prob, + state_extra, + initial_kinetic_energy, + initial_kinetic_energy_extra, + util.map_tree(jnp.zeros_like, momentum), + ), ) def integrator_wrapper(step, integrator_step_state, integrator_step_extra): @@ -1547,13 +1636,16 @@ def integrator_wrapper(step, integrator_step_state, integrator_step_extra): old_integrator_step_state = integrator_step_state old_integrator_step_extra = integrator_step_extra integrator_step_state, integrator_step_extra = integrator_step_fn( - integrator_step_state) + integrator_step_state + ) if is_ragged: - integrator_step_state = choose(step < num_steps, integrator_step_state, - old_integrator_step_state) - integrator_step_extra = choose(step < num_steps, integrator_step_extra, - old_integrator_step_extra) + integrator_step_state = choose( + step < num_steps, integrator_step_state, old_integrator_step_state + ) + integrator_step_extra = choose( + step < num_steps, integrator_step_extra, old_integrator_step_extra + ) step = step + 1 return (step, integrator_step_state, integrator_step_extra), [] @@ -1569,15 +1661,18 @@ def integrator_trace_wrapper_fn(args, _): unroll=unroll, ) - final_energy = (-integrator_step_extra.target_log_prob + - integrator_step_extra.kinetic_energy) + final_energy = ( + -integrator_step_extra.target_log_prob + + integrator_step_extra.kinetic_energy + ) state = IntegratorState( state=integrator_step_state.state, state_extra=integrator_step_extra.state_extra, state_grads=integrator_step_state.state_grads, target_log_prob=integrator_step_extra.target_log_prob, - momentum=integrator_step_state.momentum) + momentum=integrator_step_state.momentum, + ) extra = IntegratorExtras( initial_energy=initial_energy, @@ -1588,7 +1683,8 @@ def integrator_trace_wrapper_fn(args, _): final_kinetic_energy_extra=integrator_step_extra.kinetic_energy_extra, energy_change=final_energy - initial_energy, integrator_trace=integrator_trace, - momentum_grads=integrator_step_extra.momentum_grads) + momentum_grads=integrator_step_extra.momentum_grads, + ) return state, extra @@ -1596,14 +1692,14 @@ def integrator_trace_wrapper_fn(args, _): @util.named_call def obabo_langevin_integrator( int_state: IntegratorState, - num_steps: IntTensor, + num_steps: IntArray, integrator_step_fn: IntegratorStep, momentum_refresh_fn: Callable[[State, Any], State], energy_change_fn: Callable[ - [IntegratorState, IntegratorState], tuple[FloatTensor, Any] + [IntegratorState, IntegratorState], tuple[FloatArray, Any] ], integrator_trace_fn: Callable[ - [IntegratorState, IntegratorStepState, IntegratorStepExtras], TensorNest + [IntegratorState, IntegratorStepState, IntegratorStepExtras], ArrayNest ] = lambda *args: (), unroll: bool = False, seed: Any = None, @@ -1656,11 +1752,13 @@ def step_fn(int_state, energy_change, seed): integrator_step_state = IntegratorStepState( state=int_state.state, state_grads=int_state.state_grads, - momentum=int_state.momentum) + momentum=int_state.momentum, + ) # Integrate. integrator_step_state, integrator_step_extra = integrator_step_fn( - integrator_step_state) + integrator_step_state + ) new_int_state = int_state._replace( state=integrator_step_state.state, @@ -1681,20 +1779,23 @@ def step_fn(int_state, energy_change, seed): if integrator_trace_fn is None: integrator_trace = () else: - integrator_trace = integrator_trace_fn(new_int_state, - integrator_step_state, - integrator_step_extra) - return (new_int_state, energy_change, seed), (integrator_trace, - integrator_step_extra) - - (int_state, energy_change, - _), (integrator_trace, integrator_step_extra) = trace( - (int_state, tf.zeros_like(int_state.target_log_prob), seed), - step_fn, - num_steps, - unroll=unroll, - trace_mask=(True, False), - ) + integrator_trace = integrator_trace_fn( + new_int_state, integrator_step_state, integrator_step_extra + ) + return (new_int_state, energy_change, seed), ( + integrator_trace, + integrator_step_extra, + ) + + (int_state, energy_change, _), (integrator_trace, integrator_step_extra) = ( + trace( + (int_state, jnp.zeros_like(int_state.target_log_prob), seed), + step_fn, + num_steps, + unroll=unroll, + trace_mask=(True, False), + ) + ) extra = IntegratorExtras( initial_energy=(), @@ -1705,23 +1806,26 @@ def step_fn(int_state, energy_change, seed): final_kinetic_energy_extra=(), energy_change=energy_change, integrator_trace=integrator_trace, - momentum_grads=integrator_step_extra.momentum_grads) + momentum_grads=integrator_step_extra.momentum_grads, + ) return int_state, extra class HamiltonianMonteCarloState(NamedTuple): """Hamiltonian Monte Carlo state.""" + state: State state_grads: State - target_log_prob: FloatTensor + target_log_prob: FloatArray state_extra: StateExtra class HamiltonianMonteCarloExtra(NamedTuple): """Hamiltonian Monte Carlo extra outputs.""" - is_accepted: BooleanTensor - log_accept_ratio: FloatTensor + + is_accepted: BooleanArray + log_accept_ratio: FloatArray proposed_hmc_state: State integrator_state: IntegratorState integrator_extra: IntegratorExtras @@ -1731,8 +1835,8 @@ class HamiltonianMonteCarloExtra(NamedTuple): @util.named_call def hamiltonian_monte_carlo_init( - state: TensorNest, - target_log_prob_fn: PotentialFn) -> HamiltonianMonteCarloState: + state: ArrayNest, target_log_prob_fn: PotentialFn +) -> HamiltonianMonteCarloState: """Initializes the `HamiltonianMonteCarloState`. Args: @@ -1742,13 +1846,14 @@ def hamiltonian_monte_carlo_init( Returns: hmc_state: State of the `hamiltonian_monte_carlo_step` `TransitionOperator`. """ - state = util.map_tree(tf.convert_to_tensor, state) + state = util.map_tree(jnp.asarray, state) target_log_prob, state_extra, state_grads = util.map_tree( - tf.convert_to_tensor, + jnp.asarray, call_potential_fn_with_grads(target_log_prob_fn, state), ) - return HamiltonianMonteCarloState(state, state_grads, target_log_prob, - state_extra) + return HamiltonianMonteCarloState( + state, state_grads, target_log_prob, state_extra + ) @util.named_call @@ -1756,7 +1861,7 @@ def _default_hamiltonian_monte_carlo_energy_change_fn( current_integrator_state: IntegratorState, proposed_integrator_state: IntegratorState, integrator_extra: IntegratorExtras, -) -> tuple[FloatTensor, Any]: +) -> tuple[FloatArray, Any]: """Default HMC energy change function.""" del current_integrator_state del proposed_integrator_state @@ -1768,20 +1873,20 @@ def hamiltonian_monte_carlo_step( hmc_state: HamiltonianMonteCarloState, target_log_prob_fn: PotentialFn, step_size: Optional[Any] = None, - num_integrator_steps: Optional[IntTensor] = None, + num_integrator_steps: Optional[IntArray] = None, momentum: Optional[State] = None, kinetic_energy_fn: Optional[PotentialFn] = None, momentum_sample_fn: Optional[MomentumSampleFn] = None, integrator_trace_fn: Callable[ - [IntegratorStepState, IntegratorStepExtras], TensorNest + [IntegratorStepState, IntegratorStepExtras], ArrayNest ] = lambda *args: (), - log_uniform: Optional[FloatTensor] = None, + log_uniform: Optional[FloatArray] = None, integrator_fn=None, unroll_integrator: bool = False, - max_num_integrator_steps: Optional[IntTensor] = None, + max_num_integrator_steps: Optional[IntArray] = None, energy_change_fn: Callable[ [IntegratorState, IntegratorState, IntegratorExtras], - tuple[FloatTensor, Any], + tuple[FloatArray, Any], ] = _default_hamiltonian_monte_carlo_energy_change_fn, named_axis: Optional[StringNest] = None, seed=None, @@ -1794,7 +1899,7 @@ def hamiltonian_monte_carlo_step( step_size = 0.2 num_steps = 2000 num_integrator_steps = 10 - state = tf.ones([16, 2]) + state = jnp.ones([16, 2]) base_mean = [1., 0] base_cov = [[1, 0.5], [0.5, 1]] @@ -1810,7 +1915,7 @@ def orig_target_log_prob_fn(x): target_log_prob_fn, state = fun_mc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) - kernel = tf.function(lambda state: fun_mc.hamiltonian_monte_carlo_step( + kernel = jax.jit(lambda state: fun_mc.hamiltonian_monte_carlo_step( state, step_size=step_size, num_integrator_steps=num_integrator_steps, @@ -1864,15 +1969,13 @@ def orig_target_log_prob_fn(x): if kinetic_energy_fn is None: kinetic_energy_fn = make_gaussian_kinetic_energy_fn( - len(target_log_prob.shape) if target_log_prob.shape is not None else tf # pytype: disable=attribute-error - .rank(target_log_prob), - named_axis=named_axis) + len(target_log_prob.shape), named_axis=named_axis + ) if momentum_sample_fn is None: momentum_sample_fn = lambda seed: gaussian_momentum_sample( # pylint: disable=g-long-lambda - state=state, - seed=seed, - named_axis=named_axis) + state=state, seed=seed, named_axis=named_axis + ) if integrator_fn is None: integrator_fn = lambda state: hamiltonian_integrator( # pylint: disable=g-long-lambda @@ -1882,11 +1985,13 @@ def orig_target_log_prob_fn(x): state, step_size=step_size, target_log_prob_fn=target_log_prob_fn, - kinetic_energy_fn=kinetic_energy_fn), + kinetic_energy_fn=kinetic_energy_fn, + ), kinetic_energy_fn=kinetic_energy_fn, unroll=unroll_integrator, max_num_steps=max_num_integrator_steps, - integrator_trace_fn=integrator_trace_fn) + integrator_trace_fn=integrator_trace_fn, + ) if momentum is None: seed, sample_seed = util.split_seed(seed, 2) @@ -1906,7 +2011,8 @@ def orig_target_log_prob_fn(x): state=integrator_state.state, state_grads=integrator_state.state_grads, target_log_prob=integrator_state.target_log_prob, - state_extra=integrator_state.state_extra) + state_extra=integrator_state.state_extra, + ) energy_change, energy_change_extra = energy_change_fn( initial_integrator_state, @@ -1919,7 +2025,8 @@ def orig_target_log_prob_fn(x): proposed_state, energy_change, log_uniform=log_uniform, - seed=seed) + seed=seed, + ) hmc_state = typing.cast(HamiltonianMonteCarloState, hmc_state) return hmc_state, HamiltonianMonteCarloExtra( @@ -1929,14 +2036,17 @@ def orig_target_log_prob_fn(x): integrator_state=integrator_state, integrator_extra=integrator_extra, energy_change_extra=energy_change_extra, - initial_momentum=momentum) + initial_momentum=momentum, + ) @util.named_call -def sign_adaptation(control: FloatNest, - output: FloatTensor, - set_point: FloatTensor, - adaptation_rate: FloatTensor = 0.01) -> FloatNest: +def sign_adaptation( + control: FloatNest, + output: FloatArray, + set_point: FloatArray, + adaptation_rate: float | FloatArray = 0.01, +) -> FloatNest: """A function to do simple sign-based control of a variable. ``` @@ -1955,8 +2065,11 @@ def sign_adaptation(control: FloatNest, """ def _get_new_control(control, output, set_point): - new_control = choose(output > set_point, control * (1. + adaptation_rate), - control / (1. + adaptation_rate)) + new_control = choose( + output > set_point, + control * (1.0 + adaptation_rate), + control / (1.0 + adaptation_rate), + ) return new_control output = maybe_broadcast_structure(output, control) @@ -1967,7 +2080,7 @@ def _get_new_control(control, output, set_point): @util.named_call def choose(condition, x, y): - """A nest-aware, left-broadcasting `tf.where`. + """A nest-aware, left-broadcasting `jnp.where`. Args: condition: Boolean nest. Must left-broadcast with `x` and `y`. @@ -1984,30 +2097,33 @@ def _choose_base_case(condition, x, y): def _expand_condition_like(x): """Helper to expand `condition` like the shape of some input arg.""" expand_shape = list(condition.shape) + [1] * ( - len(x.shape) - len(condition.shape)) - return tf.reshape(condition, expand_shape) + len(x.shape) - len(condition.shape) + ) + return jnp.reshape(condition, expand_shape) if x is y: return x - x = tf.convert_to_tensor(x) - y = tf.convert_to_tensor(y) - return tf.where(_expand_condition_like(x), x, y) + x = jnp.asarray(x) + y = jnp.asarray(y) + return jnp.where(_expand_condition_like(x), x, y) - condition = tf.convert_to_tensor(condition) + condition = jnp.asarray(condition) return util.map_tree(lambda a, r: _choose_base_case(condition, a, r), x, y) class AdamState(NamedTuple): """Adam state.""" + state: State m: State v: State - t: IntTensor + t: IntArray class AdamExtra(NamedTuple): """Adam extra outputs.""" - loss: FloatTensor + + loss: FloatArray loss_extra: Any grads: State @@ -2015,21 +2131,24 @@ class AdamExtra(NamedTuple): @util.named_call def adam_init(state: FloatNest) -> AdamState: """Initializes `AdamState`.""" - state = util.map_tree(tf.convert_to_tensor, state) + state = util.map_tree(jnp.asarray, state) return AdamState( state=state, - m=util.map_tree(tf.zeros_like, state), - v=util.map_tree(tf.zeros_like, state), - t=tf.constant(0, dtype=tf.int32)) + m=util.map_tree(jnp.zeros_like, state), + v=util.map_tree(jnp.zeros_like, state), + t=jnp.zeros([], dtype=jnp.int32), + ) @util.named_call -def adam_step(adam_state: AdamState, - loss_fn: PotentialFn, - learning_rate: FloatNest, - beta_1: FloatNest = 0.9, - beta_2: FloatNest = 0.999, - epsilon: FloatNest = 1e-8) -> tuple[AdamState, AdamExtra]: +def adam_step( + adam_state: AdamState, + loss_fn: PotentialFn, + learning_rate: FloatNest, + beta_1: float | FloatNest = 0.9, + beta_2: float | FloatNest = 0.999, + epsilon: float | FloatNest = 1e-8, +) -> tuple[AdamState, AdamExtra]: """Performs one step of the Adam optimization method. Args: @@ -2066,40 +2185,44 @@ def adam_step(adam_state: AdamState, def _one_part(state, g, m, v, learning_rate, beta_1, beta_2, epsilon): """Updates one part of the state.""" - t_f = tf.cast(t, state.dtype) - beta_1 = tf.convert_to_tensor(beta_1, state.dtype) - beta_2 = tf.convert_to_tensor(beta_2, state.dtype) + t_f = jnp.array(t, state.dtype) + beta_1 = jnp.asarray(beta_1, state.dtype) + beta_2 = jnp.asarray(beta_2, state.dtype) learning_rate = learning_rate * ( - tf.math.sqrt(1. - tf.math.pow(beta_2, t_f)) / - (1. - tf.math.pow(beta_1, t_f))) + jnp.sqrt(1.0 - jnp.power(beta_2, t_f)) / (1.0 - jnp.power(beta_1, t_f)) + ) - m_t = beta_1 * m + (1. - beta_1) * g - v_t = beta_2 * v + (1. - beta_2) * tf.square(g) - state = state - learning_rate * m_t / (tf.math.sqrt(v_t) + epsilon) + m_t = beta_1 * m + (1.0 - beta_1) * g + v_t = beta_2 * v + (1.0 - beta_2) * jnp.square(g) + state = state - learning_rate * m_t / (jnp.sqrt(v_t) + epsilon) return state, m_t, v_t loss, loss_extra, grads = call_potential_fn_with_grads(loss_fn, state) - state_m_v = util.map_tree(_one_part, state, grads, m, v, learning_rate, - beta_1, beta_2, epsilon) + state_m_v = util.map_tree( + _one_part, state, grads, m, v, learning_rate, beta_1, beta_2, epsilon + ) adam_state = AdamState( state=util.map_tree_up_to(state, lambda x: x[0], state_m_v), m=util.map_tree_up_to(state, lambda x: x[1], state_m_v), v=util.map_tree_up_to(state, lambda x: x[2], state_m_v), - t=adam_state.t + 1) + t=adam_state.t + 1, + ) return adam_state, AdamExtra(loss_extra=loss_extra, loss=loss, grads=grads) class GradientDescentState(NamedTuple): """Gradient Descent state.""" + state: State class GradientDescentExtra(NamedTuple): """Gradient Descent extra outputs.""" - loss: FloatTensor + + loss: FloatArray loss_extra: Any grads: State @@ -2107,14 +2230,15 @@ class GradientDescentExtra(NamedTuple): @util.named_call def gradient_descent_init(state: FloatNest) -> GradientDescentState: """Initializes `GradientDescentState`.""" - state = util.map_tree(tf.convert_to_tensor, state) + state = util.map_tree(jnp.asarray, state) return GradientDescentState(state=state) @util.named_call def gradient_descent_step( - gd_state: GradientDescentState, loss_fn: PotentialFn, - learning_rate: FloatNest + gd_state: GradientDescentState, + loss_fn: PotentialFn, + learning_rate: FloatNest, ) -> tuple[GradientDescentState, GradientDescentExtra]: """Performs a step of regular gradient descent. @@ -2141,15 +2265,17 @@ def _one_part(state, g, learning_rate): gd_state = GradientDescentState(state=state) return gd_state, GradientDescentExtra( - loss_extra=loss_extra, loss=loss, grads=grads) + loss_extra=loss_extra, loss=loss, grads=grads + ) @util.named_call def gaussian_proposal( state: State, - scale: FloatNest = 1., + scale: float | FloatNest = 1.0, named_axis: Optional[StringNest] = None, - seed: Optional[Any] = None) -> tuple[State, tuple[tuple[()], float]]: + seed: Optional[Any] = None, +) -> tuple[State, tuple[tuple[()], float]]: """Axis-aligned gaussian random-walk proposal. Args: @@ -2171,28 +2297,31 @@ def gaussian_proposal( def _sample_part(x, scale, seed, named_axis): seed = backend.distribute_lib.fold_in_axis_index(seed, named_axis) return x + scale * util.random_normal( # pylint: disable=g-long-lambda - x.shape, x.dtype, seed) + x.shape, x.dtype, seed + ) - new_state = util.map_tree_up_to(state, _sample_part, state, scale, seeds, - named_axis) + new_state = util.map_tree_up_to( + state, _sample_part, state, scale, seeds, named_axis + ) - return new_state, ((), 0.) + return new_state, ((), 0.0) class MaximalReflectiveCouplingProposalExtra(NamedTuple): """Extra results from the `maximal_reflection_coupling_proposal`.""" - log_couple_ratio: FloatTensor - coupling_proposed: BooleanTensor + + log_couple_ratio: FloatArray + coupling_proposed: BooleanArray @util.named_call def maximal_reflection_coupling_proposal( state: State, chain_ndims: int = 0, - scale: FloatNest = 1, + scale: float | FloatNest = 1.0, named_axis: Optional[StringNest] = None, - epsilon: FloatTensor = 1e-20, - seed: Optional[Any] = None + epsilon: float | FloatArray = 1e-20, + seed: Optional[Any] = None, ) -> tuple[State, tuple[MaximalReflectiveCouplingProposalExtra, float]]: """Maximal reflection coupling proposal. @@ -2240,18 +2369,31 @@ def _struct_sum(s): mu1 = util.map_tree(lambda x: x[:num_chains], state) mu2 = util.map_tree(lambda x: x[num_chains:], state) event_dims = util.map_tree( - lambda x: tuple(range(1 + chain_ndims, len(x.shape))), mu1) + lambda x: tuple(range(1 + chain_ndims, len(x.shape))), mu1 + ) z = util.map_tree(lambda s, x1, x2: (x1 - x2) / s, scale, mu1, mu2) - z_norm = tf.sqrt( + z_norm = jnp.sqrt( _struct_sum( - util.map_tree_up_to(z, lambda z, ed, na: _sum(tf.square(z), ed, na), - z, event_dims, named_axis))) + util.map_tree_up_to( + z, + lambda z, ed, na: _sum(jnp.square(z), ed, na), + z, + event_dims, + named_axis, + ) + ) + ) e = util.map_tree( - lambda z: z / # pylint: disable=g-long-lambda - (tf.reshape(z_norm, z_norm.shape + (1,) * - (len(z.shape) - len(z_norm.shape))) + epsilon), - z) - batch_shape = util.flatten_tree(mu1)[0].shape[1:1 + chain_ndims] + lambda z: z # pylint: disable=g-long-lambda + / ( + jnp.reshape( + z_norm, z_norm.shape + (1,) * (len(z.shape) - len(z_norm.shape)) + ) + + epsilon + ), + z, + ) + batch_shape = util.flatten_tree(mu1)[0].shape[1 : 1 + chain_ndims] num_parts = len(util.flatten_tree(state)) all_seeds = util.split_seed(seed, num_parts + 1) @@ -2265,57 +2407,83 @@ def _sample_part(x, seed, named_axis): x = util.map_tree_up_to(mu1, _sample_part, mu1, x_seeds, named_axis) e_dot_x = _struct_sum( - util.map_tree_up_to(x, lambda x, e, ed, na: _sum(x * e, ed, na), x, e, - event_dims, named_axis)) + util.map_tree_up_to( + x, + lambda x, e, ed, na: _sum(x * e, ed, na), + x, + e, + event_dims, + named_axis, + ) + ) log_couple_ratio = _struct_sum( util.map_tree_up_to( - x, lambda x, z, ed, na: -_sum(x * z + tf.square(z) / 2, ed, na), x, z, - event_dims, named_axis)) + x, + lambda x, z, ed, na: -_sum(x * z + jnp.square(z) / 2, ed, na), + x, + z, + event_dims, + named_axis, + ) + ) - p_couple = tf.exp(tf.minimum(0., log_couple_ratio)) - coupling_proposed = util.random_uniform( - batch_shape, dtype=p_couple.dtype, seed=couple_seed) < p_couple + p_couple = jnp.exp(jnp.minimum(0.0, log_couple_ratio)) + coupling_proposed = ( + util.random_uniform(batch_shape, dtype=p_couple.dtype, seed=couple_seed) + < p_couple + ) y_reflected = util.map_tree( - lambda x, e: x - 2 * tf.reshape( # pylint: disable=g-long-lambda - e_dot_x, e_dot_x.shape + (1,) * - (len(e.shape) - len(e_dot_x.shape))) * e, + lambda x, e: ( # pylint: disable=g-long-lambda + x + - 2 + * jnp.reshape( + e_dot_x, + e_dot_x.shape + (1,) * (len(e.shape) - len(e_dot_x.shape)), + ) + * e + ), x, - e) + e, + ) x2 = util.map_tree(lambda x, mu1, s: mu1 + s * x, x, mu1, scale) y2 = util.map_tree(lambda y, mu2, s: mu2 + s * y, y_reflected, mu2, scale) y2 = choose(coupling_proposed, x2, y2) - new_state = util.map_tree(lambda x, y: tf.concat([x, y], axis=0), x2, y2) + new_state = util.map_tree( + lambda x, y: jnp.concatenate([x, y], axis=0), x2, y2 + ) extra = MaximalReflectiveCouplingProposalExtra( log_couple_ratio=log_couple_ratio, coupling_proposed=coupling_proposed, ) - return new_state, (extra, 0.) + return new_state, (extra, 0.0) class RandomWalkMetropolisState(NamedTuple): """Random Walk Metropolis state.""" + state: State - target_log_prob: FloatTensor + target_log_prob: FloatArray state_extra: StateExtra class RandomWalkMetropolisExtra(NamedTuple): """Random Walk Metropolis extra outputs.""" - is_accepted: BooleanTensor - log_accept_ratio: FloatTensor + + is_accepted: BooleanArray + log_accept_ratio: FloatArray proposal_extra: Any proposed_rwm_state: RandomWalkMetropolisState @util.named_call def random_walk_metropolis_init( - state: State, - target_log_prob_fn: PotentialFn) -> RandomWalkMetropolisState: + state: State, target_log_prob_fn: PotentialFn +) -> RandomWalkMetropolisState: """Initializes the `RandomWalkMetropolisState`. Args: @@ -2338,8 +2506,9 @@ def random_walk_metropolis_step( rwm_state: RandomWalkMetropolisState, target_log_prob_fn: PotentialFn, proposal_fn: TransitionOperator, - log_uniform: Optional[FloatTensor] = None, - seed=None) -> tuple[RandomWalkMetropolisState, RandomWalkMetropolisExtra]: + log_uniform: Optional[FloatArray] = None, + seed=None, +) -> tuple[RandomWalkMetropolisState, RandomWalkMetropolisExtra]: """Random Walk Metropolis Hastings `TransitionOperator`. The `proposal_fn` takes in the current state, and must return a proposed @@ -2362,16 +2531,18 @@ def random_walk_metropolis_step( rwm_extra: RandomWalkMetropolisExtra """ seed, sample_seed = util.split_seed(seed, 2) - proposed_state, (proposal_extra, - log_proposed_bias) = proposal_fn(rwm_state.state, - sample_seed) + proposed_state, (proposal_extra, log_proposed_bias) = proposal_fn( + rwm_state.state, sample_seed + ) proposed_target_log_prob, proposed_state_extra = call_potential_fn( - target_log_prob_fn, proposed_state) + target_log_prob_fn, proposed_state + ) # TODO(siege): Is it really a "log accept ratio" if we need to clamp it to 0? log_accept_ratio = ( - proposed_target_log_prob - rwm_state.target_log_prob - log_proposed_bias) + proposed_target_log_prob - rwm_state.target_log_prob - log_proposed_bias + ) proposed_rwm_state = RandomWalkMetropolisState( state=proposed_state, @@ -2399,14 +2570,16 @@ def random_walk_metropolis_step( class RunningVarianceState(NamedTuple): - num_points: IntTensor + num_points: IntNest mean: FloatNest variance: FloatNest @util.named_call -def running_variance_init(shape: IntTensor, - dtype: DTypeNest) -> RunningVarianceState: +def running_variance_init( + shape: ShapeNest, + dtype: DTypeNest, +) -> RunningVarianceState: """Initializes the `RunningVarianceState`. Args: @@ -2417,12 +2590,12 @@ def running_variance_init(shape: IntTensor, state: `RunningVarianceState`. """ return RunningVarianceState( - num_points=util.map_tree(lambda _: tf.zeros([], tf.int32), dtype), - mean=util.map_tree_up_to(dtype, tf.zeros, shape, dtype), + num_points=util.map_tree(lambda _: jnp.zeros([], jnp.int32), dtype), + mean=util.map_tree_up_to(dtype, jnp.zeros, shape, dtype), # The initial value of variance is discarded upon the first update, but # setting it to something reasonable (ones) is convenient in case the # state is read before an update. - variance=util.map_tree_up_to(dtype, tf.ones, shape, dtype), + variance=util.map_tree_up_to(dtype, jnp.ones, shape, dtype), ) @@ -2452,7 +2625,7 @@ def running_variance_step( Args: state: `RunningVarianceState`. - vec: A Tensor to incorporate into the variance estimate. + vec: A Array to incorporate into the variance estimate. axis: If not `None`, treat these axes as being additional axes to aggregate over. window_size: A nest of ints, broadcastable with the structure of `vec`. If @@ -2475,55 +2648,67 @@ def running_variance_step( def _one_part(vec, mean, variance, num_points): """Updates a single part.""" - vec = tf.convert_to_tensor(vec, mean.dtype) + vec = jnp.asarray(vec, mean.dtype) if axis is None: vec_mean = vec - vec_variance = tf.zeros_like(variance) + vec_variance = jnp.zeros_like(variance) else: - vec_mean = tf.reduce_mean(vec, axis) - vec_variance = tf.math.reduce_variance(vec, axis) + vec_mean = jnp.mean(vec, axis) + vec_variance = jnp.var(vec, axis) mean_diff = vec_mean - mean - mean_diff_sq = tf.square(mean_diff) + mean_diff_sq = jnp.square(mean_diff) variance_diff = vec_variance - variance - additional_points = tf.size(vec) // tf.size(mean) - additional_points_f = tf.cast(additional_points, vec.dtype) - num_points_f = tf.cast(num_points, vec.dtype) + additional_points = jnp.size(vec) // jnp.size(mean) + additional_points_f = jnp.array(additional_points, vec.dtype) + num_points_f = jnp.array(num_points, vec.dtype) weight = additional_points_f / (num_points_f + additional_points_f) new_mean = mean + mean_diff * weight new_variance = ( - variance + variance_diff * weight + weight * - (1. - weight) * mean_diff_sq) + variance + + variance_diff * weight + + weight * (1.0 - weight) * mean_diff_sq + ) return new_mean, new_variance, num_points + additional_points - new_mean_variance_num_points = util.map_tree(_one_part, vec, state.mean, - state.variance, state.num_points) + new_mean_variance_num_points = util.map_tree( + _one_part, vec, state.mean, state.variance, state.num_points + ) - new_mean = util.map_tree_up_to(state.mean, lambda x: x[0], - new_mean_variance_num_points) - new_variance = util.map_tree_up_to(state.mean, lambda x: x[1], - new_mean_variance_num_points) - new_num_points = util.map_tree_up_to(state.mean, lambda x: x[2], - new_mean_variance_num_points) + new_mean = util.map_tree_up_to( + state.mean, lambda x: x[0], new_mean_variance_num_points + ) + new_variance = util.map_tree_up_to( + state.mean, lambda x: x[1], new_mean_variance_num_points + ) + new_num_points = util.map_tree_up_to( + state.mean, lambda x: x[2], new_mean_variance_num_points + ) if window_size is not None: window_size = maybe_broadcast_structure(window_size, new_num_points) - new_num_points = util.map_tree(tf.minimum, new_num_points, window_size) - return RunningVarianceState( - num_points=new_num_points, mean=new_mean, variance=new_variance), () + new_num_points = util.map_tree(jnp.minimum, new_num_points, window_size) + return ( + RunningVarianceState( + num_points=new_num_points, mean=new_mean, variance=new_variance + ), + (), + ) class RunningCovarianceState(NamedTuple): """Running Covariance state.""" + num_points: IntNest mean: FloatNest covariance: FloatNest @util.named_call -def running_covariance_init(shape: IntTensor, - dtype: DTypeNest) -> RunningCovarianceState: +def running_covariance_init( + shape: ShapeNest, dtype: DTypeNest +) -> RunningCovarianceState: """Initializes the `RunningCovarianceState`. Args: @@ -2534,20 +2719,23 @@ def running_covariance_init(shape: IntTensor, state: `RunningCovarianceState`. """ return RunningCovarianceState( - num_points=util.map_tree(lambda _: tf.zeros([], tf.int32), dtype), - mean=util.map_tree_up_to(dtype, tf.zeros, shape, dtype), + num_points=util.map_tree(lambda _: jnp.zeros([], jnp.int32), dtype), + mean=util.map_tree_up_to(dtype, jnp.zeros, shape, dtype), covariance=util.map_tree_up_to( # The initial value of covariance is discarded upon the first update, # but setting it to something reasonable (the identity matrix) is # convenient in case the state is read before an update. dtype, - lambda shape, dtype: tf.eye( # pylint: disable=g-long-lambda - shape[-1], - batch_shape=shape[:-1], - dtype=dtype, + lambda shape, dtype: jnp.broadcast_to( # pylint: disable=g-long-lambda + jnp.eye( + shape[-1], + dtype=dtype, + ), + tuple(shape[:-1]) + (shape[-1], shape[-1]), ), shape, - dtype), + dtype, + ), ) @@ -2581,7 +2769,7 @@ def running_covariance_step( Args: state: `RunningCovarianceState`. - vec: A Tensor to incorporate into the variance estimate. + vec: A Array to incorporate into the variance estimate. axis: If not `None`, treat these axes as being additional axes to aggregate over. window_size: A nest of ints, broadcastable with the structure of `vec`. If @@ -2604,55 +2792,65 @@ def running_covariance_step( def _one_part(vec, mean, covariance, num_points): """Updates a single part.""" - vec = tf.convert_to_tensor(vec, mean.dtype) + vec = jnp.asarray(vec, mean.dtype) if axis is None: vec_mean = vec - vec_covariance = tf.zeros_like(covariance) + vec_covariance = jnp.zeros_like(covariance) else: - vec_mean = tf.reduce_mean(vec, axis) + vec_mean = jnp.mean(vec, axis) vec_covariance = tfp.stats.covariance(vec, sample_axis=axis) mean_diff = vec_mean - mean mean_diff_sq = ( - mean_diff[..., :, tf.newaxis] * mean_diff[..., tf.newaxis, :]) + mean_diff[..., :, jnp.newaxis] * mean_diff[..., jnp.newaxis, :] + ) covariance_diff = vec_covariance - covariance - additional_points = tf.size(vec) // tf.size(mean) - additional_points_f = tf.cast(additional_points, vec.dtype) - num_points_f = tf.cast(num_points, vec.dtype) + additional_points = jnp.size(vec) // jnp.size(mean) + additional_points_f = jnp.array(additional_points, vec.dtype) + num_points_f = jnp.array(num_points, vec.dtype) weight = additional_points_f / (num_points_f + additional_points_f) new_mean = mean + mean_diff * weight new_covariance = ( - covariance + covariance_diff * weight + weight * - (1. - weight) * mean_diff_sq) + covariance + + covariance_diff * weight + + weight * (1.0 - weight) * mean_diff_sq + ) return new_mean, new_covariance, num_points + additional_points - new_mean_covariance_num_points = util.map_tree(_one_part, vec, state.mean, - state.covariance, - state.num_points) + new_mean_covariance_num_points = util.map_tree( + _one_part, vec, state.mean, state.covariance, state.num_points + ) - new_mean = util.map_tree_up_to(state.mean, lambda x: x[0], - new_mean_covariance_num_points) - new_covariance = util.map_tree_up_to(state.mean, lambda x: x[1], - new_mean_covariance_num_points) - new_num_points = util.map_tree_up_to(state.mean, lambda x: x[2], - new_mean_covariance_num_points) + new_mean = util.map_tree_up_to( + state.mean, lambda x: x[0], new_mean_covariance_num_points + ) + new_covariance = util.map_tree_up_to( + state.mean, lambda x: x[1], new_mean_covariance_num_points + ) + new_num_points = util.map_tree_up_to( + state.mean, lambda x: x[2], new_mean_covariance_num_points + ) if window_size is not None: window_size = maybe_broadcast_structure(window_size, new_num_points) - new_num_points = util.map_tree(tf.minimum, new_num_points, window_size) - return RunningCovarianceState( - num_points=new_num_points, mean=new_mean, covariance=new_covariance), () + new_num_points = util.map_tree(jnp.minimum, new_num_points, window_size) + return ( + RunningCovarianceState( + num_points=new_num_points, mean=new_mean, covariance=new_covariance + ), + (), + ) class RunningMeanState(NamedTuple): """Running Mean state.""" + num_points: IntNest - mean: FloatTensor + mean: FloatArray @util.named_call -def running_mean_init(shape: IntTensor, - dtype: DTypeNest) -> RunningMeanState: +def running_mean_init(shape: ShapeNest, dtype: DTypeNest) -> RunningMeanState: """Initializes the `RunningMeanState`. Args: @@ -2663,8 +2861,8 @@ def running_mean_init(shape: IntTensor, state: `RunningMeanState`. """ return RunningMeanState( - num_points=util.map_tree(lambda _: tf.zeros([], tf.int32), dtype), - mean=util.map_tree_up_to(dtype, tf.zeros, shape, dtype), + num_points=util.map_tree(lambda _: jnp.zeros([], jnp.int32), dtype), + mean=util.map_tree_up_to(dtype, jnp.zeros, shape, dtype), ) @@ -2689,7 +2887,7 @@ def running_mean_step( Args: state: `RunningMeanState`. - vec: A Tensor to incorporate into the mean. + vec: A Array to incorporate into the mean. axis: If not `None`, treat these axes as being additional axes to aggregate over. window_size: A nest of ints, broadcastable with the structure of `vec`. If @@ -2712,42 +2910,48 @@ def running_mean_step( def _one_part(vec, mean, num_points): """Updates a single part.""" - vec = tf.convert_to_tensor(vec, mean.dtype) + vec = jnp.asarray(vec, mean.dtype) if axis is None: vec_mean = vec else: - vec_mean = tf.reduce_mean(vec, axis) + vec_mean = jnp.mean(vec, axis) mean_diff = vec_mean - mean - additional_points = tf.size(vec) // tf.size(mean) - additional_points_f = tf.cast(additional_points, vec.dtype) - num_points_f = tf.cast(num_points, vec.dtype) + additional_points = jnp.size(vec) // jnp.size(mean) + additional_points_f = jnp.array(additional_points, vec.dtype) + num_points_f = jnp.array(num_points, vec.dtype) weight = additional_points_f / (num_points_f + additional_points_f) new_mean = mean + mean_diff * weight return new_mean, num_points + additional_points - new_mean_num_points = util.map_tree(_one_part, vec, state.mean, - state.num_points) + new_mean_num_points = util.map_tree( + _one_part, vec, state.mean, state.num_points + ) - new_mean = util.map_tree_up_to(state.mean, lambda x: x[0], - new_mean_num_points) - new_num_points = util.map_tree_up_to(state.mean, lambda x: x[1], - new_mean_num_points) + new_mean = util.map_tree_up_to( + state.mean, lambda x: x[0], new_mean_num_points + ) + new_num_points = util.map_tree_up_to( + state.mean, lambda x: x[1], new_mean_num_points + ) if window_size is not None: window_size = maybe_broadcast_structure(window_size, new_num_points) - new_num_points = util.map_tree(tf.minimum, new_num_points, window_size) + new_num_points = util.map_tree(jnp.minimum, new_num_points, window_size) return RunningMeanState(num_points=new_num_points, mean=new_mean), () class PotentialScaleReductionState(RunningVarianceState): """Potential Scale Reduction state.""" + pass @util.named_call -def potential_scale_reduction_init(shape, - dtype) -> PotentialScaleReductionState: +def potential_scale_reduction_init( + shape: ShapeNest, + dtype: DTypeNest, +) -> PotentialScaleReductionState: """Initializes `PotentialScaleReductionState`. Args: @@ -2766,7 +2970,8 @@ def potential_scale_reduction_init(shape, @util.named_call def potential_scale_reduction_step( state: PotentialScaleReductionState, - sample) -> tuple[PotentialScaleReductionState, tuple[()]]: + sample: State, +) -> tuple[PotentialScaleReductionState, tuple[()]]: """Updates `PotentialScaleReductionState`. This computes the potential scale reduction statistic from [1]. Note that @@ -2794,14 +2999,17 @@ def potential_scale_reduction_step( # We are wrapping running variance so that the user doesn't get the chance to # set the reduction axis, which would break the assumptions of # `potential_scale_reduction_extract`. - return PotentialScaleReductionState( - *running_variance_step(state, sample)[0]), () + return ( + PotentialScaleReductionState(*running_variance_step(state, sample)[0]), + (), + ) @util.named_call def potential_scale_reduction_extract( state: PotentialScaleReductionState, - independent_chain_ndims: IntNest = 1) -> FloatNest: + independent_chain_ndims: int | IntNest = 1, +) -> FloatNest: """Extracts the potential scale reduction statistic. Args: @@ -2812,38 +3020,45 @@ def potential_scale_reduction_extract( Returns: rhat: Potential scale reduction. """ - independent_chain_ndims = maybe_broadcast_structure(independent_chain_ndims, - state.mean) + independent_chain_ndims = maybe_broadcast_structure( + independent_chain_ndims, state.mean + ) def _psr_part(num_points, mean, variance, independent_chain_ndims): """Compute PSR for a single part.""" # TODO(siege): Keeping these per-component points is mildly wasteful because # unlike general running variance estimation, these are always the same # across parts. - num_points = tf.cast(num_points, mean.dtype) - num_chains = tf.cast( - np.prod(mean.shape[:independent_chain_ndims]), mean.dtype) + num_points = jnp.array(num_points, mean.dtype) + num_chains = jnp.array( + np.prod(mean.shape[:independent_chain_ndims]), mean.dtype + ) independent_dims = list(range(independent_chain_ndims)) # Within chain variance. - var_w = num_points / (num_points - 1) * tf.reduce_mean( - variance, independent_dims) + var_w = num_points / (num_points - 1) * jnp.mean(variance, independent_dims) # Between chain variance. - var_b = num_chains / (num_chains - 1) * tf.math.reduce_variance( - mean, independent_dims) + var_b = num_chains / (num_chains - 1) * jnp.var(mean, independent_dims) # Estimate of the true variance of the target distribution. sigma2p = (num_points - 1) / num_points * var_w + var_b - return ((num_chains + 1) / num_chains * sigma2p / var_w - (num_points - 1) / - (num_chains * num_points)) - - return util.map_tree(_psr_part, state.num_points, state.mean, state.variance, - independent_chain_ndims) + return (num_chains + 1) / num_chains * sigma2p / var_w - ( + num_points - 1 + ) / (num_chains * num_points) + + return util.map_tree( + _psr_part, + state.num_points, + state.mean, + state.variance, + independent_chain_ndims, + ) class RunningApproximateAutoCovarianceState(NamedTuple): """Running Approximate Auto-Covariance state.""" + buffer: FloatNest - num_steps: IntTensor + num_steps: IntArray mean: FloatNest auto_covariance: FloatNest @@ -2851,7 +3066,7 @@ class RunningApproximateAutoCovarianceState(NamedTuple): @util.named_call def running_approximate_auto_covariance_init( max_lags: int, - state_shape: IntTensor, + state_shape: IntArray, dtype: DTypeNest, axis: Optional[Union[int, list[int], tuple[int]]] = None, ) -> RunningApproximateAutoCovarianceState: @@ -2874,33 +3089,41 @@ def running_approximate_auto_covariance_init( else: # TODO(siege): Can this be done without doing the surrogate computation? mean_shape = util.map_tree_up_to( - dtype, lambda s: tf.reduce_sum(tf.zeros(s), axis).shape, state_shape) + dtype, lambda s: jnp.sum(jnp.zeros(s), axis).shape, state_shape + ) def _shape_with_lags(shape): if isinstance(shape, (tuple, list)): return [max_lags + 1] + list(shape) else: - return tf.concat([[max_lags + 1], - tf.convert_to_tensor(shape, tf.int32)], - axis=0) + return jnp.concatenate( + [[max_lags + 1], jnp.asarray(shape, jnp.int32)], axis=0 + ) return RunningApproximateAutoCovarianceState( buffer=util.map_tree_up_to( - dtype, lambda d, s: tf.zeros(_shape_with_lags(s), dtype=d), dtype, - state_shape), - num_steps=tf.zeros([], dtype=tf.int32), - mean=util.map_tree_up_to(dtype, lambda d, s: tf.zeros(s, dtype=d), dtype, - mean_shape), + dtype, + lambda d, s: jnp.zeros(_shape_with_lags(s), dtype=d), + dtype, + state_shape, + ), + num_steps=jnp.zeros([], dtype=jnp.int32), + mean=util.map_tree_up_to( + dtype, lambda d, s: jnp.zeros(s, dtype=d), dtype, mean_shape + ), auto_covariance=util.map_tree_up_to( - dtype, lambda d, s: tf.zeros(_shape_with_lags(s), dtype=d), dtype, - mean_shape), + dtype, + lambda d, s: jnp.zeros(_shape_with_lags(s), dtype=d), + dtype, + mean_shape, + ), ) @util.named_call def running_approximate_auto_covariance_step( state: RunningApproximateAutoCovarianceState, - vec: TensorNest, + vec: ArrayNest, axis: Optional[Union[int, list[int], tuple[int]]] = None, ) -> tuple[RunningApproximateAutoCovarianceState, tuple[()]]: """Updates `RunningApproximateAutoCovarianceState`. @@ -2940,19 +3163,19 @@ def running_approximate_auto_covariance_step( def _one_part(vec, buf, mean, auto_cov): """Compute the auto-covariance for one part.""" buf_size = buf.shape[0] - tail_idx = tf.range(0, buf_size - 1) - num_steps = state.num_steps - tf.range(buf_size) - num_steps = tf.maximum(0, num_steps) + tail_idx = jnp.arange(0, buf_size - 1) + num_steps = state.num_steps - jnp.arange(buf_size) + num_steps = jnp.maximum(0, num_steps) - buf = tf.gather(buf, tail_idx) - buf = tf.concat([vec[tf.newaxis], buf], 0) + buf = buf[tail_idx] + buf = jnp.concatenate([vec[jnp.newaxis], buf], 0) centered_buf = buf - mean centered_vec = vec - mean num_steps_0 = num_steps[0] # Need to broadcast on the right with autocov. - steps_shape = ([-1] + [1] * (len(auto_cov.shape) - len(num_steps.shape))) - num_steps = tf.reshape(num_steps, steps_shape) + steps_shape = [-1] + [1] * (len(auto_cov.shape) - len(num_steps.shape)) + num_steps = jnp.reshape(num_steps, steps_shape) # TODO(siege): Simplify this to look like running_variance_step. # pyformat: disable @@ -2961,40 +3184,44 @@ def _one_part(vec, buf, mean, auto_cov): additional_points_f = 1 # This assumes `additional_points` is the same for every step, # verified by the buf update logic above. - num_points_f = additional_points_f * tf.cast(num_steps, mean.dtype) + num_points_f = additional_points_f * jnp.array(num_steps, mean.dtype) auto_cov = (( num_points_f * (num_points_f + additional_points_f) * auto_cov + num_points_f * centered_vec * centered_buf) / - tf.square(num_points_f + additional_points_f)) + jnp.square(num_points_f + additional_points_f)) else: - vec_shape = tf.convert_to_tensor(vec.shape) - additional_points = tf.math.reduce_prod(tf.gather(vec_shape, axis)) - additional_points_f = tf.cast(additional_points, vec.dtype) - num_points_f = additional_points_f * tf.cast(num_steps, mean.dtype) + vec_shape = np.asarray(vec.shape) + additional_points = jnp.prod(vec_shape[np.asarray(axis)]) + additional_points_f = jnp.array(additional_points, vec.dtype) + num_points_f = additional_points_f * jnp.array(num_steps, mean.dtype) buf_axis = util.map_tree(lambda a: a + 1, axis) auto_cov = ( num_points_f * (num_points_f + additional_points_f) * auto_cov + - num_points_f * tf.reduce_sum(centered_vec * centered_buf, buf_axis) - - tf.reduce_sum(vec, axis) * tf.reduce_sum(buf, buf_axis) + - additional_points_f * tf.reduce_sum(vec * buf, buf_axis)) / ( - tf.square(num_points_f + additional_points_f)) - centered_vec = tf.reduce_sum(centered_vec, axis) + num_points_f * jnp.sum(centered_vec * centered_buf, buf_axis) - + jnp.sum(vec, axis) * jnp.sum(buf, buf_axis) + + additional_points_f * jnp.sum(vec * buf, buf_axis)) / ( + jnp.square(num_points_f + additional_points_f)) + centered_vec = jnp.sum(centered_vec, axis) # pyformat: enable - num_points_0_f = additional_points_f * tf.cast(num_steps_0, mean.dtype) + num_points_0_f = additional_points_f * jnp.array(num_steps_0, mean.dtype) mean = mean + centered_vec / (num_points_0_f + additional_points_f) return buf, auto_cov, mean - new_buffer_auto_cov_mean = util.map_tree(_one_part, vec, state.buffer, - state.mean, state.auto_covariance) + new_buffer_auto_cov_mean = util.map_tree( + _one_part, vec, state.buffer, state.mean, state.auto_covariance + ) - new_buffer = util.map_tree_up_to(state.buffer, lambda x: x[0], - new_buffer_auto_cov_mean) - new_auto_cov = util.map_tree_up_to(state.buffer, lambda x: x[1], - new_buffer_auto_cov_mean) - new_mean = util.map_tree_up_to(state.buffer, lambda x: x[2], - new_buffer_auto_cov_mean) + new_buffer = util.map_tree_up_to( + state.buffer, lambda x: x[0], new_buffer_auto_cov_mean + ) + new_auto_cov = util.map_tree_up_to( + state.buffer, lambda x: x[1], new_buffer_auto_cov_mean + ) + new_mean = util.map_tree_up_to( + state.buffer, lambda x: x[2], new_buffer_auto_cov_mean + ) state = RunningApproximateAutoCovarianceState( num_steps=state.num_steps + 1, @@ -3007,7 +3234,7 @@ def _one_part(vec, buf, mean, auto_cov): def make_surrogate_loss_fn( grad_fn: Optional[GradFn] = None, - loss_value: tf.Tensor = 0., + loss_value: float | jnp.ndarray = 0.0, ) -> Any: """Creates a surrogate loss function with specified gradients. @@ -3042,10 +3269,11 @@ def make_surrogate_loss_fn( def loss_fn(*args, **kwargs): """The surrogate loss function.""" - @tf.custom_gradient + @jax.custom_gradient def grad_wrapper(*flat_args_kwargs): - new_args, new_kwargs = util.unflatten_tree((args, kwargs), - flat_args_kwargs) + new_args, new_kwargs = util.unflatten_tree( + (args, kwargs), flat_args_kwargs + ) g, e = grad_fn(*new_args, **new_kwargs) # pytype: disable=wrong-arg-count def inner_grad_fn(*_): @@ -3060,13 +3288,14 @@ def inner_grad_fn(*_): class SimpleDualAveragesState(NamedTuple): """Simple Dual Averages state.""" + state: State - step: IntTensor + step: IntArray grad_running_mean_state: RunningMeanState class SimpleDualAveragesExtra(NamedTuple): - loss: FloatTensor + loss: FloatArray loss_extra: Any grads: State @@ -3074,7 +3303,7 @@ class SimpleDualAveragesExtra(NamedTuple): @util.named_call def simple_dual_averages_init( state: FloatNest, - grad_mean_smoothing_steps: IntNest = 0, + grad_mean_smoothing_steps: int | IntNest = 0, ) -> SimpleDualAveragesState: """Initializes Simple Dual Averages state. @@ -3092,15 +3321,18 @@ def simple_dual_averages_init( """ grad_rms = running_mean_init( util.map_tree(lambda s: s.shape, state), - util.map_tree(lambda s: s.dtype, state)) + util.map_tree(lambda s: s.dtype, state), + ) grad_rms = grad_rms._replace( - num_points=util.map_tree(lambda _: grad_mean_smoothing_steps, - grad_rms.num_points)) + num_points=util.map_tree( + lambda _: grad_mean_smoothing_steps, grad_rms.num_points + ) + ) return SimpleDualAveragesState( state=state, # The algorithm assumes this starts at 1. - step=1, + step=jnp.ones([], jnp.int32), grad_running_mean_state=grad_rms, ) @@ -3110,7 +3342,7 @@ def simple_dual_averages_step( sda_state: SimpleDualAveragesState, loss_fn: PotentialFn, shrink_weight: FloatNest, - shrink_point: State = 0., + shrink_point: float | State = 0.0, ) -> tuple[SimpleDualAveragesState, SimpleDualAveragesExtra]: """Performs one step of the Simple Dual Averages algorithm [1]. @@ -3153,9 +3385,9 @@ def simple_dual_averages_step( grad_rms, _ = running_mean_step(sda_state.grad_running_mean_state, grads) def _one_part(shrink_point, shrink_weight, grad_running_mean): - shrink_point = tf.convert_to_tensor(shrink_point, grad_running_mean.dtype) - step_f = tf.cast(step, grad_running_mean.dtype) - return shrink_point - tf.sqrt(step_f) / shrink_weight * grad_running_mean + shrink_point = jnp.asarray(shrink_point, grad_running_mean.dtype) + step_f = jnp.array(step, grad_running_mean.dtype) + return shrink_point - jnp.sqrt(step_f) / shrink_weight * grad_running_mean state = util.map_tree(_one_part, shrink_point, shrink_weight, grad_rms.mean) @@ -3173,23 +3405,25 @@ def _one_part(shrink_point, shrink_weight, grad_running_mean): return sda_state, sda_extra -def _global_norm(x: FloatNest) -> FloatTensor: - return tf.sqrt(sum(tf.reduce_sum(tf.square(v)) for v in util.flatten_tree(x))) +def _global_norm(x: FloatNest) -> FloatArray: + return jnp.sqrt(sum(jnp.sum(jnp.square(v)) for v in util.flatten_tree(x))) -def clip_grads(x: FloatNest, - max_global_norm: FloatTensor, - eps: FloatTensor = 1e-9, - zero_out_nan: bool = True) -> FloatNest: +def clip_grads( + x: FloatNest, + max_global_norm: float | FloatArray, + eps: float | FloatArray = 1e-9, + zero_out_nan: bool = True, +) -> FloatNest: """Clip gradients flowing through x. By default, non-finite gradients are zeroed out. Args: - x: (Possibly nested) floating point `Tensor`. - max_global_norm: Floating point `Tensor`. Maximum global norm of gradients + x: (Possibly nested) floating point `Array`. + max_global_norm: Floating point `Array`. Maximum global norm of gradients flowing through `x`. - eps: Floating point `Tensor`. Epsilon used when normalizing the gradient + eps: Floating point `Array`. Epsilon used when normalizing the gradient norm. zero_out_nan: Boolean. If `True` non-finite gradients are zeroed out. @@ -3197,17 +3431,16 @@ def clip_grads(x: FloatNest, x: Same value as the input. """ - @tf.custom_gradient + @jax.custom_gradient def grad_wrapper(*x): - def grad_fn(*g): g = util.flatten_tree(g) if zero_out_nan: - g = [tf.where(tf.math.is_finite(v), v, tf.zeros_like(v)) for v in g] + g = [jnp.where(jnp.isfinite(v), v, jnp.zeros_like(v)) for v in g] norm = _global_norm(g) + eps def clip_part(v): - return tf.where(norm < max_global_norm, v, v * max_global_norm / norm) + return jnp.where(norm < max_global_norm, v, v * max_global_norm / norm) res = tuple(clip_part(v) for v in g) if len(res) == 1: @@ -3219,13 +3452,14 @@ def clip_part(v): return x, grad_fn return util.unflatten_tree( - x, util.flatten_tree(grad_wrapper(*util.flatten_tree(x)))) + x, util.flatten_tree(grad_wrapper(*util.flatten_tree(x))) + ) -TransitionExtra = TensorNest -LogWeightExtra = TensorNest -ResampleExtra = TensorNest -Stage = IntTensor +TransitionExtra = ArrayNest +LogWeightExtra = ArrayNest +ResampleExtra = ArrayNest +Stage = IntArray class AnnealedImportanceSamplingState(NamedTuple): @@ -3236,14 +3470,15 @@ class AnnealedImportanceSamplingState(NamedTuple): log_weight: Log weight of the particles. stage: Current stage. """ + state: Any - log_weight: FloatTensor + log_weight: FloatArray stage: Stage - def ess(self) -> FloatTensor: + def ess(self) -> FloatArray: """Estimates the effective sample size.""" - norm_weights = tf.nn.softmax(self.log_weight) - return 1. / tf.reduce_sum(norm_weights**2) + norm_weights = jax.nn.softmax(self.log_weight) + return 1.0 / jnp.sum(norm_weights**2) class AnnealedImportanceSamplingExtra(NamedTuple): @@ -3254,16 +3489,16 @@ class AnnealedImportanceSamplingExtra(NamedTuple): transition_extra: Extra outputs from the transition operator. log_weight_extra: Extra outputs from log-weight computation. """ - stage_log_weight: FloatTensor + + stage_log_weight: FloatArray transition_extra: TransitionExtra log_weight_extra: LogWeightExtra @util.named_call def annealed_importance_sampling_init( - state: State, - initial_log_weight: FloatTensor, - initial_stage: Stage = 0) -> AnnealedImportanceSamplingState: + state: State, initial_log_weight: FloatArray, initial_stage: int | Stage = 0 +) -> AnnealedImportanceSamplingState: """Initializes the annealed importance sampler. Args: @@ -3274,11 +3509,11 @@ def annealed_importance_sampling_init( Returns: `AnnealedImportanceSamplingState`. """ - state = util.map_tree(tf.convert_to_tensor, state) + state = util.map_tree(jnp.asarray, state) return AnnealedImportanceSamplingState( state=state, - log_weight=tf.convert_to_tensor(initial_log_weight), - stage=tf.convert_to_tensor(initial_stage, tf.int32), + log_weight=jnp.asarray(initial_log_weight), + stage=jnp.asarray(initial_stage, jnp.int32), ) @@ -3286,14 +3521,12 @@ def annealed_importance_sampling_init( def annealed_importance_sampling_step( ais_state: AnnealedImportanceSamplingState, transition_operator: Callable[ - [State, Stage, Callable[[State], tuple[FloatTensor, StateExtra]]], + [State, Stage, Callable[[State], tuple[FloatArray, StateExtra]]], tuple[State, TransitionExtra], ], make_tlp_fn: Callable[[Stage], PotentialFn], log_weight_fn: Optional[ - Callable[ - [State, State, Stage, TransitionExtra], tuple[FloatTensor, Any] - ] + Callable[[State, State, Stage, TransitionExtra], tuple[FloatArray, Any]] ] = None, ) -> tuple[AnnealedImportanceSamplingState, AnnealedImportanceSamplingExtra]: """Takes a step of the annealed importance sampler (AIS). @@ -3374,8 +3607,8 @@ def transition_operator(state, stage, tlp_fn): Args: ais_state: `AnnealedImportanceSamplingState` - transition_operator: The forward MCMC kernel. It has signature: - `(state, stage, tlp_fn) -> (state, extra)`. + transition_operator: The forward MCMC kernel. It has signature: `(state, + stage, tlp_fn) -> (state, extra)`. make_tlp_fn: A function which, given the stage index, returns an annealed density. log_weight_fn: Optional function to compute the incremental log weight of a @@ -3400,11 +3633,12 @@ def _default_log_weight_fn(old_state, new_state, stage, transition_extra): log_weight_fn = _default_log_weight_fn new_state, transition_extra = transition_operator( - ais_state.state, ais_state.stage, make_tlp_fn(ais_state.stage)) + ais_state.state, ais_state.stage, make_tlp_fn(ais_state.stage) + ) - stage_log_weight, log_weight_extra = log_weight_fn(ais_state.state, new_state, - ais_state.stage, - transition_extra) + stage_log_weight, log_weight_extra = log_weight_fn( + ais_state.state, new_state, ais_state.stage, transition_extra + ) ais_state = ais_state._replace( state=new_state, @@ -3422,10 +3656,10 @@ def _default_log_weight_fn(old_state, new_state, stage, transition_extra): @util.named_call def systematic_resample( particles: State, - log_weights: FloatTensor, + log_weights: FloatArray, seed: Any, - do_resample: Optional[BooleanTensor] = None, -) -> tuple[tuple[State, FloatTensor], IntTensor]: + do_resample: Optional[BooleanArray] = None, +) -> tuple[tuple[State, FloatArray], IntArray]: """Systematically resamples particles in proportion to their weights. This uses the algorithm from [1]. @@ -3447,25 +3681,29 @@ def systematic_resample( Multiple Data Particle Filter. 2006 IEEE Nonlinear Statistical Signal Processing Workshop. https://doi.org/10.1109/NSSPW.2006.4378818 """ - log_weights = tf.convert_to_tensor(log_weights) - log_weights = tf.where( - tf.math.is_nan(log_weights), tf.cast(-float('inf'), log_weights.dtype), - log_weights) - probs = tf.nn.softmax(log_weights) + log_weights = jnp.asarray(log_weights) + log_weights = jnp.where( + jnp.isnan(log_weights), + jnp.array(-float('inf'), log_weights.dtype), + log_weights, + ) + probs = jax.nn.softmax(log_weights) num_particles = probs.shape[0] shift = util.random_uniform([], log_weights.dtype, seed) - pie = tf.cumsum(probs) * num_particles + shift - repeats = tf.cast(util.diff(tf.floor(pie), prepend=0), tf.int32) + pie = jnp.cumsum(probs) * num_particles + shift + repeats = jnp.array(util.diff(jnp.floor(pie), prepend=0), jnp.int32) parent_idxs = util.repeat( - tf.range(num_particles), repeats, total_repeat_length=num_particles) + jnp.arange(num_particles), repeats, total_repeat_length=num_particles + ) if do_resample is not None: - parent_idxs = tf.where(do_resample, parent_idxs, tf.range(num_particles)) - new_particles = util.map_tree(lambda x: tf.gather(x, parent_idxs), particles) - new_log_weights = tf.fill(log_weights.shape, - tfp.math.reduce_logmeanexp(log_weights)) + parent_idxs = jnp.where(do_resample, parent_idxs, jnp.arange(num_particles)) + new_particles = util.map_tree(lambda x: x[parent_idxs], particles) + new_log_weights = jnp.full( + log_weights.shape, tfp.math.reduce_logmeanexp(log_weights) + ) if do_resample is not None: - new_log_weights = tf.where(do_resample, new_log_weights, log_weights) + new_log_weights = jnp.where(do_resample, new_log_weights, log_weights) return (new_particles, new_log_weights), parent_idxs @@ -3473,19 +3711,18 @@ def systematic_resample( def annealed_importance_sampling_resample( ais_state: AnnealedImportanceSamplingState, resample_fn: Callable[ - [State, FloatTensor, Any, BooleanTensor], - tuple[tuple[State, tf.Tensor], ResampleExtra], + [State, FloatArray, Any, BooleanArray], + tuple[tuple[State, jnp.ndarray], ResampleExtra], ] = systematic_resample, - min_ess_threshold: FloatTensor = 0.5, + min_ess_threshold: float | FloatArray = 0.5, seed: Any = None, ) -> tuple[AnnealedImportanceSamplingState, ResampleExtra]: """Resamples the particles in AnnealedImportanceSamplingState.""" - log_weight = tf.convert_to_tensor(ais_state.log_weight) + log_weight = jnp.asarray(ais_state.log_weight) do_resample = ( ais_state.ess() - < tf.cast(log_weight.shape[0], log_weight.dtype) - * min_ess_threshold + < jnp.array(log_weight.shape[0], log_weight.dtype) * min_ess_threshold ) (state, log_weight), extra = resample_fn( ais_state.state, ais_state.log_weight, seed, do_resample @@ -3501,9 +3738,10 @@ class GeometricAnnealingPathExtra(NamedTuple): final_extra: Extra outputs from the `final_target_log_prob_fn`. fraction: Interpolation fraction. """ + initial_extra: StateExtra final_extra: StateExtra - fraction: FloatTensor + fraction: FloatArray def geometric_annealing_path( @@ -3511,7 +3749,7 @@ def geometric_annealing_path( num_stages: Stage, initial_target_log_prob_fn: PotentialFn, final_target_log_prob_fn: PotentialFn, - fraction_fn: Optional[Callable[[FloatTensor], tf.Tensor]] = None, + fraction_fn: Optional[Callable[[FloatArray], jnp.ndarray]] = None, ) -> PotentialFn: """Returns a geometrically interpolated target density function. @@ -3538,12 +3776,13 @@ def annealed_target_log_prob_fn(*args, **kwargs): dtype = init_tlp.dtype - fraction = tf.cast(stage, dtype) / tf.cast(num_stages, dtype) + fraction = jnp.array(stage, dtype) / jnp.array(num_stages, dtype) if fraction_fn is not None: fraction = fraction_fn(fraction) extra = GeometricAnnealingPathExtra( - initial_extra=init_extra, final_extra=fin_extra, fraction=fraction) + initial_extra=init_extra, final_extra=fin_extra, fraction=fraction + ) return init_tlp * (1 - fraction) + fin_tlp * (fraction), extra diff --git a/spinoffs/fun_mc/fun_mc/fun_mc_test.py b/spinoffs/fun_mc/fun_mc/fun_mc_test.py index 4941891633..d45b9c879a 100644 --- a/spinoffs/fun_mc/fun_mc/fun_mc_test.py +++ b/spinoffs/fun_mc/fun_mc/fun_mc_test.py @@ -21,23 +21,24 @@ # Dependency imports from absl.testing import parameterized -import jax +import jax as real_jax from jax import config as jax_config import numpy as np import scipy.stats import tensorflow.compat.v2 as real_tf - from tensorflow_probability.python.internal import test_util as tfp_test_util from fun_mc import backend from fun_mc import fun_mc_lib as fun_mc from fun_mc import prefab from fun_mc import test_util -tf = backend.tf +jnp = backend.jnp +jax = backend.jax tfp = backend.tfp util = backend.util real_tf.enable_v2_behavior() +real_tf.experimental.numpy.experimental_enable_numpy_behavior() jax_config.update('jax_enable_x64', True) TestNamedTuple = collections.namedtuple('TestNamedTuple', 'x, y') @@ -45,8 +46,10 @@ BACKEND = None # Rewritten by backends/rewrite.py. if BACKEND == 'backend_jax': - os.environ['XLA_FLAGS'] = (f'{os.environ.get("XLA_FLAGS", "")} ' - '--xla_force_host_platform_device_count=4') + os.environ['XLA_FLAGS'] = ( + f'{os.environ.get("XLA_FLAGS", "")} ' + '--xla_force_host_platform_device_count=4' + ) def _test_seed(): @@ -66,7 +69,6 @@ def _rev_mclachlan_optimal_4th_order_step(*args, **kwargs): def _skip_on_jax(fn): - @functools.wraps(fn) def _wrapper(self, *args, **kwargs): if not self._is_on_jax: @@ -110,16 +112,17 @@ def _gen_cov(data, axis): source_1[last_unaggregated_src_dim] = 'x' source_2 = list(symbols[:rank]) source_2[last_unaggregated_src_dim] = 'y' - dest = dest[:last_unaggregated_dest_dim] + [ - 'x', 'y' - ] + dest[last_unaggregated_dest_dim + 1:] + dest = ( + dest[:last_unaggregated_dest_dim] + + ['x', 'y'] + + dest[last_unaggregated_dest_dim + 1 :] + ) formula = '{source_1},{source_2}->{dest}'.format( - source_1=''.join(source_1), - source_2=''.join(source_2), - dest=''.join(dest)) - cov = ( - np.einsum(formula, centered_data, centered_data) / - np.prod(np.array(shape)[np.array(axis)])) + source_1=''.join(source_1), source_2=''.join(source_2), dest=''.join(dest) + ) + cov = np.einsum(formula, centered_data, centered_data) / np.prod( + np.array(shape)[np.array(axis)] + ) return cov @@ -135,12 +138,11 @@ def testGenCov(self): class FunMCTest(tfp_test_util.TestCase, parameterized.TestCase): - _is_on_jax = BACKEND == 'backend_jax' def _make_seed(self, seed): if self._is_on_jax: - return jax.random.PRNGKey(seed) + return real_jax.random.PRNGKey(seed) else: return util.make_tensor_seed([seed, 0]) @@ -149,65 +151,65 @@ def _dtype(self): raise NotImplementedError() def _constant(self, value): - return tf.constant(value, self._dtype) + return jnp.array(value, self._dtype) @parameterized.named_parameters( ('Unrolled', True), ('NotUnrolled', False), ) def testTraceSingle(self, unroll): - def fun(x): - return x + 1., 2 * x + return x + 1.0, 2 * x x, e_trace = fun_mc.trace( - state=0., + state=0.0, fn=fun, num_steps=5, trace_fn=lambda _, xp1: xp1, - unroll=unroll) + unroll=unroll, + ) - self.assertAllEqual(5., x) - self.assertAllEqual([0., 2., 4., 6., 8.], e_trace) + self.assertAllEqual(5.0, x) + self.assertAllEqual([0.0, 2.0, 4.0, 6.0, 8.0], e_trace) @parameterized.named_parameters( ('Unrolled', True), ('NotUnrolled', False), ) def testTraceNested(self, unroll): - def fun(x, y): - return (x + 1., y + 2.), () + return (x + 1.0, y + 2.0), () (x, y), (x_trace, y_trace) = fun_mc.trace( - state=(0., 0.), + state=(0.0, 0.0), fn=fun, num_steps=5, trace_fn=lambda xy, _: xy, - unroll=unroll) + unroll=unroll, + ) - self.assertAllEqual(5., x) - self.assertAllEqual(10., y) - self.assertAllEqual([1., 2., 3., 4., 5.], x_trace) - self.assertAllEqual([2., 4., 6., 8., 10.], y_trace) + self.assertAllEqual(5.0, x) + self.assertAllEqual(10.0, y) + self.assertAllEqual([1.0, 2.0, 3.0, 4.0, 5.0], x_trace) + self.assertAllEqual([2.0, 4.0, 6.0, 8.0, 10.0], y_trace) @parameterized.named_parameters( ('Unrolled', True), ('NotUnrolled', False), ) def testTraceTrace(self, unroll): - def fun(x): return fun_mc.trace( - x, lambda x: (x + 1., x + 1.), 2, trace_mask=False, unroll=unroll) + x, lambda x: (x + 1.0, x + 1.0), 2, trace_mask=False, unroll=unroll + ) - x, trace = fun_mc.trace(0., fun, 2) - self.assertAllEqual(4., x) - self.assertAllEqual([2., 4.], trace) + x, trace = fun_mc.trace(0.0, fun, 2) + self.assertAllEqual(4.0, x) + self.assertAllEqual([2.0, 4.0], trace) def testTraceDynamic(self): - @tf.function + @jax.jit def trace_n(num_steps): return fun_mc.trace(0, lambda x: (x + 1, ()), num_steps)[0] @@ -219,19 +221,20 @@ def trace_n(num_steps): ('NotUnrolled', False), ) def testTraceMask(self, unroll): - def fun(x): return x + 1, (2 * x, 3 * x) x, (trace_1, trace_2) = fun_mc.trace( - state=0, fn=fun, num_steps=3, trace_mask=(True, False), unroll=unroll) + state=0, fn=fun, num_steps=3, trace_mask=(True, False), unroll=unroll + ) self.assertAllEqual(3, x) self.assertAllEqual([0, 2, 4], trace_1) self.assertAllEqual(6, trace_2) x, (trace_1, trace_2) = fun_mc.trace( - state=0, fn=fun, num_steps=3, trace_mask=False, unroll=unroll) + state=0, fn=fun, num_steps=3, trace_mask=False, unroll=unroll + ) self.assertAllEqual(3, x) self.assertAllEqual(4, trace_1) @@ -290,33 +293,34 @@ def testCallFnDict(self): ('ArgsToList1', (1,), {}, [1]), ('ArgsToTuple3', (1, 2, 3), {}, [1, 2, 3]), ('ArgsToList3', (1, 2, 3), {}, [1, 2, 3]), - ('ArgsToOrdDict3', - (1, 2, 3), {}, collections.OrderedDict([('c', 1), ('b', 2), ('a', 3)])), - ('ArgsKwargsToOrdDict3', (1, 2), { - 'a': 3 - }, collections.OrderedDict([('c', 1), ('b', 2), ('a', 3)])), - ('KwargsToOrdDict3', (), { - 'a': 3, - 'b': 2, - 'c': 1 - }, collections.OrderedDict([('c', 1), ('b', 2), ('a', 3)])), - ('KwargsToDict3', (), { - 'a': 3, - 'b': 2, - 'c': 1 - }, { - 'c': 1, - 'b': 2, - 'a': 3 - }), + ( + 'ArgsToOrdDict3', + (1, 2, 3), + {}, + collections.OrderedDict([('c', 1), ('b', 2), ('a', 3)]), + ), + ( + 'ArgsKwargsToOrdDict3', + (1, 2), + {'a': 3}, + collections.OrderedDict([('c', 1), ('b', 2), ('a', 3)]), + ), + ( + 'KwargsToOrdDict3', + (), + {'a': 3, 'b': 2, 'c': 1}, + collections.OrderedDict([('c', 1), ('b', 2), ('a', 3)]), + ), + ('KwargsToDict3', (), {'a': 3, 'b': 2, 'c': 1}, {'c': 1, 'b': 2, 'a': 3}), ('ArgsToNamedTuple', (TestNamedTuple(1, 2),), {}, TestNamedTuple(1, 2)), - ('KwargsToNamedTuple', (), { - 'a': TestNamedTuple(1, 2) - }, TestNamedTuple(1, 2)), + ( + 'KwargsToNamedTuple', + (), + {'a': TestNamedTuple(1, 2)}, + TestNamedTuple(1, 2), + ), ('ArgsToScalar', (1,), {}, 1), - ('KwargsToScalar', (), { - 'a': 1 - }, 1), + ('KwargsToScalar', (), {'a': 1}, 1), ('Tuple0', (), {}, ()), ('List0', (), {}, []), ('Dict0', (), {}, {}), @@ -327,28 +331,19 @@ def testRecoverStateFromArgs(self, args, kwargs, state_structure): self.assertAllEqual(state_structure, state) @parameterized.named_parameters( - ('BadKwargs', (), { - 'a': 1, - 'b': 2 - }, 'c'), - ('ArgsOverlap', (1, 2), { - 'c': 1, - 'b': 2 - }, 'a'), + ('BadKwargs', (), {'a': 1, 'b': 2}, 'c'), + ('ArgsOverlap', (1, 2), {'c': 1, 'b': 2}, 'a'), ) def testRecoverStateFromArgsMissing(self, args, kwargs, missing): state_structure = collections.OrderedDict([('c', 1), ('b', 2), ('a', 3)]) - with self.assertRaisesRegex(ValueError, - 'Missing \'{}\' from kwargs.'.format(missing)): + with self.assertRaisesRegex( + ValueError, "Missing '{}' from kwargs.".format(missing) + ): fun_mc.recover_state_from_args(args, kwargs, state_structure) @parameterized.named_parameters( - ('Tuple1', { - 'a': 1 - }, (1,)), - ('List1', { - 'a': 1 - }, [1]), + ('Tuple1', {'a': 1}, (1,)), + ('List1', {'a': 1}, [1]), ) def testRecoverStateFromArgsNoKwargs(self, kwargs, state_structure): with self.assertRaisesRegex(ValueError, 'This wrapper does not'): @@ -365,44 +360,39 @@ def testBroadcastStructure(self): self.assertEqual([[1, 1], [2, 2, 2]], struct) def testCallPotentialFn(self): - def potential(x): return x, () - x, extra = fun_mc.call_potential_fn(potential, 0.) + x, extra = fun_mc.call_potential_fn(potential, 0.0) - self.assertEqual(0., x) + self.assertEqual(0.0, x) self.assertEqual((), extra) def testCallPotentialFnMissingExtra(self): - def potential(x): return x with self.assertRaisesRegex(TypeError, 'A common solution is to adjust'): - fun_mc.call_potential_fn(potential, 0.) + fun_mc.call_potential_fn(potential, 0.0) def testCallTransitionOperator(self): - def kernel(x, y): del y return [x, [1]], () - [x, [y]], extra = fun_mc.call_transition_operator(kernel, [0., None]) - self.assertEqual(0., x) + [x, [y]], extra = fun_mc.call_transition_operator(kernel, [0.0, None]) + self.assertEqual(0.0, x) self.assertEqual(1, y) self.assertEqual((), extra) def testCallTransitionOperatorMissingExtra(self): - def potential(x): return x with self.assertRaisesRegex(TypeError, 'A common solution is to adjust'): - fun_mc.call_transition_operator(potential, 0.) + fun_mc.call_transition_operator(potential, 0.0) def testCallTransitionOperatorBadArgs(self): - def potential(x, y, z): del z return (x, y), () @@ -411,68 +401,79 @@ def potential(x, y, z): fun_mc.call_transition_operator(potential, (1, 2, 3)) def testTransformLogProbFn(self): - def log_prob_fn(x, y): - return (tfp.distributions.Normal(self._constant(0.), 1.).log_prob(x) + - tfp.distributions.Normal(self._constant(1.), 1.).log_prob(y)), () + return ( + tfp.distributions.Normal(self._constant(0.0), 1.0).log_prob(x) + + tfp.distributions.Normal(self._constant(1.0), 1.0).log_prob(y) + ), () bijectors = [ - tfp.bijectors.Scale(scale=self._constant(2.)), - tfp.bijectors.Scale(scale=self._constant(3.)) + tfp.bijectors.Scale(scale=self._constant(2.0)), + tfp.bijectors.Scale(scale=self._constant(3.0)), ] - (transformed_log_prob_fn, - transformed_init_state) = fun_mc.transform_log_prob_fn( - log_prob_fn, bijectors, - [self._constant(2.), self._constant(3.)]) + (transformed_log_prob_fn, transformed_init_state) = ( + fun_mc.transform_log_prob_fn( + log_prob_fn, bijectors, [self._constant(2.0), self._constant(3.0)] + ) + ) self.assertIsInstance(transformed_init_state, list) - self.assertAllClose([1., 1.], transformed_init_state) - tlp, (orig_space, _) = ( - transformed_log_prob_fn(self._constant(1.), self._constant(1.))) - lp = log_prob_fn(self._constant(2.), self._constant(3.))[0] + sum( - b.forward_log_det_jacobian(self._constant(1.), event_ndims=0) - for b in bijectors) - - self.assertAllClose([2., 3.], orig_space) + self.assertAllClose([1.0, 1.0], transformed_init_state) + tlp, (orig_space, _) = transformed_log_prob_fn( + self._constant(1.0), self._constant(1.0) + ) + lp = log_prob_fn(self._constant(2.0), self._constant(3.0))[0] + sum( + b.forward_log_det_jacobian(self._constant(1.0), event_ndims=0) + for b in bijectors + ) + + self.assertAllClose([2.0, 3.0], orig_space) self.assertAllClose(lp, tlp) def testTransformLogProbFnKwargs(self): - def log_prob_fn(x, y): - return (tfp.distributions.Normal(self._constant(0.), 1.).log_prob(x) + - tfp.distributions.Normal(self._constant(1.), 1.).log_prob(y)), () + return ( + tfp.distributions.Normal(self._constant(0.0), 1.0).log_prob(x) + + tfp.distributions.Normal(self._constant(1.0), 1.0).log_prob(y) + ), () bijectors = { - 'x': tfp.bijectors.Scale(scale=self._constant(2.)), - 'y': tfp.bijectors.Scale(scale=self._constant(3.)) + 'x': tfp.bijectors.Scale(scale=self._constant(2.0)), + 'y': tfp.bijectors.Scale(scale=self._constant(3.0)), } - (transformed_log_prob_fn, - transformed_init_state) = fun_mc.transform_log_prob_fn( - log_prob_fn, bijectors, { - 'x': self._constant(2.), - 'y': self._constant(3.), - }) + (transformed_log_prob_fn, transformed_init_state) = ( + fun_mc.transform_log_prob_fn( + log_prob_fn, + bijectors, + { + 'x': self._constant(2.0), + 'y': self._constant(3.0), + }, + ) + ) self.assertIsInstance(transformed_init_state, dict) self.assertAllCloseNested( { - 'x': self._constant(1.), - 'y': self._constant(1.), - }, transformed_init_state) + 'x': self._constant(1.0), + 'y': self._constant(1.0), + }, + transformed_init_state, + ) tlp, (orig_space, _) = transformed_log_prob_fn( - x=self._constant(1.), y=self._constant(1.)) - lp = log_prob_fn( - x=self._constant(2.), y=self._constant(3.))[0] + sum( - b.forward_log_det_jacobian(self._constant(1.), event_ndims=0) - for b in bijectors.values()) - - self.assertAllCloseNested({ - 'x': self._constant(2.), - 'y': self._constant(3.) - }, orig_space) + x=self._constant(1.0), y=self._constant(1.0) + ) + lp = log_prob_fn(x=self._constant(2.0), y=self._constant(3.0))[0] + sum( + b.forward_log_det_jacobian(self._constant(1.0), event_ndims=0) + for b in bijectors.values() + ) + + self.assertAllCloseNested( + {'x': self._constant(2.0), 'y': self._constant(3.0)}, orig_space + ) self.assertAllClose(lp, tlp) # The +1's here are because we initialize the `state_grads` at 1, which @@ -481,27 +482,34 @@ def log_prob_fn(x, y): ('Leapfrog', lambda: fun_mc.leapfrog_step, 1 + 1), ('Ruth4', lambda: fun_mc.ruth4_step, 3 + 1), ('Blanes3', lambda: fun_mc.blanes_3_stage_step, 3 + 1), - ('McLachlan4Fwd', lambda: _fwd_mclachlan_optimal_4th_order_step, 4 + 1, - 9), - ('McLachlan4Rev', lambda: _rev_mclachlan_optimal_4th_order_step, 4 + 1, - 9), + ( + 'McLachlan4Fwd', + lambda: _fwd_mclachlan_optimal_4th_order_step, + 4 + 1, + 9, + ), + ( + 'McLachlan4Rev', + lambda: _rev_mclachlan_optimal_4th_order_step, + 4 + 1, + 9, + ), ) - def testIntegratorStep(self, - method_fn, - num_tlp_calls, - num_tlp_calls_jax=None): + def testIntegratorStep( + self, method_fn, num_tlp_calls, num_tlp_calls_jax=None + ): method = method_fn() tlp_call_counter = [0] def target_log_prob_fn(q): tlp_call_counter[0] += 1 - return -q**2, 1. + return -(q**2), 1.0 def kinetic_energy_fn(p): - return tf.abs(p)**3., 2. + return jnp.abs(p) ** 3.0, 2.0 - state = self._constant(1.) + state = self._constant(1.0) _, _, state_grads = fun_mc.call_potential_fn_with_grads( target_log_prob_fn, state, @@ -509,21 +517,27 @@ def kinetic_energy_fn(p): state, extras = method( integrator_step_state=fun_mc.IntegratorStepState( - state=state, state_grads=state_grads, momentum=self._constant(2.)), + state=state, state_grads=state_grads, momentum=self._constant(2.0) + ), step_size=self._constant(0.1), target_log_prob_fn=target_log_prob_fn, - kinetic_energy_fn=kinetic_energy_fn) + kinetic_energy_fn=kinetic_energy_fn, + ) if num_tlp_calls_jax is not None and self._is_on_jax: num_tlp_calls = num_tlp_calls_jax self.assertEqual(num_tlp_calls, tlp_call_counter[0]) - self.assertEqual(1., extras.state_extra) - self.assertEqual(2., extras.kinetic_energy_extra) + self.assertEqual(1.0, extras.state_extra) + self.assertEqual(2.0, extras.kinetic_energy_extra) - initial_hamiltonian = -target_log_prob_fn( - self._constant(1.))[0] + kinetic_energy_fn(self._constant(2.))[0] - fin_hamiltonian = -target_log_prob_fn(state.state)[0] + kinetic_energy_fn( - state.momentum)[0] + initial_hamiltonian = ( + -target_log_prob_fn(self._constant(1.0))[0] + + kinetic_energy_fn(self._constant(2.0))[0] + ) + fin_hamiltonian = ( + -target_log_prob_fn(state.state)[0] + + kinetic_energy_fn(state.momentum)[0] + ) self.assertAllClose(fin_hamiltonian, initial_hamiltonian, atol=0.2) @@ -533,16 +547,15 @@ def kinetic_energy_fn(p): ('Blanes3', fun_mc.blanes_3_stage_step), ) def testIntegratorStepReversible(self, method): - def target_log_prob_fn(q): - return -q**2, [] + return -(q**2), [] def kinetic_energy_fn(p): - return p**2., [] + return p**2.0, [] seed = self._make_seed(_test_seed()) - state = self._constant(1.) + state = self._constant(1.0) _, _, state_grads = fun_mc.call_potential_fn_with_grads( target_log_prob_fn, state, @@ -552,30 +565,32 @@ def kinetic_energy_fn(p): integrator_step_state=fun_mc.IntegratorStepState( state=state, state_grads=state_grads, - momentum=util.random_normal([], self._dtype, seed)), + momentum=util.random_normal([], self._dtype, seed), + ), step_size=self._constant(0.1), target_log_prob_fn=target_log_prob_fn, - kinetic_energy_fn=kinetic_energy_fn) + kinetic_energy_fn=kinetic_energy_fn, + ) state_rev, _ = method( integrator_step_state=state_fwd._replace(momentum=-state_fwd.momentum), step_size=self._constant(0.1), target_log_prob_fn=target_log_prob_fn, - kinetic_energy_fn=kinetic_energy_fn) + kinetic_energy_fn=kinetic_energy_fn, + ) self.assertAllClose(state, state_rev.state, atol=1e-6) def testMclachlanIntegratorStepReversible(self): - def target_log_prob_fn(q): - return -q**2, [] + return -(q**2), [] def kinetic_energy_fn(p): - return p**2., [] + return p**2.0, [] seed = self._make_seed(_test_seed()) - state = self._constant(1.) + state = self._constant(1.0) _, _, state_grads = fun_mc.call_potential_fn_with_grads( target_log_prob_fn, state, @@ -585,45 +600,49 @@ def kinetic_energy_fn(p): integrator_step_state=fun_mc.IntegratorStepState( state=state, state_grads=state_grads, - momentum=util.random_normal([], self._dtype, seed)), + momentum=util.random_normal([], self._dtype, seed), + ), step_size=self._constant(0.1), target_log_prob_fn=target_log_prob_fn, - kinetic_energy_fn=kinetic_energy_fn) + kinetic_energy_fn=kinetic_energy_fn, + ) state_rev, _ = _rev_mclachlan_optimal_4th_order_step( integrator_step_state=state_fwd._replace(momentum=-state_fwd.momentum), step_size=self._constant(0.1), target_log_prob_fn=target_log_prob_fn, - kinetic_energy_fn=kinetic_energy_fn) + kinetic_energy_fn=kinetic_energy_fn, + ) self.assertAllClose(state, state_rev.state, atol=1e-6) def testMetropolisHastingsStep(self): seed = self._make_seed(_test_seed()) - zero = self._constant(0.) - one = self._constant(1.) + zero = self._constant(0.0) + one = self._constant(1.0) accepted, mh_extra = fun_mc.metropolis_hastings_step( - current_state=zero, - proposed_state=one, - energy_change=-np.inf, - seed=seed) + current_state=zero, proposed_state=one, energy_change=-np.inf, seed=seed + ) self.assertAllEqual(one, accepted) self.assertAllEqual(True, mh_extra.is_accepted) accepted, mh_extra = fun_mc.metropolis_hastings_step( - current_state=zero, proposed_state=one, energy_change=np.inf, seed=seed) + current_state=zero, proposed_state=one, energy_change=np.inf, seed=seed + ) self.assertAllEqual(zero, accepted) self.assertAllEqual(False, mh_extra.is_accepted) accepted, mh_extra = fun_mc.metropolis_hastings_step( - current_state=zero, proposed_state=one, energy_change=np.nan, seed=seed) + current_state=zero, proposed_state=one, energy_change=np.nan, seed=seed + ) self.assertAllEqual(zero, accepted) self.assertAllEqual(False, mh_extra.is_accepted) accepted, mh_extra = fun_mc.metropolis_hastings_step( - current_state=zero, proposed_state=one, energy_change=np.nan, seed=seed) + current_state=zero, proposed_state=one, energy_change=np.nan, seed=seed + ) self.assertAllEqual(zero, accepted) self.assertAllEqual(False, mh_extra.is_accepted) @@ -632,7 +651,8 @@ def testMetropolisHastingsStep(self): proposed_state=one, log_uniform=-one, energy_change=self._constant(-np.log(0.5)), - seed=seed) + seed=seed, + ) self.assertAllEqual(one, accepted) self.assertAllEqual(True, mh_extra.is_accepted) @@ -641,16 +661,18 @@ def testMetropolisHastingsStep(self): proposed_state=one, log_uniform=zero, energy_change=self._constant(-np.log(0.5)), - seed=seed) + seed=seed, + ) self.assertAllEqual(zero, accepted) self.assertAllEqual(False, mh_extra.is_accepted) accepted, _ = fun_mc.metropolis_hastings_step( - current_state=tf.zeros(1000, dtype=self._dtype), - proposed_state=tf.ones(1000, dtype=self._dtype), - energy_change=-tf.math.log(0.5 * tf.ones(1000, dtype=self._dtype)), - seed=seed) - self.assertAllClose(0.5, tf.reduce_mean(accepted), rtol=0.1) + current_state=jnp.zeros(1000, dtype=self._dtype), + proposed_state=jnp.ones(1000, dtype=self._dtype), + energy_change=-jnp.log(0.5 * jnp.ones(1000, dtype=self._dtype)), + seed=seed, + ) + self.assertAllClose(0.5, jnp.mean(accepted), rtol=0.1) def testMetropolisHastingsStepStructure(self): struct_type = collections.namedtuple('Struct', 'a, b') @@ -662,10 +684,12 @@ def testMetropolisHastingsStepStructure(self): current_state=current, proposed_state=proposed, energy_change=-np.inf, - seed=self._make_seed(_test_seed())) + seed=self._make_seed(_test_seed()), + ) self.assertAllEqual(True, mh_extra.is_accepted) self.assertAllEqual( - util.flatten_tree(proposed), util.flatten_tree(accepted)) + util.flatten_tree(proposed), util.flatten_tree(accepted) + ) @parameterized.named_parameters( ('Unrolled', True), @@ -675,14 +699,13 @@ def testBasicHMC(self, unroll): step_size = self._constant(0.2) num_steps = 2000 num_leapfrog_steps = 10 - state = tf.ones([16, 2], dtype=self._dtype) + state = jnp.ones([16, 2], dtype=self._dtype) - base_mean = self._constant([2., 3.]) - base_scale = self._constant([2., 0.5]) + base_mean = self._constant([2.0, 3.0]) + base_scale = self._constant([2.0, 0.5]) def target_log_prob_fn(x): - return -tf.reduce_sum(0.5 * tf.square( - (x - base_mean) / base_scale), -1), () + return -jnp.sum(0.5 * jnp.square((x - base_mean) / base_scale), -1), () def kernel(hmc_state, seed): hmc_seed, seed = util.split_seed(seed, 2) @@ -692,29 +715,36 @@ def kernel(hmc_state, seed): num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, unroll_integrator=unroll, - seed=hmc_seed) + seed=hmc_seed, + ) return (hmc_state, seed), hmc_state.state seed = self._make_seed(_test_seed()) - # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs - # for the jit to do anything. - _, chain = tf.function(lambda state, seed: fun_mc.trace( # pylint: disable=g-long-lambda - state=(fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), - seed), - fn=kernel, - num_steps=num_steps))(state, seed) + _, chain = jax.jit( + lambda state, seed: fun_mc.trace( # pylint: disable=g-long-lambda + state=( + fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), + seed, + ), + fn=kernel, + num_steps=num_steps, + ) + )(state, seed) # Discard the warmup samples. chain = chain[1000:] - sample_mean = tf.reduce_mean(chain, axis=[0, 1]) - sample_var = tf.math.reduce_variance(chain, axis=[0, 1]) + sample_mean = jnp.mean(chain, axis=[0, 1]) + sample_var = jnp.var(chain, axis=[0, 1]) - true_samples = util.random_normal( - shape=[4096, 2], dtype=self._dtype, seed=seed) * base_scale + base_mean + true_samples = ( + util.random_normal(shape=[4096, 2], dtype=self._dtype, seed=seed) + * base_scale + + base_mean + ) - true_mean = tf.reduce_mean(true_samples, axis=0) - true_var = tf.math.reduce_variance(true_samples, axis=0) + true_mean = jnp.mean(true_samples, axis=0) + true_var = jnp.var(true_samples, axis=0) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_var, sample_var, rtol=0.1, atol=0.1) @@ -723,21 +753,23 @@ def testPreconditionedHMC(self): step_size = self._constant(0.2) num_steps = 2000 num_leapfrog_steps = 10 - state = tf.ones([16, 2], dtype=self._dtype) + state = jnp.ones([16, 2], dtype=self._dtype) - base_mean = self._constant([1., 0]) + base_mean = self._constant([1.0, 0]) base_cov = self._constant([[1, 0.5], [0.5, 1]]) bijector = tfp.bijectors.Softplus() base_dist = tfp.distributions.MultivariateNormalFullCovariance( - loc=base_mean, covariance_matrix=base_cov) + loc=base_mean, covariance_matrix=base_cov + ) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mc.transform_log_prob_fn( - orig_target_log_prob_fn, bijector, state) + orig_target_log_prob_fn, bijector, state + ) # pylint: disable=g-long-lambda def kernel(hmc_state, seed): @@ -747,152 +779,109 @@ def kernel(hmc_state, seed): step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, - seed=hmc_seed) + seed=hmc_seed, + ) return (hmc_state, seed), hmc_state.state_extra[0] seed = self._make_seed(_test_seed()) - # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs - # for the jit to do anything. - _, chain = tf.function(lambda state, seed: fun_mc.trace( # pylint: disable=g-long-lambda - state=(fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), - seed), - fn=kernel, - num_steps=num_steps))(state, seed) + _, chain = jax.jit( + lambda state, seed: fun_mc.trace( # pylint: disable=g-long-lambda + state=( + fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), + seed, + ), + fn=kernel, + num_steps=num_steps, + ) + )(state, seed) # Discard the warmup samples. chain = chain[1000:] - sample_mean = tf.reduce_mean(chain, axis=[0, 1]) + sample_mean = jnp.mean(chain, axis=[0, 1]) sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) true_samples = target_dist.sample(4096, seed=self._make_seed(_test_seed())) - true_mean = tf.reduce_mean(true_samples, axis=0) + true_mean = jnp.mean(true_samples, axis=0) true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1) - @parameterized.parameters((tf.function, 1), (_no_compile, 2)) - @_skip_on_jax # `trace` doesn't have an efficient path in JAX yet. - def testHMCCountTargetLogProb(self, compile_fn, expected_count): - - counter = [0] - - @compile_fn - def target_log_prob_fn(x): - counter[0] += 1 - return -tf.square(x), [] - - # pylint: disable=g-long-lambda - @tf.function - def trace(): - kernel = lambda state: fun_mc.hamiltonian_monte_carlo_step( - state, - step_size=self._constant(0.1), - num_integrator_steps=3, - target_log_prob_fn=target_log_prob_fn, - seed=_test_seed()) - - fun_mc.trace( - state=fun_mc.hamiltonian_monte_carlo_init( - tf.zeros([1], dtype=self._dtype), target_log_prob_fn), - fn=kernel, - num_steps=4, - trace_fn=lambda *args: ()) - - trace() - - self.assertEqual(expected_count, counter[0]) - - @_skip_on_jax # `trace` doesn't have an efficient path in JAX yet. - def testHMCCountTargetLogProbEfficient(self): - - counter = [0] - - def target_log_prob_fn(x): - counter[0] += 1 - return -tf.square(x), [] - - @tf.function - def trace(): - # pylint: disable=g-long-lambda - kernel = lambda state: fun_mc.hamiltonian_monte_carlo_step( - state, - step_size=self._constant(0.1), - num_integrator_steps=3, - target_log_prob_fn=target_log_prob_fn, - seed=self._make_seed(_test_seed())) - - fun_mc.trace( - state=fun_mc.hamiltonian_monte_carlo_init( - state=tf.zeros([1], dtype=self._dtype), - target_log_prob_fn=target_log_prob_fn), - fn=kernel, - num_steps=4, - trace_fn=lambda *args: ()) - - trace() - - self.assertEqual(2, counter[0]) - def testAdaptiveStepSize(self): step_size = self._constant(0.2) num_steps = 200 num_adapt_steps = 100 num_leapfrog_steps = 10 - state = tf.ones([16, 2], dtype=self._dtype) + state = jnp.ones([16, 2], dtype=self._dtype) - base_mean = self._constant([1., 0]) + base_mean = self._constant([1.0, 0]) base_cov = self._constant([[1, 0.5], [0.5, 1]]) - @tf.function + @jax.jit def computation(state, seed): bijector = tfp.bijectors.Softplus() base_dist = tfp.distributions.MultivariateNormalFullCovariance( - loc=base_mean, covariance_matrix=base_cov) + loc=base_mean, covariance_matrix=base_cov + ) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mc.transform_log_prob_fn( - orig_target_log_prob_fn, bijector, state) + orig_target_log_prob_fn, bijector, state + ) def kernel(hmc_state, step_size_state, step, seed): hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo_step( hmc_state, - step_size=tf.exp(step_size_state.state), + step_size=jnp.exp(step_size_state.state), num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, - seed=hmc_seed) + seed=hmc_seed, + ) rate = prefab._polynomial_decay( # pylint: disable=protected-access step=step, step_size=self._constant(0.01), power=0.5, decay_steps=num_adapt_steps, - final_step_size=0.) - mean_p_accept = tf.reduce_mean( - tf.exp(tf.minimum(self._constant(0.), hmc_extra.log_accept_ratio))) + final_step_size=0.0, + ) + mean_p_accept = jnp.mean( + jnp.exp( + jnp.minimum(self._constant(0.0), hmc_extra.log_accept_ratio) + ) + ) loss_fn = fun_mc.make_surrogate_loss_fn( - lambda _: (0.9 - mean_p_accept, ())) + lambda _: (0.9 - mean_p_accept, ()) + ) step_size_state, _ = fun_mc.adam_step( - step_size_state, loss_fn, learning_rate=rate) + step_size_state, loss_fn, learning_rate=rate + ) - return ((hmc_state, step_size_state, step + 1, seed), - (hmc_state.state_extra[0], hmc_extra.log_accept_ratio)) + return ( + (hmc_state, step_size_state, step + 1, seed), + (hmc_state.state_extra[0], hmc_extra.log_accept_ratio), + ) _, (chain, log_accept_ratio_trace) = fun_mc.trace( - state=(fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), - fun_mc.adam_init(tf.math.log(step_size)), 0, seed), + state=( + fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), + fun_mc.adam_init(jnp.log(step_size)), + 0, + seed, + ), fn=kernel, num_steps=num_adapt_steps + num_steps, ) true_samples = target_dist.sample( - 4096, seed=self._make_seed(_test_seed())) + 4096, seed=self._make_seed(_test_seed()) + ) return chain, log_accept_ratio_trace, true_samples seed = self._make_seed(_test_seed()) @@ -901,41 +890,43 @@ def kernel(hmc_state, step_size_state, step, seed): log_accept_ratio_trace = log_accept_ratio_trace[num_adapt_steps:] chain = chain[num_adapt_steps:] - sample_mean = tf.reduce_mean(chain, axis=[0, 1]) + sample_mean = jnp.mean(chain, axis=[0, 1]) sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) - true_mean = tf.reduce_mean(true_samples, axis=0) + true_mean = jnp.mean(true_samples, axis=0) true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) self.assertAllClose(true_mean, sample_mean, rtol=0.05, atol=0.05) self.assertAllClose(true_cov, sample_cov, rtol=0.05, atol=0.05) self.assertAllClose( - tf.reduce_mean(tf.exp(tf.minimum(0., log_accept_ratio_trace))), + jnp.mean(jnp.exp(jnp.minimum(0.0, log_accept_ratio_trace))), 0.9, - rtol=0.1) + rtol=0.1, + ) def testSignAdaptation(self): new_control = fun_mc.sign_adaptation( - control=self._constant(1.), + control=self._constant(1.0), output=self._constant(0.5), - set_point=self._constant(1.), - adaptation_rate=self._constant(0.1)) - self.assertAllClose(new_control, 1. / 1.1) + set_point=self._constant(1.0), + adaptation_rate=self._constant(0.1), + ) + self.assertAllClose(new_control, 1.0 / 1.1) new_control = fun_mc.sign_adaptation( - control=self._constant(1.), + control=self._constant(1.0), output=self._constant(0.5), - set_point=self._constant(0.), - adaptation_rate=self._constant(0.1)) - self.assertAllClose(new_control, 1. * 1.1) + set_point=self._constant(0.0), + adaptation_rate=self._constant(0.1), + ) + self.assertAllClose(new_control, 1.0 * 1.1) def testOBABOLangevinIntegrator(self): - def target_log_prob_fn(q): - return -q**2, q + return -(q**2), q def kinetic_energy_fn(p): - return p**2., p + return p**2.0, p def momentum_refresh_fn(p, seed): del seed @@ -950,7 +941,8 @@ def energy_change_fn(old_is, new_is): state, step_size=0.1, target_log_prob_fn=target_log_prob_fn, - kinetic_energy_fn=kinetic_energy_fn) + kinetic_energy_fn=kinetic_energy_fn, + ) lt_integrator_fn = lambda state: fun_mc.obabo_langevin_integrator( # pylint: disable=g-long-lambda state, @@ -970,10 +962,11 @@ def energy_change_fn(old_is, new_is): integrator_trace_fn=lambda state, _: state.state, ) - state = tf.zeros([2], dtype=self._dtype) - momentum = tf.ones([2], dtype=self._dtype) + state = jnp.zeros([2], dtype=self._dtype) + momentum = jnp.ones([2], dtype=self._dtype) target_log_prob, _, state_grads = fun_mc.call_potential_fn_with_grads( - target_log_prob_fn, state) + target_log_prob_fn, state + ) start_state = fun_mc.IntegratorState( target_log_prob=target_log_prob, @@ -990,17 +983,18 @@ def energy_change_fn(old_is, new_is): # present. self.assertAllClose(lt_state, ham_state) self.assertAllClose(lt_extra.energy_change, ham_extra.energy_change) - self.assertAllClose(lt_extra.energy_change, - ham_extra.final_energy - ham_extra.initial_energy) + self.assertAllClose( + lt_extra.energy_change, + ham_extra.final_energy - ham_extra.initial_energy, + ) self.assertAllClose(lt_extra.integrator_trace, ham_extra.integrator_trace) def testRaggedIntegrator(self): - def target_log_prob_fn(q): - return -q**2, q + return -(q**2), q def kinetic_energy_fn(p): - return tf.abs(p)**3., p + return jnp.abs(p) ** 3.0, p integrator_fn = lambda state, num_steps: fun_mc.hamiltonian_integrator( # pylint: disable=g-long-lambda state, @@ -1009,14 +1003,17 @@ def kinetic_energy_fn(p): state, step_size=0.1, target_log_prob_fn=target_log_prob_fn, - kinetic_energy_fn=kinetic_energy_fn), + kinetic_energy_fn=kinetic_energy_fn, + ), kinetic_energy_fn=kinetic_energy_fn, - integrator_trace_fn=lambda state, extra: (state, extra)) + integrator_trace_fn=lambda state, extra: (state, extra), + ) - state = tf.zeros([2], dtype=self._dtype) - momentum = tf.ones([2], dtype=self._dtype) + state = jnp.zeros([2], dtype=self._dtype) + momentum = jnp.ones([2], dtype=self._dtype) target_log_prob, _, state_grads = fun_mc.call_potential_fn_with_grads( - target_log_prob_fn, state) + target_log_prob_fn, state + ) start_state = fun_mc.IntegratorState( target_log_prob=target_log_prob, @@ -1037,7 +1034,8 @@ def kinetic_energy_fn(p): def get_batch(state, idx): # For the integrator trace, we'll grab the final value. return util.map_tree( - lambda x: x[idx] if len(x.shape) == 1 else x[-1, idx], state) + lambda x: x[idx] if len(x.shape) == 1 else x[-1, idx], state + ) self.assertAllClose(get_batch(state_1, 0), get_batch(state_1_2, 0)) self.assertAllClose(get_batch(state_2, 0), get_batch(state_1_2, 1)) @@ -1051,12 +1049,11 @@ def get_slice(state, num, idx): self.assertAllClose(get_slice(state_2, 2, 0), get_slice(state_1_2, 2, 1)) def testRaggedIntegratorMaxSteps(self): - def target_log_prob_fn(q): - return -q**2, q + return -(q**2), q def kinetic_energy_fn(p): - return tf.abs(p)**3., p + return jnp.abs(p) ** 3.0, p integrator_fn = lambda state, num_steps: fun_mc.hamiltonian_integrator( # pylint: disable=g-long-lambda state, @@ -1065,15 +1062,18 @@ def kinetic_energy_fn(p): state, step_size=0.1, target_log_prob_fn=target_log_prob_fn, - kinetic_energy_fn=kinetic_energy_fn), + kinetic_energy_fn=kinetic_energy_fn, + ), kinetic_energy_fn=kinetic_energy_fn, max_num_steps=3, - integrator_trace_fn=lambda state, extra: (state, extra)) + integrator_trace_fn=lambda state, extra: (state, extra), + ) - state = tf.zeros([2], dtype=self._dtype) - momentum = tf.ones([2], dtype=self._dtype) + state = jnp.zeros([2], dtype=self._dtype) + momentum = jnp.ones([2], dtype=self._dtype) target_log_prob, _, state_grads = fun_mc.call_potential_fn_with_grads( - target_log_prob_fn, state) + target_log_prob_fn, state + ) start_state = fun_mc.IntegratorState( target_log_prob=target_log_prob, @@ -1094,7 +1094,8 @@ def kinetic_energy_fn(p): def get_batch(state, idx): # For the integrator trace, we'll grab the final value. return util.map_tree( - lambda x: x[idx] if len(x.shape) == 1 else x[-1, idx], state) + lambda x: x[idx] if len(x.shape) == 1 else x[-1, idx], state + ) self.assertAllClose(get_batch(state_1, 0), get_batch(state_1_2, 0)) self.assertAllClose(get_batch(state_2, 0), get_batch(state_1_2, 1)) @@ -1108,121 +1109,122 @@ def get_slice(state, num, idx): self.assertAllClose(get_slice(state_2, 2, 0), get_slice(state_1_2, 2, 1)) self.assertAllEqual( - 3, - util.flatten_tree(state_1[1].integrator_trace)[0].shape[0]) + 3, util.flatten_tree(state_1[1].integrator_trace)[0].shape[0] + ) self.assertAllEqual( - 3, - util.flatten_tree(state_2[1].integrator_trace)[0].shape[0]) + 3, util.flatten_tree(state_2[1].integrator_trace)[0].shape[0] + ) self.assertAllEqual( - 3, - util.flatten_tree(state_1_2[1].integrator_trace)[0].shape[0]) + 3, util.flatten_tree(state_1_2[1].integrator_trace)[0].shape[0] + ) def testAdam(self): - def loss_fn(x, y): - return tf.square(x - 1.) + tf.square(y - 2.), [] + return jnp.square(x - 1.0) + jnp.square(y - 2.0), [] _, [(x, y), loss] = fun_mc.trace( - fun_mc.adam_init([self._constant(0.), - self._constant(0.)]), + fun_mc.adam_init([self._constant(0.0), self._constant(0.0)]), lambda adam_state: fun_mc.adam_step( # pylint: disable=g-long-lambda - adam_state, - loss_fn, - learning_rate=self._constant(0.01)), + adam_state, loss_fn, learning_rate=self._constant(0.01) + ), num_steps=1000, - trace_fn=lambda state, extra: [state.state, extra.loss]) + trace_fn=lambda state, extra: [state.state, extra.loss], + ) - self.assertAllClose(1., x[-1], atol=1e-3) - self.assertAllClose(2., y[-1], atol=1e-3) - self.assertAllClose(0., loss[-1], atol=1e-3) + self.assertAllClose(1.0, x[-1], atol=1e-3) + self.assertAllClose(2.0, y[-1], atol=1e-3) + self.assertAllClose(0.0, loss[-1], atol=1e-3) def testGradientDescent(self): - def loss_fn(x, y): - return tf.square(x - 1.) + tf.square(y - 2.), [] + return jnp.square(x - 1.0) + jnp.square(y - 2.0), [] _, [(x, y), loss] = fun_mc.trace( - fun_mc.GradientDescentState([self._constant(0.), - self._constant(0.)]), + fun_mc.GradientDescentState([self._constant(0.0), self._constant(0.0)]), lambda gd_state: fun_mc.gradient_descent_step( # pylint: disable=g-long-lambda - gd_state, - loss_fn, - learning_rate=self._constant(0.01)), + gd_state, loss_fn, learning_rate=self._constant(0.01) + ), num_steps=1000, - trace_fn=lambda state, extra: [state.state, extra.loss]) + trace_fn=lambda state, extra: [state.state, extra.loss], + ) - self.assertAllClose(1., x[-1], atol=1e-3) - self.assertAllClose(2., y[-1], atol=1e-3) - self.assertAllClose(0., loss[-1], atol=1e-3) + self.assertAllClose(1.0, x[-1], atol=1e-3) + self.assertAllClose(2.0, y[-1], atol=1e-3) + self.assertAllClose(0.0, loss[-1], atol=1e-3) def testSimpleDualAverages(self): - def loss_fn(x, y): - return tf.square(x - 1.) + tf.square(y - 2.), [] + return jnp.square(x - 1.0) + jnp.square(y - 2.0), [] def kernel(sda_state, rms_state): - sda_state, _ = fun_mc.simple_dual_averages_step(sda_state, loss_fn, 1.) + sda_state, _ = fun_mc.simple_dual_averages_step(sda_state, loss_fn, 1.0) rms_state, _ = fun_mc.running_mean_step(rms_state, sda_state.state) return (sda_state, rms_state), rms_state.mean _, (x, y) = fun_mc.trace( ( fun_mc.simple_dual_averages_init( - [self._constant(0.), self._constant(0.)]), + [self._constant(0.0), self._constant(0.0)] + ), fun_mc.running_mean_init([[], []], [self._dtype, self._dtype]), ), kernel, num_steps=1000, ) - self.assertAllClose(1., x[-1], atol=1e-1) - self.assertAllClose(2., y[-1], atol=1e-1) + self.assertAllClose(1.0, x[-1], atol=1e-1) + self.assertAllClose(2.0, y[-1], atol=1e-1) def testGaussianProposal(self): num_samples = 1000 state = { - 'x': - tf.zeros([num_samples, 1], dtype=self._dtype) + - self._constant([0., 1.]), - 'y': - tf.zeros([num_samples], dtype=self._dtype) + self._constant(3.) + 'x': jnp.zeros([num_samples, 1], dtype=self._dtype) + self._constant( + [0.0, 1.0] + ), + 'y': jnp.zeros([num_samples], dtype=self._dtype) + self._constant(3.0), } - scale = 2. + scale = 2.0 state_samples, _ = fun_mc.gaussian_proposal( - state, scale=scale, seed=self._make_seed(_test_seed())) + state, scale=scale, seed=self._make_seed(_test_seed()) + ) _, p_val_x0 = scipy.stats.kstest( state_samples['x'][:, 0], - lambda x: scipy.stats.norm.cdf(x, loc=0., scale=2.)) + lambda x: scipy.stats.norm.cdf(x, loc=0.0, scale=2.0), + ) _, p_val_x1 = scipy.stats.kstest( state_samples['x'][:, 1], - lambda x: scipy.stats.norm.cdf(x, loc=1., scale=2.)) + lambda x: scipy.stats.norm.cdf(x, loc=1.0, scale=2.0), + ) _, p_val_y = scipy.stats.kstest( - state_samples['y'], lambda x: scipy.stats.norm.cdf(x, loc=3., scale=2.)) + state_samples['y'], + lambda x: scipy.stats.norm.cdf(x, loc=3.0, scale=2.0), + ) self.assertGreater(p_val_x0, 1e-3) self.assertGreater(p_val_x1, 1e-3) self.assertGreater(p_val_y, 1e-3) - mean = util.map_tree(lambda x: tf.reduce_mean(x, 0), state_samples) - variance = util.map_tree(lambda x: tf.math.reduce_variance(x, 0), - state_samples) + mean = util.map_tree(lambda x: jnp.mean(x, 0), state_samples) + variance = util.map_tree(lambda x: jnp.var(x, 0), state_samples) self.assertAllClose(state['x'][0], mean['x'], atol=0.2) self.assertAllClose(state['y'][0], mean['y'], atol=0.2) self.assertAllClose( - scale**2 * tf.ones_like(mean['x']), variance['x'], atol=0.5) + scale**2 * jnp.ones_like(mean['x']), variance['x'], atol=0.5 + ) self.assertAllClose( - scale**2 * tf.ones_like(mean['y']), variance['y'], atol=0.5) + scale**2 * jnp.ones_like(mean['y']), variance['y'], atol=0.5 + ) def testGaussianProposalNamedAxis(self): if BACKEND != 'backend_jax': self.skipTest('JAX-only') state = { - 'sharded': tf.zeros([4], self._dtype), - 'shared': tf.zeros([], self._dtype), + 'sharded': jnp.zeros([4], self._dtype), + 'shared': jnp.zeros([], self._dtype), } in_axes = { 'sharded': 0, @@ -1234,57 +1236,67 @@ def testGaussianProposalNamedAxis(self): } @functools.partial( - jax.pmap, in_axes=(in_axes, None), axis_name='named_axis') + real_jax.pmap, in_axes=(in_axes, None), axis_name='named_axis' + ) def proposal_fn(state, seed): samples, _ = fun_mc.gaussian_proposal( - state, scale=1., named_axis=named_axis, seed=seed) + state, scale=1.0, named_axis=named_axis, seed=seed + ) return samples samples = proposal_fn(state, self._make_seed(_test_seed())) self.assertAllClose(samples['shared'][0], samples['shared'][1]) self.assertTrue( - np.any(np.abs(samples['sharded'][0] - samples['sharded'][1]) > 1e-3)) + np.any(np.abs(samples['sharded'][0] - samples['sharded'][1]) > 1e-3) + ) def testMaximalReflectiveProposal(self): state = { - 'x': self._constant([[0., 1.], [2., 3.]]), - 'y': self._constant([3., 4.]) + 'x': self._constant([[0.0, 1.0], [2.0, 3.0]]), + 'y': self._constant([3.0, 4.0]), } - scale = 2. + scale = 2.0 def kernel(seed): proposal_seed, seed = util.split_seed(seed, 2) new_state, (extra, _) = fun_mc.maximal_reflection_coupling_proposal( - state, scale=scale, seed=proposal_seed) + state, scale=scale, seed=proposal_seed + ) return seed, (new_state, extra.coupling_proposed) # Simulate an MCMC run. _, (state_samples, coupling_proposed) = fun_mc.trace( - self._make_seed(_test_seed()), kernel, 1000) + self._make_seed(_test_seed()), kernel, 1000 + ) - mean = util.map_tree(lambda x: tf.reduce_mean(x, 0), state_samples) - variance = util.map_tree(lambda x: tf.math.reduce_variance(x, 0), - state_samples) + mean = util.map_tree(lambda x: jnp.mean(x, 0), state_samples) + variance = util.map_tree(lambda x: jnp.var(x, 0), state_samples) _, p_val_x00 = scipy.stats.kstest( state_samples['x'][:, 0, 0], - lambda x: scipy.stats.norm.cdf(x, loc=0., scale=2.)) + lambda x: scipy.stats.norm.cdf(x, loc=0.0, scale=2.0), + ) _, p_val_x01 = scipy.stats.kstest( state_samples['x'][:, 0, 1], - lambda x: scipy.stats.norm.cdf(x, loc=1., scale=2.)) + lambda x: scipy.stats.norm.cdf(x, loc=1.0, scale=2.0), + ) _, p_val_x10 = scipy.stats.kstest( state_samples['x'][:, 1, 0], - lambda x: scipy.stats.norm.cdf(x, loc=2., scale=2.)) + lambda x: scipy.stats.norm.cdf(x, loc=2.0, scale=2.0), + ) _, p_val_x11 = scipy.stats.kstest( state_samples['x'][:, 1, 1], - lambda x: scipy.stats.norm.cdf(x, loc=3., scale=2.)) + lambda x: scipy.stats.norm.cdf(x, loc=3.0, scale=2.0), + ) _, p_val_y0 = scipy.stats.kstest( state_samples['y'][:, 0], - lambda x: scipy.stats.norm.cdf(x, loc=3., scale=2.)) + lambda x: scipy.stats.norm.cdf(x, loc=3.0, scale=2.0), + ) _, p_val_y1 = scipy.stats.kstest( state_samples['y'][:, 1], - lambda x: scipy.stats.norm.cdf(x, loc=4., scale=2.)) + lambda x: scipy.stats.norm.cdf(x, loc=4.0, scale=2.0), + ) self.assertGreater(p_val_x00, 1e-3) self.assertGreater(p_val_x01, 1e-3) self.assertGreater(p_val_x10, 1e-3) @@ -1295,16 +1307,22 @@ def kernel(seed): self.assertAllClose(state['x'], mean['x'], atol=0.2) self.assertAllClose(state['y'], mean['y'], atol=0.2) self.assertAllClose( - scale**2 * tf.ones_like(mean['x']), variance['x'], atol=0.5) + scale**2 * jnp.ones_like(mean['x']), variance['x'], atol=0.5 + ) self.assertAllClose( - scale**2 * tf.ones_like(mean['y']), variance['y'], atol=0.5) + scale**2 * jnp.ones_like(mean['y']), variance['y'], atol=0.5 + ) coupled = coupling_proposed & np.array( util.flatten_tree( util.map_tree( - lambda x: tf.math.reduce_all( # pylint: disable=g-long-lambda - (x[:, :1] == x[:, 1:]), tuple(range(2, len(x.shape)))), - state_samples))).all(0) + lambda x: jnp.all( # pylint: disable=g-long-lambda + (x[:, :1] == x[:, 1:]), tuple(range(2, len(x.shape))) + ), + state_samples, + ) + ) + ).all(0) self.assertAllClose(coupling_proposed, coupled) def testMaximalReflectiveProposalNamedAxis(self): @@ -1313,11 +1331,11 @@ def testMaximalReflectiveProposalNamedAxis(self): state = { # The trailing shape is [coupled axis, independent chains]. - 'sharded': - tf.zeros([4, 2, 1024], self._dtype) + - self._constant([0., 1.])[:, tf.newaxis], - 'shared': - tf.zeros([2, 1024], self._dtype), + 'sharded': ( + jnp.zeros([4, 2, 1024], self._dtype) + + self._constant([0.0, 1.0])[:, jnp.newaxis] + ), + 'shared': jnp.zeros([2, 1024], self._dtype), } in_axes = { 'sharded': 0, @@ -1329,10 +1347,12 @@ def testMaximalReflectiveProposalNamedAxis(self): } @functools.partial( - jax.pmap, in_axes=(in_axes, None), axis_name='named_axis') + real_jax.pmap, in_axes=(in_axes, None), axis_name='named_axis' + ) def proposal_fn(state, seed): samples, (extra, _) = fun_mc.maximal_reflection_coupling_proposal( - state, chain_ndims=1, scale=1., named_axis=named_axis, seed=seed) + state, chain_ndims=1, scale=1.0, named_axis=named_axis, seed=seed + ) return samples, extra samples, extra = proposal_fn(state, self._make_seed(_test_seed())) @@ -1341,15 +1361,16 @@ def proposal_fn(state, seed): self.assertAllClose(extra.log_couple_ratio[0], extra.log_couple_ratio[1]) self.assertAllClose(extra.coupling_proposed[0], extra.coupling_proposed[1]) self.assertTrue( - np.any(np.abs(samples['sharded'][0] - samples['sharded'][1]) > 1e-3)) + np.any(np.abs(samples['sharded'][0] - samples['sharded'][1]) > 1e-3) + ) def testHMCNamedAxis(self): if BACKEND != 'backend_jax': self.skipTest('JAX-only') state = { - 'sharded': tf.zeros([4, 1024], self._dtype), - 'shared': tf.zeros([1024], self._dtype), + 'sharded': jnp.zeros([4, 1024], self._dtype), + 'shared': jnp.zeros([1024], self._dtype), } in_axes = { 'sharded': 0, @@ -1361,50 +1382,68 @@ def testHMCNamedAxis(self): } def target_log_prob_fn(sharded, shared): - return -(backend.distribute_lib.psum(tf.square(sharded), 'named_axis') + - tf.square(shared)), () + return ( + -( + backend.distribute_lib.psum(jnp.square(sharded), 'named_axis') + + jnp.square(shared) + ), + (), + ) @functools.partial( - jax.pmap, in_axes=(in_axes, None), axis_name='named_axis') + real_jax.pmap, in_axes=(in_axes, None), axis_name='named_axis' + ) def kernel(state, seed): hmc_state = fun_mc.hamiltonian_monte_carlo_init( - state, target_log_prob_fn=target_log_prob_fn) + state, target_log_prob_fn=target_log_prob_fn + ) hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo_step( hmc_state, step_size=self._constant(0.2), num_integrator_steps=4, target_log_prob_fn=target_log_prob_fn, named_axis=named_axis, - seed=seed) + seed=seed, + ) return hmc_state, hmc_extra seed = self._make_seed(_test_seed()) hmc_state, hmc_extra = kernel(state, seed) - self.assertAllClose(hmc_state.state['shared'][0], - hmc_state.state['shared'][1]) + self.assertAllClose( + hmc_state.state['shared'][0], hmc_state.state['shared'][1] + ) self.assertTrue( np.any( - np.abs(hmc_state.state['sharded'][0] - - hmc_state.state['sharded'][1]) > 1e-3)) + np.abs( + hmc_state.state['sharded'][0] - hmc_state.state['sharded'][1] + ) + > 1e-3 + ) + ) self.assertAllClose(hmc_extra.is_accepted[0], hmc_extra.is_accepted[1]) - self.assertAllClose(hmc_extra.log_accept_ratio[0], - hmc_extra.log_accept_ratio[1]) + self.assertAllClose( + hmc_extra.log_accept_ratio[0], hmc_extra.log_accept_ratio[1] + ) def testRandomWalkMetropolis(self): num_steps = 1000 - state = tf.ones([16], dtype=tf.int32) - target_logits = self._constant([1., 2., 3., 4.]) + 2. - proposal_logits = self._constant([4., 3., 2., 1.]) + 2. + state = jnp.ones([16], dtype=jnp.int32) + target_logits = self._constant([1.0, 2.0, 3.0, 4.0]) + 2.0 + proposal_logits = self._constant([4.0, 3.0, 2.0, 1.0]) + 2.0 def target_log_prob_fn(x): - return tf.gather(target_logits, x), () + return jnp.asarray(target_logits)[x], () def proposal_fn(x, seed): - current_logits = tf.gather(proposal_logits, x) - proposal = util.random_categorical(proposal_logits[tf.newaxis], - x.shape[0], seed)[0] - proposed_logits = tf.gather(proposal_logits, proposal) - return tf.cast(proposal, x.dtype), ((), proposed_logits - current_logits) + current_logits = proposal_logits[x] + proposal = util.random_categorical( + proposal_logits[jnp.newaxis], x.shape[0], seed + )[0] + proposed_logits = proposal_logits[proposal] + return jnp.array(proposal, x.dtype), ( + (), + proposed_logits - current_logits, + ) def kernel(rwm_state, seed): rwm_seed, seed = util.split_seed(seed, 2) @@ -1412,24 +1451,28 @@ def kernel(rwm_state, seed): rwm_state, target_log_prob_fn=target_log_prob_fn, proposal_fn=proposal_fn, - seed=rwm_seed) + seed=rwm_seed, + ) return (rwm_state, seed), rwm_extra seed = self._make_seed(_test_seed()) - # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs - # for the jit to do anything. - _, chain = tf.function(lambda state, seed: fun_mc.trace( # pylint: disable=g-long-lambda - state=(fun_mc.random_walk_metropolis_init(state, target_log_prob_fn), - seed), - fn=kernel, - num_steps=num_steps, - trace_fn=lambda state, extra: state[0].state))(state, seed) + _, chain = jax.jit( + lambda state, seed: fun_mc.trace( # pylint: disable=g-long-lambda + state=( + fun_mc.random_walk_metropolis_init(state, target_log_prob_fn), + seed, + ), + fn=kernel, + num_steps=num_steps, + trace_fn=lambda state, extra: state[0].state, + ) + )(state, seed) # Discard the warmup samples. chain = chain[500:] - sample_mean = tf.reduce_mean(tf.one_hot(chain, 4), axis=[0, 1]) - self.assertAllClose(tf.nn.softmax(target_logits), sample_mean, atol=0.11) + sample_mean = jnp.mean(jax.nn.one_hot(chain, 4), axis=[0, 1]) + self.assertAllClose(jax.nn.softmax(target_logits), sample_mean, atol=0.11) @parameterized.named_parameters( ('Basic', (10, 3), None), @@ -1447,15 +1490,19 @@ def kernel(rms, idx): rms, _ = fun_mc.running_mean_step(rms, data[idx], axis=aggregation) return (rms, idx + 1), () - true_aggregation = (0,) + (() if aggregation is None else tuple( - [a + 1 for a in util.flatten_tree(aggregation)])) + true_aggregation = (0,) + ( + () + if aggregation is None + else tuple([a + 1 for a in util.flatten_tree(aggregation)]) + ) true_mean = np.mean(data, true_aggregation) (rms, _), _ = fun_mc.trace( state=(fun_mc.running_mean_init(true_mean.shape, data.dtype), 0), fn=kernel, num_steps=len(data), - trace_fn=lambda *args: ()) + trace_fn=lambda *args: (), + ) self.assertAllClose(true_mean, rms.mean) @@ -1464,8 +1511,10 @@ def testRunningMeanMaxPoints(self): rng = np.random.RandomState(_test_seed()) data = self._constant( np.concatenate( - [rng.randn(window_size), 1. + 2. * rng.randn(window_size * 10)], - axis=0)) + [rng.randn(window_size), 1.0 + 2.0 * rng.randn(window_size * 10)], + axis=0, + ) + ) def kernel(rms, idx): rms, _ = fun_mc.running_mean_step(rms, data[idx], window_size=window_size) @@ -1481,7 +1530,7 @@ def kernel(rms, idx): # After window_size, we're doing exponential moving average, and pick up the # mean after the change in the distribution. Since the moving average is # computed only over ~window_size points, this test is rather noisy. - self.assertAllClose(1., mean[-1], atol=0.2) + self.assertAllClose(1.0, mean[-1], atol=0.2) @parameterized.named_parameters( ('Basic', (10, 3), None), @@ -1495,8 +1544,11 @@ def testRunningVariance(self, shape, aggregation): rng = np.random.RandomState(_test_seed()) data = self._constant(rng.randn(*shape)) - true_aggregation = (0,) + (() if aggregation is None else tuple( - [a + 1 for a in util.flatten_tree(aggregation)])) + true_aggregation = (0,) + ( + () + if aggregation is None + else tuple([a + 1 for a in util.flatten_tree(aggregation)]) + ) true_mean = np.mean(data, true_aggregation) true_var = np.var(data, true_aggregation) @@ -1508,7 +1560,8 @@ def kernel(rvs, idx): state=(fun_mc.running_variance_init(true_mean.shape, data[0].dtype), 0), fn=kernel, num_steps=len(data), - trace_fn=lambda *args: ()) + trace_fn=lambda *args: (), + ) self.assertAllClose(true_mean, rvs.mean) self.assertAllClose(true_var, rvs.variance) @@ -1517,12 +1570,15 @@ def testRunningVarianceMaxPoints(self): rng = np.random.RandomState(_test_seed()) data = self._constant( np.concatenate( - [rng.randn(window_size), 1. + 2. * rng.randn(window_size * 10)], - axis=0)) + [rng.randn(window_size), 1.0 + 2.0 * rng.randn(window_size * 10)], + axis=0, + ) + ) def kernel(rvs, idx): rvs, _ = fun_mc.running_variance_step( - rvs, data[idx], window_size=window_size) + rvs, data[idx], window_size=window_size + ) return (rvs, idx + 1), (rvs.mean, rvs.variance) _, (mean, var) = fun_mc.trace( @@ -1537,8 +1593,8 @@ def kernel(rvs, idx): # mean/variance after the change in the distribution. Since the moving # average is computed only over ~window_size points, this test is rather # noisy. - self.assertAllClose(1., mean[-1], atol=0.2) - self.assertAllClose(4., var[-1], atol=0.8) + self.assertAllClose(1.0, mean[-1], atol=0.2) + self.assertAllClose(4.0, var[-1], atol=0.8) @parameterized.named_parameters( ('Basic', (10, 3), None), @@ -1550,8 +1606,11 @@ def testRunningCovariance(self, shape, aggregation): rng = np.random.RandomState(_test_seed()) data = self._constant(rng.randn(*shape)) - true_aggregation = (0,) + (() if aggregation is None else tuple( - [a + 1 for a in util.flatten_tree(aggregation)])) + true_aggregation = (0,) + ( + () + if aggregation is None + else tuple([a + 1 for a in util.flatten_tree(aggregation)]) + ) true_mean = np.mean(data, true_aggregation) true_cov = _gen_cov(data, true_aggregation) @@ -1560,11 +1619,14 @@ def kernel(rcs, idx): return (rcs, idx + 1), () (rcs, _), _ = fun_mc.trace( - state=(fun_mc.running_covariance_init(true_mean.shape, - data[0].dtype), 0), + state=( + fun_mc.running_covariance_init(true_mean.shape, data[0].dtype), + 0, + ), fn=kernel, num_steps=len(data), - trace_fn=lambda *args: ()) + trace_fn=lambda *args: (), + ) self.assertAllClose(true_mean, rcs.mean) self.assertAllClose(true_cov, rcs.covariance) @@ -1575,15 +1637,17 @@ def testRunningCovarianceMaxPoints(self): np.concatenate( [ rng.randn(window_size, 2), - np.array([1., 2.]) + - np.array([2., 3.]) * rng.randn(window_size * 10, 2) + np.array([1.0, 2.0]) + + np.array([2.0, 3.0]) * rng.randn(window_size * 10, 2), ], axis=0, - )) + ) + ) def kernel(rvs, idx): rvs, _ = fun_mc.running_covariance_step( - rvs, data[idx], window_size=window_size) + rvs, data[idx], window_size=window_size + ) return (rvs, idx + 1), (rvs.mean, rvs.covariance) _, (mean, cov) = fun_mc.trace( @@ -1593,15 +1657,17 @@ def kernel(rvs, idx): ) # Up to window_size, we compute the running mean/variance exactly. self.assertAllClose( - np.mean(data[:window_size], axis=0), mean[window_size - 1]) + np.mean(data[:window_size], axis=0), mean[window_size - 1] + ) self.assertAllClose( - _gen_cov(data[:window_size], axis=0), cov[window_size - 1]) + _gen_cov(data[:window_size], axis=0), cov[window_size - 1] + ) # After window_size, we're doing exponential moving average, and pick up the # mean/variance after the change in the distribution. Since the moving # average is computed only over ~window_size points, this test is rather # noisy. - self.assertAllClose(np.array([1., 2.]), mean[-1], atol=0.2) - self.assertAllClose(np.array([[4., 0.], [0., 9.]]), cov[-1], atol=1.) + self.assertAllClose(np.array([1.0, 2.0]), mean[-1], atol=0.2) + self.assertAllClose(np.array([[4.0, 0.0], [0.0, 9.0]]), cov[-1], atol=1.0) @parameterized.named_parameters( ('BasicScalar', (10, 20), 1), @@ -1615,19 +1681,24 @@ def testPotentialScaleReduction(self, chain_shape, independent_chain_ndims): chains = 0.4 * rng.randn(*chain_shape).astype(np.float32) + chain_means true_rhat = tfp.mcmc.potential_scale_reduction( - chains, independent_chain_ndims=independent_chain_ndims) + chains, independent_chain_ndims=independent_chain_ndims + ) chains = self._constant(chains) psrs, _ = fun_mc.trace( - state=fun_mc.potential_scale_reduction_init(chain_shape[1:], - self._dtype), + state=fun_mc.potential_scale_reduction_init( + chain_shape[1:], self._dtype + ), fn=lambda psrs: fun_mc.potential_scale_reduction_step( # pylint: disable=g-long-lambda - psrs, chains[psrs.num_points]), + psrs, chains[psrs.num_points] + ), num_steps=chain_shape[0], - trace_fn=lambda *_: ()) + trace_fn=lambda *_: (), + ) running_rhat = fun_mc.potential_scale_reduction_extract( - psrs, independent_chain_ndims=independent_chain_ndims) + psrs, independent_chain_ndims=independent_chain_ndims + ) self.assertAllClose(true_rhat, running_rhat) @parameterized.named_parameters( @@ -1637,8 +1708,9 @@ def testPotentialScaleReduction(self, chain_shape, independent_chain_ndims): ('Aggregated0', (3, 2), 1, 0), ('Aggregated01', (3, 4, 2), 1, (0, 1)), ) - def testRunningApproximateAutoCovariance(self, state_shape, event_ndims, - aggregation): + def testRunningApproximateAutoCovariance( + self, state_shape, event_ndims, aggregation + ): # We'll use HMC as the source of our chain. # While HMC is being sampled, we also compute the running autocovariance. step_size = 0.2 @@ -1646,14 +1718,14 @@ def testRunningApproximateAutoCovariance(self, state_shape, event_ndims, num_leapfrog_steps = 10 max_lags = 300 - state = tf.zeros(state_shape, dtype=self._dtype) + state = jnp.zeros(state_shape, dtype=self._dtype) def target_log_prob_fn(x): - lp = -0.5 * tf.square(x) + lp = -0.5 * jnp.square(x) if event_ndims is None: return lp, () else: - return tf.reduce_sum(lp, -1), () + return jnp.sum(lp, -1), () def kernel(hmc_state, raac_state, seed): hmc_seed, seed = util.split_seed(seed, 2) @@ -1662,123 +1734,129 @@ def kernel(hmc_state, raac_state, seed): step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, - seed=hmc_seed) + seed=hmc_seed, + ) raac_state, _ = fun_mc.running_approximate_auto_covariance_step( - raac_state, hmc_state.state, axis=aggregation) + raac_state, hmc_state.state, axis=aggregation + ) return (hmc_state, raac_state, seed), hmc_extra seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. - (_, raac_state, _), chain = tf.function(lambda state, seed: fun_mc.trace( # pylint: disable=g-long-lambda - state=( - fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), - fun_mc.running_approximate_auto_covariance_init( - max_lags=max_lags, - state_shape=state_shape, - dtype=state.dtype, - axis=aggregation), - seed, - ), - fn=kernel, - num_steps=num_steps, - trace_fn=lambda state, extra: state[0].state))(state, seed) - - true_aggregation = (0,) + (() if aggregation is None else tuple( - [a + 1 for a in util.flatten_tree(aggregation)])) - true_variance = np.array( - tf.math.reduce_variance(np.array(chain), true_aggregation)) + (_, raac_state, _), chain = jax.jit( + lambda state, seed: fun_mc.trace( # pylint: disable=g-long-lambda + state=( + fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), + fun_mc.running_approximate_auto_covariance_init( + max_lags=max_lags, + state_shape=state_shape, + dtype=state.dtype, + axis=aggregation, + ), + seed, + ), + fn=kernel, + num_steps=num_steps, + trace_fn=lambda state, extra: state[0].state, + ) + )(state, seed) + + true_aggregation = (0,) + ( + () + if aggregation is None + else tuple([a + 1 for a in util.flatten_tree(aggregation)]) + ) + true_variance = np.array(jnp.var(np.array(chain), true_aggregation)) true_autocov = np.array( - tfp.stats.auto_correlation(np.array(chain), axis=0, max_lags=max_lags)) + tfp.stats.auto_correlation(np.array(chain), axis=0, max_lags=max_lags) + ) if aggregation is not None: - true_autocov = tf.reduce_mean( - true_autocov, [a + 1 for a in util.flatten_tree(aggregation)]) + true_autocov = jnp.mean( + true_autocov, [a + 1 for a in util.flatten_tree(aggregation)] + ) self.assertAllClose(true_variance, raac_state.auto_covariance[0], 1e-5) self.assertAllClose( true_autocov, raac_state.auto_covariance / raac_state.auto_covariance[0], - atol=0.1) + atol=0.1, + ) @parameterized.named_parameters( - ('Positional1', 0.), - ('Positional2', (0., 1.)), - ('Named1', { - 'a': 0. - }), - ('Named2', { - 'a': 0., - 'b': 1. - }), + ('Positional1', 0.0), + ('Positional2', (0.0, 1.0)), + ('Named1', {'a': 0.0}), + ('Named2', {'a': 0.0, 'b': 1.0}), ) def testSurrogateLossFn(self, state): - def grad_fn(*args, **kwargs): # This is uglier than user code due to the parameterized test... new_state = util.unflatten_tree(state, util.flatten_tree((args, kwargs))) - return util.map_tree(lambda x: x + 1., new_state), new_state + return util.map_tree(lambda x: x + 1.0, new_state), new_state loss_fn = fun_mc.make_surrogate_loss_fn(grad_fn) # Mutate the state to make sure we didn't capture anything. - state = util.map_tree(lambda x: self._constant(x + 1.), state) + state = util.map_tree(lambda x: self._constant(x + 1.0), state) ret, extra, grads = fun_mc.call_potential_fn_with_grads(loss_fn, state) # The default is 0. - self.assertAllClose(0., ret) + self.assertAllClose(0.0, ret) # The gradients of the surrogate loss are state + 1. - self.assertAllCloseNested(util.map_tree(lambda x: x + 1., state), grads) + self.assertAllCloseNested(util.map_tree(lambda x: x + 1.0, state), grads) self.assertAllCloseNested(state, extra) def testSurrogateLossFnDecorator(self): - - @fun_mc.make_surrogate_loss_fn(loss_value=1.) + @fun_mc.make_surrogate_loss_fn(loss_value=1.0) def loss_fn(_): - return 3., 2. + return 3.0, 2.0 - ret, extra, grads = fun_mc.call_potential_fn_with_grads(loss_fn, 0.) - self.assertAllClose(1., ret) - self.assertAllClose(2., extra) - self.assertAllClose(3., grads) + ret, extra, grads = fun_mc.call_potential_fn_with_grads(loss_fn, 0.0) + self.assertAllClose(1.0, ret) + self.assertAllClose(2.0, extra) + self.assertAllClose(3.0, grads) @parameterized.named_parameters( ('Probability', True), ('Loss', False), ) def testReparameterizeFn(self, track_volume): - def potential_fn(x, y): - return -x**2 + -y**2, () + return -(x**2) + -(y**2), () def transport_map_fn(x, y): - return [2 * x, 3 * y], ((), tf.math.log(2.) + tf.math.log(3.)) + return [2 * x, 3 * y], ((), jnp.log(2.0) + jnp.log(3.0)) def inverse_map_fn(x, y): - return [x / 2, y / 3], ((), -tf.math.log(2.) - tf.math.log(3.)) + return [x / 2, y / 3], ((), -jnp.log(2.0) - jnp.log(3.0)) transport_map_fn.inverse = inverse_map_fn - (transformed_potential_fn, - transformed_init_state) = fun_mc.reparameterize_potential_fn( - potential_fn, - transport_map_fn, - [self._constant(2.), self._constant(3.)], - track_volume=track_volume) + (transformed_potential_fn, transformed_init_state) = ( + fun_mc.reparameterize_potential_fn( + potential_fn, + transport_map_fn, + [self._constant(2.0), self._constant(3.0)], + track_volume=track_volume, + ) + ) self.assertIsInstance(transformed_init_state, list) - self.assertAllClose([1., 1.], transformed_init_state) - transformed_potential, (orig_space, _, _) = transformed_potential_fn(1., 1.) - potential = potential_fn(2., 3.)[0] + self.assertAllClose([1.0, 1.0], transformed_init_state) + transformed_potential, (orig_space, _, _) = transformed_potential_fn( + 1.0, 1.0 + ) + potential = potential_fn(2.0, 3.0)[0] if track_volume: - potential += tf.math.log(2.) + tf.math.log(3.) + potential += jnp.log(2.0) + jnp.log(3.0) - self.assertAllClose([2., 3.], orig_space) + self.assertAllClose([2.0, 3.0], orig_space) self.assertAllClose(potential, transformed_potential) def testPersistentMH(self): - def target_log_prob_fn(x): - return -x**2 / 2, () + return -(x**2) / 2, () def kernel(pmh_state, rwm_state, seed): seed, rwm_seed = util.split_seed(seed, 2) @@ -1789,71 +1867,110 @@ def kernel(pmh_state, rwm_state, seed): rwm_state, target_log_prob_fn=target_log_prob_fn, proposal_fn=lambda state, seed: fun_mc.gaussian_proposal( # pylint: disable=g-long-lambda - state, seed=seed), - seed=rwm_seed) + state, seed=seed + ), + seed=rwm_seed, + ) pmh_state, pmh_extra = fun_mc.persistent_metropolis_hastings_step( pmh_state, # Use dummy states for testing. - current_state=self._constant(0.), - proposed_state=self._constant(1.), + current_state=self._constant(0.0), + proposed_state=self._constant(1.0), # Coprime with 1000 below. drift=0.127, - energy_change=-rwm_extra.log_accept_ratio) - return (pmh_state, rwm_state, - seed), (pmh_extra.is_accepted, pmh_extra.accepted_state, - rwm_extra.is_accepted) + energy_change=-rwm_extra.log_accept_ratio, + ) + return (pmh_state, rwm_state, seed), ( + pmh_extra.is_accepted, + pmh_extra.accepted_state, + rwm_extra.is_accepted, + ) _, (pmh_is_accepted, pmh_accepted_state, rwm_is_accepted) = fun_mc.trace( - (fun_mc.persistent_metropolis_hastings_init([], self._dtype), - fun_mc.random_walk_metropolis_init( - self._constant(0.), - target_log_prob_fn), self._make_seed(_test_seed())), kernel, 1000) + ( + fun_mc.persistent_metropolis_hastings_init([], self._dtype), + fun_mc.random_walk_metropolis_init( + self._constant(0.0), target_log_prob_fn + ), + self._make_seed(_test_seed()), + ), + kernel, + 1000, + ) - pmh_is_accepted = tf.cast(pmh_is_accepted, self._dtype) - rwm_is_accepted = tf.cast(rwm_is_accepted, self._dtype) + pmh_is_accepted = jnp.array(pmh_is_accepted, self._dtype) + rwm_is_accepted = jnp.array(rwm_is_accepted, self._dtype) self.assertAllClose( - tf.reduce_mean(rwm_is_accepted), - tf.reduce_mean(pmh_is_accepted), - atol=0.05) + jnp.mean(rwm_is_accepted), jnp.mean(pmh_is_accepted), atol=0.05 + ) self.assertAllClose(pmh_is_accepted, pmh_accepted_state) @parameterized.named_parameters( - ('ScalarLarge', -1., lambda x: x**2, True, -1.), + ('ScalarLarge', -1.0, lambda x: x**2, True, -1.0), ('ScalarSmall', -0.1, lambda x: x**2, True, -0.2), - ('Vector', np.array([-3, -4.]), lambda x: tf.reduce_sum(x**2), True, - np.array([-3. / 5., -4. / 5.])), - ('List', [ - -3, - -4., - ], lambda x: x[0]**2 + x[1]**2, True, [ - -3. / 5., - -4. / 5., - ]), - ('Dict', { - 'a': -3, - 'b': -4., - }, lambda x: x['a']**2 + x['b']**2, True, { - 'a': -3. / 5., - 'b': -4. / 5., - }), - ('NaNToZero', [ - -3, - np.float32('nan'), - ], lambda x: x[0]**2 + x[1]**2, True, [ - -1., - 0., - ]), - ('NaNKept', [-3, np.float32('nan')], lambda x: x[0]**2 + x[1]**2, False, - [np.float32('nan'), np.float32('nan')]), + ( + 'Vector', + np.array([-3, -4.0]), + lambda x: jnp.sum(x**2), + True, + np.array([-3.0 / 5.0, -4.0 / 5.0]), + ), + ( + 'List', + [ + -3, + -4.0, + ], + lambda x: x[0] ** 2 + x[1] ** 2, + True, + [ + -3.0 / 5.0, + -4.0 / 5.0, + ], + ), + ( + 'Dict', + { + 'a': -3, + 'b': -4.0, + }, + lambda x: x['a'] ** 2 + x['b'] ** 2, + True, + { + 'a': -3.0 / 5.0, + 'b': -4.0 / 5.0, + }, + ), + ( + 'NaNToZero', + [ + -3, + np.float32('nan'), + ], + lambda x: x[0] ** 2 + x[1] ** 2, + True, + [ + -1.0, + 0.0, + ], + ), + ( + 'NaNKept', + [-3, np.float32('nan')], + lambda x: x[0] ** 2 + x[1] ** 2, + False, + [np.float32('nan'), np.float32('nan')], + ), ) def testClipGradients(self, x, fn, zero_out_nan, expected_grad): x = util.map_tree(self._constant, x) - max_global_norm = self._constant(1.) + max_global_norm = self._constant(1.0) expected_grad = util.map_tree(self._constant, expected_grad) def eval_fn(x): x = fun_mc.clip_grads( - x, max_global_norm=max_global_norm, zero_out_nan=zero_out_nan) + x, max_global_norm=max_global_norm, zero_out_nan=zero_out_nan + ) return fn(x), () value, _, (grad,) = fun_mc.call_potential_fn_with_grads(eval_fn, (x,)) @@ -1862,54 +1979,53 @@ def eval_fn(x): self.assertAllCloseNested(expected_grad, grad) def testSystematicResample(self): - probs = self._constant([0., 0.5, 0.2, 0.3, 0.]) - log_weights = tf.math.log(probs) - particles = tf.range(probs.shape[0]) + probs = self._constant([0.0, 0.5, 0.2, 0.3, 0.0]) + log_weights = jnp.log(probs) + particles = jnp.arange(probs.shape[0]) - @tf.function + @jax.jit def body(seed): seed, resample_seed = util.split_seed(seed, 2) - (new_particles, - new_log_weights), _ = fun_mc.systematic_resample(particles, log_weights, - resample_seed) + (new_particles, new_log_weights), _ = fun_mc.systematic_resample( + particles, log_weights, resample_seed + ) return seed, (new_particles, new_log_weights) _, (new_particles, new_log_weights) = fun_mc.trace( - self._make_seed(_test_seed()), body, 1000, trace_mask=(True, False)) + self._make_seed(_test_seed()), body, 1000, trace_mask=(True, False) + ) - new_particles_probs = tf.reduce_mean( - tf.cast(new_particles[..., tf.newaxis] == particles, tf.float32), - (0, 1)) + new_particles_probs = jnp.mean( + jnp.array(new_particles[..., jnp.newaxis] == particles, jnp.float32), + (0, 1), + ) self.assertAllClose(new_particles_probs, probs, atol=0.05) - self.assertEqual(new_particles_probs[0], 0.) - self.assertEqual(new_particles_probs[-1], 0.) + self.assertEqual(new_particles_probs[0], 0.0) + self.assertEqual(new_particles_probs[-1], 0.0) self.assertAllClose( new_log_weights, - tf.fill(probs.shape, tfp.math.reduce_logmeanexp(log_weights))) + jnp.full(probs.shape, tfp.math.reduce_logmeanexp(log_weights)), + ) def testSystematicResampleAncestors(self): - log_weights = self._constant([-float('inf'), 0.]) - particles = tf.range(log_weights.shape[0]) + log_weights = self._constant([-float('inf'), 0.0]) + particles = jnp.arange(log_weights.shape[0]) seed = self._make_seed(_test_seed()) (new_particles, new_log_weights), ancestors = fun_mc.systematic_resample( particles, log_weights, seed=seed ) - self.assertAllEqual(new_particles, tf.ones_like(particles)) - self.assertAllEqual( - new_log_weights, tf.math.log(self._constant([0.5, 0.5])) - ) - self.assertAllEqual(ancestors, tf.ones_like(particles)) + self.assertAllEqual(new_particles, jnp.ones_like(particles)) + self.assertAllEqual(new_log_weights, jnp.log(self._constant([0.5, 0.5]))) + self.assertAllEqual(ancestors, jnp.ones_like(particles)) (new_particles, new_log_weights), ancestors = fun_mc.systematic_resample( particles, log_weights, do_resample=True, seed=seed ) - self.assertAllEqual(new_particles, tf.ones_like(particles)) - self.assertAllEqual( - new_log_weights, tf.math.log(self._constant([0.5, 0.5])) - ) - self.assertAllEqual(ancestors, tf.ones_like(particles)) + self.assertAllEqual(new_particles, jnp.ones_like(particles)) + self.assertAllEqual(new_log_weights, jnp.log(self._constant([0.5, 0.5]))) + self.assertAllEqual(ancestors, jnp.ones_like(particles)) (new_particles, new_log_weights), ancestors = fun_mc.systematic_resample( particles, log_weights, do_resample=False, seed=seed @@ -1919,41 +2035,42 @@ def testSystematicResampleAncestors(self): self.assertAllEqual(ancestors, particles) def testAIS(self): - def tlp_1(x): - return -x**2 / 2., () + return -(x**2) / 2.0, () def tlp_2(x): - return -(x - 2)**2 / 2 / 16., () + return -((x - 2) ** 2) / 2 / 16.0, () - @tf.function + @jax.jit def kernel(ais_state, seed): - hmc_seed, resample_seed, seed = util.split_seed(seed, 3) ais_state, _ = fun_mc.annealed_importance_sampling_resample( - ais_state, - seed=resample_seed) + ais_state, seed=resample_seed + ) def transition_operator(state, stage, tlp_fn): - f = tf.cast(stage, state.dtype) / num_stages + f = jnp.array(stage, state.dtype) / num_stages hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, tlp_fn) hmc_state, _ = fun_mc.hamiltonian_monte_carlo_step( hmc_state, tlp_fn, - step_size=f * 4. + (1. - f) * 1., + step_size=f * 4.0 + (1.0 - f) * 1.0, num_integrator_steps=1, - seed=hmc_seed) + seed=hmc_seed, + ) return hmc_state.state, () ais_state, _ = fun_mc.annealed_importance_sampling_step( - ais_state, transition_operator, + ais_state, + transition_operator, functools.partial( fun_mc.geometric_annealing_path, num_stages=num_stages, initial_target_log_prob_fn=tlp_1, final_target_log_prob_fn=tlp_2, - )) + ), + ) return (ais_state, seed), () @@ -1963,18 +2080,23 @@ def transition_operator(state, stage, tlp_fn): init_state = util.random_normal([num_particles], self._dtype, init_seed) (ais_state, _), _ = fun_mc.trace( - (fun_mc.annealed_importance_sampling_init(init_state, tf.zeros( - [num_particles], self._dtype)), seed), + ( + fun_mc.annealed_importance_sampling_init( + init_state, jnp.zeros([num_particles], self._dtype) + ), + seed, + ), kernel, num_stages, ) - weights = tf.exp(ais_state.log_weight) - self.assertAllClose(4., tf.reduce_mean(weights), atol=0.7) + weights = jnp.exp(ais_state.log_weight) + self.assertAllClose(4.0, jnp.mean(weights), atol=0.7) self.assertAllClose( - 2., - tf.reduce_sum(tf.nn.softmax(ais_state.log_weight) * ais_state.state), - atol=0.8) + 2.0, + jnp.sum(jax.nn.softmax(ais_state.log_weight) * ais_state.state), + atol=0.8, + ) @test_util.multi_backend_test(globals(), 'fun_mc_test') @@ -1982,7 +2104,7 @@ class FunMCTest32(FunMCTest): @property def _dtype(self): - return tf.float32 + return jnp.float32 @test_util.multi_backend_test(globals(), 'fun_mc_test') @@ -1990,7 +2112,7 @@ class FunMCTest64(FunMCTest): @property def _dtype(self): - return tf.float64 + return jnp.float64 del FunMCTest diff --git a/spinoffs/fun_mc/fun_mc/malt.py b/spinoffs/fun_mc/fun_mc/malt.py index 072d66ed52..6f662cabad 100644 --- a/spinoffs/fun_mc/fun_mc/malt.py +++ b/spinoffs/fun_mc/fun_mc/malt.py @@ -27,7 +27,8 @@ from fun_mc import backend from fun_mc import fun_mc_lib as fun_mc -tf = backend.tf +jax = backend.jax +jnp = backend.jnp tfp = backend.tfp util = backend.util distribute_lib = backend.distribute_lib @@ -40,66 +41,84 @@ ] -def _gaussian_momentum_refresh_fn(old_momentum: fun_mc.State, - damping: Optional[fun_mc.FloatTensor] = 0., - step_size: Optional[fun_mc.FloatTensor] = 1., - named_axis: Optional[ - fun_mc.StringNest] = None, - seed: Optional[Any] = None) -> fun_mc.State: +def _gaussian_momentum_refresh_fn( + old_momentum: fun_mc.State, + damping: Optional[float | fun_mc.FloatArray] = 0.0, + step_size: Optional[float | fun_mc.FloatArray] = 1.0, + named_axis: Optional[fun_mc.StringNest] = None, + seed: Optional[Any] = None, +) -> fun_mc.State: """Momentum refresh function for Gaussian momentum distribution.""" if named_axis is None: named_axis = util.map_tree(lambda _: [], old_momentum) damping = fun_mc.maybe_broadcast_structure(damping, old_momentum) step_size = fun_mc.maybe_broadcast_structure(step_size, old_momentum) - decay_fraction = util.map_tree(lambda d, s: tf.exp(-d * s), damping, - step_size) - noise_fraction = util.map_tree(lambda df: tf.sqrt(1. - tf.square(df)), - decay_fraction) + decay_fraction = util.map_tree( + lambda d, s: jnp.exp(-d * s), damping, step_size + ) + noise_fraction = util.map_tree( + lambda df: jnp.sqrt(1.0 - jnp.square(df)), decay_fraction + ) - def _sample_part(old_momentum, seed, named_axis, decay_fraction, - noise_fraction): + def _sample_part( + old_momentum, seed, named_axis, decay_fraction, noise_fraction + ): seed = backend.distribute_lib.fold_in_axis_index(seed, named_axis) - return (decay_fraction * old_momentum + noise_fraction * - util.random_normal(old_momentum.shape, old_momentum.dtype, seed)) + return decay_fraction * old_momentum + noise_fraction * util.random_normal( + old_momentum.shape, old_momentum.dtype, seed + ) seeds = util.unflatten_tree( - old_momentum, util.split_seed(seed, len(util.flatten_tree(old_momentum)))) - new_momentum = util.map_tree_up_to(old_momentum, _sample_part, old_momentum, - seeds, named_axis, decay_fraction, - noise_fraction) + old_momentum, util.split_seed(seed, len(util.flatten_tree(old_momentum))) + ) + new_momentum = util.map_tree_up_to( + old_momentum, + _sample_part, + old_momentum, + seeds, + named_axis, + decay_fraction, + noise_fraction, + ) return new_momentum def _default_energy_change_fn( - old_int_state: fun_mc.State, - new_int_state: fun_mc.State, + old_int_state: fun_mc.IntegratorState, + new_int_state: fun_mc.IntegratorState, kinetic_energy_fn: Optional[fun_mc.PotentialFn], -) -> Tuple[fun_mc.FloatTensor, Tuple[Any, Any]]: +) -> Tuple[fun_mc.FloatArray, Tuple[Any, Any]]: """Default energy change function.""" old_kinetic_energy, old_kinetic_energy_extra = fun_mc.call_potential_fn( - kinetic_energy_fn, old_int_state.momentum) + kinetic_energy_fn, old_int_state.momentum + ) new_kinetic_energy, new_kinetic_energy_extra = fun_mc.call_potential_fn( - kinetic_energy_fn, new_int_state.momentum) + kinetic_energy_fn, new_int_state.momentum + ) old_energy = -old_int_state.target_log_prob + old_kinetic_energy new_energy = -new_int_state.target_log_prob + new_kinetic_energy - return new_energy - old_energy, (old_kinetic_energy_extra, - new_kinetic_energy_extra) + return new_energy - old_energy, ( + old_kinetic_energy_extra, + new_kinetic_energy_extra, + ) class MetropolisAdjustedLangevinTrajectoriesState(NamedTuple): """Integrator state.""" + state: fun_mc.State state_extra: Any state_grads: fun_mc.State - target_log_prob: fun_mc.FloatTensor + target_log_prob: fun_mc.FloatArray class MetropolisAdjustedLangevinTrajectoriesExtra(NamedTuple): """Hamiltonian Monte Carlo extra outputs.""" - is_accepted: fun_mc.BooleanTensor - log_accept_ratio: fun_mc.FloatTensor + + is_accepted: fun_mc.BooleanArray + log_accept_ratio: fun_mc.FloatArray proposed_malt_state: fun_mc.State integrator_state: fun_mc.IntegratorState integrator_extra: fun_mc.IntegratorExtras @@ -119,45 +138,61 @@ def metropolis_adjusted_langevin_trajectories_init( malt_state: State of the `metropolis_adjusted_langevin_trajectories_step` `TransitionOperator`. """ - state = util.map_tree(tf.convert_to_tensor, state) + state = util.map_tree(jnp.array, state) target_log_prob, state_extra, state_grads = util.map_tree( - tf.convert_to_tensor, + jnp.array, fun_mc.call_potential_fn_with_grads(target_log_prob_fn, state), ) return MetropolisAdjustedLangevinTrajectoriesState( state=state, state_grads=state_grads, target_log_prob=target_log_prob, - state_extra=state_extra) + state_extra=state_extra, + ) def metropolis_adjusted_langevin_trajectories_step( malt_state: MetropolisAdjustedLangevinTrajectoriesState, target_log_prob_fn: fun_mc.PotentialFn, step_size: Optional[Any] = None, - num_integrator_steps: Optional[fun_mc.IntTensor] = None, - damping: Optional[fun_mc.FloatTensor] = None, + num_integrator_steps: Optional[fun_mc.IntArray] = None, + damping: Optional[fun_mc.FloatArray] = None, momentum: Optional[fun_mc.State] = None, - integrator_trace_fn: Optional[Callable[[ - fun_mc.IntegratorState, fun_mc.IntegratorStepState, fun_mc - .IntegratorStepExtras - ], fun_mc.TensorNest]] = None, + integrator_trace_fn: Optional[ + Callable[ + [ + fun_mc.IntegratorState, + fun_mc.IntegratorStepState, + fun_mc.IntegratorStepExtras, + ], + fun_mc.ArrayNest, + ] + ] = None, unroll_integrator: bool = False, - log_uniform: Optional[fun_mc.FloatTensor] = None, + log_uniform: Optional[fun_mc.FloatArray] = None, kinetic_energy_fn: Optional[fun_mc.PotentialFn] = None, momentum_sample_fn: Optional[fun_mc.MomentumSampleFn] = None, - momentum_refresh_fn: Optional[Callable[[fun_mc.State, Any], - fun_mc.State]] = None, + momentum_refresh_fn: Optional[ + Callable[[fun_mc.State, Any], fun_mc.State] + ] = None, energy_change_fn: Optional[ - Callable[[fun_mc.IntegratorState, fun_mc.IntegratorState], - Tuple[fun_mc.FloatTensor, Any]]] = None, - integrator_fn: Optional[Callable[[fun_mc.IntegratorState, Any], - Tuple[fun_mc.IntegratorState, - fun_mc.IntegratorExtras]]] = None, + Callable[ + [fun_mc.IntegratorState, fun_mc.IntegratorState], + Tuple[fun_mc.FloatArray, Any], + ] + ] = None, + integrator_fn: Optional[ + Callable[ + [fun_mc.IntegratorState, Any], + Tuple[fun_mc.IntegratorState, fun_mc.IntegratorExtras], + ] + ] = None, named_axis: Optional[fun_mc.StringNest] = None, - seed: Any = None -) -> Tuple[MetropolisAdjustedLangevinTrajectoriesState, - MetropolisAdjustedLangevinTrajectoriesExtra]: + seed: Any = None, +) -> Tuple[ + MetropolisAdjustedLangevinTrajectoriesState, + MetropolisAdjustedLangevinTrajectoriesExtra, +]: """MALT `TransitionOperator`. This implements the Metropolis Adjusted Langevin Trajectories (MALT) algorithm @@ -246,24 +281,29 @@ def orig_target_log_prob_fn(x): if integrator_fn is None: if kinetic_energy_fn is None: kinetic_energy_fn = fun_mc.make_gaussian_kinetic_energy_fn( - (len(target_log_prob.shape) - if target_log_prob.shape is not None else tf.rank(target_log_prob)), # pytype: disable=attribute-error - named_axis=named_axis) + ( + len(target_log_prob.shape) + if target_log_prob.shape is not None + else len(target_log_prob.shape) + ), # pytype: disable=attribute-error + named_axis=named_axis, + ) if energy_change_fn is None: energy_change_fn = lambda old_is, new_is: _default_energy_change_fn( # pylint: disable=g-long-lambda - old_is, new_is, kinetic_energy_fn) + old_is, new_is, kinetic_energy_fn + ) if momentum_sample_fn is None: momentum_sample_fn = lambda seed: fun_mc.gaussian_momentum_sample( # pylint: disable=g-long-lambda - state=malt_state.state, - seed=seed, - named_axis=named_axis) + state=malt_state.state, seed=seed, named_axis=named_axis + ) if momentum_refresh_fn is None: momentum_refresh_fn = lambda m, seed: _gaussian_momentum_refresh_fn( # pylint: disable=g-long-lambda m, seed=seed, - step_size=step_size / 2., + step_size=step_size / 2.0, damping=damping, - named_axis=named_axis) + named_axis=named_axis, + ) integrator_fn = lambda int_state, seed: fun_mc.obabo_langevin_integrator( # pylint: disable=g-long-lambda int_state=int_state, num_steps=num_integrator_steps, @@ -271,12 +311,14 @@ def orig_target_log_prob_fn(x): fun_mc.leapfrog_step, step_size=step_size, target_log_prob_fn=target_log_prob_fn, - kinetic_energy_fn=kinetic_energy_fn), + kinetic_energy_fn=kinetic_energy_fn, + ), momentum_refresh_fn=momentum_refresh_fn, integrator_trace_fn=integrator_trace_fn, energy_change_fn=energy_change_fn, unroll=unroll_integrator, - seed=seed) + seed=seed, + ) mh_seed, sample_seed, integrator_seed = util.split_seed(seed, 3) if momentum is None: @@ -290,21 +332,24 @@ def orig_target_log_prob_fn(x): momentum=momentum, ) - integrator_state, integrator_extra = integrator_fn(initial_integrator_state, - integrator_seed) + integrator_state, integrator_extra = integrator_fn( + initial_integrator_state, integrator_seed + ) proposed_state = MetropolisAdjustedLangevinTrajectoriesState( state=integrator_state.state, state_grads=integrator_state.state_grads, target_log_prob=integrator_state.target_log_prob, - state_extra=integrator_state.state_extra) + state_extra=integrator_state.state_extra, + ) malt_state, mh_extra = fun_mc.metropolis_hastings_step( malt_state, proposed_state, integrator_extra.energy_change, log_uniform=log_uniform, - seed=mh_seed) + seed=mh_seed, + ) return malt_state, MetropolisAdjustedLangevinTrajectoriesExtra( is_accepted=mh_extra.is_accepted, @@ -312,4 +357,5 @@ def orig_target_log_prob_fn(x): log_accept_ratio=-integrator_extra.energy_change, integrator_state=integrator_state, integrator_extra=integrator_extra, - initial_momentum=momentum) + initial_momentum=momentum, + ) diff --git a/spinoffs/fun_mc/fun_mc/malt_test.py b/spinoffs/fun_mc/fun_mc/malt_test.py index beb927a192..091c269489 100644 --- a/spinoffs/fun_mc/fun_mc/malt_test.py +++ b/spinoffs/fun_mc/fun_mc/malt_test.py @@ -19,18 +19,18 @@ # Dependency imports -import jax +import jax as real_jax from jax import config as jax_config import numpy as np import tensorflow.compat.v2 as real_tf - from tensorflow_probability.python.internal import test_util as tfp_test_util from fun_mc import backend from fun_mc import fun_mc_lib as fun_mc from fun_mc import malt from fun_mc import test_util -tf = backend.tf +jax = backend.jax +jnp = backend.jnp tfp = backend.tfp util = backend.util tfd = tfp.distributions @@ -38,13 +38,16 @@ Root = tfd.JointDistributionCoroutine.Root real_tf.enable_v2_behavior() +real_tf.experimental.numpy.experimental_enable_numpy_behavior() jax_config.update('jax_enable_x64', True) BACKEND = None # Rewritten by backends/rewrite.py. if BACKEND == 'backend_jax': - os.environ['XLA_FLAGS'] = (f'{os.environ.get("XLA_FLAGS", "")} ' - '--xla_force_host_platform_device_count=4') + os.environ['XLA_FLAGS'] = ( + f'{os.environ.get("XLA_FLAGS", "")} ' + '--xla_force_host_platform_device_count=4' + ) def _test_seed(): @@ -52,7 +55,6 @@ def _test_seed(): class MALTTest(tfp_test_util.TestCase): - _is_on_jax = BACKEND == 'backend_jax' def _make_seed(self, seed): @@ -66,28 +68,30 @@ def _dtype(self): raise NotImplementedError() def _constant(self, value): - return tf.constant(value, self._dtype) + return jnp.array(value, self._dtype) def testPreconditionedMALT(self): step_size = self._constant(0.2) num_steps = 2000 num_leapfrog_steps = 10 damping = 0.5 - state = tf.ones([16, 2], dtype=self._dtype) + state = jnp.ones([16, 2], dtype=self._dtype) - base_mean = self._constant([1., 0]) + base_mean = self._constant([1.0, 0]) base_cov = self._constant([[1, 0.5], [0.5, 1]]) bijector = tfp.bijectors.Softplus() base_dist = tfp.distributions.MultivariateNormalFullCovariance( - loc=base_mean, covariance_matrix=base_cov) + loc=base_mean, covariance_matrix=base_cov + ) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mc.transform_log_prob_fn( - orig_target_log_prob_fn, bijector, state) + orig_target_log_prob_fn, bijector, state + ) # pylint: disable=g-long-lambda def kernel(malt_state, seed): @@ -98,27 +102,33 @@ def kernel(malt_state, seed): damping=damping, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, - seed=malt_seed) + seed=malt_seed, + ) return (malt_state, seed), malt_state.state_extra[0] seed = self._make_seed(_test_seed()) - # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs - # for the jit to do anything. - _, chain = tf.function(lambda state, seed: fun_mc.trace( # pylint: disable=g-long-lambda - state=(malt.metropolis_adjusted_langevin_trajectories_init( - state, target_log_prob_fn), seed), - fn=kernel, - num_steps=num_steps))(state, seed) + _, chain = jax.jit( + lambda state, seed: fun_mc.trace( # pylint: disable=g-long-lambda + state=( + malt.metropolis_adjusted_langevin_trajectories_init( + state, target_log_prob_fn + ), + seed, + ), + fn=kernel, + num_steps=num_steps, + ) + )(state, seed) # Discard the warmup samples. chain = chain[1000:] - sample_mean = tf.reduce_mean(chain, axis=[0, 1]) + sample_mean = jnp.mean(chain, axis=[0, 1]) sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) true_samples = target_dist.sample(4096, seed=self._make_seed(_test_seed())) - true_mean = tf.reduce_mean(true_samples, axis=0) + true_mean = jnp.mean(true_samples, axis=0) true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) @@ -129,8 +139,8 @@ def testMALTNamedAxis(self): self.skipTest('JAX-only') state = { - 'sharded': tf.zeros([4, 1024], self._dtype), - 'shared': tf.zeros([1024], self._dtype), + 'sharded': jnp.zeros([4, 1024], self._dtype), + 'shared': jnp.zeros([1024], self._dtype), } in_axes = { 'sharded': 0, @@ -142,36 +152,51 @@ def testMALTNamedAxis(self): } def target_log_prob_fn(sharded, shared): - return -(backend.distribute_lib.psum(tf.square(sharded), 'named_axis') + - tf.square(shared)), () + return ( + -( + backend.distribute_lib.psum(jnp.square(sharded), 'named_axis') + + jnp.square(shared) + ), + (), + ) @functools.partial( - jax.pmap, in_axes=(in_axes, None), axis_name='named_axis') + real_jax.pmap, in_axes=(in_axes, None), axis_name='named_axis' + ) def kernel(state, seed): malt_state = malt.metropolis_adjusted_langevin_trajectories_init( - state, target_log_prob_fn=target_log_prob_fn) - (malt_state, - malt_extra) = malt.metropolis_adjusted_langevin_trajectories_step( - malt_state, - damping=self._constant(0.5), - step_size=self._constant(0.2), - num_integrator_steps=4, - target_log_prob_fn=target_log_prob_fn, - named_axis=named_axis, - seed=seed) + state, target_log_prob_fn=target_log_prob_fn + ) + (malt_state, malt_extra) = ( + malt.metropolis_adjusted_langevin_trajectories_step( + malt_state, + damping=self._constant(0.5), + step_size=self._constant(0.2), + num_integrator_steps=4, + target_log_prob_fn=target_log_prob_fn, + named_axis=named_axis, + seed=seed, + ) + ) return malt_state, malt_extra seed = self._make_seed(_test_seed()) malt_state, malt_extra = kernel(state, seed) - self.assertAllClose(malt_state.state['shared'][0], - malt_state.state['shared'][1]) + self.assertAllClose( + malt_state.state['shared'][0], malt_state.state['shared'][1] + ) self.assertTrue( np.any( - np.abs(malt_state.state['sharded'][0] - - malt_state.state['sharded'][1]) > 1e-3)) + np.abs( + malt_state.state['sharded'][0] - malt_state.state['sharded'][1] + ) + > 1e-3 + ) + ) self.assertAllClose(malt_extra.is_accepted[0], malt_extra.is_accepted[1]) - self.assertAllClose(malt_extra.log_accept_ratio[0], - malt_extra.log_accept_ratio[1]) + self.assertAllClose( + malt_extra.log_accept_ratio[0], malt_extra.log_accept_ratio[1] + ) @test_util.multi_backend_test(globals(), 'malt_test') @@ -179,7 +204,7 @@ class MALTTest32(MALTTest): @property def _dtype(self): - return tf.float32 + return jnp.float32 @test_util.multi_backend_test(globals(), 'malt_test') @@ -187,7 +212,7 @@ class MALTTest64(MALTTest): @property def _dtype(self): - return tf.float64 + return jnp.float64 del MALTTest diff --git a/spinoffs/fun_mc/fun_mc/prefab.py b/spinoffs/fun_mc/fun_mc/prefab.py index edcf37c96c..d42963c0cf 100644 --- a/spinoffs/fun_mc/fun_mc/prefab.py +++ b/spinoffs/fun_mc/fun_mc/prefab.py @@ -32,49 +32,58 @@ from fun_mc import fun_mc_lib as fun_mc from fun_mc import malt from fun_mc import sga_hmc -# Re-export sga_hmc and malt symbols. +## Re-export sga_hmc and malt symbols. from fun_mc.malt import * # pylint: disable=wildcard-import from fun_mc.sga_hmc import * # pylint: disable=wildcard-import -tf = backend.tf +jnp = backend.jnp tfp = backend.tfp util = backend.util -__all__ = [ - 'adaptive_hamiltonian_monte_carlo_init', - 'adaptive_hamiltonian_monte_carlo_step', - 'AdaptiveHamiltonianMonteCarloState', - 'interactive_trace', - 'step_size_adaptation_init', - 'step_size_adaptation_step', - 'StepSizeAdaptationExtra', - 'StepSizeAdaptationState', -] + sga_hmc.__all__ + malt.__all__ +__all__ = ( + [ + 'adaptive_hamiltonian_monte_carlo_init', + 'adaptive_hamiltonian_monte_carlo_step', + 'AdaptiveHamiltonianMonteCarloState', + 'interactive_trace', + 'step_size_adaptation_init', + 'step_size_adaptation_step', + 'StepSizeAdaptationExtra', + 'StepSizeAdaptationState', + ] + + sga_hmc.__all__ + + malt.__all__ +) @util.named_call(name='polynomial_decay') -def _polynomial_decay(step: fun_mc.AnyTensor, - step_size: fun_mc.FloatTensor, - decay_steps: fun_mc.AnyTensor, - final_step_size: fun_mc.FloatTensor, - power: fun_mc.FloatTensor = 1.) -> fun_mc.FloatTensor: +def _polynomial_decay( + step: fun_mc.AnyArray, + step_size: fun_mc.FloatArray, + decay_steps: fun_mc.AnyArray, + final_step_size: fun_mc.FloatArray, + power: float | fun_mc.FloatArray = 1.0, +) -> fun_mc.FloatArray: """Polynomial decay step size schedule.""" - step_size = tf.convert_to_tensor(step_size) - step_f = tf.cast(step, step_size.dtype) - decay_steps_f = tf.cast(decay_steps, step_size.dtype) - step_mult = (1. - step_f / decay_steps_f)**power - step_mult = tf.where(step >= decay_steps, tf.zeros_like(step_mult), step_mult) + step_size = jnp.asarray(step_size) + step_f = jnp.array(step, step_size.dtype) + decay_steps_f = jnp.array(decay_steps, step_size.dtype) + step_mult = (1.0 - step_f / decay_steps_f) ** power + step_mult = jnp.where( + step >= decay_steps, jnp.zeros_like(step_mult), step_mult + ) return step_mult * (step_size - final_step_size) + final_step_size class StepSizeAdaptationState(NamedTuple): """Step size adaptation state.""" - step: fun_mc.IntTensor + + step: fun_mc.IntArray opt_state: fun_mc.AdamState rms_state: fun_mc.RunningMeanState def opt_step_size(self): - return tf.exp(self.opt_state.state) + return jnp.exp(self.opt_state.state) @property def rms_step_size(self): @@ -82,36 +91,41 @@ def rms_step_size(self): def step_size(self, num_adaptation_steps=None): if num_adaptation_steps is not None: - return tf.where(self.step < num_adaptation_steps, self.opt_step_size(), - self.rms_step_size) + return jnp.where( + self.step < num_adaptation_steps, + self.opt_step_size(), + self.rms_step_size, + ) else: return self.opt_step_size() class StepSizeAdaptationExtra(NamedTuple): opt_extra: fun_mc.AdamExtra - accept_prob: fun_mc.FloatTensor + accept_prob: fun_mc.FloatArray @util.named_call def step_size_adaptation_init( - init_step_size: fun_mc.FloatTensor) -> StepSizeAdaptationState: + init_step_size: fun_mc.FloatArray, +) -> StepSizeAdaptationState: """Initializes `StepSizeAdaptationState`. Args: - init_step_size: Floating point Tensor. Initial step size. + init_step_size: Floating point Array. Initial step size. Returns: step_size_adaptation_state: `StepSizeAdaptationState` """ - init_step_size = tf.convert_to_tensor(init_step_size) - rms_state = fun_mc.running_mean_init(init_step_size.shape, - init_step_size.dtype) + init_step_size = jnp.asarray(init_step_size) + rms_state = fun_mc.running_mean_init( + init_step_size.shape, init_step_size.dtype + ) rms_state = rms_state._replace(mean=init_step_size) return StepSizeAdaptationState( - step=tf.constant(0, tf.int32), - opt_state=fun_mc.adam_init(tf.math.log(init_step_size)), + step=jnp.array(0, jnp.int32), + opt_state=fun_mc.adam_init(jnp.log(init_step_size)), rms_state=rms_state, ) @@ -119,14 +133,14 @@ def step_size_adaptation_init( @util.named_call def step_size_adaptation_step( state: StepSizeAdaptationState, - log_accept_ratio: fun_mc.FloatTensor, - num_adaptation_steps: Optional[fun_mc.IntTensor], - target_accept_prob: fun_mc.FloatTensor = 0.8, - adaptation_rate: fun_mc.FloatTensor = 0.05, - adaptation_rate_decay_power: fun_mc.FloatTensor = 0.1, - averaging_window_steps: fun_mc.IntTensor = 100, - min_log_accept_prob: fun_mc.FloatTensor = np.log(1e-5), - reduce_fn: Callable[[fun_mc.FloatTensor], fun_mc.FloatTensor] = ( + log_accept_ratio: fun_mc.FloatArray, + num_adaptation_steps: Optional[fun_mc.IntArray], + target_accept_prob: float | fun_mc.FloatArray = 0.8, + adaptation_rate: float | fun_mc.FloatArray = 0.05, + adaptation_rate_decay_power: float | fun_mc.FloatArray = 0.1, + averaging_window_steps: int | fun_mc.IntArray = 100, + min_log_accept_prob: fun_mc.FloatArray = np.log(1e-5), + reduce_fn: Callable[[fun_mc.FloatArray], fun_mc.FloatArray] = ( tfp.math.reduce_logmeanexp ), opt_kwargs: Optional[dict[str, Any]] = None, @@ -159,72 +173,85 @@ def step_size_adaptation_step( """ opt_kwargs = {} if opt_kwargs is None else opt_kwargs dtype = log_accept_ratio.dtype - adaptation_rate = tf.convert_to_tensor(adaptation_rate, dtype=dtype) - target_accept_prob = tf.convert_to_tensor(target_accept_prob, dtype=dtype) - adaptation_rate_decay_power = tf.convert_to_tensor( - adaptation_rate_decay_power, dtype=dtype) - min_log_accept_prob = tf.fill(log_accept_ratio.shape, - tf.constant(min_log_accept_prob, dtype)) - - log_accept_prob = tf.minimum(log_accept_ratio, tf.zeros([], dtype)) - log_accept_prob = tf.maximum(log_accept_prob, min_log_accept_prob) - log_accept_prob = tf.where( - tf.math.is_finite(log_accept_prob), log_accept_prob, min_log_accept_prob) - accept_prob = tf.exp(reduce_fn(log_accept_prob)) - - loss_fn = fun_mc.make_surrogate_loss_fn(lambda _: # pylint: disable=g-long-lambda - (target_accept_prob - accept_prob, () - )) + adaptation_rate = jnp.asarray(adaptation_rate, dtype=dtype) + target_accept_prob = jnp.asarray(target_accept_prob, dtype=dtype) + adaptation_rate_decay_power = jnp.asarray( + adaptation_rate_decay_power, dtype=dtype + ) + min_log_accept_prob = jnp.full( + log_accept_ratio.shape, jnp.array(min_log_accept_prob, dtype) + ) + + log_accept_prob = jnp.minimum(log_accept_ratio, jnp.zeros([], dtype)) + log_accept_prob = jnp.maximum(log_accept_prob, min_log_accept_prob) + log_accept_prob = jnp.where( + jnp.isfinite(log_accept_prob), log_accept_prob, min_log_accept_prob + ) + accept_prob = jnp.exp(reduce_fn(log_accept_prob)) + + loss_fn = fun_mc.make_surrogate_loss_fn( + lambda _: ( # pylint: disable=g-long-lambda + target_accept_prob - accept_prob, + (), + ) + ) if num_adaptation_steps is not None: adaptation_rate = _polynomial_decay( step=state.step, step_size=adaptation_rate, decay_steps=num_adaptation_steps, - final_step_size=0., + final_step_size=0.0, power=adaptation_rate_decay_power, ) # Optimize step size. - opt_state, opt_extra = fun_mc.adam_step(state.opt_state, loss_fn, - adaptation_rate, **opt_kwargs) + opt_state, opt_extra = fun_mc.adam_step( + state.opt_state, loss_fn, adaptation_rate, **opt_kwargs + ) # Do iterate averaging. old_rms_state = state.rms_state rms_state, _ = fun_mc.running_mean_step( old_rms_state, - tf.exp(opt_state.state), - window_size=averaging_window_steps) + jnp.exp(opt_state.state), + window_size=averaging_window_steps, + ) if num_adaptation_steps is not None: rms_state = util.map_tree( - lambda n, o: tf.where(state.step < num_adaptation_steps, n, o), - rms_state, old_rms_state) + lambda n, o: jnp.where(state.step < num_adaptation_steps, n, o), + rms_state, + old_rms_state, + ) state = state._replace( - opt_state=opt_state, rms_state=rms_state, step=state.step + 1) + opt_state=opt_state, rms_state=rms_state, step=state.step + 1 + ) extra = StepSizeAdaptationExtra(opt_extra=opt_extra, accept_prob=accept_prob) return state, extra class AdaptiveHamiltonianMonteCarloState(NamedTuple): """Adaptive HMC `TransitionOperator` state.""" + hmc_state: fun_mc.HamiltonianMonteCarloState running_var_state: fun_mc.RunningVarianceState ssa_state: StepSizeAdaptationState - step: fun_mc.IntTensor + step: fun_mc.IntArray class AdaptiveHamiltonianMonteCarloExtra(NamedTuple): """Extra outputs for Adaptive HMC `TransitionOperator`.""" + hmc_state: fun_mc.HamiltonianMonteCarloState hmc_extra: fun_mc.HamiltonianMonteCarloExtra - step_size: fun_mc.FloatTensor - num_leapfrog_steps: fun_mc.IntTensor - mean_num_leapfrog_steps: fun_mc.IntTensor + step_size: fun_mc.FloatArray + num_leapfrog_steps: fun_mc.IntArray + mean_num_leapfrog_steps: fun_mc.IntArray @property - def state(self) -> fun_mc.TensorNest: + def state(self) -> fun_mc.ArrayNest: """Returns the chain state. Note that this assumes that `target_log_prob_fn` has the `state_extra` be a @@ -234,18 +261,18 @@ def state(self) -> fun_mc.TensorNest: return self.hmc_state.state_extra[0] @property - def is_accepted(self) -> fun_mc.BooleanTensor: + def is_accepted(self) -> fun_mc.BooleanArray: return self.hmc_extra.is_accepted @util.named_call def adaptive_hamiltonian_monte_carlo_init( - state: fun_mc.TensorNest, + state: fun_mc.ArrayNest, target_log_prob_fn: fun_mc.PotentialFn, - step_size: fun_mc.FloatTensor = 1e-2, - initial_mean: fun_mc.FloatNest = 0., - initial_scale: fun_mc.FloatNest = 1., - scale_smoothing_steps: fun_mc.IntTensor = 10, + step_size: float | fun_mc.FloatArray = 1e-2, + initial_mean: float | fun_mc.FloatNest = 0.0, + initial_scale: float | fun_mc.FloatNest = 1.0, + scale_smoothing_steps: int | fun_mc.IntArray = 10, ) -> AdaptiveHamiltonianMonteCarloState: """Initializes `AdaptiveHamiltonianMonteCarloState`. @@ -267,7 +294,7 @@ def adaptive_hamiltonian_monte_carlo_init( hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn) dtype = util.flatten_tree(hmc_state.state)[0].dtype chain_ndims = len(hmc_state.target_log_prob.shape) - running_var_state = fun_mc.running_variance_init( + running_var_state = fun_mc.running_variance_init( # pytype: disable=wrong-keyword-args shape=util.map_tree(lambda s: s.shape[chain_ndims:], hmc_state.state), dtype=util.map_tree(lambda s: s.dtype, hmc_state.state), ) @@ -281,24 +308,30 @@ def adaptive_hamiltonian_monte_carlo_init( running_var_state = running_var_state._replace( num_points=util.map_tree( lambda p: ( # pylint: disable=g-long-lambda - int(np.prod(hmc_state.target_log_prob.shape)) * tf.cast( - scale_smoothing_steps, p.dtype)), - running_var_state.num_points), + int(np.prod(hmc_state.target_log_prob.shape)) + * jnp.array(scale_smoothing_steps, p.dtype) + ), + running_var_state.num_points, + ), mean=util.map_tree( # pylint: disable=g-long-lambda - lambda m, init_m: tf.ones_like(m) * init_m, running_var_state.mean, - initial_mean), + lambda m, init_m: jnp.ones_like(m) * init_m, + running_var_state.mean, + initial_mean, + ), variance=util.map_tree( # pylint: disable=g-long-lambda - lambda v, init_s: tf.ones_like(v) * init_s**2, - running_var_state.variance, initial_scale), + lambda v, init_s: jnp.ones_like(v) * init_s**2, + running_var_state.variance, + initial_scale, + ), ) - ssa_state = step_size_adaptation_init( - tf.convert_to_tensor(step_size, dtype=dtype)) + ssa_state = step_size_adaptation_init(jnp.asarray(step_size, dtype=dtype)) return AdaptiveHamiltonianMonteCarloState( hmc_state=hmc_state, running_var_state=running_var_state, ssa_state=ssa_state, - step=tf.zeros([], tf.int32)) + step=jnp.zeros([], jnp.int32), + ) @util.named_call @@ -306,7 +339,7 @@ def _uniform_jitter(mean_num_leapfrog_steps, step, seed): del step return util.random_integer( [], - dtype=tf.int32, + dtype=jnp.int32, minval=1, maxval=2 * mean_num_leapfrog_steps, seed=seed, @@ -317,21 +350,21 @@ def _uniform_jitter(mean_num_leapfrog_steps, step, seed): def adaptive_hamiltonian_monte_carlo_step( adaptive_hmc_state: AdaptiveHamiltonianMonteCarloState, target_log_prob_fn: fun_mc.PotentialFn, - num_adaptation_steps: Optional[fun_mc.IntTensor], - variance_window_steps: fun_mc.IntTensor = 100, - trajectory_length_factor: fun_mc.FloatTensor = 1.0, + num_adaptation_steps: Optional[fun_mc.IntArray], + variance_window_steps: int | fun_mc.IntArray = 100, + trajectory_length_factor: float | fun_mc.FloatArray = 1.0, num_trajectory_ramp_steps: Optional[int] = 100, - trajectory_warmup_power: fun_mc.FloatTensor = 1.0, + trajectory_warmup_power: float | fun_mc.FloatArray = 1.0, max_num_leapfrog_steps: Optional[int] = 100, - step_size_adaptation_rate: fun_mc.FloatTensor = 0.05, - step_size_adaptation_rate_decay_power: fun_mc.FloatTensor = 0.1, - target_accept_prob: fun_mc.FloatTensor = 0.8, - step_size_averaging_window_steps: fun_mc.IntTensor = 100, + step_size_adaptation_rate: float | fun_mc.FloatArray = 0.05, + step_size_adaptation_rate_decay_power: float | fun_mc.FloatArray = 0.1, + target_accept_prob: float | fun_mc.FloatArray = 0.8, + step_size_averaging_window_steps: int | fun_mc.IntArray = 100, jitter_sample_fn: Callable[ - [fun_mc.IntTensor, fun_mc.IntTensor, Any], fun_mc.IntTensor + [fun_mc.IntArray, fun_mc.IntArray, Any], fun_mc.IntArray ] = (_uniform_jitter), log_accept_ratio_reduce_fn: Callable[ - [fun_mc.FloatTensor], fun_mc.FloatTensor + [fun_mc.FloatArray], fun_mc.FloatArray ] = (tfp.math.reduce_logmeanexp), hmc_kwargs: Optional[dict[str, Any]] = None, seed: Any = None, @@ -427,7 +460,7 @@ def target_log_prob_fn(*x): # Start out at zeros (in the unconstrained space). state, _ = transform_fn( - *map(lambda e: tf.zeros([num_chains] + list(e)), model.event_shape)) + *map(lambda e: jnp.zeros([num_chains] + list(e)), model.event_shape)) reparam_log_prob_fn, reparam_state = fun_mc.reparameterize_potential_fn( target_log_prob_fn, transform_fn, state) @@ -445,7 +478,7 @@ def kernel(adaptive_hmc_state): adaptive_hmc_extra.step_size) - _, (state_chain, is_accepted_chain, step_size_chain) = tf.function( + _, (state_chain, is_accepted_chain, step_size_chain) = jax.jit( lambda: fun_mc.trace( state=fun_mc.prefab.adaptive_hamiltonian_monte_carlo_init( reparam_state, reparam_log_prob_fn), @@ -458,25 +491,22 @@ def kernel(adaptive_hmc_state): is_accepted_chain = is_accepted_chain[num_warmup_steps:] # Compute diagnostics. - accept_rate = tf.reduce_mean(tf.cast(is_accepted_chain, tf.float32)) + accept_rate = jnp.mean(jnp.array(is_accepted_chain, jnp.float32)) ess = tfp.mcmc.effective_sample_size( state_chain, filter_beyond_positive_pairs=True, cross_chain_dims=[1, 1]) rhat = tfp.mcmc.potential_scale_reduction(state_chain) # Compute relevant quantities. - sample_mean = [tf.reduce_mean(s, axis=[0, 1]) for s in state_chain] - sample_var = [tf.math.reduce_variance(s, axis=[0, 1]) for s in state_chain] + sample_mean = [jnp.mean(s, axis=[0, 1]) for s in state_chain] + sample_var = [jnp.var(s, axis=[0, 1]) for s in state_chain] # It's also important to look at the `step_size_chain` (e.g. via a plot), to # verify that adaptation succeeded. ``` - """ dtype = util.flatten_tree(adaptive_hmc_state.hmc_state.state)[0].dtype - trajectory_length_factor = tf.convert_to_tensor( - trajectory_length_factor, dtype=dtype) - trajectory_warmup_power = tf.convert_to_tensor( - trajectory_warmup_power, dtype=dtype) + trajectory_length_factor = jnp.asarray(trajectory_length_factor, dtype=dtype) + trajectory_warmup_power = jnp.asarray(trajectory_warmup_power, dtype=dtype) hmc_state = adaptive_hmc_state.hmc_state running_var_state = adaptive_hmc_state.running_var_state @@ -487,7 +517,7 @@ def kernel(adaptive_hmc_state): if num_trajectory_ramp_steps is not None: trajectory_length_factor = _polynomial_decay( step=step, - step_size=tf.constant(0., dtype), + step_size=jnp.array(0.0, dtype), decay_steps=num_trajectory_ramp_steps, final_step_size=trajectory_length_factor, power=trajectory_warmup_power, @@ -495,20 +525,22 @@ def kernel(adaptive_hmc_state): # Compute the per-component step_size and num_leapfrog_steps from the variance # estimate. - scale = util.map_tree(tf.math.sqrt, running_var_state.variance) + scale = util.map_tree(jnp.sqrt, running_var_state.variance) step_size = ssa_state.step_size(num_adaptation_steps=num_adaptation_steps) - num_leapfrog_steps = tf.cast( - tf.math.ceil(trajectory_length_factor / step_size), tf.int32) - num_leapfrog_steps = tf.maximum(1, num_leapfrog_steps) + num_leapfrog_steps = jnp.array( + jnp.ceil(trajectory_length_factor / step_size), jnp.int32 + ) + num_leapfrog_steps = jnp.maximum(1, num_leapfrog_steps) if max_num_leapfrog_steps is not None: - num_leapfrog_steps = tf.minimum(max_num_leapfrog_steps, num_leapfrog_steps) + num_leapfrog_steps = jnp.minimum(max_num_leapfrog_steps, num_leapfrog_steps) # We implement mass-matrix adaptation via step size rescaling, as this is a # little bit simpler to code up. step_size = util.map_tree(lambda scale: scale * step_size, scale) hmc_seed, jitter_seed = util.split_seed(seed, 2) - jittered_num_leapfrog_steps = jitter_sample_fn(num_leapfrog_steps, step, - jitter_seed) + jittered_num_leapfrog_steps = jitter_sample_fn( + num_leapfrog_steps, step, jitter_seed + ) # Run a step of HMC. hmc_kwargs = hmc_kwargs or {} @@ -528,15 +560,17 @@ def kernel(adaptive_hmc_state): running_var_state, hmc_state.state, axis=tuple(range(chain_ndims)) if chain_ndims else None, - window_size=int(np.prod(hmc_state.target_log_prob.shape)) * - variance_window_steps) + window_size=int(np.prod(hmc_state.target_log_prob.shape)) + * variance_window_steps, + ) if num_adaptation_steps is not None: # Take care of adaptation for variance. running_var_state = util.map_tree( - lambda n, o: tf.where(step < num_adaptation_steps, n, o), # pylint: disable=g-long-lambda + lambda n, o: jnp.where(step < num_adaptation_steps, n, o), # pylint: disable=g-long-lambda running_var_state, - old_running_var_state) + old_running_var_state, + ) # Update the scalar step size as a function of acceptance rate. ssa_state, _ = step_size_adaptation_step( @@ -562,7 +596,8 @@ def kernel(adaptive_hmc_state): hmc_extra=hmc_extra, step_size=step_size, mean_num_leapfrog_steps=num_leapfrog_steps, - num_leapfrog_steps=jittered_num_leapfrog_steps) + num_leapfrog_steps=jittered_num_leapfrog_steps, + ) return adaptive_hmc_state, extra @@ -578,21 +613,21 @@ def _tqdm_progress_bar_fn(iterable: Iterable[Any]) -> Iterable[Any]: def interactive_trace( state: fun_mc.State, fn: fun_mc.TransitionOperator, - num_steps: fun_mc.IntTensor, - trace_mask: fun_mc.BooleanNest = True, + num_steps: fun_mc.IntArray, + trace_mask: bool | fun_mc.BooleanNest = True, iteration_axis: int = 0, block_until_ready: bool = True, progress_bar_fn: Callable[[Iterable[Any]], Iterable[Any]] = ( _tqdm_progress_bar_fn ), -) -> tuple[fun_mc.State, fun_mc.TensorNest]: +) -> tuple[fun_mc.State, fun_mc.ArrayNest]: """Wrapped around fun_mc.trace, suited for interactive work. This is accomplished through unrolling fun_mc.trace, as well as optionally using a progress bar (TQDM by default). Args: - state: A nest of `Tensor`s or None. + state: A nest of `Array`s or None. fn: A `TransitionOperator`. num_steps: Number of steps to run the function for. Must be greater than 1. trace_mask: A potentially shallow nest with boolean leaves applied to the @@ -621,10 +656,7 @@ def interactive_trace( but with leaves replaced with stacked and unstacked values according to the `trace_mask`. """ - num_steps = tf.get_static_value(num_steps) - if num_steps is None: - raise ValueError( - 'Interactive tracing requires `num_steps` to be statically known.') + num_steps = int(num_steps) if progress_bar_fn is None: pbar = None @@ -656,10 +688,12 @@ def fn_with_progress(state): def fix_part(x): x = util.move_axis(x, 0, iteration_axis - 1) - x = tf.reshape( + x = jnp.reshape( x, - tuple(x.shape[:iteration_axis - 1]) + (-1,) + - tuple(x.shape[iteration_axis + 1:])) + tuple(x.shape[: iteration_axis - 1]) + + (-1,) + + tuple(x.shape[iteration_axis + 1 :]), + ) return x trace = util.map_tree(fix_part, trace) @@ -668,19 +702,21 @@ def fix_part(x): class PersistentHamiltonianMonteCarloState(NamedTuple): """Persistent Hamiltonian Monte Carlo state.""" + state: fun_mc.State state_grads: fun_mc.State momentum: fun_mc.State - target_log_prob: fun_mc.FloatTensor + target_log_prob: fun_mc.FloatArray state_extra: Any - direction: fun_mc.FloatTensor + direction: fun_mc.FloatArray pmh_state: fun_mc.PersistentMetropolistHastingsState class PersistentHamiltonianMonteCarloExtra(NamedTuple): """Persistent Hamiltonian Monte Carlo extra outputs.""" - is_accepted: fun_mc.BooleanTensor - log_accept_ratio: fun_mc.FloatTensor + + is_accepted: fun_mc.BooleanArray + log_accept_ratio: fun_mc.FloatArray proposed_phmc_state: fun_mc.State integrator_state: fun_mc.IntegratorState integrator_extra: fun_mc.IntegratorExtras @@ -690,10 +726,10 @@ class PersistentHamiltonianMonteCarloExtra(NamedTuple): @util.named_call def persistent_hamiltonian_monte_carlo_init( - state: fun_mc.TensorNest, + state: fun_mc.ArrayNest, target_log_prob_fn: fun_mc.PotentialFn, momentum: Optional[fun_mc.State] = None, - init_level: fun_mc.FloatTensor = 0., + init_level: float | fun_mc.FloatArray = 0.0, ) -> PersistentHamiltonianMonteCarloState: """Initializes the `PersistentHamiltonianMonteCarloState`. @@ -706,23 +742,25 @@ def persistent_hamiltonian_monte_carlo_init( Returns: hmc_state: `PersistentMetropolistHastingsState`. """ - state = util.map_tree(tf.convert_to_tensor, state) + state = util.map_tree(jnp.asarray, state) target_log_prob, state_extra, state_grads = util.map_tree( - tf.convert_to_tensor, + jnp.asarray, fun_mc.call_potential_fn_with_grads(target_log_prob_fn, state), ) return PersistentHamiltonianMonteCarloState( state=state, state_grads=state_grads, - momentum=momentum if momentum is not None else util.map_tree( - tf.zeros_like, state), + momentum=momentum + if momentum is not None + else util.map_tree(jnp.zeros_like, state), target_log_prob=target_log_prob, state_extra=state_extra, - direction=tf.ones_like(target_log_prob), + direction=jnp.ones_like(target_log_prob), pmh_state=fun_mc.persistent_metropolis_hastings_init( shape=target_log_prob.shape, dtype=target_log_prob.dtype, - init_level=init_level), + init_level=init_level, + ), ) @@ -734,31 +772,31 @@ def persistent_hamiltonian_monte_carlo_step( phmc_state: PersistentHamiltonianMonteCarloState, target_log_prob_fn: fun_mc.PotentialFn, step_size: Optional[Any] = None, - num_integrator_steps: Optional[fun_mc.IntTensor] = None, - noise_fraction: Optional[fun_mc.FloatTensor] = None, - mh_drift: Optional[fun_mc.FloatTensor] = None, + num_integrator_steps: Optional[fun_mc.IntArray] = None, + noise_fraction: Optional[fun_mc.FloatArray] = None, + mh_drift: Optional[fun_mc.FloatArray] = None, kinetic_energy_fn: Optional[fun_mc.PotentialFn] = None, momentum_sample_fn: Optional[PersistentMomentumSampleFn] = None, integrator_trace_fn: Callable[ [fun_mc.IntegratorStepState, fun_mc.IntegratorStepExtras], - fun_mc.TensorNest, + fun_mc.ArrayNest, ] = lambda *args: (), - log_uniform: Optional[fun_mc.FloatTensor] = None, + log_uniform: Optional[fun_mc.FloatArray] = None, integrator_fn: Optional[ Callable[ - [fun_mc.IntegratorState, fun_mc.FloatTensor], + [fun_mc.IntegratorState, fun_mc.FloatArray], tuple[fun_mc.IntegratorState, fun_mc.IntegratorExtras], ] ] = None, unroll_integrator: bool = False, - max_num_integrator_steps: Optional[fun_mc.IntTensor] = None, + max_num_integrator_steps: Optional[fun_mc.IntArray] = None, energy_change_fn: Callable[ [ fun_mc.IntegratorState, fun_mc.IntegratorState, fun_mc.IntegratorExtras, ], - tuple[fun_mc.FloatTensor, Any], + tuple[fun_mc.FloatArray, Any], ] = ( fun_mc._default_hamiltonian_monte_carlo_energy_change_fn # pylint: disable=protected-access ), @@ -881,48 +919,55 @@ def persistent_hamiltonian_monte_carlo_step( # Impute the optional args. if kinetic_energy_fn is None: kinetic_energy_fn = fun_mc.make_gaussian_kinetic_energy_fn( - len(target_log_prob.shape) if target_log_prob.shape is not None else tf - .rank(target_log_prob), named_axis=named_axis) + len(target_log_prob.shape), named_axis=named_axis + ) if momentum_sample_fn is None: if named_axis is None: named_axis = util.map_tree(lambda _: [], state) - def _momentum_sample_fn(old_momentum: fun_mc.State, - seed: Any) -> tuple[fun_mc.State, tuple[()]]: + def _momentum_sample_fn( + old_momentum: fun_mc.State, seed: Any + ) -> tuple[fun_mc.State, tuple[()]]: seeds = util.unflatten_tree( old_momentum, - util.split_seed(seed, len(util.flatten_tree(old_momentum)))) + util.split_seed(seed, len(util.flatten_tree(old_momentum))), + ) def _sample_part(old_momentum, seed, named_axis): seed = backend.distribute_lib.fold_in_axis_index(seed, named_axis) - return ( - tf.math.sqrt(1 - tf.square(noise_fraction)) * old_momentum + - noise_fraction * - util.random_normal(old_momentum.shape, old_momentum.dtype, seed)) - - new_momentum = util.map_tree_up_to(state, _sample_part, old_momentum, - seeds, named_axis) + return jnp.sqrt( + 1 - jnp.square(noise_fraction) + ) * old_momentum + noise_fraction * util.random_normal( + old_momentum.shape, old_momentum.dtype, seed + ) + + new_momentum = util.map_tree_up_to( + state, _sample_part, old_momentum, seeds, named_axis + ) return new_momentum momentum_sample_fn = _momentum_sample_fn if integrator_fn is None: - step_size = util.map_tree(tf.convert_to_tensor, step_size) + step_size = util.map_tree(jnp.asarray, step_size) step_size = fun_mc.maybe_broadcast_structure(step_size, state) def _integrator_fn( - state: fun_mc.IntegratorState, direction: fun_mc.FloatTensor + state: fun_mc.IntegratorState, direction: fun_mc.FloatArray ) -> tuple[fun_mc.IntegratorState, fun_mc.IntegratorExtras]: - directional_step_size = util.map_tree( lambda step_size, state: ( # pylint: disable=g-long-lambda - step_size * tf.reshape( + step_size + * jnp.reshape( direction, - list(direction.shape) + [1] * - (len(state.shape) - len(direction.shape)))), + list(direction.shape) + + [1] * (len(state.shape) - len(direction.shape)), + ) + ), step_size, - state.state) + state.state, + ) # TODO(siege): Ideally we'd pass in the direction here, but the # `hamiltonian_integrator` cannot handle dynamic direction switching like # that. @@ -933,11 +978,13 @@ def _integrator_fn( fun_mc.leapfrog_step, step_size=directional_step_size, target_log_prob_fn=target_log_prob_fn, - kinetic_energy_fn=kinetic_energy_fn), + kinetic_energy_fn=kinetic_energy_fn, + ), kinetic_energy_fn=kinetic_energy_fn, unroll=unroll_integrator, max_num_steps=max_num_integrator_steps, - integrator_trace_fn=integrator_trace_fn) + integrator_trace_fn=integrator_trace_fn, + ) integrator_fn = _integrator_fn @@ -952,8 +999,9 @@ def _integrator_fn( state_extra=state_extra, ) - integrator_state, integrator_extra = integrator_fn(initial_integrator_state, - direction) + integrator_state, integrator_extra = integrator_fn( + initial_integrator_state, direction + ) proposed_state = phmc_state._replace( state=integrator_state.state, @@ -976,12 +1024,14 @@ def _integrator_fn( ) if log_uniform is None: - pmh_state, pmh_extra = fun_mc.persistent_metropolis_hastings_step( + # There's some lint error due to the wrapper here. + pmh_state, pmh_extra = fun_mc.persistent_metropolis_hastings_step( # pytype: disable=wrong-keyword-args pmh_state, current_state=phmc_state, proposed_state=proposed_state, energy_change=energy_change, - drift=mh_drift) + drift=mh_drift, + ) is_accepted = pmh_extra.is_accepted phmc_state = pmh_extra.accepted_state else: @@ -990,7 +1040,8 @@ def _integrator_fn( current_state=phmc_state, proposed_state=proposed_state, energy_change=energy_change, - log_uniform=log_uniform) + log_uniform=log_uniform, + ) is_accepted = mh_extra.is_accepted phmc_state = typing.cast(PersistentHamiltonianMonteCarloState, phmc_state) @@ -1009,4 +1060,5 @@ def _integrator_fn( integrator_state=integrator_state, integrator_extra=integrator_extra, energy_change_extra=energy_change_extra, - initial_momentum=momentum) + initial_momentum=momentum, + ) diff --git a/spinoffs/fun_mc/fun_mc/prefab_test.py b/spinoffs/fun_mc/fun_mc/prefab_test.py index 5b7b85be3a..c9d5f82966 100644 --- a/spinoffs/fun_mc/fun_mc/prefab_test.py +++ b/spinoffs/fun_mc/fun_mc/prefab_test.py @@ -19,11 +19,10 @@ # Dependency imports -import jax +import jax as real_jax from jax import config as jax_config import numpy as np import tensorflow.compat.v2 as real_tf - from tensorflow_probability.python.internal import test_util as tfp_test_util from fun_mc import backend from fun_mc import fun_mc_lib as fun_mc @@ -31,7 +30,8 @@ from fun_mc import test_util from fun_mc import util_tfp -tf = backend.tf +jax = backend.jax +jnp = backend.jnp tfp = backend.tfp util = backend.util tfd = tfp.distributions @@ -42,8 +42,10 @@ BACKEND = None # Rewritten by backends/rewrite.py. if BACKEND == 'backend_jax': - os.environ['XLA_FLAGS'] = (f'{os.environ.get("XLA_FLAGS", "")} ' - '--xla_force_host_platform_device_count=4') + os.environ['XLA_FLAGS'] = ( + f'{os.environ.get("XLA_FLAGS", "")} ' + '--xla_force_host_platform_device_count=4' + ) def _test_seed(): @@ -51,12 +53,11 @@ def _test_seed(): class PrefabTest(tfp_test_util.TestCase): - _is_on_jax = BACKEND == 'backend_jax' def _make_seed(self, seed): if self._is_on_jax: - return jax.random.PRNGKey(seed) + return real_jax.random.PRNGKey(seed) else: return util.make_tensor_seed([seed, 0]) @@ -65,7 +66,7 @@ def _dtype(self): raise NotImplementedError() def _constant(self, value): - return tf.constant(value, self._dtype) + return jnp.array(value, self._dtype) def testAdaptiveHMC(self): num_chains = 16 @@ -75,26 +76,33 @@ def testAdaptiveHMC(self): # Setup the model and state constraints. model = tfp.distributions.JointDistributionSequential([ - tfp.distributions.Normal(loc=self._constant(0.), scale=1.), + tfp.distributions.Normal(loc=self._constant(0.0), scale=1.0), tfp.distributions.Independent( tfp.distributions.LogNormal( - loc=self._constant([1., 1.]), scale=0.5), 1), + loc=self._constant([1.0, 1.0]), scale=0.5 + ), + 1, + ), ]) bijector = [tfp.bijectors.Identity(), tfp.bijectors.Exp()] transform_fn = util_tfp.bijector_to_transform_fn( - bijector, model.dtype, batch_ndims=1) + bijector, model.dtype, batch_ndims=1 + ) def target_log_prob_fn(*x): return model.log_prob(x), () # Start out at zeros (in the unconstrained space). - state, _ = transform_fn(*[ - tf.zeros([num_chains] + list(e), dtype=self._dtype) - for e in model.event_shape - ]) + state, _ = transform_fn( + *[ + jnp.zeros([num_chains] + list(e), dtype=self._dtype) + for e in model.event_shape + ] + ) reparam_log_prob_fn, reparam_state = fun_mc.reparameterize_potential_fn( - target_log_prob_fn, transform_fn, state) + target_log_prob_fn, transform_fn, state + ) # Define the kernel. def kernel(adaptive_hmc_state, seed): @@ -105,40 +113,48 @@ def kernel(adaptive_hmc_state, seed): adaptive_hmc_state, target_log_prob_fn=reparam_log_prob_fn, num_adaptation_steps=num_adapt_steps, - seed=hmc_seed)) + seed=hmc_seed, + ) + ) - return (adaptive_hmc_state, - seed), (adaptive_hmc_extra.state, adaptive_hmc_extra.is_accepted, - adaptive_hmc_extra.step_size) + return (adaptive_hmc_state, seed), ( + adaptive_hmc_extra.state, + adaptive_hmc_extra.is_accepted, + adaptive_hmc_extra.step_size, + ) seed = self._make_seed(_test_seed()) - # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs - # for the jit to do anything. - _, (state_chain, is_accepted_chain, - _) = tf.function(lambda reparam_state, seed: fun_mc.trace( # pylint: disable=g-long-lambda - state=(prefab.adaptive_hamiltonian_monte_carlo_init( - reparam_state, reparam_log_prob_fn), seed), + _, (state_chain, is_accepted_chain, _) = jax.jit( + lambda reparam_state, seed: fun_mc.trace( # pylint: disable=g-long-lambda + state=( + prefab.adaptive_hamiltonian_monte_carlo_init( + reparam_state, reparam_log_prob_fn + ), + seed, + ), fn=kernel, - num_steps=num_steps))(reparam_state, seed) + num_steps=num_steps, + ) + )(reparam_state, seed) # Discard the warmup samples. state_chain = [s[num_warmup_steps:] for s in state_chain] is_accepted_chain = is_accepted_chain[num_warmup_steps:] - accept_rate = tf.reduce_mean(tf.cast(is_accepted_chain, tf.float32)) + accept_rate = jnp.mean(jnp.asarray(is_accepted_chain, jnp.float32)) rhat = tfp.mcmc.potential_scale_reduction(state_chain) - sample_mean = [tf.reduce_mean(s, axis=[0, 1]) for s in state_chain] - sample_var = [tf.math.reduce_variance(s, axis=[0, 1]) for s in state_chain] + sample_mean = [jnp.mean(s, axis=[0, 1]) for s in state_chain] + sample_var = [jnp.var(s, axis=[0, 1]) for s in state_chain] - self.assertAllAssertsNested(lambda rhat: self.assertAllLess(rhat, 1.1), - rhat) + self.assertAllAssertsNested( + lambda rhat: self.assertAllLess(rhat, 1.1), rhat + ) self.assertAllClose(0.8, accept_rate, atol=0.05) self.assertAllClose(model.mean(), sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(model.variance(), sample_var, rtol=0.1, atol=0.1) def testInteractiveTrace(self): - def kernel(x): return x + 1, x @@ -150,82 +166,87 @@ def progress_bar_fn(iterable): yield x_fin, x_trace = prefab.interactive_trace( - 0., kernel, num_steps=5, progress_bar_fn=progress_bar_fn) + 0.0, kernel, num_steps=5, progress_bar_fn=progress_bar_fn + ) self.assertAllClose(5, x_fin) self.assertAllClose(np.arange(5), x_trace) self.assertEqual(5, counter[0]) def testStepSizeAdaptation(self): - def log_accept_ratio_fn(step_size): - return -step_size**2 + return -(step_size**2) def kernel(ssa_state, seed): normal_seed, seed = util.split_seed(seed, 2) - log_accept_ratio = ( - log_accept_ratio_fn(ssa_state.step_size()) + - 0.01 * util.random_normal([4], self._dtype, normal_seed)) + log_accept_ratio = log_accept_ratio_fn( + ssa_state.step_size() + ) + 0.01 * util.random_normal([4], self._dtype, normal_seed) ssa_state, ssa_extra = prefab.step_size_adaptation_step( - ssa_state, log_accept_ratio, num_adaptation_steps=100) - return (ssa_state, seed), (ssa_extra.accept_prob, ssa_state.step_size(), - ssa_state.step_size(num_adaptation_steps=100)) + ssa_state, log_accept_ratio, num_adaptation_steps=100 + ) + return (ssa_state, seed), ( + ssa_extra.accept_prob, + ssa_state.step_size(), + ssa_state.step_size(num_adaptation_steps=100), + ) seed = self._make_seed(_test_seed()) _, (p_accept, step_size, rms_step_size) = fun_mc.trace( - (prefab.step_size_adaptation_init(tf.constant(0.1, self._dtype)), seed), - kernel, 200) + (prefab.step_size_adaptation_init(jnp.array(0.1, self._dtype)), seed), + kernel, + 200, + ) self.assertAllClose(0.8, p_accept[100], atol=0.1) self.assertAllClose(step_size[100], step_size[150]) self.assertAllClose(rms_step_size[100], rms_step_size[150]) def testInteractiveIterationAxis1(self): - def kernel(x): return x + 1, x state, trace = prefab.interactive_trace( - 0., + 0.0, lambda x: fun_mc.trace(x, kernel, 5), 20, iteration_axis=1, - progress_bar_fn=None) + progress_bar_fn=None, + ) - self.assertAllClose(100., state) + self.assertAllClose(100.0, state) self.assertEqual([100], list(trace.shape)) - self.assertAllClose(99., trace[-1]) + self.assertAllClose(99.0, trace[-1]) def testInteractiveIterationAxis2(self): - def kernel(x): return x + 1, x def inner(x): state, trace = fun_mc.trace(x, kernel, 5) - trace = tf.transpose(trace, [1, 0]) + trace = jnp.transpose(trace, [1, 0]) return state, trace state, trace = prefab.interactive_trace( - tf.zeros(2), inner, 20, iteration_axis=2, progress_bar_fn=None) + jnp.zeros(2), inner, 20, iteration_axis=2, progress_bar_fn=None + ) - self.assertAllClose([100., 100.], state) + self.assertAllClose([100.0, 100.0], state) self.assertEqual([2, 100], list(trace.shape)) - self.assertAllClose([99., 99.], trace[:, -1]) + self.assertAllClose([99.0, 99.0], trace[:, -1]) def testPHMC(self): step_size = self._constant(0.2) num_steps = 2000 num_leapfrog_steps = 10 - state = tf.ones([16, 2], dtype=self._dtype) + state = jnp.ones([16, 2], dtype=self._dtype) - base_mean = self._constant([2., 3.]) - base_scale = self._constant([2., 0.5]) + base_mean = self._constant([2.0, 3.0]) + base_scale = self._constant([2.0, 0.5]) def target_log_prob_fn(x): - return -tf.reduce_sum(0.5 * tf.square( - (x - base_mean) / base_scale), -1), () + return -jnp.sum(0.5 * jnp.square((x - base_mean) / base_scale), -1), () def kernel(phmc_state, seed): phmc_seed, seed = util.split_seed(seed, 2) @@ -236,40 +257,49 @@ def kernel(phmc_state, seed): target_log_prob_fn=target_log_prob_fn, noise_fraction=self._constant(0.5), mh_drift=self._constant(0.127), - seed=phmc_seed) + seed=phmc_seed, + ) return (phmc_state, seed), phmc_state.state seed = self._make_seed(_test_seed()) - # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs - # for the jit to do anything. - _, chain = tf.function(lambda state, seed: fun_mc.trace( # pylint: disable=g-long-lambda - state=(prefab.persistent_hamiltonian_monte_carlo_init( - state, target_log_prob_fn), seed), - fn=kernel, - num_steps=num_steps))(state, seed) + _, chain = jax.jit( + lambda state, seed: fun_mc.trace( # pylint: disable=g-long-lambda + state=( + prefab.persistent_hamiltonian_monte_carlo_init( + state, target_log_prob_fn + ), + seed, + ), + fn=kernel, + num_steps=num_steps, + ) + )(state, seed) # Discard the warmup samples. chain = chain[1000:] - sample_mean = tf.reduce_mean(chain, axis=[0, 1]) - sample_var = tf.math.reduce_variance(chain, axis=[0, 1]) + sample_mean = jnp.mean(chain, axis=[0, 1]) + sample_var = jnp.var(chain, axis=[0, 1]) - true_samples = util.random_normal( - shape=[4096, 2], dtype=self._dtype, seed=seed) * base_scale + base_mean + true_samples = ( + util.random_normal(shape=[4096, 2], dtype=self._dtype, seed=seed) + * base_scale + + base_mean + ) - true_mean = tf.reduce_mean(true_samples, axis=0) - true_var = tf.math.reduce_variance(true_samples, axis=0) + true_mean = jnp.mean(true_samples, axis=0) + true_var = jnp.var(true_samples, axis=0) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_var, sample_var, rtol=0.1, atol=0.1) def testPHMCWithLogUniform(self): - def target_log_prob_fn(x): - return -x**2, () + return -(x**2), () phmc_state = prefab.persistent_hamiltonian_monte_carlo_init( - self._constant(0.), target_log_prob_fn) + self._constant(0.0), target_log_prob_fn + ) seed = self._make_seed(_test_seed()) _, accepted_phmc_extra = prefab.persistent_hamiltonian_monte_carlo_step( phmc_state, @@ -278,8 +308,9 @@ def target_log_prob_fn(x): num_integrator_steps=1, noise_fraction=self._constant(0.5), mh_drift=self._constant(0.1), - log_uniform=tf.math.log(self._constant(0.)), - seed=seed) + log_uniform=jnp.log(self._constant(0.0)), + seed=seed, + ) _, rejected_phmc_extra = prefab.persistent_hamiltonian_monte_carlo_step( phmc_state, target_log_prob_fn, @@ -287,8 +318,9 @@ def target_log_prob_fn(x): num_integrator_steps=1, noise_fraction=self._constant(0.5), mh_drift=self._constant(0.1), - log_uniform=0., - seed=seed) + log_uniform=0.0, + seed=seed, + ) self.assertTrue(accepted_phmc_extra.is_accepted) self.assertFalse(rejected_phmc_extra.is_accepted) @@ -298,8 +330,8 @@ def testPHMCNamedAxis(self): self.skipTest('JAX-only') state = { - 'sharded': tf.zeros([4, 1024], self._dtype), - 'shared': tf.zeros([1024], self._dtype), + 'sharded': jnp.zeros([4, 1024], self._dtype), + 'shared': jnp.zeros([1024], self._dtype), } in_axes = { 'sharded': 0, @@ -311,14 +343,21 @@ def testPHMCNamedAxis(self): } def target_log_prob_fn(sharded, shared): - return -(backend.distribute_lib.psum(tf.square(sharded), 'named_axis') + - tf.square(shared)), () + return ( + -( + backend.distribute_lib.psum(jnp.square(sharded), 'named_axis') + + jnp.square(shared) + ), + (), + ) @functools.partial( - jax.pmap, in_axes=(in_axes, None), axis_name='named_axis') + real_jax.pmap, in_axes=(in_axes, None), axis_name='named_axis' + ) def kernel(state, seed): phmc_state = prefab.persistent_hamiltonian_monte_carlo_init( - state, target_log_prob_fn=target_log_prob_fn) + state, target_log_prob_fn=target_log_prob_fn + ) phmc_state, phmc_extra = prefab.persistent_hamiltonian_monte_carlo_step( phmc_state, step_size=self._constant(0.2), @@ -327,20 +366,27 @@ def kernel(state, seed): mh_drift=0.1, target_log_prob_fn=target_log_prob_fn, named_axis=named_axis, - seed=seed) + seed=seed, + ) return phmc_state, phmc_extra seed = self._make_seed(_test_seed()) phmc_state, phmc_extra = kernel(state, seed) - self.assertAllClose(phmc_state.state['shared'][0], - phmc_state.state['shared'][1]) + self.assertAllClose( + phmc_state.state['shared'][0], phmc_state.state['shared'][1] + ) self.assertTrue( np.any( - np.abs(phmc_state.state['sharded'][0] - - phmc_state.state['sharded'][1]) > 1e-3)) + np.abs( + phmc_state.state['sharded'][0] - phmc_state.state['sharded'][1] + ) + > 1e-3 + ) + ) self.assertAllClose(phmc_extra.is_accepted[0], phmc_extra.is_accepted[1]) - self.assertAllClose(phmc_extra.log_accept_ratio[0], - phmc_extra.log_accept_ratio[1]) + self.assertAllClose( + phmc_extra.log_accept_ratio[0], phmc_extra.log_accept_ratio[1] + ) @test_util.multi_backend_test(globals(), 'prefab_test') @@ -348,7 +394,7 @@ class PrefabTest32(PrefabTest): @property def _dtype(self): - return tf.float32 + return jnp.float32 @test_util.multi_backend_test(globals(), 'prefab_test') @@ -356,7 +402,7 @@ class PrefabTest64(PrefabTest): @property def _dtype(self): - return tf.float64 + return jnp.float64 del PrefabTest diff --git a/spinoffs/fun_mc/fun_mc/sga_hmc.py b/spinoffs/fun_mc/fun_mc/sga_hmc.py index 73ea7cba5e..0e3e5f0589 100644 --- a/spinoffs/fun_mc/fun_mc/sga_hmc.py +++ b/spinoffs/fun_mc/fun_mc/sga_hmc.py @@ -30,7 +30,8 @@ from fun_mc import backend from fun_mc import fun_mc_lib as fun_mc -tf = backend.tf +jax = backend.jax +jnp = backend.jnp tfp = backend.tfp util = backend.util distribute_lib = backend.distribute_lib @@ -54,17 +55,18 @@ class HamiltonianMonteCarloWithStateGradsExtra(NamedTuple): """Extra outputs for hamiltonian_monte_carlo_with_state_grads_step.""" + hmc_extra: fun_mc.HamiltonianMonteCarloExtra - num_integrator_steps: fun_mc.IntTensor + num_integrator_steps: fun_mc.IntArray proposed_state: fun_mc.State @util.named_call def hamiltonian_monte_carlo_with_state_grads_step( hmc_state: fun_mc.HamiltonianMonteCarloState, - trajectory_length: fun_mc.FloatTensor, - scalar_step_size: fun_mc.FloatTensor, - step_size_scale: fun_mc.FloatNest = 1.0, + trajectory_length: fun_mc.FloatArray, + scalar_step_size: fun_mc.FloatArray, + step_size_scale: float | fun_mc.FloatNest = 1.0, named_axis: Optional[fun_mc.StringNest] = None, **hmc_kwargs, ) -> tuple[ @@ -105,35 +107,40 @@ def hamiltonian_monte_carlo_with_state_grads_step( consts = scalar_step_size, hmc_state, step_size_scale flat_consts = util.flatten_tree(consts) - @tf.custom_gradient + @jax.custom_gradient def hmc(*traj_and_flat_consts): trajectory_length = traj_and_flat_consts[0] - scalar_step_size, hmc_state, step_size_scale = ( - util.unflatten_tree(consts, traj_and_flat_consts[1:]) + scalar_step_size, hmc_state, step_size_scale = util.unflatten_tree( + consts, traj_and_flat_consts[1:] + ) + trajectory_length = jnp.asarray(trajectory_length) + num_integrator_steps = jnp.asarray( + jnp.ceil(trajectory_length / scalar_step_size), jnp.int32 ) - trajectory_length = tf.convert_to_tensor(trajectory_length) - num_integrator_steps = tf.cast( - tf.math.ceil(trajectory_length / scalar_step_size), tf.int32) # In case something goes negative. - num_integrator_steps = tf.maximum(1, num_integrator_steps) + num_integrator_steps = jnp.maximum(1, num_integrator_steps) new_hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo_step( hmc_state, num_integrator_steps=num_integrator_steps, - step_size=util.map_tree(lambda s: s * scalar_step_size, - step_size_scale), + step_size=util.map_tree( + lambda s: s * scalar_step_size, step_size_scale + ), named_axis=named_axis, - **hmc_kwargs) + **hmc_kwargs, + ) hmc_with_grads_extra = HamiltonianMonteCarloWithStateGradsExtra( proposed_state=hmc_extra.proposed_hmc_state.state, hmc_extra=hmc_extra, - num_integrator_steps=num_integrator_steps) + num_integrator_steps=num_integrator_steps, + ) res = (new_hmc_state, hmc_with_grads_extra) def grad(*grads): grads = util.unflatten_tree(res, util.flatten_tree(grads)) step_size_scale_bc = fun_mc.maybe_broadcast_structure( - step_size_scale, hmc_extra.integrator_extra.momentum_grads) + step_size_scale, hmc_extra.integrator_extra.momentum_grads + ) # We wish to compute `grads^T @ # jacobian(proposed_state(trajectory_length))`. @@ -147,14 +154,19 @@ def grad(*grads): # side of that expression is `momentum_grads * step_size_scale` by the # chain rule. Since the Jacobian in question has 1 row, the # vector-Jacobian product is simply the dot product. - state_grads = util.map_tree(lambda s, m, g: s * m * g, step_size_scale_bc, - hmc_extra.integrator_extra.momentum_grads, - grads[1].proposed_state) + state_grads = util.map_tree( + lambda s, m, g: s * m * g, + step_size_scale_bc, + hmc_extra.integrator_extra.momentum_grads, + grads[1].proposed_state, + ) def do_sum(x, named_axis): return distribute_lib.reduce_sum( - x, list(range(len(trajectory_length.shape), len(x.shape))), - named_axis) + x, + list(range(len(trajectory_length.shape), len(x.shape))), + named_axis, + ) if named_axis is None: named_axis_bc = util.map_tree(lambda _: [], state_grads) @@ -163,8 +175,11 @@ def do_sum(x, named_axis): traj_grad = sum( util.flatten_tree( - util.map_tree_up_to(state_grads, do_sum, state_grads, - named_axis_bc))) + util.map_tree_up_to( + state_grads, do_sum, state_grads, named_axis_bc + ) + ) + ) return (traj_grad,) + (None,) * len(flat_consts) return res, grad @@ -176,13 +191,13 @@ def do_sum(x, named_axis): def chees_criterion( previous_state: fun_mc.State, proposed_state: fun_mc.State, - accept_prob: fun_mc.FloatTensor, - trajectory_length: Optional[fun_mc.FloatTensor] = None, + accept_prob: fun_mc.FloatArray, + trajectory_length: Optional[fun_mc.FloatArray] = None, state_mean: Optional[fun_mc.State] = None, - state_mean_weight: fun_mc.FloatNest = 0., + state_mean_weight: float | fun_mc.FloatNest = 0.0, named_axis: Optional[fun_mc.StringNest] = None, chain_named_axis: Optional[fun_mc.StringNest] = None, -) -> tuple[fun_mc.FloatTensor, fun_mc.FloatTensor]: +) -> tuple[fun_mc.FloatArray, fun_mc.FloatArray]: """The ChEES criterion from [1]. ChEES stands for Change in the Estimator of the Expected Square. @@ -212,15 +227,15 @@ def chees_criterion( posterior expectations. Args: - previous_state: (Possibly nested) floating point `Tensor`. The previous - state of the MCMC chain. - proposed_state: (Possibly nested) floating point `Tensor`. The proposed - state of the MCMC chain. - accept_prob: Floating `Tensor`. Probability of acceping the proposed state. + previous_state: (Possibly nested) floating point `Array`. The previous state + of the MCMC chain. + proposed_state: (Possibly nested) floating point `Array`. The proposed state + of the MCMC chain. + accept_prob: Floating `Array`. Probability of acceping the proposed state. trajectory_length: Ignored. - state_mean: (Possibly nested) floating point `Tensor`. Optional estimate of + state_mean: (Possibly nested) floating point `Array`. Optional estimate of the MCMC chain mean. - state_mean_weight: Floating point `Tensor`. Used to weight `state_mean` with + state_mean_weight: Floating point `Array`. Used to weight `state_mean` with the mean computed by averaging across the previous/proposed state. Setting it to effectively uses `state_mean` as the only source of the MCMCM chain mean. @@ -237,7 +252,6 @@ def chees_criterion( [1]: Hoffman, M., Radul, A., & Sountsov, P. (2020). An Adaptive MCMC Scheme for Setting Trajectory Lengths in Hamiltonian Monte Carlo. In preparation. - """ del trajectory_length batch_ndims = len(accept_prob.shape) @@ -245,8 +259,9 @@ def chees_criterion( no_state_mean = object() if state_mean is None: state_mean = fun_mc.maybe_broadcast_structure(no_state_mean, previous_state) - state_mean_weight = fun_mc.maybe_broadcast_structure(state_mean_weight, - previous_state) + state_mean_weight = fun_mc.maybe_broadcast_structure( + state_mean_weight, previous_state + ) if named_axis is None: named_axis_bc = util.map_tree(lambda _: [], previous_state) else: @@ -257,17 +272,19 @@ def chees_criterion( def _center_previous_state(x, mx, mw): x_center = distribute_lib.reduce_mean( - x, axis=batch_axes, named_axis=chain_named_axis) + x, axis=batch_axes, named_axis=chain_named_axis + ) if mx is not no_state_mean: x_center = x_center * (1 - mw) + mx * mw # The empirical mean here is a stand-in for the true mean, so we drop the # gradient that flows through this term. - return x - tf.stop_gradient(x_center) + return x - jax.lax.stop_gradient(x_center) def _center_proposed_state(x, mx, mw): expand_shape = list(accept_prob.shape) + [1] * ( - len(x.shape) - len(accept_prob.shape)) - expanded_accept_prob = tf.reshape(accept_prob, expand_shape) + len(x.shape) - len(accept_prob.shape) + ) + expanded_accept_prob = jnp.reshape(accept_prob, expand_shape) # Weight the proposed state by the acceptance probability. The goal here is # to get a reliable diagnostic of the underlying dynamics, rather than @@ -275,22 +292,24 @@ def _center_proposed_state(x, mx, mw): # accept_prob is zero when x is NaN, but we still want to sanitize such # values. - x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x)) + x_safe = jnp.where(jnp.isfinite(x), x, jnp.zeros_like(x)) # If all accept_prob's are zero, the x_center will have a nonsense value, # but well set the overall criterion to zero in this case, so its fine. - x_center = ( + x_center = distribute_lib.reduce_sum( + expanded_accept_prob * x_safe, + axis=batch_axes, + named_axis=chain_named_axis, + ) / ( distribute_lib.reduce_sum( - expanded_accept_prob * x_safe, - axis=batch_axes, - named_axis=chain_named_axis) / - (distribute_lib.reduce_sum( - expanded_accept_prob, axis=batch_axes, named_axis=chain_named_axis) - + 1e-20)) + expanded_accept_prob, axis=batch_axes, named_axis=chain_named_axis + ) + + 1e-20 + ) if mx is not no_state_mean: x_center = x_center * (1 - mw) + mx * mw # The empirical mean here is a stand-in for the true mean, so we drop the # gradient that flows through this term. - return x - tf.stop_gradient(x_center) + return x - jax.lax.stop_gradient(x_center) def _sum_event_part(x, named_axis): event_axes = tuple(range(batch_ndims, len(x.shape))) @@ -304,49 +323,55 @@ def _sum_event(x): _sum_event_part, x, named_axis_bc, - ))) + ) + ) + ) def _square(x): - return util.map_tree(tf.square, x) + return util.map_tree(jnp.square, x) def _sub(x, y): return util.map_tree(lambda x, y: x - y, x, y) - previous_state = util.map_tree(_center_previous_state, previous_state, - state_mean, state_mean_weight) - proposed_state = util.map_tree(_center_proposed_state, proposed_state, - state_mean, state_mean_weight) - chees = 0.25 * tf.square( - _sum_event(_sub(_square(proposed_state), _square(previous_state)))) + previous_state = util.map_tree( + _center_previous_state, previous_state, state_mean, state_mean_weight + ) + proposed_state = util.map_tree( + _center_proposed_state, proposed_state, state_mean, state_mean_weight + ) + chees = 0.25 * jnp.square( + _sum_event(_sub(_square(proposed_state), _square(previous_state))) + ) # Zero-out per-chain ChEES values where acceptance probability is low. Those # values are probably not reflective of the underlying dynamics. - chees = tf.where(accept_prob > 1e-4, chees, 0.) + chees = jnp.where(accept_prob > 1e-4, chees, 0.0) accept_prob = accept_prob / distribute_lib.reduce_sum( - accept_prob + 1e-20, named_axis=chain_named_axis) + accept_prob + 1e-20, named_axis=chain_named_axis + ) chees = chees * accept_prob return distribute_lib.reduce_mean(chees, named_axis=chain_named_axis), chees class ChEESPerGradExtra(NamedTuple): - chees: fun_mc.FloatTensor - per_chain_chees: fun_mc.FloatTensor - per_chain_chees_per_grad: fun_mc.FloatTensor + chees: fun_mc.FloatArray + per_chain_chees: fun_mc.FloatArray + per_chain_chees_per_grad: fun_mc.FloatArray @util.named_call def chees_per_grad_criterion( previous_state: fun_mc.State, proposed_state: fun_mc.State, - accept_prob: fun_mc.FloatTensor, - trajectory_length: fun_mc.FloatTensor, - power: fun_mc.FloatTensor = 1., + accept_prob: fun_mc.FloatArray, + trajectory_length: fun_mc.FloatArray, + power: float | fun_mc.FloatArray = 1.0, state_mean: Optional[fun_mc.State] = None, - state_mean_weight: fun_mc.FloatNest = 0., + state_mean_weight: float | fun_mc.FloatNest = 0.0, named_axis: Optional[fun_mc.StringNest] = None, chain_named_axis: Optional[fun_mc.StringNest] = None, -) -> tuple[fun_mc.FloatTensor, ChEESPerGradExtra]: +) -> tuple[fun_mc.FloatArray, ChEESPerGradExtra]: """ChEES per gradient criterion. This criterion is computed as: @@ -365,17 +390,17 @@ def chees_per_grad_criterion( less likely to be found. Args: - previous_state: (Possibly nested) floating point `Tensor`. The previous - state of the MCMC chain. - proposed_state: (Possibly nested) floating point `Tensor`. The proposed - state of the MCMC chain. - accept_prob: Floating `Tensor`. Probability of acceping the proposed state. + previous_state: (Possibly nested) floating point `Array`. The previous state + of the MCMC chain. + proposed_state: (Possibly nested) floating point `Array`. The proposed state + of the MCMC chain. + accept_prob: Floating `Array`. Probability of acceping the proposed state. trajectory_length: Trajectory length associated with the transition from `previous_state` to `proposed_state`. - power: Floating `Tensor`. Used to scale the `trajectory_length` term. - state_mean: (Possibly nested) floating point `Tensor`. Optional estimate of + power: Floating `Array`. Used to scale the `trajectory_length` term. + state_mean: (Possibly nested) floating point `Array`. Optional estimate of the MCMC chain mean. - state_mean_weight: Floating point `Tensor`. Used to weight `state_mean` with + state_mean_weight: Floating point `Array`. Used to weight `state_mean` with the mean computed by averaging across the previous/proposed state. Setting it to effectively uses `state_mean` as the only source of the MCMCM chain mean. @@ -394,41 +419,52 @@ def chees_per_grad_criterion( state_mean=state_mean, state_mean_weight=state_mean_weight, named_axis=named_axis, - chain_named_axis=chain_named_axis) + chain_named_axis=chain_named_axis, + ) per_chain_chees_per_grad = per_chain_chees / distribute_lib.pbroadcast( - trajectory_length**power, chain_named_axis) + trajectory_length**power, chain_named_axis + ) extra = ChEESPerGradExtra( chees=chees, per_chain_chees=per_chain_chees, per_chain_chees_per_grad=per_chain_chees_per_grad, ) - return distribute_lib.reduce_mean( - per_chain_chees_per_grad, named_axis=chain_named_axis), extra + return ( + distribute_lib.reduce_mean( + per_chain_chees_per_grad, named_axis=chain_named_axis + ), + extra, + ) @util.named_call -def _halton(float_index: fun_mc.FloatTensor, - max_bits: fun_mc.FloatTensor = 10) -> fun_mc.FloatTensor: - float_index = tf.convert_to_tensor(float_index) - bit_masks = 2**tf.range(max_bits, dtype=float_index.dtype) - return tf.einsum('i,i->', tf.math.mod((float_index + 1) // bit_masks, 2), - 0.5 / bit_masks) +def _halton( + float_index: fun_mc.FloatArray, max_bits: float | fun_mc.FloatArray = 10 +) -> fun_mc.FloatArray: + float_index = jnp.asarray(float_index) + bit_masks = 2 ** jnp.arange(max_bits, dtype=float_index.dtype) + return jnp.einsum( + 'i,i->', jnp.mod((float_index + 1) // bit_masks, 2), 0.5 / bit_masks + ) class DefaultTrajectoryLengthParams(NamedTuple): """Learnable trajectory length parameters.""" - log_mean_trajectory_length: fun_mc.FloatTensor + + log_mean_trajectory_length: fun_mc.FloatArray @util.named_call - def mean_trajectory_length(self) -> fun_mc.FloatTensor: + def mean_trajectory_length(self) -> fun_mc.FloatArray: """Computes the mean trajectory length.""" - return tf.exp(self.log_mean_trajectory_length) + return jnp.exp(self.log_mean_trajectory_length) @util.named_call def default_trajectory_length_sample( trajectory_length_params: DefaultTrajectoryLengthParams, - step: fun_mc.IntTensor, seed: Any) -> fun_mc.FloatTensor: + step: fun_mc.IntArray, + seed: Any, +) -> fun_mc.FloatArray: """Samples a trajectory length. The trajectory length is sampled from `[0, 2 * mean_trajectory_length]`. The @@ -444,18 +480,24 @@ def default_trajectory_length_sample( trajectory_length: Sampled trajectory. """ del seed - mean_trajectory_length = tf.exp( - fun_mc.clip_grads(trajectory_length_params.log_mean_trajectory_length, - 1.)) - trajectory_length = 2 * _halton(tf.cast( - step, mean_trajectory_length.dtype)) * mean_trajectory_length + mean_trajectory_length = jnp.exp( + fun_mc.clip_grads( + trajectory_length_params.log_mean_trajectory_length, 1.0 + ) + ) + trajectory_length = ( + 2 + * _halton(jnp.asarray(step, mean_trajectory_length.dtype)) + * mean_trajectory_length + ) return trajectory_length @util.named_call def default_trajectory_length_constrain( trajectory_length_params: DefaultTrajectoryLengthParams, - max_trajectory_length: fun_mc.FloatTensor = 3.) -> fun_mc.FloatTensor: + max_trajectory_length: float | fun_mc.FloatArray = 3.0, +) -> DefaultTrajectoryLengthParams: """Constrains the trajectory parameters. Args: @@ -465,19 +507,22 @@ def default_trajectory_length_constrain( Returns: trajectory_length_params: Constrained trajectory params. """ - max_trajectory_length = tf.convert_to_tensor( + max_trajectory_length = jnp.asarray( max_trajectory_length, - trajectory_length_params.log_mean_trajectory_length.dtype) + trajectory_length_params.log_mean_trajectory_length.dtype, + ) return trajectory_length_params._replace( - log_mean_trajectory_length=tf.minimum( + log_mean_trajectory_length=jnp.minimum( trajectory_length_params.log_mean_trajectory_length, - tf.math.log(max_trajectory_length))) + jnp.log(max_trajectory_length), + ) + ) @util.named_call def default_trajectory_length_init( - init_trajectory_length: fun_mc.FloatTensor + init_trajectory_length: fun_mc.FloatArray, ) -> DefaultTrajectoryLengthParams: """Initializes trajectory parameters. @@ -488,24 +533,27 @@ def default_trajectory_length_init( trajectory_length_params: Initialized trajectory parameters. """ return DefaultTrajectoryLengthParams( - log_mean_trajectory_length=tf.math.log(init_trajectory_length)) + log_mean_trajectory_length=jnp.log(init_trajectory_length) + ) class StochasticGradientAscentHMCState(NamedTuple): """Stochastic Gradient Ascent Hamiltonian Monte Carlo state.""" + hmc_state: fun_mc.HamiltonianMonteCarloState - step: fun_mc.IntTensor + step: fun_mc.IntArray trajectory_length_params_opt_state: fun_mc.AdamState trajectory_length_params_rmean_state: fun_mc.RunningMeanState class StochasticGradientAscentHMCExtra(NamedTuple): """Stochastic Gradient Ascent Hamiltonian Monte Carlo extra.""" + hmc_extra: fun_mc.HamiltonianMonteCarloExtra - num_integrator_steps: fun_mc.IntTensor + num_integrator_steps: fun_mc.IntArray trajectory_length_params_opt_extra: fun_mc.AdamExtra trajectory_length_params: Any - criterion: fun_mc.FloatTensor + criterion: fun_mc.FloatArray criterion_extra: Any @@ -513,9 +561,11 @@ class StochasticGradientAscentHMCExtra(NamedTuple): def stochastic_gradient_ascent_hmc_init( state: fun_mc.State, target_log_prob_fn: fun_mc.PotentialFn, - init_trajectory_length: fun_mc.FloatTensor, - trajectory_length_params_init_fn: - Callable[[fun_mc.FloatTensor], Any] = default_trajectory_length_init): + init_trajectory_length: fun_mc.FloatArray, + trajectory_length_params_init_fn: Callable[ + [fun_mc.FloatArray], Any + ] = default_trajectory_length_init, +): """Initialize Stochastic Gradient Ascent HMC state. Args: @@ -529,14 +579,16 @@ def stochastic_gradient_ascent_hmc_init( Returns: sga_hmc_state: New Stochastic Gradient Ascent HMC state. """ - init_trajectory_length = tf.convert_to_tensor(init_trajectory_length) + init_trajectory_length = jnp.asarray(init_trajectory_length) init_trajectory_length_params = trajectory_length_params_init_fn( - init_trajectory_length) + init_trajectory_length + ) return StochasticGradientAscentHMCState( hmc_state=fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), - step=tf.ones([], tf.int32), + step=jnp.ones([], jnp.int32), trajectory_length_params_opt_state=fun_mc.adam_init( - init_trajectory_length_params), + init_trajectory_length_params + ), trajectory_length_params_rmean_state=fun_mc.running_mean_init( util.map_tree(lambda x: x.shape, init_trajectory_length_params), util.map_tree(lambda x: x.dtype, init_trajectory_length_params), @@ -549,12 +601,12 @@ def stochastic_gradient_ascent_hmc_step( sga_hmc_state: StochasticGradientAscentHMCState, scalar_step_size: fun_mc.FloatNest, criterion_fn: Callable[ - [fun_mc.State, fun_mc.State, fun_mc.FloatTensor, fun_mc.FloatTensor], - tuple[fun_mc.FloatTensor, Any], + [fun_mc.State, fun_mc.State, fun_mc.FloatArray, fun_mc.FloatArray], + tuple[fun_mc.FloatArray, Any], ], - trajectory_length_adaptation_rate: fun_mc.FloatTensor = 0.05, + trajectory_length_adaptation_rate: float | fun_mc.FloatArray = 0.05, trajectory_length_sample_fn: Callable[ - [Any, fun_mc.IntTensor, Any], fun_mc.FloatTensor + [Any, fun_mc.IntArray, Any], fun_mc.FloatArray ] = (default_trajectory_length_sample), trajectory_length_constrain_fn: Callable[[Any], Any] = ( default_trajectory_length_constrain @@ -562,8 +614,8 @@ def stochastic_gradient_ascent_hmc_step( adam_kwargs: Mapping[str, Any] = immutabledict.immutabledict( {'beta_1': 0.0, 'beta_2': 0.5} ), - averaging_window_steps: fun_mc.IntTensor = 100, - adapt: fun_mc.BooleanTensor = True, + averaging_window_steps: int | fun_mc.IntArray = 100, + adapt: bool | fun_mc.BooleanArray = True, seed: Any = None, **hmc_kwargs: Mapping[str, Any], ): @@ -631,22 +683,27 @@ def loss_fn(*args, **kwargs): rmean_params = sga_hmc_state.trajectory_length_params_rmean_state.mean adapting_params = fun_mc.recover_state_from_args(args, kwargs, rmean_params) params = fun_mc.choose(adapt, adapting_params, rmean_params) - trajectory_length = trajectory_length_sample_fn(params, sga_hmc_state.step, - sample_seed) + trajectory_length = trajectory_length_sample_fn( + params, sga_hmc_state.step, sample_seed + ) - hmc_state, hmc_extra = hamiltonian_monte_carlo_with_state_grads_step( + hmc_state, hmc_extra = hamiltonian_monte_carlo_with_state_grads_step( # pytype: disable=wrong-keyword-args sga_hmc_state.hmc_state, trajectory_length=trajectory_length, scalar_step_size=scalar_step_size, seed=hmc_seed, - **hmc_kwargs) + **hmc_kwargs, + ) - accept_prob = tf.exp( - tf.minimum( - tf.zeros_like(hmc_extra.hmc_extra.log_accept_ratio), - hmc_extra.hmc_extra.log_accept_ratio)) - accept_prob = tf.where( - tf.math.is_finite(accept_prob), accept_prob, tf.zeros_like(accept_prob)) + accept_prob = jnp.exp( + jnp.minimum( + jnp.zeros_like(hmc_extra.hmc_extra.log_accept_ratio), + hmc_extra.hmc_extra.log_accept_ratio, + ) + ) + accept_prob = jnp.where( + jnp.isfinite(accept_prob), accept_prob, jnp.zeros_like(accept_prob) + ) criterion, criterion_extra = criterion_fn( sga_hmc_state.hmc_state.state, @@ -657,37 +714,58 @@ def loss_fn(*args, **kwargs): trajectory_length + scalar_step_size, ) - return -criterion, (hmc_state, hmc_extra, criterion, criterion_extra, - params) + return -criterion, ( + hmc_state, + hmc_extra, + criterion, + criterion_extra, + params, + ) # Adapt trajectory. - trajectory_length_params_opt_state, trajectory_length_params_opt_extra = fun_mc.adam_step( - sga_hmc_state.trajectory_length_params_opt_state, - loss_fn, - learning_rate=trajectory_length_adaptation_rate, - **adam_kwargs, + trajectory_length_params_opt_state, trajectory_length_params_opt_extra = ( + fun_mc.adam_step( + sga_hmc_state.trajectory_length_params_opt_state, + loss_fn, + learning_rate=trajectory_length_adaptation_rate, + **adam_kwargs, + ) ) - (hmc_state, hmc_extra, criterion, criterion_extra, - trajectory_length_params) = trajectory_length_params_opt_extra.loss_extra + ( + hmc_state, + hmc_extra, + criterion, + criterion_extra, + trajectory_length_params, + ) = trajectory_length_params_opt_extra.loss_extra # Constrain trajectory params. trajectory_length_params_opt_state = fun_mc.choose( - adapt, trajectory_length_params_opt_state, - sga_hmc_state.trajectory_length_params_opt_state) + adapt, + trajectory_length_params_opt_state, + sga_hmc_state.trajectory_length_params_opt_state, + ) constrained_trajectory_length_params = trajectory_length_constrain_fn( - trajectory_length_params_opt_state.state) - trajectory_length_params_opt_state = trajectory_length_params_opt_state._replace( - state=constrained_trajectory_length_params) + trajectory_length_params_opt_state.state + ) + trajectory_length_params_opt_state = ( + trajectory_length_params_opt_state._replace( + state=constrained_trajectory_length_params + ) + ) # Update the running mean for trajectory params. trajectory_length_params_rmean_state, _ = fun_mc.running_mean_step( sga_hmc_state.trajectory_length_params_rmean_state, trajectory_length_params_opt_state.state, - window_size=averaging_window_steps) + window_size=averaging_window_steps, + ) trajectory_length_params_rmean_state = fun_mc.choose( - adapt, trajectory_length_params_rmean_state, - sga_hmc_state.trajectory_length_params_rmean_state) + adapt, + trajectory_length_params_rmean_state, + sga_hmc_state.trajectory_length_params_rmean_state, + ) sga_hmc_state = sga_hmc_state._replace( hmc_state=hmc_state, diff --git a/spinoffs/fun_mc/fun_mc/sga_hmc_test.py b/spinoffs/fun_mc/fun_mc/sga_hmc_test.py index 4cdee429ce..7da2538907 100644 --- a/spinoffs/fun_mc/fun_mc/sga_hmc_test.py +++ b/spinoffs/fun_mc/fun_mc/sga_hmc_test.py @@ -20,17 +20,16 @@ # Dependency imports from absl.testing import parameterized -import jax from jax import config as jax_config import tensorflow.compat.v2 as real_tf - from tensorflow_probability.python.internal import test_util as tfp_test_util from fun_mc import backend from fun_mc import fun_mc_lib as fun_mc from fun_mc import sga_hmc from fun_mc import test_util -tf = backend.tf +jax = backend.jax +jnp = backend.jnp tfp = backend.tfp util = backend.util tfd = tfp.distributions @@ -39,13 +38,16 @@ real_tf.enable_v2_behavior() +real_tf.experimental.numpy.experimental_enable_numpy_behavior() jax_config.update('jax_enable_x64', True) BACKEND = None # Rewritten by backends/rewrite.py. if BACKEND == 'backend_jax': - os.environ['XLA_FLAGS'] = (f'{os.environ.get("XLA_FLAGS", "")} ' - '--xla_force_host_platform_device_count=4') + os.environ['XLA_FLAGS'] = ( + f'{os.environ.get("XLA_FLAGS", "")} ' + '--xla_force_host_platform_device_count=4' + ) def _test_seed(): @@ -53,7 +55,6 @@ def _test_seed(): class SGAHMCTest(tfp_test_util.TestCase): - _is_on_jax = BACKEND == 'backend_jax' def _make_seed(self, seed): @@ -67,28 +68,30 @@ def _dtype(self): raise NotImplementedError() def _constant(self, value): - return tf.constant(value, self._dtype) + return jnp.array(value, self._dtype) def testHMCWithStateGrads(self): - trajectory_length = 1. + trajectory_length = 1.0 epsilon = 1e-3 seed = self._make_seed(_test_seed()) def hmc_step(trajectory_length, axis_name=()): - @tfp.experimental.distribute.JointDistributionCoroutine def model(): - z = yield Root(tfd.Normal(0., 1)) + z = yield Root(tfd.Normal(0.0, 1)) yield tfp.experimental.distribute.Sharded( - tfd.Sample(tfd.Normal(z, 1.), 8), axis_name) + tfd.Sample(tfd.Normal(z, 1.0), 8), axis_name + ) @tfp.experimental.distribute.JointDistributionCoroutine def momentum_dist(): - yield Root(tfd.Normal(0., 2)) + yield Root(tfd.Normal(0.0, 2)) yield Root( tfp.experimental.distribute.Sharded( - tfd.Sample(tfd.Normal(0., 3.), 8), axis_name)) + tfd.Sample(tfd.Normal(0.0, 3.0), 8), axis_name + ) + ) def target_log_prob_fn(x): return model.log_prob(x), () @@ -106,31 +109,41 @@ def momentum_sample_fn(seed): hmc_state, trajectory_length=trajectory_length, scalar_step_size=epsilon, - step_size_scale=util.map_tree(lambda x: 1. + tf.abs(x), state), + step_size_scale=util.map_tree(lambda x: 1.0 + jnp.abs(x), state), target_log_prob_fn=target_log_prob_fn, seed=seed, kinetic_energy_fn=kinetic_energy_fn, momentum_sample_fn=momentum_sample_fn, - named_axis=model.experimental_shard_axis_names)) + named_axis=model.experimental_shard_axis_names, + ) + ) def sum_state(x, axis_name): - res = tf.reduce_sum(x**2) + res = jnp.sum(x**2) if axis_name: res = backend.distribute_lib.psum(res, axis_name) return res - sum_sq = util.map_tree_up_to(hmc_extra.proposed_state, sum_state, - hmc_extra.proposed_state, - model.experimental_shard_axis_names) + sum_sq = util.map_tree_up_to( + hmc_extra.proposed_state, + sum_state, + hmc_extra.proposed_state, + model.experimental_shard_axis_names, + ) sum_sq = sum(util.flatten_tree(sum_sq)) return sum_sq, () def finite_diff_grad(f, epsilon, x): - return (fun_mc.call_potential_fn(f, util.map_tree( - lambda x: x + epsilon, x))[0] - fun_mc.call_potential_fn( - f, util.map_tree(lambda x: x - epsilon, x))[0]) / (2 * epsilon) - - f = tf.function(hmc_step) + return ( + fun_mc.call_potential_fn(f, util.map_tree(lambda x: x + epsilon, x))[ + 0 + ] + - fun_mc.call_potential_fn( + f, util.map_tree(lambda x: x - epsilon, x) + )[0] + ) / (2 * epsilon) + + f = jax.jit(hmc_step) auto_diff = util.value_and_grad(f, trajectory_length)[2] finite_diff = finite_diff_grad(f, epsilon, trajectory_length) @@ -140,13 +153,16 @@ def finite_diff_grad(f, epsilon, x): @functools.partial(jax.pmap, axis_name='i') def run(_): - f = tf.function(lambda trajectory_length: hmc_step( # pylint: disable=g-long-lambda - trajectory_length, axis_name='i')) + f = jax.jit( + lambda trajectory_length: hmc_step( # pylint: disable=g-long-lambda + trajectory_length, axis_name='i' + ) + ) auto_diff = util.value_and_grad(f, trajectory_length)[2] finite_diff = finite_diff_grad(f, epsilon, trajectory_length) return auto_diff, finite_diff - auto_diff, finite_diff = run(tf.ones(4)) + auto_diff, finite_diff = run(jnp.ones(4)) self.assertAllClose(auto_diff, finite_diff, rtol=0.01) @parameterized.named_parameters( @@ -180,25 +196,34 @@ def testCriterion(self, criterion_fn, with_trajectory, with_state_mean): seed = self._make_seed(_test_seed()) seeds = util.split_seed(seed, 4) previous_state = { - 'global': - self._constant(util.random_normal([2, 3], self._dtype, seeds[0])), - 'local': - self._constant( - util.random_normal([2, 2, 3], self._dtype, seeds[1])) + 'global': self._constant( + util.random_normal([2, 3], self._dtype, seeds[0]) + ), + 'local': self._constant( + util.random_normal([2, 2, 3], self._dtype, seeds[1]) + ), } named_axis = util.map_tree(lambda _: [], previous_state) chain_named_axis = [] trajectory_length = self._constant(0.5) accept_prob = self._constant([0.1, 0.5]) state_mean = { - 'global': - self._constant(util.random_normal([3], self._dtype, seeds[2])), - 'local': - self._constant(util.random_normal([2, 3], self._dtype, seeds[3])) + 'global': self._constant( + util.random_normal([3], self._dtype, seeds[2]) + ), + 'local': self._constant( + util.random_normal([2, 3], self._dtype, seeds[3]) + ), } - def eval_criterion(trajectory_length, previous_state, accept_prob, - chain_named_axis, state_mean, named_axis): + def eval_criterion( + trajectory_length, + previous_state, + accept_prob, + chain_named_axis, + state_mean, + named_axis, + ): extra_kwargs = {} if with_trajectory: extra_kwargs.update(trajectory_length=trajectory_length) @@ -208,13 +233,15 @@ def eval_criterion(trajectory_length, previous_state, accept_prob, def proposed_state_part(previous_state, named_axis): if BACKEND == 'backend_jax': part_trajectory_length = distribute_lib.pbroadcast( - trajectory_length, [chain_named_axis, named_axis]) + trajectory_length, [chain_named_axis, named_axis] + ) else: part_trajectory_length = trajectory_length return (part_trajectory_length + 1) * previous_state - proposed_state = util.map_tree_up_to(previous_state, proposed_state_part, - previous_state, named_axis) + proposed_state = util.map_tree_up_to( + previous_state, proposed_state_part, previous_state, named_axis + ) return criterion_fn( previous_state=previous_state, @@ -222,7 +249,8 @@ def proposed_state_part(previous_state, named_axis): accept_prob=accept_prob, named_axis=named_axis, chain_named_axis=chain_named_axis, - **extra_kwargs) + **extra_kwargs, + ) value, _, grad = fun_mc.call_potential_fn_with_grads( functools.partial( @@ -231,14 +259,16 @@ def proposed_state_part(previous_state, named_axis): chain_named_axis=chain_named_axis, accept_prob=accept_prob, state_mean=state_mean, - named_axis=named_axis), trajectory_length) + named_axis=named_axis, + ), + trajectory_length, + ) self.assertEqual(self._dtype, value.dtype) self.assertEqual(self._dtype, grad.dtype) - self.assertAllGreater(tf.abs(grad), 0.) + self.assertAllGreater(jnp.abs(grad), 0.0) if BACKEND == 'backend_jax': - named_axis = { 'global': [], 'local': 'local', @@ -251,9 +281,9 @@ def proposed_state_part(previous_state, named_axis): @functools.partial(jax.pmap, axis_name='chain') def run_chain(previous_state, accept_prob): - @functools.partial( - jax.pmap, axis_name='local', in_axes=(in_axes, in_axes)) + jax.pmap, axis_name='local', in_axes=(in_axes, in_axes) + ) def run_state(previous_state, state_mean): value, _, grad = fun_mc.call_potential_fn_with_grads( functools.partial( @@ -262,7 +292,10 @@ def run_state(previous_state, state_mean): chain_named_axis=chain_named_axis, accept_prob=accept_prob, state_mean=state_mean, - named_axis=named_axis), trajectory_length) + named_axis=named_axis, + ), + trajectory_length, + ) return value, grad return run_state(previous_state, state_mean) @@ -272,44 +305,59 @@ def run_state(previous_state, state_mean): self.assertAllClose(grad, sharded_grad[0, 0]) def testSGAHMC(self): - @tfd.JointDistributionCoroutine def model(): - x = yield Root(tfd.Normal(self._constant(0.), 1.)) - yield tfd.Sample(tfd.Normal(x, 1.), 2) + x = yield Root(tfd.Normal(self._constant(0.0), 1.0)) + yield tfd.Sample(tfd.Normal(x, 1.0), 2) def target_log_prob_fn(x): return model.log_prob(x), () - @tf.function + @jax.jit def kernel(sga_hmc_state, step, seed): adapt = step < num_adapt_steps seed, hmc_seed = util.split_seed(seed, 2) - sga_hmc_state, sga_hmc_extra = sga_hmc.stochastic_gradient_ascent_hmc_step( - sga_hmc_state, - scalar_step_size=self._constant(0.1), - step_size_scale=self._constant(1.), - target_log_prob_fn=target_log_prob_fn, - criterion_fn=sga_hmc.chees_criterion, - adapt=adapt, - seed=hmc_seed, + sga_hmc_state, sga_hmc_extra = ( + sga_hmc.stochastic_gradient_ascent_hmc_step( + sga_hmc_state, + scalar_step_size=self._constant(0.1), + step_size_scale=self._constant(1.0), + target_log_prob_fn=target_log_prob_fn, + criterion_fn=sga_hmc.chees_criterion, + adapt=adapt, + seed=hmc_seed, + ) ) - return (sga_hmc_state, step + 1, seed - ), sga_hmc_extra.trajectory_length_params.mean_trajectory_length() + return ( + sga_hmc_state, + step + 1, + seed, + ), sga_hmc_extra.trajectory_length_params.mean_trajectory_length() init_trajectory_length = self._constant(0.1) num_adapt_steps = 10 _, trajectory_length = fun_mc.trace( - (sga_hmc.stochastic_gradient_ascent_hmc_init( - util.map_tree_up_to( - model.dtype, lambda dtype, shape: tf.zeros( # pylint: disable=g-long-lambda - (16,) + tuple(shape), dtype), model.dtype, - model.event_shape), - target_log_prob_fn, - init_trajectory_length=init_trajectory_length), 0, - self._make_seed(_test_seed())), kernel, num_adapt_steps + 2) + ( + sga_hmc.stochastic_gradient_ascent_hmc_init( + util.map_tree_up_to( + model.dtype, + lambda dtype, shape: jnp.zeros( # pylint: disable=g-long-lambda + (16,) + tuple(shape), dtype + ), + model.dtype, + model.event_shape, + ), + target_log_prob_fn, + init_trajectory_length=init_trajectory_length, + ), + 0, + self._make_seed(_test_seed()), + ), + kernel, + num_adapt_steps + 2, + ) # We expect it to increase as part of adaptation. self.assertAllGreater(trajectory_length[-1], init_trajectory_length) @@ -322,7 +370,7 @@ class SGAHMCTest32(SGAHMCTest): @property def _dtype(self): - return tf.float32 + return jnp.float32 @test_util.multi_backend_test(globals(), 'sga_hmc_test') @@ -330,7 +378,7 @@ class SGAHMCTest64(SGAHMCTest): @property def _dtype(self): - return tf.float64 + return jnp.float64 del SGAHMCTest diff --git a/spinoffs/fun_mc/fun_mc/test_util.py b/spinoffs/fun_mc/fun_mc/test_util.py index 9ff4d870a4..2fd9d912c8 100644 --- a/spinoffs/fun_mc/fun_mc/test_util.py +++ b/spinoffs/fun_mc/fun_mc/test_util.py @@ -19,10 +19,12 @@ BACKEND = None # Rewritten by backends/rewrite.py. -def multi_backend_test(globals_dict, - relative_module_name, - backends=('jax', 'tensorflow'), - test_case=None): +def multi_backend_test( + globals_dict, + relative_module_name, + backends=('jax', 'tensorflow'), + test_case=None, +): """Multi-backend test decorator. The end goal of this decorator is that the decorated test case is removed, and @@ -61,14 +63,16 @@ def multi_backend_test(globals_dict, return lambda test_case: multi_backend_test( # pylint: disable=g-long-lambda globals_dict=globals_dict, relative_module_name=relative_module_name, - test_case=test_case) + test_case=test_case, + ) if BACKEND is not None: return test_case if relative_module_name == '__main__': raise ValueError( - 'module_name should be written out manually, not by passing __name__.') + 'module_name should be written out manually, not by passing __name__.' + ) # This assumes `test_util` is 1 levels deep inside of `fun_mc`. If we # move it, we'd change the `-1` to equal the (negative) nesting level. @@ -81,16 +85,19 @@ def multi_backend_test(globals_dict, new_test_case_names = [] for backend in backends: new_module_name_comps = ( - root_name_comps + ['dynamic', 'backend_{}'.format(backend)] + - relative_module_name_comps) + root_name_comps + + ['dynamic', 'backend_{}'.format(backend)] + + relative_module_name_comps + ) # Rewrite the module. new_module = importlib.import_module('.'.join(new_module_name_comps)) # Subclass the test case so that we can rename it (absl uses the class name # in its UI). base_new_test = getattr(new_module, test_case.__name__) - new_test = type('{}_{}'.format(test_case.__name__, backend), - (base_new_test,), {}) + new_test = type( + '{}_{}'.format(test_case.__name__, backend), (base_new_test,), {} + ) new_test_case_names.append(new_test.__name__) globals_dict[new_test.__name__] = new_test diff --git a/spinoffs/fun_mc/fun_mc/using_jax.py b/spinoffs/fun_mc/fun_mc/using_jax.py index a4ebef3f61..d1f86ee9d0 100644 --- a/spinoffs/fun_mc/fun_mc/using_jax.py +++ b/spinoffs/fun_mc/fun_mc/using_jax.py @@ -19,6 +19,7 @@ from fun_mc.dynamic.backend_jax import api # pytype: disable=import-error # pylint: disable=wildcard-import from fun_mc.dynamic.backend_jax.api import * # pytype: disable=import-error + del rewrite __all__ = api.__all__ diff --git a/spinoffs/fun_mc/fun_mc/using_tensorflow.py b/spinoffs/fun_mc/fun_mc/using_tensorflow.py index 08b9e96543..acdc43c97c 100644 --- a/spinoffs/fun_mc/fun_mc/using_tensorflow.py +++ b/spinoffs/fun_mc/fun_mc/using_tensorflow.py @@ -19,6 +19,7 @@ from fun_mc.dynamic.backend_tensorflow import api # pytype: disable=import-error # pylint: disable=wildcard-import from fun_mc.dynamic.backend_tensorflow.api import * # pytype: disable=import-error + del rewrite __all__ = api.__all__ diff --git a/spinoffs/fun_mc/fun_mc/util_tfp.py b/spinoffs/fun_mc/fun_mc/util_tfp.py index 74caeb01dc..084654be5b 100644 --- a/spinoffs/fun_mc/fun_mc/util_tfp.py +++ b/spinoffs/fun_mc/fun_mc/util_tfp.py @@ -15,12 +15,11 @@ """FunMC utilities implemented via TensorFlow Probability.""" import functools - from typing import Any, Optional from fun_mc import backend from fun_mc import fun_mc_lib -tf = backend.tf +jnp = backend.jnp tfp = backend.tfp tfb = tfp.bijectors util = backend.util @@ -29,7 +28,8 @@ def bijector_to_transform_fn( bijector: fun_mc_lib.BijectorNest, state_structure: Any, - batch_ndims: fun_mc_lib.IntTensor = 0) -> fun_mc_lib.TransitionOperator: + batch_ndims: int = 0, +) -> fun_mc_lib.TransitionOperator: """Creates a TransitionOperator that transforms the state using a bijector. The returned operator has the following signature: @@ -60,34 +60,46 @@ def bijector_to_transform_fn( transform_fn: The created transformation. """ bijector_structure = util.get_shallow_tree( - lambda b: isinstance(b, tfb.Bijector), bijector) + lambda b: isinstance(b, tfb.Bijector), bijector + ) def transform_fn(bijector, state_structure, *args, **kwargs): """Transport map implemented via the bijector.""" state = fun_mc_lib.recover_state_from_args(args, kwargs, state_structure) - value = util.map_tree_up_to(bijector_structure, lambda b, x: b(x), bijector, - state) + value = util.map_tree_up_to( + bijector_structure, lambda b, x: b(x), bijector, state + ) ldj_parts = util.map_tree_up_to( bijector_structure, lambda b, x: b.forward_log_det_jacobian( # pylint: disable=g-long-lambda x, - event_ndims=util.map_tree(lambda x: tf.rank(x) - batch_ndims, x)), + event_ndims=util.map_tree(lambda x: len(x.shape) - batch_ndims, x), + ), bijector, - state) + state, + ) ldj = sum(util.flatten_tree(ldj_parts)) return value, ((), ldj) - inverse_bijector = util.map_tree_up_to(bijector_structure, - tfp.bijectors.Invert, bijector) + inverse_bijector = util.map_tree_up_to( + bijector_structure, tfp.bijectors.Invert, bijector + ) - forward_transform_fn = functools.partial(transform_fn, bijector, - state_structure) + forward_transform_fn = functools.partial( + transform_fn, bijector, state_structure + ) inverse_transform_fn = functools.partial( - transform_fn, inverse_bijector, - util.map_tree_up_to(bijector_structure, lambda b, s: b.forward_dtype(s), - bijector, state_structure)) + transform_fn, + inverse_bijector, + util.map_tree_up_to( + bijector_structure, + lambda b, s: b.forward_dtype(s), + bijector, + state_structure, + ), + ) forward_transform_fn.inverse = inverse_transform_fn inverse_transform_fn.inverse = forward_transform_fn @@ -96,8 +108,10 @@ def transform_fn(bijector, state_structure, *args, **kwargs): def transition_kernel_wrapper( - current_state: fun_mc_lib.FloatNest, kernel_results: Optional[Any], - kernel: tfp.mcmc.TransitionKernel) -> tuple[fun_mc_lib.FloatNest, Any]: + current_state: fun_mc_lib.FloatNest, + kernel_results: Optional[Any], + kernel: tfp.mcmc.TransitionKernel, +) -> tuple[fun_mc_lib.FloatNest, Any]: """Wraps a `tfp.mcmc.TransitionKernel` as a `TransitionOperator`. Args: @@ -113,7 +127,10 @@ def transition_kernel_wrapper( extra: An empty tuple. """ flat_current_state = util.flatten_tree(current_state) - flat_current_state, kernel_results = kernel.one_step(flat_current_state, - kernel_results) - return (util.unflatten_tree(current_state, - flat_current_state), kernel_results), () + flat_current_state, kernel_results = kernel.one_step( + flat_current_state, kernel_results + ) + return ( + util.unflatten_tree(current_state, flat_current_state), + kernel_results, + ), () diff --git a/spinoffs/fun_mc/fun_mc/util_tfp_test.py b/spinoffs/fun_mc/fun_mc/util_tfp_test.py index 6315f8e6e0..7e4d7fd78e 100644 --- a/spinoffs/fun_mc/fun_mc/util_tfp_test.py +++ b/spinoffs/fun_mc/fun_mc/util_tfp_test.py @@ -20,17 +20,17 @@ from jax import config as jax_config import numpy as np import tensorflow.compat.v2 as real_tf - from fun_mc import backend from fun_mc import fun_mc_lib as fun_mc from fun_mc import test_util from fun_mc import util_tfp -tf = backend.tf +jnp = backend.jnp tfp = backend.tfp util = backend.util real_tf.enable_v2_behavior() +real_tf.experimental.numpy.experimental_enable_numpy_behavior() jax_config.update('jax_enable_x64', True) @@ -43,7 +43,8 @@ def __init__(self): inverse_min_event_ndims=[0, 0], validate_args=False, parameters={}, - name='dup') + name='dup', + ) def forward(self, x, **kwargs): return [x, x] @@ -58,10 +59,10 @@ def inverse_event_shape(self, y_shape, **kwargs): return y_shape def forward_log_det_jacobian(self, x, event_ndims, **kwargs): - return 0. + return 0.0 def inverse_log_det_jacobian(self, y, event_ndims, **kwargs): - return 0. + return 0.0 def forward_dtype(self, x_dtype, **kwargs): return [x_dtype, x_dtype] @@ -77,10 +78,9 @@ def _dtype(self): raise NotImplementedError() def _constant(self, value): - return tf.constant(value, self._dtype) + return jnp.array(value, self._dtype) def testWrapTransitionKernel(self): - class TestKernel(tfp.mcmc.TransitionKernel): def one_step(self, current_state, previous_kernel_results): @@ -93,45 +93,52 @@ def is_calibrated(self): return True def kernel(state, pkr): - return util_tfp.transition_kernel_wrapper( - state, pkr, TestKernel()) + return util_tfp.transition_kernel_wrapper(state, pkr, TestKernel()) - state = {'x': self._constant(0.), 'y': self._constant(1.)} - kr = 1. + state = {'x': self._constant(0.0), 'y': self._constant(1.0)} + kr = 1.0 (final_state, final_kr), _ = fun_mc.trace( (state, kr), kernel, 2, trace_fn=lambda *args: (), ) - self.assertAllEqual({ - 'x': 2., - 'y': 3. - }, util.map_tree(np.array, final_state)) - self.assertAllEqual(1. + 2., final_kr) + self.assertAllEqual( + {'x': 2.0, 'y': 3.0}, util.map_tree(np.array, final_state) + ) + self.assertAllEqual(1.0 + 2.0, final_kr) def testBijectorToTransformFn(self): bijectors = [ tfp.bijectors.Identity(), - tfp.bijectors.Scale(self._constant([ - [1., 2.], - [3., 4.], - ])) + tfp.bijectors.Scale( + self._constant([ + [1.0, 2.0], + [3.0, 4.0], + ]) + ), ] state = [ - tf.ones([2, 1], dtype=self._dtype), - tf.ones([2, 2], dtype=self._dtype) + jnp.ones([2, 1], dtype=self._dtype), + jnp.ones([2, 2], dtype=self._dtype), ] transform_fn = util_tfp.bijector_to_transform_fn( - bijectors, state_structure=state, batch_ndims=1) + bijectors, state_structure=state, batch_ndims=1 + ) fwd, (_, fwd_ldj1), fwd_ldj2 = fun_mc.call_transport_map_with_ldj( - transform_fn, state) + transform_fn, state + ) self.assertAllClose( - [np.ones([2, 1]), np.array([ - [1., 2.], - [3., 4], - ])], fwd) + [ + np.ones([2, 1]), + np.array([ + [1.0, 2.0], + [3.0, 4], + ]), + ], + fwd, + ) true_fwd_ldj = np.array([ np.log(1) + np.log(2), @@ -143,38 +150,47 @@ def testBijectorToTransformFn(self): inverse_transform_fn = util.inverse_fn(transform_fn) inv, (_, inv_ldj1), inv_ldj2 = fun_mc.call_transport_map_with_ldj( - inverse_transform_fn, state) + inverse_transform_fn, state + ) self.assertAllClose( - [np.ones([2, 1]), - np.array([ - [1., 1. / 2.], - [1. / 3., 1. / 4.], - ])], inv) + [ + np.ones([2, 1]), + np.array([ + [1.0, 1.0 / 2.0], + [1.0 / 3.0, 1.0 / 4.0], + ]), + ], + inv, + ) self.assertAllClose(-true_fwd_ldj, inv_ldj1) self.assertAllClose(-true_fwd_ldj, inv_ldj2) def testBijectorToTransformFnMulti(self): bijector = DupBijector() - state = tf.ones([1, 2], dtype=self._dtype) + state = jnp.ones([1, 2], dtype=self._dtype) transform_fn = util_tfp.bijector_to_transform_fn( - bijector, state_structure=state, batch_ndims=1) + bijector, state_structure=state, batch_ndims=1 + ) fwd, (_, fwd_ldj1), fwd_ldj2 = fun_mc.call_transport_map_with_ldj( - transform_fn, state) + transform_fn, state + ) self.assertAllClose([np.ones([1, 2]), np.ones([1, 2])], fwd) - self.assertAllClose(0., fwd_ldj1) - self.assertAllClose(0., fwd_ldj2) + self.assertAllClose(0.0, fwd_ldj1) + self.assertAllClose(0.0, fwd_ldj2) inverse_transform_fn = util.inverse_fn(transform_fn) inv, (_, inv_ldj1), inv_ldj2 = fun_mc.call_transport_map_with_ldj( - inverse_transform_fn, [ - tf.ones([1, 2], dtype=self._dtype), - tf.ones([2, 1], dtype=self._dtype) - ]) + inverse_transform_fn, + [ + jnp.ones([1, 2], dtype=self._dtype), + jnp.ones([2, 1], dtype=self._dtype), + ], + ) self.assertAllClose(np.ones([1, 2]), inv) - self.assertAllClose(0., inv_ldj1) - self.assertAllClose(0., inv_ldj2) + self.assertAllClose(0.0, inv_ldj1) + self.assertAllClose(0.0, inv_ldj2) @test_util.multi_backend_test(globals(), 'util_tfp_test') @@ -182,7 +198,7 @@ class UtilTFPTest32(UtilTFPTest): @property def _dtype(self): - return tf.float32 + return jnp.float32 @test_util.multi_backend_test(globals(), 'util_tfp_test') @@ -190,7 +206,7 @@ class UtilTFPTest64(UtilTFPTest): @property def _dtype(self): - return tf.float64 + return jnp.float64 del UtilTFPTest