Element-wise multiplication of BCOO sparse matrices reports out of memory error #15070
-
I would like to use import jax
import jax.experimental.sparse
X = jax.random.bernoulli(jax.random.PRNGKey(1234), 0.01, (5, 1000, 1000))
X_sp = jax.experimental.sparse.BCOO.fromdense(X)
jax.experimental.sparse.bcoo_multiply_sparse(X_sp, X_sp) But this code reports the following error message:
I think element-wise multiplication of COO matrices needs at most the memory size of original matrices. How can I fix this problem? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi - thanks for the question. You're right that in theory this should be possible: the problem is that Hope that helps! |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question. You're right that in theory this should be possible: the problem is that
jax.experimental.sparse
is built on the operations available in XLA, and XLA doesn't have any set-like methods which are required for many sparse-sparse operations. Sojax.experimental.sparse
relies on a very inefficient method to compute sparse-sparse elementwise multiplication. I don't know of any workaround aside from using smaller matrices or a larger machine, or perhaps converting your matrices to a more structured sparse form (incidentally, issues like this are one reason JAX's sparse support is still experimental: it works well in many cases, but for sparse-sparse operations in pa…