Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split the core of initdt to be algorithm independent #2007

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 53 additions & 29 deletions src/initdt.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,50 @@
@muladd function ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm,
prob::DiffEqBase.AbstractODEProblem{uType, tType, true
},
integrator) where {tType, uType}
prob::DiffEqBase.AbstractODEProblem{uType, tType, true}, integrator) where {tType, uType}

sk = if !(typeof(integrator.alg) <: CompositeAlgorithm)
first(get_tmp_cache(integrator)
else
nothing
end

current_fsal = get_current_isfsal(integrator.alg, integrator.cache)
is_odeintegrator = typeof(integrator) <: ODEIntegrator
verbose = integrator.opts.verbose
alg_order = get_current_alg_order(integrator.alg, integrator.cache)

linsolve = if haskey(integrator.alg, :linsolve)
integrator.alg.linsolve
else
nothing
end

fsallast = if current_fsal
integrator.fsallast
else
nothing
end

_ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm,
prob, integrator.p, integrator.opts.dtmin, integrator.isdae, iscomposite,
sk, fsallast, current_fsal, is_odeintegrator, verbose, alg_order, linsolve)
end

@muladd function _ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm,
prob::DiffEqBase.AbstractODEProblem{uType, tType, true}, p, dtmin, isdae, iscomposite, sk, fsallast, current_fsal, is_odeintegrator,
verbose, alg_order, linsolve) where {tType, uType}
_tType = eltype(tType)
f = prob.f
p = integrator.p
oneunit_tType = oneunit(_tType)
dtmax_tdir = tdir * dtmax

dtmin = nextfloat(integrator.opts.dtmin)
dtmin = nextfloat(dtmin)
smalldt = convert(_tType, oneunit_tType * 1 // 10^(6))

if integrator.isdae
if isdae
return tdir * max(smalldt, dtmin)
end

if eltype(u0) <: Number && !(typeof(integrator.alg) <: CompositeAlgorithm)
cache = get_tmp_cache(integrator)
sk = first(cache)
if eltype(u0) <: Number && iscomposite
if u0 isa Array && abstol isa Number && reltol isa Number
@inbounds @simd ivdep for i in eachindex(u0)
sk[i] = abstol + internalnorm(u0[i], t) * reltol
Expand All @@ -36,10 +63,9 @@
end
end

if get_current_isfsal(integrator.alg, integrator.cache) &&
typeof(integrator) <: ODEIntegrator
if current_fsal && is_odeintegrator
# Right now DelayDiffEq has issues with fsallast not being initialized
f₀ = integrator.fsallast
f₀ = fsallast
f(f₀, u0, p, t)
else
# TODO: use more caches
Expand Down Expand Up @@ -107,7 +133,7 @@
any(mm != I for mm in prob.f.mass_matrix))
ftmp = zero(f₀)
try
integrator.alg.linsolve(ftmp, copy(prob.f.mass_matrix), f₀, true)
linsolve(ftmp, copy(prob.f.mass_matrix), f₀, true)
copyto!(f₀, ftmp)
catch
return tdir * max(smalldt, dtmin)
Expand All @@ -127,7 +153,7 @@
# Better than checking any(x->any(isnan, x), f₀)
# because it also checks if partials are NaN
# https://discourse.julialang.org/t/incorporating-forcing-functions-in-the-ode-model/70133/26
if integrator.opts.verbose && isnan(d₁)
if verbose && isnan(d₁)
@warn("First function call produced NaNs. Exiting. Double check that none of the initial conditions, parameters, or timespan values are NaN.")
return tdir * dtmin
end
Expand Down Expand Up @@ -166,7 +192,7 @@

if prob.f.mass_matrix != I && (!(typeof(prob.f) <: DynamicalODEFunction) ||
any(mm != I for mm in prob.f.mass_matrix))
integrator.alg.linsolve(ftmp, prob.f.mass_matrix, f₁, false)
linsolve(ftmp, prob.f.mass_matrix, f₁, false)
copyto!(f₁, ftmp)
end

Expand All @@ -192,8 +218,7 @@
else
dt₁ = convert(_tType,
oneunit_tType *
10.0^(-(2 + log10(max_d₁d₂)) /
get_current_alg_order(integrator.alg, integrator.cache)))
10.0^(-(2 + log10(max_d₁d₂)) / alg_order))
end
return tdir * max(dtmin, min(100dt₀, dt₁, dtmax_tdir))
end
Expand Down Expand Up @@ -224,28 +249,28 @@ function Base.showerror(io::IO, e::TypeNotConstantError)
println(io, e.f₀)
end

@muladd function ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm,
@muladd function _ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm,
prob::DiffEqBase.AbstractODEProblem{uType, tType,
false},
integrator) where {uType, tType}
false}, p, dtmin, isdae, iscomposite, sk, fsallast, current_fsal, is_odeintegrator,
verbose, alg_order, linsolve) where {uType, tType}
_tType = eltype(tType)
f = prob.f
p = prob.p
oneunit_tType = oneunit(_tType)
dtmax_tdir = tdir * dtmax

dtmin = nextfloat(integrator.opts.dtmin)
dtmin = nextfloat(dtmin)
smalldt = convert(_tType, oneunit_tType * 1 // 10^(6))

if integrator.isdae
if isdae
return tdir * max(smalldt, dtmin)
end

sk = @.. broadcast=false abstol+internalnorm(u0, t) * reltol
d₀ = internalnorm(u0 ./ sk, t)

f₀ = f(u0, p, t)
if integrator.opts.verbose && any(x -> any(isnan, x), f₀)
if verbose && any(x -> any(isnan, x), f₀)
@warn("First function call produced NaNs. Exiting. Double check that none of the initial conditions, parameters, or timespan values are NaN.")
end

Expand Down Expand Up @@ -279,19 +304,18 @@ end
dt₁ = max(smalldt, dt₀ * 1 // 10^(3))
else
dt₁ = _tType(oneunit_tType *
10^(-(2 + log10(max_d₁d₂)) /
get_current_alg_order(integrator.alg, integrator.cache)))
10^(-(2 + log10(max_d₁d₂)) / alg_order))
end
return tdir * max(dtmin, min(100dt₀, dt₁, dtmax_tdir))
end

@inline function ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm,
@inline function _ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm,
prob::DiffEqBase.AbstractDAEProblem{duType, uType,
tType},
integrator) where {duType, uType, tType}
tType}, p, dtmin, isdae, iscomposite, sk, fsallast, current_fsal, is_odeintegrator,
verbose, alg_order, linsolve) where {duType, uType, tType}
_tType = eltype(tType)
tspan = prob.tspan
init_dt = abs(tspan[2] - tspan[1])
init_dt = isfinite(init_dt) ? init_dt : oneunit(_tType)
return convert(_tType, init_dt * 1 // 10^(6))
end
end