Skip to content

Commit

Permalink
Dispatch on Val(Ns)
Browse files Browse the repository at this point in the history
  • Loading branch information
pvillacorta committed Jun 18, 2024
1 parent 2bce2e4 commit 872d16a
Showing 1 changed file with 43 additions and 35 deletions.
78 changes: 43 additions & 35 deletions KomaMRIBase/src/datatypes/phantom/motion/ArbitraryMotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,23 @@
# Degree = Linear,Cubic....
# ETPType = Periodic, Flat...

const Interpolator = Interpolations.GriddedInterpolation{
T,N,V,Itp,K
const Interpolator1D = Interpolations.GriddedInterpolation{
T,1,V,Itp,K
} where {
T<:Real,
N,
V<:AbstractArray{T},
Itp<:Tuple{ Vararg{ Union{Interpolations.Gridded{Linear{Throw{OnGrid}}},Interpolations.NoInterp} } },
K<:Tuple{Vararg{AbstractRange{T}}},
Itp<:Tuple{Interpolations.Gridded{Linear{Throw{OnGrid}}}},
K<:Tuple{AbstractRange{T}},
}

function GriddedInterpolation(
nodes::NType,
A::AType
) where {T<:Real, AType<:AbstractArray{T}, NType<:Tuple{Vararg{AbstractRange{T}}}}
Ns, _ = size(A)
if Ns > 1
ITPType = Tuple{NoInterp, Gridded{Linear{Throw{OnGrid}}}}
return Interpolations.GriddedInterpolation{T, 2, typeof(A), ITPType, typeof(nodes)}(nodes, A, (NoInterp(), Gridded(Linear())))
else
ITPType = Tuple{Gridded{Linear{Throw{OnGrid}}}}
return Interpolations.GriddedInterpolation{T, 1, typeof(A[:]), ITPType, typeof((nodes[2], ))}((nodes[2], ), A[:], (Gridded(Linear()), ))
end
end
const Interpolator2D = Interpolations.GriddedInterpolation{
T,2,V,Itp,K
} where {
T<:Real,
V<:AbstractArray{T},
Itp<:Tuple{Interpolations.NoInterp, Interpolations.Gridded{Linear{Throw{OnGrid}}}},
K<:Tuple{AbstractRange{T}, AbstractRange{T}},
}

"""
motion = ArbitraryMotion(period_durations, dx, dy, dz)
Expand Down Expand Up @@ -73,6 +67,11 @@ function Base.getindex(
)
return ArbitraryMotion(motion.t_start, motion.t_end, motion.dx[p,:], motion.dy[p,:], motion.dz[p,:])
end
function Base.view(

Check warning on line 70 in KomaMRIBase/src/datatypes/phantom/motion/ArbitraryMotion.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRIBase/src/datatypes/phantom/motion/ArbitraryMotion.jl#L70

Added line #L70 was not covered by tests
motion::ArbitraryMotion, p::Union{AbstractRange,AbstractVector,Colon}
)
return ArbitraryMotion(motion.t_start, motion.t_end, @view(motion.dx[p,:]), @view(motion.dy[p,:]), @view(motion.dz[p,:]))

Check warning on line 73 in KomaMRIBase/src/datatypes/phantom/motion/ArbitraryMotion.jl

View check run for this annotation

Codecov / codecov/patch

KomaMRIBase/src/datatypes/phantom/motion/ArbitraryMotion.jl#L73

Added line #L73 was not covered by tests
end

Base.:(==)(m1::ArbitraryMotion, m2::ArbitraryMotion) = reduce(&, [getfield(m1, field) == getfield(m2, field) for field in fieldnames(ArbitraryMotion)])
Base.:()(m1::ArbitraryMotion, m2::ArbitraryMotion) = reduce(&, [getfield(m1, field) getfield(m2, field) for field in fieldnames(ArbitraryMotion)])
Expand All @@ -89,27 +88,36 @@ function times(motion::ArbitraryMotion)
return range(motion.t_start, motion.t_end, length=size(motion.dx, 2))
end

function GriddedInterpolation(nodes, A, ITP)
return Interpolations.GriddedInterpolation{eltype(A), length(nodes), typeof(A), typeof(ITP), typeof(nodes)}(nodes, A, ITP)
end

function interpolate(motion::ArbitraryMotion{T}, Ns::Val{1}) where {T<:Real}
_, Nt = size(motion.dx)
t = range(zero(T), oneunit(T), Nt)
itpx = GriddedInterpolation((t, ), motion.dx[:], (Gridded(Linear()), ))
itpy = GriddedInterpolation((t, ), motion.dy[:], (Gridded(Linear()), ))
itpz = GriddedInterpolation((t, ), motion.dz[:], (Gridded(Linear()), ))
return itpx, itpy, itpz
end

function interpolate(motion::ArbitraryMotion{T}) where {T<:Real}
function interpolate(motion::ArbitraryMotion{T}, Ns::Val) where {T<:Real}
Ns, Nt = size(motion.dx)
itpx = GriddedInterpolation((one(T):Ns, range(0,one(T),Nt)), motion.dx)
itpy = GriddedInterpolation((one(T):Ns, range(0,one(T),Nt)), motion.dy)
itpz = GriddedInterpolation((one(T):Ns, range(0,one(T),Nt)), motion.dz)
id = one(T):Ns
t = range(zero(T), oneunit(T), Nt)
itpx = GriddedInterpolation((id, t), motion.dx, (NoInterp(), Gridded(Linear())))
itpy = GriddedInterpolation((id, t), motion.dy, (NoInterp(), Gridded(Linear())))
itpz = GriddedInterpolation((id, t), motion.dz, (NoInterp(), Gridded(Linear())))
return itpx, itpy, itpz
end

function resample(
itpx::Interpolator{T},
itpy::Interpolator{T},
itpz::Interpolator{T},
t::AbstractArray{T}
) where {T<:Real}
Ns = ndims(itpx.coefs) == 1 ? 1 : size(itpx.coefs,1)
if Ns > 1
return itpx.(1:Ns, t), itpy.(1:Ns, t), itpz.(1:Ns, t)
else
return itpx.(t), itpy.(t), itpz.(t)
end
function resample(itpx::Interpolator1D{T}, itpy::Interpolator1D{T}, itpz::Interpolator1D{T}, t::AbstractArray{T}) where {T<:Real}
return itpx.(t), itpy.(t), itpz.(t)
end

function resample(itpx::Interpolator2D{T}, itpy::Interpolator2D{T}, itpz::Interpolator2D{T}, t::AbstractArray{T}) where {T<:Real}
Ns = size(itpx.coefs, 1)
return itpx.(1:Ns, t), itpy.(1:Ns, t), itpz.(1:Ns, t)
end

function get_spin_coords(
Expand All @@ -119,7 +127,7 @@ function get_spin_coords(
z::AbstractVector{T},
t::AbstractArray{T}
) where {T<:Real}
motion_functions = interpolate(motion)
motion_functions = interpolate(motion, Val(size(x,1)))
ux, uy, uz = resample(motion_functions..., unit_time(t, motion.t_start, motion.t_end))
return x .+ ux, y .+ uy, z .+ uz
end

0 comments on commit 872d16a

Please sign in to comment.