Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce gauge condition in preconditioners for ConjugateGradientPoissonSolver #3802

Merged
merged 13 commits into from
Oct 1, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -80,47 +80,72 @@ function ConjugateGradientPoissonSolver(grid;
return ConjugateGradientPoissonSolver(grid, rhs, conjugate_gradient_solver)
end

@kernel function compute_source_term!(rhs, grid, Δt, U★)
i, j, k = @index(Global, NTuple)
δ = divᶜᶜᶜ(i, j, k, grid, U★.u, U★.v, U★.w)
inactive = !inactive_cell(i, j, k, grid)
@inbounds rhs[i, j, k] = δ / Δt * inactive
end

function solve_for_pressure!(pressure, solver::ConjugateGradientPoissonSolver, Δt, U★)
# We may want a criteria like this:
# min_Δt = eps(typeof(Δt))
# Δt <= min_Δt && return pressure

rhs = solver.right_hand_side
grid = solver.grid
arch = architecture(grid)
launch!(arch, grid, :xyz, compute_source_term!, rhs, grid, Δt, U★)

# Solve pressure Pressure equation for pressure, given rhs
# @info "Δt before pressure solve: $(Δt)"
solve!(pressure, solver.conjugate_gradient_solver, rhs)

return pressure
launch!(arch, grid, :xyz, _compute_source_term!, rhs, grid, Δt, U★)
return solve!(pressure, solver.conjugate_gradient_solver, rhs)
end

#####
##### A preconditioner based on the FFT solver
#####

@kernel function fft_preconditioner_right_hand_side!(preconditioner_rhs, rhs)
@kernel function fft_preconditioner_rhs!(preconditioner_rhs, rhs)
i, j, k = @index(Global, NTuple)
@inbounds preconditioner_rhs[i, j, k] = rhs[i, j, k]
end

function precondition!(p, solver::FFTBasedPoissonSolver, rhs, args...)
@kernel function fourier_tridiagonal_preconditioner_rhs!(preconditioner_rhs, ::XDirection, grid, rhs)
i, j, k = @index(Global, NTuple)
@inbounds preconditioner_rhs[i, j, k] = Δxᶜᶜᶜ(i, j, k, grid) * rhs[i, j, k]
end

@kernel function fourier_tridiagonal_preconditioner_rhs!(preconditioner_rhs, ::YDirection, grid, rhs)
i, j, k = @index(Global, NTuple)
@inbounds preconditioner_rhs[i, j, k] = Δyᶜᶜᶜ(i, j, k, grid) * rhs[i, j, k]
end

@kernel function fourier_tridiagonal_preconditioner_rhs!(preconditioner_rhs, ::ZDirection, grid, rhs)
i, j, k = @index(Global, NTuple)
@inbounds preconditioner_rhs[i, j, k] = Δzᶜᶜᶜ(i, j, k, grid) * rhs[i, j, k]
end

function compute_preconditioner_rhs!(solver::FFTBasedPoissonSolver, rhs)
grid = solver.grid
arch = architecture(grid)
launch!(arch, grid, :xyz, fft_preconditioner_rhs!, solver.storage, rhs)
return nothing
end

function compute_preconditioner_rhs!(solver::FourierTridiagonalPoissonSolver, rhs)
grid = solver.grid
arch = architecture(grid)
tridiagonal_dir = solver.batched_tridiagonal_solver.tridiagonal_direction
launch!(arch, grid, :xyz, fourier_tridiagonal_preconditioner_rhs!,
solver.storage, tridiagonal_dir, rhs)
return nothing
end

function precondition!(p, solver, rhs, args...)
compute_preconditioner_rhs!(solver, rhs)
p = solve!(p, solver)

P = mean(p)
grid = solver.grid
arch = architecture(grid)
launch!(arch, grid, :xyz, fft_preconditioner_right_hand_side!, solver.storage, rhs)
p = solve!(p, solver, solver.storage)
launch!(arch, grid, :xyz, subtract_and_mask!, p, grid, P)

return p
end

@kernel function subtract_and_mask!(a, grid, b)
i, j, k = @index(Global, NTuple)
active = !inactive_cell(i, j, k, grid)
a[i, j, k] = (a[i, j, k] - b) * active
end

#####
##### The "DiagonallyDominantPreconditioner" used by MITgcm
#####
Expand All @@ -133,6 +158,12 @@ Base.summary(::DiagonallyDominantPreconditioner) = "DiagonallyDominantPreconditi
arch = architecture(p)
fill_halo_regions!(r)
launch!(arch, grid, :xyz, _diagonally_dominant_precondition!, p, grid, r)

P = mean(p)
grid = solver.grid
arch = architecture(grid)
launch!(arch, grid, :xyz, subtract_and_mask!, p, grid, P)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xkykai I think this is the crucial lines. Without this the pressure does not have a zero mean and I suspect that can throw off the CG solver. But I'm checking.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jm-c might be good to have your input

Copy link
Collaborator

@navidcy navidcy Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have found that I needed to enforce the zero mean (e.g., when solving Laplace's or Poisson's equation) when I was using the conjugate gradient solver with @elise-palethorpe to compare with MultiGrid

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was a key ingredient missing from previous implementation. The FFT-based preconditioner zeros out the mean over the underlying grid, but does not zero out the mean on the immersed grid.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked the derivation, looks good to me!


return p
end

Expand Down
97 changes: 43 additions & 54 deletions src/Models/NonhydrostaticModels/solve_for_pressure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,88 +7,77 @@ using Oceananigans.Grids: XDirection, YDirection, ZDirection
##### Calculate the right-hand-side of the non-hydrostatic pressure Poisson equation.
#####

@kernel function calculate_pressure_source_term_fft_based_solver!(rhs, grid, Δt, U★)
@kernel function _compute_source_term!(rhs, grid, Δt, U★)
i, j, k = @index(Global, NTuple)
@inbounds rhs[i, j, k] = divᶜᶜᶜ(i, j, k, grid, U★.u, U★.v, U★.w) / Δt
active = !inactive_cell(i, j, k, grid)
δ = divᶜᶜᶜ(i, j, k, grid, U★.u, U★.v, U★.w)
@inbounds rhs[i, j, k] = active * δ / Δt
end

@kernel function calculate_pressure_source_term_fourier_tridiagonal_solver!(rhs, grid, Δt, U★, ::XDirection)
@kernel function _fourier_tridiagonal_source_term!(rhs, ::XDirection, grid, Δt, U★)
i, j, k = @index(Global, NTuple)
@inbounds rhs[i, j, k] = Δxᶜᶜᶜ(i, j, k, grid) * divᶜᶜᶜ(i, j, k, grid, U★.u, U★.v, U★.w) / Δt
active = !inactive_cell(i, j, k, grid)
δ = divᶜᶜᶜ(i, j, k, grid, U★.u, U★.v, U★.w)
@inbounds rhs[i, j, k] = active * Δxᶜᶜᶜ(i, j, k, grid) * δ / Δt
end

@kernel function calculate_pressure_source_term_fourier_tridiagonal_solver!(rhs, grid, Δt, U★, ::YDirection)
@kernel function _fourier_tridiagonal_source_term!(rhs, ::YDirection, grid, Δt, U★)
i, j, k = @index(Global, NTuple)
@inbounds rhs[i, j, k] = Δyᶜᶜᶜ(i, j, k, grid) * divᶜᶜᶜ(i, j, k, grid, U★.u, U★.v, U★.w) / Δt
active = !inactive_cell(i, j, k, grid)
δ = divᶜᶜᶜ(i, j, k, grid, U★.u, U★.v, U★.w)
@inbounds rhs[i, j, k] = active * Δyᶜᶜᶜ(i, j, k, grid) * δ / Δt
end

@kernel function calculate_pressure_source_term_fourier_tridiagonal_solver!(rhs, grid, Δt, U★, ::ZDirection)
@kernel function _fourier_tridiagonal_source_term!(rhs, ::ZDirection, grid, Δt, U★)
i, j, k = @index(Global, NTuple)
@inbounds rhs[i, j, k] = Δzᶜᶜᶜ(i, j, k, grid) * divᶜᶜᶜ(i, j, k, grid, U★.u, U★.v, U★.w) / Δt
active = !inactive_cell(i, j, k, grid)
δ = divᶜᶜᶜ(i, j, k, grid, U★.u, U★.v, U★.w)
@inbounds rhs[i, j, k] = active * Δzᶜᶜᶜ(i, j, k, grid) * δ / Δt
end

#####
##### Solve for pressure
#####

function solve_for_pressure!(pressure, solver::DistributedFFTBasedPoissonSolver, Δt, U★)
function compute_source_term!(pressure, solver::DistributedFFTBasedPoissonSolver, Δt, U★)
rhs = solver.storage.zfield
arch = architecture(solver)
grid = solver.local_grid

launch!(arch, grid, :xyz, calculate_pressure_source_term_fft_based_solver!,
rhs, grid, Δt, U★)

# Solve pressure Poisson equation for pressure, given rhs
solve!(pressure, solver)

return pressure
launch!(arch, grid, :xyz, _compute_source_term!, rhs, grid, Δt, U★)
return nothing
end

function solve_for_pressure!(pressure, solver::FFTBasedPoissonSolver, Δt, U★)

# Calculate right hand side:
rhs = solver.storage
function compute_source_term!(pressure, solver::DistributedFourierTridiagonalPoissonSolver, Δt, U★)
rhs = solver.storage.zfield
arch = architecture(solver)
grid = solver.grid

launch!(arch, grid, :xyz, calculate_pressure_source_term_fft_based_solver!,
rhs, grid, Δt, U★)

# Solve pressure Poisson given for pressure, given rhs
solve!(pressure, solver, rhs)

grid = solver.local_grid
tridiagonal_dir = solver.batched_tridiagonal_solver.tridiagonal_direction
launch!(arch, grid, :xyz, _fourier_tridiagonal_source_term!,
rhs, grid, Δt, U★, tridiagonal_dir)
return nothing
end

function solve_for_pressure!(pressure, solver::DistributedFourierTridiagonalPoissonSolver, Δt, U★)

# Calculate right hand side:
rhs = solver.storage.zfield
function compute_source_term!(pressure, solver::FourierTridiagonalPoissonSolver, Δt, U★)
rhs = solver.source_term
arch = architecture(solver)
grid = solver.local_grid

launch!(arch, grid, :xyz, calculate_pressure_source_term_fourier_tridiagonal_solver!,
rhs, grid, Δt, U★, solver.batched_tridiagonal_solver.tridiagonal_direction)

# Pressure Poisson rhs, scaled by the spacing in the stretched direction at ᶜᶜᶜ, is stored in solver.source_term:
solve!(pressure, solver)

grid = solver.grid
tridiagonal_dir = solver.batched_tridiagonal_solver.tridiagonal_direction
launch!(arch, grid, :xyz, _fourier_tridiagonal_source_term!,
rhs, grid, Δt, U★, tridiagonal_dir)
return nothing
end

function solve_for_pressure!(pressure, solver::FourierTridiagonalPoissonSolver, Δt, U★)

# Calculate right hand side:
rhs = solver.source_term
function compute_source_term!(pressure, solver::FFTBasedPoissonSolver, Δt, U★)
rhs = solver.storage
arch = architecture(solver)
grid = solver.grid
launch!(arch, grid, :xyz, _compute_source_term!, rhs, grid, Δt, U★)
return nothing
end

launch!(arch, grid, :xyz, calculate_pressure_source_term_fourier_tridiagonal_solver!,
rhs, grid, Δt, U★, solver.batched_tridiagonal_solver.tridiagonal_direction)
#####
##### Solve for pressure
#####

# Pressure Poisson rhs, scaled by the spacing in the stretched direction at ᶜᶜᶜ, is stored in solver.source_term:
function solve_for_pressure!(pressure, solver, Δt, U★)
compute_source_term!(pressure, solver, Δt, U★)
solve!(pressure, solver)

return nothing
return pressure
end

3 changes: 2 additions & 1 deletion src/Solvers/fft_based_poisson_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ elements (typically the same type as `solver.storage`).
Equation ``(∇² + m) ϕ = b`` is sometimes referred to as the "screened Poisson" equation
when ``m < 0``, or the Helmholtz equation when ``m > 0``.
"""
function solve!(ϕ, solver::FFTBasedPoissonSolver, b, m=0)
function solve!(ϕ, solver::FFTBasedPoissonSolver, b=solver.storage, m=0)
arch = architecture(solver)
topo = TX, TY, TZ = topology(solver.grid)
Nx, Ny, Nz = size(solver.grid)
Expand Down Expand Up @@ -131,3 +131,4 @@ end

@inbounds ϕ[i′, j′, k′] = real(ϕc[i, j, k])
end