-
Hi all, I've been prototyping a statistical model and wanted to make use of The general framework does something like: for loop in main_loop:
params = _inner_loop(data, params)
@sparsify
def _inner_loop(data, params):
params = # use jax grad to perform numerical optimization for some of the params
params = # use some analytic results to update other params
return params There is a need to However, JAX complains that it cannot evaluate the gradient due to unsupported operations between Running the linked gist results in the following output:
I would greatly appreciate any help or comments on the matter. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Thanks for the question! |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
jax.experimental.sparse
is under active development, and we typically implement new primitives as they are needed. You are the first user I'm aware of to need the sparsification rule foradd_any
. Would you be able to put together a minimal example of code that requires this path?