Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use index guessing functionality moved to FindFirstFunctions.jl (Guesser) #323

Merged
merged 7 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ DataInterpolationsSymbolicsExt = "Symbolics"
Aqua = "0.8"
BenchmarkTools = "1"
ChainRulesCore = "1.24"
FindFirstFunctions = "1.1"
FindFirstFunctions = "1.3"
FiniteDifferences = "0.12.31"
ForwardDiff = "0.10.36"
LinearAlgebra = "1.10"
Expand Down
4 changes: 2 additions & 2 deletions ext/DataInterpolationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end

function u_tangent(A::LinearInterpolation, t, Δ)
out = zero(A.u)
idx = get_idx(A, t, A.idx_prev[])
idx = get_idx(A, t, A.iguesser)
t_factor = (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx])
out[idx] = Δ * (one(eltype(out)) - t_factor)
out[idx + 1] = Δ * t_factor
Expand All @@ -61,7 +61,7 @@ end

function u_tangent(A::QuadraticInterpolation, t, Δ)
out = zero(A.u)
i₀, i₁, i₂ = _quad_interp_indices(A, t, A.idx_prev[])
i₀, i₁, i₂ = _quad_interp_indices(A, t, A.iguesser)
t₀ = A.t[i₀]
t₁ = A.t[i₁]
t₂ = A.t[i₂]
Expand Down
6 changes: 1 addition & 5 deletions src/DataInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using LinearAlgebra, RecipesBase
using PrettyTables
using ForwardDiff
import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated,
bracketstrictlymontonic
Guesser

include("parameter_caches.jl")
include("interpolation_caches.jl")
Expand All @@ -22,10 +22,6 @@ include("online.jl")
include("show.jl")

(interp::AbstractInterpolation)(t::Number) = _interpolate(interp, t)
function (interp::AbstractInterpolation)(t::Number, i::Integer)
interp.idx_prev[] = i
_interpolate(interp, t)
end

function (interp::AbstractInterpolation)(t::AbstractVector)
u = get_u(interp.u, t)
Expand Down
42 changes: 19 additions & 23 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
function derivative(A, t, order = 1)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
iguess = A.idx_prev[]
iguess = A.iguesser

return if order == 1
val, idx = _derivative(A, t, iguess)
A.idx_prev[] = idx
val
_derivative(A, t, iguess)
elseif order == 2
ForwardDiff.derivative(t -> begin
val, idx = _derivative(A, t, iguess)
A.idx_prev[] = idx
val
_derivative(A, t, iguess)
end, t)
else
throw(DerivativeNotFoundError())
Expand All @@ -20,7 +16,7 @@ end
function _derivative(A::LinearInterpolation, t::Number, iguess)
idx = get_idx(A, t, iguess; idx_shift = -1, ub_shift = -1, side = :first)
slope = get_parameters(A, idx)
slope, idx
slope
end

function _derivative(A::QuadraticInterpolation, t::Number, iguess)
Expand All @@ -29,7 +25,7 @@ function _derivative(A::QuadraticInterpolation, t::Number, iguess)
du₀ = l₀ * (2t - A.t[i₁] - A.t[i₂])
du₁ = l₁ * (2t - A.t[i₀] - A.t[i₂])
du₂ = l₂ * (2t - A.t[i₀] - A.t[i₁])
return @views @. du₀ + du₁ + du₂, i₀
return @views @. du₀ + du₁ + du₂
end

function _derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
Expand Down Expand Up @@ -101,21 +97,21 @@ function _derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number)
end

function _derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number, idx)
_derivative(A, t), idx
_derivative(A, t)
end
function _derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number, idx)
_derivative(A, t), idx
_derivative(A, t)
end

function _derivative(A::AkimaInterpolation{<:AbstractVector}, t::Number, iguess)
idx = get_idx(A, t, iguess; idx_shift = -1, side = :first)
j = min(idx, length(A.c)) # for smooth derivative at A.t[end]
wj = t - A.t[idx]
(@evalpoly wj A.b[idx] 2A.c[j] 3A.d[j]), idx
@evalpoly wj A.b[idx] 2A.c[j] 3A.d[j]
end

function _derivative(A::ConstantInterpolation, t::Number, iguess)
return zero(first(A.u)), iguess
return zero(first(A.u))
end

function _derivative(A::ConstantInterpolation{<:AbstractVector}, t::Number)
Expand All @@ -132,7 +128,7 @@ end
function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
idx = get_idx(A, t, iguess; lb = 2, ub_shift = 0, side = :first)
σ = get_parameters(A, idx - 1)
A.z[idx - 1] + 2σ * (t - A.t[idx - 1]), idx
A.z[idx - 1] + 2σ * (t - A.t[idx - 1])
end

# CubicSpline Interpolation
Expand All @@ -144,13 +140,13 @@ function _derivative(A::CubicSpline{<:AbstractVector}, t::Number, iguess)
c₁, c₂ = get_parameters(A, idx)
dC = c₁
dD = -c₂
dI + dC + dD, idx
dI + dC + dD
end

function _derivative(A::BSplineInterpolation{<:AbstractVector{<:Number}}, t::Number, iguess)
# change t into param [0 1]
t < A.t[1] && return zero(A.u[1]), 1
t > A.t[end] && return zero(A.u[end]), lastindex(t)
t < A.t[1] && return zero(A.u[1])
t > A.t[end] && return zero(A.u[end])
idx = get_idx(A, t, iguess)
n = length(A.t)
scale = (A.p[idx + 1] - A.p[idx]) / (A.t[idx + 1] - A.t[idx])
Expand All @@ -165,14 +161,14 @@ function _derivative(A::BSplineInterpolation{<:AbstractVector{<:Number}}, t::Num
ducum += N[i + 1] * (A.c[i + 1] - A.c[i]) / (A.k[i + A.d + 1] - A.k[i + 1])
end
end
ducum * A.d * scale, idx
ducum * A.d * scale
end

# BSpline Curve Approx
function _derivative(A::BSplineApprox{<:AbstractVector{<:Number}}, t::Number, iguess)
# change t into param [0 1]
t < A.t[1] && return zero(A.u[1]), 1
t > A.t[end] && return zero(A.u[end]), lastindex(t)
t < A.t[1] && return zero(A.u[1])
t > A.t[end] && return zero(A.u[end])
idx = get_idx(A, t, iguess)
scale = (A.p[idx + 1] - A.p[idx]) / (A.t[idx + 1] - A.t[idx])
t_ = A.p[idx] + (t - A.t[idx]) * scale
Expand All @@ -186,7 +182,7 @@ function _derivative(A::BSplineApprox{<:AbstractVector{<:Number}}, t::Number, ig
ducum += N[i + 1] * (A.c[i + 1] - A.c[i]) / (A.k[i + A.d + 1] - A.k[i + 1])
end
end
ducum * A.d * scale, idx
ducum * A.d * scale
end

# Cubic Hermite Spline
Expand All @@ -198,7 +194,7 @@ function _derivative(
out = A.du[idx]
c₁, c₂ = get_parameters(A, idx)
out += Δt₀ * (Δt₀ * c₂ + 2(c₁ + Δt₁ * c₂))
out, idx
out
end

# Quintic Hermite Spline
Expand All @@ -211,5 +207,5 @@ function _derivative(
c₁, c₂, c₃ = get_parameters(A, idx)
out += Δt₀^2 *
(3c₁ + (3Δt₁ + Δt₀) * c₂ + (3Δt₁^2 + Δt₀ * 2Δt₁) * c₃)
out, idx
out
end
14 changes: 7 additions & 7 deletions src/integral_inverses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ invert_integral(A::AbstractInterpolation) = throw(IntegralInverseNotFoundError()
_integral(A::AbstractIntegralInverseInterpolation, idx, t) = throw(IntegralNotFoundError())

function _derivative(A::AbstractIntegralInverseInterpolation, t::Number, iguess)
inv(A.itp(A(t))), A.idx_prev[]
inv(A.itp(A(t)))
end

"""
Expand All @@ -38,11 +38,11 @@ struct LinearInterpolationIntInv{uType, tType, itpType, T} <:
u::uType
t::tType
extrapolate::Bool
idx_prev::Base.RefValue{Int}
iguesser::Guesser{tType}
itp::itpType
function LinearInterpolationIntInv(u, t, A)
new{typeof(u), typeof(t), typeof(A), eltype(u)}(
u, t, A.extrapolate, Ref(1), A)
u, t, A.extrapolate, Guesser(t), A)
end
end

Expand All @@ -64,7 +64,7 @@ function _interpolate(
x = A.itp.u[idx]
slope = get_parameters(A.itp, idx)
u = A.u[idx] + 2Δt / (x + sqrt(x^2 + slope * 2Δt))
u, idx
u
end

"""
Expand All @@ -84,11 +84,11 @@ struct ConstantInterpolationIntInv{uType, tType, itpType, T} <:
u::uType
t::tType
extrapolate::Bool
idx_prev::Base.RefValue{Int}
iguesser::Guesser{tType}
itp::itpType
function ConstantInterpolationIntInv(u, t, A)
new{typeof(u), typeof(t), typeof(A), eltype(u)}(
u, t, A.extrapolate, Ref(1), A
u, t, A.extrapolate, Guesser(t), A
)
end
end
Expand All @@ -112,5 +112,5 @@ function _interpolate(
# :right means that value to the right is used for interpolation
idx_ = get_idx(A, t, idx; side = :first, lb = 1, ub_shift = 0)
end
A.u[idx] + (t - A.t[idx]) / A.itp.u[idx_], idx
A.u[idx] + (t - A.t[idx]) / A.itp.u[idx_]
end
Loading
Loading