-
I am writing a finite element analysis (FEA) code in JAX. In order to solve the system of equations (KU=F), I currently use the Jacobian of the residual to find the 'K' (stiffness matrix). This is easy in JAX since the residual is a differentiable function. To solve the system of equations, I currently use the
But someone told me that since K is sparse, I should not use the JVP but instead create a sparse matrix out of K and create an explicit matrix-vector product function. This is shown below
Is the second one really faster ? This is important for me since the matvec function will be called many times in the solver. |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments
-
In general sparse operations will not be faster than dense operations, particularly on accelerators like GPU (this is not just true in JAX, but in virtually all systems that run on modern hardware). But sparse operations can be useful if the dense version of your matrix is too large to fit in memory, or if your matrix is extremely sparse (i.e. ~99.9% sparse) so that the indexing overhead does not dominate the cost of a dense matmul. |
Beta Was this translation helpful? Give feedback.
-
Oh.. So, I can use the Thanks a lot jake! |
Beta Was this translation helpful? Give feedback.
-
To go into a little more detail: there's several facets to which of these is best. For example, if your Another point to consider: if you call Here's a third point: with There's all kinds of trade-offs to be made here, so generally speaking you'll need to try both (or know something about the structure of your problem) to figure out which is best wrt runtime / compile time / memory usage. Finally, allow me to point you at Lineax, which is our enhanced suite of linear solvers for JAX. In particular it includes a number of kinds of linear operator, which would allow you to switch between the |
Beta Was this translation helpful? Give feedback.
In general sparse operations will not be faster than dense operations, particularly on accelerators like GPU (this is not just true in JAX, but in virtually all systems that run on modern hardware). But sparse operations can be useful if the dense version of your matrix is too large to fit in memory, or if your matrix is extremely sparse (i.e. ~99.9% sparse) so that the indexing overhead does not dominate the cost of a dense matmul.