Skip to content

Commit

Permalink
Fix GN
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 1, 2023
1 parent 21e9ed4 commit 0e3efd7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
return GaussNewtonCache{iip}(f, alg, u, u_cache, fu, fu_cache, du, dfu, p, uf,
linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default,
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2,
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)), trace)
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), trace)
end

function perform_step!(cache::GaussNewtonCache{iip}) where {iip}
Expand All @@ -117,14 +117,14 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip}
# Use normal form to solve the Linear Problem
if cache.JᵀJ !== nothing
__update_JᵀJ!(Val{iip}(), cache, :JᵀJ, cache.J)
__update_Jᵀf!(Val{iip}(), cache, :Jᵀf, :JᵀJ, cache.J, cache.fu1)
__update_Jᵀf!(Val{iip}(), cache, :Jᵀf, :JᵀJ, cache.J, cache.fu)
A, b = __maybe_symmetric(cache.JᵀJ), _vec(cache.Jᵀf)
else
A, b = cache.J, _vec(cache.fu)
end

linres = dolinsolve(alg.precs, linsolve; A, b, linu = _vec(du), cache.p,
reltol = cache.abstol)
linres = dolinsolve(cache.alg.precs, cache.linsolve; A, b, linu = _vec(cache.du),
cache.p, reltol = cache.abstol)
cache.linsolve = linres.cache
cache.du = _restructure(cache.du, linres.u)

Expand All @@ -136,7 +136,7 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip}
check_and_update!(cache.tc_cache_1, cache, cache.fu, cache.u, cache.u_cache)
if !cache.force_stop
@bb @. cache.dfu = cache.fu .- cache.dfu
check_and_update!(cache.tc_cache_2, cache, cache.dfu, cache.u, cache.u_prev)
check_and_update!(cache.tc_cache_2, cache, cache.dfu, cache.u, cache.u_cache)
end

@bb copyto!(cache.u_cache, cache.u)
Expand Down
9 changes: 9 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,15 @@ function evaluate_f(prob::Union{NonlinearProblem{uType, iip},
return fu
end

function evaluate_f(f, u, p, ::Val{iip}; fu = nothing) where {iip}
if iip
f(fu, u, p)
return fu
else
return f(u, p)
end
end

function evaluate_f(cache, u, p)
if isinplace(cache)
cache.prob.f(get_fu(cache), u, p)
Expand Down

0 comments on commit 0e3efd7

Please sign in to comment.