Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fft_poisson_solver constructor #3890

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Solvers/Solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ const GridWithFourierTridiagonalSolver = Union{XYRegularRG, XZRegularRG, YZRegul

fft_poisson_solver(grid::XYZRegularRG) = FFTBasedPoissonSolver(grid)
fft_poisson_solver(grid::GridWithFourierTridiagonalSolver) =
FourierTridiagonalPoissonSolver(grid.underlying_grid)
FourierTridiagonalPoissonSolver(grid)

end # module
6 changes: 3 additions & 3 deletions src/Solvers/conjugate_gradient_poisson_solver.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Oceananigans.Operators: divᶜᶜᶜ, ∇²ᶜᶜᶜ
using Oceananigans.Operators
using Oceananigans.ImmersedBoundaries: ImmersedBoundaryGrid
using Statistics: mean

Expand Down Expand Up @@ -107,15 +107,15 @@ function compute_preconditioner_rhs!(solver::FourierTridiagonalPoissonSolver, rh
arch = architecture(grid)
tridiagonal_dir = solver.batched_tridiagonal_solver.tridiagonal_direction
launch!(arch, grid, :xyz, fourier_tridiagonal_preconditioner_rhs!,
solver.storage, tridiagonal_dir, rhs)
solver.storage, tridiagonal_dir, grid, rhs)
return nothing
end

const FFTBasedPreconditioner = Union{FFTBasedPoissonSolver, FourierTridiagonalPoissonSolver}

function precondition!(p, preconditioner::FFTBasedPreconditioner, r, args...)
compute_preconditioner_rhs!(preconditioner, r)
p = solve!(p, preconditioner)
solve!(p, preconditioner)

mean_p = mean(p)
grid = p.grid
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ CUDA.allowscalar() do
if group == :poisson_solvers_2 || group == :all
@testset "Poisson Solvers 2" begin
include("test_poisson_solvers_stretched_grids.jl")
include("test_conjugate_gradient_poisson_solver.jl")
end
end

Expand Down
28 changes: 28 additions & 0 deletions test/test_conjugate_gradient_poisson_solver.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
include("dependencies_for_runtests.jl")
using Oceananigans.Solvers: fft_poisson_solver, ConjugateGradientPoissonSolver

@testset "Conjugate gradient Poisson solver" begin
@info "Testing Conjugate gradient poisson solver..."
for arch in archs
@testset "Conjugate gradient Poisson solver unit tests [$arch]" begin
@info "Unit testing Conjugate gradient poisson solver..."

# Test the generic fft_poisson_solver constructor
x = y = (0, 1)
z = (0, 1)
grid = RectilinearGrid(arch, size=(2, 2, 2); x, y, z)
solver = ConjugateGradientPoissonSolver(grid, preconditioner=fft_poisson_solver(grid))
pressure = CenterField(grid)
solve!(pressure, solver.conjugate_gradient_solver, solver.right_hand_side)
@test solver isa ConjugateGradientPoissonSolver

z = [0, 0.2, 1]
grid = RectilinearGrid(arch, size=(2, 2, 2); x, y, z)
solver = ConjugateGradientPoissonSolver(grid, preconditioner=fft_poisson_solver(grid))
pressure = CenterField(grid)
solve!(pressure, solver.conjugate_gradient_solver, solver.right_hand_side)
@test solver isa ConjugateGradientPoissonSolver
end
end
end

14 changes: 14 additions & 0 deletions test/test_poisson_solvers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
include("dependencies_for_runtests.jl")
include("dependencies_for_poisson_solvers.jl")

using Oceananigans.Solvers: fft_poisson_solver

#####
##### Run pressure solver tests 1
#####
Expand Down Expand Up @@ -38,6 +40,18 @@ two_dimensional_topologies = [(Flat, Bounded, Bounded),
@test poisson_solver_instantiates(grid, FFTW.ESTIMATE)
@test poisson_solver_instantiates(grid, FFTW.MEASURE)
end

# Test the generic fft_poisson_solver constructor
x = y = (0, 1)
z = (0, 1)
regular_grid = RectilinearGrid(arch, size=(2, 2, 2); x, y, z)
fft_based_solver = fft_poisson_solver(regular_grid)
@test fft_based_solver isa FFTBasedPoissonSolver

z = [0, 0.2, 1]
vertically_stretched_grid = RectilinearGrid(arch, size=(2, 2, 2); x, y, z)
fourier_tridiagonal_solver = fft_poisson_solver(vertically_stretched_grid)
@test fourier_tridiagonal_solver isa FourierTridiagonalPoissonSolver
end
end

Expand Down