Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

addec CUDA.jl support for czt #44

Merged
merged 3 commits into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 82 additions & 57 deletions src/czt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export czt, iczt, plan_czt

"""
get_kernel_1d(RT::Type, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N), extra_phase=0.0, global_phase=0.0)
get_kernel_1d(arr::AbstractArray{T,D}, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N), extra_phase=0.0, global_phase=0.0) where {T,D}

calculates the kernel for the Bluestein algorithm. Note that the length depends on the destination size.
Note the the resulting kernel-size is computed based on the minimum required length for the task.
Expand All @@ -21,12 +21,16 @@ The code is based on Rabiner, Schafer & Rader 1969, IEEE Trans. on Audio and El
returns: a tuple of three arrays for the initial multiplication (A*W), the convolution
(already fourier-transformed) and the post multiplication.
"""
function get_kernel_1d(RT::Type, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N), extra_phase=0.0, global_phase=0.0)
# intorduce also sscale ??
# the size needed to avoid wrap
function get_kernel_1d(arr::AT, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N), extra_phase=0.0, global_phase=0.0) where {T,D, AT <: AbstractArray{T,D}}
# the size is needed to avoid wrap
RT = real(T)
CT = (RT <: Real) ? Complex{RT} : RT
RT = real(CT)
# nowrap_size = N + ceil(N÷2)

# converts ShiftedArrays.CircShiftedArray into a plain array type:
tmp = similar(arr, RT, (1,))
RAT = real_arr_type(typeof(tmp), Val(1))

# the maximal size where the convolution does not yield zero
# max_size = 2*N-1
# the minimum size needed for the convolution
Expand All @@ -35,13 +39,15 @@ function get_kernel_1d(RT::Type, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N)

#Note that the source size ssz is used here
# W of the source for FFTs.
n = (0:N-1)
n = RAT(0:N-1)
# pre-calculate the product of a.^(-n) and w to be later multiplied with the input x
# late casting is important, since the rounding errors are overly large if the original calculations are done in Float32.
aw = CT.((a .^ (-n)) .* w .^ ((n .^ 2) ./ 2))

conv_kernel = zeros(CT, L) # Array{CT}(undef, L)
m = (0:M-1)
conv_kernel = similar(arr, CT, L) # Array{CT}(undef, L)
fill!(conv_kernel, zero(CT))

m = RAT(0:M-1)
conv_kernel[1:M] .= w .^ (-(m .^ 2) ./ 2)
right_start = L-N+1
n = (1:N-1)
Expand All @@ -59,7 +65,7 @@ end

# type for planning. The arrays are 1D but oriented
"""
CZTPlan_1D{CT, D} # <: AbstractArray{T,D}
CZTPlan_1D{CT<:Complex, D<:Integer, AT<:AbstractArray{CT, D}, PT<:Number, PFFT<:AbstractFFTs.Plan, PIFFT<:AbstractFFTs.ScaledPlan}

type used for the onedimensional plan of the chirp Z transformation (CZT).
containing
Expand All @@ -70,20 +76,18 @@ containing
`aw`: factor to multiply input with
`fft_fv`: fourier-transform (FFTW) of the convolutio kernel
`wd`: factor to multiply the result of the convolution by
`fftw_plan`: plan for the forward FFTW of the convolution kernel
`ifftw_plan`: plan for the inverse FFTW of the convolution kernel
`fftw_plan!`: plan for the forward FFTW of the convolution kernel
`ifftw_plan!`: plan for the inverse FFTW of the convolution kernel
"""
struct CZTPlan_1D{CT, PT, D} # <: AbstractArray{T,D}
struct CZTPlan_1D{CT<:Complex, AT<:AbstractArray{CT}, PT<:Number, PFFT<:AbstractFFTs.Plan, PIFFT<:AbstractFFTs.ScaledPlan}
d :: Int
pad_value :: PT
pad_ranges :: NTuple{2,UnitRange{Int64}}
aw :: Array{CT, D}
fft_fv :: Array{CT, D}
wd :: Array{CT, D}
fftw_plan :: FFTW.cFFTWPlan
ifftw_plan :: AbstractFFTs.ScaledPlan
# dimension of this transformation
# as :: Array{T, D} # not needed since it is just the conjugate of ws
pad_ranges :: NTuple{2, UnitRange{Int64}}
aw :: AT
fft_fv :: AT
wd :: AT
fftw_plan! :: PFFT
ifftw_plan! :: PIFFT
end

"""
Expand All @@ -94,8 +98,8 @@ containing
# Members:
`plans`: vector of CZTPlan_1D for each of the directions of the ND array to transform
"""
struct CZTPlan_ND{CT, PT, D} # <: AbstractArray{T,D}
plans :: Vector{CZTPlan_1D{CT,PT, D}}
struct CZTPlan_ND{CT<:Complex, AT<:AbstractArray{CT}, PT<:Number, PFFT<:AbstractFFTs.Plan, PIFFT<:AbstractFFTs.ScaledPlan}
plans :: Vector{CZTPlan_1D{CT, AT, PT, PFFT, PIFFT}}
end

function get_invalid_ranges(sz, scaled, dsize, dst_center)
Expand Down Expand Up @@ -138,8 +142,8 @@ end
creates a plan for an one-dimensional chirp z-transformation (CZT). The generated plan is then applied via
muliplication. For details about the arguments, see `czt_1d()`.
"""
function plan_czt_1d(xin, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, extra_phase=nothing, global_phase=nothing, damp=1.0, src_center=(size(xin,d)+1)/2,
dst_center=dsize÷2+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE)
function plan_czt_1d(xin::AT, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, extra_phase=nothing, global_phase=nothing, damp=1.0, src_center=(size(xin,d)+1)/2,
dst_center=dsize÷2+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE) where {AT}

a = isnothing(a) ? exp(-1im*(dst_center-1)*2pi/(scaled*size(xin,d))) : a
w = isnothing(w) ? cispi(-2/(scaled*size(xin,d))) : w
Expand All @@ -148,23 +152,26 @@ function plan_czt_1d(xin, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, ex
extra_phase = isnothing(extra_phase) ? exp(1im*2pi*(src_center-1)/(scaled*size(xin,d))) : extra_phase
global_phase = isnothing(global_phase) ? a ^ (src_center-1) : global_phase

aw, fft_fv, wd = get_kernel_1d(eltype(xin), size(xin, d), dsize; a=a, w=w, extra_phase=extra_phase, global_phase=global_phase)
aw, fft_fv, wd = get_kernel_1d(xin, size(xin, d), dsize; a=a, w=w, extra_phase=extra_phase, global_phase=global_phase)

# set pad ranges to empty ranges:
start_range = 1:0
end_range = 1:0
stop_range = 1:0

if remove_wrap
start_range, stop_range = get_invalid_ranges(size(xin, d), scaled, dsize, dst_center)

wd[start_range] .= zero(eltype(wd))
wd[stop_range] .= zero(eltype(wd))
end

nsz = ntuple((dd) -> (d==dd) ? size(fft_fv, 1) : size(xin, dd), Val(ndims(xin)))
y = Array{eltype(aw), ndims(xin)}(undef, nsz)
y = similar(xin, eltype(aw), nsz)

fft_p = plan_fft(y, (d,); flags=fft_flags)
ifft_p = plan_ifft(y, (d,); flags=fft_flags) # inv(fft_p)
fft_p! = (typeof(y) <: Array) ? plan_fft!(y, (d,); flags=fft_flags) : plan_fft!(y, (d,))
ifft_p! = (typeof(y) <: Array) ? plan_ifft!(y, (d,); flags=fft_flags) : plan_ifft!(y, (d,))

plan = CZTPlan_1D(d, pad_value, (start_range, end_range), reorient(aw, d, Val(ndims(xin))), reorient(fft_fv, d, Val(ndims(xin))), reorient(wd, d, Val(ndims(xin))), fft_p, ifft_p)
plan = CZTPlan_1D(d, pad_value, (start_range, stop_range), reorient(aw, d, Val(ndims(xin))), reorient(fft_fv, d, Val(ndims(xin))), reorient(wd, d, Val(ndims(xin))), fft_p!, ifft_p!)
return plan
end

Expand All @@ -175,21 +182,30 @@ end
creates a plan for an N-dimensional chirp z-transformation (CZT). The generated plan is then applied via
muliplication. For details about the arguments, see `czt()`.
"""
function plan_czt(xin, scale, dims, dsize=size(xin); a=nothing, w=nothing, damp=ones(ndims(xin)), src_center=size(xin).÷2 .+1, dst_center=dsize.÷2 .+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE)
function plan_czt(xin::AbstractArray{U,D}, scale, dims, dsize=size(xin); a=nothing, w=nothing, damp=ones(ndims(xin)),
src_center=size(xin).÷2 .+1, dst_center=dsize.÷2 .+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE) where {U,D}
CT = (eltype(xin) <: Real) ? Complex{eltype(xin)} : eltype(xin)
D = ndims(xin)
plans = [] # Vector{CZT1DPlan{CT,D}}
sz = size(xin)
for d in dims
xin = Array{eltype(xin)}(undef, sz)

d = dims[1]
p = plan_czt_1d(xin, scale[d], d, dsize[d]; a=a, w=w, damp=damp[d], src_center=src_center[d], dst_center=dst_center[d], remove_wrap=remove_wrap, pad_value=pad_value, fft_flags=fft_flags)
plans = Vector{typeof(p)}(undef, length(dims))
sz = ntuple((dd)-> (dd==d) ? dsize[d] : sz[dd], ndims(xin))
n=1
plans[n]=p
n+=1
for d in dims[2:end]
xin = Array{eltype(xin)}(undef, sz)
p = plan_czt_1d(xin, scale[d], d, dsize[d]; a=a, w=w, damp=damp[d], src_center=src_center[d], dst_center=dst_center[d], remove_wrap=remove_wrap, pad_value=pad_value, fft_flags=fft_flags)
sz = ntuple((dd)-> (dd==d) ? dsize[d] : sz[dd], ndims(xin))
push!(plans, p)
plans[n]=p
n += 1
end
return CZTPlan_ND{CT, typeof(pad_value),D}(plans)
return CZTPlan_ND(plans)
end

function Base.:*(p::CZTPlan_ND, xin::AbstractArray{U,D}; kargs...) where {U,D} # Complex{U}
function Base.:*(p::CZTPlan_ND, xin::AbstractArray{U,D}; kargs...)::AbstractArray{complex(U),D} where {U,D}
xout = xin
for pd in p.plans
xout = czt_1d(xout, pd)
Expand Down Expand Up @@ -230,13 +246,13 @@ The code is based on Rabiner, Schafer & Rader 1969, IEEE Trans. on Audio and El
+ `remove_wrap`: if true, the positions that represent a wrap-around will be set to zero
+ `pad_value`: the value to pad wrapped data with.
"""
function czt_1d(xin, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, damp=1.0, src_center=size(xin,d)÷2+1,
dst_center=dsize÷2+1, extra_phase=nothing, global_phase=nothing, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE)
function czt_1d(xin::AbstractArray{U,D}, scaled, d, dsize=size(xin,d); a=nothing, w=nothing, damp=1.0, src_center=size(xin,d)÷2+1,
dst_center=dsize÷2+1, extra_phase=nothing, global_phase=nothing, remove_wrap=false, pad_value=zero(U), fft_flags=FFTW.ESTIMATE)::AbstractArray{complex(U), D} where {U,D}
plan = plan_czt_1d(xin, scaled, d, dsize; a=a, w=w, extra_phase=extra_phase, global_phase=global_phase, damp, src_center=src_center, dst_center=dst_center, remove_wrap=remove_wrap, pad_value=pad_value, fft_flags=fft_flags);
return plan * xin
end

function Base.:*(p::CZTPlan_1D, xin::AbstractArray{U,D}; kargs...) where {U,D} # Complex{U}
function Base.:*(p::CZTPlan_1D, xin::AbstractArray{U,D}; kargs...)::AbstractArray{complex(U), D} where {U,D} # Complex{U}
return czt_1d(xin, p)
end

Expand All @@ -258,7 +274,7 @@ The code is based on Rabiner, Schafer & Rader 1969, IEEE Trans. on Audio and El
# Arguments
`plan`: A plan created via plan_czt_1d()
"""
function czt_1d(xin, plan::CZTPlan_1D)
function czt_1d(xin::AbstractArray{U,D}, plan::CZTPlan_1D)::AbstractArray{complex(U), D} where {U,D}
# destination position
# cispi(-1/scaled * half_pix_shift)
#
Expand All @@ -267,25 +283,30 @@ function czt_1d(xin, plan::CZTPlan_1D)
# which (intentionally) leads to non-real results for even-sized arrays at non-unit zoom

L = size(plan.fft_fv, plan.d)
nsz = ntuple((dd) -> (dd==plan.d) ? L : size(xin, dd), Val(ndims(xin)))
nsz = ntuple((dd) -> (dd==plan.d) ? L : size(xin, dd), Val(D))
# append zeros
y = zeros(eltype(plan.aw), nsz)
myrange = ntuple((dd) -> (1:size(xin,dd)), Val(ndims(xin)))
y[myrange...] = xin .* plan.aw
# corner = ntuple((x)->1, Val(ndims(xin)))
# select_region(xin .* plan.aw, new_size=nsz, center=corner, dst_center=corner)

# g = ifft(fft(y, d) .* plan.fft_fv, d)
g = plan.ifftw_plan * (plan.fftw_plan * y .* plan.fft_fv)
tmp = eltype(plan.aw).(xin .* plan.aw)

corner = ntuple((x)->1, Val(D))
y = NDTools.select_region(tmp, nsz; center=corner, dst_center=corner)

# in-place application to y:
plan.fftw_plan! * y
y .*= plan.fft_fv
# in-place application to y:
plan.ifftw_plan! * y

# dsz = ntuple((dd) -> (d==dd) ? dsize : size(xin), Val(ndims(xin)))
# return only the wanted (valid) part
myrange = ntuple((dd) -> (dd==plan.d) ? (1:size(plan.wd,plan.d)) : (1:size(xin, dd)), Val(ndims(xin)))
res = g[myrange...] .* plan.wd
myrange = ntuple((dd) -> (dd==plan.d) ? (1:size(plan.wd, plan.d)) : (1:size(xin, dd)), Val(D))
res = y[myrange...] .* plan.wd
# pad_value=0 means that it is either already handled by plan.wd or no padding is wanted.
if plan.pad_value != 0
myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[1] : Colon(), Val(ndims(xin)))
# first the start_range (plan.pad_ranges[1]):
myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[1] : Colon(), Val(D))
res[myrange...] .= plan.pad_value
myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[2] : Colon(), Val(ndims(xin)))
# first the stop_range (plan.pad_ranges[2]):
myrange = ntuple((dd) -> (dd==plan.d) ? plan.pad_ranges[2] : Colon(), Val(D))
res[myrange...] .= plan.pad_value
end
return res
Expand Down Expand Up @@ -355,18 +376,22 @@ julia> zoomed = real.(ift(xft))
0.0239759 -0.028264 0.0541186 -0.0116475 -0.261294 0.312719 -0.261294 -0.0116475 0.0541186 -0.028264
```
"""
function czt(xin::AbstractArray{T,N}, scale, dims=1:ndims(xin), dsize=size(xin);
a=nothing, w=nothing, damp=ones(ndims(xin)), src_center=size(xin).÷2 .+1, dst_center=dsize.÷2 .+1, remove_wrap=false, pad_value=zero(eltype(xin)), fft_flags=FFTW.ESTIMATE)::AbstractArray{complex(T),N} where {T,N}
function czt(xin::AbstractArray{T,D}, scale, dims=1:D, dsize=size(xin);
a=nothing, w=nothing, damp=ones(D), src_center=size(xin).÷2 .+1, dst_center=dsize.÷2 .+1,
remove_wrap=false, pad_value=zero(T), fft_flags=FFTW.ESTIMATE)::AbstractArray{complex(T), D} where {T,D}
xout = xin
if length(scale) != ndims(xin)
error("Every of the $(ndims(xin)) dimension needs exactly one corresponding scale (zoom) factor, which should be equal to 1.0 for dimensions not contained in the dims argument.")
end
for d = 1:ndims(xin)
# check all the dims:
for d = 1:D
if !(d in dims) && scale[d] != 1.0 && !isnothing(scale[d])
error("The scale factor $(scale[d]) needs to be nothing or 1.0, if this dimension is not in the list of dimensions to transform.")
end
end

for d in dims
# in-place assignement is not possible, since with a zoom the size always changes.
xout = czt_1d(xout, scale[d], d, dsize[d]; a=a, w=w, damp=damp[d], src_center=src_center[d], dst_center=dst_center[d], remove_wrap=remove_wrap, pad_value=pad_value, fft_flags=fft_flags)
end
return xout
Expand Down
10 changes: 5 additions & 5 deletions test/czt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ using NDTools # this is needed for the select_region! function below.
# @test ≈(czt(y,zoom, src_center=(size(y).+1)./2), select_region(upsample2(ft(y), fix_center=true), new_size=size(y)), rtol=1e-5)

# for uneven sizes this works:
@test ≈(czt(y[1:5,1:5],zoom, (1,2), (10,10)), upsample2(ft(y[1:5,1:5]), fix_center=true), rtol=1e-5)
@test ≈(czt(y[1:5,1:5], zoom, (1,2), (10,10)), upsample2(ft(y[1:5,1:5]), fix_center=true), rtol=1e-5)
p_czt = plan_czt(y, zoom, (1,2), (11,12))
@test ≈(p_czt * y, czt(y, zoom, (1,2), (11,12)))
# zoom smaller 1.0 causes wrap around:
zoom = (0.5,2.0)
@test abs(czt(y,zoom)[1,1]) > 1e-5
zoom = (2.0, 0.5)
# check if the remove_wrap works
@test abs(czt(y,zoom; remove_wrap=true)[1,1]) == 0.0
@test abs(iczt(y,zoom; remove_wrap=true)[1,1]) == 0.0
@test abs(czt(y,zoom; pad_value=0.2, remove_wrap=true)[1,1]) == 0.2f0
@test abs(iczt(y,zoom; pad_value=0.5f0, remove_wrap=true)[1,1]) == 0.5f0
@test abs(czt(y, zoom; remove_wrap=true)[1,1]) == 0.0
@test abs(iczt(y, zoom; remove_wrap=true)[1,1]) == 0.0
@test abs(czt(y, zoom; pad_value=0.2, remove_wrap=true)[1,1]) == 0.2f0
@test abs(iczt(y, zoom; pad_value=0.5f0, remove_wrap=true)[1,1]) == 0.5f0
end
end
Loading