From e86b165e6801c411956c271205f4cdc9cae45837 Mon Sep 17 00:00:00 2001 From: Rainer Heintzmann Date: Tue, 29 Aug 2023 14:21:56 +0200 Subject: [PATCH] CUFFT: Add support for more transform directions (#1903) Co-authored-by: Tim Besard --- lib/cufft/fft.jl | 208 ++++++++++++++++++++++++++++------------ lib/cufft/wrappers.jl | 29 ++++-- src/indexing.jl | 4 +- test/libraries/cufft.jl | 16 +++- 4 files changed, 182 insertions(+), 75 deletions(-) diff --git a/lib/cufft/fft.jl b/lib/cufft/fft.jl index 7a09aa75bf..c431453d1f 100644 --- a/lib/cufft/fft.jl +++ b/lib/cufft/fft.jl @@ -3,7 +3,7 @@ @reexport using AbstractFFTs import AbstractFFTs: plan_fft, plan_fft!, plan_bfft, plan_bfft!, plan_ifft, - plan_rfft, plan_brfft, plan_inv, normalization, fft, bfft, ifft, rfft, + plan_rfft, plan_brfft, plan_inv, normalization, fft, bfft, ifft, rfft, irfft, Plan, ScaledPlan using LinearAlgebra @@ -11,7 +11,6 @@ using LinearAlgebra Base.:(*)(p::Plan{T}, x::DenseCuArray) where {T} = p * copy1(T, x) Base.:(*)(p::ScaledPlan, x::DenseCuArray) = rmul!(p.p * x, p.scale) - ## plan structure # K is an integer flag for forward/backward @@ -34,7 +33,9 @@ function CUDA.unsafe_free!(plan::CuFFTPlan) end mutable struct cCuFFTPlan{T<:cufftNumber,K,inplace,N} <: CuFFTPlan{T,K,inplace} - handle::cufftHandle + # handle to Cuda low level plan. Note that this plan sometimes has lower dimensions + # to handle more transform cases such as individual directions + handle::cufftHandle ctx::CuContext stream::CuStream sz::NTuple{N,Int} # Julia size of input array @@ -43,14 +44,20 @@ mutable struct cCuFFTPlan{T<:cufftNumber,K,inplace,N} <: CuFFTPlan{T,K,inplace} region::Any pinv::ScaledPlan # required by AbstractFFT API - function cCuFFTPlan{T,K,inplace,N}(handle::cufftHandle, X::DenseCuArray{T,N}, + function cCuFFTPlan{T,K,inplace,N}(handle::cufftHandle, sizex::NTuple{N, Int}, sizey::Tuple, region, xtype ) where {T<:cufftNumber,K,inplace,N} # TODO: enforce consistency of sizey - p = new(handle, context(), stream(), size(X), sizey, xtype, region) + p = new(handle, context(), stream(), sizex, sizey, xtype, region) finalizer(unsafe_free!, p) p end + + function cCuFFTPlan{T,K,inplace,N}(handle::cufftHandle, X::DenseCuArray{T,N}, + sizey::Tuple, region, xtype + ) where {T<:cufftNumber,K,inplace,N} + cCuFFTPlan{T,K,inplace,N}(handle, size(X), sizey, region, xtype) + end end mutable struct rCuFFTPlan{T<:cufftNumber,K,inplace,N} <: CuFFTPlan{T,K,inplace} @@ -63,14 +70,18 @@ mutable struct rCuFFTPlan{T<:cufftNumber,K,inplace,N} <: CuFFTPlan{T,K,inplace} region::Any pinv::ScaledPlan # required by AbstractFFT API - function rCuFFTPlan{T,K,inplace,N}(handle::cufftHandle, X::DenseCuArray{T,N}, + function rCuFFTPlan{T,K,inplace,N}(handle::cufftHandle, sizex::NTuple{N, Int}, sizey::Tuple, region, xtype ) where {T<:cufftNumber,K,inplace,N} # TODO: enforce consistency of sizey - p = new(handle, context(), stream(), size(X), sizey, xtype, region) + p = new(handle, context(), stream(), sizex, sizey, xtype, region) finalizer(unsafe_free!, p) p end + function rCuFFTPlan{T,K,inplace,N}(handle::cufftHandle, X::DenseCuArray{T,N}, + sizey::Tuple, region, xtype) where {T<:cufftNumber,K,inplace,N} + rCuFFTPlan{T,K,inplace,N}(handle, size(X), sizey, region, xtype) + end end const xtypenames = Dict{cufftType,String}(CUFFT_R2C => "real-to-complex", @@ -130,6 +141,29 @@ end rfft(x::DenseCuArray{<:Union{Integer,Rational}}, region=1:ndims(x)) = rfft(realfloat(x), region) plan_rfft(x::DenseCuArray{<:Real}, region) = plan_rfft(realfloat(x), region) +function irfft(x::DenseCuArray{<:Union{Real,Integer,Rational}}, d::Integer, region=1:ndims(x)) + irfft(complexfloat(x), d, region) +end + +# yields the maximal dimensions of the plan, for plans starting at dim 1 or ending at the size vector, +# this is always the full input size +function plan_max_dims(region, sz) + if (region[1] == 1 && (length(region) <=1 || all(diff(collect(region)) .== 1))) + return length(sz) + else + return region[end] + end +end + +# retrieves the size to allocate even if the trailing dimensions do no transform +get_osz(osz, x) = ntuple((d)->(d>length(osz) ? size(x, d) : osz[d]), ndims(x)) + +# returns a view of the front part of the dimensions of the array up to md dimensions +function front_view(X, md) + t = ntuple((d)->ifelse(d<=md, Colon(), 1), ndims(X)) + @view X[t...] +end + # region is an iterable subset of dimensions # spec. an integer, range, tuple, or array @@ -138,18 +172,25 @@ function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} K = CUFFT_FORWARD inplace = true xtype = (T == cufftComplex) ? CUFFT_C2C : CUFFT_Z2Z + region = Tuple(region) - handle = cufftGetPlan(xtype, size(X), region) + md = plan_max_dims(region, size(X)) + sizex = size(X)[1:md] + handle = cufftGetPlan(xtype, sizex, region) cCuFFTPlan{T,K,inplace,N}(handle, X, size(X), region, xtype) end + function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} K = CUFFT_INVERSE inplace = true xtype = (T == cufftComplex) ? CUFFT_C2C : CUFFT_Z2Z + region = Tuple(region) - handle = cufftGetPlan(xtype, size(X), region) + md = plan_max_dims(region, size(X)) + sizex = size(X)[1:md] + handle = cufftGetPlan(xtype, sizex, region) cCuFFTPlan{T,K,inplace,N}(handle, X, size(X), region, xtype) end @@ -159,8 +200,11 @@ function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} K = CUFFT_FORWARD xtype = (T == cufftComplex) ? CUFFT_C2C : CUFFT_Z2Z inplace = false + region = Tuple(region) - handle = cufftGetPlan(xtype, size(X), region) + md = plan_max_dims(region,size(X)) + sizex = size(X)[1:md] + handle = cufftGetPlan(xtype, sizex, region) cCuFFTPlan{T,K,inplace,N}(handle, X, size(X), region, xtype) end @@ -169,10 +213,13 @@ function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} K = CUFFT_INVERSE inplace = false xtype = (T == cufftComplex) ? CUFFT_C2C : CUFFT_Z2Z + region = Tuple(region) - handle = cufftGetPlan(xtype, size(X), region) + md = plan_max_dims(region,size(X)) + sizex = size(X)[1:md] + handle = cufftGetPlan(xtype, sizex, region) - cCuFFTPlan{T,K,inplace,N}(handle, X, size(X), region, xtype) + cCuFFTPlan{T,K,inplace,N}(handle, size(X), size(X), region, xtype) end # out-of-place real-to-complex @@ -180,65 +227,75 @@ function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N} K = CUFFT_FORWARD inplace = false xtype = (T == cufftReal) ? CUFFT_R2C : CUFFT_D2Z + region = Tuple(region) + + md = plan_max_dims(region,size(X)) + # X = front_view(X, md) + sizex = size(X)[1:md] - handle = cufftGetPlan(xtype, size(X), region) + handle = cufftGetPlan(xtype, sizex, region) ydims = collect(size(X)) ydims[region[1]] = div(ydims[region[1]],2)+1 - rCuFFTPlan{T,K,inplace,N}(handle, X, (ydims...,), region, xtype) + rCuFFTPlan{T,K,inplace,N}(handle, size(X), (ydims...,), region, xtype) end function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region::Any) where {T<:cufftComplexes,N} K = CUFFT_INVERSE inplace = false xtype = (T == cufftComplex) ? CUFFT_C2R : CUFFT_Z2D + region = Tuple(region) + ydims = collect(size(X)) ydims[region[1]] = d handle = cufftGetPlan(xtype, (ydims...,), region) - rCuFFTPlan{T,K,inplace,N}(handle, X, (ydims...,), region, xtype) + rCuFFTPlan{T,K,inplace,N}(handle, size(X), (ydims...,), region, xtype) end + # FIXME: plan_inv methods allocate needlessly (to provide type parameters) # Perhaps use FakeArray types to avoid this. function plan_inv(p::cCuFFTPlan{T,CUFFT_FORWARD,inplace,N}) where {T,N,inplace} - X = CuArray{T}(undef, p.sz) - handle = cufftGetPlan(p.xtype, p.sz, p.region) - ScaledPlan(cCuFFTPlan{T,CUFFT_INVERSE,inplace,N}(handle, X, p.sz, p.region, + md = plan_max_dims(p.region, p.sz) + sizex = p.sz[1:md] + handle = cufftGetPlan(p.xtype, sizex, p.region) + ScaledPlan(cCuFFTPlan{T,CUFFT_INVERSE,inplace,N}(handle, p.sz, p.sz, p.region, p.xtype), - normalization(X, p.region)) + normalization(real(T), p.sz, p.region)) end function plan_inv(p::cCuFFTPlan{T,CUFFT_INVERSE,inplace,N}) where {T,N,inplace} - X = CuArray{T}(undef, p.sz) - handle = cufftGetPlan(p.xtype, p.sz, p.region) - ScaledPlan(cCuFFTPlan{T,CUFFT_FORWARD,inplace,N}(handle, X, p.sz, p.region, + md = plan_max_dims(p.region,p.sz) + sizex = p.sz[1:md] + handle = cufftGetPlan(p.xtype, sizex, p.region) + ScaledPlan(cCuFFTPlan{T,CUFFT_FORWARD,inplace,N}(handle, p.sz, p.sz, p.region, p.xtype), - normalization(X, p.region)) + normalization(real(T), p.sz, p.region)) end function plan_inv(p::rCuFFTPlan{T,CUFFT_INVERSE,inplace,N} ) where {T<:cufftComplexes,N,inplace} - X = CuArray{real(T)}(undef, p.osz) - Y = CuArray{T}(undef, p.sz) + md_osz = plan_max_dims(p.region, p.osz) + sz_X = p.osz[1:md_osz] xtype = p.xtype == CUFFT_C2R ? CUFFT_R2C : CUFFT_D2Z - handle = cufftGetPlan(xtype, p.osz, p.region) - ScaledPlan(rCuFFTPlan{real(T),CUFFT_FORWARD,inplace,N}(handle, X, p.sz, p.region, xtype), - normalization(X, p.region)) + handle = cufftGetPlan(xtype, sz_X, p.region) + ScaledPlan(rCuFFTPlan{real(T),CUFFT_FORWARD,inplace,N}(handle, p.osz, p.sz, p.region, xtype), + normalization(real(T), p.osz, p.region)) end function plan_inv(p::rCuFFTPlan{T,CUFFT_FORWARD,inplace,N} ) where {T<:cufftReals,N,inplace} - X = CuArray{complex(T)}(undef, p.osz) - Y = CuArray{T}(undef, p.sz) + md_sz = plan_max_dims(p.region,p.sz) + sz_Y = p.sz[1:md_sz] xtype = p.xtype == CUFFT_R2C ? CUFFT_C2R : CUFFT_Z2D - handle = cufftGetPlan(xtype, p.sz, p.region) - ScaledPlan(rCuFFTPlan{complex(T),CUFFT_INVERSE,inplace,N}(handle, X, p.sz, + handle = cufftGetPlan(xtype, sz_Y, p.region) + ScaledPlan(rCuFFTPlan{complex(T),CUFFT_INVERSE,inplace,N}(handle, p.osz, p.sz, p.region, xtype), - normalization(Y, p.region)) + normalization(real(T), p.sz, p.region)) end @@ -267,24 +324,24 @@ function assert_applicable(p::CuFFTPlan{T,K,inplace}, X::DenseCuArray{T}, end end -function unsafe_execute!(plan::cCuFFTPlan{cufftComplex,K,<:Any,N}, +function unsafe_execute!(plan::cCuFFTPlan{cufftComplex,K,<:Any,M}, x::DenseCuArray{cufftComplex,N}, - y::DenseCuArray{cufftComplex,N}) where {K,N} + y::DenseCuArray{cufftComplex,N}) where {K,M,N} @assert plan.xtype == CUFFT_C2C update_stream(plan) cufftExecC2C(plan, x, y, K) end -function unsafe_execute!(plan::rCuFFTPlan{cufftComplex,K,true,N}, +function unsafe_execute!(plan::rCuFFTPlan{cufftComplex,K,true,M}, x::DenseCuArray{cufftComplex,N}, - y::DenseCuArray{cufftReal,N}) where {K,N} + y::DenseCuArray{cufftReal,N}) where {K,M,N} @assert plan.xtype == CUFFT_C2R update_stream(plan) cufftExecC2R(plan, x, y) end -function unsafe_execute!(plan::rCuFFTPlan{cufftComplex,K,false,N}, +function unsafe_execute!(plan::rCuFFTPlan{cufftComplex,K,false,M}, x::DenseCuArray{cufftComplex,N}, - y::DenseCuArray{cufftReal}) where {K,N} + y::DenseCuArray{cufftReal}) where {K,M,N} @assert plan.xtype == CUFFT_C2R x = copy(x) update_stream(plan) @@ -292,32 +349,32 @@ function unsafe_execute!(plan::rCuFFTPlan{cufftComplex,K,false,N}, unsafe_free!(x) end -function unsafe_execute!(plan::rCuFFTPlan{cufftReal,K,<:Any,N}, +function unsafe_execute!(plan::rCuFFTPlan{cufftReal,K,<:Any,M}, x::DenseCuArray{cufftReal,N}, - y::DenseCuArray{cufftComplex,N}) where {K,N} + y::DenseCuArray{cufftComplex,N}) where {K,M,N} @assert plan.xtype == CUFFT_R2C update_stream(plan) cufftExecR2C(plan, x, y) end -function unsafe_execute!(plan::cCuFFTPlan{cufftDoubleComplex,K,<:Any,N}, +function unsafe_execute!(plan::cCuFFTPlan{cufftDoubleComplex,K,<:Any,M}, x::DenseCuArray{cufftDoubleComplex,N}, - y::DenseCuArray{cufftDoubleComplex}) where {K,N} + y::DenseCuArray{cufftDoubleComplex}) where {K,M,N} @assert plan.xtype == CUFFT_Z2Z update_stream(plan) cufftExecZ2Z(plan, x, y, K) end -function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleComplex,K,true,N}, +function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleComplex,K,true,M}, x::DenseCuArray{cufftDoubleComplex,N}, - y::DenseCuArray{cufftDoubleReal}) where {K,N} + y::DenseCuArray{cufftDoubleReal}) where {K,M,N} update_stream(plan) @assert plan.xtype == CUFFT_Z2D cufftExecZ2D(plan, x, y) end -function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleComplex,K,false,N}, +function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleComplex,K,false,M}, x::DenseCuArray{cufftDoubleComplex,N}, - y::DenseCuArray{cufftDoubleReal}) where {K,N} + y::DenseCuArray{cufftDoubleReal}) where {K,M,N} @assert plan.xtype == CUFFT_Z2D x = copy(x) update_stream(plan) @@ -325,51 +382,76 @@ function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleComplex,K,false,N}, unsafe_free!(x) end -function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleReal,K,<:Any,N}, +function unsafe_execute!(plan::rCuFFTPlan{cufftDoubleReal,K,<:Any,M}, x::DenseCuArray{cufftDoubleReal,N}, - y::DenseCuArray{cufftDoubleComplex,N}) where {K,N} + y::DenseCuArray{cufftDoubleComplex,N}) where {K,M,N} @assert plan.xtype == CUFFT_D2Z update_stream(plan) cufftExecD2Z(plan, x, y) end +# a version of unsafe_execute which applies the plan to each element of trailing dimensions not covered by the plan. +# Note that for plans, with trailing non-transform dimensions views are created for each of such elements. +# Such views each have lower dimensions and are then transformed by the lower dimension low-level Cuda plan. +function unsafe_execute_trailing!(p, x, y) + N = plan_max_dims(p.region, p.osz) + M = ndims(x) + d = p.region[end] + if M == N + unsafe_execute!(p,x,y) + else + front_ids = ntuple((dd)->Colon(), d) + for c in CartesianIndices(size(x)[d+1:end]) + ids = ntuple((dd)->c[dd], M-N) + vx = @view x[front_ids..., ids...] + vy = @view y[front_ids..., ids...] + unsafe_execute!(p,vx,vy) + end + end +end ## high-level integrations function LinearAlgebra.mul!(y::DenseCuArray{Ty}, p::CuFFTPlan{T}, x::DenseCuArray{T} ) where {Ty, T} assert_applicable(p,x,y) - unsafe_execute!(p,x,y) + unsafe_execute_trailing!(p,x, y) return y end -function Base.:(*)(p::cCuFFTPlan{T,K,true,N}, x::DenseCuArray{T,N}) where {T,K,N} +function Base.:(*)(p::cCuFFTPlan{T,K,true,N}, x::DenseCuArray{T,M}) where {T,K,N,M} assert_applicable(p,x) - unsafe_execute!(p,x,x) + unsafe_execute_trailing!(p,x, x) x end -function Base.:(*)(p::rCuFFTPlan{T,CUFFT_FORWARD,false,N}, x::DenseCuArray{T,N} - ) where {T<:cufftReals,N} +function Base.:(*)(p::rCuFFTPlan{T,CUFFT_FORWARD,false,N}, x::DenseCuArray{T,M} + ) where {T<:cufftReals,N,M} assert_applicable(p,x) @assert p.xtype ∈ [CUFFT_R2C,CUFFT_D2Z] - y = CuArray{complex(T),N}(undef, p.osz) - unsafe_execute!(p,x,y) + y = CuArray{complex(T),M}(undef, p.osz) + unsafe_execute_trailing!(p,x, y) y end -function Base.:(*)(p::rCuFFTPlan{T,CUFFT_INVERSE,false,N}, x::DenseCuArray{T,N} - ) where {T<:cufftComplexes,N} +function Base.:(*)(p::rCuFFTPlan{T,CUFFT_INVERSE,false,N}, x::DenseCuArray{T,M} + ) where {T<:cufftComplexes,N,M} assert_applicable(p,x) @assert p.xtype ∈ [CUFFT_C2R,CUFFT_Z2D] - y = CuArray{real(T),N}(undef, p.osz) - unsafe_execute!(p,x,y) + y = CuArray{real(T),M}(undef, p.osz) + unsafe_execute_trailing!(p,x, y) y end -function Base.:(*)(p::cCuFFTPlan{T,K,false,N}, x::DenseCuArray{T,N}) where {T,K,N} +function Base.:(*)(p::rCuFFTPlan{T,CUFFT_INVERSE,false,N}, x::DenseCuArray{T2,M} + ) where {T<:cufftComplexes,N,M, T2<:cufftReals} + x = complex.(x) + p*x +end + +function Base.:(*)(p::cCuFFTPlan{T,K,false,N}, x::DenseCuArray{T,M}) where {T,K,N,M} assert_applicable(p,x) - y = CuArray{T,N}(undef, p.osz) - unsafe_execute!(p,x,y) + y = CuArray{T,M}(undef, p.osz) + unsafe_execute_trailing!(p,x, y) y end diff --git a/lib/cufft/wrappers.jl b/lib/cufft/wrappers.jl index b7ffcfd025..3ab01f1eb9 100644 --- a/lib/cufft/wrappers.jl +++ b/lib/cufft/wrappers.jl @@ -11,12 +11,17 @@ version() = VersionNumber(cufftGetProperty(CUDA.MAJOR_VERSION), cufftGetProperty(CUDA.PATCH_LEVEL)) function cufftMakePlan(xtype::cufftType_t, xdims::Dims, region) + if any(diff(collect(region)) .< 1) + throw(ArgumentError("region must be an increasing sequence")) + end + if any(region .< 1 .|| region .> length(xdims)) + throw(ArgumentError("region can only refer to valid dimensions")) + end nrank = length(region) sz = [xdims[i] for i in region] csz = copy(sz) csz[1] = div(sz[1],2) + 1 batch = prod(xdims) ÷ prod(sz) - # initialize the plan handle handle_ref = Ref{cufftHandle}() cufftCreate(handle_ref) @@ -24,6 +29,7 @@ function cufftMakePlan(xtype::cufftType_t, xdims::Dims, region) # make the plan worksize_ref = Ref{Csize_t}() + # 1d, 2d and 3d plans can only be used for a single batch (i.e. the full array being transformed) if (nrank == 1) && (batch == 1) cufftMakePlan1d(handle, sz[1], xtype, 1, worksize_ref) elseif (nrank == 2) && (batch == 1) @@ -32,11 +38,23 @@ function cufftMakePlan(xtype::cufftType_t, xdims::Dims, region) cufftMakePlan3d(handle, sz[3], sz[2], sz[1], xtype, worksize_ref) else rsz = (length(sz) > 1) ? rsz = reverse(sz) : sz + if nrank > 3 + throw(ArgumentError("only up to three transform dimensions are allowed in one plan")) + end if ((region...,) == ((1:nrank)...,)) - # handle simple case ... simply! (for robustness) - cufftMakePlanMany(handle, nrank, Cint[rsz...], C_NULL, 1, 1, C_NULL, 1, 1, + # handle simple case, transforming the first nrank dimensions, ... simply! (for robustness) + # arguments are: plan, rank, transform-sizes, inembed, istride, idist, onembed, ostride, odist, type batch + cufftMakePlanMany(handle, nrank, Cint[rsz...], C_NULL, 1, 1, C_NULL, 1, 1, xtype, batch, worksize_ref) else + # reduce the array to the final transform direction. This situation will be picked up in the application of the plan later. + if region[end] != length(xdims) + # just make a plan for a smaller dimension number + xdims = xdims[1:region[end]] + batch = prod(xdims) ÷ prod(sz) + # throw(ArgumentError("batching dims must be sequential")) + end + if nrank==1 || all(diff(collect(region)) .== 1) # _stride: successive elements in innermost dimension # _dist: distance between first elements of batches @@ -45,9 +63,6 @@ function cufftMakePlan(xtype::cufftType_t, xdims::Dims, region) idist = prod(sz) cdist = prod(csz) else - if region[end] != length(xdims) - throw(ArgumentError("batching dims must be sequential")) - end istride = prod(xdims[1:region[1]-1]) idist = 1 cdist = 1 @@ -67,6 +82,7 @@ function cufftMakePlan(xtype::cufftType_t, xdims::Dims, region) inembed = cnembed end else + # multiple non-sequential transforms if any(diff(collect(region)) .< 1) throw(ArgumentError("region must be an increasing sequence")) end @@ -139,6 +155,7 @@ end const cufftHandleCacheKey = Tuple{CuContext, cufftType_t, Dims, Any} const idle_handles = HandleCache{cufftHandleCacheKey, cufftHandle}() function cufftGetPlan(args...) + ctx = context() handle = pop!(idle_handles, (ctx, args...)) do # make the plan diff --git a/src/indexing.jl b/src/indexing.jl index 013cbff1af..687a4348e8 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -11,9 +11,9 @@ using Base.Cartesian # we cannot use Base.LogicalIndex, which does not support indexing but requires iteration. # TODO: it should still be possible to use the same technique; # Base.LogicalIndex basically contains the same as our `findall` here does. -Base.to_index(::AbstractGPUArray, I::AbstractArray{Bool}) = findall(I) +Base.to_index(::CuArray, I::AbstractArray{Bool}) = findall(I) ## same for the trailing Array{Bool} optimization (see `_maybe_linear_logical_index` in Base) -Base.to_indices(A::AbstractGPUArray, inds, +Base.to_indices(A::CuArray, inds, I::Tuple{Union{Array{Bool,N}, BitArray{N}}}) where {N} = (Base.to_index(A, I[1]),) diff --git a/test/libraries/cufft.jl b/test/libraries/cufft.jl index 52499a68b8..6c3cbfa200 100644 --- a/test/libraries/cufft.jl +++ b/test/libraries/cufft.jl @@ -39,6 +39,7 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N} d_Z = pinv2 * d_Y Z = collect(d_Z) @test isapprox(Z, X, rtol = MYRTOL, atol = MYATOL) + end function in_place(X::AbstractArray{T,N}) where {T <: Complex,N} @@ -60,6 +61,9 @@ function batched(X::AbstractArray{T,N},region) where {T <: Complex,N} d_X = CuArray(X) p = plan_fft(d_X,region) d_Y = p * d_X + d_X2 = reshape(d_X, (size(d_X)..., 1)) + @test_throws ArgumentError p * d_X2 + Y = collect(d_Y) @test isapprox(Y, fftw_X, rtol = MYRTOL, atol = MYATOL) @@ -67,6 +71,10 @@ function batched(X::AbstractArray{T,N},region) where {T <: Complex,N} d_Z = pinv * d_Y Z = collect(d_Z) @test isapprox(Z, X, rtol = MYRTOL, atol = MYATOL) + + ldiv!(d_Z, p, d_Y) + Z = collect(d_Z) + @test isapprox(Z, X, rtol = MYRTOL, atol = MYATOL) end @testset for T in [ComplexF32, ComplexF64] @@ -143,11 +151,11 @@ end @testset "Batch 2D (in 4D)" begin dims = (N1,N2,N3,N4) - for region in [(1,2),(1,4),(3,4)] + for region in [(1,2),(1,4),(3,4),(1,3),(2,3),(2,),(3,)] X = rand(T, dims) batched(X,region) end - for region in [(1,3),(2,3),(2,4)] + for region in [(2,4)] X = rand(T, dims) @test_throws ArgumentError batched(X,region) end @@ -236,11 +244,11 @@ end @testset "Batch 2D (in 4D)" begin dims = (N1,N2,N3,N4) - for region in [(1,2),(1,4),(3,4)] + for region in [(1,2),(1,4),(3,4),(1,3),(2,3)] X = rand(T, dims) batched(X,region) end - for region in [(1,3),(2,3),(2,4)] + for region in [(2,4)] X = rand(T, dims) @test_throws ArgumentError batched(X,region) end