diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py index 0a88d94f6a..cd9ea3e493 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py @@ -169,8 +169,12 @@ def random_categorical(logits, num_samples, seed): def trace(state, fn, num_steps, unroll, max_steps, **_): """Implementation of `trace` operator, without the calling convention.""" # We need the shapes and dtypes of the outputs of `fn`. - _, untraced_spec, traced_spec = jax.eval_shape( + _, untraced_spec, traced_spec, stop_spec = jax.eval_shape( fn, map_tree(lambda s: jax.ShapeDtypeStruct(s.shape, s.dtype), state)) + if isinstance(stop_spec, tuple): + stop = () + else: + stop = False untraced_init, traced_init = map_tree( lambda spec: jnp.zeros(spec.shape, spec.dtype), (untraced_spec, traced_spec), @@ -194,7 +198,7 @@ def trace(state, fn, num_steps, unroll, max_steps, **_): 'Cannot unroll when `num_steps` is not statically known and ' '`max_steps` is not specified.' ) - if max_steps is not None: + if max_steps is not None or not isinstance(stop_spec, tuple): use_scan = False if unroll: @@ -203,8 +207,8 @@ def trace(state, fn, num_steps, unroll, max_steps, **_): traced_lists = map_tree(lambda _: [], traced_spec) untraced = untraced_init for step in range(num_outputs): - if step < num_steps: - state, untraced, traced_element = fn(state) + if step < num_steps and not stop: + state, untraced, traced_element, stop = fn(state) else: traced_element = traced_init map_tree_up_to(traced_spec, lambda l, e: l.append(e), traced_lists, @@ -217,7 +221,7 @@ def trace(state, fn, num_steps, unroll, max_steps, **_): def wrapper(state_untraced, _): state, _ = state_untraced - state, untraced, traced = fn(state) + state, untraced, traced, _ = fn(state) return (state, untraced), traced (state, untraced), traced = lax.scan( @@ -234,19 +238,31 @@ def wrapper(state_untraced, _): trace_arrays = map_tree( lambda spec: jnp.zeros((num_outputs,) + spec.shape, spec.dtype), - traced_spec) + traced_spec, + ) + loop_vars = ( + jnp.zeros_like(num_steps), + stop, + state, + untraced_init, + trace_arrays, + ) - def wrapper(i, state_untraced_traced): - state, _, trace_arrays = state_untraced_traced - state, untraced, traced = fn(state) + def cond(loop_vars): + i, stop, *_ = loop_vars + return (i < num_steps) & (isinstance(stop, tuple) or ~stop) + + def body(loop_vars): + i, _, state, _, trace_arrays = loop_vars + state, untraced, traced, stop = fn(state) trace_arrays = map_tree(lambda a, e: a.at[i].set(e), trace_arrays, traced) - return (state, untraced, trace_arrays) - state, untraced, traced = lax.fori_loop( - jnp.asarray(0, num_steps.dtype), - num_steps, - wrapper, - (state, untraced_init, trace_arrays), + return i + 1, stop, state, untraced, trace_arrays + + _, _, state, untraced, traced = lax.while_loop( + cond, + body, + loop_vars, ) return state, untraced, traced diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py index cb3cd8ae6e..92d0a5c275 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py @@ -182,7 +182,7 @@ def trace(state, fn, num_steps, unroll, max_steps, parallel_iterations=10): num_outputs = num_steps if max_steps is None else max_steps if tf.config.experimental_functions_run_eagerly() or tf.executing_eagerly(): - state, first_untraced, first_traced = fn(state) + state, first_untraced, first_traced, stop = fn(state) arrays = tf.nest.map_structure( lambda v: tf.TensorArray( # pylint: disable=g-long-lambda v.dtype, @@ -195,7 +195,7 @@ def trace(state, fn, num_steps, unroll, max_steps, parallel_iterations=10): # the `TensorArray`s etc., we can get it by pre-compiling the wrapper # function. input_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor, state) - fn, (_, untraced_spec, traced_spec) = _eval_shape(fn, input_spec) + fn, (_, untraced_spec, traced_spec, stop_spec) = _eval_shape(fn, input_spec) arrays = tf.nest.map_structure( lambda spec: tf.TensorArray( # pylint: disable=g-long-lambda @@ -206,18 +206,23 @@ def trace(state, fn, num_steps, unroll, max_steps, parallel_iterations=10): first_untraced = tf.nest.map_structure( lambda spec: tf.zeros(spec.shape, spec.dtype), untraced_spec) start_idx = 0 + if isinstance(stop_spec, tuple): + stop = () + else: + stop = False - def body(i, state, _, arrays): - state, untraced, traced = fn(state) + def body(i, stop, state, untraced, arrays): + del stop, untraced + state, untraced, traced, stop = fn(state) arrays = tf.nest.map_structure(lambda a, e: a.write(i, e), arrays, traced) - return i + 1, state, untraced, arrays + return i + 1, stop, state, untraced, arrays - def cond(i, *_): - return i < num_steps + def cond(i, stop, *_): + return (i < num_steps) & (isinstance(stop, tuple) or ~stop) static_num_steps = tf.get_static_value(num_steps) static_num_outputs = tf.get_static_value(num_outputs) - loop_vars = (start_idx, state, first_untraced, arrays) + loop_vars = (start_idx, stop, state, first_untraced, arrays) if unroll: if static_num_steps is None: @@ -233,8 +238,10 @@ def cond(i, *_): # TODO(siege): Investigate if using lists instead of TensorArray's is faster # (like is done in the JAX backend). for _ in range(start_idx, static_num_iters): + if loop_vars[1]: + break loop_vars = body(*loop_vars) - _, state, untraced, arrays = loop_vars + _, _, state, untraced, arrays = loop_vars else: if static_num_steps is None: if max_steps is None: @@ -246,7 +253,7 @@ def cond(i, *_): maximum_iterations = static_num_steps - start_idx else: maximum_iterations = min(static_num_steps, max_steps) - start_idx - _, state, untraced, arrays = tf.while_loop( + _, _, state, untraced, arrays = tf.while_loop( cond=cond, body=body, loop_vars=loop_vars, diff --git a/spinoffs/fun_mc/fun_mc/fun_mc_lib.py b/spinoffs/fun_mc/fun_mc/fun_mc_lib.py index 03069a14f1..afc87faa38 100644 --- a/spinoffs/fun_mc/fun_mc/fun_mc_lib.py +++ b/spinoffs/fun_mc/fun_mc/fun_mc_lib.py @@ -221,6 +221,7 @@ def trace( trace_mask: bool | BooleanNest = True, unroll: bool = False, max_steps: int | None = None, + stop_fn: Callable[[State, ArrayNest], BooleanArray] | None = None, parallel_iterations: int = 10, ) -> tuple[State, ArrayNest]: """`TransitionOperator` that runs `fn` repeatedly and traces its outputs. @@ -245,8 +246,11 @@ def trace( performance at the cost of increasing the XLA optimization time. Only works if `num_steps` is statically known. max_steps: If `num_steps` is not statically known and you still want to - trace values, you can use `max_steps` to allocate output trace to be of - this length. Only elements up to `num_steps` will be valid, however. + values, you can use `max_steps` to allocate output trace to be of this + length. Only elements up to `num_steps` will be valid, however. + stop_fn: Optional callable that takes in the outputs of `fn` and returns a + boolean. If `True`, then the iteration is stopped. Only the elements + stored into traces before `stop_fn` returned `True` are valid. parallel_iterations: Number of iterations of the while loop to run in parallel (TensorFlow-only). @@ -286,11 +290,17 @@ def wrapper(state): state, extra = util.map_tree( util.convert_to_tensor, call_transition_operator(fn, state) ) + if stop_fn is None: + # For TF compatibility, we can't use None. () is conveniently "falsy", + # which we rely on in the backend implementations. + stop = () + else: + stop = stop_fn(state, extra) trace_element = util.map_tree( util.convert_to_tensor, trace_fn(state, extra) ) untraced, traced = _split_trace(trace_element, trace_mask) - return state, untraced, traced + return state, untraced, traced, stop state = util.map_tree(util.convert_to_tensor, state) diff --git a/spinoffs/fun_mc/fun_mc/fun_mc_test.py b/spinoffs/fun_mc/fun_mc/fun_mc_test.py index 21382bf439..cb743a9a20 100644 --- a/spinoffs/fun_mc/fun_mc/fun_mc_test.py +++ b/spinoffs/fun_mc/fun_mc/fun_mc_test.py @@ -243,6 +243,42 @@ def testTraceMaxSteps(self, unroll): self.assertEqual(6, traced.shape[0]) self.assertAllEqual(400, untraced) + @parameterized.named_parameters( + ('Unrolled', True), + ('NotUnrolled', False), + ) + def testTraceStopFnSingle(self, unroll): + x, (traced, untraced) = fun_mc.trace( + 0, + lambda x: (x + 1, (10 * x, 100 * x)), + 5, + unroll=unroll, + trace_mask=(True, False), + stop_fn=lambda x, _: x == 1, + ) + self.assertAllEqual(1, x) + self.assertAllEqual(0, traced[0]) + self.assertEqual(5, traced.shape[0]) + self.assertAllEqual(0, untraced) + + @parameterized.named_parameters( + ('Unrolled', True), + ('NotUnrolled', False), + ) + def testTraceStopFnMulti(self, unroll): + x, (traced, untraced) = fun_mc.trace( + 0, + lambda x: (x + 1, (10 * x, 100 * x)), + 5, + unroll=unroll, + trace_mask=(True, False), + stop_fn=lambda x, _: x == 3, + ) + self.assertAllEqual(3, x) + self.assertAllEqual(20, traced[2]) + self.assertEqual(5, traced.shape[0]) + self.assertAllEqual(200, untraced) + @parameterized.named_parameters( ('Unrolled', True), ('NotUnrolled', False),