From 4aff70ca5469b560644c6bc345b3776c4cee913c Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Wed, 17 Jul 2024 07:58:22 +0000 Subject: [PATCH 01/35] docs: changes in joss paper according to reviews --- joss/paper.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/joss/paper.md b/joss/paper.md index f2d2a4be..3777da64 100644 --- a/joss/paper.md +++ b/joss/paper.md @@ -31,11 +31,11 @@ bibliography: paper.bib # Summary -Interpolations are used to estimate values between known data points using an approximate continuous function.DataInterpolations.jl is a Julia [@Bezanson2017] package containing 1D implementations of some of the most commonly used interpolation functions. These include Constant Interpolation, Linear Interpolation, Quadratic Interpolation, Lagrange Interpolation [@lagrange], Quadratic Splines, Cubic Splines [@Schoenberg1988], Akima Splines [@10.1145/321607.321609], Cubic Hermite Splines, Quintic Hermite Splines, B-Splines [@Curry1988] [@DEBOOR197250] and Regression based B-Splines. Along with these, the package also has methods to fit parameterized curves with the data points and Tikhonov regularization [@Tikhonov1943OnTS] [@amt-14-7909-2021] for obtaining smooth curves. The package also provides functionality to compute integrals and derivatives upto second order for those interpolations methods. +Interpolations are used to estimate values between known data points using an approximate continuous function.DataInterpolations.jl is a Julia [@Bezanson2017] package containing 1D implementations of some of the most commonly used interpolation functions. These include Constant Interpolation, Linear Interpolation, Quadratic Interpolation, Lagrange Interpolation [@lagrange], Quadratic Splines, Cubic Splines [@Schoenberg1988], Akima Splines [@10.1145/321607.321609], Cubic Hermite Splines, Piecewise Cubic Hermite Interpolating Polynomial (PCHIP), Quintic Hermite Splines, B-Splines [@Curry1988] [@DEBOOR197250] and Regression based B-Splines. Along with these, the package also has methods to fit parameterized curves with the data points and Tikhonov regularization [@Tikhonov1943OnTS] [@amt-14-7909-2021] for obtaining smooth curves. The package also provides functionality to compute integrals and derivatives upto second order for those interpolations methods. It is also automatic differentiation friendly. It can also be used symbolically with Symbolics.jl [@gowda2021high] and plugged into models defined using ModelingToolkit.jl [@ma2021modelingtoolkit]. # Statement of need -Interpolations are a very important component of many modeling workflows. In many models, inputs which are sampled or measured need to be represented as a continuous function or a smooth curve for simulation. In many scientific machine learning workflows, we need interpolations of data to learn continuous models. There already have been a few interpolation packages in Julia like Interpolations.jl but it has a limitation of assuming uniformly spaced data which is not usually the case with data collected from real world. DataInterpolations.jl provides fast interpolation methods for arbitrary spaced 1D data with a consistent and simple interface. It is also automatic differentiation friendly. It can also be used symbolically with Symbolics.jl [@gowda2021high] and plugged into models defined using ModelingToolkit.jl [@ma2021modelingtoolkit]. +Interpolations are a very important component of many modeling workflows. Often, sampled or measured inputs need to be transformed into continuous functions or smooth curves for simulation purposes. In many scientific machine learning workflows, interpolating data is essential to learn continuous models. DataInterpolations.jl can be used for facilitating these types of workflows. Several interpolation packages already exist in Julia, such as [Interpolations.jl](https://juliamath.github.io/Interpolations.jl/stable/), which primarily specializes in B-Splines and uniformly spaced data with some support for irregularly spaced data. In contrast, DataInterpolations.jl does not assume any specific structure in the data, offering greater flexibility for diverse datasets. [Interpolations.jl](https://juliamath.github.io/Interpolations.jl/stable/) also doesn't offer methods like Quadratic Interpolation, Lagrange Interpolation, Hermite Splines etc. [BasicInterpolators.jl](https://github.com/markmbaum/BasicInterpolators.jl) is more similar to DataInterpolations.jl, although it doesn't offer methods like B-Splines. Rest of the interpolation packages focus on particular methods like [BSplineKit.jl](https://github.com/jipolanco/BSplineKit.jl) for B-Splines, [FastChebInterp.jl](https://github.com/JuliaMath/FastChebInterp.jl) for Chebyshev interpolation, [PCHIPInterpolation](https://github.com/gerlero/PCHIPInterpolation.jl) for PCHIP interpolation etc. In summary, DataInterpolations.jl is more generic from other packages and offers many fast interpolation methods for arbitrarily spaced 1D data, all within a consistent and simple interface. # Example From 601491dc8c1b576d3ecc9bcea408491065871e8d Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Wed, 17 Jul 2024 09:22:37 +0000 Subject: [PATCH 02/35] docs: include all suggestions for joss paper --- joss/paper.md | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/joss/paper.md b/joss/paper.md index 3777da64..ce70c1ca 100644 --- a/joss/paper.md +++ b/joss/paper.md @@ -31,11 +31,26 @@ bibliography: paper.bib # Summary -Interpolations are used to estimate values between known data points using an approximate continuous function.DataInterpolations.jl is a Julia [@Bezanson2017] package containing 1D implementations of some of the most commonly used interpolation functions. These include Constant Interpolation, Linear Interpolation, Quadratic Interpolation, Lagrange Interpolation [@lagrange], Quadratic Splines, Cubic Splines [@Schoenberg1988], Akima Splines [@10.1145/321607.321609], Cubic Hermite Splines, Piecewise Cubic Hermite Interpolating Polynomial (PCHIP), Quintic Hermite Splines, B-Splines [@Curry1988] [@DEBOOR197250] and Regression based B-Splines. Along with these, the package also has methods to fit parameterized curves with the data points and Tikhonov regularization [@Tikhonov1943OnTS] [@amt-14-7909-2021] for obtaining smooth curves. The package also provides functionality to compute integrals and derivatives upto second order for those interpolations methods. It is also automatic differentiation friendly. It can also be used symbolically with Symbolics.jl [@gowda2021high] and plugged into models defined using ModelingToolkit.jl [@ma2021modelingtoolkit]. +Interpolations are used to estimate values between known data points using an approximate continuous function.DataInterpolations.jl is a Julia [@Bezanson2017] package containing 1D implementations of some of the most commonly used interpolation functions. These include: + + - Constant Interpolation + - Linear Interpolation + - Quadratic Interpolation + - Lagrange Interpolation [@lagrange] + - Quadratic Splines + - Cubic Splines [@Schoenberg1988] + - Akima Splines [@10.1145/321607.321609] + - Cubic Hermite Splines + - Piecewise Cubic Hermite Interpolating Polynomial (PCHIP) [@doi:10.1137/0905021] + - Quintic Hermite Splines + - B-Splines [@Curry1988] [@DEBOOR197250] + - Regression based B-Splines + +and a continually growing list. Along with these, the package also has methods to fit parameterized curves with the data points and Tikhonov regularization [@Tikhonov1943OnTS] [@amt-14-7909-2021] for obtaining smooth curves. The package also provides functionality to compute integrals and derivatives upto second order for those interpolations methods. It is also automatic differentiation friendly. It can also be used symbolically with Symbolics.jl [@gowda2021high] and plugged into models defined using ModelingToolkit.jl [@ma2021modelingtoolkit]. # Statement of need -Interpolations are a very important component of many modeling workflows. Often, sampled or measured inputs need to be transformed into continuous functions or smooth curves for simulation purposes. In many scientific machine learning workflows, interpolating data is essential to learn continuous models. DataInterpolations.jl can be used for facilitating these types of workflows. Several interpolation packages already exist in Julia, such as [Interpolations.jl](https://juliamath.github.io/Interpolations.jl/stable/), which primarily specializes in B-Splines and uniformly spaced data with some support for irregularly spaced data. In contrast, DataInterpolations.jl does not assume any specific structure in the data, offering greater flexibility for diverse datasets. [Interpolations.jl](https://juliamath.github.io/Interpolations.jl/stable/) also doesn't offer methods like Quadratic Interpolation, Lagrange Interpolation, Hermite Splines etc. [BasicInterpolators.jl](https://github.com/markmbaum/BasicInterpolators.jl) is more similar to DataInterpolations.jl, although it doesn't offer methods like B-Splines. Rest of the interpolation packages focus on particular methods like [BSplineKit.jl](https://github.com/jipolanco/BSplineKit.jl) for B-Splines, [FastChebInterp.jl](https://github.com/JuliaMath/FastChebInterp.jl) for Chebyshev interpolation, [PCHIPInterpolation](https://github.com/gerlero/PCHIPInterpolation.jl) for PCHIP interpolation etc. In summary, DataInterpolations.jl is more generic from other packages and offers many fast interpolation methods for arbitrarily spaced 1D data, all within a consistent and simple interface. +Interpolations are a very important component of many modeling workflows. Often, sampled or measured inputs need to be transformed into continuous functions or smooth curves for simulation purposes. In many scientific machine learning workflows, interpolating data is essential to learn continuous models. DataInterpolations.jl can be used for facilitating these types of workflows. Several interpolation packages already exist in Julia, such as [Interpolations.jl](https://juliamath.github.io/Interpolations.jl/stable/), which primarily specializes in B-Splines and uniformly spaced data with some support for irregularly spaced data. In contrast, DataInterpolations.jl does not assume any specific structure in the data, offering greater flexibility for diverse datasets. [Interpolations.jl](https://juliamath.github.io/Interpolations.jl/stable/) also doesn't offer methods like Quadratic Interpolation, Lagrange Interpolation, Hermite Splines etc. [BasicInterpolators.jl](https://github.com/markmbaum/BasicInterpolators.jl) is more similar to DataInterpolations.jl, although it doesn't offer methods like B-Splines. Rest of the interpolation packages focus on particular methods like [BSplineKit.jl](https://github.com/jipolanco/BSplineKit.jl) for B-Splines, [FastChebInterp.jl](https://github.com/JuliaMath/FastChebInterp.jl) for Chebyshev interpolation, [PCHIPInterpolation](https://github.com/gerlero/PCHIPInterpolation.jl) for PCHIP interpolation etc. Additionally, DataInterpolations.jl includes many novel techniques for accelerating the interpolation searches with specialized caching, quasi-linear guessing, and more to improve the performance algorithmically, beyond the simple computational optimizations. In summary, DataInterpolations.jl is more generic from other packages and offers many fast interpolation methods for arbitrarily spaced 1D data, all within a consistent and simple interface. # Example From 4f688f9df5270d7ed28bed43ec39b99247cf3c61 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Wed, 17 Jul 2024 09:22:59 +0000 Subject: [PATCH 03/35] docs: add citation for pchip interpolation --- joss/paper.bib | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/joss/paper.bib b/joss/paper.bib index d754d9cc..101b0181 100644 --- a/joss/paper.bib +++ b/joss/paper.bib @@ -134,3 +134,17 @@ @book{lagrange1898lectures year={1898}, publisher={Open court publishing Company} } + +@article{doi:10.1137/0905021, + author = {Fritsch, F. N. and Butland, J.}, + title = {A Method for Constructing Local Monotone Piecewise Cubic Interpolants}, + journal = {SIAM Journal on Scientific and Statistical Computing}, + volume = {5}, + number = {2}, + pages = {300-304}, + year = {1984}, + doi = {10.1137/0905021}, + URL = {https://doi.org/10.1137/0905021}, + eprint = {https://doi.org/10.1137/0905021}, + abstract = { A method is described for producing monotone piecewise cubic interpolants to monotone data which is completely local and which is extremely simple to implement. } +} From 69dbed7ddab1e7132db9c71e4cf8916d5d4f000d Mon Sep 17 00:00:00 2001 From: fjebaker Date: Mon, 22 Jul 2024 23:07:22 +0200 Subject: [PATCH 04/35] fix: type inference for interface The interface for, e.g. LinearInterpolation, is not type stable because `cummulative_integral` returns a union. This is because it does a runtime `isempty` check on the methods table. Replacing this with a `hasmethod` check allows the boolean to be determined during compilation, fixing the types. --- src/interpolation_utils.jl | 2 +- test/interface.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 0f749316..db4393b2 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -122,7 +122,7 @@ function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = : end function cumulative_integral(A) - if isempty(methods(_integral, (typeof(A), Any, Any))) + if !hasmethod(_integral, Tuple{typeof(A), Number, Number}) return nothing end integral_values = [_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx]) diff --git a/test/interface.jl b/test/interface.jl index 5d02a22a..b4ce60ba 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -1,6 +1,7 @@ using DataInterpolations u = 2.0collect(1:10) t = 1.0collect(1:10) +@inferred LinearInterpolation(u, t) A = LinearInterpolation(u, t) for i in 1:10 From 1e1e59a0f825763fea698227d727723040fc2cbe Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 23:59:15 +0000 Subject: [PATCH 05/35] build(deps): bump julia-actions/setup-julia from 2.2.0 to 2.3.0 Bumps [julia-actions/setup-julia](https://github.com/julia-actions/setup-julia) from 2.2.0 to 2.3.0. - [Release notes](https://github.com/julia-actions/setup-julia/releases) - [Commits](https://github.com/julia-actions/setup-julia/compare/v2.2...v2.3) --- updated-dependencies: - dependency-name: julia-actions/setup-julia dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/CompatHelper.yml | 2 +- .github/workflows/Downgrade.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 36e59135..35cc34ba 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -13,7 +13,7 @@ jobs: CompatHelper: runs-on: ubuntu-latest steps: - - uses: julia-actions/setup-julia@3645a07f58c7f83b9f82ac8e0bb95583e69149e6 + - uses: julia-actions/setup-julia@780022b48dfc0c2c6b94cfee6a9284850107d037 with: version: 1.3 - name: Pkg.add("CompatHelper") diff --git a/.github/workflows/Downgrade.yml b/.github/workflows/Downgrade.yml index c0d0123e..4546ebd0 100644 --- a/.github/workflows/Downgrade.yml +++ b/.github/workflows/Downgrade.yml @@ -28,7 +28,7 @@ jobs: - windows-latest steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2.2.0 + - uses: julia-actions/setup-julia@v2.3.0 with: version: ${{ matrix.version }} - uses: julia-actions/julia-downgrade-compat@v1 From 864c8a6726392cc92b98eb3ef9aa47a24a5e3a01 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Tue, 23 Jul 2024 06:45:36 +0000 Subject: [PATCH 06/35] fix: idxs for derivatives for LinearInterpolation --- src/derivatives.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/derivatives.jl b/src/derivatives.jl index 30c76fd0..4d8f7189 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -18,7 +18,7 @@ function derivative(A, t, order = 1) end function _derivative(A::LinearInterpolation, t::Number, iguess) - idx = get_idx(A.t, t, iguess; idx_shift = -1, ub_shift = -2, side = :first) + idx = get_idx(A.t, t, iguess; idx_shift = -1, ub_shift = -1, side = :first) A.p.slope[idx], idx end From e1e6b05e61327d37bfcc2a11acab92343eeb3931 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Tue, 23 Jul 2024 06:46:03 +0000 Subject: [PATCH 07/35] test: add test for derivatives with two points --- test/derivative_tests.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index 37351d0d..50abe4ac 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -82,6 +82,12 @@ end u = vcat(2.0collect(1:10)', 3.0collect(1:10)') test_derivatives( LinearInterpolation; args = [u, t], name = "Linear Interpolation (Matrix)") + + # Issue: https://github.com/SciML/DataInterpolations.jl/issues/303 + u = [3.0, 3.0] + t = [0.0, 2.0] + test_derivatives( + LinearInterpolation; args = [u, t], name = "Linear Interpolation with two points") end @testset "Quadratic Interpolation" begin From c80d7bdc135d94f71cd3c37345d01d18852562d7 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Tue, 23 Jul 2024 10:29:43 +0000 Subject: [PATCH 08/35] refactor: collect parameters such that they are type stable --- src/online.jl | 2 +- src/parameter_caches.jl | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/online.jl b/src/online.jl index 0fab5d44..60c53d79 100644 --- a/src/online.jl +++ b/src/online.jl @@ -68,7 +68,7 @@ function append!( append!(A.t.parent, t) parameters = quadratic_interpolation_parameters.( Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) - l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) + l₀, l₁, l₂ = collect.(eachrow(reduce(hcat, collect.(parameters)))) append!(A.p.l₀, l₀) append!(A.p.l₁, l₁) append!(A.p.l₂, l₂) diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index 2820dc8f..fcd78156 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -24,7 +24,7 @@ end function QuadraticParameterCache(u, t) parameters = quadratic_interpolation_parameters.( Ref(u), Ref(t), 1:(length(t) - 2)) - l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) + l₀, l₁, l₂ = collect.(eachrow(reduce(hcat, collect.(parameters)))) return QuadraticParameterCache(l₀, l₁, l₂) end @@ -72,7 +72,7 @@ end function CubicSplineParameterCache(u, h, z) parameters = cubic_spline_parameters.( Ref(u), Ref(h), Ref(z), 1:(size(u)[end] - 1)) - c₁, c₂ = collect.(eachrow(hcat(collect.(parameters)...))) + c₁, c₂ = collect.(eachrow(reduce(hcat, collect.(parameters)))) return CubicSplineParameterCache(c₁, c₂) end @@ -90,7 +90,7 @@ end function CubicHermiteParameterCache(du, u, t) parameters = cubic_hermite_spline_parameters.( Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) - c₁, c₂ = collect.(eachrow(hcat(collect.(parameters)...))) + c₁, c₂ = collect.(eachrow(reduce(hcat, collect.(parameters)))) return CubicHermiteParameterCache(c₁, c₂) end @@ -114,7 +114,7 @@ end function QuinticHermiteParameterCache(ddu, du, u, t) parameters = quintic_hermite_spline_parameters.( Ref(ddu), Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) - c₁, c₂, c₃ = collect.(eachrow(hcat(collect.(parameters)...))) + c₁, c₂, c₃ = collect.(eachrow(reduce(hcat, collect.(parameters)))) return QuinticHermiteParameterCache(c₁, c₂, c₃) end From c292467df27a39dae60850c47c61ec083976eaf3 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Tue, 23 Jul 2024 11:22:54 +0000 Subject: [PATCH 09/35] test: add tests for inference --- test/interface.jl | 61 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/test/interface.jl b/test/interface.jl index b4ce60ba..3e910547 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -1,21 +1,52 @@ using DataInterpolations -u = 2.0collect(1:10) -t = 1.0collect(1:10) -@inferred LinearInterpolation(u, t) -A = LinearInterpolation(u, t) +using Symbolics -for i in 1:10 - @test u[i] == A.u[i] -end +@testset "Interface" begin + u = 2.0collect(1:10) + t = 1.0collect(1:10) + A = LinearInterpolation(u, t) + + for i in 1:10 + @test u[i] == A.u[i] + end -for i in 1:10 - @test t[i] == A.t[i] + for i in 1:10 + @test t[i] == A.t[i] + end end -using Symbolics -u = 2.0collect(1:10) -t = 1.0collect(1:10) -A = LinearInterpolation(u, t) +@testset "Symbolics" begin + u = 2.0collect(1:10) + t = 1.0collect(1:10) + A = LinearInterpolation(u, t) + @variables t x(t) + substitute(A(t), Dict(t => x)) +end -@variables t x(t) -substitute(A(t), Dict(t => x)) +@testset "Type Inference" begin + u = 2.0collect(1:10) + t = 1.0collect(1:10) + methods = [ + ConstantInterpolation, LinearInterpolation, + QuadraticInterpolation, LagrangeInterpolation, + QuadraticSpline, CubicSpline, AkimaInterpolation + ] + @testset "$method" for method in methods + @inferred method(u, t) + end + @testset "BSplineInterpolation" begin + @inferred BSplineInterpolation(u, t, 3, :Uniform, :Uniform) + @inferred BSplineInterpolation(u, t, 3, :ArcLen, :Average) + end + @testset "BSplineApprox" begin + @inferred BSplineApprox(u, t, 3, 5, :Uniform, :Uniform) + @inferred BSplineApprox(u, t, 3, 5, :ArcLen, :Average) + end + du = ones(10) + ddu = zeros(10) + @testset "Hermite Splines" begin + @inferred CubicHermiteSpline(du, u, t) + @inferred PCHIPInterpolation(u, t) + @inferred QuinticHermiteSpline(ddu, du, u, t) + end +end From 1121c945dedf5ce699960738f7f585997b7ab205 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Tue, 23 Jul 2024 11:40:01 +0000 Subject: [PATCH 10/35] test: label online tests --- test/online_tests.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/online_tests.jl b/test/online_tests.jl index 3cf832f0..160dd06f 100644 --- a/test/online_tests.jl +++ b/test/online_tests.jl @@ -9,7 +9,8 @@ u2 = [1.0, 2.0, 1.0] ts_append = 1.0:0.5:6.0 ts_push = 1.0:0.5:4.0 -for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolation] +@testset "$method" for method in [ + LinearInterpolation, QuadraticInterpolation, ConstantInterpolation] func1 = method(u1, t1) append!(func1, u2, t2) func2 = method(vcat(u1, u2), vcat(t1, t2)) From 826f0a70fbc8de6f7fa735c0ddaebc4a48b631ae Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Tue, 23 Jul 2024 12:01:09 +0000 Subject: [PATCH 11/35] refactor: use stack instead of reduce hcat --- src/online.jl | 2 +- src/parameter_caches.jl | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/online.jl b/src/online.jl index 60c53d79..630d5cf0 100644 --- a/src/online.jl +++ b/src/online.jl @@ -68,7 +68,7 @@ function append!( append!(A.t.parent, t) parameters = quadratic_interpolation_parameters.( Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) - l₀, l₁, l₂ = collect.(eachrow(reduce(hcat, collect.(parameters)))) + l₀, l₁, l₂ = collect.(eachrow(stack(collect.(parameters)))) append!(A.p.l₀, l₀) append!(A.p.l₁, l₁) append!(A.p.l₂, l₂) diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index fcd78156..e6689b64 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -24,7 +24,7 @@ end function QuadraticParameterCache(u, t) parameters = quadratic_interpolation_parameters.( Ref(u), Ref(t), 1:(length(t) - 2)) - l₀, l₁, l₂ = collect.(eachrow(reduce(hcat, collect.(parameters)))) + l₀, l₁, l₂ = collect.(eachrow(stack(collect.(parameters)))) return QuadraticParameterCache(l₀, l₁, l₂) end @@ -72,7 +72,7 @@ end function CubicSplineParameterCache(u, h, z) parameters = cubic_spline_parameters.( Ref(u), Ref(h), Ref(z), 1:(size(u)[end] - 1)) - c₁, c₂ = collect.(eachrow(reduce(hcat, collect.(parameters)))) + c₁, c₂ = collect.(eachrow(stack(collect.(parameters)))) return CubicSplineParameterCache(c₁, c₂) end @@ -90,7 +90,7 @@ end function CubicHermiteParameterCache(du, u, t) parameters = cubic_hermite_spline_parameters.( Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) - c₁, c₂ = collect.(eachrow(reduce(hcat, collect.(parameters)))) + c₁, c₂ = collect.(eachrow(stack(collect.(parameters)))) return CubicHermiteParameterCache(c₁, c₂) end @@ -114,7 +114,7 @@ end function QuinticHermiteParameterCache(ddu, du, u, t) parameters = quintic_hermite_spline_parameters.( Ref(ddu), Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) - c₁, c₂, c₃ = collect.(eachrow(reduce(hcat, collect.(parameters)))) + c₁, c₂, c₃ = collect.(eachrow(stack(collect.(parameters)))) return QuinticHermiteParameterCache(c₁, c₂, c₃) end From a6a339927fe0b3a74da981eca3fae4de279b0c4a Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Tue, 23 Jul 2024 14:17:53 +0200 Subject: [PATCH 12/35] Add fix + test --- src/parameter_caches.jl | 15 +++++++++++++-- test/interpolation_tests.jl | 7 +++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index 2820dc8f..690e27ac 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -7,8 +7,19 @@ function LinearParameterCache(u, t) return LinearParameterCache(slope) end -function linear_interpolation_parameters(u, t, idx) - Δu = u isa AbstractMatrix ? u[:, idx + 1] - u[:, idx] : u[idx + 1] - u[idx] +""" +Prevent e.g. Inf - Inf = NaN +""" +function safe_diff(b, a::T) where T + b == a ? zero(T) : b - a +end + +function linear_interpolation_parameters(u::AbstractArray{T}, t, idx) where T + Δu = if u isa AbstractMatrix + [safe_diff(u[j, idx + 1], u[j, idx]) for j in 1:size(u)[1]] + else + safe_diff(u[idx + 1], u[idx]) + end Δt = t[idx + 1] - t[idx] slope = Δu / Δt slope = iszero(Δt) ? zero(slope) : slope diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 9549038c..acd0394a 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -161,6 +161,13 @@ end @test A(5.5) == fill(11.0) @test A(11) == fill(22) + # Test constant -Inf interpolation + u = [-Inf, -Inf] + t = [0.0, 1.0] + A = LinearInterpolation(u, t) + @test A(0.0) == -Inf + @test A(0.5) == -Inf + # Test extrapolation u = 2.0collect(1:10) t = 1.0collect(1:10) From b738c7a38d932f16610fcb2a7b68d9daa415d1b5 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Tue, 23 Jul 2024 14:18:32 +0200 Subject: [PATCH 13/35] Formatting --- src/parameter_caches.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index 690e27ac..77007039 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -10,15 +10,15 @@ end """ Prevent e.g. Inf - Inf = NaN """ -function safe_diff(b, a::T) where T - b == a ? zero(T) : b - a +function safe_diff(b, a::T) where {T} + b == a ? zero(T) : b - a end -function linear_interpolation_parameters(u::AbstractArray{T}, t, idx) where T +function linear_interpolation_parameters(u::AbstractArray{T}, t, idx) where {T} Δu = if u isa AbstractMatrix [safe_diff(u[j, idx + 1], u[j, idx]) for j in 1:size(u)[1]] else - safe_diff(u[idx + 1], u[idx]) + safe_diff(u[idx + 1], u[idx]) end Δt = t[idx + 1] - t[idx] slope = Δu / Δt From 485dc328ba4a88e07c5d40d70fe84ba51a79799c Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Tue, 23 Jul 2024 14:22:55 +0200 Subject: [PATCH 14/35] Remove docstring --- src/parameter_caches.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index 77007039..97d5e2ae 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -7,9 +7,7 @@ function LinearParameterCache(u, t) return LinearParameterCache(slope) end -""" -Prevent e.g. Inf - Inf = NaN -""" +# Prevent e.g. Inf - Inf = NaN function safe_diff(b, a::T) where {T} b == a ? zero(T) : b - a end From 5edaa381fa25f2c31a6716eab6838df6e1acd184 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Tue, 23 Jul 2024 14:07:57 +0000 Subject: [PATCH 15/35] docs: add statement of need in docs --- docs/src/index.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 9b1277f3..f2075f8e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,10 +1,6 @@ # DataInterpolations.jl -DataInterpolations.jl is a library for performing interpolations of one-dimensional data. By -"data interpolations" we mean techniques for interpolating possibly noisy data, and thus -some methods are mixtures of regressions with interpolations (i.e. do not hit the data -points exactly, smoothing out the lines). This library can be used to fill in intermediate -data points in applications like timeseries data. +DataInterpolations.jl is a library for performing interpolations of one-dimensional data. Interpolations are a very important component of many modeling workflows. Often, sampled or measured inputs need to be transformed into continuous functions or smooth curves for simulation purposes. In many scientific machine learning workflows, interpolating data is essential to learn continuous models. DataInterpolations.jl can be used for facilitating these types of workflows. By "data interpolations" we mean techniques for interpolating possibly noisy data, and thus some methods are mixtures of regressions with interpolations (i.e. do not hit the data points exactly, smoothing out the lines). ## Installation From 3d9c3d6fa3a372802459d5419b9ea08bdba3ab44 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Wed, 17 Jul 2024 11:08:13 +0000 Subject: [PATCH 16/35] docs: small tutorial on using DataInterpolations with Symbolics/MTK --- docs/make.jl | 3 +- docs/src/symbolics.md | 65 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 docs/src/symbolics.md diff --git a/docs/make.jl b/docs/make.jl index 94546761..6438f482 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -13,6 +13,7 @@ makedocs(modules = [DataInterpolations], format = Documenter.HTML(assets = ["assets/favicon.ico"], canonical = "https://docs.sciml.ai/DataInterpolations/stable/"), pages = ["index.md", "Methods" => "methods.md", - "Interface" => "interface.md", "Manual" => "manual.md", "Inverting Integrals" => "inverting_integrals.md"]) + "Interface" => "interface.md", "Using with Symbolics/ModelingToolkit" => "symbolics.md", + "Manual" => "manual.md", "Inverting Integrals" => "inverting_integrals.md"]) deploydocs(repo = "github.com/SciML/DataInterpolations.jl"; push_preview = true) diff --git a/docs/src/symbolics.md b/docs/src/symbolics.md new file mode 100644 index 00000000..77d45ead --- /dev/null +++ b/docs/src/symbolics.md @@ -0,0 +1,65 @@ +# Using DataInterpolations.jl with Symbolics.jl and ModelingToolkit.jl + +All interpolation methods can be integrated with [Symbolics.jl](https://symbolics.juliasymbolics.org/stable/) and [ModelingToolkit.jl](https://docs.sciml.ai/ModelingToolkit/stable/) seamlessly. + +## Using with Symbolics.jl + +### Expressions + +```@example symbolics +using DataInterpolations, Symbolics +using Test + +u = [0.0, 1.5, 0.0] +t = [0.0, 0.5, 1.0] +A = LinearInterpolation(u, t) + +@variables τ + +# Simple Expression +ex = cos(τ) * A(τ) +@test substitute(ex, Dict(τ => 0.5)) == cos(0.5) * A(0.5) # true +``` + +### Symbolic Derivatives + +```@example symbolics +D = Differential(τ) + +ex1 = A(τ) + +# Derivative of interpolation +ex2 = expand_derivatives(D(ex1)) + +@test substitute(ex2, Dict(τ => 0.5)) == DataInterpolations.derivative(A, 0.5) # true + +# Higher Order Derivatives +ex3 = expand_derivatives(D(D(A(τ)))) + +@test substitute(ex3, Dict(τ => 0.5)) == DataInterpolations.derivative(A, 0.5, 2) # true +``` + +## Using with ModelingToolkit.jl + +Most common use case with [ModelingToolkit.jl](https://docs.sciml.ai/ModelingToolkit/stable/) is to plug in interpolation objects as input functions. This can be done using `TimeVaryingFunction` component of [ModelingToolkitStandardLibrary.jl](https://docs.sciml.ai/ModelingToolkitStandardLibrary/stable/). + +```@example mtk +using DataInterpolations +using ModelingToolkitStandardLibrary.Blocks +using ModelingToolkit +using ModelingToolkit: t_nounits as t, D_nounits as D +using OrdinaryDiffEq + +us = [0.0, 1.5, 0.0] +times = [0.0, 0.5, 1.0] +A = LinearInterpolation(us, times) + +@named src = TimeVaryingFunction(A) +vars = @variables x(t) out(t) +eqs = [out ~ src.output.u, D(x) ~ 1 + out] +@named sys = ODESystem(eqs, t, vars, []; systems = [src]) + +sys = structural_simplify(sys) +prob = ODEProblem(sys, [x => 0.0], (times[1], times[end])) +sol = solve(prob) +``` From 558c6834262da61c18e03af9bca3833526f98389 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Wed, 17 Jul 2024 11:08:40 +0000 Subject: [PATCH 17/35] build(docs): add Symbolics, MTK and reqired deps --- docs/Project.toml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index 30a52bff..ffa0d3ad 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,15 +1,23 @@ [deps] DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" +ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739" Optim = "429524aa-4258-5aef-a3af-852621145aeb" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [compat] DataInterpolations = "5" Documenter = "1" +ModelingToolkit = "9" +ModelingToolkitStandardLibrary = "2" Optim = "1" +OrdinaryDiffEq = "6" Plots = "1" RegularizationTools = "0.6" -StableRNGs = "1" \ No newline at end of file +StableRNGs = "1" +Symbolics = "5.29" From 8f35994361d35f43430936b2bc15f45e84671b78 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan <35105271+sathvikbhagavan@users.noreply.github.com> Date: Wed, 24 Jul 2024 11:20:59 +0530 Subject: [PATCH 18/35] build: bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7f08abb0..d4133013 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DataInterpolations" uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" -version = "5.3.0" +version = "5.3.1" [deps] FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" From 7cb35b60a98201bd0c15999746c40b405ef9bcb5 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Wed, 24 Jul 2024 05:58:00 +0000 Subject: [PATCH 19/35] docs: add reference to symbolics/mtk tutorial in joss paper --- joss/paper.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/joss/paper.md b/joss/paper.md index ce70c1ca..66d15aa0 100644 --- a/joss/paper.md +++ b/joss/paper.md @@ -54,7 +54,7 @@ Interpolations are a very important component of many modeling workflows. Often, # Example -The following tutorials in the documentation [1](https://docs.sciml.ai/DataInterpolations/stable/methods/) provides how to define each of the interpolation methods and compute the value at any point. [2](https://docs.sciml.ai/DataInterpolations/stable/interface/) provides explanation for using the interface and interpolated objects for evaluating at any point, computing the derivative at any point and computing the integral between any two points. +The following tutorials in the documentation [1](https://docs.sciml.ai/DataInterpolations/stable/methods/) provides how to define each of the interpolation methods and compute the value at any point. [2](https://docs.sciml.ai/DataInterpolations/stable/interface/) provides explanation for using the interface and interpolated objects for evaluating at any point, computing the derivative at any point and computing the integral between any two points. [3](https://docs.sciml.ai/DataInterpolations/stable/symbolics/) provides how to use interpolation objects with Symbolics.jl and ModelingToolkit.jl. A simple demonstration here: From c8b322455377bcbfe6b08ea14373979def5d00f3 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 27 Jul 2024 17:17:57 +0200 Subject: [PATCH 20/35] Refactor parameter caching, add zygote tests --- Project.toml | 5 +- docs/src/interface.md | 17 +- ext/DataInterpolationsOptimExt.jl | 5 +- ...ataInterpolationsRegularizationToolsExt.jl | 28 +- src/DataInterpolations.jl | 19 +- src/derivatives.jl | 23 +- src/integral_inverses.jl | 15 +- src/integrals.jl | 54 ++- src/interpolation_caches.jl | 335 +++++++++++------- src/interpolation_methods.jl | 30 +- src/interpolation_utils.jl | 84 +++-- src/online.jl | 96 ++--- test/interpolation_tests.jl | 1 - test/online_tests.jl | 8 +- test/parameter_tests.jl | 12 +- test/runtests.jl | 1 + test/zygote_tests.jl | 66 ++++ 17 files changed, 497 insertions(+), 302 deletions(-) create mode 100644 test/zygote_tests.jl diff --git a/Project.toml b/Project.toml index b06d2adb..e492399f 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -33,7 +32,6 @@ LinearAlgebra = "1.10" Optim = "1.6" PrettyTables = "2" QuadGK = "2.9.1" -ReadOnlyArrays = "0.2.0" RecipesBase = "1.3" Reexport = "1" RegularizationTools = "0.6" @@ -55,6 +53,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics"] +test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Zygote"] diff --git a/docs/src/interface.md b/docs/src/interface.md index ca5e9819..cfc9ed0b 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -35,22 +35,7 @@ A2(300.0) The values computed beyond the range of the time points provided during interpolation will not be reliable, as these methods only perform well within the range and the first/last piece polynomial fit is extrapolated on either side which might not reflect the true nature of the data. -The keyword `safetycopy = false` can be passed to make sure no copies of `u` and `t` are made when initializing the interpolation object. - -```@example interface -A3 = QuadraticInterpolation(u, t; safetycopy = false) - -# Check for same memory -u === A3.u.parent -``` - -Note that this does not prevent allocation in every interpolation constructor call, because parameter values are cached for all interpolation types except [`ConstantInterpolation`](@ref). - -Because of the caching of parameters which depend on `u` and `t`, this data should not be mutated. Therefore `u` and `t` are wrapped in a `ReadOnlyArray` from [ReadOnlyArrays.jl](https://github.com/JuliaArrays/ReadOnlyArrays.jl). - -```@repl interface -A3.t[2] = 3.14 -``` +The keyword `cache_parameters = true` can be passed to precalculate parameters at initialization, making evalations cheaper to compute. This is not compatible with modifying `u` and `t`. The default `cache_parameters = false` does however not prevent allocation in every interpolation constructor call. ## Derivatives diff --git a/ext/DataInterpolationsOptimExt.jl b/ext/DataInterpolationsOptimExt.jl index 5528503f..b3bce295 100644 --- a/ext/DataInterpolationsOptimExt.jl +++ b/ext/DataInterpolationsOptimExt.jl @@ -18,9 +18,8 @@ function Curvefit(u, box = false, lb = nothing, ub = nothing; - extrapolate = false, - safetycopy = false) - u, t = munge_data(u, t, safetycopy) + extrapolate = false) + u, t = munge_data(u, t) errfun(t, u, p) = sum(abs2.(u .- model(t, p))) if box == false mfit = optimize(p -> errfun(t, u, p), p0, alg) diff --git a/ext/DataInterpolationsRegularizationToolsExt.jl b/ext/DataInterpolationsRegularizationToolsExt.jl index 10ea3e4c..732ea1bb 100644 --- a/ext/DataInterpolationsRegularizationToolsExt.jl +++ b/ext/DataInterpolationsRegularizationToolsExt.jl @@ -69,8 +69,8 @@ A = RegularizationSmooth(u, t, t̂, wls, wr, d; λ = 1.0, alg = :gcv_svd) """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, wls::AbstractVector, wr::AbstractVector, d::Int = 2; - λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) M = _mapping_matrix(t̂, t) Wls½ = LA.diagm(sqrt.(wls)) Wr½ = LA.diagm(sqrt.(wr)) @@ -86,8 +86,8 @@ A = RegularizationSmooth(u, t, d; λ = 1.0, alg = :gcv_svd, extrapolate = false) """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -115,8 +115,8 @@ A = RegularizationSmooth(u, t, t̂, d; λ = 1.0, alg = :gcv_svd, extrapolate = f """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd, - extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + extrapolate::Bool = false) + u, t = munge_data(u, t) N, N̂ = length(t), length(t̂) M = _mapping_matrix(t̂, t) Wls½ = Array{Float64}(LA.I, N, N) @@ -143,8 +143,8 @@ A = RegularizationSmooth(u, t, t̂, wls, d; λ = 1.0, alg = :gcv_svd, extrapolat """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, wls::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) N, N̂ = length(t), length(t̂) M = _mapping_matrix(t̂, t) Wls½ = LA.diagm(sqrt.(wls)) @@ -172,8 +172,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -202,8 +202,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::AbstractVector, wr::AbstractVector, d::Int = 2; - λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -232,8 +232,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::Symbol, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd, - extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index c86a6579..19cb47c0 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -7,7 +7,6 @@ abstract type AbstractInterpolation{T} end using LinearAlgebra, RecipesBase using PrettyTables using ForwardDiff -using ReadOnlyArrays import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated, bracketstrictlymontonic @@ -88,12 +87,6 @@ function Base.showerror(io::IO, e::IntegralNotInvertibleError) print(io, INTEGRAL_NOT_INVERTIBLE_ERROR) end -const MUST_COPY_ERROR = "A copy must be made of u, t to filter missing data" -struct MustCopyError <: Exception end -function Base.showerror(io::IO, e::MustCopyError) - print(io, MUST_COPY_ERROR) -end - export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline, BSplineInterpolation, BSplineApprox, CubicHermiteSpline, @@ -126,12 +119,12 @@ struct RegularizationSmooth{uType, tType, T, T2, ITP <: AbstractInterpolation{T} Aitp, extrapolate) new{typeof(u), typeof(t), eltype(u), typeof(λ), typeof(Aitp)}( - readonly_wrap(u), - readonly_wrap(û), - readonly_wrap(t), - readonly_wrap(t̂), - readonly_wrap(oftype(u.parent, wls)), - readonly_wrap(oftype(u.parent, wr)), + u, + û, + t, + t̂, + wls, + wr, d, λ, alg, diff --git a/src/derivatives.jl b/src/derivatives.jl index 30c76fd0..01eb18bb 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -19,14 +19,16 @@ end function _derivative(A::LinearInterpolation, t::Number, iguess) idx = get_idx(A.t, t, iguess; idx_shift = -1, ub_shift = -2, side = :first) - A.p.slope[idx], idx + slope = get_parameters(A, idx) + slope, idx end function _derivative(A::QuadraticInterpolation, t::Number, iguess) i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess) - du₀ = A.p.l₀[i₀] * (2t - A.t[i₁] - A.t[i₂]) - du₁ = A.p.l₁[i₀] * (2t - A.t[i₀] - A.t[i₂]) - du₂ = A.p.l₂[i₀] * (2t - A.t[i₀] - A.t[i₁]) + l₀, l₁, l₂ = get_parameters(A, i₀) + du₀ = l₀ * (2t - A.t[i₁] - A.t[i₂]) + du₁ = l₁ * (2t - A.t[i₀] - A.t[i₂]) + du₂ = l₂ * (2t - A.t[i₀] - A.t[i₁]) return @views @. du₀ + du₁ + du₂, i₀ end @@ -129,7 +131,7 @@ end # QuadraticSpline Interpolation function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first) - σ = A.p.σ[idx - 1] + σ = get_parameters(A, idx - 1) A.z[idx - 1] + 2σ * (t - A.t[idx - 1]), idx end @@ -139,8 +141,9 @@ function _derivative(A::CubicSpline{<:AbstractVector}, t::Number, iguess) Δt₁ = t - A.t[idx] Δt₂ = A.t[idx + 1] - t dI = (-A.z[idx] * Δt₂^2 + A.z[idx + 1] * Δt₁^2) / (2A.h[idx + 1]) - dC = A.p.c₁[idx] - dD = -A.p.c₂[idx] + c₁, c₂ = get_parameters(A, idx) + dC = c₁ + dD = -c₂ dI + dC + dD, idx end @@ -193,7 +196,8 @@ function _derivative( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.du[idx] - out += Δt₀ * (Δt₀ * A.p.c₂[idx] + 2(A.p.c₁[idx] + Δt₁ * A.p.c₂[idx])) + c₁, c₂ = get_parameters(A, idx) + out += Δt₀ * (Δt₀ * c₂ + 2(c₁ + Δt₁ * c₂)) out, idx end @@ -204,7 +208,8 @@ function _derivative( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.du[idx] + A.ddu[idx] * Δt₀ + c₁, c₂, c₃ = get_parameters(A, idx) out += Δt₀^2 * - (3A.p.c₁[idx] + (3Δt₁ + Δt₀) * A.p.c₂[idx] + (3Δt₁^2 + Δt₀ * 2Δt₁) * A.p.c₃[idx]) + (3c₁ + (3Δt₁ + Δt₀) * c₂ + (3Δt₁^2 + Δt₀ * 2Δt₁) * c₃) out, idx end diff --git a/src/integral_inverses.jl b/src/integral_inverses.jl index 4437726e..38c14b14 100644 --- a/src/integral_inverses.jl +++ b/src/integral_inverses.jl @@ -40,10 +40,9 @@ struct LinearInterpolationIntInv{uType, tType, itpType, T} <: extrapolate::Bool idx_prev::Base.RefValue{Int} itp::itpType - safetycopy::Bool function LinearInterpolationIntInv(u, t, A) new{typeof(u), typeof(t), typeof(A), eltype(u)}( - u, t, A.extrapolate, Ref(1), A, A.safetycopy) + u, t, A.extrapolate, Ref(1), A) end end @@ -51,9 +50,11 @@ function invertible_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) return all(A.u .> 0) end +get_I(A::AbstractInterpolation) = isnothing(A.I) ? cumulative_integral(A) : A.I + function invert_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) !invertible_integral(A) && throw(IntegralNotInvertibleError()) - return LinearInterpolationIntInv(A.t, A.I, A) + return LinearInterpolationIntInv(A.t, get_I(A), A) end function _interpolate( @@ -61,7 +62,8 @@ function _interpolate( idx = get_idx(A.t, t, iguess) Δt = t - A.t[idx] x = A.itp.u[idx] - u = A.u[idx] + 2Δt / (x + sqrt(x^2 + A.itp.p.slope[idx] * 2Δt)) + slope = get_parameters(A.itp, idx) + u = A.u[idx] + 2Δt / (x + sqrt(x^2 + slope * 2Δt)) u, idx end @@ -84,10 +86,9 @@ struct ConstantInterpolationIntInv{uType, tType, itpType, T} <: extrapolate::Bool idx_prev::Base.RefValue{Int} itp::itpType - safetycopy::Bool function ConstantInterpolationIntInv(u, t, A) new{typeof(u), typeof(t), typeof(A), eltype(u)}( - u, t, A.extrapolate, Ref(1), A, A.safetycopy + u, t, A.extrapolate, Ref(1), A ) end end @@ -98,7 +99,7 @@ end function invert_integral(A::ConstantInterpolation{<:AbstractVector{<:Number}}) !invertible_integral(A) && throw(IntegralNotInvertibleError()) - return ConstantInterpolationIntInv(A.t, A.I, A) + return ConstantInterpolationIntInv(A.t, get_I(A), A) end function _interpolate( diff --git a/src/integrals.jl b/src/integrals.jl index 3040189f..03ea26c5 100644 --- a/src/integrals.jl +++ b/src/integrals.jl @@ -12,14 +12,24 @@ function integral(A::AbstractInterpolation, t1::Number, t2::Number) # the index less than t2 idx2 = get_idx(A.t, t2, 0; idx_shift = -1, side = :first) - total = A.I[idx2] - A.I[idx1] - return if t1 == t2 - zero(total) + if A.cache_parameters + total = A.I[idx2] - A.I[idx1] + return if t1 == t2 + zero(total) + else + total += _integral(A, idx1, A.t[idx1]) + total -= _integral(A, idx1, t1) + total += _integral(A, idx2, t2) + total -= _integral(A, idx2, A.t[idx2]) + total + end else - total += _integral(A, idx1, A.t[idx1]) - total -= _integral(A, idx1, t1) - total += _integral(A, idx2, t2) - total -= _integral(A, idx2, A.t[idx2]) + total = zero(eltype(A.u)) + for idx in idx1:idx2 + lt1 = idx == idx1 ? t1 : A.t[idx] + lt2 = idx == idx2 ? t2 : A.t[idx + 1] + total += _integral(A, idx, lt2) - _integral(A, idx, lt1) + end total end end @@ -28,7 +38,8 @@ function _integral(A::LinearInterpolation{<:AbstractVector{<:Number}}, idx::Number, t::Number) Δt = t - A.t[idx] - Δt * (A.u[idx] + A.p.slope[idx] * Δt / 2) + slope = get_parameters(A, idx) + Δt * (A.u[idx] + slope * Δt / 2) end function _integral( @@ -52,24 +63,27 @@ function _integral(A::QuadraticInterpolation{<:AbstractVector{<:Number}}, t₂ = A.t[idx + 2] t_sq = (t^2) / 3 - Iu₀ = A.p.l₀[idx] * t * (t_sq - t * (t₁ + t₂) / 2 + t₁ * t₂) - Iu₁ = A.p.l₁[idx] * t * (t_sq - t * (t₀ + t₂) / 2 + t₀ * t₂) - Iu₂ = A.p.l₂[idx] * t * (t_sq - t * (t₀ + t₁) / 2 + t₀ * t₁) + l₀, l₁, l₂ = get_parameters(A, idx) + Iu₀ = l₀ * t * (t_sq - t * (t₁ + t₂) / 2 + t₁ * t₂) + Iu₁ = l₁ * t * (t_sq - t * (t₀ + t₂) / 2 + t₀ * t₂) + Iu₂ = l₂ * t * (t_sq - t * (t₀ + t₁) / 2 + t₀ * t₁) return Iu₀ + Iu₁ + Iu₂ end function _integral(A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt^2 / 2 + A.p.σ[idx] * Δt^3 / 3 + Cᵢ * Δt + σ = get_parameters(A, idx) + return A.z[idx] * Δt^2 / 2 + σ * Δt^3 / 3 + Cᵢ * Δt end function _integral(A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) Δt₁sq = (t - A.t[idx])^2 / 2 Δt₂sq = (A.t[idx + 1] - t)^2 / 2 II = (-A.z[idx] * Δt₂sq^2 + A.z[idx + 1] * Δt₁sq^2) / (6A.h[idx + 1]) - IC = A.p.c₁[idx] * Δt₁sq - ID = -A.p.c₂[idx] * Δt₂sq + c₁, c₂ = get_parameters(A, idx) + IC = c₁ * Δt₁sq + ID = -c₂ * Δt₂sq II + IC + ID end @@ -91,8 +105,9 @@ function _integral( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = Δt₀ * (A.u[idx] + Δt₀ * A.du[idx] / 2) - p = A.p.c₁[idx] + Δt₁ * A.p.c₂[idx] - dp = A.p.c₂[idx] + c₁, c₂ = get_parameters(A, idx) + p = c₁ + Δt₁ * c₂ + dp = c₂ out += Δt₀^3 / 3 * (p - dp * Δt₀ / 4) out end @@ -103,9 +118,10 @@ function _integral( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = Δt₀ * (A.u[idx] + A.du[idx] * Δt₀ / 2 + A.ddu[idx] * Δt₀^2 / 6) - p = A.p.c₁[idx] + A.p.c₂[idx] * Δt₁ + A.p.c₃[idx] * Δt₁^2 - dp = A.p.c₂[idx] + 2A.p.c₃[idx] * Δt₁ - ddp = 2A.p.c₃[idx] + c₁, c₂, c₃ = get_parameters(A, idx) + p = c₁ + c₂ * Δt₁ + c₃ * Δt₁^2 + dp = c₂ + 2c₃ * Δt₁ + ddp = 2c₃ out += Δt₀^4 / 4 * (p - Δt₀ / 5 * dp + Δt₀^2 / 30 * ddp) out end diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index c7274471..286bf6bc 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -1,5 +1,5 @@ """ - LinearInterpolation(u, t; extrapolate = false) + LinearInterpolation(u, t; extrapolate = false, cache_parameters = false) It is the method of interpolating between the data points using a linear polynomial. For any point, two data points one each side are chosen and connected with a line. Extrapolation extends the last linear polynomial on each side. @@ -12,7 +12,7 @@ Extrapolation extends the last linear polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct LinearInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolation{T} u::uType @@ -21,23 +21,33 @@ struct LinearInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolati p::LinearParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function LinearInterpolation(u, t, I, p, extrapolate, safetycopy) + cache_parameters::Bool + function LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.slope), eltype(u)}( - u, t, I, p, extrapolate, Ref(1), safetycopy) + u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function LinearInterpolation(u, t; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - p = LinearParameterCache(u, t) - A = LinearInterpolation(u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - LinearInterpolation(u, t, I, p, extrapolate, safetycopy) +function LinearInterpolation(u, t; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + p = if cache_parameters + LinearParameterCache(u, t) + else + LinearParameterCache(nothing) + end + + A = LinearInterpolation(u, t, nothing, p, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) + end + + A end """ - QuadraticInterpolation(u, t, mode = :Forward; extrapolate = false) + QuadraticInterpolation(u, t, mode = :Forward; cache_parameters = false) It is the method of interpolating between the data points using quadratic polynomials. For any point, three data points nearby are taken to fit a quadratic polynomial. Extrapolation extends the last quadratic polynomial on each side. @@ -51,7 +61,7 @@ Extrapolation extends the last quadratic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuadraticInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolation{T} u::uType @@ -61,25 +71,35 @@ struct QuadraticInterpolation{uType, tType, IType, pType, T} <: AbstractInterpol mode::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function QuadraticInterpolation(u, t, I, p, mode, extrapolate, safetycopy) + cache_parameters::Bool + function QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) mode ∈ (:Forward, :Backward) || error("mode should be :Forward or :Backward for QuadraticInterpolation") new{typeof(u), typeof(t), typeof(I), typeof(p.l₀), eltype(u)}( - u, t, I, p, mode, extrapolate, Ref(1), safetycopy) + u, t, I, p, mode, extrapolate, Ref(1), cache_parameters) end end -function QuadraticInterpolation(u, t, mode; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - p = QuadraticParameterCache(u, t) - A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticInterpolation(u, t, I, p, mode, extrapolate, safetycopy) +function QuadraticInterpolation(u, t, mode; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + p = if cache_parameters + QuadraticParameterCache(u, t) + else + QuadraticParameterCache(nothing, nothing, nothing) + end + + A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) + end + + A end -function QuadraticInterpolation(u, t; extrapolate = false, safetycopy = true) - QuadraticInterpolation(u, t, :Forward; extrapolate, safetycopy) +function QuadraticInterpolation(u, t; extrapolate = false, cache_parameters = false) + QuadraticInterpolation(u, t, :Forward; extrapolate, cache_parameters) end """ @@ -96,7 +116,6 @@ It is the method of interpolation using Lagrange polynomials of (k-1)th order pa ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: AbstractInterpolation{T} @@ -107,8 +126,7 @@ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: idxs::Vector{Int} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function LagrangeInterpolation(u, t, n, extrapolate, safetycopy) + function LagrangeInterpolation(u, t, n, extrapolate) bcache = zeros(eltype(u[1]), n + 1) idxs = zeros(Int, n + 1) fill!(bcache, NaN) @@ -118,23 +136,22 @@ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: bcache, idxs, extrapolate, - Ref(1), - safetycopy + Ref(1) ) end end function LagrangeInterpolation( - u, t, n = length(t) - 1; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, n = length(t) - 1; extrapolate = false) + u, t = munge_data(u, t) if n != length(t) - 1 error("Currently only n=length(t) - 1 is supported") end - LagrangeInterpolation(u, t, n, extrapolate, safetycopy) + LagrangeInterpolation(u, t, n, extrapolate) end """ - AkimaInterpolation(u, t; extrapolate = false) + AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation built from cubic polynomials. It forms a continuously differentiable function. For more details, refer: https://en.wikipedia.org/wiki/Akima_spline. Extrapolation extends the last cubic polynomial on each side. @@ -147,7 +164,7 @@ Extrapolation extends the last cubic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: AbstractInterpolation{T} @@ -159,8 +176,8 @@ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: d::dType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function AkimaInterpolation(u, t, I, b, c, d, extrapolate, safetycopy) + cache_parameters::Bool + function AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(b), typeof(c), typeof(d), eltype(u)}(u, t, @@ -170,13 +187,13 @@ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: d, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end -function AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) +function AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) n = length(t) dt = diff(t) m = Array{eltype(u)}(undef, n + 3) @@ -197,13 +214,18 @@ function AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) c = (3.0 .* m[3:(end - 2)] .- 2.0 .* b[1:(end - 1)] .- b[2:end]) ./ dt d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2 - A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, safetycopy) - I = cumulative_integral(A) - AkimaInterpolation(u, t, I, b, c, d, extrapolate, safetycopy) + A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) + end + + A end """ - ConstantInterpolation(u, t; dir = :left, extrapolate = false) + ConstantInterpolation(u, t; dir = :left, extrapolate = false, cache_parameters = false) It is the method of interpolating using a constant polynomial. For any point, two adjacent data points are found on either side (left and right). The value at that point depends on `dir`. If it is `:left`, then the value at the left point is chosen and if it is `:right`, the value at the right point is chosen. @@ -218,7 +240,7 @@ Extrapolation extends the last constant polynomial at the end points on each sid - `dir`: indicates which value should be used for interpolation (`:left` or `:right`). - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct ConstantInterpolation{uType, tType, IType, T} <: AbstractInterpolation{T} u::uType @@ -228,22 +250,28 @@ struct ConstantInterpolation{uType, tType, IType, T} <: AbstractInterpolation{T} dir::Symbol # indicates if value to the $dir should be used for the interpolation extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function ConstantInterpolation(u, t, I, dir, extrapolate, safetycopy) + cache_parameters::Bool + function ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), eltype(u)}( - u, t, I, nothing, dir, extrapolate, Ref(1), safetycopy) + u, t, I, nothing, dir, extrapolate, Ref(1), cache_parameters) end end -function ConstantInterpolation(u, t; dir = :left, extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - A = ConstantInterpolation(u, t, nothing, dir, extrapolate, safetycopy) - I = cumulative_integral(A) - ConstantInterpolation(u, t, I, dir, extrapolate, safetycopy) +function ConstantInterpolation( + u, t; dir = :left, extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + A = ConstantInterpolation(u, t, nothing, dir, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) + end + + A end """ - QuadraticSpline(u, t; extrapolate = false) + QuadraticSpline(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation using piecewise quadratic polynomials between each pair of data points. Its first derivative is also continuous. Extrapolation extends the last quadratic polynomial on each side. @@ -256,7 +284,7 @@ Extrapolation extends the last quadratic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: AbstractInterpolation{T} @@ -269,8 +297,8 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: z::zType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + cache_parameters::Bool + function QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.σ), typeof(tA), typeof(d), typeof(z), eltype(u)}(u, t, @@ -281,15 +309,15 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: z, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end function QuadraticSpline( u::uType, t; extrapolate = false, - safetycopy = true) where {uType <: AbstractVector{<:Number}} - u, t = munge_data(u, t, safetycopy) + cache_parameters = false) where {uType <: AbstractVector{<:Number}} + u, t = munge_data(u, t) s = length(t) dl = ones(eltype(t), s - 1) d_tmp = ones(eltype(t), s) @@ -301,15 +329,27 @@ function QuadraticSpline( d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s) z = tA \ d - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + + p = if cache_parameters + QuadraticSplineParameterCache(z, t) + else + QuadraticSplineParameterCache(nothing) + end + + A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) + end + + A end function QuadraticSpline( - u::uType, t; extrapolate = false, safetycopy = true) where {uType <: AbstractVector} - u, t = munge_data(u, t, safetycopy) + u::uType, t; extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector} + u, t = munge_data(u, t) s = length(t) dl = ones(eltype(t), s - 1) d_tmp = ones(eltype(t), s) @@ -322,14 +362,23 @@ function QuadraticSpline( d = transpose(reshape(reduce(hcat, d_), :, s)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + p = if cache_parameters + QuadraticSplineParameterCache(z, t) + else + QuadraticSplineParameterCache(nothing) + end + A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) + end + + A end """ - CubicSpline(u, t; extrapolate = false) + CubicSpline(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation using piecewise cubic polynomials between each pair of data points. Its first and second derivative is also continuous. Second derivative on both ends are zero, which are also called "natural" boundary conditions. Extrapolation extends the last cubic polynomial on each side. @@ -342,7 +391,7 @@ Second derivative on both ends are zero, which are also called "natural" boundar ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInterpolation{T} u::uType @@ -353,8 +402,8 @@ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInter z::zType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function CubicSpline(u, t, I, p, h, z, extrapolate, safetycopy) + cache_parameters::Bool + function CubicSpline(u, t, I, p, h, z, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.c₁), typeof(h), typeof(z), eltype(u)}( u, t, @@ -364,15 +413,16 @@ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInter z, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end function CubicSpline(u::uType, t; - extrapolate = false, safetycopy = true) where {uType <: AbstractVector{<:Number}} - u, t = munge_data(u, t, safetycopy) + extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector{<:Number}} + u, t = munge_data(u, t) n = length(t) - 1 h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0) dl = vcat(h[2:n], zero(eltype(h))) @@ -389,15 +439,25 @@ function CubicSpline(u::uType, 6(u[i + 1] - u[i]) / h[i + 1] - 6(u[i] - u[i - 1]) / h[i], 1:(n + 1)) z = tA \ d - p = CubicSplineParameterCache(u, h, z) - A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) + p = if cache_parameters + CubicSplineParameterCache(u, h, z) + else + CubicSplineParameterCache(nothing, nothing) + end + A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + end + + A end function CubicSpline( - u::uType, t; extrapolate = false, safetycopy = true) where {uType <: AbstractVector} - u, t = munge_data(u, t, safetycopy) + u::uType, t; extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector} + u, t = munge_data(u, t) n = length(t) - 1 h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0) dl = vcat(h[2:n], zero(eltype(h))) @@ -411,10 +471,20 @@ function CubicSpline( d = transpose(reshape(reduce(hcat, d_), :, n + 1)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = CubicSplineParameterCache(u, h, z) - A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) + p = if cache_parameters + CubicSplineParameterCache(u, h, z) + else + CubicSplineParameterCache(nothing, nothing) + end + + A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + end + + A end """ @@ -434,7 +504,6 @@ Extrapolation is a constant polynomial of the end points on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: AbstractInterpolation{T} @@ -449,7 +518,6 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: knotVecType::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool function BSplineInterpolation(u, t, d, @@ -459,8 +527,7 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: N, pVecType, knotVecType, - extrapolate, - safetycopy) + extrapolate) new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), typeof(N), eltype(u)}(u, t, d, @@ -471,15 +538,14 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: pVecType, knotVecType, extrapolate, - Ref(1), - safetycopy + Ref(1) ) end end function BSplineInterpolation( - u, t, d, pVecType, knotVecType; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, d, pVecType, knotVecType; extrapolate = false) + u, t = munge_data(u, t) n = length(t) n < d + 1 && error("BSplineInterpolation needs at least d + 1, i.e. $(d+1) points.") s = zero(eltype(u)) @@ -543,11 +609,11 @@ function BSplineInterpolation( c = vec(N \ u[:, :]) N = zeros(eltype(t), n) BSplineInterpolation( - u, t, d, p, k, c, N, pVecType, knotVecType, extrapolate, safetycopy) + u, t, d, p, k, c, N, pVecType, knotVecType, extrapolate) end """ - BSplineApprox(u, t, d, h, pVecType, knotVecType; extrapolate = false) + BSplineApprox(u, t, d, h, pVecType, knotVecType) It is a regression based B-spline. The argument choices are the same as the `BSplineInterpolation`, with the additional parameter `h < length(t)` which is the number of control points to use, with smaller `h` indicating more smoothing. For more information, refer http://www.cad.zju.edu.cn/home/zhx/GM/009/00-bsia.pdf. @@ -565,7 +631,6 @@ Extrapolation is a constant polynomial of the end points on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: AbstractInterpolation{T} @@ -581,7 +646,6 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: knotVecType::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool function BSplineApprox(u, t, d, @@ -592,8 +656,7 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: N, pVecType, knotVecType, - extrapolate, - safetycopy + extrapolate ) new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), typeof(N), eltype(u)}(u, t, @@ -606,15 +669,14 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: pVecType, knotVecType, extrapolate, - Ref(1), - safetycopy::Bool + Ref(1) ) end end function BSplineApprox( - u, t, d, h, pVecType, knotVecType; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, d, h, pVecType, knotVecType; extrapolate = false) + u, t = munge_data(u, t) n = length(t) h < d + 1 && error("BSplineApprox needs at least d + 1, i.e. $(d+1) control points.") s = zero(eltype(u)) @@ -698,11 +760,12 @@ function BSplineApprox( P = M \ Q c[2:(end - 1)] .= vec(P) N = zeros(eltype(t), h) - BSplineApprox(u, t, d, h, p, k, c, N, pVecType, knotVecType, extrapolate, safetycopy) + BSplineApprox( + u, t, d, h, p, k, c, N, pVecType, knotVecType, extrapolate) end """ - CubicHermiteSpline(du, u, t; extrapolate = false) + CubicHermiteSpline(du, u, t; extrapolate = false, cache_parameters = false) It is a Cubic Hermite interpolation, which is a piece-wise third degree polynomial such that the value and the first derivative are equal to given values in the data points. @@ -715,7 +778,7 @@ It is a Cubic Hermite interpolation, which is a piece-wise third degree polynomi ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct CubicHermiteSpline{uType, tType, IType, duType, pType, T} <: AbstractInterpolation{T} du::duType @@ -725,24 +788,33 @@ struct CubicHermiteSpline{uType, tType, IType, duType, pType, T} <: AbstractInte p::CubicHermiteParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function CubicHermiteSpline(du, u, t, I, p, extrapolate, safetycopy) + cache_parameters::Bool + function CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(du), typeof(p.c₁), eltype(u)}( - du, u, t, I, p, extrapolate, Ref(1), safetycopy) + du, u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function CubicHermiteSpline(du, u, t; extrapolate = false, safetycopy = true) +function CubicHermiteSpline(du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du) "Length of `u` is not equal to length of `du`." - u, t = munge_data(u, t, safetycopy) - p = CubicHermiteParameterCache(du, u, t) - A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicHermiteSpline(du, u, t, I, p, extrapolate, safetycopy) + u, t = munge_data(u, t) + p = if cache_parameters + CubicHermiteParameterCache(du, u, t) + else + CubicHermiteParameterCache(nothing, nothing) + end + A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) + end + + A end """ - QuinticHermiteSpline(ddu, du, u, t; extrapolate = false) + QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, cache_parameters = false) It is a Quintic Hermite interpolation, which is a piece-wise fifth degree polynomial such that the value and the first and second derivative are equal to given values in the data points. @@ -756,7 +828,7 @@ It is a Quintic Hermite interpolation, which is a piece-wise fifth degree polyno ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuinticHermiteSpline{uType, tType, IType, duType, dduType, pType, T} <: AbstractInterpolation{T} @@ -768,19 +840,28 @@ struct QuinticHermiteSpline{uType, tType, IType, duType, dduType, pType, T} <: p::QuinticHermiteParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, safetycopy) + cache_parameters::Bool + function QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(du), typeof(ddu), typeof(p.c₁), eltype(u)}( - ddu, du, u, t, I, p, extrapolate, Ref(1), safetycopy) + ddu, du, u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, safetycopy = true) +function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du)==length(ddu) "Length of `u` is not equal to length of `du` or `ddu`." - u, t = munge_data(u, t, safetycopy) - p = QuinticHermiteParameterCache(ddu, du, u, t) - A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, safetycopy) + u, t = munge_data(u, t) + p = if cache_parameters + QuinticHermiteParameterCache(ddu, du, u, t) + else + QuinticHermiteParameterCache(nothing, nothing, nothing) + end + A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) + end + + A end diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 5d09ceff..c409d5fc 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -10,15 +10,15 @@ end function _interpolate(A::LinearInterpolation{<:AbstractVector}, t::Number, iguess) if isnan(t) # For correct derivative with NaN - idx = firstindex(A.u) - 1 + idx = firstindex(A.u) t1 = t2 = one(eltype(A.t)) u1 = u2 = one(eltype(A.u)) - slope = t * one(eltype(A.p.slope)) + slope = t * get_parameters(A, idx) else idx = get_idx(A.t, t, iguess) t1, t2 = A.t[idx], A.t[idx + 1] u1, u2 = A.u[idx], A.u[idx + 1] - slope = A.p.slope[idx] + slope = get_parameters(A, idx) end Δt = t - t1 @@ -38,7 +38,8 @@ end function _interpolate(A::LinearInterpolation{<:AbstractMatrix}, t::Number, iguess) idx = get_idx(A.t, t, iguess) Δt = t - A.t[idx] - return A.u[:, idx] + A.p.slope[idx] * Δt, idx + slope = get_parameters(A, idx) + return A.u[:, idx] + slope * Δt, idx end # Quadratic Interpolation @@ -50,9 +51,10 @@ end function _interpolate(A::QuadraticInterpolation, t::Number, iguess) i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess) - u₀ = A.p.l₀[i₀] * (t - A.t[i₁]) * (t - A.t[i₂]) - u₁ = A.p.l₁[i₀] * (t - A.t[i₀]) * (t - A.t[i₂]) - u₂ = A.p.l₂[i₀] * (t - A.t[i₀]) * (t - A.t[i₁]) + l₀, l₁, l₂ = get_parameters(A, i₀) + u₀ = l₀ * (t - A.t[i₁]) * (t - A.t[i₂]) + u₁ = l₁ * (t - A.t[i₀]) * (t - A.t[i₂]) + u₂ = l₂ * (t - A.t[i₀]) * (t - A.t[i₁]) return u₀ + u₁ + u₂, i₀ end @@ -149,7 +151,8 @@ function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt + A.p.σ[idx] * Δt^2 + Cᵢ, idx + σ = get_parameters(A, idx) + return A.z[idx] * Δt + σ * Δt^2 + Cᵢ, idx end # CubicSpline Interpolation @@ -158,8 +161,9 @@ function _interpolate(A::CubicSpline{<:AbstractVector}, t::Number, iguess) Δt₁ = t - A.t[idx] Δt₂ = A.t[idx + 1] - t I = (A.z[idx] * Δt₂^3 + A.z[idx + 1] * Δt₁^3) / (6A.h[idx + 1]) - C = A.p.c₁[idx] * Δt₁ - D = A.p.c₂[idx] * Δt₂ + c₁, c₂ = get_parameters(A, idx) + C = c₁ * Δt₁ + D = c₂ * Δt₂ I + C + D, idx end @@ -205,7 +209,8 @@ function _interpolate( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.u[idx] + Δt₀ * A.du[idx] - out += Δt₀^2 * (A.p.c₁[idx] + Δt₁ * A.p.c₂[idx]) + c₁, c₂ = get_parameters(A, idx) + out += Δt₀^2 * (c₁ + Δt₁ * c₂) out, idx end @@ -216,6 +221,7 @@ function _interpolate( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.u[idx] + Δt₀ * (A.du[idx] + A.ddu[idx] * Δt₀ / 2) - out += Δt₀^3 * (A.p.c₁[idx] + Δt₁ * (A.p.c₂[idx] + A.p.c₃[idx] * Δt₁)) + c₁, c₂, c₃ = get_parameters(A, idx) + out += Δt₀^3 * (c₁ + Δt₁ * (c₂ + c₃ * Δt₁)) out, idx end diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 466248b1..17b08328 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -60,15 +60,11 @@ function spline_coefficients!(N, d, k, u::AbstractVector) end # helper function for data manipulation -function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real}, safetycopy::Bool) - if safetycopy - u = copy(u) - t = copy(t) - end - return readonly_wrap(u), readonly_wrap(t) +function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real}) + return u, t end -function munge_data(u::AbstractVector, t::AbstractVector, safetycopy::Bool) +function munge_data(u::AbstractVector, t::AbstractVector) Tu = Base.nonmissingtype(eltype(u)) Tt = Base.nonmissingtype(eltype(t)) @assert length(t) == length(u) @@ -77,17 +73,13 @@ function munge_data(u::AbstractVector, t::AbstractVector, safetycopy::Bool) if !ismissing(u[i]) && !ismissing(t[i]) ) - if safetycopy - u = Tu.([u[i] for i in non_missing_indices]) - t = Tt.([t[i] for i in non_missing_indices]) - else - !isempty(non_missing_indices) && throw(MustCopyError()) - end + u = Tu.([u[i] for i in non_missing_indices]) + t = Tt.([t[i] for i in non_missing_indices]) - return readonly_wrap(u), readonly_wrap(t) + return u, t end -function munge_data(U::StridedMatrix, t::AbstractVector, safetycopy::Bool) +function munge_data(U::StridedMatrix, t::AbstractVector) TU = Base.nonmissingtype(eltype(U)) Tt = Base.nonmissingtype(eltype(t)) @assert length(t) == size(U, 2) @@ -96,20 +88,12 @@ function munge_data(U::StridedMatrix, t::AbstractVector, safetycopy::Bool) if !any(ismissing, U[:, i]) && !ismissing(t[i]) ) - if safetycopy - U = hcat([TU.(U[:, i]) for i in non_missing_indices]...) - t = Tt.([t[i] for i in non_missing_indices]) - else - !isempty(non_missing_indices) && throw(MustCopyError()) - end + U = hcat([TU.(U[:, i]) for i in non_missing_indices]...) + t = Tt.([t[i] for i in non_missing_indices]) - return readonly_wrap(U), readonly_wrap(t) + return U, t end -# Don't nest ReadOnlyArrays -readonly_wrap(a::AbstractArray) = ReadOnlyArray(a) -readonly_wrap(a::ReadOnlyArray) = a - function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = :last) ub = length(tvec) + ub_shift return if side == :last @@ -130,3 +114,51 @@ function cumulative_integral(A) pushfirst!(integral_values, zero(first(integral_values))) return cumsum(integral_values) end + +function get_parameters(A::LinearInterpolation, idx) + if A.cache_parameters + A.p.slope[idx] + else + linear_interpolation_parameters(A.u, A.t, idx) + end +end + +function get_parameters(A::QuadraticInterpolation, idx) + if A.cache_parameters + A.p.l₀[idx], A.p.l₁[idx], A.p.l₂[idx] + else + quadratic_interpolation_parameters(A.u, A.t, idx) + end +end + +function get_parameters(A::QuadraticSpline, idx) + if A.cache_parameters + A.p.σ[idx] + else + quadratic_spline_parameters(A.z, A.t, idx) + end +end + +function get_parameters(A::CubicSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx] + else + cubic_spline_parameters(A.u, A.h, A.z, idx) + end +end + +function get_parameters(A::CubicHermiteSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx] + else + cubic_hermite_spline_parameters(A.du, A.u, A.t, idx) + end +end + +function get_parameters(A::QuinticHermiteSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx], A.p.c₃[idx] + else + quintic_hermite_spline_parameters(A.ddu, A.du, A.u, A.t, idx) + end +end diff --git a/src/online.jl b/src/online.jl index 0fab5d44..5193e6b2 100644 --- a/src/online.jl +++ b/src/online.jl @@ -9,69 +9,81 @@ function add_integral_values!(A) end function push!(A::LinearInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - slope = linear_interpolation_parameters(A.u, A.t, length(A.t) - 1) - push!(A.p.slope, slope) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + slope = linear_interpolation_parameters(A.u, A.t, length(A.t) - 1) + push!(A.p.slope, slope) + add_integral_values!(A) + end A end function push!(A::QuadraticInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - l₀, l₁, l₂ = quadratic_interpolation_parameters(A.u, A.t, length(A.t) - 2) - push!(A.p.l₀, l₀) - push!(A.p.l₁, l₁) - push!(A.p.l₂, l₂) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + l₀, l₁, l₂ = quadratic_interpolation_parameters(A.u, A.t, length(A.t) - 2) + push!(A.p.l₀, l₀) + push!(A.p.l₁, l₁) + push!(A.p.l₂, l₂) + add_integral_values!(A) + end A end function push!(A::ConstantInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + add_integral_values!(A) + end A end function append!( - A::LinearInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} + A::LinearInterpolation{U, T}, u::U, t::T) where { + U, T} length_old = length(A.t) - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - slope = linear_interpolation_parameters.( - Ref(A.u), Ref(A.t), length_old:(length(A.t) - 1)) - append!(A.p.slope, slope) - add_integral_values!(A) + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + slope = linear_interpolation_parameters.( + Ref(A.u), Ref(A.t), length_old:(length(A.t) - 1)) + append!(A.p.slope, slope) + add_integral_values!(A) + end A end function append!( - A::ConstantInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - add_integral_values!(A) + A::ConstantInterpolation{U, T}, u::U, t::T) where { + U, T} + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + add_integral_values!(A) + end A end function append!( - A::QuadraticInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} + A::QuadraticInterpolation{U, T}, u::U, t::T) where { + U, T} length_old = length(A.t) - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - parameters = quadratic_interpolation_parameters.( - Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) - l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) - append!(A.p.l₀, l₀) - append!(A.p.l₁, l₁) - append!(A.p.l₂, l₂) - add_integral_values!(A) + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + parameters = quadratic_interpolation_parameters.( + Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) + l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) + append!(A.p.l₀, l₀) + append!(A.p.l₁, l₁) + append!(A.p.l₂, l₂) + add_integral_values!(A) + end A end diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 9549038c..9562c7b4 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -9,7 +9,6 @@ function test_interpolation_type(T) @test hasfield(T, :t) @test hasfield(T, :extrapolate) @test hasfield(T, :idx_prev) - @test hasfield(T, :safetycopy) @test !isempty(methods(DataInterpolations._interpolate, (T, Any, Number))) @test !isempty(methods(DataInterpolations._integral, (T, Any, Number))) @test !isempty(methods(DataInterpolations._derivative, (T, Any, Number))) diff --git a/test/online_tests.jl b/test/online_tests.jl index f9c3e1dd..3ae6438e 100644 --- a/test/online_tests.jl +++ b/test/online_tests.jl @@ -9,9 +9,9 @@ u2 = [1.0, 2.0, 1.0] ts = 1.0:0.5:6.0 for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolation] - func1 = method(u1, t1) + func1 = method(copy(u1), copy(t1); cache_parameters = true) append!(func1, u2, t2) - func2 = method(vcat(u1, u2), vcat(t1, t2)) + func2 = method(vcat(u1, u2), vcat(t1, t2); cache_parameters = true) @test func1.u == func2.u @test func1.t == func2.t for name in propertynames(func1.p) @@ -20,9 +20,9 @@ for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolatio @test func1(ts) == func2(ts) @test func1.I == func2.I - func1 = method(u1, t1) + func1 = method(copy(u1), copy(t1); cache_parameters = true) push!(func1, 1.0, 4.0) - func2 = method(vcat(u1, 1.0), vcat(t1, 4.0)) + func2 = method(vcat(u1, 1.0), vcat(t1, 4.0); cache_parameters = true) @test func1.u == func2.u @test func1.t == func2.t for name in propertynames(func1.p) diff --git a/test/parameter_tests.jl b/test/parameter_tests.jl index bcd26cf7..2e84b98d 100644 --- a/test/parameter_tests.jl +++ b/test/parameter_tests.jl @@ -3,14 +3,14 @@ using DataInterpolations @testset "Linear Interpolation" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = LinearInterpolation(u, t) + A = LinearInterpolation(u, t; cache_parameters = true) @test A.p.slope ≈ [4.0, -2.0, 1.0, 0.0] end @testset "Quadratic Interpolation" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuadraticInterpolation(u, t) + A = QuadraticInterpolation(u, t; cache_parameters = true) @test A.p.l₀ ≈ [0.5, 2.5, 1.5] @test A.p.l₁ ≈ [-5.0, -3.0, -4.0] @test A.p.l₂ ≈ [1.5, 2.0, 2.0] @@ -19,14 +19,14 @@ end @testset "Quadratic Spline" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuadraticSpline(u, t) + A = QuadraticSpline(u, t; cache_parameters = true) @test A.p.σ ≈ [4.0, -10.0, 13.0, -14.0] end @testset "Cubic Spline" begin u = [1, 5, 3, 4, 4] t = collect(1:5) - A = CubicSpline(u, t) + A = CubicSpline(u, t; cache_parameters = true) @test A.p.c₁ ≈ [6.839285714285714, 1.642857142857143, 4.589285714285714, 4.0] @test A.p.c₂ ≈ [1.0, 6.839285714285714, 1.642857142857143, 4.589285714285714] end @@ -35,7 +35,7 @@ end du = [5.0, 3.0, 6.0, 8.0, 1.0] u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = CubicHermiteSpline(du, u, t) + A = CubicHermiteSpline(du, u, t; cache_parameters = true) @test A.p.c₁ ≈ [-1.0, -5.0, -5.0, -8.0] @test A.p.c₂ ≈ [0.0, 13.0, 12.0, 9.0] end @@ -45,7 +45,7 @@ end du = [5.0, 3.0, 6.0, 8.0, 1.0] u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuinticHermiteSpline(ddu, du, u, t) + A = QuinticHermiteSpline(ddu, du, u, t; cache_parameters = true) @test A.p.c₁ ≈ [-1.0, -6.5, -8.0, -10.0] @test A.p.c₂ ≈ [1.0, 19.5, 20.0, 19.0] @test A.p.c₃ ≈ [1.5, -37.5, -37.0, -26.5] diff --git a/test/runtests.jl b/test/runtests.jl index 0c722b2d..80080a75 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,3 +10,4 @@ using SafeTestsets @safetestset "Online Tests" include("online_tests.jl") @safetestset "Regularization Smoothing" include("regularization.jl") @safetestset "Show methods" include("show.jl") +@safetestset "Zygote support" include("zygote_tests.jl") diff --git a/test/zygote_tests.jl b/test/zygote_tests.jl new file mode 100644 index 00000000..6887ddd2 --- /dev/null +++ b/test/zygote_tests.jl @@ -0,0 +1,66 @@ +using DataInterpolations +using ForwardDiff +using Zygote + +function test_zygote(method, u, t; args = [], kwargs = [], name::String) + func = method(args..., u, t; kwargs..., extrapolate = true) + (; u, t) = func + trange = collect(range(minimum(t) - 5.0, maximum(t) + 5.0, step = 0.1)) + trange_exclude = filter(x -> !in(x, t), trange) + @testset "$name, derivatives w.r.t. input" begin + for _t in trange_exclude + adiff = DataInterpolations.derivative(func, _t) + zdiff = only(Zygote.gradient(func, _t)) + zdiff == nothing && (zdiff = 0.0) + @test adiff ≈ zdiff + end + end + @testset "$name, derivatives w.r.t. u" begin + function f(u) + A = method(args..., u, t; kwargs..., extrapolate = true) + out = zero(eltype(u)) + for _t in trange + out += A(_t) + end + out + end + zgrad = only(Zygote.gradient(f, u)) + fgrad = ForwardDiff.gradient(f, u) + @test zgrad ≈ fgrad + end +end + +@testset "LinearInterpolation" begin + u = vcat(collect(1:5), 2 * collect(6:10)) + t = 1.0collect(1:10) + test_zygote( + LinearInterpolation, u, t; name = "Linear Interpolation") +end + +@testset "Quadratic Interpolation" begin + u = [1.0, 4.0, 9.0, 16.0] + t = [1.0, 2.0, 3.0, 4.0] + test_zygote(QuadraticInterpolation, u, t; name = "Quadratic Interpolation") +end + +@testset "Constant Interpolation" begin + u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] + t = collect(0.0:10.0) + test_zygote(ConstantInterpolation, u, t; name = "Constant Interpolation") +end + +@testset "Cubic Hermite Spline" begin + du = [-0.047, -0.058, 0.054, 0.012, -0.068, 0.0] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + t = [0.0, 62.25, 109.66, 162.66, 205.8, 252.3] + test_zygote(CubicHermiteSpline, u, t, args = [du], name = "Cubic Hermite Spline") +end + +@testset "Quintic Hermite Spline" begin + ddu = [0.0, -0.00033, 0.0051, -0.0067, 0.0029, 0.0] + du = [-0.047, -0.058, 0.054, 0.012, -0.068, 0.0] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + t = [0.0, 62.25, 109.66, 162.66, 205.8, 252.3] + test_zygote( + QuinticHermiteSpline, u, t, args = [ddu, du], name = "Quintic Hermite Spline") +end From 89757b62b97ecbbef722429824e4eb23c528c0cd Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 27 Jul 2024 17:21:54 +0200 Subject: [PATCH 21/35] Refactor parameter caching, add zygote tests --- Project.toml | 6 +- docs/src/interface.md | 17 +- ext/DataInterpolationsOptimExt.jl | 5 +- ...ataInterpolationsRegularizationToolsExt.jl | 28 +- src/DataInterpolations.jl | 19 +- src/derivatives.jl | 23 +- src/integral_inverses.jl | 15 +- src/integrals.jl | 54 ++- src/interpolation_caches.jl | 335 +++++++++++------- src/interpolation_methods.jl | 30 +- src/interpolation_utils.jl | 84 +++-- src/online.jl | 96 ++--- test/interpolation_tests.jl | 1 - test/online_tests.jl | 8 +- test/parameter_tests.jl | 12 +- test/runtests.jl | 1 + test/zygote_tests.jl | 66 ++++ 17 files changed, 498 insertions(+), 302 deletions(-) create mode 100644 test/zygote_tests.jl diff --git a/Project.toml b/Project.toml index b06d2adb..963ee8cd 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -33,7 +32,6 @@ LinearAlgebra = "1.10" Optim = "1.6" PrettyTables = "2" QuadGK = "2.9.1" -ReadOnlyArrays = "0.2.0" RecipesBase = "1.3" Reexport = "1" RegularizationTools = "0.6" @@ -41,6 +39,7 @@ SafeTestsets = "0.1" StableRNGs = "1" Symbolics = "5.29" Test = "1" +Zygote = "0.6" julia = "1.10" [extras] @@ -55,6 +54,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics"] +test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Zygote"] diff --git a/docs/src/interface.md b/docs/src/interface.md index ca5e9819..cfc9ed0b 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -35,22 +35,7 @@ A2(300.0) The values computed beyond the range of the time points provided during interpolation will not be reliable, as these methods only perform well within the range and the first/last piece polynomial fit is extrapolated on either side which might not reflect the true nature of the data. -The keyword `safetycopy = false` can be passed to make sure no copies of `u` and `t` are made when initializing the interpolation object. - -```@example interface -A3 = QuadraticInterpolation(u, t; safetycopy = false) - -# Check for same memory -u === A3.u.parent -``` - -Note that this does not prevent allocation in every interpolation constructor call, because parameter values are cached for all interpolation types except [`ConstantInterpolation`](@ref). - -Because of the caching of parameters which depend on `u` and `t`, this data should not be mutated. Therefore `u` and `t` are wrapped in a `ReadOnlyArray` from [ReadOnlyArrays.jl](https://github.com/JuliaArrays/ReadOnlyArrays.jl). - -```@repl interface -A3.t[2] = 3.14 -``` +The keyword `cache_parameters = true` can be passed to precalculate parameters at initialization, making evalations cheaper to compute. This is not compatible with modifying `u` and `t`. The default `cache_parameters = false` does however not prevent allocation in every interpolation constructor call. ## Derivatives diff --git a/ext/DataInterpolationsOptimExt.jl b/ext/DataInterpolationsOptimExt.jl index 5528503f..b3bce295 100644 --- a/ext/DataInterpolationsOptimExt.jl +++ b/ext/DataInterpolationsOptimExt.jl @@ -18,9 +18,8 @@ function Curvefit(u, box = false, lb = nothing, ub = nothing; - extrapolate = false, - safetycopy = false) - u, t = munge_data(u, t, safetycopy) + extrapolate = false) + u, t = munge_data(u, t) errfun(t, u, p) = sum(abs2.(u .- model(t, p))) if box == false mfit = optimize(p -> errfun(t, u, p), p0, alg) diff --git a/ext/DataInterpolationsRegularizationToolsExt.jl b/ext/DataInterpolationsRegularizationToolsExt.jl index 10ea3e4c..732ea1bb 100644 --- a/ext/DataInterpolationsRegularizationToolsExt.jl +++ b/ext/DataInterpolationsRegularizationToolsExt.jl @@ -69,8 +69,8 @@ A = RegularizationSmooth(u, t, t̂, wls, wr, d; λ = 1.0, alg = :gcv_svd) """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, wls::AbstractVector, wr::AbstractVector, d::Int = 2; - λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) M = _mapping_matrix(t̂, t) Wls½ = LA.diagm(sqrt.(wls)) Wr½ = LA.diagm(sqrt.(wr)) @@ -86,8 +86,8 @@ A = RegularizationSmooth(u, t, d; λ = 1.0, alg = :gcv_svd, extrapolate = false) """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -115,8 +115,8 @@ A = RegularizationSmooth(u, t, t̂, d; λ = 1.0, alg = :gcv_svd, extrapolate = f """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd, - extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + extrapolate::Bool = false) + u, t = munge_data(u, t) N, N̂ = length(t), length(t̂) M = _mapping_matrix(t̂, t) Wls½ = Array{Float64}(LA.I, N, N) @@ -143,8 +143,8 @@ A = RegularizationSmooth(u, t, t̂, wls, d; λ = 1.0, alg = :gcv_svd, extrapolat """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, wls::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) N, N̂ = length(t), length(t̂) M = _mapping_matrix(t̂, t) Wls½ = LA.diagm(sqrt.(wls)) @@ -172,8 +172,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -202,8 +202,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::AbstractVector, wr::AbstractVector, d::Int = 2; - λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -232,8 +232,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::Symbol, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd, - extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index c86a6579..19cb47c0 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -7,7 +7,6 @@ abstract type AbstractInterpolation{T} end using LinearAlgebra, RecipesBase using PrettyTables using ForwardDiff -using ReadOnlyArrays import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated, bracketstrictlymontonic @@ -88,12 +87,6 @@ function Base.showerror(io::IO, e::IntegralNotInvertibleError) print(io, INTEGRAL_NOT_INVERTIBLE_ERROR) end -const MUST_COPY_ERROR = "A copy must be made of u, t to filter missing data" -struct MustCopyError <: Exception end -function Base.showerror(io::IO, e::MustCopyError) - print(io, MUST_COPY_ERROR) -end - export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline, BSplineInterpolation, BSplineApprox, CubicHermiteSpline, @@ -126,12 +119,12 @@ struct RegularizationSmooth{uType, tType, T, T2, ITP <: AbstractInterpolation{T} Aitp, extrapolate) new{typeof(u), typeof(t), eltype(u), typeof(λ), typeof(Aitp)}( - readonly_wrap(u), - readonly_wrap(û), - readonly_wrap(t), - readonly_wrap(t̂), - readonly_wrap(oftype(u.parent, wls)), - readonly_wrap(oftype(u.parent, wr)), + u, + û, + t, + t̂, + wls, + wr, d, λ, alg, diff --git a/src/derivatives.jl b/src/derivatives.jl index 30c76fd0..01eb18bb 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -19,14 +19,16 @@ end function _derivative(A::LinearInterpolation, t::Number, iguess) idx = get_idx(A.t, t, iguess; idx_shift = -1, ub_shift = -2, side = :first) - A.p.slope[idx], idx + slope = get_parameters(A, idx) + slope, idx end function _derivative(A::QuadraticInterpolation, t::Number, iguess) i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess) - du₀ = A.p.l₀[i₀] * (2t - A.t[i₁] - A.t[i₂]) - du₁ = A.p.l₁[i₀] * (2t - A.t[i₀] - A.t[i₂]) - du₂ = A.p.l₂[i₀] * (2t - A.t[i₀] - A.t[i₁]) + l₀, l₁, l₂ = get_parameters(A, i₀) + du₀ = l₀ * (2t - A.t[i₁] - A.t[i₂]) + du₁ = l₁ * (2t - A.t[i₀] - A.t[i₂]) + du₂ = l₂ * (2t - A.t[i₀] - A.t[i₁]) return @views @. du₀ + du₁ + du₂, i₀ end @@ -129,7 +131,7 @@ end # QuadraticSpline Interpolation function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first) - σ = A.p.σ[idx - 1] + σ = get_parameters(A, idx - 1) A.z[idx - 1] + 2σ * (t - A.t[idx - 1]), idx end @@ -139,8 +141,9 @@ function _derivative(A::CubicSpline{<:AbstractVector}, t::Number, iguess) Δt₁ = t - A.t[idx] Δt₂ = A.t[idx + 1] - t dI = (-A.z[idx] * Δt₂^2 + A.z[idx + 1] * Δt₁^2) / (2A.h[idx + 1]) - dC = A.p.c₁[idx] - dD = -A.p.c₂[idx] + c₁, c₂ = get_parameters(A, idx) + dC = c₁ + dD = -c₂ dI + dC + dD, idx end @@ -193,7 +196,8 @@ function _derivative( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.du[idx] - out += Δt₀ * (Δt₀ * A.p.c₂[idx] + 2(A.p.c₁[idx] + Δt₁ * A.p.c₂[idx])) + c₁, c₂ = get_parameters(A, idx) + out += Δt₀ * (Δt₀ * c₂ + 2(c₁ + Δt₁ * c₂)) out, idx end @@ -204,7 +208,8 @@ function _derivative( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.du[idx] + A.ddu[idx] * Δt₀ + c₁, c₂, c₃ = get_parameters(A, idx) out += Δt₀^2 * - (3A.p.c₁[idx] + (3Δt₁ + Δt₀) * A.p.c₂[idx] + (3Δt₁^2 + Δt₀ * 2Δt₁) * A.p.c₃[idx]) + (3c₁ + (3Δt₁ + Δt₀) * c₂ + (3Δt₁^2 + Δt₀ * 2Δt₁) * c₃) out, idx end diff --git a/src/integral_inverses.jl b/src/integral_inverses.jl index 4437726e..38c14b14 100644 --- a/src/integral_inverses.jl +++ b/src/integral_inverses.jl @@ -40,10 +40,9 @@ struct LinearInterpolationIntInv{uType, tType, itpType, T} <: extrapolate::Bool idx_prev::Base.RefValue{Int} itp::itpType - safetycopy::Bool function LinearInterpolationIntInv(u, t, A) new{typeof(u), typeof(t), typeof(A), eltype(u)}( - u, t, A.extrapolate, Ref(1), A, A.safetycopy) + u, t, A.extrapolate, Ref(1), A) end end @@ -51,9 +50,11 @@ function invertible_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) return all(A.u .> 0) end +get_I(A::AbstractInterpolation) = isnothing(A.I) ? cumulative_integral(A) : A.I + function invert_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) !invertible_integral(A) && throw(IntegralNotInvertibleError()) - return LinearInterpolationIntInv(A.t, A.I, A) + return LinearInterpolationIntInv(A.t, get_I(A), A) end function _interpolate( @@ -61,7 +62,8 @@ function _interpolate( idx = get_idx(A.t, t, iguess) Δt = t - A.t[idx] x = A.itp.u[idx] - u = A.u[idx] + 2Δt / (x + sqrt(x^2 + A.itp.p.slope[idx] * 2Δt)) + slope = get_parameters(A.itp, idx) + u = A.u[idx] + 2Δt / (x + sqrt(x^2 + slope * 2Δt)) u, idx end @@ -84,10 +86,9 @@ struct ConstantInterpolationIntInv{uType, tType, itpType, T} <: extrapolate::Bool idx_prev::Base.RefValue{Int} itp::itpType - safetycopy::Bool function ConstantInterpolationIntInv(u, t, A) new{typeof(u), typeof(t), typeof(A), eltype(u)}( - u, t, A.extrapolate, Ref(1), A, A.safetycopy + u, t, A.extrapolate, Ref(1), A ) end end @@ -98,7 +99,7 @@ end function invert_integral(A::ConstantInterpolation{<:AbstractVector{<:Number}}) !invertible_integral(A) && throw(IntegralNotInvertibleError()) - return ConstantInterpolationIntInv(A.t, A.I, A) + return ConstantInterpolationIntInv(A.t, get_I(A), A) end function _interpolate( diff --git a/src/integrals.jl b/src/integrals.jl index 3040189f..03ea26c5 100644 --- a/src/integrals.jl +++ b/src/integrals.jl @@ -12,14 +12,24 @@ function integral(A::AbstractInterpolation, t1::Number, t2::Number) # the index less than t2 idx2 = get_idx(A.t, t2, 0; idx_shift = -1, side = :first) - total = A.I[idx2] - A.I[idx1] - return if t1 == t2 - zero(total) + if A.cache_parameters + total = A.I[idx2] - A.I[idx1] + return if t1 == t2 + zero(total) + else + total += _integral(A, idx1, A.t[idx1]) + total -= _integral(A, idx1, t1) + total += _integral(A, idx2, t2) + total -= _integral(A, idx2, A.t[idx2]) + total + end else - total += _integral(A, idx1, A.t[idx1]) - total -= _integral(A, idx1, t1) - total += _integral(A, idx2, t2) - total -= _integral(A, idx2, A.t[idx2]) + total = zero(eltype(A.u)) + for idx in idx1:idx2 + lt1 = idx == idx1 ? t1 : A.t[idx] + lt2 = idx == idx2 ? t2 : A.t[idx + 1] + total += _integral(A, idx, lt2) - _integral(A, idx, lt1) + end total end end @@ -28,7 +38,8 @@ function _integral(A::LinearInterpolation{<:AbstractVector{<:Number}}, idx::Number, t::Number) Δt = t - A.t[idx] - Δt * (A.u[idx] + A.p.slope[idx] * Δt / 2) + slope = get_parameters(A, idx) + Δt * (A.u[idx] + slope * Δt / 2) end function _integral( @@ -52,24 +63,27 @@ function _integral(A::QuadraticInterpolation{<:AbstractVector{<:Number}}, t₂ = A.t[idx + 2] t_sq = (t^2) / 3 - Iu₀ = A.p.l₀[idx] * t * (t_sq - t * (t₁ + t₂) / 2 + t₁ * t₂) - Iu₁ = A.p.l₁[idx] * t * (t_sq - t * (t₀ + t₂) / 2 + t₀ * t₂) - Iu₂ = A.p.l₂[idx] * t * (t_sq - t * (t₀ + t₁) / 2 + t₀ * t₁) + l₀, l₁, l₂ = get_parameters(A, idx) + Iu₀ = l₀ * t * (t_sq - t * (t₁ + t₂) / 2 + t₁ * t₂) + Iu₁ = l₁ * t * (t_sq - t * (t₀ + t₂) / 2 + t₀ * t₂) + Iu₂ = l₂ * t * (t_sq - t * (t₀ + t₁) / 2 + t₀ * t₁) return Iu₀ + Iu₁ + Iu₂ end function _integral(A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt^2 / 2 + A.p.σ[idx] * Δt^3 / 3 + Cᵢ * Δt + σ = get_parameters(A, idx) + return A.z[idx] * Δt^2 / 2 + σ * Δt^3 / 3 + Cᵢ * Δt end function _integral(A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) Δt₁sq = (t - A.t[idx])^2 / 2 Δt₂sq = (A.t[idx + 1] - t)^2 / 2 II = (-A.z[idx] * Δt₂sq^2 + A.z[idx + 1] * Δt₁sq^2) / (6A.h[idx + 1]) - IC = A.p.c₁[idx] * Δt₁sq - ID = -A.p.c₂[idx] * Δt₂sq + c₁, c₂ = get_parameters(A, idx) + IC = c₁ * Δt₁sq + ID = -c₂ * Δt₂sq II + IC + ID end @@ -91,8 +105,9 @@ function _integral( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = Δt₀ * (A.u[idx] + Δt₀ * A.du[idx] / 2) - p = A.p.c₁[idx] + Δt₁ * A.p.c₂[idx] - dp = A.p.c₂[idx] + c₁, c₂ = get_parameters(A, idx) + p = c₁ + Δt₁ * c₂ + dp = c₂ out += Δt₀^3 / 3 * (p - dp * Δt₀ / 4) out end @@ -103,9 +118,10 @@ function _integral( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = Δt₀ * (A.u[idx] + A.du[idx] * Δt₀ / 2 + A.ddu[idx] * Δt₀^2 / 6) - p = A.p.c₁[idx] + A.p.c₂[idx] * Δt₁ + A.p.c₃[idx] * Δt₁^2 - dp = A.p.c₂[idx] + 2A.p.c₃[idx] * Δt₁ - ddp = 2A.p.c₃[idx] + c₁, c₂, c₃ = get_parameters(A, idx) + p = c₁ + c₂ * Δt₁ + c₃ * Δt₁^2 + dp = c₂ + 2c₃ * Δt₁ + ddp = 2c₃ out += Δt₀^4 / 4 * (p - Δt₀ / 5 * dp + Δt₀^2 / 30 * ddp) out end diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index c7274471..286bf6bc 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -1,5 +1,5 @@ """ - LinearInterpolation(u, t; extrapolate = false) + LinearInterpolation(u, t; extrapolate = false, cache_parameters = false) It is the method of interpolating between the data points using a linear polynomial. For any point, two data points one each side are chosen and connected with a line. Extrapolation extends the last linear polynomial on each side. @@ -12,7 +12,7 @@ Extrapolation extends the last linear polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct LinearInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolation{T} u::uType @@ -21,23 +21,33 @@ struct LinearInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolati p::LinearParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function LinearInterpolation(u, t, I, p, extrapolate, safetycopy) + cache_parameters::Bool + function LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.slope), eltype(u)}( - u, t, I, p, extrapolate, Ref(1), safetycopy) + u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function LinearInterpolation(u, t; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - p = LinearParameterCache(u, t) - A = LinearInterpolation(u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - LinearInterpolation(u, t, I, p, extrapolate, safetycopy) +function LinearInterpolation(u, t; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + p = if cache_parameters + LinearParameterCache(u, t) + else + LinearParameterCache(nothing) + end + + A = LinearInterpolation(u, t, nothing, p, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) + end + + A end """ - QuadraticInterpolation(u, t, mode = :Forward; extrapolate = false) + QuadraticInterpolation(u, t, mode = :Forward; cache_parameters = false) It is the method of interpolating between the data points using quadratic polynomials. For any point, three data points nearby are taken to fit a quadratic polynomial. Extrapolation extends the last quadratic polynomial on each side. @@ -51,7 +61,7 @@ Extrapolation extends the last quadratic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuadraticInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolation{T} u::uType @@ -61,25 +71,35 @@ struct QuadraticInterpolation{uType, tType, IType, pType, T} <: AbstractInterpol mode::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function QuadraticInterpolation(u, t, I, p, mode, extrapolate, safetycopy) + cache_parameters::Bool + function QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) mode ∈ (:Forward, :Backward) || error("mode should be :Forward or :Backward for QuadraticInterpolation") new{typeof(u), typeof(t), typeof(I), typeof(p.l₀), eltype(u)}( - u, t, I, p, mode, extrapolate, Ref(1), safetycopy) + u, t, I, p, mode, extrapolate, Ref(1), cache_parameters) end end -function QuadraticInterpolation(u, t, mode; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - p = QuadraticParameterCache(u, t) - A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticInterpolation(u, t, I, p, mode, extrapolate, safetycopy) +function QuadraticInterpolation(u, t, mode; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + p = if cache_parameters + QuadraticParameterCache(u, t) + else + QuadraticParameterCache(nothing, nothing, nothing) + end + + A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) + end + + A end -function QuadraticInterpolation(u, t; extrapolate = false, safetycopy = true) - QuadraticInterpolation(u, t, :Forward; extrapolate, safetycopy) +function QuadraticInterpolation(u, t; extrapolate = false, cache_parameters = false) + QuadraticInterpolation(u, t, :Forward; extrapolate, cache_parameters) end """ @@ -96,7 +116,6 @@ It is the method of interpolation using Lagrange polynomials of (k-1)th order pa ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: AbstractInterpolation{T} @@ -107,8 +126,7 @@ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: idxs::Vector{Int} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function LagrangeInterpolation(u, t, n, extrapolate, safetycopy) + function LagrangeInterpolation(u, t, n, extrapolate) bcache = zeros(eltype(u[1]), n + 1) idxs = zeros(Int, n + 1) fill!(bcache, NaN) @@ -118,23 +136,22 @@ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: bcache, idxs, extrapolate, - Ref(1), - safetycopy + Ref(1) ) end end function LagrangeInterpolation( - u, t, n = length(t) - 1; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, n = length(t) - 1; extrapolate = false) + u, t = munge_data(u, t) if n != length(t) - 1 error("Currently only n=length(t) - 1 is supported") end - LagrangeInterpolation(u, t, n, extrapolate, safetycopy) + LagrangeInterpolation(u, t, n, extrapolate) end """ - AkimaInterpolation(u, t; extrapolate = false) + AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation built from cubic polynomials. It forms a continuously differentiable function. For more details, refer: https://en.wikipedia.org/wiki/Akima_spline. Extrapolation extends the last cubic polynomial on each side. @@ -147,7 +164,7 @@ Extrapolation extends the last cubic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: AbstractInterpolation{T} @@ -159,8 +176,8 @@ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: d::dType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function AkimaInterpolation(u, t, I, b, c, d, extrapolate, safetycopy) + cache_parameters::Bool + function AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(b), typeof(c), typeof(d), eltype(u)}(u, t, @@ -170,13 +187,13 @@ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: d, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end -function AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) +function AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) n = length(t) dt = diff(t) m = Array{eltype(u)}(undef, n + 3) @@ -197,13 +214,18 @@ function AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) c = (3.0 .* m[3:(end - 2)] .- 2.0 .* b[1:(end - 1)] .- b[2:end]) ./ dt d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2 - A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, safetycopy) - I = cumulative_integral(A) - AkimaInterpolation(u, t, I, b, c, d, extrapolate, safetycopy) + A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) + end + + A end """ - ConstantInterpolation(u, t; dir = :left, extrapolate = false) + ConstantInterpolation(u, t; dir = :left, extrapolate = false, cache_parameters = false) It is the method of interpolating using a constant polynomial. For any point, two adjacent data points are found on either side (left and right). The value at that point depends on `dir`. If it is `:left`, then the value at the left point is chosen and if it is `:right`, the value at the right point is chosen. @@ -218,7 +240,7 @@ Extrapolation extends the last constant polynomial at the end points on each sid - `dir`: indicates which value should be used for interpolation (`:left` or `:right`). - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct ConstantInterpolation{uType, tType, IType, T} <: AbstractInterpolation{T} u::uType @@ -228,22 +250,28 @@ struct ConstantInterpolation{uType, tType, IType, T} <: AbstractInterpolation{T} dir::Symbol # indicates if value to the $dir should be used for the interpolation extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function ConstantInterpolation(u, t, I, dir, extrapolate, safetycopy) + cache_parameters::Bool + function ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), eltype(u)}( - u, t, I, nothing, dir, extrapolate, Ref(1), safetycopy) + u, t, I, nothing, dir, extrapolate, Ref(1), cache_parameters) end end -function ConstantInterpolation(u, t; dir = :left, extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - A = ConstantInterpolation(u, t, nothing, dir, extrapolate, safetycopy) - I = cumulative_integral(A) - ConstantInterpolation(u, t, I, dir, extrapolate, safetycopy) +function ConstantInterpolation( + u, t; dir = :left, extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + A = ConstantInterpolation(u, t, nothing, dir, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) + end + + A end """ - QuadraticSpline(u, t; extrapolate = false) + QuadraticSpline(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation using piecewise quadratic polynomials between each pair of data points. Its first derivative is also continuous. Extrapolation extends the last quadratic polynomial on each side. @@ -256,7 +284,7 @@ Extrapolation extends the last quadratic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: AbstractInterpolation{T} @@ -269,8 +297,8 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: z::zType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + cache_parameters::Bool + function QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.σ), typeof(tA), typeof(d), typeof(z), eltype(u)}(u, t, @@ -281,15 +309,15 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: z, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end function QuadraticSpline( u::uType, t; extrapolate = false, - safetycopy = true) where {uType <: AbstractVector{<:Number}} - u, t = munge_data(u, t, safetycopy) + cache_parameters = false) where {uType <: AbstractVector{<:Number}} + u, t = munge_data(u, t) s = length(t) dl = ones(eltype(t), s - 1) d_tmp = ones(eltype(t), s) @@ -301,15 +329,27 @@ function QuadraticSpline( d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s) z = tA \ d - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + + p = if cache_parameters + QuadraticSplineParameterCache(z, t) + else + QuadraticSplineParameterCache(nothing) + end + + A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) + end + + A end function QuadraticSpline( - u::uType, t; extrapolate = false, safetycopy = true) where {uType <: AbstractVector} - u, t = munge_data(u, t, safetycopy) + u::uType, t; extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector} + u, t = munge_data(u, t) s = length(t) dl = ones(eltype(t), s - 1) d_tmp = ones(eltype(t), s) @@ -322,14 +362,23 @@ function QuadraticSpline( d = transpose(reshape(reduce(hcat, d_), :, s)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + p = if cache_parameters + QuadraticSplineParameterCache(z, t) + else + QuadraticSplineParameterCache(nothing) + end + A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) + end + + A end """ - CubicSpline(u, t; extrapolate = false) + CubicSpline(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation using piecewise cubic polynomials between each pair of data points. Its first and second derivative is also continuous. Second derivative on both ends are zero, which are also called "natural" boundary conditions. Extrapolation extends the last cubic polynomial on each side. @@ -342,7 +391,7 @@ Second derivative on both ends are zero, which are also called "natural" boundar ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInterpolation{T} u::uType @@ -353,8 +402,8 @@ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInter z::zType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function CubicSpline(u, t, I, p, h, z, extrapolate, safetycopy) + cache_parameters::Bool + function CubicSpline(u, t, I, p, h, z, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.c₁), typeof(h), typeof(z), eltype(u)}( u, t, @@ -364,15 +413,16 @@ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInter z, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end function CubicSpline(u::uType, t; - extrapolate = false, safetycopy = true) where {uType <: AbstractVector{<:Number}} - u, t = munge_data(u, t, safetycopy) + extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector{<:Number}} + u, t = munge_data(u, t) n = length(t) - 1 h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0) dl = vcat(h[2:n], zero(eltype(h))) @@ -389,15 +439,25 @@ function CubicSpline(u::uType, 6(u[i + 1] - u[i]) / h[i + 1] - 6(u[i] - u[i - 1]) / h[i], 1:(n + 1)) z = tA \ d - p = CubicSplineParameterCache(u, h, z) - A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) + p = if cache_parameters + CubicSplineParameterCache(u, h, z) + else + CubicSplineParameterCache(nothing, nothing) + end + A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + end + + A end function CubicSpline( - u::uType, t; extrapolate = false, safetycopy = true) where {uType <: AbstractVector} - u, t = munge_data(u, t, safetycopy) + u::uType, t; extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector} + u, t = munge_data(u, t) n = length(t) - 1 h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0) dl = vcat(h[2:n], zero(eltype(h))) @@ -411,10 +471,20 @@ function CubicSpline( d = transpose(reshape(reduce(hcat, d_), :, n + 1)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = CubicSplineParameterCache(u, h, z) - A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) + p = if cache_parameters + CubicSplineParameterCache(u, h, z) + else + CubicSplineParameterCache(nothing, nothing) + end + + A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + end + + A end """ @@ -434,7 +504,6 @@ Extrapolation is a constant polynomial of the end points on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: AbstractInterpolation{T} @@ -449,7 +518,6 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: knotVecType::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool function BSplineInterpolation(u, t, d, @@ -459,8 +527,7 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: N, pVecType, knotVecType, - extrapolate, - safetycopy) + extrapolate) new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), typeof(N), eltype(u)}(u, t, d, @@ -471,15 +538,14 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: pVecType, knotVecType, extrapolate, - Ref(1), - safetycopy + Ref(1) ) end end function BSplineInterpolation( - u, t, d, pVecType, knotVecType; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, d, pVecType, knotVecType; extrapolate = false) + u, t = munge_data(u, t) n = length(t) n < d + 1 && error("BSplineInterpolation needs at least d + 1, i.e. $(d+1) points.") s = zero(eltype(u)) @@ -543,11 +609,11 @@ function BSplineInterpolation( c = vec(N \ u[:, :]) N = zeros(eltype(t), n) BSplineInterpolation( - u, t, d, p, k, c, N, pVecType, knotVecType, extrapolate, safetycopy) + u, t, d, p, k, c, N, pVecType, knotVecType, extrapolate) end """ - BSplineApprox(u, t, d, h, pVecType, knotVecType; extrapolate = false) + BSplineApprox(u, t, d, h, pVecType, knotVecType) It is a regression based B-spline. The argument choices are the same as the `BSplineInterpolation`, with the additional parameter `h < length(t)` which is the number of control points to use, with smaller `h` indicating more smoothing. For more information, refer http://www.cad.zju.edu.cn/home/zhx/GM/009/00-bsia.pdf. @@ -565,7 +631,6 @@ Extrapolation is a constant polynomial of the end points on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: AbstractInterpolation{T} @@ -581,7 +646,6 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: knotVecType::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool function BSplineApprox(u, t, d, @@ -592,8 +656,7 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: N, pVecType, knotVecType, - extrapolate, - safetycopy + extrapolate ) new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), typeof(N), eltype(u)}(u, t, @@ -606,15 +669,14 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: pVecType, knotVecType, extrapolate, - Ref(1), - safetycopy::Bool + Ref(1) ) end end function BSplineApprox( - u, t, d, h, pVecType, knotVecType; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, d, h, pVecType, knotVecType; extrapolate = false) + u, t = munge_data(u, t) n = length(t) h < d + 1 && error("BSplineApprox needs at least d + 1, i.e. $(d+1) control points.") s = zero(eltype(u)) @@ -698,11 +760,12 @@ function BSplineApprox( P = M \ Q c[2:(end - 1)] .= vec(P) N = zeros(eltype(t), h) - BSplineApprox(u, t, d, h, p, k, c, N, pVecType, knotVecType, extrapolate, safetycopy) + BSplineApprox( + u, t, d, h, p, k, c, N, pVecType, knotVecType, extrapolate) end """ - CubicHermiteSpline(du, u, t; extrapolate = false) + CubicHermiteSpline(du, u, t; extrapolate = false, cache_parameters = false) It is a Cubic Hermite interpolation, which is a piece-wise third degree polynomial such that the value and the first derivative are equal to given values in the data points. @@ -715,7 +778,7 @@ It is a Cubic Hermite interpolation, which is a piece-wise third degree polynomi ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct CubicHermiteSpline{uType, tType, IType, duType, pType, T} <: AbstractInterpolation{T} du::duType @@ -725,24 +788,33 @@ struct CubicHermiteSpline{uType, tType, IType, duType, pType, T} <: AbstractInte p::CubicHermiteParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function CubicHermiteSpline(du, u, t, I, p, extrapolate, safetycopy) + cache_parameters::Bool + function CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(du), typeof(p.c₁), eltype(u)}( - du, u, t, I, p, extrapolate, Ref(1), safetycopy) + du, u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function CubicHermiteSpline(du, u, t; extrapolate = false, safetycopy = true) +function CubicHermiteSpline(du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du) "Length of `u` is not equal to length of `du`." - u, t = munge_data(u, t, safetycopy) - p = CubicHermiteParameterCache(du, u, t) - A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicHermiteSpline(du, u, t, I, p, extrapolate, safetycopy) + u, t = munge_data(u, t) + p = if cache_parameters + CubicHermiteParameterCache(du, u, t) + else + CubicHermiteParameterCache(nothing, nothing) + end + A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) + end + + A end """ - QuinticHermiteSpline(ddu, du, u, t; extrapolate = false) + QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, cache_parameters = false) It is a Quintic Hermite interpolation, which is a piece-wise fifth degree polynomial such that the value and the first and second derivative are equal to given values in the data points. @@ -756,7 +828,7 @@ It is a Quintic Hermite interpolation, which is a piece-wise fifth degree polyno ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuinticHermiteSpline{uType, tType, IType, duType, dduType, pType, T} <: AbstractInterpolation{T} @@ -768,19 +840,28 @@ struct QuinticHermiteSpline{uType, tType, IType, duType, dduType, pType, T} <: p::QuinticHermiteParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, safetycopy) + cache_parameters::Bool + function QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(du), typeof(ddu), typeof(p.c₁), eltype(u)}( - ddu, du, u, t, I, p, extrapolate, Ref(1), safetycopy) + ddu, du, u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, safetycopy = true) +function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du)==length(ddu) "Length of `u` is not equal to length of `du` or `ddu`." - u, t = munge_data(u, t, safetycopy) - p = QuinticHermiteParameterCache(ddu, du, u, t) - A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, safetycopy) + u, t = munge_data(u, t) + p = if cache_parameters + QuinticHermiteParameterCache(ddu, du, u, t) + else + QuinticHermiteParameterCache(nothing, nothing, nothing) + end + A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) + end + + A end diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 5d09ceff..c409d5fc 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -10,15 +10,15 @@ end function _interpolate(A::LinearInterpolation{<:AbstractVector}, t::Number, iguess) if isnan(t) # For correct derivative with NaN - idx = firstindex(A.u) - 1 + idx = firstindex(A.u) t1 = t2 = one(eltype(A.t)) u1 = u2 = one(eltype(A.u)) - slope = t * one(eltype(A.p.slope)) + slope = t * get_parameters(A, idx) else idx = get_idx(A.t, t, iguess) t1, t2 = A.t[idx], A.t[idx + 1] u1, u2 = A.u[idx], A.u[idx + 1] - slope = A.p.slope[idx] + slope = get_parameters(A, idx) end Δt = t - t1 @@ -38,7 +38,8 @@ end function _interpolate(A::LinearInterpolation{<:AbstractMatrix}, t::Number, iguess) idx = get_idx(A.t, t, iguess) Δt = t - A.t[idx] - return A.u[:, idx] + A.p.slope[idx] * Δt, idx + slope = get_parameters(A, idx) + return A.u[:, idx] + slope * Δt, idx end # Quadratic Interpolation @@ -50,9 +51,10 @@ end function _interpolate(A::QuadraticInterpolation, t::Number, iguess) i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess) - u₀ = A.p.l₀[i₀] * (t - A.t[i₁]) * (t - A.t[i₂]) - u₁ = A.p.l₁[i₀] * (t - A.t[i₀]) * (t - A.t[i₂]) - u₂ = A.p.l₂[i₀] * (t - A.t[i₀]) * (t - A.t[i₁]) + l₀, l₁, l₂ = get_parameters(A, i₀) + u₀ = l₀ * (t - A.t[i₁]) * (t - A.t[i₂]) + u₁ = l₁ * (t - A.t[i₀]) * (t - A.t[i₂]) + u₂ = l₂ * (t - A.t[i₀]) * (t - A.t[i₁]) return u₀ + u₁ + u₂, i₀ end @@ -149,7 +151,8 @@ function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt + A.p.σ[idx] * Δt^2 + Cᵢ, idx + σ = get_parameters(A, idx) + return A.z[idx] * Δt + σ * Δt^2 + Cᵢ, idx end # CubicSpline Interpolation @@ -158,8 +161,9 @@ function _interpolate(A::CubicSpline{<:AbstractVector}, t::Number, iguess) Δt₁ = t - A.t[idx] Δt₂ = A.t[idx + 1] - t I = (A.z[idx] * Δt₂^3 + A.z[idx + 1] * Δt₁^3) / (6A.h[idx + 1]) - C = A.p.c₁[idx] * Δt₁ - D = A.p.c₂[idx] * Δt₂ + c₁, c₂ = get_parameters(A, idx) + C = c₁ * Δt₁ + D = c₂ * Δt₂ I + C + D, idx end @@ -205,7 +209,8 @@ function _interpolate( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.u[idx] + Δt₀ * A.du[idx] - out += Δt₀^2 * (A.p.c₁[idx] + Δt₁ * A.p.c₂[idx]) + c₁, c₂ = get_parameters(A, idx) + out += Δt₀^2 * (c₁ + Δt₁ * c₂) out, idx end @@ -216,6 +221,7 @@ function _interpolate( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.u[idx] + Δt₀ * (A.du[idx] + A.ddu[idx] * Δt₀ / 2) - out += Δt₀^3 * (A.p.c₁[idx] + Δt₁ * (A.p.c₂[idx] + A.p.c₃[idx] * Δt₁)) + c₁, c₂, c₃ = get_parameters(A, idx) + out += Δt₀^3 * (c₁ + Δt₁ * (c₂ + c₃ * Δt₁)) out, idx end diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 466248b1..17b08328 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -60,15 +60,11 @@ function spline_coefficients!(N, d, k, u::AbstractVector) end # helper function for data manipulation -function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real}, safetycopy::Bool) - if safetycopy - u = copy(u) - t = copy(t) - end - return readonly_wrap(u), readonly_wrap(t) +function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real}) + return u, t end -function munge_data(u::AbstractVector, t::AbstractVector, safetycopy::Bool) +function munge_data(u::AbstractVector, t::AbstractVector) Tu = Base.nonmissingtype(eltype(u)) Tt = Base.nonmissingtype(eltype(t)) @assert length(t) == length(u) @@ -77,17 +73,13 @@ function munge_data(u::AbstractVector, t::AbstractVector, safetycopy::Bool) if !ismissing(u[i]) && !ismissing(t[i]) ) - if safetycopy - u = Tu.([u[i] for i in non_missing_indices]) - t = Tt.([t[i] for i in non_missing_indices]) - else - !isempty(non_missing_indices) && throw(MustCopyError()) - end + u = Tu.([u[i] for i in non_missing_indices]) + t = Tt.([t[i] for i in non_missing_indices]) - return readonly_wrap(u), readonly_wrap(t) + return u, t end -function munge_data(U::StridedMatrix, t::AbstractVector, safetycopy::Bool) +function munge_data(U::StridedMatrix, t::AbstractVector) TU = Base.nonmissingtype(eltype(U)) Tt = Base.nonmissingtype(eltype(t)) @assert length(t) == size(U, 2) @@ -96,20 +88,12 @@ function munge_data(U::StridedMatrix, t::AbstractVector, safetycopy::Bool) if !any(ismissing, U[:, i]) && !ismissing(t[i]) ) - if safetycopy - U = hcat([TU.(U[:, i]) for i in non_missing_indices]...) - t = Tt.([t[i] for i in non_missing_indices]) - else - !isempty(non_missing_indices) && throw(MustCopyError()) - end + U = hcat([TU.(U[:, i]) for i in non_missing_indices]...) + t = Tt.([t[i] for i in non_missing_indices]) - return readonly_wrap(U), readonly_wrap(t) + return U, t end -# Don't nest ReadOnlyArrays -readonly_wrap(a::AbstractArray) = ReadOnlyArray(a) -readonly_wrap(a::ReadOnlyArray) = a - function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = :last) ub = length(tvec) + ub_shift return if side == :last @@ -130,3 +114,51 @@ function cumulative_integral(A) pushfirst!(integral_values, zero(first(integral_values))) return cumsum(integral_values) end + +function get_parameters(A::LinearInterpolation, idx) + if A.cache_parameters + A.p.slope[idx] + else + linear_interpolation_parameters(A.u, A.t, idx) + end +end + +function get_parameters(A::QuadraticInterpolation, idx) + if A.cache_parameters + A.p.l₀[idx], A.p.l₁[idx], A.p.l₂[idx] + else + quadratic_interpolation_parameters(A.u, A.t, idx) + end +end + +function get_parameters(A::QuadraticSpline, idx) + if A.cache_parameters + A.p.σ[idx] + else + quadratic_spline_parameters(A.z, A.t, idx) + end +end + +function get_parameters(A::CubicSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx] + else + cubic_spline_parameters(A.u, A.h, A.z, idx) + end +end + +function get_parameters(A::CubicHermiteSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx] + else + cubic_hermite_spline_parameters(A.du, A.u, A.t, idx) + end +end + +function get_parameters(A::QuinticHermiteSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx], A.p.c₃[idx] + else + quintic_hermite_spline_parameters(A.ddu, A.du, A.u, A.t, idx) + end +end diff --git a/src/online.jl b/src/online.jl index 0fab5d44..5193e6b2 100644 --- a/src/online.jl +++ b/src/online.jl @@ -9,69 +9,81 @@ function add_integral_values!(A) end function push!(A::LinearInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - slope = linear_interpolation_parameters(A.u, A.t, length(A.t) - 1) - push!(A.p.slope, slope) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + slope = linear_interpolation_parameters(A.u, A.t, length(A.t) - 1) + push!(A.p.slope, slope) + add_integral_values!(A) + end A end function push!(A::QuadraticInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - l₀, l₁, l₂ = quadratic_interpolation_parameters(A.u, A.t, length(A.t) - 2) - push!(A.p.l₀, l₀) - push!(A.p.l₁, l₁) - push!(A.p.l₂, l₂) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + l₀, l₁, l₂ = quadratic_interpolation_parameters(A.u, A.t, length(A.t) - 2) + push!(A.p.l₀, l₀) + push!(A.p.l₁, l₁) + push!(A.p.l₂, l₂) + add_integral_values!(A) + end A end function push!(A::ConstantInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + add_integral_values!(A) + end A end function append!( - A::LinearInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} + A::LinearInterpolation{U, T}, u::U, t::T) where { + U, T} length_old = length(A.t) - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - slope = linear_interpolation_parameters.( - Ref(A.u), Ref(A.t), length_old:(length(A.t) - 1)) - append!(A.p.slope, slope) - add_integral_values!(A) + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + slope = linear_interpolation_parameters.( + Ref(A.u), Ref(A.t), length_old:(length(A.t) - 1)) + append!(A.p.slope, slope) + add_integral_values!(A) + end A end function append!( - A::ConstantInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - add_integral_values!(A) + A::ConstantInterpolation{U, T}, u::U, t::T) where { + U, T} + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + add_integral_values!(A) + end A end function append!( - A::QuadraticInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} + A::QuadraticInterpolation{U, T}, u::U, t::T) where { + U, T} length_old = length(A.t) - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - parameters = quadratic_interpolation_parameters.( - Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) - l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) - append!(A.p.l₀, l₀) - append!(A.p.l₁, l₁) - append!(A.p.l₂, l₂) - add_integral_values!(A) + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + parameters = quadratic_interpolation_parameters.( + Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) + l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) + append!(A.p.l₀, l₀) + append!(A.p.l₁, l₁) + append!(A.p.l₂, l₂) + add_integral_values!(A) + end A end diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 9549038c..9562c7b4 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -9,7 +9,6 @@ function test_interpolation_type(T) @test hasfield(T, :t) @test hasfield(T, :extrapolate) @test hasfield(T, :idx_prev) - @test hasfield(T, :safetycopy) @test !isempty(methods(DataInterpolations._interpolate, (T, Any, Number))) @test !isempty(methods(DataInterpolations._integral, (T, Any, Number))) @test !isempty(methods(DataInterpolations._derivative, (T, Any, Number))) diff --git a/test/online_tests.jl b/test/online_tests.jl index f9c3e1dd..3ae6438e 100644 --- a/test/online_tests.jl +++ b/test/online_tests.jl @@ -9,9 +9,9 @@ u2 = [1.0, 2.0, 1.0] ts = 1.0:0.5:6.0 for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolation] - func1 = method(u1, t1) + func1 = method(copy(u1), copy(t1); cache_parameters = true) append!(func1, u2, t2) - func2 = method(vcat(u1, u2), vcat(t1, t2)) + func2 = method(vcat(u1, u2), vcat(t1, t2); cache_parameters = true) @test func1.u == func2.u @test func1.t == func2.t for name in propertynames(func1.p) @@ -20,9 +20,9 @@ for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolatio @test func1(ts) == func2(ts) @test func1.I == func2.I - func1 = method(u1, t1) + func1 = method(copy(u1), copy(t1); cache_parameters = true) push!(func1, 1.0, 4.0) - func2 = method(vcat(u1, 1.0), vcat(t1, 4.0)) + func2 = method(vcat(u1, 1.0), vcat(t1, 4.0); cache_parameters = true) @test func1.u == func2.u @test func1.t == func2.t for name in propertynames(func1.p) diff --git a/test/parameter_tests.jl b/test/parameter_tests.jl index bcd26cf7..2e84b98d 100644 --- a/test/parameter_tests.jl +++ b/test/parameter_tests.jl @@ -3,14 +3,14 @@ using DataInterpolations @testset "Linear Interpolation" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = LinearInterpolation(u, t) + A = LinearInterpolation(u, t; cache_parameters = true) @test A.p.slope ≈ [4.0, -2.0, 1.0, 0.0] end @testset "Quadratic Interpolation" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuadraticInterpolation(u, t) + A = QuadraticInterpolation(u, t; cache_parameters = true) @test A.p.l₀ ≈ [0.5, 2.5, 1.5] @test A.p.l₁ ≈ [-5.0, -3.0, -4.0] @test A.p.l₂ ≈ [1.5, 2.0, 2.0] @@ -19,14 +19,14 @@ end @testset "Quadratic Spline" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuadraticSpline(u, t) + A = QuadraticSpline(u, t; cache_parameters = true) @test A.p.σ ≈ [4.0, -10.0, 13.0, -14.0] end @testset "Cubic Spline" begin u = [1, 5, 3, 4, 4] t = collect(1:5) - A = CubicSpline(u, t) + A = CubicSpline(u, t; cache_parameters = true) @test A.p.c₁ ≈ [6.839285714285714, 1.642857142857143, 4.589285714285714, 4.0] @test A.p.c₂ ≈ [1.0, 6.839285714285714, 1.642857142857143, 4.589285714285714] end @@ -35,7 +35,7 @@ end du = [5.0, 3.0, 6.0, 8.0, 1.0] u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = CubicHermiteSpline(du, u, t) + A = CubicHermiteSpline(du, u, t; cache_parameters = true) @test A.p.c₁ ≈ [-1.0, -5.0, -5.0, -8.0] @test A.p.c₂ ≈ [0.0, 13.0, 12.0, 9.0] end @@ -45,7 +45,7 @@ end du = [5.0, 3.0, 6.0, 8.0, 1.0] u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuinticHermiteSpline(ddu, du, u, t) + A = QuinticHermiteSpline(ddu, du, u, t; cache_parameters = true) @test A.p.c₁ ≈ [-1.0, -6.5, -8.0, -10.0] @test A.p.c₂ ≈ [1.0, 19.5, 20.0, 19.0] @test A.p.c₃ ≈ [1.5, -37.5, -37.0, -26.5] diff --git a/test/runtests.jl b/test/runtests.jl index 0c722b2d..80080a75 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,3 +10,4 @@ using SafeTestsets @safetestset "Online Tests" include("online_tests.jl") @safetestset "Regularization Smoothing" include("regularization.jl") @safetestset "Show methods" include("show.jl") +@safetestset "Zygote support" include("zygote_tests.jl") diff --git a/test/zygote_tests.jl b/test/zygote_tests.jl new file mode 100644 index 00000000..6887ddd2 --- /dev/null +++ b/test/zygote_tests.jl @@ -0,0 +1,66 @@ +using DataInterpolations +using ForwardDiff +using Zygote + +function test_zygote(method, u, t; args = [], kwargs = [], name::String) + func = method(args..., u, t; kwargs..., extrapolate = true) + (; u, t) = func + trange = collect(range(minimum(t) - 5.0, maximum(t) + 5.0, step = 0.1)) + trange_exclude = filter(x -> !in(x, t), trange) + @testset "$name, derivatives w.r.t. input" begin + for _t in trange_exclude + adiff = DataInterpolations.derivative(func, _t) + zdiff = only(Zygote.gradient(func, _t)) + zdiff == nothing && (zdiff = 0.0) + @test adiff ≈ zdiff + end + end + @testset "$name, derivatives w.r.t. u" begin + function f(u) + A = method(args..., u, t; kwargs..., extrapolate = true) + out = zero(eltype(u)) + for _t in trange + out += A(_t) + end + out + end + zgrad = only(Zygote.gradient(f, u)) + fgrad = ForwardDiff.gradient(f, u) + @test zgrad ≈ fgrad + end +end + +@testset "LinearInterpolation" begin + u = vcat(collect(1:5), 2 * collect(6:10)) + t = 1.0collect(1:10) + test_zygote( + LinearInterpolation, u, t; name = "Linear Interpolation") +end + +@testset "Quadratic Interpolation" begin + u = [1.0, 4.0, 9.0, 16.0] + t = [1.0, 2.0, 3.0, 4.0] + test_zygote(QuadraticInterpolation, u, t; name = "Quadratic Interpolation") +end + +@testset "Constant Interpolation" begin + u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] + t = collect(0.0:10.0) + test_zygote(ConstantInterpolation, u, t; name = "Constant Interpolation") +end + +@testset "Cubic Hermite Spline" begin + du = [-0.047, -0.058, 0.054, 0.012, -0.068, 0.0] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + t = [0.0, 62.25, 109.66, 162.66, 205.8, 252.3] + test_zygote(CubicHermiteSpline, u, t, args = [du], name = "Cubic Hermite Spline") +end + +@testset "Quintic Hermite Spline" begin + ddu = [0.0, -0.00033, 0.0051, -0.0067, 0.0029, 0.0] + du = [-0.047, -0.058, 0.054, 0.012, -0.068, 0.0] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + t = [0.0, 62.25, 109.66, 162.66, 205.8, 252.3] + test_zygote( + QuinticHermiteSpline, u, t, args = [ddu, du], name = "Quintic Hermite Spline") +end From f69ad15d61d8c46d7a786f01e5c84930f186dc71 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 27 Jul 2024 20:55:48 +0200 Subject: [PATCH 22/35] Merge remote-tracking branch 'upstream/master' into cache_parameters_opt_in --- .github/workflows/CompatHelper.yml | 2 +- .github/workflows/Downgrade.yml | 2 +- Project.toml | 2 +- README.md | 7 + docs/Project.toml | 10 +- docs/make.jl | 3 +- docs/src/index.md | 13 +- docs/src/manual.md | 1 + docs/src/methods.md | 51 ++--- docs/src/symbolics.md | 65 ++++++ joss/paper.bib | 14 ++ joss/paper.md | 21 +- src/DataInterpolations.jl | 14 +- src/derivatives.jl | 2 +- src/integral_inverses.jl | 2 +- src/interpolation_caches.jl | 178 +++++---------- src/interpolation_utils.jl | 50 ++++- src/parameter_caches.jl | 105 ++++++--- src/plot_rec.jl | 349 +++++++++++++++++++++-------- test/derivative_tests.jl | 6 + test/interface.jl | 60 +++-- test/interpolation_tests.jl | 20 ++ test/online_tests.jl | 10 +- 23 files changed, 678 insertions(+), 309 deletions(-) create mode 100644 docs/src/symbolics.md diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 36e59135..35cc34ba 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -13,7 +13,7 @@ jobs: CompatHelper: runs-on: ubuntu-latest steps: - - uses: julia-actions/setup-julia@3645a07f58c7f83b9f82ac8e0bb95583e69149e6 + - uses: julia-actions/setup-julia@780022b48dfc0c2c6b94cfee6a9284850107d037 with: version: 1.3 - name: Pkg.add("CompatHelper") diff --git a/.github/workflows/Downgrade.yml b/.github/workflows/Downgrade.yml index c0d0123e..4546ebd0 100644 --- a/.github/workflows/Downgrade.yml +++ b/.github/workflows/Downgrade.yml @@ -28,7 +28,7 @@ jobs: - windows-latest steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2.2.0 + - uses: julia-actions/setup-julia@v2.3.0 with: version: ${{ matrix.version }} - uses: julia-actions/julia-downgrade-compat@v1 diff --git a/Project.toml b/Project.toml index 963ee8cd..ccef602b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DataInterpolations" uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" -version = "5.2.0" +version = "5.3.1" [deps] FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" diff --git a/README.md b/README.md index 8a326579..5348685f 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,7 @@ corresponding to `(u,t)` pairs. + `knotVec` - Symbol to Knot Vector, `knotVec = :Uniform` for uniform knot vector, `knotVec = :Average` for average spaced knot vector. - `BSplineApprox(u,t,d,h,pVec,knotVec)` - A regression B-spline which smooths the fitting curve. The argument choices are the same as the `BSplineInterpolation`, with the additional parameter `h "methods.md", - "Interface" => "interface.md", "Manual" => "manual.md", "Inverting Integrals" => "inverting_integrals.md"]) + "Interface" => "interface.md", "Using with Symbolics/ModelingToolkit" => "symbolics.md", + "Manual" => "manual.md", "Inverting Integrals" => "inverting_integrals.md"]) deploydocs(repo = "github.com/SciML/DataInterpolations.jl"; push_preview = true) diff --git a/docs/src/index.md b/docs/src/index.md index f5b1512b..f2075f8e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,10 +1,6 @@ # DataInterpolations.jl -DataInterpolations.jl is a library for performing interpolations of one-dimensional data. By -"data interpolations" we mean techniques for interpolating possibly noisy data, and thus -some methods are mixtures of regressions with interpolations (i.e. do not hit the data -points exactly, smoothing out the lines). This library can be used to fill in intermediate -data points in applications like timeseries data. +DataInterpolations.jl is a library for performing interpolations of one-dimensional data. Interpolations are a very important component of many modeling workflows. Often, sampled or measured inputs need to be transformed into continuous functions or smooth curves for simulation purposes. In many scientific machine learning workflows, interpolating data is essential to learn continuous models. DataInterpolations.jl can be used for facilitating these types of workflows. By "data interpolations" we mean techniques for interpolating possibly noisy data, and thus some methods are mixtures of regressions with interpolations (i.e. do not hit the data points exactly, smoothing out the lines). ## Installation @@ -35,6 +31,7 @@ corresponding to `(u,t)` pairs. + `knotVec` - Symbol to Knot Vector, `knotVec = :Uniform` for uniform knot vector, `knotVec = :Average` for average spaced knot vector. - `BSplineApprox(u,t,d,h,pVec,knotVec)` - A regression B-spline which smooths the fitting curve. The argument choices are the same as the `BSplineInterpolation`, with the additional parameter `h 0.5)) == cos(0.5) * A(0.5) # true +``` + +### Symbolic Derivatives + +```@example symbolics +D = Differential(τ) + +ex1 = A(τ) + +# Derivative of interpolation +ex2 = expand_derivatives(D(ex1)) + +@test substitute(ex2, Dict(τ => 0.5)) == DataInterpolations.derivative(A, 0.5) # true + +# Higher Order Derivatives +ex3 = expand_derivatives(D(D(A(τ)))) + +@test substitute(ex3, Dict(τ => 0.5)) == DataInterpolations.derivative(A, 0.5, 2) # true +``` + +## Using with ModelingToolkit.jl + +Most common use case with [ModelingToolkit.jl](https://docs.sciml.ai/ModelingToolkit/stable/) is to plug in interpolation objects as input functions. This can be done using `TimeVaryingFunction` component of [ModelingToolkitStandardLibrary.jl](https://docs.sciml.ai/ModelingToolkitStandardLibrary/stable/). + +```@example mtk +using DataInterpolations +using ModelingToolkitStandardLibrary.Blocks +using ModelingToolkit +using ModelingToolkit: t_nounits as t, D_nounits as D +using OrdinaryDiffEq + +us = [0.0, 1.5, 0.0] +times = [0.0, 0.5, 1.0] +A = LinearInterpolation(us, times) + +@named src = TimeVaryingFunction(A) +vars = @variables x(t) out(t) +eqs = [out ~ src.output.u, D(x) ~ 1 + out] +@named sys = ODESystem(eqs, t, vars, []; systems = [src]) + +sys = structural_simplify(sys) +prob = ODEProblem(sys, [x => 0.0], (times[1], times[end])) +sol = solve(prob) +``` diff --git a/joss/paper.bib b/joss/paper.bib index d754d9cc..101b0181 100644 --- a/joss/paper.bib +++ b/joss/paper.bib @@ -134,3 +134,17 @@ @book{lagrange1898lectures year={1898}, publisher={Open court publishing Company} } + +@article{doi:10.1137/0905021, + author = {Fritsch, F. N. and Butland, J.}, + title = {A Method for Constructing Local Monotone Piecewise Cubic Interpolants}, + journal = {SIAM Journal on Scientific and Statistical Computing}, + volume = {5}, + number = {2}, + pages = {300-304}, + year = {1984}, + doi = {10.1137/0905021}, + URL = {https://doi.org/10.1137/0905021}, + eprint = {https://doi.org/10.1137/0905021}, + abstract = { A method is described for producing monotone piecewise cubic interpolants to monotone data which is completely local and which is extremely simple to implement. } +} diff --git a/joss/paper.md b/joss/paper.md index f2d2a4be..66d15aa0 100644 --- a/joss/paper.md +++ b/joss/paper.md @@ -31,15 +31,30 @@ bibliography: paper.bib # Summary -Interpolations are used to estimate values between known data points using an approximate continuous function.DataInterpolations.jl is a Julia [@Bezanson2017] package containing 1D implementations of some of the most commonly used interpolation functions. These include Constant Interpolation, Linear Interpolation, Quadratic Interpolation, Lagrange Interpolation [@lagrange], Quadratic Splines, Cubic Splines [@Schoenberg1988], Akima Splines [@10.1145/321607.321609], Cubic Hermite Splines, Quintic Hermite Splines, B-Splines [@Curry1988] [@DEBOOR197250] and Regression based B-Splines. Along with these, the package also has methods to fit parameterized curves with the data points and Tikhonov regularization [@Tikhonov1943OnTS] [@amt-14-7909-2021] for obtaining smooth curves. The package also provides functionality to compute integrals and derivatives upto second order for those interpolations methods. +Interpolations are used to estimate values between known data points using an approximate continuous function.DataInterpolations.jl is a Julia [@Bezanson2017] package containing 1D implementations of some of the most commonly used interpolation functions. These include: + + - Constant Interpolation + - Linear Interpolation + - Quadratic Interpolation + - Lagrange Interpolation [@lagrange] + - Quadratic Splines + - Cubic Splines [@Schoenberg1988] + - Akima Splines [@10.1145/321607.321609] + - Cubic Hermite Splines + - Piecewise Cubic Hermite Interpolating Polynomial (PCHIP) [@doi:10.1137/0905021] + - Quintic Hermite Splines + - B-Splines [@Curry1988] [@DEBOOR197250] + - Regression based B-Splines + +and a continually growing list. Along with these, the package also has methods to fit parameterized curves with the data points and Tikhonov regularization [@Tikhonov1943OnTS] [@amt-14-7909-2021] for obtaining smooth curves. The package also provides functionality to compute integrals and derivatives upto second order for those interpolations methods. It is also automatic differentiation friendly. It can also be used symbolically with Symbolics.jl [@gowda2021high] and plugged into models defined using ModelingToolkit.jl [@ma2021modelingtoolkit]. # Statement of need -Interpolations are a very important component of many modeling workflows. In many models, inputs which are sampled or measured need to be represented as a continuous function or a smooth curve for simulation. In many scientific machine learning workflows, we need interpolations of data to learn continuous models. There already have been a few interpolation packages in Julia like Interpolations.jl but it has a limitation of assuming uniformly spaced data which is not usually the case with data collected from real world. DataInterpolations.jl provides fast interpolation methods for arbitrary spaced 1D data with a consistent and simple interface. It is also automatic differentiation friendly. It can also be used symbolically with Symbolics.jl [@gowda2021high] and plugged into models defined using ModelingToolkit.jl [@ma2021modelingtoolkit]. +Interpolations are a very important component of many modeling workflows. Often, sampled or measured inputs need to be transformed into continuous functions or smooth curves for simulation purposes. In many scientific machine learning workflows, interpolating data is essential to learn continuous models. DataInterpolations.jl can be used for facilitating these types of workflows. Several interpolation packages already exist in Julia, such as [Interpolations.jl](https://juliamath.github.io/Interpolations.jl/stable/), which primarily specializes in B-Splines and uniformly spaced data with some support for irregularly spaced data. In contrast, DataInterpolations.jl does not assume any specific structure in the data, offering greater flexibility for diverse datasets. [Interpolations.jl](https://juliamath.github.io/Interpolations.jl/stable/) also doesn't offer methods like Quadratic Interpolation, Lagrange Interpolation, Hermite Splines etc. [BasicInterpolators.jl](https://github.com/markmbaum/BasicInterpolators.jl) is more similar to DataInterpolations.jl, although it doesn't offer methods like B-Splines. Rest of the interpolation packages focus on particular methods like [BSplineKit.jl](https://github.com/jipolanco/BSplineKit.jl) for B-Splines, [FastChebInterp.jl](https://github.com/JuliaMath/FastChebInterp.jl) for Chebyshev interpolation, [PCHIPInterpolation](https://github.com/gerlero/PCHIPInterpolation.jl) for PCHIP interpolation etc. Additionally, DataInterpolations.jl includes many novel techniques for accelerating the interpolation searches with specialized caching, quasi-linear guessing, and more to improve the performance algorithmically, beyond the simple computational optimizations. In summary, DataInterpolations.jl is more generic from other packages and offers many fast interpolation methods for arbitrarily spaced 1D data, all within a consistent and simple interface. # Example -The following tutorials in the documentation [1](https://docs.sciml.ai/DataInterpolations/stable/methods/) provides how to define each of the interpolation methods and compute the value at any point. [2](https://docs.sciml.ai/DataInterpolations/stable/interface/) provides explanation for using the interface and interpolated objects for evaluating at any point, computing the derivative at any point and computing the integral between any two points. +The following tutorials in the documentation [1](https://docs.sciml.ai/DataInterpolations/stable/methods/) provides how to define each of the interpolation methods and compute the value at any point. [2](https://docs.sciml.ai/DataInterpolations/stable/interface/) provides explanation for using the interface and interpolated objects for evaluating at any point, computing the derivative at any point and computing the integral between any two points. [3](https://docs.sciml.ai/DataInterpolations/stable/symbolics/) provides how to use interpolation objects with Symbolics.jl and ModelingToolkit.jl. A simple demonstration here: diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index 19cb47c0..7f44c878 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -22,7 +22,11 @@ include("online.jl") include("show.jl") (interp::AbstractInterpolation)(t::Number) = _interpolate(interp, t) -(interp::AbstractInterpolation)(t::Number, i::Integer) = _interpolate(interp, t, i) +function (interp::AbstractInterpolation)(t::Number, i::Integer) + interp.idx_prev[] = i + _interpolate(interp, t) +end + function (interp::AbstractInterpolation)(t::AbstractVector) u = get_u(interp.u, t) interp(u, t) @@ -43,16 +47,14 @@ function get_u(u::AbstractMatrix, t) end function (interp::AbstractInterpolation)(u::AbstractMatrix, t::AbstractVector) - iguess = firstindex(interp.t) @inbounds for i in eachindex(t) - u[:, i], iguess = interp(t[i], iguess) + u[:, i] = interp(t[i]) end u end function (interp::AbstractInterpolation)(u::AbstractVector, t::AbstractVector) - iguess = firstindex(interp.t) @inbounds for i in eachindex(u, t) - u[i], iguess = interp(t[i], iguess) + u[i] = interp(t[i]) end u end @@ -89,7 +91,7 @@ end export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline, - BSplineInterpolation, BSplineApprox, CubicHermiteSpline, + BSplineInterpolation, BSplineApprox, CubicHermiteSpline, PCHIPInterpolation, QuinticHermiteSpline, LinearInterpolationIntInv, ConstantInterpolationIntInv # added for RegularizationSmooth, JJS 11/27/21 diff --git a/src/derivatives.jl b/src/derivatives.jl index 01eb18bb..75872095 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -18,7 +18,7 @@ function derivative(A, t, order = 1) end function _derivative(A::LinearInterpolation, t::Number, iguess) - idx = get_idx(A.t, t, iguess; idx_shift = -1, ub_shift = -2, side = :first) + idx = get_idx(A.t, t, iguess; idx_shift = -1, ub_shift = -1, side = :first) slope = get_parameters(A, idx) slope, idx end diff --git a/src/integral_inverses.jl b/src/integral_inverses.jl index 38c14b14..31d0853c 100644 --- a/src/integral_inverses.jl +++ b/src/integral_inverses.jl @@ -50,7 +50,7 @@ function invertible_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) return all(A.u .> 0) end -get_I(A::AbstractInterpolation) = isnothing(A.I) ? cumulative_integral(A) : A.I +get_I(A::AbstractInterpolation) = isempty(A.I) ? cumulative_integral(A, true) : A.I function invert_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) !invertible_integral(A) && throw(IntegralNotInvertibleError()) diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index 286bf6bc..7f6991b8 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -30,24 +30,14 @@ end function LinearInterpolation(u, t; extrapolate = false, cache_parameters = false) u, t = munge_data(u, t) - p = if cache_parameters - LinearParameterCache(u, t) - else - LinearParameterCache(nothing) - end - + p = LinearParameterCache(u, t, cache_parameters) A = LinearInterpolation(u, t, nothing, p, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) - end - - A + I = cumulative_integral(A, cache_parameters) + LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) end """ - QuadraticInterpolation(u, t, mode = :Forward; cache_parameters = false) + QuadraticInterpolation(u, t, mode = :Forward; extrapolate = false, cache_parameters = false) It is the method of interpolating between the data points using quadratic polynomials. For any point, three data points nearby are taken to fit a quadratic polynomial. Extrapolation extends the last quadratic polynomial on each side. @@ -82,20 +72,10 @@ end function QuadraticInterpolation(u, t, mode; extrapolate = false, cache_parameters = false) u, t = munge_data(u, t) - p = if cache_parameters - QuadraticParameterCache(u, t) - else - QuadraticParameterCache(nothing, nothing, nothing) - end - + p = QuadraticParameterCache(u, t, cache_parameters) A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) - end - - A + I = cumulative_integral(A, cache_parameters) + QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) end function QuadraticInterpolation(u, t; extrapolate = false, cache_parameters = false) @@ -103,7 +83,7 @@ function QuadraticInterpolation(u, t; extrapolate = false, cache_parameters = fa end """ - LagrangeInterpolation(u, t, n = length(t) - 1; extrapolate = false) + LagrangeInterpolation(u, t, n = length(t) - 1; extrapolate = false, safetycopy = true) It is the method of interpolation using Lagrange polynomials of (k-1)th order passing through all the data points where k is the number of data points. @@ -153,7 +133,7 @@ end """ AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) -It is a spline interpolation built from cubic polynomials. It forms a continuously differentiable function. For more details, refer: https://en.wikipedia.org/wiki/Akima_spline. +It is a spline interpolation built from cubic polynomials. It forms a continuously differentiable function. For more details, refer: [https://en.wikipedia.org/wiki/Akima_spline](https://en.wikipedia.org/wiki/Akima_spline). Extrapolation extends the last cubic polynomial on each side. ## Arguments @@ -215,13 +195,8 @@ function AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2 A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) - end - - A + I = cumulative_integral(A, cache_parameters) + AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) end """ @@ -261,13 +236,8 @@ function ConstantInterpolation( u, t; dir = :left, extrapolate = false, cache_parameters = false) u, t = munge_data(u, t) A = ConstantInterpolation(u, t, nothing, dir, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) - end - - A + I = cumulative_integral(A, cache_parameters) + ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) end """ @@ -330,20 +300,10 @@ function QuadraticSpline( d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s) z = tA \ d - p = if cache_parameters - QuadraticSplineParameterCache(z, t) - else - QuadraticSplineParameterCache(nothing) - end - + p = QuadraticSplineParameterCache(z, t, cache_parameters) A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) - end - - A + I = cumulative_integral(A, cache_parameters) + QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) end function QuadraticSpline( @@ -362,19 +322,11 @@ function QuadraticSpline( d = transpose(reshape(reduce(hcat, d_), :, s)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = if cache_parameters - QuadraticSplineParameterCache(z, t) - else - QuadraticSplineParameterCache(nothing) - end - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) - end - A + p = QuadraticSplineParameterCache(z, t, cache_parameters) + A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) end """ @@ -439,19 +391,11 @@ function CubicSpline(u::uType, 6(u[i + 1] - u[i]) / h[i + 1] - 6(u[i] - u[i - 1]) / h[i], 1:(n + 1)) z = tA \ d - p = if cache_parameters - CubicSplineParameterCache(u, h, z) - else - CubicSplineParameterCache(nothing, nothing) - end - A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) - if cache_parameters - I = cumulative_integral(A) - A = CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) - end - - A + p = CubicSplineParameterCache(u, h, z, cache_parameters) + A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) end function CubicSpline( @@ -471,26 +415,17 @@ function CubicSpline( d = transpose(reshape(reduce(hcat, d_), :, n + 1)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = if cache_parameters - CubicSplineParameterCache(u, h, z) - else - CubicSplineParameterCache(nothing, nothing) - end + p = CubicSplineParameterCache(u, h, z, cache_parameters) A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) - end - - A + I = cumulative_integral(A, cache_parameters) + CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) end """ - BSplineInterpolation(u, t, d, pVecType, knotVecType; extrapolate = false) + BSplineInterpolation(u, t, d, pVecType, knotVecType; extrapolate = false, safetycopy = true) -It is a curve defined by the linear combination of `n` basis functions of degree `d` where `n` is the number of data points. For more information, refer https://pages.mtu.edu/~shene/COURSES/cs3621/NOTES/spline/B-spline/bspline-curve.html. +It is a curve defined by the linear combination of `n` basis functions of degree `d` where `n` is the number of data points. For more information, refer [https://pages.mtu.edu/~shene/COURSES/cs3621/NOTES/spline/B-spline/bspline-curve.html](https://pages.mtu.edu/%7Eshene/COURSES/cs3621/NOTES/spline/B-spline/bspline-curve.html). Extrapolation is a constant polynomial of the end points on each side. ## Arguments @@ -613,10 +548,10 @@ function BSplineInterpolation( end """ - BSplineApprox(u, t, d, h, pVecType, knotVecType) + BSplineApprox(u, t, d, h, pVecType, knotVecType; extrapolate = false) It is a regression based B-spline. The argument choices are the same as the `BSplineInterpolation`, with the additional parameter `h < length(t)` which is the number of control points to use, with smaller `h` indicating more smoothing. -For more information, refer http://www.cad.zju.edu.cn/home/zhx/GM/009/00-bsia.pdf. +For more information, refer [http://www.cad.zju.edu.cn/home/zhx/GM/009/00-bsia.pdf](http://www.cad.zju.edu.cn/home/zhx/GM/009/00-bsia.pdf). Extrapolation is a constant polynomial of the end points on each side. ## Arguments @@ -798,23 +733,37 @@ end function CubicHermiteSpline(du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du) "Length of `u` is not equal to length of `du`." u, t = munge_data(u, t) - p = if cache_parameters - CubicHermiteParameterCache(du, u, t) - else - CubicHermiteParameterCache(nothing, nothing) - end + p = CubicHermiteParameterCache(du, u, t, cache_parameters) A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) +end - if cache_parameters - I = cumulative_integral(A) - A = CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) - end +""" + PCHIPInterpolation(u, t; extrapolate = false, safetycopy = true) + +It is a PCHIP Interpolation, which is a type of [`CubicHermiteSpline`](@ref) where the derivative values `du` are derived from the input data +in such a way that the interpolation never overshoots the data. See [here](https://www.mathworks.com/content/dam/mathworks/mathworks-dot-com/moler/interp.pdf), +section 3.4 for more details. + +## Arguments - A + - `u`: data points. + - `t`: time points. + +## Keyword Arguments + + - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. +""" +function PCHIPInterpolation(u, t; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + du = du_PCHIP(u, t) + CubicHermiteSpline(du, u, t; extrapolate, cache_parameters) end """ - QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, cache_parameters = false) + QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, safetycopy = true) It is a Quintic Hermite interpolation, which is a piece-wise fifth degree polynomial such that the value and the first and second derivative are equal to given values in the data points. @@ -851,17 +800,8 @@ end function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du)==length(ddu) "Length of `u` is not equal to length of `du` or `ddu`." u, t = munge_data(u, t) - p = if cache_parameters - QuinticHermiteParameterCache(ddu, du, u, t) - else - QuinticHermiteParameterCache(nothing, nothing, nothing) - end + p = QuinticHermiteParameterCache(ddu, du, u, t, cache_parameters) A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) - end - - A + I = cumulative_integral(A, cache_parameters) + QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) end diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 17b08328..2a2392bd 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -105,14 +105,15 @@ function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = : end end -function cumulative_integral(A) - if isempty(methods(_integral, (typeof(A), Any, Any))) - return nothing +function cumulative_integral(A, cache_parameters) + if cache_parameters && hasmethod(_integral, Tuple{typeof(A), Number, Number}) + integral_values = [_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx]) + for idx in 1:(length(A.t) - 1)] + pushfirst!(integral_values, zero(first(integral_values))) + cumsum(integral_values) + else + promote_type(eltype(A.u), eltype(A.t))[] end - integral_values = [_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx]) - for idx in 1:(length(A.t) - 1)] - pushfirst!(integral_values, zero(first(integral_values))) - return cumsum(integral_values) end function get_parameters(A::LinearInterpolation, idx) @@ -162,3 +163,38 @@ function get_parameters(A::QuinticHermiteSpline, idx) quintic_hermite_spline_parameters(A.ddu, A.du, A.u, A.t, idx) end end + +function du_PCHIP(u, t) + h = diff(u) + δ = h ./ diff(t) + s = sign.(δ) + + function _du(k) + sₖ₋₁, sₖ = if k == 1 + s[1], s[2] + elseif k == lastindex(t) + s[end - 1], s[end] + else + s[k - 1], s[k] + end + + if sₖ₋₁ == 0 && sₖ == 0 + zero(eltype(δ)) + elseif sₖ₋₁ == sₖ + if k == 1 + ((2 * h[1] + h[2]) * δ[1] - h[1] * δ[2]) / (h[1] + h[2]) + elseif k == lastindex(t) + ((2 * h[end] + h[end - 1]) * δ[end] - h[end] * δ[end - 1]) / + (h[end] + h[end - 1]) + else + w₁ = 2h[k] + h[k - 1] + w₂ = h[k] + 2h[k - 1] + δ[k - 1] * δ[k] * (w₁ + w₂) / (w₁ * δ[k] + w₂ * δ[k - 1]) + end + else + zero(eltype(δ)) + end + end + + return _du.(eachindex(t)) +end diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index 2820dc8f..0701b3a2 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -2,13 +2,28 @@ struct LinearParameterCache{pType} slope::pType end -function LinearParameterCache(u, t) - slope = linear_interpolation_parameters.(Ref(u), Ref(t), 1:(length(t) - 1)) - return LinearParameterCache(slope) +function LinearParameterCache(u, t, cache_parameters) + if cache_parameters + slope = linear_interpolation_parameters.(Ref(u), Ref(t), 1:(length(t) - 1)) + LinearParameterCache(slope) + else + # Compute parameters once to infer types + slope = linear_interpolation_parameters(u, t, 1) + LinearParameterCache(typeof(slope)[]) + end +end + +# Prevent e.g. Inf - Inf = NaN +function safe_diff(b, a::T) where {T} + b == a ? zero(T) : b - a end -function linear_interpolation_parameters(u, t, idx) - Δu = u isa AbstractMatrix ? u[:, idx + 1] - u[:, idx] : u[idx + 1] - u[idx] +function linear_interpolation_parameters(u::AbstractArray{T}, t, idx) where {T} + Δu = if u isa AbstractMatrix + [safe_diff(u[j, idx + 1], u[j, idx]) for j in 1:size(u)[1]] + else + safe_diff(u[idx + 1], u[idx]) + end Δt = t[idx + 1] - t[idx] slope = Δu / Δt slope = iszero(Δt) ? zero(slope) : slope @@ -21,11 +36,18 @@ struct QuadraticParameterCache{pType} l₂::pType end -function QuadraticParameterCache(u, t) - parameters = quadratic_interpolation_parameters.( - Ref(u), Ref(t), 1:(length(t) - 2)) - l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) - return QuadraticParameterCache(l₀, l₁, l₂) +function QuadraticParameterCache(u, t, cache_parameters) + if cache_parameters + parameters = quadratic_interpolation_parameters.( + Ref(u), Ref(t), 1:(length(t) - 2)) + l₀, l₁, l₂ = collect.(eachrow(stack(collect.(parameters)))) + QuadraticParameterCache(l₀, l₁, l₂) + else + # Compute parameters once to infer types + l₀, l₁, l₂ = quadratic_interpolation_parameters(u, t, 1) + pType = typeof(l₀) + QuadraticParameterCache(pType[], pType[], pType[]) + end end function quadratic_interpolation_parameters(u, t, idx) @@ -54,9 +76,15 @@ struct QuadraticSplineParameterCache{pType} σ::pType end -function QuadraticSplineParameterCache(z, t) - σ = quadratic_spline_parameters.(Ref(z), Ref(t), 1:(length(t) - 1)) - return QuadraticSplineParameterCache(σ) +function QuadraticSplineParameterCache(z, t, cache_parameters) + if cache_parameters + σ = quadratic_spline_parameters.(Ref(z), Ref(t), 1:(length(t) - 1)) + QuadraticSplineParameterCache(σ) + else + # Compute parameters once to infer types + σ = quadratic_spline_parameters(z, t, 1) + QuadraticSplineParameterCache(typeof(σ)[]) + end end function quadratic_spline_parameters(z, t, idx) @@ -69,11 +97,18 @@ struct CubicSplineParameterCache{pType} c₂::pType end -function CubicSplineParameterCache(u, h, z) - parameters = cubic_spline_parameters.( - Ref(u), Ref(h), Ref(z), 1:(size(u)[end] - 1)) - c₁, c₂ = collect.(eachrow(hcat(collect.(parameters)...))) - return CubicSplineParameterCache(c₁, c₂) +function CubicSplineParameterCache(u, h, z, cache_parameters) + if cache_parameters + parameters = cubic_spline_parameters.( + Ref(u), Ref(h), Ref(z), 1:(size(u)[end] - 1)) + c₁, c₂ = collect.(eachrow(stack(collect.(parameters)))) + CubicSplineParameterCache(c₁, c₂) + else + # Compute parameters once to infer types + c₁, c₂ = cubic_spline_parameters(u, h, z, 1) + pType = typeof(c₁) + CubicSplineParameterCache(pType[], pType[]) + end end function cubic_spline_parameters(u, h, z, idx) @@ -87,11 +122,18 @@ struct CubicHermiteParameterCache{pType} c₂::pType end -function CubicHermiteParameterCache(du, u, t) - parameters = cubic_hermite_spline_parameters.( - Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) - c₁, c₂ = collect.(eachrow(hcat(collect.(parameters)...))) - return CubicHermiteParameterCache(c₁, c₂) +function CubicHermiteParameterCache(du, u, t, cache_parameters) + if cache_parameters + parameters = cubic_hermite_spline_parameters.( + Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) + c₁, c₂ = collect.(eachrow(stack(collect.(parameters)))) + CubicHermiteParameterCache(c₁, c₂) + else + # Compute parameters once to infer types + c₁, c₂ = cubic_hermite_spline_parameters(du, u, t, 1) + pType = typeof(c₁) + CubicHermiteParameterCache(pType[], pType[]) + end end function cubic_hermite_spline_parameters(du, u, t, idx) @@ -111,11 +153,18 @@ struct QuinticHermiteParameterCache{pType} c₃::pType end -function QuinticHermiteParameterCache(ddu, du, u, t) - parameters = quintic_hermite_spline_parameters.( - Ref(ddu), Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) - c₁, c₂, c₃ = collect.(eachrow(hcat(collect.(parameters)...))) - return QuinticHermiteParameterCache(c₁, c₂, c₃) +function QuinticHermiteParameterCache(ddu, du, u, t, cache_parameters) + if cache_parameters + parameters = quintic_hermite_spline_parameters.( + Ref(ddu), Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) + c₁, c₂, c₃ = collect.(eachrow(stack(collect.(parameters)))) + QuinticHermiteParameterCache(c₁, c₂, c₃) + else + # Compute parameters once to infer types + c₁, c₂, c₃ = quintic_hermite_spline_parameters(ddu, du, u, t, 1) + pType = typeof(c₁) + QuinticHermiteParameterCache(pType[], pType[], pType[]) + end end function quintic_hermite_spline_parameters(ddu, du, u, t, idx) diff --git a/src/plot_rec.jl b/src/plot_rec.jl index a7bd8afc..6c576c49 100644 --- a/src/plot_rec.jl +++ b/src/plot_rec.jl @@ -16,7 +16,16 @@ function to_plottable(A::AbstractInterpolation; plotdensity = 10_000, denseplot end @recipe function f(A::AbstractInterpolation; plotdensity = 10_000, denseplot = true) - to_plottable(A; plotdensity = plotdensity, denseplot = denseplot) + @series begin + seriestype := :path + label --> string(nameof(typeof(A))) + to_plottable(A; plotdensity = plotdensity, denseplot = denseplot) + end + @series begin + seriestype := :scatter + label --> "Data points" + A.t, A.u + end end ################################################################################ @@ -35,18 +44,26 @@ end x, y, z; + extrapolate = false, + safetycopy = false, plotdensity = 10_000, denseplot = true) - seriestype := :path - - label --> "Linear fit" - - nx, ny = to_plottable(LinearInterpolation(y, x); + T = promote_type(eltype(y), eltype(x)) + nx, ny = to_plottable(LinearInterpolation(T.(y), T.(x); extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - - x := nx - y := ny + @series begin + seriestype := :path + label --> "LinearInterpolation" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end ######################################## @@ -57,18 +74,60 @@ end x, y, z; + mode = :Forward, + extrapolate = false, + safetycopy = false, plotdensity = 10_000, denseplot = true) - seriestype := :path + T = promote_type(eltype(y), eltype(x)) + nx, ny = to_plottable( + QuadraticInterpolation(T.(y), + T.(x), mode; extrapolate, safetycopy); + plotdensity = plotdensity, + denseplot = denseplot) + @series begin + seriestype := :path + label --> "QuadraticInterpolation" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end +end - label --> "Quadratic fit" +######################################## +# Lagrange Interpolation # +######################################## - nx, ny = to_plottable(QuadraticInterpolation(T.(y), - T.(x)); +@recipe function f(::Type{Val{:lagrange_interp}}, + x, y, z; + n = length(x) - 1, + extrapolate = false, + safetycopy = false, + plotdensity = 10_000, + denseplot = true) + T = promote_type(eltype(y), eltype(x)) + nx, ny = to_plottable(LagrangeInterpolation(T.(y), + T.(x), + n; extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - x := nx - y := ny + @series begin + seriestype := :path + label --> "LagrangeInterpolation" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end ######################################## @@ -79,96 +138,125 @@ end x, y, z; + extrapolate = false, + safetycopy = false, plotdensity = 10_000, denseplot = true) - seriestype := :path - - label --> "Quadratic Spline" - T = promote_type(eltype(y), eltype(x)) - nx, ny = to_plottable(QuadraticSpline(T.(y), - T.(x)); + T.(x); extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - - x := nx - y := ny + @series begin + seriestype := :path + label --> "QuadraticSpline" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end ######################################## -# Lagrange Interpolation # +# Cubic Spline # ######################################## -@recipe function f(::Type{Val{:lagrange_interp}}, - x, y, z; - n = length(x) - 1, +@recipe function f(::Type{Val{:cubic_spline}}, + x, + y, + z; + extrapolate = false, + safetycopy = false, plotdensity = 10_000, denseplot = true) - seriestype := :path - - label --> "Lagrange Fit" - T = promote_type(eltype(y), eltype(x)) - - nx, ny = to_plottable(LagrangeInterpolation(T.(y), - T.(x), - n); + nx, ny = to_plottable(CubicSpline(T.(y), + T.(x); extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - - x := nx - y := ny + @series begin + seriestype := :path + label --> "CubicSpline" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end ######################################## -# Cubic Spline # +# Akima interpolation # ######################################## -@recipe function f(::Type{Val{:cubic_spline}}, +@recipe function f(::Type{Val{:akima_interp}}, x, y, z; + extrapolate = false, + safetycopy = false, plotdensity = 10_000, denseplot = true) - seriestype := :path - - label --> "Cubic Spline" - T = promote_type(eltype(y), eltype(x)) - - nx, ny = to_plottable(CubicSpline(T.(y), - T.(x)); + nx, ny = to_plottable(AkimaInterpolation(T.(y), + T.(x); extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - x := nx - y := ny + @series begin + seriestype := :path + label --> "AkimaInterpolation" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end +######################################## +# B-spline Interpolation # +######################################## + @recipe function f(::Type{Val{:bspline_interp}}, x, y, z; d = 5, - pVec = :ArcLen, - knotVec = :Average, - plotdensity = length(x) * 6, + pVecType = :ArcLen, + knotVecType = :Average, + extrapolate = false, + safetycopy = false, + plotdensity = 10_000, denseplot = true) - seriestype := :path - - label --> "B-Spline" - - @show x y eltype(x) - - # T = promote_type(eltype(y), eltype(x)) - - nx, ny = to_plottable(BSplineInterpolation(T.(y), + T = promote_type(eltype(y), eltype(x)) + nx, ny = to_plottable( + BSplineInterpolation(T.(y), T.(x), d, - pVec, - knotVec); + pVecType, + knotVecType; extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - x := nx - y := ny + @series begin + seriestype := :path + label --> "BSplineInterpolation" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end ######################################## @@ -179,48 +267,133 @@ end x, y, z; d = 5, h = length(x) - 1, - pVec = :ArcLen, - knotVec = :Average, - plotdensity = length(x) * 6, + pVecType = :ArcLen, + knotVecType = :Average, + extrapolate = false, + safetycopy = false, + plotdensity = 10_000, denseplot = true) - seriestype := :path - - label --> "B-Spline" - T = promote_type(eltype(y), eltype(x)) - - nx, ny = to_plottable(BSplineApprox(T.(y), + nx, ny = to_plottable( + BSplineApprox(T.(y), T.(x), d, h, - pVec, - knotVec); + pVecType, + knotVecType; extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - x := nx - y := ny + @series begin + seriestype := :path + label --> "BSplineApprox" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end ######################################## -# Akima interpolation # +# Cubic Hermite Spline # ######################################## -@recipe function f(::Type{Val{:akima}}, +@recipe function f(::Type{Val{:cubic_hermite_spline}}, x, y, z; - plotdensity = length(x) * 6, + du = nothing, + extrapolate = false, + safetycopy = false, + plotdensity = 10_000, denseplot = true) - seriestype := :path + isnothing(du) && error("Provide `du` as a keyword argument.") + T = promote_type(eltype(y), eltype(x)) + nx, ny = to_plottable( + CubicHermiteSpline(T.(du), T.(y), + T.(x); extrapolate, safetycopy); + plotdensity = plotdensity, + denseplot = denseplot) + @series begin + seriestype := :path + label --> "CubicHermiteSpline" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end +end - label --> "Akima" +######################################## +# PCHIP Interpolation # +######################################## +@recipe function f(::Type{Val{:pchip_interp}}, + x, + y, + z; + extrapolate = false, + safetycopy = false, + plotdensity = 10_000, + denseplot = true) T = promote_type(eltype(y), eltype(x)) + nx, ny = to_plottable(PCHIPInterpolation(T.(y), + T.(x); extrapolate, safetycopy); + plotdensity = plotdensity, + denseplot = denseplot) + @series begin + seriestype := :path + label --> "PCHIP Interpolation" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end +end - nx, ny = to_plottable(AkimaInterpolation(T.(y), - T.(x)); +######################################## +# Quintic Hermite Spline # +######################################## + +@recipe function f(::Type{Val{:quintic_hermite_spline}}, + x, + y, + z; + du = nothing, + ddu = nothing, + extrapolate = false, + safetycopy = false, + plotdensity = 10_000, + denseplot = true) + (isnothing(du) || isnothing(ddu)) && + error("Provide `du` and `ddu` as keyword arguments.") + T = promote_type(eltype(y), eltype(x)) + nx, ny = to_plottable( + QuinticHermiteSpline(T.(ddu), T.(du), T.(y), + T.(x); extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - x := nx - y := ny + @series begin + seriestype := :path + label --> "QuinticHermiteSpline" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index 37351d0d..50abe4ac 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -82,6 +82,12 @@ end u = vcat(2.0collect(1:10)', 3.0collect(1:10)') test_derivatives( LinearInterpolation; args = [u, t], name = "Linear Interpolation (Matrix)") + + # Issue: https://github.com/SciML/DataInterpolations.jl/issues/303 + u = [3.0, 3.0] + t = [0.0, 2.0] + test_derivatives( + LinearInterpolation; args = [u, t], name = "Linear Interpolation with two points") end @testset "Quadratic Interpolation" begin diff --git a/test/interface.jl b/test/interface.jl index 5d02a22a..3e910547 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -1,20 +1,52 @@ using DataInterpolations -u = 2.0collect(1:10) -t = 1.0collect(1:10) -A = LinearInterpolation(u, t) +using Symbolics -for i in 1:10 - @test u[i] == A.u[i] -end +@testset "Interface" begin + u = 2.0collect(1:10) + t = 1.0collect(1:10) + A = LinearInterpolation(u, t) + + for i in 1:10 + @test u[i] == A.u[i] + end -for i in 1:10 - @test t[i] == A.t[i] + for i in 1:10 + @test t[i] == A.t[i] + end end -using Symbolics -u = 2.0collect(1:10) -t = 1.0collect(1:10) -A = LinearInterpolation(u, t) +@testset "Symbolics" begin + u = 2.0collect(1:10) + t = 1.0collect(1:10) + A = LinearInterpolation(u, t) + @variables t x(t) + substitute(A(t), Dict(t => x)) +end -@variables t x(t) -substitute(A(t), Dict(t => x)) +@testset "Type Inference" begin + u = 2.0collect(1:10) + t = 1.0collect(1:10) + methods = [ + ConstantInterpolation, LinearInterpolation, + QuadraticInterpolation, LagrangeInterpolation, + QuadraticSpline, CubicSpline, AkimaInterpolation + ] + @testset "$method" for method in methods + @inferred method(u, t) + end + @testset "BSplineInterpolation" begin + @inferred BSplineInterpolation(u, t, 3, :Uniform, :Uniform) + @inferred BSplineInterpolation(u, t, 3, :ArcLen, :Average) + end + @testset "BSplineApprox" begin + @inferred BSplineApprox(u, t, 3, 5, :Uniform, :Uniform) + @inferred BSplineApprox(u, t, 3, 5, :ArcLen, :Average) + end + du = ones(10) + ddu = zeros(10) + @testset "Hermite Splines" begin + @inferred CubicHermiteSpline(du, u, t) + @inferred PCHIPInterpolation(u, t) + @inferred QuinticHermiteSpline(ddu, du, u, t) + end +end diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 9562c7b4..53fff25e 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -160,6 +160,13 @@ end @test A(5.5) == fill(11.0) @test A(11) == fill(22) + # Test constant -Inf interpolation + u = [-Inf, -Inf] + t = [0.0, 1.0] + A = LinearInterpolation(u, t) + @test A(0.0) == -Inf + @test A(0.5) == -Inf + # Test extrapolation u = 2.0collect(1:10) t = 1.0collect(1:10) @@ -169,6 +176,7 @@ end A = LinearInterpolation(u, t) @test_throws DataInterpolations.ExtrapolationError A(-1.0) @test_throws DataInterpolations.ExtrapolationError A(11.0) + @test_throws DataInterpolations.ExtrapolationError A([-1.0, 11.0]) end @testset "Quadratic Interpolation" begin @@ -669,6 +677,18 @@ end @test_throws AssertionError CubicHermiteSpline(du, u, t) end +@testset "PCHIPInterpolation" begin + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + t = [0.0, 62.25, 109.66, 162.66, 205.8, 250.0] + A = PCHIPInterpolation(u, t) + @test A isa CubicHermiteSpline + ts = 0.0:0.1:250.0 + us = A(ts) + @test all(minimum(u) .<= us) + @test all(maximum(u) .>= us) + @test all(A.du[3:4] .== 0.0) +end + @testset "Quintic Hermite Spline" begin test_interpolation_type(QuinticHermiteSpline) diff --git a/test/online_tests.jl b/test/online_tests.jl index 3ae6438e..1872e0cc 100644 --- a/test/online_tests.jl +++ b/test/online_tests.jl @@ -6,9 +6,11 @@ u1 = [0.0, 1.0, 0.0] t2 = [4.0, 5.0, 6.0] u2 = [1.0, 2.0, 1.0] -ts = 1.0:0.5:6.0 +ts_append = 1.0:0.5:6.0 +ts_push = 1.0:0.5:4.0 -for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolation] +@testset "$method" for method in [ + LinearInterpolation, QuadraticInterpolation, ConstantInterpolation] func1 = method(copy(u1), copy(t1); cache_parameters = true) append!(func1, u2, t2) func2 = method(vcat(u1, u2), vcat(t1, t2); cache_parameters = true) @@ -17,7 +19,7 @@ for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolatio for name in propertynames(func1.p) @test getfield(func1.p, name) == getfield(func2.p, name) end - @test func1(ts) == func2(ts) + @test func1(ts_append) == func2(ts_append) @test func1.I == func2.I func1 = method(copy(u1), copy(t1); cache_parameters = true) @@ -28,6 +30,6 @@ for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolatio for name in propertynames(func1.p) @test getfield(func1.p, name) == getfield(func2.p, name) end - @test func1(ts) == func2(ts) + @test func1(ts_push) == func2(ts_push) @test func1.I == func2.I end From a5e0f2061a3e0dfa8492e2d1365efbcf878b0207 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 27 Jul 2024 21:08:17 +0200 Subject: [PATCH 23/35] Zygote compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ccef602b..9219a059 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ SafeTestsets = "0.1" StableRNGs = "1" Symbolics = "5.29" Test = "1" -Zygote = "0.6" +Zygote = "0" julia = "1.10" [extras] From 8f44fa36f0b0f426b07b327f2235063d9ca2a4ff Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 27 Jul 2024 21:11:27 +0200 Subject: [PATCH 24/35] Zygote compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9219a059..baf7d6ac 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ SafeTestsets = "0.1" StableRNGs = "1" Symbolics = "5.29" Test = "1" -Zygote = "0" +Zygote = "0.*" julia = "1.10" [extras] From 53905c141302e681b69a1e9c368b873f2c8b80f7 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 27 Jul 2024 21:13:57 +0200 Subject: [PATCH 25/35] Zygote compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index baf7d6ac..518bfc2f 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ SafeTestsets = "0.1" StableRNGs = "1" Symbolics = "5.29" Test = "1" -Zygote = "0.*" +Zygote = "^0" julia = "1.10" [extras] From 8a3fd7ba480fa7b911d646025c150b51231dc25a Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 27 Jul 2024 18:03:17 -0400 Subject: [PATCH 26/35] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 518bfc2f..ccef602b 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ SafeTestsets = "0.1" StableRNGs = "1" Symbolics = "5.29" Test = "1" -Zygote = "^0" +Zygote = "0.6" julia = "1.10" [extras] From b8370765536e3653d6bb1a8c38e71ab4c7bbe2dc Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 27 Jul 2024 18:03:33 -0400 Subject: [PATCH 27/35] Update Project.toml --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index ccef602b..c9b7a2e9 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Optim = "429524aa-4258-5aef-a3af-852621145aeb" RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DataInterpolationsChainRulesCoreExt = "ChainRulesCore" From e0cce0b491e7c05e1e3cab9876f105a0ecdad442 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sun, 28 Jul 2024 17:04:04 +0200 Subject: [PATCH 28/35] Add examples of how to speed up gradients w.r.t. u, add more zygote tests --- ext/DataInterpolationsChainRulesCoreExt.jl | 78 ++++++++++++++++++++-- test/zygote_tests.jl | 65 +++++++++++++----- 2 files changed, 124 insertions(+), 19 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 34e27841..9c33b09c 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -1,19 +1,87 @@ module DataInterpolationsChainRulesCoreExt - if isdefined(Base, :get_extension) using DataInterpolations: _interpolate, derivative, AbstractInterpolation, + LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, - BSplineInterpolation, BSplineApprox + BSplineInterpolation, BSplineApprox, get_idx, get_parameters, + _quad_interp_indices using ChainRulesCore else using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, + LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, - BSplineInterpolation, BSplineApprox + BSplineInterpolation, BSplineApprox, get_parameters, + _quad_interp_indices using ..ChainRulesCore end +function ChainRulesCore.rrule( + ::Type{LinearInterpolation}, u, t, I, p, extrapolate, cache_parameters) + A = LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) + function LinearInterpolation_pullback(ΔA) + df = NoTangent() + du = ΔA.u + dt = NoTangent() + dI = NoTangent() + dp = NoTangent() + dextrapolate = NoTangent() + dcache_parameters = NoTangent() + df, du, dt, dI, dp, dextrapolate, dcache_parameters + end + + A, LinearInterpolation_pullback +end + +function ChainRulesCore.rrule( + ::Type{QuadraticInterpolation}, u, t, I, p, mode, extrapolate, cache_parameters) + A = QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) + function LinearInterpolation_pullback(ΔA) + df = NoTangent() + du = ΔA.u + dt = NoTangent() + dI = NoTangent() + dp = NoTangent() + dmode = NoTangent() + dextrapolate = NoTangent() + dcache_parameters = NoTangent() + df, du, dt, dI, dp, dmode, dextrapolate, dcache_parameters + end + + A, LinearInterpolation_pullback +end + +function u_tangent(A::LinearInterpolation, t, Δ) + out = zero(A.u) + idx = get_idx(A.t, t, A.idx_prev[]) + t_factor = (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx]) + out[idx] = Δ * (one(eltype(out)) - t_factor) + out[idx + 1] = Δ * t_factor + out +end + +function u_tangent(A::QuadraticInterpolation, t, Δ) + out = zero(A.u) + i₀, i₁, i₂ = _quad_interp_indices(A, t, A.idx_prev[]) + t₀ = A.t[i₀] + t₁ = A.t[i₁] + t₂ = A.t[i₂] + Δt₀ = t₁ - t₀ + Δt₁ = t₂ - t₁ + Δt₂ = t₂ - t₀ + out[i₀] = Δ * (t - A.t[i₁]) * (t - A.t[i₂]) / (Δt₀ * Δt₂) + out[i₁] = -Δ * (t - A.t[i₀]) * (t - A.t[i₂]) / (Δt₀ * Δt₁) + out[i₂] = Δ * (t - A.t[i₀]) * (t - A.t[i₁]) / (Δt₂ * Δt₁) + out +end + +function u_tangent(A, t, Δ) + NoTangent() +end + function ChainRulesCore.rrule(::typeof(_interpolate), A::Union{ + LinearInterpolation, + QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, @@ -21,7 +89,9 @@ function ChainRulesCore.rrule(::typeof(_interpolate), }, t::Number) deriv = derivative(A, t) - interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ) + function interpolate_pullback(Δ) + (NoTangent(), Tangent{typeof(A)}(; u = u_tangent(A, t, Δ)), deriv * Δ) + end return _interpolate(A, t), interpolate_pullback end diff --git a/test/zygote_tests.jl b/test/zygote_tests.jl index 6887ddd2..1a7fc447 100644 --- a/test/zygote_tests.jl +++ b/test/zygote_tests.jl @@ -2,8 +2,8 @@ using DataInterpolations using ForwardDiff using Zygote -function test_zygote(method, u, t; args = [], kwargs = [], name::String) - func = method(args..., u, t; kwargs..., extrapolate = true) +function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name::String) + func = method(args..., u, t, args_after...; kwargs..., extrapolate = true) (; u, t) = func trange = collect(range(minimum(t) - 5.0, maximum(t) + 5.0, step = 0.1)) trange_exclude = filter(x -> !in(x, t), trange) @@ -11,28 +11,30 @@ function test_zygote(method, u, t; args = [], kwargs = [], name::String) for _t in trange_exclude adiff = DataInterpolations.derivative(func, _t) zdiff = only(Zygote.gradient(func, _t)) - zdiff == nothing && (zdiff = 0.0) + isnothing(zdiff) && (zdiff = 0.0) @test adiff ≈ zdiff end end - @testset "$name, derivatives w.r.t. u" begin - function f(u) - A = method(args..., u, t; kwargs..., extrapolate = true) - out = zero(eltype(u)) - for _t in trange - out += A(_t) + if method ∉ [LagrangeInterpolation, BSplineInterpolation, BSplineApprox] + @testset "$name, derivatives w.r.t. u" begin + function f(u) + A = method(args..., u, t, args_after...; kwargs..., extrapolate = true) + out = zero(eltype(u)) + for _t in trange + out += A(_t) + end + out end - out + zgrad = only(Zygote.gradient(f, u)) + fgrad = ForwardDiff.gradient(f, u) + @test zgrad ≈ fgrad end - zgrad = only(Zygote.gradient(f, u)) - fgrad = ForwardDiff.gradient(f, u) - @test zgrad ≈ fgrad end end @testset "LinearInterpolation" begin - u = vcat(collect(1:5), 2 * collect(6:10)) - t = 1.0collect(1:10) + u = vcat(collect(1.0:5.0), 2 * collect(6.0:10.0)) + t = collect(1.0:10.0) test_zygote( LinearInterpolation, u, t; name = "Linear Interpolation") end @@ -64,3 +66,36 @@ end test_zygote( QuinticHermiteSpline, u, t, args = [ddu, du], name = "Quintic Hermite Spline") end + +@testset "Quadratic Spline" begin + u = [1.0, 4.0, 9.0, 16.0] + t = [1.0, 2.0, 3.0, 4.0] + test_zygote(QuadraticSpline, u, t, name = "Quadratic Spline") +end + +@testset "Lagrange Interpolation" begin + u = [1.0, 4.0, 9.0] + t = [1.0, 2.0, 3.0] + test_zygote(LagrangeInterpolation, u, t, name = "Lagrange Interpolation") +end + +@testset "Constant Interpolation" begin + u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] + t = collect(0.0:10.0) + test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation") +end + +@testset "Cubic Spline" begin + u = [0.0, 1.0, 3.0] + t = [-1.0, 0.0, 1.0] + test_zygote(CubicSpline, u, t, name = "Cubic Spline") +end + +@testset "BSplines" begin + t = [0, 62.25, 109.66, 162.66, 205.8, 252.3] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + test_zygote(BSplineInterpolation, u, t; args_after = [2, :Uniform, :Uniform], + name = "BSpline Interpolation") + test_zygote(BSplineApprox, u, t; args_after = [2, 4, :Uniform, :Uniform], + name = "BSpline approximation") +end From 8e2e73fc643787e83753f5b386852f5d7e29323d Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 28 Jul 2024 16:03:59 -0400 Subject: [PATCH 29/35] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c9b7a2e9..b581252a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DataInterpolations" uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" -version = "5.3.1" +version = "5.4.0" [deps] FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" From bf37761ec5fa3bc9e61339ab6b95363bb0e3911f Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Mon, 29 Jul 2024 05:45:14 +0000 Subject: [PATCH 30/35] build: bump zygote compat to get downgrade tests working --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b581252a..d7a44eb0 100644 --- a/Project.toml +++ b/Project.toml @@ -40,7 +40,7 @@ SafeTestsets = "0.1" StableRNGs = "1" Symbolics = "5.29" Test = "1" -Zygote = "0.6" +Zygote = "0.6.70" julia = "1.10" [extras] From 21516aab9d88ad0965cb26e6ef6e802e60a676b3 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Mon, 29 Jul 2024 06:18:34 +0000 Subject: [PATCH 31/35] build: bump ChainRulesCore compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d7a44eb0..e755802a 100644 --- a/Project.toml +++ b/Project.toml @@ -25,7 +25,7 @@ DataInterpolationsSymbolicsExt = "Symbolics" [compat] Aqua = "0.8" -ChainRulesCore = "1.18" +ChainRulesCore = "1.24" FindFirstFunctions = "1.1" FiniteDifferences = "0.12.31" ForwardDiff = "0.10.36" From a3e82f10265ded9846a9a9da6da6229912073755 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Mon, 29 Jul 2024 09:08:54 +0000 Subject: [PATCH 32/35] refactor: use register symbolic for interpolation objects --- ext/DataInterpolationsSymbolicsExt.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ext/DataInterpolationsSymbolicsExt.jl b/ext/DataInterpolationsSymbolicsExt.jl index 106535ef..b144da22 100644 --- a/ext/DataInterpolationsSymbolicsExt.jl +++ b/ext/DataInterpolationsSymbolicsExt.jl @@ -12,8 +12,7 @@ else using ..Symbolics: Num, unwrap, SymbolicUtils end -(interp::AbstractInterpolation)(t::Num) = SymbolicUtils.term(interp, unwrap(t)) -SymbolicUtils.promote_symtype(t::AbstractInterpolation, _...) = Real +@register_symbolic (interp::AbstractInterpolation)(t) Base.nameof(interp::AbstractInterpolation) = :Interpolation function derivative(interp::AbstractInterpolation, t::Num, order = 1) From 245fe5151ba75ffdb768ed62c788baee5a303c22 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Mon, 29 Jul 2024 09:11:33 +0000 Subject: [PATCH 33/35] test: add tests for symbolic interpolation compositions --- test/interface.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/interface.jl b/test/interface.jl index 3e910547..e7b2b81b 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -18,9 +18,14 @@ end @testset "Symbolics" begin u = 2.0collect(1:10) t = 1.0collect(1:10) - A = LinearInterpolation(u, t) + A = LinearInterpolation(u, t; extrapolate = true) + B = LinearInterpolation(u .^ 2, t; extrapolate = true) @variables t x(t) substitute(A(t), Dict(t => x)) + t_val = 2.7 + @test substitute(A(t), Dict(t => t_val)) == A(t_val) + @test substitute(B(A(t)), Dict(t => t_val)) == B(A(t_val)) + @test substitute(A(B(A(t))), Dict(t => t_val)) == A(B(A(t_val))) end @testset "Type Inference" begin From 71347ccd686781aaef594fdabbd3522d69ce2223 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan <35105271+sathvikbhagavan@users.noreply.github.com> Date: Mon, 29 Jul 2024 18:13:09 +0530 Subject: [PATCH 34/35] build: bump major version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e755802a..17fdd131 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DataInterpolations" uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" -version = "5.4.0" +version = "6.0.0" [deps] FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" From be7366c586ec801d8ccd9dde57d3452b371f59ca Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Mon, 29 Jul 2024 12:49:39 +0000 Subject: [PATCH 35/35] build(docs): bump DataInterpolations compat --- docs/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index ffa0d3ad..540ffc47 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -11,7 +11,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [compat] -DataInterpolations = "5" +DataInterpolations = "6" Documenter = "1" ModelingToolkit = "9" ModelingToolkitStandardLibrary = "2"