Skip to content

Commit

Permalink
Fix a race condition(?) in test_riccati_custom_adjoint_solver.
Browse files Browse the repository at this point in the history
Also, simplify the test to make it even more obvious what the expected result
should be.

PiperOrigin-RevId: 565116879
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Sep 13, 2023
1 parent d4fbef1 commit c7e8b59
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions tensorflow_probability/python/math/ode/ode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,19 @@ def __init__(self, make_solver_fn, first_step_size):
)

def _solve(self, **kwargs):
step_size = kwargs.pop('previous_solver_internal_state')
step_size, solve_count = kwargs.pop('previous_solver_internal_state')
results = self._make_solver_fn(step_size).solve(**kwargs)
return results._replace(
solver_internal_state=results.solver_internal_state.step_size)
solver_internal_state=(
results.solver_internal_state.step_size,
solve_count + 1,
)
)

def _initialize_solver_internal_state(self, **kwargs):
del kwargs
return self._first_step_size
# The second value is solve count, for testing.
return (self._first_step_size, 0)

def _adjust_solver_internal_state_for_state_jump(self, **kwargs):
return kwargs['previous_solver_internal_state']
Expand Down Expand Up @@ -447,17 +452,17 @@ def test_riccati_custom_adjoint_solver(self, solver, solution_times_fn):
# Instrument the adjoint solver for testing. We have to do this because the
# API doesn't provide access to the adjoint solver's diagnostics.
first_step_size = np.float64(1.)
last_initial_step_size = tf.Variable(0., dtype=tf.float64)
self.evaluate(last_initial_step_size.initializer)
solve_count = tf.Variable(0, dtype=tf.int32)
self.evaluate(solve_count.initializer)

class _InstrumentedSolver(StepSizeHeuristicAdjointSolver):

def solve(self, **kwargs):
with tf.control_dependencies([
last_initial_step_size.assign(
kwargs['previous_solver_internal_state'])
]):
return super(_InstrumentedSolver, self).solve(**kwargs)
results = super(_InstrumentedSolver, self).solve(**kwargs)
with tf.control_dependencies(
[solve_count.assign(results.solver_internal_state[1])]
):
return tf.nest.map_structure(tf.identity, results)

adjoint_solver = _InstrumentedSolver(
make_solver_fn=lambda step_size: solver( # pylint: disable=g-long-lambda
Expand All @@ -479,13 +484,14 @@ def grad_fn(initial_state):
final_state = results.states[-1]
return final_state
_, grad = tfp_gradient.value_and_gradient(grad_fn, initial_state)
grad, last_initial_step_size = self.evaluate((grad, last_initial_step_size))
grad = self.evaluate(grad)
# There's a race condition if we evaluate solve_count right away. Evaluate
# it after we're done the computation to produce `grad`.
solve_count = self.evaluate(solve_count)
grad_exact = 1. / (1. - initial_state_value * final_time)**2
self.assertAllClose(grad, grad_exact, rtol=1e-3, atol=1e-3)
# This indicates that the adaptation carried over to the final solve. We
# expect the step size to decrease because we purposefully made the initial
# step size way too large.
self.assertLess(last_initial_step_size, first_step_size)
# This indicates that the adaptation carried over to the final solve.
self.assertGreater(solve_count, 0)

def test_linear_ode(self, solver, solution_times_fn):
if not tf1.control_flow_v2_enabled():
Expand Down

0 comments on commit c7e8b59

Please sign in to comment.