From 2b60a8334d2b3a8c4b810b39fcea77e0619b12cc Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Thu, 12 Dec 2024 21:00:45 +0000 Subject: [PATCH 01/17] Add ForwardDiff extension --- Project.toml | 4 +++- ext/AbstractFFTsForwardDiffExt.jl | 38 +++++++++++++++++++++++++++++++ test/abstractfftsforwarddiff.jl | 0 test/runtests.jl | 1 + 4 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 ext/AbstractFFTsForwardDiffExt.jl create mode 100644 test/abstractfftsforwarddiff.jl diff --git a/Project.toml b/Project.toml index 86ce0d0..e95e812 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AbstractFFTs" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.5.0" +version = "1.6.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -9,10 +9,12 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [extensions] AbstractFFTsChainRulesCoreExt = "ChainRulesCore" +AbstractFFTsForwardDiffExt = "ForwardDiff" AbstractFFTsTestExt = "Test" [compat] diff --git a/ext/AbstractFFTsForwardDiffExt.jl b/ext/AbstractFFTsForwardDiffExt.jl new file mode 100644 index 0000000..7891df4 --- /dev/null +++ b/ext/AbstractFFTsForwardDiffExt.jl @@ -0,0 +1,38 @@ +module AbstractFFTsForwardDiffExt + +using AbstractFFTs +import ForwardDiff +import ForwardDiff: Dual +import AbstractFFTs: Plan + +for P in (:Plan, :ScaledPlan) # need ScaledPlan to avoid ambiguities + @eval begin + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p * dual2array(x)) + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x)) + end +end + +mul!(y::AbstractArray{<:Union{Dual,Complex{<:Dual}}}, p::Plan, x::AbstractArray{<:Union{Dual,Complex{<:Dual}}}) = copyto!(y, p*x) + +AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x) +AbstractFFTs.complexfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) + 0im + +AbstractFFTs.realfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.realfloat.(x) +AbstractFFTs.realfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) + +dual2array(x::Array{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x) +dual2array(x::Array{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x))) +array2dual(DT::Type{<:Dual}, x::Array{T}) where T = reinterpret(reshape, DT, real(x)) +array2dual(DT::Type{<:Dual}, x::Array{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x))) + + +for plan in (:plan_fft, :plan_ifft, :plan_bfft, :plan_rfft) + @eval begin + AbstractFFTs.$plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = AbstractFFTs.$plan(dual2array(x), 1 .+ dims) + AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, dims=1:ndims(x)) = AbstractFFTs.$plan(dual2array(x), 1 .+ dims) + end +end + + + +end # module \ No newline at end of file diff --git a/test/abstractfftsforwarddiff.jl b/test/abstractfftsforwarddiff.jl new file mode 100644 index 0000000..e69de29 diff --git a/test/runtests.jl b/test/runtests.jl index 0560174..f46b180 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -274,3 +274,4 @@ end end end +include("abstractfftsforwarddiff.jl") \ No newline at end of file From 8a88755b8077bd3594e38ad92575af342710e8ec Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Thu, 12 Dec 2024 21:19:08 +0000 Subject: [PATCH 02/17] add tests --- test/TestPlans.jl | 8 ++--- test/abstractfftsforwarddiff.jl | 57 +++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/test/TestPlans.jl b/test/TestPlans.jl index 1c3459a..1c04eb6 100644 --- a/test/TestPlans.jl +++ b/test/TestPlans.jl @@ -95,8 +95,8 @@ function LinearAlgebra.mul!( dft!(y, x, p.region, 1) end -Base.:*(p::TestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x) -Base.:*(p::InverseTestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x) +Base.:*(p::TestPlan{T}, x::AbstractArray{T}) where T = mul!(similar(x, complex(float(eltype(x)))), p, x) +Base.:*(p::InverseTestPlan{T}, x::AbstractArray{T}) where T = mul!(similar(x, complex(float(eltype(x)))), p, x) mutable struct TestRPlan{T,N,G} <: Plan{T} region::G @@ -219,7 +219,7 @@ function LinearAlgebra.mul!(y::AbstractArray{<:Complex, N}, p::TestRPlan, x::Abs return y end -function Base.:*(p::TestRPlan, x::AbstractArray) +function Base.:*(p::TestRPlan{T}, x::AbstractArray{T}) where T # create output array firstdim = first(p.region)::Int d = size(x, firstdim) @@ -241,7 +241,7 @@ function LinearAlgebra.mul!(y::AbstractArray{<:Real, N}, p::InverseTestRPlan, x: real_invdft!(y, x, p.region) end -function Base.:*(p::InverseTestRPlan, x::AbstractArray) +function Base.:*(p::InverseTestRPlan{T}, x::AbstractArray{T}) where T # create output array firstdim = first(p.region)::Int d = p.d diff --git a/test/abstractfftsforwarddiff.jl b/test/abstractfftsforwarddiff.jl index e69de29..548033c 100644 --- a/test/abstractfftsforwarddiff.jl +++ b/test/abstractfftsforwarddiff.jl @@ -0,0 +1,57 @@ +using AbstractFFTs +using ForwardDiff +using Test +using ForwardDiff: Dual, partials, value + +@testset "ForwardDiff extension tests" begin + x1 = Dual.(1:4.0, 2:5, 3:6) + + @test AbstractFFTs.complexfloat(x1)[1] === AbstractFFTs.complexfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + 0im + @test AbstractFFTs.realfloat(x1)[1] === AbstractFFTs.realfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + + @test fft(x1, 1)[1] isa Complex{<:Dual} + + @testset "$f" for f in (fft, ifft, rfft, bfft) + @test value.(f(x1)) == f(value.(x1)) + @test partials.(real(f(x1)), 1) + im*partials.(imag(f(x1)), 1) == f(partials.(x1, 1)) + @test partials.(real(f(x1)), 2) + im*partials.(imag(f(x1)), 2) == f(partials.(x1, 2)) + end + + @test ifft(fft(x1)) ≈ x1 + @test irfft(rfft(x1), length(x1)) ≈ x1 + @test brfft(rfft(x1), length(x1)) ≈ 4x1 + + f = x -> real(fft([x; 0; 0])[1]) + @test derivative(f,0.1) ≈ 1 + + r = x -> real(rfft([x; 0; 0])[1]) + @test derivative(r,0.1) ≈ 1 + + + n = 100 + θ = range(0,2π; length=n+1)[1:end-1] + # emperical from Mathematical + @test derivative(ω -> fft(exp.(ω .* cos.(θ)))[1]/n, 1) ≈ 0.565159103992485 + + # c = x -> dct([x; 0; 0])[1] + # @test derivative(c,0.1) ≈ 1 + + @testset "matrix" begin + A = x1 * (1:10)' + @test value.(fft(A)) == fft(value.(A)) + @test partials.(fft(A), 1) == fft(partials.(A, 1)) + @test partials.(fft(A), 2) == fft(partials.(A, 2)) + + @test value.(fft(A, 1)) == fft(value.(A), 1) + @test partials.(fft(A, 1), 1) == fft(partials.(A, 1), 1) + @test partials.(fft(A, 1), 2) == fft(partials.(A, 2), 1) + + @test value.(fft(A, 2)) == fft(value.(A), 2) + @test partials.(fft(A, 2), 1) == fft(partials.(A, 1), 2) + @test partials.(fft(A, 2), 2) == fft(partials.(A, 2), 2) + end + + c1 = complex.(x1) + @test mul!(similar(c1), FFTW.plan_fft(x1), x1) == fft(x1) + @test mul!(similar(c1), FFTW.plan_fft(c1), c1) == fft(c1) +end \ No newline at end of file From 497763390fce00b5a10bbe41bab1b236a219acfc Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Thu, 12 Dec 2024 22:50:30 +0000 Subject: [PATCH 03/17] add tests --- Project.toml | 5 ++++- ext/AbstractFFTsForwardDiffExt.jl | 2 +- src/AbstractFFTs.jl | 1 + test/TestPlans.jl | 4 ++-- test/abstractfftsforwarddiff.jl | 4 ++-- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index e95e812..cee270a 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "1.6.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -22,6 +23,7 @@ Aqua = "0.8" ChainRulesCore = "1" ChainRulesTestUtils = "1" FiniteDifferences = "0.12" +ForwardDiff = "0.10" LinearAlgebra = "<0.0.1, 1" Random = "<0.0.1, 1" Test = "<0.0.1, 1" @@ -33,9 +35,10 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Random", "Test", "Unitful"] +test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "ForwardDiff", "Random", "Test", "Unitful"] diff --git a/ext/AbstractFFTsForwardDiffExt.jl b/ext/AbstractFFTsForwardDiffExt.jl index 7891df4..77ed4e7 100644 --- a/ext/AbstractFFTsForwardDiffExt.jl +++ b/ext/AbstractFFTsForwardDiffExt.jl @@ -8,7 +8,7 @@ import AbstractFFTs: Plan for P in (:Plan, :ScaledPlan) # need ScaledPlan to avoid ambiguities @eval begin Base.:*(p::AbstractFFTs.$P, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p * dual2array(x)) - Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x)) + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x)) end end diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 3225916..9b7c1c8 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -11,6 +11,7 @@ include("TestUtils.jl") if !isdefined(Base, :get_extension) include("../ext/AbstractFFTsChainRulesCoreExt.jl") include("../ext/AbstractFFTsTestExt.jl") + include("../ext/AbstractFFTsForwardDiffExt.jl") end end # module diff --git a/test/TestPlans.jl b/test/TestPlans.jl index 1c04eb6..594acaf 100644 --- a/test/TestPlans.jl +++ b/test/TestPlans.jl @@ -219,7 +219,7 @@ function LinearAlgebra.mul!(y::AbstractArray{<:Complex, N}, p::TestRPlan, x::Abs return y end -function Base.:*(p::TestRPlan{T}, x::AbstractArray{T}) where T +function Base.:*(p::TestRPlan{Typ}, x::AbstractArray{Typ}) where Typ # create output array firstdim = first(p.region)::Int d = size(x, firstdim) @@ -241,7 +241,7 @@ function LinearAlgebra.mul!(y::AbstractArray{<:Real, N}, p::InverseTestRPlan, x: real_invdft!(y, x, p.region) end -function Base.:*(p::InverseTestRPlan{T}, x::AbstractArray{T}) where T +function Base.:*(p::InverseTestRPlan{T}, x::AbstractArray{Complex{T}}) where T # create output array firstdim = first(p.region)::Int d = p.d diff --git a/test/abstractfftsforwarddiff.jl b/test/abstractfftsforwarddiff.jl index 548033c..637d537 100644 --- a/test/abstractfftsforwarddiff.jl +++ b/test/abstractfftsforwarddiff.jl @@ -52,6 +52,6 @@ using ForwardDiff: Dual, partials, value end c1 = complex.(x1) - @test mul!(similar(c1), FFTW.plan_fft(x1), x1) == fft(x1) - @test mul!(similar(c1), FFTW.plan_fft(c1), c1) == fft(c1) + @test mul!(similar(c1), plan_fft(x1), x1) == fft(x1) + @test mul!(similar(c1), plan_fft(c1), c1) == fft(c1) end \ No newline at end of file From 1695c7e3a93e20bbe45b2e93e98ca2a3bd119b8f Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Fri, 13 Dec 2024 08:39:20 +0000 Subject: [PATCH 04/17] Add plan_mul to capture partial interface implementation --- ext/AbstractFFTsForwardDiffExt.jl | 4 ++-- src/definitions.jl | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ext/AbstractFFTsForwardDiffExt.jl b/ext/AbstractFFTsForwardDiffExt.jl index 77ed4e7..ad786f7 100644 --- a/ext/AbstractFFTsForwardDiffExt.jl +++ b/ext/AbstractFFTsForwardDiffExt.jl @@ -7,8 +7,8 @@ import AbstractFFTs: Plan for P in (:Plan, :ScaledPlan) # need ScaledPlan to avoid ambiguities @eval begin - Base.:*(p::AbstractFFTs.$P, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p * dual2array(x)) - Base.:*(p::AbstractFFTs.$P, x::AbstractArray{Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x)) + AbstractFFTs.plan_mul(p::AbstractFFTs.$P, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p * dual2array(x)) + AbstractFFTs.plan_mul(p::AbstractFFTs.$P, x::AbstractArray{Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x)) end end diff --git a/src/definitions.jl b/src/definitions.jl index f4f1c19..8294ff2 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -221,7 +221,9 @@ rfft(x::AbstractArray{<:Union{Integer,Rational}}, region=1:ndims(x)) = rfft(real plan_rfft(x::AbstractArray, region; kws...) = plan_rfft(realfloat(x), region; kws...) # only require implementation to provide *(::Plan{T}, ::Array{T}) -*(p::Plan{T}, x::AbstractArray) where {T} = p * copy1(T, x) +plan_mul(p::Plan{T}, x::AbstractArray) where {T} = p * copy1(T, x) +plan_mul(p::Plan{T}, x::AbstractArray{T}) where {T} = error("The plan interface requires overloading *(::MyPlan{T}, ::AbstractArray{T}) where T") +*(p::Plan, x::AbstractArray) = plan_mul(p, x) # Implementations should also implement mul!(Y, plan, X) so as to support # pre-allocated output arrays. We don't define * in terms of mul! From 6028c0e2e38641f9a60d66fbde7666c37d666ecc Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Fri, 13 Dec 2024 09:35:15 +0000 Subject: [PATCH 05/17] Add DualPlan --- ext/AbstractFFTsForwardDiffExt.jl | 42 +++++++++++++++++++++++-------- test/abstractfftsforwarddiff.jl | 25 ++++++++++-------- 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/ext/AbstractFFTsForwardDiffExt.jl b/ext/AbstractFFTsForwardDiffExt.jl index ad786f7..abf1e98 100644 --- a/ext/AbstractFFTsForwardDiffExt.jl +++ b/ext/AbstractFFTsForwardDiffExt.jl @@ -1,18 +1,11 @@ module AbstractFFTsForwardDiffExt using AbstractFFTs +using AbstractFFTs.LinearAlgebra import ForwardDiff import ForwardDiff: Dual -import AbstractFFTs: Plan +import AbstractFFTs: Plan, mul! -for P in (:Plan, :ScaledPlan) # need ScaledPlan to avoid ambiguities - @eval begin - AbstractFFTs.plan_mul(p::AbstractFFTs.$P, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p * dual2array(x)) - AbstractFFTs.plan_mul(p::AbstractFFTs.$P, x::AbstractArray{Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x)) - end -end - -mul!(y::AbstractArray{<:Union{Dual,Complex{<:Dual}}}, p::Plan, x::AbstractArray{<:Union{Dual,Complex{<:Dual}}}) = copyto!(y, p*x) AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x) AbstractFFTs.complexfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) + 0im @@ -26,10 +19,37 @@ array2dual(DT::Type{<:Dual}, x::Array{T}) where T = reinterpret(reshape, DT, rea array2dual(DT::Type{<:Dual}, x::Array{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x))) +######## +# DualPlan +# represents a plan acting on dual numbers. We wrap a plan acting on a higher dimensional tensor +# as an array of duals can be reinterpreted as a higher dimensional array. +# This allows standard FFTW plans to act on arrays of duals. +##### +struct DualPlan{T,P} <: Plan{T} + p::P + DualPlan{T,P}(p) where {T,P} = new(p) +end + +DualPlan(::Type{Dual{Tag,V,N}}, p::Plan{T}) where {Tag,T<:Real,V,N} = DualPlan{Dual{Tag,T,N},typeof(p)}(p) +DualPlan(::Type{Dual{Tag,V,N}}, p::Plan{Complex{T}}) where {Tag,T<:Real,V,N} = DualPlan{Complex{Dual{Tag,T,N}},typeof(p)}(p) +Base.size(p::DualPlan) = Base.tail(size(p.p)) +Base.:*(p::DualPlan{DT}, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p.p * dual2array(x)) +Base.:*(p::DualPlan{Complex{DT}}, x::AbstractArray{Complex{DT}}) where DT<:Dual = array2dual(DT, p.p * dual2array(x)) + +function LinearAlgebra.mul!(y::AbstractArray{<:Dual}, p::DualPlan, x::AbstractArray{<:Dual}) + LinearAlgebra.mul!(dual2array(y), p.p, dual2array(x)) # even though `Dual` are immutable, when in an `Array` they can be modified. + y +end + +function LinearAlgebra.mul!(y::AbstractArray{<:Complex{<:Dual}}, p::DualPlan, x::AbstractArray{<:Union{Dual,Complex{<:Dual}}}) + copyto!(y, p*x) # Complex duals cannot be reinterpret in-place +end + + for plan in (:plan_fft, :plan_ifft, :plan_bfft, :plan_rfft) @eval begin - AbstractFFTs.$plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = AbstractFFTs.$plan(dual2array(x), 1 .+ dims) - AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, dims=1:ndims(x)) = AbstractFFTs.$plan(dual2array(x), 1 .+ dims) + AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = DualPlan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims)) + AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x)) where D<:Dual = DualPlan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims)) end end diff --git a/test/abstractfftsforwarddiff.jl b/test/abstractfftsforwarddiff.jl index 637d537..23c8505 100644 --- a/test/abstractfftsforwarddiff.jl +++ b/test/abstractfftsforwarddiff.jl @@ -3,6 +3,9 @@ using ForwardDiff using Test using ForwardDiff: Dual, partials, value +# Needed until https://github.com/JuliaDiff/ForwardDiff.jl/pull/732 is merged +complexpartials(x, k) = partials(real(x), k) + im*partials(imag(x), k) + @testset "ForwardDiff extension tests" begin x1 = Dual.(1:4.0, 2:5, 3:6) @@ -13,8 +16,8 @@ using ForwardDiff: Dual, partials, value @testset "$f" for f in (fft, ifft, rfft, bfft) @test value.(f(x1)) == f(value.(x1)) - @test partials.(real(f(x1)), 1) + im*partials.(imag(f(x1)), 1) == f(partials.(x1, 1)) - @test partials.(real(f(x1)), 2) + im*partials.(imag(f(x1)), 2) == f(partials.(x1, 2)) + @test complexpartials.(f(x1), 1) == f(partials.(x1, 1)) + @test complexpartials.(f(x1), 2) == f(partials.(x1, 2)) end @test ifft(fft(x1)) ≈ x1 @@ -22,16 +25,16 @@ using ForwardDiff: Dual, partials, value @test brfft(rfft(x1), length(x1)) ≈ 4x1 f = x -> real(fft([x; 0; 0])[1]) - @test derivative(f,0.1) ≈ 1 + @test ForwardDiff.derivative(f,0.1) ≈ 1 r = x -> real(rfft([x; 0; 0])[1]) - @test derivative(r,0.1) ≈ 1 + @test ForwardDiff.derivative(r,0.1) ≈ 1 n = 100 θ = range(0,2π; length=n+1)[1:end-1] # emperical from Mathematical - @test derivative(ω -> fft(exp.(ω .* cos.(θ)))[1]/n, 1) ≈ 0.565159103992485 + @test ForwardDiff.derivative(ω -> fft(exp.(ω .* cos.(θ)))[1]/n, 1) ≈ 0.565159103992485 # c = x -> dct([x; 0; 0])[1] # @test derivative(c,0.1) ≈ 1 @@ -39,16 +42,16 @@ using ForwardDiff: Dual, partials, value @testset "matrix" begin A = x1 * (1:10)' @test value.(fft(A)) == fft(value.(A)) - @test partials.(fft(A), 1) == fft(partials.(A, 1)) - @test partials.(fft(A), 2) == fft(partials.(A, 2)) + @test complexpartials.(fft(A), 1) == fft(partials.(A, 1)) + @test complexpartials.(fft(A), 2) == fft(partials.(A, 2)) @test value.(fft(A, 1)) == fft(value.(A), 1) - @test partials.(fft(A, 1), 1) == fft(partials.(A, 1), 1) - @test partials.(fft(A, 1), 2) == fft(partials.(A, 2), 1) + @test complexpartials.(fft(A, 1), 1) == fft(partials.(A, 1), 1) + @test complexpartials.(fft(A, 1), 2) == fft(partials.(A, 2), 1) @test value.(fft(A, 2)) == fft(value.(A), 2) - @test partials.(fft(A, 2), 1) == fft(partials.(A, 1), 2) - @test partials.(fft(A, 2), 2) == fft(partials.(A, 2), 2) + @test complexpartials.(fft(A, 2), 1) == fft(partials.(A, 1), 2) + @test complexpartials.(fft(A, 2), 2) == fft(partials.(A, 2), 2) end c1 = complex.(x1) From cc6e3d8d1a3bd22a105345e29d7f6cce8e925f2d Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Fri, 13 Dec 2024 09:40:38 +0000 Subject: [PATCH 06/17] revert definitions --- src/definitions.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index 8294ff2..f4f1c19 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -221,9 +221,7 @@ rfft(x::AbstractArray{<:Union{Integer,Rational}}, region=1:ndims(x)) = rfft(real plan_rfft(x::AbstractArray, region; kws...) = plan_rfft(realfloat(x), region; kws...) # only require implementation to provide *(::Plan{T}, ::Array{T}) -plan_mul(p::Plan{T}, x::AbstractArray) where {T} = p * copy1(T, x) -plan_mul(p::Plan{T}, x::AbstractArray{T}) where {T} = error("The plan interface requires overloading *(::MyPlan{T}, ::AbstractArray{T}) where T") -*(p::Plan, x::AbstractArray) = plan_mul(p, x) +*(p::Plan{T}, x::AbstractArray) where {T} = p * copy1(T, x) # Implementations should also implement mul!(Y, plan, X) so as to support # pre-allocated output arrays. We don't define * in terms of mul! From 159253176aa7337fcb0520f5606fbf7388dcadff Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Fri, 13 Dec 2024 09:40:56 +0000 Subject: [PATCH 07/17] Revert TestPlans --- test/TestPlans.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/TestPlans.jl b/test/TestPlans.jl index 594acaf..1c3459a 100644 --- a/test/TestPlans.jl +++ b/test/TestPlans.jl @@ -95,8 +95,8 @@ function LinearAlgebra.mul!( dft!(y, x, p.region, 1) end -Base.:*(p::TestPlan{T}, x::AbstractArray{T}) where T = mul!(similar(x, complex(float(eltype(x)))), p, x) -Base.:*(p::InverseTestPlan{T}, x::AbstractArray{T}) where T = mul!(similar(x, complex(float(eltype(x)))), p, x) +Base.:*(p::TestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x) +Base.:*(p::InverseTestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x) mutable struct TestRPlan{T,N,G} <: Plan{T} region::G @@ -219,7 +219,7 @@ function LinearAlgebra.mul!(y::AbstractArray{<:Complex, N}, p::TestRPlan, x::Abs return y end -function Base.:*(p::TestRPlan{Typ}, x::AbstractArray{Typ}) where Typ +function Base.:*(p::TestRPlan, x::AbstractArray) # create output array firstdim = first(p.region)::Int d = size(x, firstdim) @@ -241,7 +241,7 @@ function LinearAlgebra.mul!(y::AbstractArray{<:Real, N}, p::InverseTestRPlan, x: real_invdft!(y, x, p.region) end -function Base.:*(p::InverseTestRPlan{T}, x::AbstractArray{Complex{T}}) where T +function Base.:*(p::InverseTestRPlan, x::AbstractArray) # create output array firstdim = first(p.region)::Int d = p.d From 8ffa7dfa032d39a807f994f2bb6acc988bf5a47f Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Tue, 17 Dec 2024 22:11:16 +0000 Subject: [PATCH 08/17] tests pass --- ext/AbstractFFTsForwardDiffExt.jl | 14 +++++++++++--- src/AbstractFFTs.jl | 5 +++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/ext/AbstractFFTsForwardDiffExt.jl b/ext/AbstractFFTsForwardDiffExt.jl index abf1e98..089ac84 100644 --- a/ext/AbstractFFTsForwardDiffExt.jl +++ b/ext/AbstractFFTsForwardDiffExt.jl @@ -4,7 +4,7 @@ using AbstractFFTs using AbstractFFTs.LinearAlgebra import ForwardDiff import ForwardDiff: Dual -import AbstractFFTs: Plan, mul! +import AbstractFFTs: Plan, mul!, dualplan, dual2array AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x) @@ -32,6 +32,7 @@ end DualPlan(::Type{Dual{Tag,V,N}}, p::Plan{T}) where {Tag,T<:Real,V,N} = DualPlan{Dual{Tag,T,N},typeof(p)}(p) DualPlan(::Type{Dual{Tag,V,N}}, p::Plan{Complex{T}}) where {Tag,T<:Real,V,N} = DualPlan{Complex{Dual{Tag,T,N}},typeof(p)}(p) +dualplan(D, p) = DualPlan(D, p) Base.size(p::DualPlan) = Base.tail(size(p.p)) Base.:*(p::DualPlan{DT}, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p.p * dual2array(x)) Base.:*(p::DualPlan{Complex{DT}}, x::AbstractArray{Complex{DT}}) where DT<:Dual = array2dual(DT, p.p * dual2array(x)) @@ -48,11 +49,18 @@ end for plan in (:plan_fft, :plan_ifft, :plan_bfft, :plan_rfft) @eval begin - AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = DualPlan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims)) - AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x)) where D<:Dual = DualPlan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims)) + AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims)) + AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims)) end end +for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex? + @eval begin + AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims)) + AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, d::Integer, mdims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), d, 1 .+ dims)) + end +end + end # module \ No newline at end of file diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 9b7c1c8..5f8feff 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -8,10 +8,15 @@ export fft, ifft, bfft, fft!, ifft!, bfft!, include("definitions.jl") include("TestUtils.jl") +# Create function used by multiple extension as loading order is not guaranteed +function dualplan end +function dual2array end + if !isdefined(Base, :get_extension) include("../ext/AbstractFFTsChainRulesCoreExt.jl") include("../ext/AbstractFFTsTestExt.jl") include("../ext/AbstractFFTsForwardDiffExt.jl") end + end # module From 4943167ae4df3e6fb42676b4e119db53bbcb34bd Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Tue, 17 Dec 2024 22:20:51 +0000 Subject: [PATCH 09/17] Update AbstractFFTsForwardDiffExt.jl --- ext/AbstractFFTsForwardDiffExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/AbstractFFTsForwardDiffExt.jl b/ext/AbstractFFTsForwardDiffExt.jl index 089ac84..d09de9b 100644 --- a/ext/AbstractFFTsForwardDiffExt.jl +++ b/ext/AbstractFFTsForwardDiffExt.jl @@ -58,7 +58,7 @@ end for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex? @eval begin AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims)) - AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, d::Integer, mdims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), d, 1 .+ dims)) + AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, d::Integer, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), d, 1 .+ dims)) end end From 35f794bb6fb19ff393ab4ec5acc1d21b36194adc Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Wed, 18 Dec 2024 10:27:21 +0000 Subject: [PATCH 10/17] Update Project.toml Co-authored-by: David Widmann --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index cee270a..c02f119 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ version = "1.6.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" From a516bd0555a3fc8ac9b89902dc3fad9dd1b6dacf Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Wed, 18 Dec 2024 10:31:58 +0000 Subject: [PATCH 11/17] Overload only _fftfloat --- ext/AbstractFFTsForwardDiffExt.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ext/AbstractFFTsForwardDiffExt.jl b/ext/AbstractFFTsForwardDiffExt.jl index d09de9b..18f177f 100644 --- a/ext/AbstractFFTsForwardDiffExt.jl +++ b/ext/AbstractFFTsForwardDiffExt.jl @@ -7,11 +7,7 @@ import ForwardDiff: Dual import AbstractFFTs: Plan, mul!, dualplan, dual2array -AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x) -AbstractFFTs.complexfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) + 0im - -AbstractFFTs.realfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.realfloat.(x) -AbstractFFTs.realfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) +AbstractFFTs._fftfloat(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,AbstractFFTs._fftfloat(V),N} dual2array(x::Array{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x) dual2array(x::Array{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x))) From a968cd197565eb376e2d8c89014f703907c0dc93 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Wed, 18 Dec 2024 10:33:19 +0000 Subject: [PATCH 12/17] Only load/test ForwardDiff on versions that support extensions --- src/AbstractFFTs.jl | 1 - test/runtests.jl | 20 +++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 5f8feff..2c64a70 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -15,7 +15,6 @@ function dual2array end if !isdefined(Base, :get_extension) include("../ext/AbstractFFTsChainRulesCoreExt.jl") include("../ext/AbstractFFTsTestExt.jl") - include("../ext/AbstractFFTsForwardDiffExt.jl") end diff --git a/test/runtests.jl b/test/runtests.jl index f46b180..ceba516 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,7 +17,7 @@ Random.seed!(1234) # Load example plan implementation. include("TestPlans.jl") -# Run interface tests for TestPlans +# Run interface tests for TestPlans AbstractFFTs.TestUtils.test_complex_ffts(Array) AbstractFFTs.TestUtils.test_real_ffts(Array) @@ -180,17 +180,17 @@ end p0 = plan_fft(zeros(ComplexF64, 3)) p = TestPlans.WrapperTestPlan(p0) u = rand(ComplexF64, 3) - @test p' * u ≈ p0' * u + @test p' * u ≈ p0' * u # rfft p0 = plan_rfft(zeros(3)) p = TestPlans.WrapperTestPlan(p0) u = rand(ComplexF64, 2) - @test p' * u ≈ p0' * u + @test p' * u ≈ p0' * u # brfft p0 = plan_brfft(zeros(ComplexF64, 3), 5) p = TestPlans.WrapperTestPlan(p0) u = rand(Float64, 5) - @test p' * u ≈ p0' * u + @test p' * u ≈ p0' * u end @testset "ChainRules" begin @@ -238,7 +238,7 @@ end test_frule(f, complex_x, dims) test_rrule(f, complex_x, dims) end - for (pf, pf!) in ((plan_fft, plan_fft!), (plan_ifft, plan_ifft!), (plan_bfft, plan_bfft!)) + for (pf, pf!) in ((plan_fft, plan_fft!), (plan_ifft, plan_ifft!), (plan_bfft, plan_bfft!)) test_frule(*, pf(x, dims), x) test_rrule(*, pf(x, dims), x) test_frule(*, pf(complex_x, dims), complex_x) @@ -248,7 +248,7 @@ end @test_throws ArgumentError ChainRulesCore.rrule(*, pf!(complex_x, dims), complex_x) end - # rfft + # rfft test_frule(rfft, x, dims) test_rrule(rfft, x, dims) test_frule(*, plan_rfft(x, dims), x) @@ -266,12 +266,14 @@ end for pf in (plan_irfft, plan_brfft) for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) test_frule(*, pf(complex_x, d, dims), complex_x) - test_rrule(*, pf(complex_x, d, dims), complex_x) + test_rrule(*, pf(complex_x, d, dims), complex_x) end end end end end end - -include("abstractfftsforwarddiff.jl") \ No newline at end of file + +if isdefined(Base, :get_extension) + include("abstractfftsforwarddiff.jl") +end \ No newline at end of file From 76b776e6c36d0ab5a2e813428a52499d0d0db041 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Wed, 18 Dec 2024 10:39:10 +0000 Subject: [PATCH 13/17] Update abstractfftsforwarddiff.jl --- test/abstractfftsforwarddiff.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/abstractfftsforwarddiff.jl b/test/abstractfftsforwarddiff.jl index 23c8505..9eacda6 100644 --- a/test/abstractfftsforwarddiff.jl +++ b/test/abstractfftsforwarddiff.jl @@ -9,8 +9,8 @@ complexpartials(x, k) = partials(real(x), k) + im*partials(imag(x), k) @testset "ForwardDiff extension tests" begin x1 = Dual.(1:4.0, 2:5, 3:6) - @test AbstractFFTs.complexfloat(x1)[1] === AbstractFFTs.complexfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + 0im - @test AbstractFFTs.realfloat(x1)[1] === AbstractFFTs.realfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + @test AbstractFFTs.complexfloat(x1)[1] === Dual(1.0, 2.0, 3.0) + 0im + @test AbstractFFTs.realfloat(x1)[1] === Dual(1.0, 2.0, 3.0) @test fft(x1, 1)[1] isa Complex{<:Dual} From 81168f3d1b9b1543d5555899c3edb5e4d2cc3d4f Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Wed, 18 Dec 2024 10:41:48 +0000 Subject: [PATCH 14/17] Update src/AbstractFFTs.jl Co-authored-by: David Widmann --- src/AbstractFFTs.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 2c64a70..52538bf 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -17,5 +17,4 @@ if !isdefined(Base, :get_extension) include("../ext/AbstractFFTsTestExt.jl") end - end # module From cde9caaf118f85ba28a3dab15d79faccb110ee25 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Wed, 18 Dec 2024 12:16:32 +0000 Subject: [PATCH 15/17] add complex tests --- ext/AbstractFFTsForwardDiffExt.jl | 9 +++------ test/abstractfftsforwarddiff.jl | 15 ++++++++++++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/ext/AbstractFFTsForwardDiffExt.jl b/ext/AbstractFFTsForwardDiffExt.jl index 18f177f..2439699 100644 --- a/ext/AbstractFFTsForwardDiffExt.jl +++ b/ext/AbstractFFTsForwardDiffExt.jl @@ -45,17 +45,14 @@ end for plan in (:plan_fft, :plan_ifft, :plan_bfft, :plan_rfft) @eval begin - AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims)) - AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims)) + AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x); kwds...) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims; kwds...)) + AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x); kwds...) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims; kwds...)) end end for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex? - @eval begin - AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims)) - AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, d::Integer, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), d, 1 .+ dims)) - end + @eval AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, d::Integer, dims=1:ndims(x); kwds...) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), d, 1 .+ dims; kwds...)) end diff --git a/test/abstractfftsforwarddiff.jl b/test/abstractfftsforwarddiff.jl index 9eacda6..08e53fc 100644 --- a/test/abstractfftsforwarddiff.jl +++ b/test/abstractfftsforwarddiff.jl @@ -8,6 +8,7 @@ complexpartials(x, k) = partials(real(x), k) + im*partials(imag(x), k) @testset "ForwardDiff extension tests" begin x1 = Dual.(1:4.0, 2:5, 3:6) + c1 = Dual.(1:4.0, 2:5, 3:6) + im*Dual.(2:5.0, 3:6, 3:6) @test AbstractFFTs.complexfloat(x1)[1] === Dual(1.0, 2.0, 3.0) + 0im @test AbstractFFTs.realfloat(x1)[1] === Dual(1.0, 2.0, 3.0) @@ -54,7 +55,15 @@ complexpartials(x, k) = partials(real(x), k) + im*partials(imag(x), k) @test complexpartials.(fft(A, 2), 2) == fft(partials.(A, 2), 2) end - c1 = complex.(x1) - @test mul!(similar(c1), plan_fft(x1), x1) == fft(x1) - @test mul!(similar(c1), plan_fft(c1), c1) == fft(c1) + @testset "complex" begin + @test fft(c1) ≈ fft(real(c1)) + im*fft(imag(c1)) + dest = similar(c1) + @test mul!(dest, plan_fft(x1), x1) == fft(x1) == dest + @test mul!(dest, plan_fft(c1), c1) == fft(c1) == dest + + C = c1 * ((1:10) .+ im*(2:11))' + @test fft(C) ≈ fft(real(C)) + im*fft(imag(C)) + dest = similar(C) + @test mul!(dest, plan_fft(C), C) == fft(C) == dest + end end \ No newline at end of file From 2ea43f39a3ecb1b6d85840973d4153053d37ee0e Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Thu, 19 Dec 2024 21:51:17 +0000 Subject: [PATCH 16/17] Generalise dual2array/array2dual for strided --- ext/AbstractFFTsForwardDiffExt.jl | 8 ++++---- test/abstractfftsforwarddiff.jl | 3 --- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/ext/AbstractFFTsForwardDiffExt.jl b/ext/AbstractFFTsForwardDiffExt.jl index 2439699..029e09d 100644 --- a/ext/AbstractFFTsForwardDiffExt.jl +++ b/ext/AbstractFFTsForwardDiffExt.jl @@ -9,10 +9,10 @@ import AbstractFFTs: Plan, mul!, dualplan, dual2array AbstractFFTs._fftfloat(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,AbstractFFTs._fftfloat(V),N} -dual2array(x::Array{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x) -dual2array(x::Array{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x))) -array2dual(DT::Type{<:Dual}, x::Array{T}) where T = reinterpret(reshape, DT, real(x)) -array2dual(DT::Type{<:Dual}, x::Array{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x))) +dual2array(x::StridedArray{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x) +dual2array(x::StridedArray{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x))) +array2dual(DT::Type{<:Dual}, x::StridedArray{T}) where T = reinterpret(reshape, DT, real(x)) +array2dual(DT::Type{<:Dual}, x::StridedArray{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x))) ######## diff --git a/test/abstractfftsforwarddiff.jl b/test/abstractfftsforwarddiff.jl index 08e53fc..e612241 100644 --- a/test/abstractfftsforwarddiff.jl +++ b/test/abstractfftsforwarddiff.jl @@ -37,9 +37,6 @@ complexpartials(x, k) = partials(real(x), k) + im*partials(imag(x), k) # emperical from Mathematical @test ForwardDiff.derivative(ω -> fft(exp.(ω .* cos.(θ)))[1]/n, 1) ≈ 0.565159103992485 - # c = x -> dct([x; 0; 0])[1] - # @test derivative(c,0.1) ≈ 1 - @testset "matrix" begin A = x1 * (1:10)' @test value.(fft(A)) == fft(value.(A)) From 4fe2464ff968fa3a16305f9657101fceeb2c6246 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Thu, 19 Dec 2024 22:19:43 +0000 Subject: [PATCH 17/17] Update abstractfftsforwarddiff.jl --- test/abstractfftsforwarddiff.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/abstractfftsforwarddiff.jl b/test/abstractfftsforwarddiff.jl index e612241..8a2b3e7 100644 --- a/test/abstractfftsforwarddiff.jl +++ b/test/abstractfftsforwarddiff.jl @@ -14,6 +14,8 @@ complexpartials(x, k) = partials(real(x), k) + im*partials(imag(x), k) @test AbstractFFTs.realfloat(x1)[1] === Dual(1.0, 2.0, 3.0) @test fft(x1, 1)[1] isa Complex{<:Dual} + @test plan_fft(x1, 1) * x1 == fft(x1, 1) + @test size(plan_fft(x1,1)) == (4,) @testset "$f" for f in (fft, ifft, rfft, bfft) @test value.(f(x1)) == f(value.(x1))