Skip to content

Commit

Permalink
FunMC: Start on a better implementation of SMC.
Browse files Browse the repository at this point in the history
Other assorted changes:
- Fix exception chaining in call_transition_operator

PiperOrigin-RevId: 698451948
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Nov 20, 2024
1 parent b7ab24d commit 5536afd
Show file tree
Hide file tree
Showing 11 changed files with 545 additions and 57 deletions.
68 changes: 58 additions & 10 deletions spinoffs/fun_mc/fun_mc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ py_library(
deps = [
":fun_mc_lib",
":prefab",
":smc",
":types",
":util_tfp",
],
)
Expand Down Expand Up @@ -85,8 +87,10 @@ py_library(
srcs = ["using_jax.py"],
deps = [
":api",
# jax dep,
"//fun_mc/backends:rewrite",
"//fun_mc/dynamic/backend_jax:backend",
# tensorflow_probability/substrates:jax dep,
],
)

Expand All @@ -96,6 +100,8 @@ py_library(
srcs = ["using_tensorflow.py"],
deps = [
":api",
# tensorflow dep,
# tensorflow_probability dep,
"//fun_mc/backends:rewrite",
"//fun_mc/dynamic/backend_tensorflow:backend",
],
Expand Down Expand Up @@ -144,6 +150,33 @@ py_test(
],
)

# pytype
py_library(
name = "prefab",
srcs = ["prefab.py"],
deps = [
":backend",
":fun_mc_lib",
":malt",
":sga_hmc",
],
)

py_test(
name = "prefab_test",
srcs = ["prefab_test.py"],
shard_count = 2,
deps = [
":fun_mc",
":prefab",
":test_util",
# jax dep,
# tensorflow dep,
# tensorflow_probability/python/internal:test_util dep,
# tensorflow/compiler/jit dep,
],
)

# pytype
py_library(
name = "sga_hmc",
Expand Down Expand Up @@ -172,31 +205,46 @@ py_test(

# pytype
py_library(
name = "prefab",
srcs = ["prefab.py"],
name = "smc",
srcs = ["smc.py"],
deps = [
":backend",
":fun_mc_lib",
":malt",
":sga_hmc",
":types",
],
)

py_test(
name = "prefab_test",
srcs = ["prefab_test.py"],
shard_count = 2,
pytype_strict_contrib_test(
name = "smc_test",
srcs = ["smc_test.py"],
shard_count = 4,
deps = [
":fun_mc",
":prefab",
":backend",
":fun_mc_lib",
":smc",
":test_util",
":types",
# absl/testing:parameterized dep,
# jax dep,
# jaxtyping dep,
# mock dep,
# tensorflow dep,
# tensorflow_probability/python/internal:test_util dep,
# tensorflow/compiler/jit dep,
],
)

# pytype
py_library(
name = "types",
srcs = ["types.py"],
deps = [
":backend",
# jaxtyping dep,
# typeguard dep,
],
)

# pytype
py_library(
name = "util_tfp",
Expand Down
6 changes: 5 additions & 1 deletion spinoffs/fun_mc/fun_mc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@

from fun_mc import fun_mc_lib
from fun_mc import prefab
from fun_mc import smc
from fun_mc import types
from fun_mc import util_tfp
from fun_mc.fun_mc_lib import *
from fun_mc.smc import *
from fun_mc.types import *

__all__ = [
'prefab',
'util_tfp',
] + fun_mc_lib.__all__
] + fun_mc_lib.__all__ + smc.__all__ + types.__all__
4 changes: 1 addition & 3 deletions spinoffs/fun_mc/fun_mc/dynamic/backend_jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ py_library(
name = "util",
srcs = ["util.py"],
deps = [
# jax dep,
# jax:stax dep,
# jaxtyping dep,
],
)

Expand All @@ -42,6 +41,5 @@ py_library(
srcs = ["backend.py"],
deps = [
":util",
# tensorflow_probability/substrates:jax dep,
],
)
49 changes: 37 additions & 12 deletions spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
from jax import random
from jax import tree_util
import jax.numpy as jnp
import jaxtyping

__all__ = [
'Array',
'assert_same_shallow_tree',
'block_until_ready',
'convert_to_tensor',
'diff',
'DType',
'flatten_tree',
'get_shallow_tree',
'inverse_fn',
Expand All @@ -39,8 +42,10 @@
'random_categorical',
'random_integer',
'random_normal',
'random_permutation',
'random_uniform',
'repeat',
'Seed',
'split_seed',
'stack_dynamic_array',
'trace',
Expand All @@ -50,6 +55,11 @@
]


Array = jaxtyping.Array
DType = jax.typing.DTypeLike
Seed = jaxtyping.PRNGKeyArray


def map_tree(fn, tree, *args):
"""Maps `fn` over the leaves of a nested structure."""
return tree_util.tree_map(fn, tree, *args)
Expand Down Expand Up @@ -102,8 +112,9 @@ def make_tensor_seed(seed):
if hasattr(seed, 'dtype') and jax.dtypes.issubdtype(
seed.dtype, jax.dtypes.prng_key
):
return seed
return jnp.asarray(seed, jnp.uint32)
return jnp.asarray(seed)
else:
return jnp.asarray(seed, jnp.uint32)


def split_seed(seed, count):
Expand All @@ -114,7 +125,8 @@ def split_seed(seed, count):
def random_uniform(shape, dtype, seed):
"""Generates a sample from uniform distribution over [0., 1)."""
return random.uniform(
shape=tuple(shape), dtype=dtype, key=make_tensor_seed(seed))
shape=tuple(shape), dtype=dtype, key=make_tensor_seed(seed)
)


def random_integer(shape, dtype, minval, maxval, seed):
Expand All @@ -124,13 +136,15 @@ def random_integer(shape, dtype, minval, maxval, seed):
dtype=dtype,
minval=minval,
maxval=maxval,
key=make_tensor_seed(seed))
key=make_tensor_seed(seed),
)


def random_normal(shape, dtype, seed):
"""Generates a sample from a standard normal distribution."""
return random.normal(
shape=tuple(shape), dtype=dtype, key=make_tensor_seed(seed))
shape=tuple(shape), dtype=dtype, key=make_tensor_seed(seed)
)


def _searchsorted(a, v):
Expand Down Expand Up @@ -158,19 +172,26 @@ def random_categorical(logits, num_samples, seed):
cum_sum = jnp.cumsum(probs, axis=-1)

eta = random.uniform(
make_tensor_seed(seed), (num_samples,) + cum_sum.shape[:-1])
make_tensor_seed(seed), (num_samples,) + cum_sum.shape[:-1]
)
cum_sum = jnp.broadcast_to(cum_sum, (num_samples,) + cum_sum.shape)

flat_cum_sum = cum_sum.reshape([-1, cum_sum.shape[-1]])
flat_eta = eta.reshape([-1])
return jax.vmap(_searchsorted)(flat_cum_sum, flat_eta).reshape(eta.shape).T


def random_permutation(value, seed):
"""Randomly permutes the array."""
return random.permutation(seed, value)


def trace(state, fn, num_steps, unroll, max_steps, **_):
"""Implementation of `trace` operator, without the calling convention."""
# We need the shapes and dtypes of the outputs of `fn`.
_, untraced_spec, traced_spec, stop_spec = jax.eval_shape(
fn, map_tree(lambda s: jax.ShapeDtypeStruct(s.shape, s.dtype), state))
fn, map_tree(lambda s: jax.ShapeDtypeStruct(s.shape, s.dtype), state)
)
if isinstance(stop_spec, tuple):
stop = ()
else:
Expand Down Expand Up @@ -211,12 +232,16 @@ def trace(state, fn, num_steps, unroll, max_steps, **_):
state, untraced, traced_element, stop = fn(state)
else:
traced_element = traced_init
map_tree_up_to(traced_spec, lambda l, e: l.append(e), traced_lists,
traced_element)
map_tree_up_to(
traced_spec, lambda l, e: l.append(e), traced_lists, traced_element
)
# Using asarray instead of stack to handle empty arrays correctly.
traced = map_tree_up_to(traced_spec,
lambda l, s: jnp.asarray(l, dtype=s.dtype),
traced_lists, traced_spec)
traced = map_tree_up_to(
traced_spec,
lambda l, s: jnp.asarray(l, dtype=s.dtype),
traced_lists,
traced_spec,
)
elif use_scan:

def wrapper(state_untraced, _):
Expand Down
3 changes: 0 additions & 3 deletions spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ py_library(
srcs = ["util.py"],
deps = [
# numpy dep,
# tensorflow dep,
],
)

Expand All @@ -37,7 +36,5 @@ py_library(
srcs = ["backend.py"],
deps = [
":util",
# tensorflow dep,
# tensorflow_probability dep,
],
)
Loading

0 comments on commit 5536afd

Please sign in to comment.