Skip to content

Commit

Permalink
transform
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed Sep 10, 2023
1 parent d60abb0 commit 5442811
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
6 changes: 5 additions & 1 deletion src/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ function OpConv(ch_in::Int, ch_out::Int, modes::NTuple{D, Int};
) where{D}

transform = isnothing(transform) ? (FFTW.rfft, FFTW.irfft) : transform
# TODO
if isnothing(transform)
transform = FourierTransform(mesh...)
end

OpConv(ch_in, ch_out, modes, transform, init)
end
Expand All @@ -129,7 +133,7 @@ function Lux.initialparameters(rng::Random.AbstractRNG, l::OpConv)
scale = one(Float32) / (l.ch_in * l.ch_out)

(;
W = scale * l.init(rng, ComplexF32, dims...),
W = scale * l.init(rng, ComplexF32, dims...), # TODO eltype(l.transform)
)
end

Expand Down
32 changes: 17 additions & 15 deletions src/transform.jl
Original file line number Diff line number Diff line change
@@ -1,43 +1,45 @@

#
# TODO: subtype AbstractTransform <: Lux.AbstractExplicitLayer and make
# TODO: OpConv(Bilinear) a hyper network/ Lux Container layer.
# TODO: then we can think about trainable transform types
abstract type AbstractTransform{D} end

"""
$TYPEDEF
"""
struct FourierTransform{D} <: AbstractTransform{D}
dims::NTuple{D, Int}
struct FourierTransform{D} <: AbstractTransform{D} # apply rfft on [1:D]
mesh::NTuple{D, Int}
end

FourierTransform(mesh::Int...) = FourierTransform(mesh)
Base.eltype(::FourierTransform) = ComplexF32
Base.ndims(::FourierTransform{D}) where{D} = D
Base.size(F::FourierTransform) = 1

function Base.:*(F::FourierTransform, x::AbstractArray)
FFTW.rfft(x, F.dims)
function Base.:*(F::FourierTransform{D}, x::AbstractArray) where{D}
FFTW.rfft(x, 1:D)
end

function Base.:\(F::FourierTransform, x::AbstractArray)
FFTW.irfft(x, F.d, F.dims)
function Base.:\(F::FourierTransform{D}, x::AbstractArray) where{D}
FFTW.irfft(x, F.mesh[1], 1:D)
end

"""
$TYPEDEF
"""
struct CosineTransform{D} <: AbstractTransform{D}
dims::NTuple{D, Int}
mesh::NTuple{D, Int}
end

CosineTransform(mesh::Int...) = FourierTransform(mesh)
Base.eltype(::CosineTransform) = Float32
Base.ndims(::CosineTransform{D}) where{D} = D
Base.size(F::CosineTransform) = 1

function Base.:*(F::CosineTransform, x::AbstractArray)
FFTW.dct(x, F.dims)
function Base.:*(F::CosineTransform{D}, x::AbstractArray) where{D}
dct(x, 1:D)
end

function Base.:\(F::CosineTransform, x::AbstractArray)
FFTW.idct(x, F.dims)
function Base.:\(F::CosineTransform{D}, x::AbstractArray) where{D}
idct(x, 1:D)
end
#

0 comments on commit 5442811

Please sign in to comment.