Skip to content

Commit

Permalink
Merge pull request #568 from N5N3/cbfix
Browse files Browse the repository at this point in the history
Some recursion tuning to allow more eager inference.
  • Loading branch information
mkitti authored Nov 23, 2023
2 parents 43fe00f + 21f8b76 commit b47eddc
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Interpolations"
uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
version = "0.15.0"
version = "0.15.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
31 changes: 20 additions & 11 deletions src/Interpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,25 +278,34 @@ struct InterpGetindex{N,A<:AbstractArray{<:Any,N}}
InterpGetindex(A::AbstractArray) = new{ndims(A),typeof(A)}(A)
end
@inline Base.getindex(A::InterpGetindex{N}, I::Vararg{Union{Int,WeightedIndex},N}) where {N} =
interp_getindex(A.coeffs, ntuple(_ -> 0, Val(N)), map(indexflag, I)...)
indexflag(I::Int) = I
@inline indexflag(I::WeightedIndex) = indextuple(I), weights(I)
interp_getindex(A.coeffs, ntuple(zero, Val(N)), map(indexflag, I)...)
@inline indexflag(I) = indextuple(I), weights(I)

# Direct recursion would allow more eager inference before julia 1.11.
# Normalize all index into the same format.
struct One end # Singleton for express weights of no-interp dims
indextuple(I::Int) = (I,)
weights(::Int) = (One(),)

struct Zero end # Singleton for dim expansion termination

# A recursion-based `interp_getindex`, which follows a "move processed indexes to the back" strategy
# `I` contains the processed index, and (wi1, wis...) contains the yet-to-be-processed indexes
# Here we meet a no-interp dim, just append the index to `I`'s end.
@inline interp_getindex(A, I, wi1::Int, wis...) =
interp_getindex(A, (Base.tail(I)..., wi1), wis...)
# Here we handle the expansion of a single dimension.
@inline interp_getindex(A, I, wi1::NTuple{2,Tuple{Any,Vararg{Any,N}}}, wis...) where {N} =
wi1[2][end] * interp_getindex(A, (Base.tail(I)..., wi1[1][end]), wis...) +
interp_getindex(A, I, map(Base.front, wi1), wis...)
@inline interp_getindex(A, I, wi1::NTuple{2,Tuple{Any}}, wis...) =
wi1[2][1] * interp_getindex(A, (Base.tail(I)..., wi1[1][1]), wis...)
@inline function interp_getindex(A, I, (is, ws)::NTuple{2,Tuple}, wis...)
itped1 = interp_getindex(A, (Base.tail(I)..., is[end]), wis...)
witped = interp_getindex(A, I, (Base.front(is), Base.front(ws)), wis...)
_weight_itp(ws[end], itped1, witped)
end
interp_getindex(_, _, ::NTuple{2,Tuple{}}, ::Vararg) = Zero()
# Termination
@inline interp_getindex(A::AbstractArray{T,N}, I::Dims{N}) where {T,N} =
@inbounds A[I...] # all bounds-checks have already happened

_weight_itp(w, i, wir) = w * i + wir
_weight_itp(::One, i, ::Zero) = i
_weight_itp(w, i, ::Zero) = w * i

"""
w = value_weights(degree, δx)
Expand Down
17 changes: 9 additions & 8 deletions src/b-splines/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,12 @@ function weightedindexes(parts::Vararg{Union{Int,GradParts},N}) where N
slot_substitute(parts, map(positions, parts), map(valuecoefs, parts), map(gradcoefs, parts))
end

# Skip over NoInterp dimensions
slot_substitute(kind::Tuple{Int,Vararg{Any}}, p, v, g) = slot_substitute(Base.tail(kind), p, v, g)
# Substitute the dth dimension's gradient coefs for the remaining coefs
slot_substitute(kind, p, v, g) = (map(maybe_weightedindex, p, substitute_ruled(v, kind, g)), slot_substitute(Base.tail(kind), p, v, g)...)
function slot_substitute(kind, p, v, g)
rest = slot_substitute(Base.tail(kind), p, v, g)
kind[1] isa Int && return rest # Skip over NoInterp dimensions
(map(maybe_weightedindex, p, substitute_ruled(v, kind, g)), rest...)
end
# Termination
slot_substitute(kind::Tuple{}, p, v, g) = ()

Expand All @@ -132,15 +134,14 @@ function _column(kind1::K, kind2::K, p, v, g, h) where {K<:Tuple}
ss = substitute_ruled(v, kind1, h)
(map(maybe_weightedindex, p, ss), _column(Base.tail(kind1), kind2, p, v, g, h)...)
end
_column(kind1::K, kind2::K, p, v, g, h) where {K<:Tuple{Int,Vararg}} = () # Skip over NoInterp dimensions
function _column(kind1::Tuple, kind2::Tuple, p, v, g, h)
rest = _column(Base.tail(kind1), kind2, p, v, g, h)
kind1[1] isa Int && return rest # Skip over NoInterp dimensions
ss = substitute_ruled(substitute_ruled(v, kind1, g), kind2, g)
(map(maybe_weightedindex, p, ss), _column(Base.tail(kind1), kind2, p, v, g, h)...)
(map(maybe_weightedindex, p, ss), rest...)
end
_column(::Tuple{}, ::Tuple, p, v, g, h) = ()
# Skip over NoInterp dimensions
slot_substitute(kind::Tuple{Int,Vararg{Any}}, p, v, g, h) = slot_substitute(Base.tail(kind), p, v, g, h)
_column(kind1::Tuple{Int,Vararg{Any}}, kind2::Tuple, p, v, g, h) =
_column(Base.tail(kind1), kind2, p, v, g, h)

weightedindex_parts(fs::F, itpflag::BSpline, ax, x) where F =
weightedindex_parts(fs, degree(itpflag), ax, x)
Expand Down
2 changes: 1 addition & 1 deletion test/issues/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ using Interpolations, Test, ForwardDiff
end
@testset "issue 469" begin
# We have different inference result on different version.
max_dim = VERSION < v"1.3" ? 3 : isdefined(Base, :Any32) ? 7 : 5
max_dim = isdefined(Base, :Any32) ? 7 : 5
for dims in 3:max_dim
A = zeros(Float64, ntuple(_ -> 5, dims))
itp = interpolate(A, BSpline(Quadratic(Reflect(OnCell()))))
Expand Down
12 changes: 12 additions & 0 deletions test/nointerp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,15 @@
# @test ae[0,1] === NaN
# @test_throws InexactError ae(1.5,2)
end

@testset "Stability of mixtrue with NoInterp and Interp" begin
A = zeros(Float64, 5, 5, 5, 5, 5, 5, 5)
st = BSpline(Quadratic(Reflect(OnCell()))), NoInterp(),
BSpline(Linear()), NoInterp(),
BSpline(Quadratic()), NoInterp(),
BSpline(Quadratic(Reflect(OnCell())))
itp = interpolate(A, st)
@test (@inferred Interpolations.hessian(itp, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)) == zeros(4,4)
@test (@inferred Interpolations.gradient(itp, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)) == zeros(4)
@test (@inferred itp(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)) == 0
end

1 comment on commit b47eddc

@N5N3
Copy link
Contributor

@N5N3 N5N3 commented on b47eddc Nov 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we have blocking the release for quite a while. @mkitti

Please sign in to comment.