From 5442811f927bc66f9e856966d58de780267aed8d Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sun, 10 Sep 2023 16:17:21 -0400 Subject: [PATCH] transform --- src/operator.jl | 6 +++++- src/transform.jl | 32 +++++++++++++++++--------------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/operator.jl b/src/operator.jl index 3fa2371c..f6313434 100644 --- a/src/operator.jl +++ b/src/operator.jl @@ -119,6 +119,10 @@ function OpConv(ch_in::Int, ch_out::Int, modes::NTuple{D, Int}; ) where{D} transform = isnothing(transform) ? (FFTW.rfft, FFTW.irfft) : transform + # TODO + if isnothing(transform) + transform = FourierTransform(mesh...) + end OpConv(ch_in, ch_out, modes, transform, init) end @@ -129,7 +133,7 @@ function Lux.initialparameters(rng::Random.AbstractRNG, l::OpConv) scale = one(Float32) / (l.ch_in * l.ch_out) (; - W = scale * l.init(rng, ComplexF32, dims...), + W = scale * l.init(rng, ComplexF32, dims...), # TODO eltype(l.transform) ) end diff --git a/src/transform.jl b/src/transform.jl index a234e674..872ce69b 100644 --- a/src/transform.jl +++ b/src/transform.jl @@ -1,24 +1,26 @@ - +# +# TODO: subtype AbstractTransform <: Lux.AbstractExplicitLayer and make +# TODO: OpConv(Bilinear) a hyper network/ Lux Container layer. +# TODO: then we can think about trainable transform types abstract type AbstractTransform{D} end """ $TYPEDEF - """ -struct FourierTransform{D} <: AbstractTransform{D} - dims::NTuple{D, Int} +struct FourierTransform{D} <: AbstractTransform{D} # apply rfft on [1:D] + mesh::NTuple{D, Int} end +FourierTransform(mesh::Int...) = FourierTransform(mesh) Base.eltype(::FourierTransform) = ComplexF32 Base.ndims(::FourierTransform{D}) where{D} = D -Base.size(F::FourierTransform) = 1 -function Base.:*(F::FourierTransform, x::AbstractArray) - FFTW.rfft(x, F.dims) +function Base.:*(F::FourierTransform{D}, x::AbstractArray) where{D} + FFTW.rfft(x, 1:D) end -function Base.:\(F::FourierTransform, x::AbstractArray) - FFTW.irfft(x, F.d, F.dims) +function Base.:\(F::FourierTransform{D}, x::AbstractArray) where{D} + FFTW.irfft(x, F.mesh[1], 1:D) end """ @@ -26,18 +28,18 @@ $TYPEDEF """ struct CosineTransform{D} <: AbstractTransform{D} - dims::NTuple{D, Int} + mesh::NTuple{D, Int} end +CosineTransform(mesh::Int...) = FourierTransform(mesh) Base.eltype(::CosineTransform) = Float32 Base.ndims(::CosineTransform{D}) where{D} = D -Base.size(F::CosineTransform) = 1 -function Base.:*(F::CosineTransform, x::AbstractArray) - FFTW.dct(x, F.dims) +function Base.:*(F::CosineTransform{D}, x::AbstractArray) where{D} + dct(x, 1:D) end -function Base.:\(F::CosineTransform, x::AbstractArray) - FFTW.idct(x, F.dims) +function Base.:\(F::CosineTransform{D}, x::AbstractArray) where{D} + idct(x, 1:D) end #