Implicit and competitive differentiation in JAX.
Our "competitive differentiation" approach uses Competitive Gradient Descent to solve the equality-constrained nonlinear program associated with the fixed-point problem. A standalone implementation of CGD is provided under fax/competitive/cga.py and the equality-constrained solver derived from it can be accessed via fax.constrained.cga_lagrange_min
or fax.constrained.cga_ecp
. An implementation of implicit differentiation based on Christianson's two-phases reverse accumulation algorithm can also be obtained with the function fax.implicit.two_phase_solve
.
See fax/constrained/constrained_test.py for examples. Please note that the API is subject to change.
To get the latest version from Github:
pip install git+https://github.com/gehring/fax.git
Otherwise on PyPI:
pip install jax-fixedpoint
The main entry point for Christianson's two-phases reverse accumulation is through fax.implicit.two_phase_solver
. For example, imagine that you have a fixed-point iteration method like Power iteration and want to compute the gradient of a function of its output. You could write something like:
import jax
import jax.numpy as jnp
from fax import implicit
def power_iteration(A):
def _power_iteration_step(b):
b = A @ b
return b/jnp.linalg.norm(b)
return _power_iteration_step
def objective(A):
b0 = jnp.ones((A.shape[0]))
b = implicit.two_phase_solve(power_iteration, b0, A)
return (b.T @ A @ b)/(b.T @ b)
A = jnp.array([[1, 2], [3, 4.]])
print(jax.grad(objective)(A))
# Output array should be close to:
# DeviceArray([[0.23888351, 0.52223295],
# [0.34815535, 0.76111656]], dtype=float32)
Given a function and an initial guess, we can use fax.implicit.two_phase_solve
to solve a fixed-point problem such that the result is differentiated using the implicit form of the returned fixed-point. Behind the scene, fax.implicit.two_phase_solve
tells jax
to apply a custom VJP rule which fax
derives from the fixed-point iteration function that it receives.
Not only does this provides numerical and computational benefits over backpropagating through the fixed-point iteration loop, it allows us to define gradients even when our fixed-point solver isn't differentiable. The two_phase_solve
function allows us reproduce our power iteration example, but using a solver which jax
is incapable of differentiating:
import numpy as np
import jax
import jax.numpy as jnp
from fax import implicit
def numpy_max_eig(A):
w, v = np.linalg.eig(A)
return v[:, np.argmax(w)]
def power_iteration(A):
def _power_iteration_step(b):
b = A @ b
return b/jnp.linalg.norm(b)
return _power_iteration_step
def objective_non_diff_solver(A):
b0 = jnp.ones((A.shape[0]))
b = implicit.two_phase_solve(
power_iteration,
b0,
A,
solvers=(lambda f, init_b, matrix: numpy_max_eig(matrix),),
)
return (b.T @ A @ b)/(b.T @ b)
A = jnp.array([[1, 2], [3, 4.]])
print(jax.grad(objective_non_diff_solver)(A))
# Output array should be close to:
# DeviceArray([[0.23888351, 0.52223295],
# [0.34815535, 0.76111656]], dtype=float32)
NOTE: this example will not work when jit'ed using jax.jit
since jax
won't be able compile the "external" numpy call. This is only meant as a demonstration of how implicit differentiation doesn't care about whether the solver itself is differentiable; it only cares whether the fixed-point function is.
Citing competitive differentiation:
@inproceedings{bacon2019optrl,
author={Bacon, Pierre-Luc and Schafer, Florian and Gehring, Clement and Anandkumar, Animashree and Brunskill, Emma},
title={A Lagrangian Method for Inverse Problems in Reinforcement Learning},
booktitle={NeurIPS Optimization Foundations for Reinforcement Learning Workshop},
year={2019},
url={http://lis.csail.mit.edu/pubs/bacon-optrl-2019.pdf},
keywords={Optimization, Reinforcement Learning, Lagrangian}
}
Citing this repo:
@misc{gehring2019fax,
author = {Gehring, Clement and Bacon, Pierre-Luc and Schaefer, Florian},
title = {{FAX: differentiating fixed point problems in JAX}},
note = {Available at: https://github.com/gehring/fax},
year = {2019}
}