Replies: 2 comments 1 reply
-
For JIT to work, even temporary variables, such as import jax.numpy as jnp
import jax.random as jrnd
key = jrnd.PRNGKey(0)
key, wishart_key, choice_key = jrnd.split(key, 3)
p = 4
dof = p + 1
scale = jnp.eye(p)
# Create a pxp Wishart-distributed random matrix
u = jrnd.multivariate_normal(wishart_key, mean=jnp.zeros((p, )), cov=scale, shape=(dof, ))
Sigma = jnp.einsum('dp,dq->pq', u, u)
W = Sigma
# Create a pxp symmetric, binary matrix
G = jnp.array([[1, 1, 0, 0], [1, 1, 0, 1], [0, 1, 1, 1], [0, 0, 1, 1]])
def take_dynamic_submatrices(i, G):
# select all variables that i connects to in G
ix = jnp.where(G[i] == 1)[0]
W_sub = W[ix[:, None], ix]
return W_sub
def take_dynamic_submatrices_padded(i, G):
# select all variables that i connects to in G
mix = jnp.arange(len(G[i]))
mix = jax.lax.associative_scan(lambda a, b: a + b, G[i]) - 1
size = mix[-1] + 1
mix = jnp.where(G[i] == 1, mix, len(G[i]) + 1)
W_sub = jnp.zeros(W.shape).at[mix[:, None], mix].set(W, mode='fill')
return W_sub, size
# Pick a node
i = 1
W_sub = take_dynamic_submatrices(i, G)
print('W_sub:\n', W_sub)
# How can we make this work within a vmap or lax.scan?
W_sub_jit, size = jax.jit(take_dynamic_submatrices_padded)(i, G)
print('W_sub_jit:\n', W_sub_jit)
print('W_sub_jit size:\n', size) Output:
|
Beta Was this translation helpful? Give feedback.
-
You cannot do what you're asking: i.e. create a matrix In general you probably can do something similar to what you have in mind, by side-stepping the actual creation of the dynamic matrix For example:
If you need specific advice, you could expand your MWE to show the kinds of operations you're trying to do. |
Beta Was this translation helpful? Give feedback.
-
I'm trying to implement a specific reversible-jump MCMC algorithm, By definition, an RJ algorithm traverses solutions of different dimensionalities, which is obviously tricky with Jax, but I'm hoping I'm simply missing a clever trick to make this work :-)
Here is a MWE of what I am trying to do:
Throughout the algorithm, Sigma/W and G would change values, but remain of shape$p\times p$ . However depending on random steps, G would change, and that would mean
ix
and henceW_sub
would/could be different in size across iterations of the algorithm. I do not really need to store elements of different sizes, as I would assignW_sub
back to a subpart ofW
, keeping what I track over iterations of the same shape. I would need to have them as variables on the fly though.I have tried a couple of directions:
static_argnums
, and simply masking the matrices that I update. The first one does not work, becauseG
is not hashable (as it is an array), and the second approach failed as well because later I have to do some linear algebra operations onW_sub
, and that does not work if it is masked.It's definitely possible that what I am trying to do is simply impossible with Jax due to the requirement of array shapes being known, but since the objects I really care about do keep the same shape, do you perhaps have a suggestion on how to proceed?
Beta Was this translation helpful? Give feedback.
All reactions