Skip to content

Commit

Permalink
Fix PT
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 30, 2023
1 parent 4f2dec0 commit 4c61c4a
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 105 deletions.
6 changes: 3 additions & 3 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
import ForwardDiff
import ForwardDiff: Dual
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
import MaybeInplace: @bb
import MaybeInplace: setindex_trait, @bb, CanSetindex, CannotSetindex
import RecursiveArrayTools: ArrayPartition,
AbstractVectorOfArray, recursivecopy!, recursivefill!
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
Expand Down Expand Up @@ -80,7 +80,7 @@ function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache{iip}, u0 = get_u(c
return cache
end

__reinit_internal!(cache::AbstractNonlinearSolveCache) = nothing
__reinit_internal!(::AbstractNonlinearSolveCache) = nothing

function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
str = "$(nameof(typeof(alg)))("
Expand Down Expand Up @@ -157,7 +157,7 @@ include("raphson.jl")
# include("levenberg.jl")
# include("gaussnewton.jl")
# include("dfsane.jl")
# include("pseudotransient.jl")
include("pseudotransient.jl")
include("broyden.jl")
include("klement.jl")
# include("lbroyden.jl")
Expand Down
9 changes: 5 additions & 4 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,11 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
end

if linsolve_init
linprob_A = needsJᵀJ ? __maybe_symmetric(JᵀJ) : J
# linprob_A = alg isa PseudoTransient ?
# (J - (1 / (convert(eltype(u), alg.alpha_initial))) * I) :
# (needsJᵀJ ? __maybe_symmetric(JᵀJ) : J)
if alg isa PseudoTransient && J isa SciMLOperators.AbstractSciMLOperator
linprob_A = J - inv(convert(eltype(u), alg.alpha_initial)) * I
else
linprob_A = needsJᵀJ ? __maybe_symmetric(JᵀJ) : J
end
linsolve = linsolve_caches(linprob_A, needsJᵀJ ? Jᵀfu : fu, du, p, alg;
linsolve_kwargs)
else
Expand Down
6 changes: 4 additions & 2 deletions src/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,11 @@ function perform_step!(cache::GeneralKlementCache{iip}) where {iip}
cache.resets += 1
end

A = ifelse(cache.J isa SMatrix || cache.J isa Number || !fact_done, cache.J, nothing)

# u = u - J \ fu
linres = dolinsolve(alg.precs, cache.linsolve; A = cache.J, b = _vec(cache.fu),
linu = _vec(cache.du), cache.p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, cache.linsolve; A,
b = _vec(cache.fu), linu = _vec(cache.du), cache.p, reltol = cache.abstol)
cache.linsolve = linres.cache
!iip && (cache.du = linres.u)

Expand Down
140 changes: 44 additions & 96 deletions src/pseudotransient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ SIAM Journal on Scientific Computing,25, 553-569.](https://doi.org/10.1137/S1064
alpha_initial
end

#concrete_jac(::PseudoTransient{CJ}) where {CJ} = CJ
function set_ad(alg::PseudoTransient{CJ}, ad) where {CJ}
return PseudoTransient{CJ}(ad, alg.linsolve, alg.precs, alg.alpha_initial)
end
Expand All @@ -56,9 +55,9 @@ end
f
alg
u
u_prev
fu1
fu2
u_cache
fu
fu_cache
du
p
alpha
Expand Down Expand Up @@ -86,126 +85,75 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::PseudoTransi
alg = get_concrete_algorithm(alg_, prob)

@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
fu1 = evaluate_f(prob, u)
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
u = __maybe_unaliased(u0, alias_u0)
fu = evaluate_f(prob, u)
uf, linsolve, J, fu_cache, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_kwargs)
alpha = convert(eltype(u), alg.alpha_initial)
res_norm = internalnorm(fu1)
res_norm = internalnorm(fu)

abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu1, u,
@bb u_cache = copy(u)

abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
termination_condition)
trace = init_nonlinearsolve_trace(alg, u, fu1, ApplyArray(__zero, J), du; kwargs...)
trace = init_nonlinearsolve_trace(alg, u, fu, ApplyArray(__zero, J), du; kwargs...)

return PseudoTransientCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, alpha, res_norm,
uf, linsolve, J, jac_cache, false, maxiters, internalnorm, ReturnCode.Default,
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache, trace)
return PseudoTransientCache{iip}(f, alg, u, u_cache, fu, fu_cache, du, p, alpha,
res_norm, uf, linsolve, J, jac_cache, false, maxiters, internalnorm,
ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache, trace)
end

function perform_step!(cache::PseudoTransientCache{true})
@unpack u, u_prev, fu1, f, p, alg, J, linsolve, du, alpha = cache
jacobian!!(J, cache)
function perform_step!(cache::PseudoTransientCache{iip}) where {iip}
@unpack alg = cache

inv_alpha = inv(alpha)
if J isa SciMLBase.AbstractSciMLOperator
J = J - inv_alpha * I
else
idxs = diagind(J)
if fast_scalar_indexing(J)
@inbounds for i in axes(J, 1)
J[i, i] = J[i, i] - inv_alpha
cache.J = jacobian!!(cache.J, cache)

inv_α = inv(cache.alpha)
if cache.J isa SciMLOperators.AbstractSciMLOperator
A = cache.J - inv_α * I
elseif setindex_trait(cache.J) === CanSetindex()
idxs = diagind(cache.J)
if fast_scalar_indexing(cache.J)
@inbounds for i in axes(cache.J, 1)
cache.J[i, i] = cache.J[i, i] - inv_α
end
else
@.. broadcast=false @view(J[idxs])=@view(J[idxs]) - inv_alpha
@.. broadcast=false @view(cache.J[idxs])=@view(cache.J[idxs]) - inv_α
end
A = cache.J
else
cache.J = cache.J - inv_α * I
A = cache.J
end

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du),
p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, cache.linsolve; A, b = _vec(cache.fu),
linu = _vec(cache.du), cache.p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du
f(fu1, u, p)

update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), J,
cache.du)

new_norm = cache.internalnorm(fu1)
cache.alpha *= cache.res_norm / new_norm
cache.res_norm = new_norm

check_and_update!(cache, cache.fu1, cache.u, cache.u_prev)

@. u_prev = u
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
return nothing
end
!iip && (cache.du = linres.u)

function perform_step!(cache::PseudoTransientCache{false})
@unpack u, u_prev, fu1, f, p, alg, linsolve, alpha = cache
@bb axpy!(-true, cache.du, cache.u)

cache.J = jacobian!!(cache.J, cache)

inv_alpha = inv(alpha)
cache.J = cache.J - inv_alpha * I
# u = u - J \ fu
if linsolve === nothing
cache.du = fu1 / cache.J
else
linres = dolinsolve(alg.precs, linsolve; A = cache.J, b = _vec(fu1),
linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end
cache.u = @. u - cache.du # `u` might not support mutation
cache.fu1 = f(cache.u, p)
evaluate_f(cache, cache.u, cache.p)

update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), cache.J,
cache.du)
update_trace!(cache, true)

new_norm = cache.internalnorm(fu1)
new_norm = cache.internalnorm(cache.fu)
cache.alpha *= cache.res_norm / new_norm
cache.res_norm = new_norm

check_and_update!(cache, cache.fu1, cache.u, cache.u_prev)
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)

cache.u_prev = cache.u
@bb copyto!(cache.u_cache, cache.u)
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
return nothing
end

function SciMLBase.reinit!(cache::PseudoTransientCache{iip}, u0 = cache.u; p = cache.p,
alpha = cache.alpha, abstol = cache.abstol, reltol = cache.reltol,
termination_condition = get_termination_mode(cache.tc_cache),
maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu1, cache.u, p)
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu1 = cache.f(cache.u, p)
end

reset!(cache.trace)
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu1, cache.u,
termination_condition)

cache.alpha = convert(eltype(cache.u), alpha)
cache.res_norm = cache.internalnorm(cache.fu1)
cache.abstol = abstol
cache.reltol = reltol
cache.tc_cache = tc_cache
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
function __reinit_internal!(cache::PseudoTransientCache)
cache.alpha = convert(eltype(cache.u), cache.alg.alpha_initial)
cache.res_norm = cache.internalnorm(cache.fu)
return nothing
end

0 comments on commit 4c61c4a

Please sign in to comment.