Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding custom linear solvers #47

Open
aadityacs opened this issue Dec 10, 2024 · 2 comments
Open

Adding custom linear solvers #47

aadityacs opened this issue Dec 10, 2024 · 2 comments

Comments

@aadityacs
Copy link

Great library!

I think it would be nice if the library had a host of other linear solvers that can be useful for solving larger problems. For instance, I have a set of solvers like multi-grid solvers, intel's pypardiso integrated with JAX here.

I was wondering what would be the best way to go about integrating these solvers into the framework?

Image

@tianjuxue
Copy link
Collaborator

Hello Aaditya, thanks for mentioning this and your linear solver examples. Indeed, the entire performance of this library JAX-FEM depends on scalable linear solvers. I just updated the code a little bit so that at least users have a chance to define their own favoriate solver, e.g., pardiso solver as in your example.

Please see the source code file solver.py as well as the user file example.py.

Let me know your thoughts! The next move for the library is really to go for massive parallel solvers (with multiple CPUs or even multiple GPUs), but I am a bit not sure what'd be the best practice at this moment.

@aadityacs
Copy link
Author

I think we should be able to leverage already existing solvers on GPU(s) to speed up the computation.

For instance see this notebook where I am simply calling cupy code from jax to see a speed-up from 33 secs (on CPU) to 2 secs (on GPU)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants