From db7e64ea3496229a41fc696d68e8986c976a2e03 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 26 Oct 2024 18:28:52 -0400 Subject: [PATCH] fix: move the `BandedArrays` extension --- Project.toml | 2 -- ext/NonlinearSolveBandedMatricesExt.jl | 11 ----------- lib/NonlinearSolveBase/Project.toml | 3 +++ .../ext/NonlinearSolveBaseBandedMatricesExt.jl | 16 ++++++++++++++++ .../ext/NonlinearSolveBaseSparseArraysExt.jl | 4 +++- lib/NonlinearSolveBase/src/utils.jl | 2 ++ 6 files changed, 24 insertions(+), 14 deletions(-) delete mode 100644 ext/NonlinearSolveBandedMatricesExt.jl create mode 100644 lib/NonlinearSolveBase/ext/NonlinearSolveBaseBandedMatricesExt.jl diff --git a/Project.toml b/Project.toml index 03e6519dd..2ae6f7e4d 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,6 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" [weakdeps] -BandedMatrices = "aae01518-5342-5314-be14-df237901396f" FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce" FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176" LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891" @@ -48,7 +47,6 @@ SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" [extensions] -NonlinearSolveBandedMatricesExt = "BandedMatrices" NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt" NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration" NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim" diff --git a/ext/NonlinearSolveBandedMatricesExt.jl b/ext/NonlinearSolveBandedMatricesExt.jl deleted file mode 100644 index b79df3578..000000000 --- a/ext/NonlinearSolveBandedMatricesExt.jl +++ /dev/null @@ -1,11 +0,0 @@ -module NonlinearSolveBandedMatricesExt - -using BandedMatrices: BandedMatrix -using LinearAlgebra: Diagonal -using NonlinearSolve: NonlinearSolve -using SparseArrays: sparse - -# This is used if we vcat a Banded Jacobian with a Diagonal Matrix in Levenberg -@inline NonlinearSolve._vcat(B::BandedMatrix, D::Diagonal) = vcat(sparse(B), D) - -end diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 52fdb6aff..31f14f3b9 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -27,6 +27,7 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" [weakdeps] +BandedMatrices = "aae01518-5342-5314-be14-df237901396f" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" @@ -34,6 +35,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" [extensions] +NonlinearSolveBaseBandedMatricesExt = "BandedMatrices" NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase" NonlinearSolveBaseForwardDiffExt = "ForwardDiff" NonlinearSolveBaseLinearSolveExt = "LinearSolve" @@ -45,6 +47,7 @@ ADTypes = "1.9" Adapt = "4.1.0" Aqua = "0.8.7" ArrayInterface = "7.9" +BandedMatrices = "1.5" CommonSolve = "0.2.4" Compat = "4.15" ConcreteStructs = "0.2.3" diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseBandedMatricesExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseBandedMatricesExt.jl new file mode 100644 index 000000000..7f2ac7f90 --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseBandedMatricesExt.jl @@ -0,0 +1,16 @@ +module NonlinearSolveBaseBandedMatricesExt + +using BandedMatrices: BandedMatrix +using LinearAlgebra: Diagonal +using NonlinearSolveBase: NonlinearSolveBase, Utils + +# This is used if we vcat a Banded Jacobian with a Diagonal Matrix in Levenberg +@inline function Utils.faster_vcat(B::BandedMatrix, D::Diagonal) + if Utils.is_extension_loaded(Val(:SparseArrays)) + @warn "Load `SparseArrays` for an optimized vcat for BandedMatrices." + return vcat(B, D) + end + return vcat(Utils.make_sparse(B), D) +end + +end diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseSparseArraysExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseSparseArraysExt.jl index fca3793ad..9ffadf3a1 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseSparseArraysExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseSparseArraysExt.jl @@ -1,7 +1,7 @@ module NonlinearSolveBaseSparseArraysExt using NonlinearSolveBase: NonlinearSolveBase, Utils -using SparseArrays: AbstractSparseMatrix, AbstractSparseMatrixCSC, nonzeros +using SparseArrays: AbstractSparseMatrix, AbstractSparseMatrixCSC, nonzeros, sparse function NonlinearSolveBase.NAN_CHECK(x::AbstractSparseMatrixCSC) return any(NonlinearSolveBase.NAN_CHECK, nonzeros(x)) @@ -11,4 +11,6 @@ NonlinearSolveBase.sparse_or_structured_prototype(::AbstractSparseMatrix) = true Utils.maybe_symmetric(x::AbstractSparseMatrix) = x +Utils.make_sparse(x) = sparse(x) + end diff --git a/lib/NonlinearSolveBase/src/utils.jl b/lib/NonlinearSolveBase/src/utils.jl index 628a4e873..14e6b1faa 100644 --- a/lib/NonlinearSolveBase/src/utils.jl +++ b/lib/NonlinearSolveBase/src/utils.jl @@ -144,4 +144,6 @@ function evaluate_f!!(f::NonlinearFunction, fu, u, p) return f(u, p) end +function make_sparse end + end