diff --git a/.github/workflows/updec-linux.yml b/.github/workflows/updes-linux.yml similarity index 96% rename from .github/workflows/updec-linux.yml rename to .github/workflows/updes-linux.yml index 7c4c254..f9b50e8 100644 --- a/.github/workflows/updec-linux.yml +++ b/.github/workflows/updes-linux.yml @@ -1,4 +1,4 @@ -name: Updec CI/CD +name: Updes CI/CD on: [push] diff --git a/README.md b/README.md index 1474d74..de454b6 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,10 @@ pip install updes ``` The example below illustrates how to solve the Laplace equation with Dirichlet and Neumann boundary conditions: +

+ +

+ ```python import updes import jax.numpy as jnp @@ -70,6 +74,7 @@ cloud.visualize_field(sol.vals, cmap="jet", projection="3d", title="RBF solution ## To-Dos 1. Logo, contributors guide, and developer documentation +2. Improved ill-conditioned linear systems for RBF-FD (i.e. `support_size != "max"`) 2. More introductory examples in the documentation : - Integration with neural networks and [Equinox](https://github.com/patrick-kidger/equinox) - Non-linear and multi-dimensional PDEs @@ -83,7 +88,7 @@ We welcome contributions from the community. Please feel free to open an issue o ## Dependencies -- **Core**: [JAX](https://github.com/google/jax) - [GMSH](https://pypi.org/project/gmsh/) - [Matplotlib](https://github.com/matplotlib/matplotlib) - [Seaborn](https://github.com/mwaskom/seaborn) - [Scikit-Learn](https://github.com/scikit-learn/scikit-learn) +- **Core**: [JAX](https://github.com/google/jax) - [GMSH](https://pypi.org/project/gmsh/) - [Lineax](https://github.com/patrick-kidger/lineax) - [Matplotlib](https://github.com/matplotlib/matplotlib) - [Seaborn](https://github.com/mwaskom/seaborn) - [Scikit-Learn](https://github.com/scikit-learn/scikit-learn) - **Optional**: [PyVista](https://github.com/pyvista/pyvista) - [FFMPEG](https://github.com/kkroening/ffmpeg-python) - [QuartoDoc](https://github.com/machow/quartodoc/) See the `pyproject.toml` file the specific versions of the dependencies. diff --git a/demos/Laplace/30_laplace_super_scaled.py b/demos/Laplace/30_laplace_super_scaled.py index 8512469..b2fa2c6 100644 --- a/demos/Laplace/30_laplace_super_scaled.py +++ b/demos/Laplace/30_laplace_super_scaled.py @@ -5,6 +5,10 @@ """ # os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false" + +import pstats +from updes import * + import time import jax @@ -16,29 +20,48 @@ import matplotlib.pyplot as plt import seaborn as sns -from updes import * +import cProfile + DATAFOLDER = "./data/TempFolder/" RBF = partial(polyharmonic, a=1) -# RBF = partial(gaussian, eps=1e-1) +# RBF = partial(gaussian, eps=1e1) # RBF = partial(thin_plate, a=3) -MAX_DEGREE = 0 +MAX_DEGREE = 1 -Nx = Ny = 20 +Nx = Ny = 50 # SUPPORT_SIZE = "max" -SUPPORT_SIZE = 20*2 +SUPPORT_SIZE = 2 facet_types={"South":"d", "West":"d", "North":"d", "East":"d"} start = time.time() + +## benchmarking with cprofile +# res = cProfile.run("cloud = SquareCloud(Nx=Nx, Ny=Ny, facet_types=facet_types, support_size=SUPPORT_SIZE)") cloud = SquareCloud(Nx=Nx, Ny=Ny, facet_types=facet_types, support_size=SUPPORT_SIZE) + +## Print results sorted by cumulative time +# p = pstats.Stats(res) +# p.sort_stats('cumulative').print_stats(10) + + +## Only print the top 10 high-level function +# p.print_callers(10) + + + +# cloud = SquareCloud(Nx=Nx, Ny=Ny, facet_types=facet_types, support_size=SUPPORT_SIZE) + + + walltime = time.time() - start print(f"Cloud generation walltime: {walltime:.2f} seconds") # cloud.visualize_cloud(s=0.5, figsize=(7,6)); -## %% +# %% def my_diff_operator(x, center=None, rbf=None, monomial=None, fields=None): return nodal_laplacian(x, center, rbf, monomial) @@ -80,8 +103,35 @@ def my_rhs_operator(x, centers=None, rbf=None, fields=None): +# import jax +# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false" +# import jax.numpy as jnp +# import jax.random as jr +# import lineax as lx + +# # size = 15000 +# # matrix_key, vector_key = jr.split(jr.PRNGKey(0)) +# # matrix = jr.normal(matrix_key, (size, size)) +# # vector = jr.normal(vector_key, (size,)) +# # operator = lx.MatrixLinearOperator(matrix) +# # solution = lx.linear_solve(operator, vector, solver=lx.QR()) +# # solution.value + +# # size = 8000 +# # matrix_key, vector_key = jr.split(jr.PRNGKey(0)) +# # matrix = jr.normal(matrix_key, (size, size)) +# # vector = jr.normal(vector_key, (size,)) +# # solution = jnp.linalg.solve(matrix, vector) +# # solution + + +# size = 15000 +# matrix_key, vector_key = jr.split(jr.PRNGKey(0)) +# matrix = jr.normal(matrix_key, (size, size)) +# vector = jr.normal(vector_key, (size,)) +# solution = jnp.linalg.lstsq(matrix, vector) - +# %% # ## Observing the sparsity patten of the matrices involved @@ -98,30 +148,30 @@ def my_rhs_operator(x, centers=None, rbf=None, fields=None): -M = compute_nb_monomials(MAX_DEGREE, 2) -A = assemble_A(cloud, RBF, M) -mat1 = jnp.abs(A) +# M = compute_nb_monomials(MAX_DEGREE, 2) +# A = assemble_A(cloud, RBF, M) +# mat1 = jnp.abs(A) -inv_A = assemble_invert_A(cloud, RBF, M) -mat2 = jnp.abs(inv_A) +# inv_A = assemble_invert_A(cloud, RBF, M) +# mat2 = jnp.abs(inv_A) -## Matrix B for the linear system -mat3 = sol.mat +# ## Matrix B for the linear system +# mat3 = sol.mat -## 3 figures -fig, ax = plt.subplots(1, 3, figsize=(15,5)) +# ## 3 figures +# fig, ax = plt.subplots(1, 3, figsize=(15,5)) -sns.heatmap(jnp.abs(mat1), ax=ax[0], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False) -ax[0].set_title("Collocation Matrix") +# sns.heatmap(jnp.abs(mat1), ax=ax[0], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False) +# ax[0].set_title("Collocation Matrix") -sns.heatmap(jnp.abs(mat2), ax=ax[1], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False) -ax[1].set_title("Inverse of Collocation Matrix") +# sns.heatmap(jnp.abs(mat2), ax=ax[1], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False) +# ax[1].set_title("Inverse of Collocation Matrix") -sns.heatmap(jnp.abs(mat3), ax=ax[2], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False) -ax[2].set_title("Linear System Matrix (B)") +# sns.heatmap(jnp.abs(mat3), ax=ax[2], cmap="grey", cbar=True, square=True, xticklabels=False, yticklabels=False) +# ax[2].set_title("Linear System Matrix (B)") -# plt.title("Sparsity Pattern of the Collocation Matrix") -plt.show() +# # plt.title("Sparsity Pattern of the Collocation Matrix") +# plt.show() #%% diff --git a/docs/assets/LaplacePDE.png b/docs/assets/LaplacePDE.png new file mode 100644 index 0000000..91ffcdc Binary files /dev/null and b/docs/assets/LaplacePDE.png differ diff --git a/docs/assets/NextRelease.md b/docs/assets/NextRelease.md index d366a99..6a1a943 100644 --- a/docs/assets/NextRelease.md +++ b/docs/assets/NextRelease.md @@ -1,4 +1,5 @@ -For the next release v1.1.0 +For the next release v1.0.2 - [X] Added colorbar to animate fields - [X] Fixed the args inputs to construct the local matrix for nodal_div_grad: (if array, else, etc.) - [X] Implemented the Darcy flow problem +- [X] Faster linear solves with Lineax diff --git a/docs/assets/README_PyPI.md b/docs/assets/README_PyPI.md index 1388f4a..f6aed10 100644 --- a/docs/assets/README_PyPI.md +++ b/docs/assets/README_PyPI.md @@ -60,7 +60,7 @@ cloud.visualize_field(sol.vals, cmap="jet", projection="3d", title="RBF solution ## Dependencies -- **Core**: JAX - GMSH - Matplotlib - Seaborn - Scikit-Learn +- **Core**: JAX - GMSH - Lineax - Matplotlib - Seaborn - Scikit-Learn - **Optional**: PyVista - FFMPEG - QuartoDoc See the `pyproject.toml` file the specific versions of the dependencies. diff --git a/pyproject.toml b/pyproject.toml index 3295e7a..df52536 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ keywords = [ dependencies = [ "jax >= 0.3.4", + "lineax", "gmsh", "pytest", "matplotlib>=3.4.0", diff --git a/updes/cloud.py b/updes/cloud.py index cc8a90b..160c665 100644 --- a/updes/cloud.py +++ b/updes/cloud.py @@ -99,9 +99,13 @@ def define_local_supports(self): assert self.support_size > 0, "Support size must be strictly greater than 0" assert self.support_size <= self.N, "Support size must be strictly less than or equal the number of nodes" + # ## If support size == coords.shape[0], then we are using all the nodes + # if self.support_size > coords.shape[0]: + # self.local_supports = {renumb_map[i]:list(range(self.N)) for i in range(self.N)} + # else: ## Use BallTree for fast nearest neighbors search # ball_tree = KDTree(coords, leaf_size=40, metric='euclidean') - ball_tree = BallTree(coords, leaf_size=40, metric='euclidean') + ball_tree = BallTree(coords, leaf_size=1, metric='euclidean') for i in range(self.N): _, neighbours = ball_tree.query(self.nodes[i][jnp.newaxis], k=self.support_size) neighbours = neighbours[0][1:] ## Result is a 2d list, without the first el (the node itself) diff --git a/updes/operators.py b/updes/operators.py index f0c3a7b..358aca7 100644 --- a/updes/operators.py +++ b/updes/operators.py @@ -2,6 +2,7 @@ import jax import jax.numpy as jnp from jax.tree_util import Partial +import lineax as lx from functools import cache @@ -601,8 +602,17 @@ def pde_solver( diff_operator:callable, B1 = assemble_B(diff_operator, cloud, rbf, nb_monomials, diff_args, robin_coeffs) rhs = assemble_q(rhs_operator, boundary_conditions, cloud, rbf, nb_monomials, rhs_args) - ## Solve the linear system - sol_vals = jnp.linalg.solve(B1, rhs) + ## Solve the linear system using JAX's direct solver + # sol_vals = jnp.linalg.solve(B1, rhs) + + ## Solve the linear system using Scipy's iterative solver + # sol_vals = jax.scipy.sparse.linalg.gmres(B1, rhs, tol=1e-5)[0] + + ## Solve the linear system using Lineax + operator = lx.MatrixLinearOperator(B1) + sol_vals = lx.linear_solve(operator, rhs, solver=lx.QR()).value + # sol_vals = lx.linear_solve(operator, rhs, solver=lx.GMRES(rtol=1e-3, atol=1e-3)).value + sol_coeffs = core_compute_coefficients(sol_vals, cloud, rbf, nb_monomials) return SteadySol(sol_vals, sol_coeffs, B1)