-
I have a function that takes another function as an argument and I want to define a custom vjp for that function. However, I don't know how to use it. Here is my code: import jax
import jax.numpy as np
from functools import partial
# Define the function
@partial(jax.custom_vjp, nondiff_argnums=(0,))
def f(g, x):
return g(x)
# Define the forward pass and backward pass (vjp)
def f_fwd(g, x):
y = g(x)
return y, (g, x)
def f_bwd(res, g_bar):
g, x = res
return (None, g(x) * g_bar)
# Associate them with the function
f.defvjp(f_fwd, f_bwd)
# Test it out
def square(x):
return np.square(x)
x = np.array([2., 3., 4.])
print(jax.grad(f, argnums=1)(square, x)) # should print [4., 6., 8.] When I run it, it gave me this error:
Is there any way to define |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Take a look at In the case of your example function, the full solution might look something like this: import jax
import jax.numpy as jnp
from functools import partial
@partial(jax.custom_vjp, nondiff_argnums=(0,))
def f(g, x):
return g(x)
def f_fwd(g, x):
# Note: g_x is equivalent to f(g, x) here, but in general this will not be the case.
g_x, g_vjp = jax.vjp(g, x)
return f(g, x), (g_x, g_vjp)
def f_bwd(g, res, g_bar):
# Note: g_x unneeded for this simple function f, but in general the gradient will depend on it.
g_x, g_vjp = res
return g_vjp(g_bar)
f.defvjp(f_fwd, f_bwd)
def square(x):
return jnp.square(x)
x = np.array([2., 3., 4.])
print(jax.vmap(jax.grad(f, argnums=1), in_axes=(None, 0))(square, x))
# [4. 6. 8.] (Note I had to use |
Beta Was this translation helpful? Give feedback.
Take a look at
jax.custom_vjp
withnondiff_argnums
. It basically covers exactly this question, though the example is a bit misleading because it doesn't actually compute the gradient of the input function.In the case of your example function, the full solution might look something like this: