diff --git a/src/Solvers/Solvers.jl b/src/Solvers/Solvers.jl index bbad459ea9..c89a56e40e 100644 --- a/src/Solvers/Solvers.jl +++ b/src/Solvers/Solvers.jl @@ -53,6 +53,6 @@ const GridWithFourierTridiagonalSolver = Union{XYRegularRG, XZRegularRG, YZRegul fft_poisson_solver(grid::XYZRegularRG) = FFTBasedPoissonSolver(grid) fft_poisson_solver(grid::GridWithFourierTridiagonalSolver) = - FourierTridiagonalPoissonSolver(grid.underlying_grid) + FourierTridiagonalPoissonSolver(grid) end # module diff --git a/src/Solvers/conjugate_gradient_poisson_solver.jl b/src/Solvers/conjugate_gradient_poisson_solver.jl index 7a4ba2f117..eb9ec9db21 100644 --- a/src/Solvers/conjugate_gradient_poisson_solver.jl +++ b/src/Solvers/conjugate_gradient_poisson_solver.jl @@ -1,4 +1,4 @@ -using Oceananigans.Operators: divᶜᶜᶜ, ∇²ᶜᶜᶜ +using Oceananigans.Operators using Oceananigans.ImmersedBoundaries: ImmersedBoundaryGrid using Statistics: mean @@ -107,7 +107,7 @@ function compute_preconditioner_rhs!(solver::FourierTridiagonalPoissonSolver, rh arch = architecture(grid) tridiagonal_dir = solver.batched_tridiagonal_solver.tridiagonal_direction launch!(arch, grid, :xyz, fourier_tridiagonal_preconditioner_rhs!, - solver.storage, tridiagonal_dir, rhs) + solver.storage, tridiagonal_dir, grid, rhs) return nothing end @@ -115,7 +115,7 @@ const FFTBasedPreconditioner = Union{FFTBasedPoissonSolver, FourierTridiagonalPo function precondition!(p, preconditioner::FFTBasedPreconditioner, r, args...) compute_preconditioner_rhs!(preconditioner, r) - p = solve!(p, preconditioner) + solve!(p, preconditioner) mean_p = mean(p) grid = p.grid diff --git a/test/runtests.jl b/test/runtests.jl index 816d8d578a..aa386f0994 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -79,6 +79,7 @@ CUDA.allowscalar() do if group == :poisson_solvers_2 || group == :all @testset "Poisson Solvers 2" begin include("test_poisson_solvers_stretched_grids.jl") + include("test_conjugate_gradient_poisson_solver.jl") end end diff --git a/test/test_conjugate_gradient_poisson_solver.jl b/test/test_conjugate_gradient_poisson_solver.jl new file mode 100644 index 0000000000..5e92a8f241 --- /dev/null +++ b/test/test_conjugate_gradient_poisson_solver.jl @@ -0,0 +1,28 @@ +include("dependencies_for_runtests.jl") +using Oceananigans.Solvers: fft_poisson_solver, ConjugateGradientPoissonSolver + +@testset "Conjugate gradient Poisson solver" begin + @info "Testing Conjugate gradient poisson solver..." + for arch in archs + @testset "Conjugate gradient Poisson solver unit tests [$arch]" begin + @info "Unit testing Conjugate gradient poisson solver..." + + # Test the generic fft_poisson_solver constructor + x = y = (0, 1) + z = (0, 1) + grid = RectilinearGrid(arch, size=(2, 2, 2); x, y, z) + solver = ConjugateGradientPoissonSolver(grid, preconditioner=fft_poisson_solver(grid)) + pressure = CenterField(grid) + solve!(pressure, solver.conjugate_gradient_solver, solver.right_hand_side) + @test solver isa ConjugateGradientPoissonSolver + + z = [0, 0.2, 1] + grid = RectilinearGrid(arch, size=(2, 2, 2); x, y, z) + solver = ConjugateGradientPoissonSolver(grid, preconditioner=fft_poisson_solver(grid)) + pressure = CenterField(grid) + solve!(pressure, solver.conjugate_gradient_solver, solver.right_hand_side) + @test solver isa ConjugateGradientPoissonSolver + end + end +end + diff --git a/test/test_poisson_solvers.jl b/test/test_poisson_solvers.jl index 7e37253a3e..ac12c92dbd 100644 --- a/test/test_poisson_solvers.jl +++ b/test/test_poisson_solvers.jl @@ -1,6 +1,8 @@ include("dependencies_for_runtests.jl") include("dependencies_for_poisson_solvers.jl") +using Oceananigans.Solvers: fft_poisson_solver + ##### ##### Run pressure solver tests 1 ##### @@ -38,6 +40,18 @@ two_dimensional_topologies = [(Flat, Bounded, Bounded), @test poisson_solver_instantiates(grid, FFTW.ESTIMATE) @test poisson_solver_instantiates(grid, FFTW.MEASURE) end + + # Test the generic fft_poisson_solver constructor + x = y = (0, 1) + z = (0, 1) + regular_grid = RectilinearGrid(arch, size=(2, 2, 2); x, y, z) + fft_based_solver = fft_poisson_solver(regular_grid) + @test fft_based_solver isa FFTBasedPoissonSolver + + z = [0, 0.2, 1] + vertically_stretched_grid = RectilinearGrid(arch, size=(2, 2, 2); x, y, z) + fourier_tridiagonal_solver = fft_poisson_solver(vertically_stretched_grid) + @test fourier_tridiagonal_solver isa FourierTridiagonalPoissonSolver end end