Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for autodifferentiating hydrostatic turbulence simulation #3867

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
20e4165
Add test for hydrostatic turbulence
glwagner Oct 24, 2024
92b314c
Bypass fill_halo_regions for tuples
glwagner Oct 24, 2024
b25c2b4
Cleanup to interface
glwagner Oct 24, 2024
fd3031c
Make the test better
glwagner Oct 24, 2024
0047d2f
Updates
glwagner Oct 25, 2024
3912c7a
Merge branch 'main' into glw/autodiff-hydrostatic-turbulence
jlk9 Oct 25, 2024
45dbdf8
Merge branch 'glw/autodiff-hydrostatic-turbulence' of https://github.…
glwagner Oct 25, 2024
1522ae4
Fix z-coord bug
glwagner Oct 25, 2024
b97190b
Few more bugs
glwagner Oct 25, 2024
08292d2
Fix bug in field_tuples
glwagner Oct 25, 2024
0f48773
Tiny cleanup
glwagner Oct 25, 2024
9b9575f
Another bug
glwagner Oct 25, 2024
423281b
Back to pure ab2
glwagner Oct 25, 2024
fe0f9cf
Make set! more differentiable
glwagner Oct 25, 2024
e66af64
Restore syntax for some reason
glwagner Oct 26, 2024
15b582a
More appropriate test w tolerance
glwagner Oct 26, 2024
9706b8b
Merge branch 'main' into glw/autodiff-hydrostatic-turbulence
glwagner Oct 28, 2024
d2e8940
Merge remote-tracking branch 'origin/main' into glw/autodiff-hydrosta…
glwagner Oct 31, 2024
eced1e5
hopefully fix type stability issue with set!(u, ::Number)
glwagner Oct 31, 2024
7cbcadf
Updates
glwagner Nov 2, 2024
397abe4
Update Enzyme
glwagner Nov 4, 2024
8f65907
Down Enzyme
glwagner Nov 5, 2024
7574a4d
Changed FD computation to central difference instead of forward diffe…
jlk9 Nov 7, 2024
f20c900
Switched named of variables for FD comparison
jlk9 Nov 7, 2024
c35a751
Set AD test to be at ν1 = ν₀ + Δν to avoid a global minimum
jlk9 Nov 7, 2024
9d254fe
Removed extraneous @show statements
jlk9 Nov 7, 2024
11809b0
Testing to see if removing grid from the handle of tupled_fill_halo_r…
jlk9 Nov 10, 2024
ee6361b
Another test with grid handles
jlk9 Nov 10, 2024
22d4a95
Update src/Fields/field_tuples.jl
jlk9 Nov 10, 2024
c9d9942
More robust test for grid
jlk9 Nov 11, 2024
933aac3
Merge branch 'main' into glw/autodiff-hydrostatic-turbulence
jlk9 Nov 11, 2024
2a239e9
Modularizing different tupled_fill_halo_regions! methods
jlk9 Nov 11, 2024
a795037
Merge branch 'main' into glw/autodiff-hydrostatic-turbulence
jlk9 Nov 12, 2024
df48234
Update src/Fields/field_tuples.jl
jlk9 Nov 12, 2024
ca2108c
Merge branch 'main' of https://github.com/CliMA/Oceananigans.jl into …
jlk9 Nov 23, 2024
8947a3b
Fixed variable name conflict error
jlk9 Nov 23, 2024
975ffa0
Merge branch 'main' into glw/autodiff-hydrostatic-turbulence
jlk9 Nov 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Fields/field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,8 @@ Base.checkbounds(f::Field, I...) = Base.checkbounds(f.data, I...)
@propagate_inbounds Base.lastindex(f::Field) = lastindex(f.data)
@propagate_inbounds Base.lastindex(f::Field, dim) = lastindex(f.data, dim)

Base.fill!(f::Field, val) = fill!(parent(f), val)
Base.parent(f::Field) = parent(f.data)
@inline Base.fill!(f::Field, val) = fill!(parent(f), val)
@inline Base.parent(f::Field) = parent(f.data)
Adapt.adapt_structure(to, f::Field) = Adapt.adapt(to, f.data)

total_size(f::Field) = total_size(f.grid, location(f), f.indices)
Expand Down
72 changes: 48 additions & 24 deletions src/Fields/field_tuples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,42 +56,66 @@ Fill halo regions for all `fields`. The algorithm:
function fill_halo_regions!(maybe_nested_tuple::Union{NamedTuple, Tuple}, args...; kwargs...)
flattened = flattened_unique_values(maybe_nested_tuple)

# Sort fields into ReducedField and Field with non-nothing boundary conditions
fields_with_bcs = filter(f -> !isnothing(boundary_conditions(f)), flattened)
reduced_fields = filter(f -> f isa ReducedField, fields_with_bcs)

for field in reduced_fields
fill_halo_regions!(field, args...; kwargs...)
# Look for grid within the flattened field tuple:
for f in flattened
if isdefined(f, :grid)
grid = f.grid
return tupled_fill_halo_regions!(flattened, grid, args...; kwargs...)
end
end

# MultiRegion fields are considered windowed_fields (indices isa MultiRegionObject))
windowed_fields = filter(f -> !(f isa FullField), fields_with_bcs)
ordinary_fields = filter(f -> (f isa FullField) && !(f isa ReducedField), fields_with_bcs)
return tupled_fill_halo_regions!(flattened, args...; kwargs...)
end

# Fill halo regions for reduced and windowed fields
for field in windowed_fields
fill_halo_regions!(field, args...; kwargs...)
end
# Version where we find grid amongst ordinary fields:
function tupled_fill_halo_regions!(fields, args...; kwargs...)

# Fill the rest
if !isempty(ordinary_fields)
ordinary_fields = produce_ordinary_fields(fields, args...; kwargs)

if !isempty(ordinary_fields) # ie not reduced, and with default_indices
grid = first(ordinary_fields).grid
tupled_fill_halo_regions!(ordinary_fields, grid, args...; kwargs...)
fill_halo_regions!(map(data, ordinary_fields),
map(boundary_conditions, ordinary_fields),
default_indices(3),
map(instantiated_location, ordinary_fields),
grid, args...; kwargs...)
end

return nothing
end

# Version where grid is provided:
function tupled_fill_halo_regions!(fields, grid::AbstractGrid, args...; kwargs...)

function tupled_fill_halo_regions!(fields, grid, args...; kwargs...)
ordinary_fields = produce_ordinary_fields(fields, args...; kwargs)

# We cannot group windowed fields together, the indices must be (:, :, :)!
indices = default_indices(3)
if !isempty(ordinary_fields) # ie not reduced, and with default_indices
fill_halo_regions!(map(data, ordinary_fields),
map(boundary_conditions, ordinary_fields),
default_indices(3),
map(instantiated_location, ordinary_fields),
grid, args...; kwargs...)
end

return nothing
end

# Helper function to create the tuple of ordinary fields:
@inline function produce_ordinary_fields(fields, args...; kwargs)

ordinary_fields = Field[]
for f in fields
if !isnothing(boundary_conditions(f))
if f isa ReducedField || !(f isa FullField)
# Windowed and reduced fields
fill_halo_regions!(f, args...; kwargs...)
else
push!(ordinary_fields, f)
end
end
end

return fill_halo_regions!(map(data, fields),
map(boundary_conditions, fields),
indices,
map(instantiated_location, fields),
grid, args...; kwargs...)
return tuple(ordinary_fields...)
end

#####
Expand Down
5 changes: 5 additions & 0 deletions src/Fields/set!.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ set!(u::Field, f::Function) = set_to_function!(u, f)
set!(u::Field, a::Union{Array, CuArray, OffsetArray}) = set_to_array!(u, a)
set!(u::Field, v::Field) = set_to_field!(u, v)

function set!(u::Field, a::Number)
fill!(interior(u), a) # note all other set! only change interior
return u # return u, not parent(u), for type-stability
end

function set!(u::Field, v)
u .= v # fallback
return u
Expand Down
12 changes: 9 additions & 3 deletions src/Models/HydrostaticFreeSurfaceModels/explicit_free_surface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ on_architecture(to, free_surface::ExplicitFreeSurface) =
function materialize_free_surface(free_surface::ExplicitFreeSurface{Nothing}, velocities, grid)
η = free_surface_displacement_field(velocities, free_surface, grid)
g = convert(eltype(grid), free_surface.gravitational_acceleration)

return ExplicitFreeSurface(η, g)
end

Expand Down Expand Up @@ -85,9 +84,16 @@ explicit_ab2_step_free_surface!(free_surface, model, Δt, χ) =
@inbounds η[i, j, Nz+1] += ζⁿ * η⁻[i, j, k] + γⁿ * (η[i, j, k] + convert(FT, Δt) * Gⁿ[i, j, k])
end

@kernel function _explicit_ab2_step_free_surface!(η, Δt, χ::FT, Gηⁿ, Gη⁻, Nz) where FT
@kernel function _explicit_ab2_step_free_surface!(η, Δt, χ, Gηⁿ, Gη⁻, Nz)
i, j = @index(Global, NTuple)
@inbounds η[i, j, Nz+1] += Δt * ((FT(1.5) + χ) * Gηⁿ[i, j, Nz+1] - (FT(0.5) + χ) * Gη⁻[i, j, Nz+1])
FT0 = typeof(χ)
one_point_five = convert(FT0, 1.5)
oh_point_five = convert(FT0, 0.5)
not_euler = χ != convert(FT0, -0.5)
@inbounds begin
Gη = (one_point_five + χ) * Gηⁿ[i, j, Nz+1] - (oh_point_five + χ) * Gη⁻[i, j, Nz+1] * not_euler
η[i, j, Nz+1] += Δt * Gη
end
end

#####
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ T₀[T₀ .< 0.5] .= 0

set!(model, u=u₀, v=v₀, T=T₀)

model.velocities.u
@model.velocities.u

# output

Expand Down Expand Up @@ -61,7 +61,8 @@ model.velocities.u
end

initialize!(model)
update_state!(model; compute_tendencies = false)
update_state!(model; compute_tendencies=false)

return nothing
end

Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Oceananigans.TurbulenceClosures: compute_diffusivities!
using Oceananigans.ImmersedBoundaries: mask_immersed_field!, mask_immersed_field_xy!, inactive_node
using Oceananigans.Models: update_model_field_time_series!
using Oceananigans.Models.NonhydrostaticModels: update_hydrostatic_pressure!, p_kernel_parameters
using Oceananigans.Fields: replace_horizontal_vector_halos!
using Oceananigans.Fields: replace_horizontal_vector_halos!, tupled_fill_halo_regions!

import Oceananigans.Models.NonhydrostaticModels: compute_auxiliaries!
import Oceananigans.TimeSteppers: update_state!
Expand Down Expand Up @@ -38,7 +38,7 @@ function update_state!(model::HydrostaticFreeSurfaceModel, grid, callbacks; comp
# Update the boundary conditions
@apply_regionally update_boundary_condition!(fields(model), model)

fill_halo_regions!(prognostic_fields(model), model.clock, fields(model); async = true)
tupled_fill_halo_regions!(prognostic_fields(model), grid, model.clock, fields(model), async=true)

@apply_regionally replace_horizontal_vector_halos!(model.velocities, model.grid)
@apply_regionally compute_auxiliaries!(model)
Expand Down
39 changes: 15 additions & 24 deletions src/TimeSteppers/quasi_adams_bashforth_2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,40 +80,27 @@ function time_step!(model::AbstractModel{<:QuasiAdamsBashforth2TimeStepper}, Δt
Δt == 0 && @warn "Δt == 0 may cause model blowup!"

# Be paranoid and update state at iteration 0
model.clock.iteration == 0 && update_state!(model, callbacks)

ab2_timestepper = model.timestepper
model.clock.iteration == 0 && update_state!(model, callbacks; compute_tendencies=true)

# Change the default χ if necessary, which occurs if:
# Take an euler step if:
# * We detect that the time-step size has changed.
# * We detect that this is the "first" time-step, which means we
# need to take an euler step. Note that model.clock.last_Δt is
# initialized as Inf
# * The user has passed euler=true to time_step!
euler = euler || (Δt != model.clock.last_Δt)

euler && @debug "Taking a forward Euler step."

# If euler, then set χ = -0.5
minus_point_five = convert(eltype(model.grid), -0.5)
ab2_timestepper = model.timestepper
χ = ifelse(euler, minus_point_five, ab2_timestepper.χ)

# Set time-stepper χ (this is used in ab2_step!, but may also be used elsewhere)
χ₀ = ab2_timestepper.χ # Save initial value
ab2_timestepper.χ = χ

# Ensure zeroing out all previous tendency fields to avoid errors in
# case G⁻ includes NaNs. See https://github.com/CliMA/Oceananigans.jl/issues/2259
if euler
@debug "Taking a forward Euler step."
for field in ab2_timestepper.G⁻
!isnothing(field) && @apply_regionally fill!(field, 0)
end
end
# Full step for tracers, fractional step for velocities.
ab2_step!(model, Δt)

# Be paranoid and update state at iteration 0
model.clock.iteration == 0 && update_state!(model, callbacks; compute_tendencies=true)

ab2_step!(model, Δt) # full step for tracers, fractional step for velocities.

tick!(model.clock, Δt)
model.clock.last_Δt = Δt
model.clock.last_stage_Δt = Δt # just one stage
Expand All @@ -126,7 +113,7 @@ function time_step!(model::AbstractModel{<:QuasiAdamsBashforth2TimeStepper}, Δt

# Return χ to initial value
ab2_timestepper.χ = χ₀

return nothing
end

Expand All @@ -142,7 +129,6 @@ end

""" Generic implementation. """
function ab2_step!(model, Δt)

grid = model.grid
arch = architecture(grid)
model_fields = prognostic_fields(model)
Expand Down Expand Up @@ -176,11 +162,16 @@ Time step velocity fields via the 2nd-order quasi Adams-Bashforth method
@kernel function ab2_step_field!(u, Δt, χ, Gⁿ, G⁻)
i, j, k = @index(Global, NTuple)

FT = eltype(χ)
FT = typeof(χ)
Δt = convert(FT, Δt)
one_point_five = convert(FT, 1.5)
oh_point_five = convert(FT, 0.5)
not_euler = χ != convert(FT, -0.5) # use to prevent corruption by leftover NaNs in G⁻

@inbounds u[i, j, k] += convert(FT, Δt) * ((one_point_five + χ) * Gⁿ[i, j, k] - (oh_point_five + χ) * G⁻[i, j, k])
@inbounds begin
Gu = (one_point_five + χ) * Gⁿ[i, j, k] - (oh_point_five + χ) * G⁻[i, j, k] * not_euler
u[i, j, k] += Δt * Gu
end
end

@kernel ab2_step_field!(::FunctionField, Δt, χ, Gⁿ, G⁻) = nothing
Expand Down
5 changes: 2 additions & 3 deletions test/test_abstract_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ function x_derivative_cell(arch)
end

function times_x_derivative(a, b, location, i, j, k, answer)
a∇b = @at location b * ∂x(a)

return CUDA.@allowscalar a∇b[i, j, k] == answer
b∇a = @at location b * ∂x(a)
return CUDA.@allowscalar b∇a[i, j, k] == answer
end

for arch in archs
Expand Down
Loading