You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In my project, I am dealing with a complex composition of functions that can be represented as a Directed Acyclic Graph (DAG). I have already implemented the computation pipeline using JAX, and now I am exploring ways to optimize it by avoiding redundant computations. I previously inquired about ways to avoid recomputing 'expensive' functions when calculating multiple gradients, and one of the recommendations I received was to use VJPs (#16464).
To provide a clearer representation of my use case, I have defined the corresponding DAG as a dictionary with functions representing state transitions and their dependencies:
dag = {
'A': (x_to_A, ['x']), # A = x_to_A(x)
'B': (A_to_B, ['A']), # B = A_to_B(A)
'C': (B_to_C, ['B']), # C = B_to_C(B)
'D': (C_to_D, ['C']), # D = C_to_D(C)
'E': (B_to_E, ['B']) # E = B_to_E(B)
}
In my current code, I manually compute gradients using VJPs, as shown in the example below:
def common_pipeline(x):
print("Computing the pipeline steps")
A = x_to_A(x)
B = A_to_B(A)
C = B_to_C(B)
return B, C
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=jnp.float32)
(B, C), a_bc_vjp_fn = jax.vjp(common_pipeline, x)
# Function output and the VJP function
E, E_vjp_fn = jax.vjp(B_to_E, B)
D, D_vjp_fn = jax.vjp(C_to_D, C)
# Gradients of E and D with respect to B and C
E_B_grad = E_vjp_fn(1.0)[0]
D_C_grad = D_vjp_fn(1.0)[0]
# Compute the gradients of E and D with respect to x
E_x_grad = a_bc_vjp_fn((E_B_grad, jnp.zeros_like(B)))
D_x_grad = a_bc_vjp_fn((jnp.zeros_like(C), D_C_grad))
print(f"Final D state is {D} and its derivative with respect to the input is {D_x_grad}")
print(f"Final E state is {E} and its derivative with respect to the input is {E_x_grad}")
While I was trying to generalize this to other DAGs, I realized that I am essentially trying to manually implement backpropagation (reverse-mode autodiff), which seems like reinventing the wheel and not taking full advantage of JAX's capabilities. I am looking for suggestions on how I can leverage JAX's functionalities to compute gradients efficiently in my complex DAG. Specifically, I want to avoid recomputing the intermediate gradients/ values while taking advantage of JAX's automatic differentiation capabilities. Notably, some of the functions I use for state transitions are not jittable, so I cannot JIT the entire 'pipeline' and benefit from common subexpression simplification. I was also wondering if I could leverage ML packages like Haiku for this since backpropagation on DAGs is such a common task.
During my exploration, I came across the Autodidact repository which was been very insightful. @mattjj, I was wondering if you would have any insights or advice on implementing this efficiently. Thanks!
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hello everyone,
In my project, I am dealing with a complex composition of functions that can be represented as a Directed Acyclic Graph (DAG). I have already implemented the computation pipeline using JAX, and now I am exploring ways to optimize it by avoiding redundant computations. I previously inquired about ways to avoid recomputing 'expensive' functions when calculating multiple gradients, and one of the recommendations I received was to use VJPs (#16464).
To provide a clearer representation of my use case, I have defined the corresponding DAG as a dictionary with functions representing state transitions and their dependencies:
In my current code, I manually compute gradients using VJPs, as shown in the example below:
While I was trying to generalize this to other DAGs, I realized that I am essentially trying to manually implement backpropagation (reverse-mode autodiff), which seems like reinventing the wheel and not taking full advantage of JAX's capabilities. I am looking for suggestions on how I can leverage JAX's functionalities to compute gradients efficiently in my complex DAG. Specifically, I want to avoid recomputing the intermediate gradients/ values while taking advantage of JAX's automatic differentiation capabilities. Notably, some of the functions I use for state transitions are not jittable, so I cannot JIT the entire 'pipeline' and benefit from common subexpression simplification. I was also wondering if I could leverage ML packages like Haiku for this since backpropagation on DAGs is such a common task.
During my exploration, I came across the Autodidact repository which was been very insightful. @mattjj, I was wondering if you would have any insights or advice on implementing this efficiently. Thanks!
Beta Was this translation helpful? Give feedback.
All reactions