Skip to content

How to define a custom_vjp for function that takes another function as an argument? #16540

Answered by jakevdp
mfkasim1 asked this question in Q&A
Discussion options

You must be logged in to vote

Take a look at jax.custom_vjp with nondiff_argnums. It basically covers exactly this question, though the example is a bit misleading because it doesn't actually compute the gradient of the input function.

In the case of your example function, the full solution might look something like this:

import jax
import jax.numpy as jnp
from functools import partial

@partial(jax.custom_vjp, nondiff_argnums=(0,))
def f(g, x):
    return g(x)

def f_fwd(g, x):
    # Note: g_x is equivalent to f(g, x) here, but in general this will not be the case.
    g_x, g_vjp = jax.vjp(g, x)
    return f(g, x), (g_x, g_vjp)

def f_bwd(g, res, g_bar):
    # Note: g_x unneeded for this simple function f, but in gen…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@mfkasim1
Comment options

@jakevdp
Comment options

Answer selected by mfkasim1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants