From b21689ab87a5f771137824933772ef455c7c40f7 Mon Sep 17 00:00:00 2001 From: siege Date: Thu, 20 Jun 2024 13:25:23 -0700 Subject: [PATCH] FunMC: Allow specifying `max_steps` to enable tracing even when `num_steps` is dynamic. PiperOrigin-RevId: 645136707 --- .../fun_mc/fun_mc/dynamic/backend_jax/util.py | 47 +++++++++++++------ .../fun_mc/dynamic/backend_tensorflow/util.py | 32 +++++++++---- spinoffs/fun_mc/fun_mc/fun_mc_lib.py | 7 ++- spinoffs/fun_mc/fun_mc/fun_mc_test.py | 31 +++++++++++- 4 files changed, 92 insertions(+), 25 deletions(-) diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py index 116d66321a..0a88d94f6a 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py @@ -166,33 +166,47 @@ def random_categorical(logits, num_samples, seed): return jax.vmap(_searchsorted)(flat_cum_sum, flat_eta).reshape(eta.shape).T -def trace(state, fn, num_steps, unroll, **_): +def trace(state, fn, num_steps, unroll, max_steps, **_): """Implementation of `trace` operator, without the calling convention.""" # We need the shapes and dtypes of the outputs of `fn`. _, untraced_spec, traced_spec = jax.eval_shape( fn, map_tree(lambda s: jax.ShapeDtypeStruct(s.shape, s.dtype), state)) - untraced_init = map_tree(lambda spec: jnp.zeros(spec.shape, spec.dtype), - untraced_spec) + untraced_init, traced_init = map_tree( + lambda spec: jnp.zeros(spec.shape, spec.dtype), + (untraced_spec, traced_spec), + ) try: num_steps = int(num_steps) use_scan = True except TypeError: use_scan = False - if flatten_tree(traced_spec): - raise ValueError( - 'Cannot trace values when `num_steps` is not statically known. Pass ' - 'False to `trace_mask` or return an empty structure (e.g. `()`) as ' - 'the extra output.') - if unroll: - raise ValueError( - 'Cannot unroll when `num_steps` is not statically known.') + if max_steps is None: + if flatten_tree(traced_spec): + raise ValueError( # pylint: disable=raise-missing-from + 'Cannot trace values when `num_steps` is not statically known and ' + '`max_steps` is not specified. Pass `False` to `trace_mask` or ' + 'return an empty structure (e.g. `()`) as ' + 'the extra output.' + ) + if unroll: + raise ValueError( # pylint: disable=raise-missing-from + 'Cannot unroll when `num_steps` is not statically known and ' + '`max_steps` is not specified.' + ) + if max_steps is not None: + use_scan = False if unroll: + num_outputs = num_steps if max_steps is None else max_steps + traced_lists = map_tree(lambda _: [], traced_spec) untraced = untraced_init - for _ in range(num_steps): - state, untraced, traced_element = fn(state) + for step in range(num_outputs): + if step < num_steps: + state, untraced, traced_element = fn(state) + else: + traced_element = traced_init map_tree_up_to(traced_spec, lambda l, e: l.append(e), traced_lists, traced_element) # Using asarray instead of stack to handle empty arrays correctly. @@ -213,8 +227,13 @@ def wrapper(state_untraced, _): length=num_steps, ) else: + num_outputs = num_steps if max_steps is None else max_steps + num_steps = ( + num_steps if max_steps is None else jnp.minimum(num_steps, max_steps) + ) + trace_arrays = map_tree( - lambda spec: jnp.zeros((num_steps,) + spec.shape, spec.dtype), + lambda spec: jnp.zeros((num_outputs,) + spec.shape, spec.dtype), traced_spec) def wrapper(i, state_untraced_traced): diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py index 9041acd500..cb3cd8ae6e 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py @@ -177,14 +177,16 @@ def compiled_fn(x): return compiled_fn, output_spec -def trace(state, fn, num_steps, unroll, parallel_iterations=10): +def trace(state, fn, num_steps, unroll, max_steps, parallel_iterations=10): """TF implementation of `trace` operator, without the calling convention.""" + num_outputs = num_steps if max_steps is None else max_steps + if tf.config.experimental_functions_run_eagerly() or tf.executing_eagerly(): state, first_untraced, first_traced = fn(state) arrays = tf.nest.map_structure( lambda v: tf.TensorArray( # pylint: disable=g-long-lambda v.dtype, - size=num_steps, + size=num_outputs, element_shape=v.shape).write(0, v), first_traced) start_idx = 1 @@ -198,7 +200,7 @@ def trace(state, fn, num_steps, unroll, parallel_iterations=10): arrays = tf.nest.map_structure( lambda spec: tf.TensorArray( # pylint: disable=g-long-lambda spec.dtype, - size=num_steps, + size=num_outputs, element_shape=spec.shape), traced_spec) first_untraced = tf.nest.map_structure( @@ -214,22 +216,36 @@ def cond(i, *_): return i < num_steps static_num_steps = tf.get_static_value(num_steps) + static_num_outputs = tf.get_static_value(num_outputs) loop_vars = (start_idx, state, first_untraced, arrays) if unroll: if static_num_steps is None: raise ValueError( - 'Cannot unroll when `num_steps` is not statically known.') + 'Cannot unroll when `num_steps` is not statically known or ' + '`max_steps` is None.' + ) + static_num_iters = ( + static_num_steps + if max_steps is None + else min(static_num_steps, max_steps) + ) # TODO(siege): Investigate if using lists instead of TensorArray's is faster # (like is done in the JAX backend). - for _ in range(start_idx, static_num_steps): + for _ in range(start_idx, static_num_iters): loop_vars = body(*loop_vars) _, state, untraced, arrays = loop_vars else: if static_num_steps is None: - maximum_iterations = None + if max_steps is None: + maximum_iterations = None + else: + maximum_iterations = max_steps - start_idx else: - maximum_iterations = static_num_steps - start_idx + if max_steps is None: + maximum_iterations = static_num_steps - start_idx + else: + maximum_iterations = min(static_num_steps, max_steps) - start_idx _, state, untraced, arrays = tf.while_loop( cond=cond, body=body, @@ -241,7 +257,7 @@ def cond(i, *_): traced = tf.nest.map_structure(lambda a: a.stack(), arrays) def _merge_static_length(x): - x.set_shape(tf.TensorShape(static_num_steps).concatenate(x.shape[1:])) + x.set_shape(tf.TensorShape(static_num_outputs).concatenate(x.shape[1:])) return x traced = tf.nest.map_structure(_merge_static_length, traced) diff --git a/spinoffs/fun_mc/fun_mc/fun_mc_lib.py b/spinoffs/fun_mc/fun_mc/fun_mc_lib.py index a507e5d2ba..03069a14f1 100644 --- a/spinoffs/fun_mc/fun_mc/fun_mc_lib.py +++ b/spinoffs/fun_mc/fun_mc/fun_mc_lib.py @@ -220,6 +220,7 @@ def trace( trace_fn: Callable[[State, ArrayNest], ArrayNest] = _trace_extra, trace_mask: bool | BooleanNest = True, unroll: bool = False, + max_steps: int | None = None, parallel_iterations: int = 10, ) -> tuple[State, ArrayNest]: """`TransitionOperator` that runs `fn` repeatedly and traces its outputs. @@ -243,8 +244,11 @@ def trace( unroll: Whether to unroll the loop. This can occasionally lead to improved performance at the cost of increasing the XLA optimization time. Only works if `num_steps` is statically known. + max_steps: If `num_steps` is not statically known and you still want to + trace values, you can use `max_steps` to allocate output trace to be of + this length. Only elements up to `num_steps` will be valid, however. parallel_iterations: Number of iterations of the while loop to run in - parallel. + parallel (TensorFlow-only). Returns: state: The final state returned by `fn`. @@ -295,6 +299,7 @@ def wrapper(state): fn=wrapper, num_steps=num_steps, unroll=unroll, + max_steps=max_steps, parallel_iterations=parallel_iterations, ) diff --git a/spinoffs/fun_mc/fun_mc/fun_mc_test.py b/spinoffs/fun_mc/fun_mc/fun_mc_test.py index d45b9c879a..21382bf439 100644 --- a/spinoffs/fun_mc/fun_mc/fun_mc_test.py +++ b/spinoffs/fun_mc/fun_mc/fun_mc_test.py @@ -211,10 +211,37 @@ def testTraceDynamic(self): @jax.jit def trace_n(num_steps): - return fun_mc.trace(0, lambda x: (x + 1, ()), num_steps)[0] + return fun_mc.trace( + 0, + lambda x: (x + 1, (10 * x, 100 * x)), + num_steps, + max_steps=6, + trace_mask=(True, False), + ) - x = trace_n(5) + x, (traced, untraced) = trace_n(5) + self.assertAllEqual(5, x) + self.assertAllEqual(40, traced[4]) + self.assertEqual(6, traced.shape[0]) + self.assertAllEqual(400, untraced) + + @parameterized.named_parameters( + ('Unrolled', True), + ('NotUnrolled', False), + ) + def testTraceMaxSteps(self, unroll): + x, (traced, untraced) = fun_mc.trace( + 0, + lambda x: (x + 1, (10 * x, 100 * x)), + 5, + max_steps=6, + unroll=unroll, + trace_mask=(True, False), + ) self.assertAllEqual(5, x) + self.assertAllEqual(40, traced[4]) + self.assertEqual(6, traced.shape[0]) + self.assertAllEqual(400, untraced) @parameterized.named_parameters( ('Unrolled', True),