diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 36e59135..35cc34ba 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -13,7 +13,7 @@ jobs: CompatHelper: runs-on: ubuntu-latest steps: - - uses: julia-actions/setup-julia@3645a07f58c7f83b9f82ac8e0bb95583e69149e6 + - uses: julia-actions/setup-julia@780022b48dfc0c2c6b94cfee6a9284850107d037 with: version: 1.3 - name: Pkg.add("CompatHelper") diff --git a/.github/workflows/Downgrade.yml b/.github/workflows/Downgrade.yml index c0d0123e..4546ebd0 100644 --- a/.github/workflows/Downgrade.yml +++ b/.github/workflows/Downgrade.yml @@ -28,7 +28,7 @@ jobs: - windows-latest steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2.2.0 + - uses: julia-actions/setup-julia@v2.3.0 with: version: ${{ matrix.version }} - uses: julia-actions/julia-downgrade-compat@v1 diff --git a/Project.toml b/Project.toml index 7f08abb0..17fdd131 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,12 @@ name = "DataInterpolations" uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" -version = "5.3.0" +version = "6.0.0" [deps] FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -16,6 +15,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Optim = "429524aa-4258-5aef-a3af-852621145aeb" RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DataInterpolationsChainRulesCoreExt = "ChainRulesCore" @@ -25,7 +25,7 @@ DataInterpolationsSymbolicsExt = "Symbolics" [compat] Aqua = "0.8" -ChainRulesCore = "1.18" +ChainRulesCore = "1.24" FindFirstFunctions = "1.1" FiniteDifferences = "0.12.31" ForwardDiff = "0.10.36" @@ -33,7 +33,6 @@ LinearAlgebra = "1.10" Optim = "1.6" PrettyTables = "2" QuadGK = "2.9.1" -ReadOnlyArrays = "0.2.0" RecipesBase = "1.3" Reexport = "1" RegularizationTools = "0.6" @@ -41,6 +40,7 @@ SafeTestsets = "0.1" StableRNGs = "1" Symbolics = "5.29" Test = "1" +Zygote = "0.6.70" julia = "1.10" [extras] @@ -55,6 +55,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics"] +test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Zygote"] diff --git a/docs/Project.toml b/docs/Project.toml index 30a52bff..540ffc47 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,15 +1,23 @@ [deps] DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" +ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739" Optim = "429524aa-4258-5aef-a3af-852621145aeb" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [compat] -DataInterpolations = "5" +DataInterpolations = "6" Documenter = "1" +ModelingToolkit = "9" +ModelingToolkitStandardLibrary = "2" Optim = "1" +OrdinaryDiffEq = "6" Plots = "1" RegularizationTools = "0.6" -StableRNGs = "1" \ No newline at end of file +StableRNGs = "1" +Symbolics = "5.29" diff --git a/docs/make.jl b/docs/make.jl index 94546761..6438f482 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -13,6 +13,7 @@ makedocs(modules = [DataInterpolations], format = Documenter.HTML(assets = ["assets/favicon.ico"], canonical = "https://docs.sciml.ai/DataInterpolations/stable/"), pages = ["index.md", "Methods" => "methods.md", - "Interface" => "interface.md", "Manual" => "manual.md", "Inverting Integrals" => "inverting_integrals.md"]) + "Interface" => "interface.md", "Using with Symbolics/ModelingToolkit" => "symbolics.md", + "Manual" => "manual.md", "Inverting Integrals" => "inverting_integrals.md"]) deploydocs(repo = "github.com/SciML/DataInterpolations.jl"; push_preview = true) diff --git a/docs/src/index.md b/docs/src/index.md index 9b1277f3..f2075f8e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,10 +1,6 @@ # DataInterpolations.jl -DataInterpolations.jl is a library for performing interpolations of one-dimensional data. By -"data interpolations" we mean techniques for interpolating possibly noisy data, and thus -some methods are mixtures of regressions with interpolations (i.e. do not hit the data -points exactly, smoothing out the lines). This library can be used to fill in intermediate -data points in applications like timeseries data. +DataInterpolations.jl is a library for performing interpolations of one-dimensional data. Interpolations are a very important component of many modeling workflows. Often, sampled or measured inputs need to be transformed into continuous functions or smooth curves for simulation purposes. In many scientific machine learning workflows, interpolating data is essential to learn continuous models. DataInterpolations.jl can be used for facilitating these types of workflows. By "data interpolations" we mean techniques for interpolating possibly noisy data, and thus some methods are mixtures of regressions with interpolations (i.e. do not hit the data points exactly, smoothing out the lines). ## Installation diff --git a/docs/src/interface.md b/docs/src/interface.md index ca5e9819..cfc9ed0b 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -35,22 +35,7 @@ A2(300.0) The values computed beyond the range of the time points provided during interpolation will not be reliable, as these methods only perform well within the range and the first/last piece polynomial fit is extrapolated on either side which might not reflect the true nature of the data. -The keyword `safetycopy = false` can be passed to make sure no copies of `u` and `t` are made when initializing the interpolation object. - -```@example interface -A3 = QuadraticInterpolation(u, t; safetycopy = false) - -# Check for same memory -u === A3.u.parent -``` - -Note that this does not prevent allocation in every interpolation constructor call, because parameter values are cached for all interpolation types except [`ConstantInterpolation`](@ref). - -Because of the caching of parameters which depend on `u` and `t`, this data should not be mutated. Therefore `u` and `t` are wrapped in a `ReadOnlyArray` from [ReadOnlyArrays.jl](https://github.com/JuliaArrays/ReadOnlyArrays.jl). - -```@repl interface -A3.t[2] = 3.14 -``` +The keyword `cache_parameters = true` can be passed to precalculate parameters at initialization, making evalations cheaper to compute. This is not compatible with modifying `u` and `t`. The default `cache_parameters = false` does however not prevent allocation in every interpolation constructor call. ## Derivatives diff --git a/docs/src/symbolics.md b/docs/src/symbolics.md new file mode 100644 index 00000000..77d45ead --- /dev/null +++ b/docs/src/symbolics.md @@ -0,0 +1,65 @@ +# Using DataInterpolations.jl with Symbolics.jl and ModelingToolkit.jl + +All interpolation methods can be integrated with [Symbolics.jl](https://symbolics.juliasymbolics.org/stable/) and [ModelingToolkit.jl](https://docs.sciml.ai/ModelingToolkit/stable/) seamlessly. + +## Using with Symbolics.jl + +### Expressions + +```@example symbolics +using DataInterpolations, Symbolics +using Test + +u = [0.0, 1.5, 0.0] +t = [0.0, 0.5, 1.0] +A = LinearInterpolation(u, t) + +@variables τ + +# Simple Expression +ex = cos(τ) * A(τ) +@test substitute(ex, Dict(τ => 0.5)) == cos(0.5) * A(0.5) # true +``` + +### Symbolic Derivatives + +```@example symbolics +D = Differential(τ) + +ex1 = A(τ) + +# Derivative of interpolation +ex2 = expand_derivatives(D(ex1)) + +@test substitute(ex2, Dict(τ => 0.5)) == DataInterpolations.derivative(A, 0.5) # true + +# Higher Order Derivatives +ex3 = expand_derivatives(D(D(A(τ)))) + +@test substitute(ex3, Dict(τ => 0.5)) == DataInterpolations.derivative(A, 0.5, 2) # true +``` + +## Using with ModelingToolkit.jl + +Most common use case with [ModelingToolkit.jl](https://docs.sciml.ai/ModelingToolkit/stable/) is to plug in interpolation objects as input functions. This can be done using `TimeVaryingFunction` component of [ModelingToolkitStandardLibrary.jl](https://docs.sciml.ai/ModelingToolkitStandardLibrary/stable/). + +```@example mtk +using DataInterpolations +using ModelingToolkitStandardLibrary.Blocks +using ModelingToolkit +using ModelingToolkit: t_nounits as t, D_nounits as D +using OrdinaryDiffEq + +us = [0.0, 1.5, 0.0] +times = [0.0, 0.5, 1.0] +A = LinearInterpolation(us, times) + +@named src = TimeVaryingFunction(A) +vars = @variables x(t) out(t) +eqs = [out ~ src.output.u, D(x) ~ 1 + out] +@named sys = ODESystem(eqs, t, vars, []; systems = [src]) + +sys = structural_simplify(sys) +prob = ODEProblem(sys, [x => 0.0], (times[1], times[end])) +sol = solve(prob) +``` diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 34e27841..9c33b09c 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -1,19 +1,87 @@ module DataInterpolationsChainRulesCoreExt - if isdefined(Base, :get_extension) using DataInterpolations: _interpolate, derivative, AbstractInterpolation, + LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, - BSplineInterpolation, BSplineApprox + BSplineInterpolation, BSplineApprox, get_idx, get_parameters, + _quad_interp_indices using ChainRulesCore else using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, + LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, - BSplineInterpolation, BSplineApprox + BSplineInterpolation, BSplineApprox, get_parameters, + _quad_interp_indices using ..ChainRulesCore end +function ChainRulesCore.rrule( + ::Type{LinearInterpolation}, u, t, I, p, extrapolate, cache_parameters) + A = LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) + function LinearInterpolation_pullback(ΔA) + df = NoTangent() + du = ΔA.u + dt = NoTangent() + dI = NoTangent() + dp = NoTangent() + dextrapolate = NoTangent() + dcache_parameters = NoTangent() + df, du, dt, dI, dp, dextrapolate, dcache_parameters + end + + A, LinearInterpolation_pullback +end + +function ChainRulesCore.rrule( + ::Type{QuadraticInterpolation}, u, t, I, p, mode, extrapolate, cache_parameters) + A = QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) + function LinearInterpolation_pullback(ΔA) + df = NoTangent() + du = ΔA.u + dt = NoTangent() + dI = NoTangent() + dp = NoTangent() + dmode = NoTangent() + dextrapolate = NoTangent() + dcache_parameters = NoTangent() + df, du, dt, dI, dp, dmode, dextrapolate, dcache_parameters + end + + A, LinearInterpolation_pullback +end + +function u_tangent(A::LinearInterpolation, t, Δ) + out = zero(A.u) + idx = get_idx(A.t, t, A.idx_prev[]) + t_factor = (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx]) + out[idx] = Δ * (one(eltype(out)) - t_factor) + out[idx + 1] = Δ * t_factor + out +end + +function u_tangent(A::QuadraticInterpolation, t, Δ) + out = zero(A.u) + i₀, i₁, i₂ = _quad_interp_indices(A, t, A.idx_prev[]) + t₀ = A.t[i₀] + t₁ = A.t[i₁] + t₂ = A.t[i₂] + Δt₀ = t₁ - t₀ + Δt₁ = t₂ - t₁ + Δt₂ = t₂ - t₀ + out[i₀] = Δ * (t - A.t[i₁]) * (t - A.t[i₂]) / (Δt₀ * Δt₂) + out[i₁] = -Δ * (t - A.t[i₀]) * (t - A.t[i₂]) / (Δt₀ * Δt₁) + out[i₂] = Δ * (t - A.t[i₀]) * (t - A.t[i₁]) / (Δt₂ * Δt₁) + out +end + +function u_tangent(A, t, Δ) + NoTangent() +end + function ChainRulesCore.rrule(::typeof(_interpolate), A::Union{ + LinearInterpolation, + QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, @@ -21,7 +89,9 @@ function ChainRulesCore.rrule(::typeof(_interpolate), }, t::Number) deriv = derivative(A, t) - interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ) + function interpolate_pullback(Δ) + (NoTangent(), Tangent{typeof(A)}(; u = u_tangent(A, t, Δ)), deriv * Δ) + end return _interpolate(A, t), interpolate_pullback end diff --git a/ext/DataInterpolationsOptimExt.jl b/ext/DataInterpolationsOptimExt.jl index 5528503f..b3bce295 100644 --- a/ext/DataInterpolationsOptimExt.jl +++ b/ext/DataInterpolationsOptimExt.jl @@ -18,9 +18,8 @@ function Curvefit(u, box = false, lb = nothing, ub = nothing; - extrapolate = false, - safetycopy = false) - u, t = munge_data(u, t, safetycopy) + extrapolate = false) + u, t = munge_data(u, t) errfun(t, u, p) = sum(abs2.(u .- model(t, p))) if box == false mfit = optimize(p -> errfun(t, u, p), p0, alg) diff --git a/ext/DataInterpolationsRegularizationToolsExt.jl b/ext/DataInterpolationsRegularizationToolsExt.jl index 10ea3e4c..732ea1bb 100644 --- a/ext/DataInterpolationsRegularizationToolsExt.jl +++ b/ext/DataInterpolationsRegularizationToolsExt.jl @@ -69,8 +69,8 @@ A = RegularizationSmooth(u, t, t̂, wls, wr, d; λ = 1.0, alg = :gcv_svd) """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, wls::AbstractVector, wr::AbstractVector, d::Int = 2; - λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) M = _mapping_matrix(t̂, t) Wls½ = LA.diagm(sqrt.(wls)) Wr½ = LA.diagm(sqrt.(wr)) @@ -86,8 +86,8 @@ A = RegularizationSmooth(u, t, d; λ = 1.0, alg = :gcv_svd, extrapolate = false) """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -115,8 +115,8 @@ A = RegularizationSmooth(u, t, t̂, d; λ = 1.0, alg = :gcv_svd, extrapolate = f """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd, - extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + extrapolate::Bool = false) + u, t = munge_data(u, t) N, N̂ = length(t), length(t̂) M = _mapping_matrix(t̂, t) Wls½ = Array{Float64}(LA.I, N, N) @@ -143,8 +143,8 @@ A = RegularizationSmooth(u, t, t̂, wls, d; λ = 1.0, alg = :gcv_svd, extrapolat """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, wls::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) N, N̂ = length(t), length(t̂) M = _mapping_matrix(t̂, t) Wls½ = LA.diagm(sqrt.(wls)) @@ -172,8 +172,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -202,8 +202,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::AbstractVector, wr::AbstractVector, d::Int = 2; - λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -232,8 +232,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::Symbol, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd, - extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) diff --git a/ext/DataInterpolationsSymbolicsExt.jl b/ext/DataInterpolationsSymbolicsExt.jl index 106535ef..b144da22 100644 --- a/ext/DataInterpolationsSymbolicsExt.jl +++ b/ext/DataInterpolationsSymbolicsExt.jl @@ -12,8 +12,7 @@ else using ..Symbolics: Num, unwrap, SymbolicUtils end -(interp::AbstractInterpolation)(t::Num) = SymbolicUtils.term(interp, unwrap(t)) -SymbolicUtils.promote_symtype(t::AbstractInterpolation, _...) = Real +@register_symbolic (interp::AbstractInterpolation)(t) Base.nameof(interp::AbstractInterpolation) = :Interpolation function derivative(interp::AbstractInterpolation, t::Num, order = 1) diff --git a/joss/paper.bib b/joss/paper.bib index d754d9cc..101b0181 100644 --- a/joss/paper.bib +++ b/joss/paper.bib @@ -134,3 +134,17 @@ @book{lagrange1898lectures year={1898}, publisher={Open court publishing Company} } + +@article{doi:10.1137/0905021, + author = {Fritsch, F. N. and Butland, J.}, + title = {A Method for Constructing Local Monotone Piecewise Cubic Interpolants}, + journal = {SIAM Journal on Scientific and Statistical Computing}, + volume = {5}, + number = {2}, + pages = {300-304}, + year = {1984}, + doi = {10.1137/0905021}, + URL = {https://doi.org/10.1137/0905021}, + eprint = {https://doi.org/10.1137/0905021}, + abstract = { A method is described for producing monotone piecewise cubic interpolants to monotone data which is completely local and which is extremely simple to implement. } +} diff --git a/joss/paper.md b/joss/paper.md index f2d2a4be..66d15aa0 100644 --- a/joss/paper.md +++ b/joss/paper.md @@ -31,15 +31,30 @@ bibliography: paper.bib # Summary -Interpolations are used to estimate values between known data points using an approximate continuous function.DataInterpolations.jl is a Julia [@Bezanson2017] package containing 1D implementations of some of the most commonly used interpolation functions. These include Constant Interpolation, Linear Interpolation, Quadratic Interpolation, Lagrange Interpolation [@lagrange], Quadratic Splines, Cubic Splines [@Schoenberg1988], Akima Splines [@10.1145/321607.321609], Cubic Hermite Splines, Quintic Hermite Splines, B-Splines [@Curry1988] [@DEBOOR197250] and Regression based B-Splines. Along with these, the package also has methods to fit parameterized curves with the data points and Tikhonov regularization [@Tikhonov1943OnTS] [@amt-14-7909-2021] for obtaining smooth curves. The package also provides functionality to compute integrals and derivatives upto second order for those interpolations methods. +Interpolations are used to estimate values between known data points using an approximate continuous function.DataInterpolations.jl is a Julia [@Bezanson2017] package containing 1D implementations of some of the most commonly used interpolation functions. These include: + + - Constant Interpolation + - Linear Interpolation + - Quadratic Interpolation + - Lagrange Interpolation [@lagrange] + - Quadratic Splines + - Cubic Splines [@Schoenberg1988] + - Akima Splines [@10.1145/321607.321609] + - Cubic Hermite Splines + - Piecewise Cubic Hermite Interpolating Polynomial (PCHIP) [@doi:10.1137/0905021] + - Quintic Hermite Splines + - B-Splines [@Curry1988] [@DEBOOR197250] + - Regression based B-Splines + +and a continually growing list. Along with these, the package also has methods to fit parameterized curves with the data points and Tikhonov regularization [@Tikhonov1943OnTS] [@amt-14-7909-2021] for obtaining smooth curves. The package also provides functionality to compute integrals and derivatives upto second order for those interpolations methods. It is also automatic differentiation friendly. It can also be used symbolically with Symbolics.jl [@gowda2021high] and plugged into models defined using ModelingToolkit.jl [@ma2021modelingtoolkit]. # Statement of need -Interpolations are a very important component of many modeling workflows. In many models, inputs which are sampled or measured need to be represented as a continuous function or a smooth curve for simulation. In many scientific machine learning workflows, we need interpolations of data to learn continuous models. There already have been a few interpolation packages in Julia like Interpolations.jl but it has a limitation of assuming uniformly spaced data which is not usually the case with data collected from real world. DataInterpolations.jl provides fast interpolation methods for arbitrary spaced 1D data with a consistent and simple interface. It is also automatic differentiation friendly. It can also be used symbolically with Symbolics.jl [@gowda2021high] and plugged into models defined using ModelingToolkit.jl [@ma2021modelingtoolkit]. +Interpolations are a very important component of many modeling workflows. Often, sampled or measured inputs need to be transformed into continuous functions or smooth curves for simulation purposes. In many scientific machine learning workflows, interpolating data is essential to learn continuous models. DataInterpolations.jl can be used for facilitating these types of workflows. Several interpolation packages already exist in Julia, such as [Interpolations.jl](https://juliamath.github.io/Interpolations.jl/stable/), which primarily specializes in B-Splines and uniformly spaced data with some support for irregularly spaced data. In contrast, DataInterpolations.jl does not assume any specific structure in the data, offering greater flexibility for diverse datasets. [Interpolations.jl](https://juliamath.github.io/Interpolations.jl/stable/) also doesn't offer methods like Quadratic Interpolation, Lagrange Interpolation, Hermite Splines etc. [BasicInterpolators.jl](https://github.com/markmbaum/BasicInterpolators.jl) is more similar to DataInterpolations.jl, although it doesn't offer methods like B-Splines. Rest of the interpolation packages focus on particular methods like [BSplineKit.jl](https://github.com/jipolanco/BSplineKit.jl) for B-Splines, [FastChebInterp.jl](https://github.com/JuliaMath/FastChebInterp.jl) for Chebyshev interpolation, [PCHIPInterpolation](https://github.com/gerlero/PCHIPInterpolation.jl) for PCHIP interpolation etc. Additionally, DataInterpolations.jl includes many novel techniques for accelerating the interpolation searches with specialized caching, quasi-linear guessing, and more to improve the performance algorithmically, beyond the simple computational optimizations. In summary, DataInterpolations.jl is more generic from other packages and offers many fast interpolation methods for arbitrarily spaced 1D data, all within a consistent and simple interface. # Example -The following tutorials in the documentation [1](https://docs.sciml.ai/DataInterpolations/stable/methods/) provides how to define each of the interpolation methods and compute the value at any point. [2](https://docs.sciml.ai/DataInterpolations/stable/interface/) provides explanation for using the interface and interpolated objects for evaluating at any point, computing the derivative at any point and computing the integral between any two points. +The following tutorials in the documentation [1](https://docs.sciml.ai/DataInterpolations/stable/methods/) provides how to define each of the interpolation methods and compute the value at any point. [2](https://docs.sciml.ai/DataInterpolations/stable/interface/) provides explanation for using the interface and interpolated objects for evaluating at any point, computing the derivative at any point and computing the integral between any two points. [3](https://docs.sciml.ai/DataInterpolations/stable/symbolics/) provides how to use interpolation objects with Symbolics.jl and ModelingToolkit.jl. A simple demonstration here: diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index d96ea69a..7f44c878 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -7,7 +7,6 @@ abstract type AbstractInterpolation{T} end using LinearAlgebra, RecipesBase using PrettyTables using ForwardDiff -using ReadOnlyArrays import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated, bracketstrictlymontonic @@ -90,12 +89,6 @@ function Base.showerror(io::IO, e::IntegralNotInvertibleError) print(io, INTEGRAL_NOT_INVERTIBLE_ERROR) end -const MUST_COPY_ERROR = "A copy must be made of u, t to filter missing data" -struct MustCopyError <: Exception end -function Base.showerror(io::IO, e::MustCopyError) - print(io, MUST_COPY_ERROR) -end - export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline, BSplineInterpolation, BSplineApprox, CubicHermiteSpline, PCHIPInterpolation, @@ -128,12 +121,12 @@ struct RegularizationSmooth{uType, tType, T, T2, ITP <: AbstractInterpolation{T} Aitp, extrapolate) new{typeof(u), typeof(t), eltype(u), typeof(λ), typeof(Aitp)}( - readonly_wrap(u), - readonly_wrap(û), - readonly_wrap(t), - readonly_wrap(t̂), - readonly_wrap(oftype(u.parent, wls)), - readonly_wrap(oftype(u.parent, wr)), + u, + û, + t, + t̂, + wls, + wr, d, λ, alg, diff --git a/src/derivatives.jl b/src/derivatives.jl index 30c76fd0..75872095 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -18,15 +18,17 @@ function derivative(A, t, order = 1) end function _derivative(A::LinearInterpolation, t::Number, iguess) - idx = get_idx(A.t, t, iguess; idx_shift = -1, ub_shift = -2, side = :first) - A.p.slope[idx], idx + idx = get_idx(A.t, t, iguess; idx_shift = -1, ub_shift = -1, side = :first) + slope = get_parameters(A, idx) + slope, idx end function _derivative(A::QuadraticInterpolation, t::Number, iguess) i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess) - du₀ = A.p.l₀[i₀] * (2t - A.t[i₁] - A.t[i₂]) - du₁ = A.p.l₁[i₀] * (2t - A.t[i₀] - A.t[i₂]) - du₂ = A.p.l₂[i₀] * (2t - A.t[i₀] - A.t[i₁]) + l₀, l₁, l₂ = get_parameters(A, i₀) + du₀ = l₀ * (2t - A.t[i₁] - A.t[i₂]) + du₁ = l₁ * (2t - A.t[i₀] - A.t[i₂]) + du₂ = l₂ * (2t - A.t[i₀] - A.t[i₁]) return @views @. du₀ + du₁ + du₂, i₀ end @@ -129,7 +131,7 @@ end # QuadraticSpline Interpolation function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first) - σ = A.p.σ[idx - 1] + σ = get_parameters(A, idx - 1) A.z[idx - 1] + 2σ * (t - A.t[idx - 1]), idx end @@ -139,8 +141,9 @@ function _derivative(A::CubicSpline{<:AbstractVector}, t::Number, iguess) Δt₁ = t - A.t[idx] Δt₂ = A.t[idx + 1] - t dI = (-A.z[idx] * Δt₂^2 + A.z[idx + 1] * Δt₁^2) / (2A.h[idx + 1]) - dC = A.p.c₁[idx] - dD = -A.p.c₂[idx] + c₁, c₂ = get_parameters(A, idx) + dC = c₁ + dD = -c₂ dI + dC + dD, idx end @@ -193,7 +196,8 @@ function _derivative( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.du[idx] - out += Δt₀ * (Δt₀ * A.p.c₂[idx] + 2(A.p.c₁[idx] + Δt₁ * A.p.c₂[idx])) + c₁, c₂ = get_parameters(A, idx) + out += Δt₀ * (Δt₀ * c₂ + 2(c₁ + Δt₁ * c₂)) out, idx end @@ -204,7 +208,8 @@ function _derivative( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.du[idx] + A.ddu[idx] * Δt₀ + c₁, c₂, c₃ = get_parameters(A, idx) out += Δt₀^2 * - (3A.p.c₁[idx] + (3Δt₁ + Δt₀) * A.p.c₂[idx] + (3Δt₁^2 + Δt₀ * 2Δt₁) * A.p.c₃[idx]) + (3c₁ + (3Δt₁ + Δt₀) * c₂ + (3Δt₁^2 + Δt₀ * 2Δt₁) * c₃) out, idx end diff --git a/src/integral_inverses.jl b/src/integral_inverses.jl index 4437726e..31d0853c 100644 --- a/src/integral_inverses.jl +++ b/src/integral_inverses.jl @@ -40,10 +40,9 @@ struct LinearInterpolationIntInv{uType, tType, itpType, T} <: extrapolate::Bool idx_prev::Base.RefValue{Int} itp::itpType - safetycopy::Bool function LinearInterpolationIntInv(u, t, A) new{typeof(u), typeof(t), typeof(A), eltype(u)}( - u, t, A.extrapolate, Ref(1), A, A.safetycopy) + u, t, A.extrapolate, Ref(1), A) end end @@ -51,9 +50,11 @@ function invertible_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) return all(A.u .> 0) end +get_I(A::AbstractInterpolation) = isempty(A.I) ? cumulative_integral(A, true) : A.I + function invert_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) !invertible_integral(A) && throw(IntegralNotInvertibleError()) - return LinearInterpolationIntInv(A.t, A.I, A) + return LinearInterpolationIntInv(A.t, get_I(A), A) end function _interpolate( @@ -61,7 +62,8 @@ function _interpolate( idx = get_idx(A.t, t, iguess) Δt = t - A.t[idx] x = A.itp.u[idx] - u = A.u[idx] + 2Δt / (x + sqrt(x^2 + A.itp.p.slope[idx] * 2Δt)) + slope = get_parameters(A.itp, idx) + u = A.u[idx] + 2Δt / (x + sqrt(x^2 + slope * 2Δt)) u, idx end @@ -84,10 +86,9 @@ struct ConstantInterpolationIntInv{uType, tType, itpType, T} <: extrapolate::Bool idx_prev::Base.RefValue{Int} itp::itpType - safetycopy::Bool function ConstantInterpolationIntInv(u, t, A) new{typeof(u), typeof(t), typeof(A), eltype(u)}( - u, t, A.extrapolate, Ref(1), A, A.safetycopy + u, t, A.extrapolate, Ref(1), A ) end end @@ -98,7 +99,7 @@ end function invert_integral(A::ConstantInterpolation{<:AbstractVector{<:Number}}) !invertible_integral(A) && throw(IntegralNotInvertibleError()) - return ConstantInterpolationIntInv(A.t, A.I, A) + return ConstantInterpolationIntInv(A.t, get_I(A), A) end function _interpolate( diff --git a/src/integrals.jl b/src/integrals.jl index 3040189f..03ea26c5 100644 --- a/src/integrals.jl +++ b/src/integrals.jl @@ -12,14 +12,24 @@ function integral(A::AbstractInterpolation, t1::Number, t2::Number) # the index less than t2 idx2 = get_idx(A.t, t2, 0; idx_shift = -1, side = :first) - total = A.I[idx2] - A.I[idx1] - return if t1 == t2 - zero(total) + if A.cache_parameters + total = A.I[idx2] - A.I[idx1] + return if t1 == t2 + zero(total) + else + total += _integral(A, idx1, A.t[idx1]) + total -= _integral(A, idx1, t1) + total += _integral(A, idx2, t2) + total -= _integral(A, idx2, A.t[idx2]) + total + end else - total += _integral(A, idx1, A.t[idx1]) - total -= _integral(A, idx1, t1) - total += _integral(A, idx2, t2) - total -= _integral(A, idx2, A.t[idx2]) + total = zero(eltype(A.u)) + for idx in idx1:idx2 + lt1 = idx == idx1 ? t1 : A.t[idx] + lt2 = idx == idx2 ? t2 : A.t[idx + 1] + total += _integral(A, idx, lt2) - _integral(A, idx, lt1) + end total end end @@ -28,7 +38,8 @@ function _integral(A::LinearInterpolation{<:AbstractVector{<:Number}}, idx::Number, t::Number) Δt = t - A.t[idx] - Δt * (A.u[idx] + A.p.slope[idx] * Δt / 2) + slope = get_parameters(A, idx) + Δt * (A.u[idx] + slope * Δt / 2) end function _integral( @@ -52,24 +63,27 @@ function _integral(A::QuadraticInterpolation{<:AbstractVector{<:Number}}, t₂ = A.t[idx + 2] t_sq = (t^2) / 3 - Iu₀ = A.p.l₀[idx] * t * (t_sq - t * (t₁ + t₂) / 2 + t₁ * t₂) - Iu₁ = A.p.l₁[idx] * t * (t_sq - t * (t₀ + t₂) / 2 + t₀ * t₂) - Iu₂ = A.p.l₂[idx] * t * (t_sq - t * (t₀ + t₁) / 2 + t₀ * t₁) + l₀, l₁, l₂ = get_parameters(A, idx) + Iu₀ = l₀ * t * (t_sq - t * (t₁ + t₂) / 2 + t₁ * t₂) + Iu₁ = l₁ * t * (t_sq - t * (t₀ + t₂) / 2 + t₀ * t₂) + Iu₂ = l₂ * t * (t_sq - t * (t₀ + t₁) / 2 + t₀ * t₁) return Iu₀ + Iu₁ + Iu₂ end function _integral(A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt^2 / 2 + A.p.σ[idx] * Δt^3 / 3 + Cᵢ * Δt + σ = get_parameters(A, idx) + return A.z[idx] * Δt^2 / 2 + σ * Δt^3 / 3 + Cᵢ * Δt end function _integral(A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) Δt₁sq = (t - A.t[idx])^2 / 2 Δt₂sq = (A.t[idx + 1] - t)^2 / 2 II = (-A.z[idx] * Δt₂sq^2 + A.z[idx + 1] * Δt₁sq^2) / (6A.h[idx + 1]) - IC = A.p.c₁[idx] * Δt₁sq - ID = -A.p.c₂[idx] * Δt₂sq + c₁, c₂ = get_parameters(A, idx) + IC = c₁ * Δt₁sq + ID = -c₂ * Δt₂sq II + IC + ID end @@ -91,8 +105,9 @@ function _integral( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = Δt₀ * (A.u[idx] + Δt₀ * A.du[idx] / 2) - p = A.p.c₁[idx] + Δt₁ * A.p.c₂[idx] - dp = A.p.c₂[idx] + c₁, c₂ = get_parameters(A, idx) + p = c₁ + Δt₁ * c₂ + dp = c₂ out += Δt₀^3 / 3 * (p - dp * Δt₀ / 4) out end @@ -103,9 +118,10 @@ function _integral( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = Δt₀ * (A.u[idx] + A.du[idx] * Δt₀ / 2 + A.ddu[idx] * Δt₀^2 / 6) - p = A.p.c₁[idx] + A.p.c₂[idx] * Δt₁ + A.p.c₃[idx] * Δt₁^2 - dp = A.p.c₂[idx] + 2A.p.c₃[idx] * Δt₁ - ddp = 2A.p.c₃[idx] + c₁, c₂, c₃ = get_parameters(A, idx) + p = c₁ + c₂ * Δt₁ + c₃ * Δt₁^2 + dp = c₂ + 2c₃ * Δt₁ + ddp = 2c₃ out += Δt₀^4 / 4 * (p - Δt₀ / 5 * dp + Δt₀^2 / 30 * ddp) out end diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index 96234c0c..aa452589 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -1,5 +1,5 @@ """ - LinearInterpolation(u, t; extrapolate = false, safetycopy = true) + LinearInterpolation(u, t; extrapolate = false, cache_parameters = false) It is the method of interpolating between the data points using a linear polynomial. For any point, two data points one each side are chosen and connected with a line. Extrapolation extends the last linear polynomial on each side. @@ -12,7 +12,8 @@ Extrapolation extends the last linear polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation + computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. - `assume_linear_t`: boolean value to specify a faster index lookup behaviour for evenly-distributed abscissae. Alternatively, a numerical threshold may be specified for a test based on the normalized standard deviation of the difference with respect @@ -25,27 +26,26 @@ struct LinearInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolati p::LinearParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool + cache_parameters::Bool use_linear_lookup::Bool - function LinearInterpolation( - u, t, I, p, extrapolate, safetycopy, assume_linear_t) + function LinearInterpolation(u, t, I, p, extrapolate, cache_parameters, assume_linear_t) linear_flag = seems_linear(assume_linear_t, t) new{typeof(u), typeof(t), typeof(I), typeof(p.slope), eltype(u)}( - u, t, I, p, extrapolate, Ref(1), safetycopy, linear_flag) + u, t, I, p, extrapolate, Ref(1), cache_parameters, linear_flag) end end function LinearInterpolation( - u, t; extrapolate = false, safetycopy = true, assume_linear_t = 1e-2) - u, t = munge_data(u, t, safetycopy) - p = LinearParameterCache(u, t) - A = LinearInterpolation(u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - LinearInterpolation(u, t, I, p, extrapolate, safetycopy, assume_linear_t) + u, t; extrapolate = false, cache_parameters = false, assume_linear_t = 1e-2) + u, t = munge_data(u, t) + p = LinearParameterCache(u, t, cache_parameters) + A = LinearInterpolation(u, t, nothing, p, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + LinearInterpolation(u, t, I, p, extrapolate, cache_parameters, assume_linear_t) end """ - QuadraticInterpolation(u, t, mode = :Forward; extrapolate = false, safetycopy = true) + QuadraticInterpolation(u, t, mode = :Forward; extrapolate = false, cache_parameters = false) It is the method of interpolating between the data points using quadratic polynomials. For any point, three data points nearby are taken to fit a quadratic polynomial. Extrapolation extends the last quadratic polynomial on each side. @@ -59,7 +59,7 @@ Extrapolation extends the last quadratic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuadraticInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolation{T} u::uType @@ -69,28 +69,28 @@ struct QuadraticInterpolation{uType, tType, IType, pType, T} <: AbstractInterpol mode::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool + cache_parameters::Bool use_linear_lookup::Bool function QuadraticInterpolation( - u, t, I, p, mode, extrapolate, safetycopy, assume_linear_t) + u, t, I, p, mode, extrapolate, cache_parameters, assume_linear_t) mode ∈ (:Forward, :Backward) || error("mode should be :Forward or :Backward for QuadraticInterpolation") linear_flag = seems_linear(assume_linear_t, t) new{typeof(u), typeof(t), typeof(I), typeof(p.l₀), eltype(u)}( - u, t, I, p, mode, extrapolate, Ref(1), safetycopy, linear_flag) + u, t, I, p, mode, extrapolate, Ref(1), cache_parameters, linear_flag) end end -function QuadraticInterpolation(u, t, mode; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - p = QuadraticParameterCache(u, t) - A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticInterpolation(u, t, I, p, mode, extrapolate, safetycopy) +function QuadraticInterpolation(u, t, mode; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + p = QuadraticParameterCache(u, t, cache_parameters) + A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) end -function QuadraticInterpolation(u, t; extrapolate = false, safetycopy = true) - QuadraticInterpolation(u, t, :Forward; extrapolate, safetycopy) +function QuadraticInterpolation(u, t; extrapolate = false, cache_parameters = false) + QuadraticInterpolation(u, t, :Forward; extrapolate, cache_parameters) end """ @@ -107,7 +107,6 @@ It is the method of interpolation using Lagrange polynomials of (k-1)th order pa ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: AbstractInterpolation{T} @@ -118,8 +117,7 @@ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: idxs::Vector{Int} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function LagrangeInterpolation(u, t, n, extrapolate, safetycopy) + function LagrangeInterpolation(u, t, n, extrapolate) bcache = zeros(eltype(u[1]), n + 1) idxs = zeros(Int, n + 1) fill!(bcache, NaN) @@ -129,23 +127,22 @@ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: bcache, idxs, extrapolate, - Ref(1), - safetycopy + Ref(1) ) end end function LagrangeInterpolation( - u, t, n = length(t) - 1; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, n = length(t) - 1; extrapolate = false) + u, t = munge_data(u, t) if n != length(t) - 1 error("Currently only n=length(t) - 1 is supported") end - LagrangeInterpolation(u, t, n, extrapolate, safetycopy) + LagrangeInterpolation(u, t, n, extrapolate) end """ - AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) + AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation built from cubic polynomials. It forms a continuously differentiable function. For more details, refer: [https://en.wikipedia.org/wiki/Akima_spline](https://en.wikipedia.org/wiki/Akima_spline). Extrapolation extends the last cubic polynomial on each side. @@ -158,7 +155,7 @@ Extrapolation extends the last cubic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: AbstractInterpolation{T} @@ -170,8 +167,8 @@ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: d::dType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function AkimaInterpolation(u, t, I, b, c, d, extrapolate, safetycopy) + cache_parameters::Bool + function AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(b), typeof(c), typeof(d), eltype(u)}(u, t, @@ -181,13 +178,13 @@ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: d, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end -function AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) +function AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) n = length(t) dt = diff(t) m = Array{eltype(u)}(undef, n + 3) @@ -208,13 +205,13 @@ function AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) c = (3.0 .* m[3:(end - 2)] .- 2.0 .* b[1:(end - 1)] .- b[2:end]) ./ dt d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2 - A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, safetycopy) - I = cumulative_integral(A) - AkimaInterpolation(u, t, I, b, c, d, extrapolate, safetycopy) + A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) end """ - ConstantInterpolation(u, t; dir = :left, extrapolate = false, safetycopy = true) + ConstantInterpolation(u, t; dir = :left, extrapolate = false, cache_parameters = false) It is the method of interpolating using a constant polynomial. For any point, two adjacent data points are found on either side (left and right). The value at that point depends on `dir`. If it is `:left`, then the value at the left point is chosen and if it is `:right`, the value at the right point is chosen. @@ -229,7 +226,7 @@ Extrapolation extends the last constant polynomial at the end points on each sid - `dir`: indicates which value should be used for interpolation (`:left` or `:right`). - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct ConstantInterpolation{uType, tType, IType, T} <: AbstractInterpolation{T} u::uType @@ -239,23 +236,24 @@ struct ConstantInterpolation{uType, tType, IType, T} <: AbstractInterpolation{T} dir::Symbol # indicates if value to the $dir should be used for the interpolation extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool + cache_parameters::Bool use_linear_lookup::Bool - function ConstantInterpolation(u, t, I, dir, extrapolate, safetycopy) + function ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), eltype(u)}( - u, t, I, nothing, dir, extrapolate, Ref(1), safetycopy) + u, t, I, nothing, dir, extrapolate, Ref(1), cache_parameters) end end -function ConstantInterpolation(u, t; dir = :left, extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - A = ConstantInterpolation(u, t, nothing, dir, extrapolate, safetycopy) - I = cumulative_integral(A) - ConstantInterpolation(u, t, I, dir, extrapolate, safetycopy) +function ConstantInterpolation( + u, t; dir = :left, extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + A = ConstantInterpolation(u, t, nothing, dir, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) end """ - QuadraticSpline(u, t; extrapolate = false, safetycopy = true) + QuadraticSpline(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation using piecewise quadratic polynomials between each pair of data points. Its first derivative is also continuous. Extrapolation extends the last quadratic polynomial on each side. @@ -268,7 +266,7 @@ Extrapolation extends the last quadratic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: AbstractInterpolation{T} @@ -281,9 +279,9 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: z::zType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool + cache_parameters::Bool use_linear_lookup::Bool - function QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + function QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.σ), typeof(tA), typeof(d), typeof(z), eltype(u)}(u, t, @@ -294,15 +292,15 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: z, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end function QuadraticSpline( u::uType, t; extrapolate = false, - safetycopy = true) where {uType <: AbstractVector{<:Number}} - u, t = munge_data(u, t, safetycopy) + cache_parameters = false) where {uType <: AbstractVector{<:Number}} + u, t = munge_data(u, t) s = length(t) dl = ones(eltype(t), s - 1) d_tmp = ones(eltype(t), s) @@ -314,15 +312,17 @@ function QuadraticSpline( d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s) z = tA \ d - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + + p = QuadraticSplineParameterCache(z, t, cache_parameters) + A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) end function QuadraticSpline( - u::uType, t; extrapolate = false, safetycopy = true) where {uType <: AbstractVector} - u, t = munge_data(u, t, safetycopy) + u::uType, t; extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector} + u, t = munge_data(u, t) s = length(t) dl = ones(eltype(t), s - 1) d_tmp = ones(eltype(t), s) @@ -335,14 +335,15 @@ function QuadraticSpline( d = transpose(reshape(reduce(hcat, d_), :, s)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + + p = QuadraticSplineParameterCache(z, t, cache_parameters) + A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) end """ - CubicSpline(u, t; extrapolate = false, safetycopy = true) + CubicSpline(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation using piecewise cubic polynomials between each pair of data points. Its first and second derivative is also continuous. Second derivative on both ends are zero, which are also called "natural" boundary conditions. Extrapolation extends the last cubic polynomial on each side. @@ -355,7 +356,7 @@ Second derivative on both ends are zero, which are also called "natural" boundar ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInterpolation{T} u::uType @@ -366,9 +367,9 @@ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInter z::zType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool + cache_parameters::Bool use_linear_lookup::Bool - function CubicSpline(u, t, I, p, h, z, extrapolate, safetycopy) + function CubicSpline(u, t, I, p, h, z, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.c₁), typeof(h), typeof(z), eltype(u)}( u, t, @@ -378,15 +379,16 @@ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInter z, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end function CubicSpline(u::uType, t; - extrapolate = false, safetycopy = true) where {uType <: AbstractVector{<:Number}} - u, t = munge_data(u, t, safetycopy) + extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector{<:Number}} + u, t = munge_data(u, t) n = length(t) - 1 h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0) dl = vcat(h[2:n], zero(eltype(h))) @@ -403,15 +405,17 @@ function CubicSpline(u::uType, 6(u[i + 1] - u[i]) / h[i + 1] - 6(u[i] - u[i - 1]) / h[i], 1:(n + 1)) z = tA \ d - p = CubicSplineParameterCache(u, h, z) - A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) + + p = CubicSplineParameterCache(u, h, z, cache_parameters) + A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) end function CubicSpline( - u::uType, t; extrapolate = false, safetycopy = true) where {uType <: AbstractVector} - u, t = munge_data(u, t, safetycopy) + u::uType, t; extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector} + u, t = munge_data(u, t) n = length(t) - 1 h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0) dl = vcat(h[2:n], zero(eltype(h))) @@ -425,10 +429,11 @@ function CubicSpline( d = transpose(reshape(reduce(hcat, d_), :, n + 1)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = CubicSplineParameterCache(u, h, z) - A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) + + p = CubicSplineParameterCache(u, h, z, cache_parameters) + A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) end """ @@ -448,7 +453,6 @@ Extrapolation is a constant polynomial of the end points on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: AbstractInterpolation{T} @@ -463,7 +467,6 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: knotVecType::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool use_linear_lookup::Bool function BSplineInterpolation(u, t, @@ -474,8 +477,7 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: N, pVecType, knotVecType, - extrapolate, - safetycopy) + extrapolate) new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), typeof(N), eltype(u)}(u, t, d, @@ -486,15 +488,14 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: pVecType, knotVecType, extrapolate, - Ref(1), - safetycopy + Ref(1) ) end end function BSplineInterpolation( - u, t, d, pVecType, knotVecType; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, d, pVecType, knotVecType; extrapolate = false) + u, t = munge_data(u, t) n = length(t) n < d + 1 && error("BSplineInterpolation needs at least d + 1, i.e. $(d+1) points.") s = zero(eltype(u)) @@ -558,11 +559,11 @@ function BSplineInterpolation( c = vec(N \ u[:, :]) N = zeros(eltype(t), n) BSplineInterpolation( - u, t, d, p, k, c, N, pVecType, knotVecType, extrapolate, safetycopy) + u, t, d, p, k, c, N, pVecType, knotVecType, extrapolate) end """ - BSplineApprox(u, t, d, h, pVecType, knotVecType; extrapolate = false, safetycopy = true) + BSplineApprox(u, t, d, h, pVecType, knotVecType; extrapolate = false) It is a regression based B-spline. The argument choices are the same as the `BSplineInterpolation`, with the additional parameter `h < length(t)` which is the number of control points to use, with smaller `h` indicating more smoothing. For more information, refer [http://www.cad.zju.edu.cn/home/zhx/GM/009/00-bsia.pdf](http://www.cad.zju.edu.cn/home/zhx/GM/009/00-bsia.pdf). @@ -580,7 +581,6 @@ Extrapolation is a constant polynomial of the end points on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: AbstractInterpolation{T} @@ -596,7 +596,6 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: knotVecType::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool use_linear_lookup::Bool function BSplineApprox(u, t, @@ -608,8 +607,7 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: N, pVecType, knotVecType, - extrapolate, - safetycopy + extrapolate ) new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), typeof(N), eltype(u)}(u, t, @@ -622,15 +620,14 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: pVecType, knotVecType, extrapolate, - Ref(1), - safetycopy::Bool + Ref(1) ) end end function BSplineApprox( - u, t, d, h, pVecType, knotVecType; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, d, h, pVecType, knotVecType; extrapolate = false) + u, t = munge_data(u, t) n = length(t) h < d + 1 && error("BSplineApprox needs at least d + 1, i.e. $(d+1) control points.") s = zero(eltype(u)) @@ -714,11 +711,12 @@ function BSplineApprox( P = M \ Q c[2:(end - 1)] .= vec(P) N = zeros(eltype(t), h) - BSplineApprox(u, t, d, h, p, k, c, N, pVecType, knotVecType, extrapolate, safetycopy) + BSplineApprox( + u, t, d, h, p, k, c, N, pVecType, knotVecType, extrapolate) end """ - CubicHermiteSpline(du, u, t; extrapolate = false, safetycopy = true) + CubicHermiteSpline(du, u, t; extrapolate = false, cache_parameters = false) It is a Cubic Hermite interpolation, which is a piece-wise third degree polynomial such that the value and the first derivative are equal to given values in the data points. @@ -731,7 +729,7 @@ It is a Cubic Hermite interpolation, which is a piece-wise third degree polynomi ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct CubicHermiteSpline{uType, tType, IType, duType, pType, T} <: AbstractInterpolation{T} du::duType @@ -741,21 +739,21 @@ struct CubicHermiteSpline{uType, tType, IType, duType, pType, T} <: AbstractInte p::CubicHermiteParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool + cache_parameters::Bool use_linear_lookup::Bool - function CubicHermiteSpline(du, u, t, I, p, extrapolate, safetycopy) + function CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(du), typeof(p.c₁), eltype(u)}( - du, u, t, I, p, extrapolate, Ref(1), safetycopy) + du, u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function CubicHermiteSpline(du, u, t; extrapolate = false, safetycopy = true) +function CubicHermiteSpline(du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du) "Length of `u` is not equal to length of `du`." - u, t = munge_data(u, t, safetycopy) - p = CubicHermiteParameterCache(du, u, t) - A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicHermiteSpline(du, u, t, I, p, extrapolate, safetycopy) + u, t = munge_data(u, t) + p = CubicHermiteParameterCache(du, u, t, cache_parameters) + A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) end """ @@ -773,12 +771,12 @@ section 3.4 for more details. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ -function PCHIPInterpolation(u, t; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) +function PCHIPInterpolation(u, t; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) du = du_PCHIP(u, t) - CubicHermiteSpline(du, u, t; extrapolate, safetycopy) + CubicHermiteSpline(du, u, t; extrapolate, cache_parameters) end """ @@ -796,7 +794,7 @@ It is a Quintic Hermite interpolation, which is a piece-wise fifth degree polyno ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuinticHermiteSpline{uType, tType, IType, duType, dduType, pType, T} <: AbstractInterpolation{T} @@ -808,20 +806,20 @@ struct QuinticHermiteSpline{uType, tType, IType, duType, dduType, pType, T} <: p::QuinticHermiteParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool + cache_parameters::Bool use_linear_lookup::Bool - function QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, safetycopy) + function QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(du), typeof(ddu), typeof(p.c₁), eltype(u)}( - ddu, du, u, t, I, p, extrapolate, Ref(1), safetycopy) + ddu, du, u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, safetycopy = true) +function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du)==length(ddu) "Length of `u` is not equal to length of `du` or `ddu`." - u, t = munge_data(u, t, safetycopy) - p = QuinticHermiteParameterCache(ddu, du, u, t) - A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, safetycopy) + u, t = munge_data(u, t) + p = QuinticHermiteParameterCache(ddu, du, u, t, cache_parameters) + A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) end diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index d8870568..99f3874d 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -18,15 +18,15 @@ end function _interpolate(A::LinearInterpolation{<:AbstractVector}, t::Number, iguess) if isnan(t) # For correct derivative with NaN - idx = firstindex(A.u) - 1 + idx = firstindex(A.u) t1 = t2 = one(eltype(A.t)) u1 = u2 = one(eltype(A.u)) - slope = t * one(eltype(A.p.slope)) + slope = t * get_parameters(A, idx) else idx = get_idx(A.t, t, iguess) t1, t2 = A.t[idx], A.t[idx + 1] u1, u2 = A.u[idx], A.u[idx + 1] - slope = A.p.slope[idx] + slope = get_parameters(A, idx) end Δt = t - t1 @@ -46,7 +46,8 @@ end function _interpolate(A::LinearInterpolation{<:AbstractMatrix}, t::Number, iguess) idx = get_idx(A.t, t, iguess) Δt = t - A.t[idx] - return A.u[:, idx] + A.p.slope[idx] * Δt, idx + slope = get_parameters(A, idx) + return A.u[:, idx] + slope * Δt, idx end # Quadratic Interpolation @@ -58,9 +59,10 @@ end function _interpolate(A::QuadraticInterpolation, t::Number, iguess) i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess) - u₀ = A.p.l₀[i₀] * (t - A.t[i₁]) * (t - A.t[i₂]) - u₁ = A.p.l₁[i₀] * (t - A.t[i₀]) * (t - A.t[i₂]) - u₂ = A.p.l₂[i₀] * (t - A.t[i₀]) * (t - A.t[i₁]) + l₀, l₁, l₂ = get_parameters(A, i₀) + u₀ = l₀ * (t - A.t[i₁]) * (t - A.t[i₂]) + u₁ = l₁ * (t - A.t[i₀]) * (t - A.t[i₂]) + u₂ = l₂ * (t - A.t[i₀]) * (t - A.t[i₁]) return u₀ + u₁ + u₂, i₀ end @@ -157,7 +159,8 @@ function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt + A.p.σ[idx] * Δt^2 + Cᵢ, idx + σ = get_parameters(A, idx) + return A.z[idx] * Δt + σ * Δt^2 + Cᵢ, idx end # CubicSpline Interpolation @@ -166,8 +169,9 @@ function _interpolate(A::CubicSpline{<:AbstractVector}, t::Number, iguess) Δt₁ = t - A.t[idx] Δt₂ = A.t[idx + 1] - t I = (A.z[idx] * Δt₂^3 + A.z[idx + 1] * Δt₁^3) / (6A.h[idx + 1]) - C = A.p.c₁[idx] * Δt₁ - D = A.p.c₂[idx] * Δt₂ + c₁, c₂ = get_parameters(A, idx) + C = c₁ * Δt₁ + D = c₂ * Δt₂ I + C + D, idx end @@ -213,7 +217,8 @@ function _interpolate( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.u[idx] + Δt₀ * A.du[idx] - out += Δt₀^2 * (A.p.c₁[idx] + Δt₁ * A.p.c₂[idx]) + c₁, c₂ = get_parameters(A, idx) + out += Δt₀^2 * (c₁ + Δt₁ * c₂) out, idx end @@ -224,6 +229,7 @@ function _interpolate( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.u[idx] + Δt₀ * (A.du[idx] + A.ddu[idx] * Δt₀ / 2) - out += Δt₀^3 * (A.p.c₁[idx] + Δt₁ * (A.p.c₂[idx] + A.p.c₃[idx] * Δt₁)) + c₁, c₂, c₃ = get_parameters(A, idx) + out += Δt₀^3 * (c₁ + Δt₁ * (c₂ + c₃ * Δt₁)) out, idx end diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 7d51c308..d29959dc 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -60,15 +60,11 @@ function spline_coefficients!(N, d, k, u::AbstractVector) end # helper function for data manipulation -function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real}, safetycopy::Bool) - if safetycopy - u = copy(u) - t = copy(t) - end - return readonly_wrap(u), readonly_wrap(t) +function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real}) + return u, t end -function munge_data(u::AbstractVector, t::AbstractVector, safetycopy::Bool) +function munge_data(u::AbstractVector, t::AbstractVector) Tu = Base.nonmissingtype(eltype(u)) Tt = Base.nonmissingtype(eltype(t)) @assert length(t) == length(u) @@ -77,17 +73,13 @@ function munge_data(u::AbstractVector, t::AbstractVector, safetycopy::Bool) if !ismissing(u[i]) && !ismissing(t[i]) ) - if safetycopy - u = Tu.([u[i] for i in non_missing_indices]) - t = Tt.([t[i] for i in non_missing_indices]) - else - !isempty(non_missing_indices) && throw(MustCopyError()) - end + u = Tu.([u[i] for i in non_missing_indices]) + t = Tt.([t[i] for i in non_missing_indices]) - return readonly_wrap(u), readonly_wrap(t) + return u, t end -function munge_data(U::StridedMatrix, t::AbstractVector, safetycopy::Bool) +function munge_data(U::StridedMatrix, t::AbstractVector) TU = Base.nonmissingtype(eltype(U)) Tt = Base.nonmissingtype(eltype(t)) @assert length(t) == size(U, 2) @@ -96,14 +88,10 @@ function munge_data(U::StridedMatrix, t::AbstractVector, safetycopy::Bool) if !any(ismissing, U[:, i]) && !ismissing(t[i]) ) - if safetycopy - U = hcat([TU.(U[:, i]) for i in non_missing_indices]...) - t = Tt.([t[i] for i in non_missing_indices]) - else - !isempty(non_missing_indices) && throw(MustCopyError()) - end + U = hcat([TU.(U[:, i]) for i in non_missing_indices]...) + t = Tt.([t[i] for i in non_missing_indices]) - return readonly_wrap(U), readonly_wrap(t) + return U, t end seems_linear(assume_linear_t::Bool, _) = assume_linear_t @@ -125,10 +113,6 @@ function looks_linear(t; threshold = 1e-2) norm_var < threshold^2 end -# Don't nest ReadOnlyArrays -readonly_wrap(a::AbstractArray) = ReadOnlyArray(a) -readonly_wrap(a::ReadOnlyArray) = a - function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = :last) ub = length(tvec) + ub_shift return if side == :last @@ -140,14 +124,63 @@ function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = : end end -function cumulative_integral(A) - if isempty(methods(_integral, (typeof(A), Any, Any))) - return nothing +function cumulative_integral(A, cache_parameters) + if cache_parameters && hasmethod(_integral, Tuple{typeof(A), Number, Number}) + integral_values = [_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx]) + for idx in 1:(length(A.t) - 1)] + pushfirst!(integral_values, zero(first(integral_values))) + cumsum(integral_values) + else + promote_type(eltype(A.u), eltype(A.t))[] + end +end + +function get_parameters(A::LinearInterpolation, idx) + if A.cache_parameters + A.p.slope[idx] + else + linear_interpolation_parameters(A.u, A.t, idx) + end +end + +function get_parameters(A::QuadraticInterpolation, idx) + if A.cache_parameters + A.p.l₀[idx], A.p.l₁[idx], A.p.l₂[idx] + else + quadratic_interpolation_parameters(A.u, A.t, idx) + end +end + +function get_parameters(A::QuadraticSpline, idx) + if A.cache_parameters + A.p.σ[idx] + else + quadratic_spline_parameters(A.z, A.t, idx) + end +end + +function get_parameters(A::CubicSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx] + else + cubic_spline_parameters(A.u, A.h, A.z, idx) + end +end + +function get_parameters(A::CubicHermiteSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx] + else + cubic_hermite_spline_parameters(A.du, A.u, A.t, idx) + end +end + +function get_parameters(A::QuinticHermiteSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx], A.p.c₃[idx] + else + quintic_hermite_spline_parameters(A.ddu, A.du, A.u, A.t, idx) end - integral_values = [_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx]) - for idx in 1:(length(A.t) - 1)] - pushfirst!(integral_values, zero(first(integral_values))) - return cumsum(integral_values) end function du_PCHIP(u, t) diff --git a/src/online.jl b/src/online.jl index 0fab5d44..5193e6b2 100644 --- a/src/online.jl +++ b/src/online.jl @@ -9,69 +9,81 @@ function add_integral_values!(A) end function push!(A::LinearInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - slope = linear_interpolation_parameters(A.u, A.t, length(A.t) - 1) - push!(A.p.slope, slope) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + slope = linear_interpolation_parameters(A.u, A.t, length(A.t) - 1) + push!(A.p.slope, slope) + add_integral_values!(A) + end A end function push!(A::QuadraticInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - l₀, l₁, l₂ = quadratic_interpolation_parameters(A.u, A.t, length(A.t) - 2) - push!(A.p.l₀, l₀) - push!(A.p.l₁, l₁) - push!(A.p.l₂, l₂) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + l₀, l₁, l₂ = quadratic_interpolation_parameters(A.u, A.t, length(A.t) - 2) + push!(A.p.l₀, l₀) + push!(A.p.l₁, l₁) + push!(A.p.l₂, l₂) + add_integral_values!(A) + end A end function push!(A::ConstantInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + add_integral_values!(A) + end A end function append!( - A::LinearInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} + A::LinearInterpolation{U, T}, u::U, t::T) where { + U, T} length_old = length(A.t) - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - slope = linear_interpolation_parameters.( - Ref(A.u), Ref(A.t), length_old:(length(A.t) - 1)) - append!(A.p.slope, slope) - add_integral_values!(A) + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + slope = linear_interpolation_parameters.( + Ref(A.u), Ref(A.t), length_old:(length(A.t) - 1)) + append!(A.p.slope, slope) + add_integral_values!(A) + end A end function append!( - A::ConstantInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - add_integral_values!(A) + A::ConstantInterpolation{U, T}, u::U, t::T) where { + U, T} + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + add_integral_values!(A) + end A end function append!( - A::QuadraticInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} + A::QuadraticInterpolation{U, T}, u::U, t::T) where { + U, T} length_old = length(A.t) - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - parameters = quadratic_interpolation_parameters.( - Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) - l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) - append!(A.p.l₀, l₀) - append!(A.p.l₁, l₁) - append!(A.p.l₂, l₂) - add_integral_values!(A) + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + parameters = quadratic_interpolation_parameters.( + Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) + l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) + append!(A.p.l₀, l₀) + append!(A.p.l₁, l₁) + append!(A.p.l₂, l₂) + add_integral_values!(A) + end A end diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index 2820dc8f..0701b3a2 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -2,13 +2,28 @@ struct LinearParameterCache{pType} slope::pType end -function LinearParameterCache(u, t) - slope = linear_interpolation_parameters.(Ref(u), Ref(t), 1:(length(t) - 1)) - return LinearParameterCache(slope) +function LinearParameterCache(u, t, cache_parameters) + if cache_parameters + slope = linear_interpolation_parameters.(Ref(u), Ref(t), 1:(length(t) - 1)) + LinearParameterCache(slope) + else + # Compute parameters once to infer types + slope = linear_interpolation_parameters(u, t, 1) + LinearParameterCache(typeof(slope)[]) + end +end + +# Prevent e.g. Inf - Inf = NaN +function safe_diff(b, a::T) where {T} + b == a ? zero(T) : b - a end -function linear_interpolation_parameters(u, t, idx) - Δu = u isa AbstractMatrix ? u[:, idx + 1] - u[:, idx] : u[idx + 1] - u[idx] +function linear_interpolation_parameters(u::AbstractArray{T}, t, idx) where {T} + Δu = if u isa AbstractMatrix + [safe_diff(u[j, idx + 1], u[j, idx]) for j in 1:size(u)[1]] + else + safe_diff(u[idx + 1], u[idx]) + end Δt = t[idx + 1] - t[idx] slope = Δu / Δt slope = iszero(Δt) ? zero(slope) : slope @@ -21,11 +36,18 @@ struct QuadraticParameterCache{pType} l₂::pType end -function QuadraticParameterCache(u, t) - parameters = quadratic_interpolation_parameters.( - Ref(u), Ref(t), 1:(length(t) - 2)) - l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) - return QuadraticParameterCache(l₀, l₁, l₂) +function QuadraticParameterCache(u, t, cache_parameters) + if cache_parameters + parameters = quadratic_interpolation_parameters.( + Ref(u), Ref(t), 1:(length(t) - 2)) + l₀, l₁, l₂ = collect.(eachrow(stack(collect.(parameters)))) + QuadraticParameterCache(l₀, l₁, l₂) + else + # Compute parameters once to infer types + l₀, l₁, l₂ = quadratic_interpolation_parameters(u, t, 1) + pType = typeof(l₀) + QuadraticParameterCache(pType[], pType[], pType[]) + end end function quadratic_interpolation_parameters(u, t, idx) @@ -54,9 +76,15 @@ struct QuadraticSplineParameterCache{pType} σ::pType end -function QuadraticSplineParameterCache(z, t) - σ = quadratic_spline_parameters.(Ref(z), Ref(t), 1:(length(t) - 1)) - return QuadraticSplineParameterCache(σ) +function QuadraticSplineParameterCache(z, t, cache_parameters) + if cache_parameters + σ = quadratic_spline_parameters.(Ref(z), Ref(t), 1:(length(t) - 1)) + QuadraticSplineParameterCache(σ) + else + # Compute parameters once to infer types + σ = quadratic_spline_parameters(z, t, 1) + QuadraticSplineParameterCache(typeof(σ)[]) + end end function quadratic_spline_parameters(z, t, idx) @@ -69,11 +97,18 @@ struct CubicSplineParameterCache{pType} c₂::pType end -function CubicSplineParameterCache(u, h, z) - parameters = cubic_spline_parameters.( - Ref(u), Ref(h), Ref(z), 1:(size(u)[end] - 1)) - c₁, c₂ = collect.(eachrow(hcat(collect.(parameters)...))) - return CubicSplineParameterCache(c₁, c₂) +function CubicSplineParameterCache(u, h, z, cache_parameters) + if cache_parameters + parameters = cubic_spline_parameters.( + Ref(u), Ref(h), Ref(z), 1:(size(u)[end] - 1)) + c₁, c₂ = collect.(eachrow(stack(collect.(parameters)))) + CubicSplineParameterCache(c₁, c₂) + else + # Compute parameters once to infer types + c₁, c₂ = cubic_spline_parameters(u, h, z, 1) + pType = typeof(c₁) + CubicSplineParameterCache(pType[], pType[]) + end end function cubic_spline_parameters(u, h, z, idx) @@ -87,11 +122,18 @@ struct CubicHermiteParameterCache{pType} c₂::pType end -function CubicHermiteParameterCache(du, u, t) - parameters = cubic_hermite_spline_parameters.( - Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) - c₁, c₂ = collect.(eachrow(hcat(collect.(parameters)...))) - return CubicHermiteParameterCache(c₁, c₂) +function CubicHermiteParameterCache(du, u, t, cache_parameters) + if cache_parameters + parameters = cubic_hermite_spline_parameters.( + Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) + c₁, c₂ = collect.(eachrow(stack(collect.(parameters)))) + CubicHermiteParameterCache(c₁, c₂) + else + # Compute parameters once to infer types + c₁, c₂ = cubic_hermite_spline_parameters(du, u, t, 1) + pType = typeof(c₁) + CubicHermiteParameterCache(pType[], pType[]) + end end function cubic_hermite_spline_parameters(du, u, t, idx) @@ -111,11 +153,18 @@ struct QuinticHermiteParameterCache{pType} c₃::pType end -function QuinticHermiteParameterCache(ddu, du, u, t) - parameters = quintic_hermite_spline_parameters.( - Ref(ddu), Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) - c₁, c₂, c₃ = collect.(eachrow(hcat(collect.(parameters)...))) - return QuinticHermiteParameterCache(c₁, c₂, c₃) +function QuinticHermiteParameterCache(ddu, du, u, t, cache_parameters) + if cache_parameters + parameters = quintic_hermite_spline_parameters.( + Ref(ddu), Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) + c₁, c₂, c₃ = collect.(eachrow(stack(collect.(parameters)))) + QuinticHermiteParameterCache(c₁, c₂, c₃) + else + # Compute parameters once to infer types + c₁, c₂, c₃ = quintic_hermite_spline_parameters(ddu, du, u, t, 1) + pType = typeof(c₁) + QuinticHermiteParameterCache(pType[], pType[], pType[]) + end end function quintic_hermite_spline_parameters(ddu, du, u, t, idx) diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index 37351d0d..50abe4ac 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -82,6 +82,12 @@ end u = vcat(2.0collect(1:10)', 3.0collect(1:10)') test_derivatives( LinearInterpolation; args = [u, t], name = "Linear Interpolation (Matrix)") + + # Issue: https://github.com/SciML/DataInterpolations.jl/issues/303 + u = [3.0, 3.0] + t = [0.0, 2.0] + test_derivatives( + LinearInterpolation; args = [u, t], name = "Linear Interpolation with two points") end @testset "Quadratic Interpolation" begin diff --git a/test/interface.jl b/test/interface.jl index 5d02a22a..e7b2b81b 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -1,20 +1,57 @@ using DataInterpolations -u = 2.0collect(1:10) -t = 1.0collect(1:10) -A = LinearInterpolation(u, t) +using Symbolics -for i in 1:10 - @test u[i] == A.u[i] -end +@testset "Interface" begin + u = 2.0collect(1:10) + t = 1.0collect(1:10) + A = LinearInterpolation(u, t) + + for i in 1:10 + @test u[i] == A.u[i] + end -for i in 1:10 - @test t[i] == A.t[i] + for i in 1:10 + @test t[i] == A.t[i] + end end -using Symbolics -u = 2.0collect(1:10) -t = 1.0collect(1:10) -A = LinearInterpolation(u, t) +@testset "Symbolics" begin + u = 2.0collect(1:10) + t = 1.0collect(1:10) + A = LinearInterpolation(u, t; extrapolate = true) + B = LinearInterpolation(u .^ 2, t; extrapolate = true) + @variables t x(t) + substitute(A(t), Dict(t => x)) + t_val = 2.7 + @test substitute(A(t), Dict(t => t_val)) == A(t_val) + @test substitute(B(A(t)), Dict(t => t_val)) == B(A(t_val)) + @test substitute(A(B(A(t))), Dict(t => t_val)) == A(B(A(t_val))) +end -@variables t x(t) -substitute(A(t), Dict(t => x)) +@testset "Type Inference" begin + u = 2.0collect(1:10) + t = 1.0collect(1:10) + methods = [ + ConstantInterpolation, LinearInterpolation, + QuadraticInterpolation, LagrangeInterpolation, + QuadraticSpline, CubicSpline, AkimaInterpolation + ] + @testset "$method" for method in methods + @inferred method(u, t) + end + @testset "BSplineInterpolation" begin + @inferred BSplineInterpolation(u, t, 3, :Uniform, :Uniform) + @inferred BSplineInterpolation(u, t, 3, :ArcLen, :Average) + end + @testset "BSplineApprox" begin + @inferred BSplineApprox(u, t, 3, 5, :Uniform, :Uniform) + @inferred BSplineApprox(u, t, 3, 5, :ArcLen, :Average) + end + du = ones(10) + ddu = zeros(10) + @testset "Hermite Splines" begin + @inferred CubicHermiteSpline(du, u, t) + @inferred PCHIPInterpolation(u, t) + @inferred QuinticHermiteSpline(ddu, du, u, t) + end +end diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index d2c61c80..53fff25e 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -9,7 +9,6 @@ function test_interpolation_type(T) @test hasfield(T, :t) @test hasfield(T, :extrapolate) @test hasfield(T, :idx_prev) - @test hasfield(T, :safetycopy) @test !isempty(methods(DataInterpolations._interpolate, (T, Any, Number))) @test !isempty(methods(DataInterpolations._integral, (T, Any, Number))) @test !isempty(methods(DataInterpolations._derivative, (T, Any, Number))) @@ -161,6 +160,13 @@ end @test A(5.5) == fill(11.0) @test A(11) == fill(22) + # Test constant -Inf interpolation + u = [-Inf, -Inf] + t = [0.0, 1.0] + A = LinearInterpolation(u, t) + @test A(0.0) == -Inf + @test A(0.5) == -Inf + # Test extrapolation u = 2.0collect(1:10) t = 1.0collect(1:10) diff --git a/test/online_tests.jl b/test/online_tests.jl index 3cf832f0..1872e0cc 100644 --- a/test/online_tests.jl +++ b/test/online_tests.jl @@ -9,10 +9,11 @@ u2 = [1.0, 2.0, 1.0] ts_append = 1.0:0.5:6.0 ts_push = 1.0:0.5:4.0 -for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolation] - func1 = method(u1, t1) +@testset "$method" for method in [ + LinearInterpolation, QuadraticInterpolation, ConstantInterpolation] + func1 = method(copy(u1), copy(t1); cache_parameters = true) append!(func1, u2, t2) - func2 = method(vcat(u1, u2), vcat(t1, t2)) + func2 = method(vcat(u1, u2), vcat(t1, t2); cache_parameters = true) @test func1.u == func2.u @test func1.t == func2.t for name in propertynames(func1.p) @@ -21,9 +22,9 @@ for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolatio @test func1(ts_append) == func2(ts_append) @test func1.I == func2.I - func1 = method(u1, t1) + func1 = method(copy(u1), copy(t1); cache_parameters = true) push!(func1, 1.0, 4.0) - func2 = method(vcat(u1, 1.0), vcat(t1, 4.0)) + func2 = method(vcat(u1, 1.0), vcat(t1, 4.0); cache_parameters = true) @test func1.u == func2.u @test func1.t == func2.t for name in propertynames(func1.p) diff --git a/test/parameter_tests.jl b/test/parameter_tests.jl index bcd26cf7..2e84b98d 100644 --- a/test/parameter_tests.jl +++ b/test/parameter_tests.jl @@ -3,14 +3,14 @@ using DataInterpolations @testset "Linear Interpolation" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = LinearInterpolation(u, t) + A = LinearInterpolation(u, t; cache_parameters = true) @test A.p.slope ≈ [4.0, -2.0, 1.0, 0.0] end @testset "Quadratic Interpolation" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuadraticInterpolation(u, t) + A = QuadraticInterpolation(u, t; cache_parameters = true) @test A.p.l₀ ≈ [0.5, 2.5, 1.5] @test A.p.l₁ ≈ [-5.0, -3.0, -4.0] @test A.p.l₂ ≈ [1.5, 2.0, 2.0] @@ -19,14 +19,14 @@ end @testset "Quadratic Spline" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuadraticSpline(u, t) + A = QuadraticSpline(u, t; cache_parameters = true) @test A.p.σ ≈ [4.0, -10.0, 13.0, -14.0] end @testset "Cubic Spline" begin u = [1, 5, 3, 4, 4] t = collect(1:5) - A = CubicSpline(u, t) + A = CubicSpline(u, t; cache_parameters = true) @test A.p.c₁ ≈ [6.839285714285714, 1.642857142857143, 4.589285714285714, 4.0] @test A.p.c₂ ≈ [1.0, 6.839285714285714, 1.642857142857143, 4.589285714285714] end @@ -35,7 +35,7 @@ end du = [5.0, 3.0, 6.0, 8.0, 1.0] u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = CubicHermiteSpline(du, u, t) + A = CubicHermiteSpline(du, u, t; cache_parameters = true) @test A.p.c₁ ≈ [-1.0, -5.0, -5.0, -8.0] @test A.p.c₂ ≈ [0.0, 13.0, 12.0, 9.0] end @@ -45,7 +45,7 @@ end du = [5.0, 3.0, 6.0, 8.0, 1.0] u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuinticHermiteSpline(ddu, du, u, t) + A = QuinticHermiteSpline(ddu, du, u, t; cache_parameters = true) @test A.p.c₁ ≈ [-1.0, -6.5, -8.0, -10.0] @test A.p.c₂ ≈ [1.0, 19.5, 20.0, 19.0] @test A.p.c₃ ≈ [1.5, -37.5, -37.0, -26.5] diff --git a/test/runtests.jl b/test/runtests.jl index 0c722b2d..80080a75 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,3 +10,4 @@ using SafeTestsets @safetestset "Online Tests" include("online_tests.jl") @safetestset "Regularization Smoothing" include("regularization.jl") @safetestset "Show methods" include("show.jl") +@safetestset "Zygote support" include("zygote_tests.jl") diff --git a/test/zygote_tests.jl b/test/zygote_tests.jl new file mode 100644 index 00000000..1a7fc447 --- /dev/null +++ b/test/zygote_tests.jl @@ -0,0 +1,101 @@ +using DataInterpolations +using ForwardDiff +using Zygote + +function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name::String) + func = method(args..., u, t, args_after...; kwargs..., extrapolate = true) + (; u, t) = func + trange = collect(range(minimum(t) - 5.0, maximum(t) + 5.0, step = 0.1)) + trange_exclude = filter(x -> !in(x, t), trange) + @testset "$name, derivatives w.r.t. input" begin + for _t in trange_exclude + adiff = DataInterpolations.derivative(func, _t) + zdiff = only(Zygote.gradient(func, _t)) + isnothing(zdiff) && (zdiff = 0.0) + @test adiff ≈ zdiff + end + end + if method ∉ [LagrangeInterpolation, BSplineInterpolation, BSplineApprox] + @testset "$name, derivatives w.r.t. u" begin + function f(u) + A = method(args..., u, t, args_after...; kwargs..., extrapolate = true) + out = zero(eltype(u)) + for _t in trange + out += A(_t) + end + out + end + zgrad = only(Zygote.gradient(f, u)) + fgrad = ForwardDiff.gradient(f, u) + @test zgrad ≈ fgrad + end + end +end + +@testset "LinearInterpolation" begin + u = vcat(collect(1.0:5.0), 2 * collect(6.0:10.0)) + t = collect(1.0:10.0) + test_zygote( + LinearInterpolation, u, t; name = "Linear Interpolation") +end + +@testset "Quadratic Interpolation" begin + u = [1.0, 4.0, 9.0, 16.0] + t = [1.0, 2.0, 3.0, 4.0] + test_zygote(QuadraticInterpolation, u, t; name = "Quadratic Interpolation") +end + +@testset "Constant Interpolation" begin + u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] + t = collect(0.0:10.0) + test_zygote(ConstantInterpolation, u, t; name = "Constant Interpolation") +end + +@testset "Cubic Hermite Spline" begin + du = [-0.047, -0.058, 0.054, 0.012, -0.068, 0.0] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + t = [0.0, 62.25, 109.66, 162.66, 205.8, 252.3] + test_zygote(CubicHermiteSpline, u, t, args = [du], name = "Cubic Hermite Spline") +end + +@testset "Quintic Hermite Spline" begin + ddu = [0.0, -0.00033, 0.0051, -0.0067, 0.0029, 0.0] + du = [-0.047, -0.058, 0.054, 0.012, -0.068, 0.0] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + t = [0.0, 62.25, 109.66, 162.66, 205.8, 252.3] + test_zygote( + QuinticHermiteSpline, u, t, args = [ddu, du], name = "Quintic Hermite Spline") +end + +@testset "Quadratic Spline" begin + u = [1.0, 4.0, 9.0, 16.0] + t = [1.0, 2.0, 3.0, 4.0] + test_zygote(QuadraticSpline, u, t, name = "Quadratic Spline") +end + +@testset "Lagrange Interpolation" begin + u = [1.0, 4.0, 9.0] + t = [1.0, 2.0, 3.0] + test_zygote(LagrangeInterpolation, u, t, name = "Lagrange Interpolation") +end + +@testset "Constant Interpolation" begin + u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] + t = collect(0.0:10.0) + test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation") +end + +@testset "Cubic Spline" begin + u = [0.0, 1.0, 3.0] + t = [-1.0, 0.0, 1.0] + test_zygote(CubicSpline, u, t, name = "Cubic Spline") +end + +@testset "BSplines" begin + t = [0, 62.25, 109.66, 162.66, 205.8, 252.3] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + test_zygote(BSplineInterpolation, u, t; args_after = [2, :Uniform, :Uniform], + name = "BSpline Interpolation") + test_zygote(BSplineApprox, u, t; args_after = [2, 4, :Uniform, :Uniform], + name = "BSpline approximation") +end