Skip to content

Commit

Permalink
rem. support for low-D plans to be applied to ND
Browse files Browse the repository at this point in the history
  • Loading branch information
RainerHeintzmann committed Aug 29, 2023
1 parent 31b4042 commit c9aa2cd
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 61 deletions.
109 changes: 48 additions & 61 deletions lib/cufft/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ using LinearAlgebra
Base.:(*)(p::Plan{T}, x::DenseCuArray) where {T} = p * copy1(T, x)
Base.:(*)(p::ScaledPlan, x::DenseCuArray) = rmul!(p.p * x, p.scale)


## plan structure

# K is an integer flag for forward/backward
Expand All @@ -34,7 +33,9 @@ function CUDA.unsafe_free!(plan::CuFFTPlan)
end

mutable struct cCuFFTPlan{T<:cufftNumber,K,inplace,N} <: CuFFTPlan{T,K,inplace}
handle::cufftHandle
# handle to Cuda low level plan. Note that this plan sometimes has lower dimensions
# to handle more transform cases such as individual directions
handle::cufftHandle
ctx::CuContext
stream::CuStream
sz::NTuple{N,Int} # Julia size of input array
Expand Down Expand Up @@ -170,11 +171,10 @@ function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
region = Tuple(region)

md = plan_max_dims(region, size(X))
# X = front_view(X, md)
sizex = size(X)[1:md]
handle = cufftGetPlan(xtype, sizex, region)

cCuFFTPlan{T,K,inplace,md}(handle, sizex, sizex, region, xtype)
cCuFFTPlan{T,K,inplace,N}(handle, X, size(X), region, xtype)
end


Expand All @@ -185,12 +185,10 @@ function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
region = Tuple(region)

md = plan_max_dims(region, size(X))
# X = front_view(X, md)
sizex = size(X)[1:md]

handle = cufftGetPlan(xtype, sizex, region)

cCuFFTPlan{T,K,inplace,md}(handle, sizex, sizex, region, xtype)
cCuFFTPlan{T,K,inplace,N}(handle, X, size(X), region, xtype)
end

# out-of-place complex
Expand All @@ -201,11 +199,10 @@ function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
region = Tuple(region)

md = plan_max_dims(region,size(X))
# X = front_view(X, md)
sizex = size(X)[1:md]
handle = cufftGetPlan(xtype, sizex, region)

cCuFFTPlan{T,K,inplace,md}(handle,sizex, sizex, region, xtype)
cCuFFTPlan{T,K,inplace,N}(handle, X, size(X), region, xtype)
end

function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
Expand All @@ -215,11 +212,10 @@ function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
region = Tuple(region)

md = plan_max_dims(region,size(X))
# X = front_view(X, md)
sizex = size(X)[1:md]
handle = cufftGetPlan(xtype, sizex, region)

cCuFFTPlan{T,K,inplace,md}(handle, sizex, sizex, region, xtype)
cCuFFTPlan{T,K,inplace,N}(handle, size(X), size(X), region, xtype)
end

# out-of-place real-to-complex
Expand All @@ -235,10 +231,10 @@ function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N}

handle = cufftGetPlan(xtype, sizex, region)

ydims = collect(sizex)
ydims = collect(size(X))
ydims[region[1]] = div(ydims[region[1]],2)+1

rCuFFTPlan{T,K,inplace,md}(handle, sizex, (ydims...,), region, xtype)
rCuFFTPlan{T,K,inplace,N}(handle, size(X), (ydims...,), region, xtype)
end

function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region::Any) where {T<:cufftComplexes,N}
Expand All @@ -247,16 +243,12 @@ function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region::Any) where {T<:cuf
xtype = (T == cufftComplex) ? CUFFT_C2R : CUFFT_Z2D
region = Tuple(region)

md = plan_max_dims(region,size(X))
sizex = size(X)[1:md]
# X = front_view(X, md)

ydims = collect(sizex)
ydims = collect(size(X))
ydims[region[1]] = d

handle = cufftGetPlan(xtype, (ydims...,), region)

rCuFFTPlan{T,K,inplace,md}(handle, sizex, (ydims...,), region, xtype)
rCuFFTPlan{T,K,inplace,N}(handle, size(X), (ydims...,), region, xtype)
end


Expand All @@ -267,44 +259,39 @@ function plan_inv(p::cCuFFTPlan{T,CUFFT_FORWARD,inplace,N}) where {T,N,inplace}
md = plan_max_dims(p.region, p.sz)
sizex = p.sz[1:md]
handle = cufftGetPlan(p.xtype, sizex, p.region)
ScaledPlan(cCuFFTPlan{T,CUFFT_INVERSE,inplace,md}(handle, sizex, sizex, p.region,
ScaledPlan(cCuFFTPlan{T,CUFFT_INVERSE,inplace,N}(handle, p.sz, p.sz, p.region,
p.xtype),
normalization(real(T), sizex, p.region))
normalization(real(T), p.sz, p.region))
end

function plan_inv(p::cCuFFTPlan{T,CUFFT_INVERSE,inplace,N}) where {T,N,inplace}
md = plan_max_dims(p.region,p.sz)
sizex = p.sz[1:md]
handle = cufftGetPlan(p.xtype, sizex, p.region)
ScaledPlan(cCuFFTPlan{T,CUFFT_FORWARD,inplace,md}(handle, sizex, sizex, p.region,
ScaledPlan(cCuFFTPlan{T,CUFFT_FORWARD,inplace,N}(handle, p.sz, p.sz, p.region,
p.xtype),
normalization(real(T), sizex, p.region))
normalization(real(T), p.sz, p.region))
end

function plan_inv(p::rCuFFTPlan{T,CUFFT_INVERSE,inplace,N}
) where {T<:cufftComplexes,N,inplace}
md_osz = plan_max_dims(p.region, p.osz)
sz_X = p.osz[1:md_osz]
md_sz = plan_max_dims(p.region, p.sz)
sz_Y = p.sz[1:md_sz]
# Y = CuArray{T}(undef, sz_Y)
xtype = p.xtype == CUFFT_C2R ? CUFFT_R2C : CUFFT_D2Z
handle = cufftGetPlan(xtype, sz_X, p.region)
ScaledPlan(rCuFFTPlan{real(T),CUFFT_FORWARD,inplace,md_sz}(handle, sz_X, sz_Y, p.region, xtype),
normalization(real(T), sz_X, p.region))
ScaledPlan(rCuFFTPlan{real(T),CUFFT_FORWARD,inplace,N}(handle, p.osz, p.sz, p.region, xtype),
normalization(real(T), p.osz, p.region))
end

function plan_inv(p::rCuFFTPlan{T,CUFFT_FORWARD,inplace,N}
) where {T<:cufftReals,N,inplace}
md_osz = plan_max_dims(p.region,p.osz)
sz_X = p.osz[1:md_osz]
md_sz = plan_max_dims(p.region,p.sz)
sz_Y = p.sz[1:md_sz]
xtype = p.xtype == CUFFT_R2C ? CUFFT_C2R : CUFFT_Z2D
handle = cufftGetPlan(xtype, sz_Y, p.region)
ScaledPlan(rCuFFTPlan{complex(T),CUFFT_INVERSE,inplace,md_sz}(handle, sz_X, sz_Y,
ScaledPlan(rCuFFTPlan{complex(T),CUFFT_INVERSE,inplace,N}(handle, p.osz, p.sz,
p.region, xtype),
normalization(real(T), sz_Y, p.region))
normalization(real(T), p.sz, p.region))
end


Expand All @@ -315,14 +302,14 @@ end
# see # JuliaGPU/CuArrays.jl#345, NVIDIA/cuFFT#2714055.

function assert_applicable(p::CuFFTPlan{T}, X::DenseCuArray{T}) where {T}
(size(X) >= p.sz) ||
(size(X) == p.sz) ||
throw(ArgumentError("CuFFT plan applied to wrong-size input"))
end

function assert_applicable(p::CuFFTPlan{T,K,inplace}, X::DenseCuArray{T},
Y::DenseCuArray) where {T,K,inplace}
assert_applicable(p, X)
if size(Y)[1:length(p.osz)] != p.osz
if size(Y) != p.osz
throw(ArgumentError("CuFFT plan applied to wrong-size output"))
elseif inplace != (pointer(X) == pointer(Y))
throw(ArgumentError(string("CuFFT ",
Expand All @@ -333,84 +320,87 @@ function assert_applicable(p::CuFFTPlan{T,K,inplace}, X::DenseCuArray{T},
end
end

function unsafe_execute!(plan::cCuFFTPlan{cufftComplex,K,<:Any,N},
function unsafe_execute!(plan::cCuFFTPlan{cufftComplex,K,<:Any,M},
x::DenseCuArray{cufftComplex,N},
y::DenseCuArray{cufftComplex,N}) where {K,N}
y::DenseCuArray{cufftComplex,N}) where {K,M,N}
@assert plan.xtype == CUFFT_C2C
update_stream(plan)
cufftExecC2C(plan, x, y, K)
end

function unsafe_execute!(plan::rCuFFTPlan{cufftComplex,K,true,N},
function unsafe_execute!(plan::rCuFFTPlan{cufftComplex,K,true,M},
x::DenseCuArray{cufftComplex,N},
y::DenseCuArray{cufftReal,N}) where {K,N}
y::DenseCuArray{cufftReal,N}) where {K,M,N}
@assert plan.xtype == CUFFT_C2R
update_stream(plan)
cufftExecC2R(plan, x, y)
end
function unsafe_execute!(plan::rCuFFTPlan{cufftComplex,K,false,N},
function unsafe_execute!(plan::rCuFFTPlan{cufftComplex,K,false,M},
x::DenseCuArray{cufftComplex,N},
y::DenseCuArray{cufftReal}) where {K,N}
y::DenseCuArray{cufftReal}) where {K,M,N}
@assert plan.xtype == CUFFT_C2R
x = copy(x)
update_stream(plan)
cufftExecC2R(plan, x, y)
unsafe_free!(x)
end

function unsafe_execute!(plan::rCuFFTPlan{cufftReal,K,<:Any,N},
function unsafe_execute!(plan::rCuFFTPlan{cufftReal,K,<:Any,M},
x::DenseCuArray{cufftReal,N},
y::DenseCuArray{cufftComplex,N}) where {K,N}
y::DenseCuArray{cufftComplex,N}) where {K,M,N}
@assert plan.xtype == CUFFT_R2C
update_stream(plan)
cufftExecR2C(plan, x, y)
end

function unsafe_execute!(plan::cCuFFTPlan{cufftDoubleComplex,K,<:Any,N},
function unsafe_execute!(plan::cCuFFTPlan{cufftDoubleComplex,K,<:Any,M},
x::DenseCuArray{cufftDoubleComplex,N},
y::DenseCuArray{cufftDoubleComplex}) where {K,N}
y::DenseCuArray{cufftDoubleComplex}) where {K,M,N}
@assert plan.xtype == CUFFT_Z2Z
update_stream(plan)
cufftExecZ2Z(plan, x, y, K)
end

function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleComplex,K,true,N},
function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleComplex,K,true,M},
x::DenseCuArray{cufftDoubleComplex,N},
y::DenseCuArray{cufftDoubleReal}) where {K,N}
y::DenseCuArray{cufftDoubleReal}) where {K,M,N}
update_stream(plan)
@assert plan.xtype == CUFFT_Z2D
cufftExecZ2D(plan, x, y)
end
function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleComplex,K,false,N},
function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleComplex,K,false,M},
x::DenseCuArray{cufftDoubleComplex,N},
y::DenseCuArray{cufftDoubleReal}) where {K,N}
y::DenseCuArray{cufftDoubleReal}) where {K,M,N}
@assert plan.xtype == CUFFT_Z2D
x = copy(x)
update_stream(plan)
cufftExecZ2D(plan, x, y)
unsafe_free!(x)
end

function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleReal,K,<:Any,N},
function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleReal,K,<:Any,M},
x::DenseCuArray{cufftDoubleReal,N},
y::DenseCuArray{cufftDoubleComplex,N}) where {K,N}
y::DenseCuArray{cufftDoubleComplex,N}) where {K,M,N}
@assert plan.xtype == CUFFT_D2Z
update_stream(plan)
cufftExecD2Z(plan, x, y)
end

# a version of unsafe_execute which applies the plan to each element of trailing dimensions not covered by the plan.
function unsafe_execute_trailing!(p,x, y)
N = ndims(p)
# Note that for plans, with trailing non-transform dimensions views are created for each of such elements.
# Such views each have lower dimensions and are then transformed by the lower dimension low-level Cuda plan.
function unsafe_execute_trailing!(p, x, y)
N = plan_max_dims(p.region, p.osz)
M = ndims(x)
d = p.region[end]
if M == N # p.region[1] == 1 || ndims(x) == d
if M == N
unsafe_execute!(p,x,y)
else
front_ids = ntuple((dd)->Colon(), d)
for c in CartesianIndices(size(x)[d+1:end])
vx = @view x[front_ids..., Tuple(c)...]
vy = @view y[front_ids..., Tuple(c)...]
ids = ntuple((dd)->c[dd], M-N)
vx = @view x[front_ids..., ids...]
vy = @view y[front_ids..., ids...]
unsafe_execute!(p,vx,vy)
end
end
Expand All @@ -435,8 +425,7 @@ function Base.:(*)(p::rCuFFTPlan{T,CUFFT_FORWARD,false,N}, x::DenseCuArray{T,M}
) where {T<:cufftReals,N,M}
assert_applicable(p,x)
@assert p.xtype [CUFFT_R2C,CUFFT_D2Z]
osz = get_osz(p.osz, x)
y = CuArray{complex(T),M}(undef, osz)
y = CuArray{complex(T),M}(undef, p.osz)
unsafe_execute_trailing!(p,x, y)
y
end
Expand All @@ -445,8 +434,7 @@ function Base.:(*)(p::rCuFFTPlan{T,CUFFT_INVERSE,false,N}, x::DenseCuArray{T,M}
) where {T<:cufftComplexes,N,M}
assert_applicable(p,x)
@assert p.xtype [CUFFT_C2R,CUFFT_Z2D]
osz = get_osz(p.osz, x)
y = CuArray{real(T),M}(undef, osz)
y = CuArray{real(T),M}(undef, p.osz)
unsafe_execute_trailing!(p,x, y)
y
end
Expand All @@ -459,8 +447,7 @@ end

function Base.:(*)(p::cCuFFTPlan{T,K,false,N}, x::DenseCuArray{T,M}) where {T,K,N,M}
assert_applicable(p,x)
osz = get_osz(p.osz, x)
y = CuArray{T,M}(undef, osz)
y = CuArray{T,M}(undef, p.osz)
unsafe_execute_trailing!(p,x, y)
y
end
3 changes: 3 additions & 0 deletions test/libraries/cufft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ function batched(X::AbstractArray{T,N},region) where {T <: Complex,N}
d_X = CuArray(X)
p = plan_fft(d_X,region)
d_Y = p * d_X
d_X2 = reshape(d_X, (size(d_X)..., 1))
@test_throws ArgumentError p * d_X2

Y = collect(d_Y)
@test isapprox(Y, fftw_X, rtol = MYRTOL, atol = MYATOL)

Expand Down

0 comments on commit c9aa2cd

Please sign in to comment.