Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add DFSane method #214

Merged
merged 30 commits into from
Oct 17, 2023
Merged

WIP: Add DFSane method #214

merged 30 commits into from
Oct 17, 2023

Conversation

axla-io
Copy link
Contributor

@axla-io axla-io commented Sep 18, 2023

This PR adds a DFSane solver, similar to the ones in SimpleNonlinearSolve, here and here.

The implementation in this PR improves on the SimpleNonlinearSolve version by adding a cached solver with non allocating iterations.

Checklist:

  • Algorithm and cache
  • Init function
  • In place steps
  • Bounds check for spectral parameter
  • Out of place steps
  • Benchmark against SimpleNonlinearSolve
  • Add tests
  • Docstrings

@axla-io
Copy link
Contributor Author

axla-io commented Sep 19, 2023

Started implementation of OOP solver but this doesn't work (error: f not found):

using NonlinearSolve
using Random
Random.seed!(123)

function f!(du, u, p)
    @. du .= u .* u .- p
    return nothing
end

f = (u, p) -> u .* u .- p

n_test = 10
u0 = rand(n_test) 
p = rand(n_test) .* 5

prob_iip = NonlinearProblem{true}(f!, u0, p);
prob_oop = NonlinearProblem{false}(f, u0, p);

alg = NonlinearSolve.DFSane()
sol = solve(prob_iip, alg) # works
sol = solve(prob_oop, alg) # doesn't work

@axla-io
Copy link
Contributor Author

axla-io commented Sep 19, 2023

Stacktrace:

ERROR: UndefVarError: `f` not defined
Stacktrace:
  [1] __init(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"}; alias_u0::Bool, maxiters::Int64, abstol::Float64, internalnorm::Function, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/dfsane.jl:118
  [2] __init(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/dfsane.jl:88
  [3] init_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; merge_callbacks::Bool, kwargshandle::DiffEqBase.KeywordArgError, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:455
  [4] init_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:433
  [5] init_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:505
  [6] init_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:475
  [7] init(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; sensealg::Nothing, u0::Nothing, p::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:468
  [8] init(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:459
  [9] __solve(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/NonlinearSolve.jl:32
 [10] __solve(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/NonlinearSolve.jl:29
 [11] solve_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; merge_callbacks::Bool, kwargshandle::DiffEqBase.KeywordArgError, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:539
 [12] solve_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:509
 [13] solve_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:1008
 [14] solve_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:973
 [15] solve(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{true}, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:967
 [16] solve(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:957
 [17] top-level scope
    @ ~/Desktop/PrincetonCourses/MIT/dfsane_test/mwe_oop.jl:21

src/dfsane.jl Outdated
Comment on lines 113 to 163
f(dx, x) = prob.f(dx, x, p)
f(fuₙ₋₁, uₙ₋₁)

else
f(x) = prob.f(x, p)
fuₙ₋₁ = f(uₙ₋₁)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error is because f is being overwritten due to branching. Changing the name might fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the name does sort of fix it. But I'm sorry, I really don't understand why this happens. I never had this problem before. Do you have some quick reference?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You just cannot define standard functions in branches, make them anonymous.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

src/dfsane.jl Outdated Show resolved Hide resolved
src/dfsane.jl Outdated Show resolved Hide resolved
src/dfsane.jl Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Sep 21, 2023

Codecov Report

Merging #214 (0df8aaf) into master (7890cef) will increase coverage by 86.11%.
The diff coverage is 99.32%.

@@             Coverage Diff             @@
##           master     #214       +/-   ##
===========================================
+ Coverage    0.00%   86.11%   +86.11%     
===========================================
  Files          13       14        +1     
  Lines        1054     1203      +149     
===========================================
+ Hits            0     1036     +1036     
+ Misses       1054      167      -887     
Files Coverage Δ
src/NonlinearSolve.jl 89.47% <ø> (+89.47%) ⬆️
src/dfsane.jl 99.32% <99.32%> (ø)

... and 12 files with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

src/dfsane.jl Outdated Show resolved Hide resolved
src/dfsane.jl Outdated Show resolved Hide resolved
@avik-pal
Copy link
Member

avik-pal commented Oct 5, 2023

How close are we from getting this in?

src/dfsane.jl Outdated Show resolved Hide resolved
@axla-io
Copy link
Contributor Author

axla-io commented Oct 17, 2023

Tests are added! Everything works except for that ForwardDiff fails in some cases, see this MWE:

using NonlinearSolve
using FiniteDiff, ForwardDiff

quadratic_f(u, p) = u .* u .- p

function benchmark_nlsolve_oop(f, u0, p=2.0)
    prob = NonlinearProblem{false}(f, u0, p)
    return solve(prob, DFSane(), abstol=1e-9)
end

broken_forwarddiff = [3.0, 4.0, 81.0]
for p in broken_forwarddiff
    analytical_derivative = 1 / (2 * sqrt(p))
    forward_diff = abs(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p))
    finite_diff = abs(FiniteDiff.finite_difference_derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p))
    println("p = $p, Analytical: $analytical_derivative, ForwardDiff: $forward_diff, FiniteDiff: $finite_diff")
end

Which prints out:

p = 3.0, Analytical: 0.2886751345948129, ForwardDiff: 1776.530469223857, FiniteDiff: 0.2886751347781091
p = 4.0, Analytical: 0.25, ForwardDiff: 1.0, FiniteDiff: 0.25000000015714613
p = 81.0, Analytical: 0.05555555555555555, ForwardDiff: 0.1, FiniteDiff: 0.05555555555505331

@ChrisRackauckas
Copy link
Member

Everything works except for that ForwardDiff fails in some cases

That's fine. We shouldn't ForwardDiff the solver anyways. Someone should handle that separately.

@ChrisRackauckas
Copy link
Member

Specifically #245

@ChrisRackauckas ChrisRackauckas merged commit 3d85a52 into SciML:master Oct 17, 2023
7 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants