Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Enzyme test for differentiating a single column model with CATKEVerticalDiffusivity #3837

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e6758b3
Delete a lot of comments
glwagner Oct 8, 2024
c6dcefe
Add parameter estimation test
glwagner Oct 8, 2024
f9b9896
Introduce custom DiffusivityFields to control fill_halo_regions behavior
glwagner Oct 8, 2024
af61d4f
Some cosmetic changes
glwagner Oct 8, 2024
38fbd92
Even more specific fill halo regions for CATKE
glwagner Oct 9, 2024
77dd5f0
Merge branch 'main' into glw/enzyme-scm
glwagner Oct 18, 2024
20e4165
Add test for hydrostatic turbulence
glwagner Oct 24, 2024
92b314c
Bypass fill_halo_regions for tuples
glwagner Oct 24, 2024
b25c2b4
Cleanup to interface
glwagner Oct 24, 2024
fd3031c
Make the test better
glwagner Oct 24, 2024
0047d2f
Updates
glwagner Oct 25, 2024
3912c7a
Merge branch 'main' into glw/autodiff-hydrostatic-turbulence
jlk9 Oct 25, 2024
409cbb2
Figuring out nan entries
jlk9 Oct 25, 2024
45dbdf8
Merge branch 'glw/autodiff-hydrostatic-turbulence' of https://github.…
glwagner Oct 25, 2024
1522ae4
Fix z-coord bug
glwagner Oct 25, 2024
b97190b
Few more bugs
glwagner Oct 25, 2024
08292d2
Fix bug in field_tuples
glwagner Oct 25, 2024
0f48773
Tiny cleanup
glwagner Oct 25, 2024
9b9575f
Another bug
glwagner Oct 25, 2024
423281b
Back to pure ab2
glwagner Oct 25, 2024
fe0f9cf
Make set! more differentiable
glwagner Oct 25, 2024
e66af64
Restore syntax for some reason
glwagner Oct 26, 2024
15b582a
More appropriate test w tolerance
glwagner Oct 26, 2024
9706b8b
Merge branch 'main' into glw/autodiff-hydrostatic-turbulence
glwagner Oct 28, 2024
d2e8940
Merge remote-tracking branch 'origin/main' into glw/autodiff-hydrosta…
glwagner Oct 31, 2024
eced1e5
hopefully fix type stability issue with set!(u, ::Number)
glwagner Oct 31, 2024
59c07e5
Merge remote-tracking branch 'origin/main' into glw/enzyme-scm
glwagner Nov 1, 2024
e6bcc72
Merge remote-tracking branch 'origin/glw/autodiff-hydrostatic-turbule…
glwagner Nov 1, 2024
5110a57
Starting to work on Simulation
glwagner Nov 1, 2024
afd8fdd
Implement get_top_tracer_bcs utility
glwagner Nov 1, 2024
3559b8d
Cleanup
glwagner Nov 1, 2024
3a2d4a0
Type-stabilize single column mode
glwagner Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
309 changes: 1 addition & 308 deletions ext/OceananigansEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ using Oceananigans.Utils: contiguousrange

using KernelAbstractions

#isdefined(Base, :get_extension) ? (import EnzymeCore) : (import ..EnzymeCore)
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)

using Enzyme: EnzymeCore
Expand All @@ -17,320 +16,14 @@ EnzymeCore.EnzymeRules.inactive_noinl(::typeof(Base.:(==)), ::Oceananigans.Abstr
EnzymeCore.EnzymeRules.inactive_noinl(::typeof(Oceananigans.AbstractOperations.validate_grid), x...) = nothing
EnzymeCore.EnzymeRules.inactive_noinl(::typeof(Oceananigans.AbstractOperations.metric_function), x...) = nothing
EnzymeCore.EnzymeRules.inactive_noinl(::typeof(Oceananigans.Utils.flatten_reduced_dimensions), x...) = nothing
EnzymeCore.EnzymeRules.inactive_noinl(::typeof(Oceananigans.Utils.prettytime), x...) = nothing
EnzymeCore.EnzymeRules.inactive(::typeof(Oceananigans.Grids.total_size), x...) = nothing
EnzymeCore.EnzymeRules.inactive(::typeof(Oceananigans.BoundaryConditions.parent_size_and_offset), x...) = nothing
@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{Oceananigans.Utils.KernelParameters}) = true

@inline batch(::Val{1}, ::Type{T}) where T = T
@inline batch(::Val{N}, ::Type{T}) where {T, N} = NTuple{N, T}

# function EnzymeCore.EnzymeRules.augmented_primal(config,
# func::EnzymeCore.Const{Type{Field}},
# ::Type{<:EnzymeCore.Annotation{RT}},
# loc::Union{EnzymeCore.Const{<:Tuple},
# EnzymeCore.Duplicated{<:Tuple}},
# grid::EnzymeCore.Annotation{<:Oceananigans.Grids.AbstractGrid},
# T::EnzymeCore.Const{<:DataType}; kw...) where RT
#
# primal = if EnzymeCore.EnzymeRules.needs_primal(config)
# func.val(loc.val, grid.val, T.val; kw...)
# else
# nothing
# end
#
# if haskey(kw, :a)
# # copy zeroing
# kw[:data] = copy(kw[:data])
# end
#
# shadow = if EnzymeCore.EnzymeRules.width(config) == 1
# func.val(loc.val, grid.val, T.val; kw...)
# else
# ntuple(Val(EnzymeCore.EnzymeRules.width(config))) do i
# Base.@_inline_meta
# func.val(loc.val, grid.val, T.val; kw...)
# end
# end
#
# P = EnzymeCore.EnzymeRules.needs_primal(config) ? RT : Nothing
# B = batch(Val(EnzymeCore.EnzymeRules.width(config)), RT)
# return EnzymeCore.EnzymeRules.AugmentedReturn{P, B, Nothing}(primal, shadow, nothing)
# end
#
# #####
# ##### Field
# #####
#
# function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1},
# func::EnzymeCore.Const{Type{Field}},
# ::RT,
# tape,
# loc::Union{EnzymeCore.Const{<:Tuple}, EnzymeCore.Duplicated{<:Tuple}},
# grid::EnzymeCore.Const{<:Oceananigans.Grids.AbstractGrid},
# T::EnzymeCore.Const{<:DataType}; kw...) where RT
# return (nothing, nothing, nothing)
# end
#
# function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1},
# func::EnzymeCore.Const{Type{Field}},
# ::RT,
# tape,
# loc::Union{EnzymeCore.Const{<:Tuple}, EnzymeCore.Duplicated{<:Tuple}},
# grid::EnzymeCore.Active{<:Oceananigans.Grids.AbstractGrid},
# T::EnzymeCore.Const{<:DataType}; kw...) where RT
# return (nothing, EnzymeCore.make_ero(grid), nothing)
# end

#####
##### FunctionField
#####

# @inline FunctionField(L::Tuple, func, grid) = FunctionField{L[1], L[2], L[3]}(func, grid)

# function EnzymeCore.EnzymeRules.augmented_primal(config,
# enzyme_func::Union{EnzymeCore.Const{<:Type{<:FunctionField}}, EnzymeCore.Const{Type{FT2}}},
# ::Type{<:EnzymeCore.Annotation{RT}},
# function_field_func,
# grid;
# clock = nothing,
# parameters = nothing) where {RT, FT2 <: FunctionField}
#
# FunctionFieldType = enzyme_func.val
#
# primal = if EnzymeCore.EnzymeRules.needs_primal(config)
# FunctionFieldType(function_field_func.val, grid.val; clock, parameters)
# else
# nothing
# end
#
# # function_field_func can be Active, Const (inactive), Duplicated (active but mutable)
# function_field_is_active = function_field_func isa Active
# # @show function_field_func
#
# # Support batched differentiation!
# config_width = EnzymeCore.EnzymeRules.width(config)
#
# dactives = if function_field_is_active
# if config_width == 1
# Ref(EnzymeCore.make_zero(function_field_func.val))
# else
# ntuple(Val(config_width)) do i
# Base.@_inline_meta
# Ref(EnzymeCore.make_zero(function_field_func.val))
# end
# end
# else
# nothing
# end
#
# shadow = if config_width == 1
# dfunction_field_func = if function_field_is_active
# dactives[]
# else
# function_field_func.dval
# end
#
# FunctionFieldType(dfunction_field_func, grid.val; clock, parameters)
# else
# ntuple(Val(config_width)) do i
# Base.@_inline_meta
#
# dfunction_field_func = if function_field_is_active
# dactives[i][]
# else
# function_field_func.dval[i]
# end
#
# FunctionFieldType(dfunction_field_func, grid.val; clock, parameters)
# end
# end
#
# P = EnzymeCore.EnzymeRules.needs_primal(config) ? RT : Nothing
# B = batch(Val(EnzymeCore.EnzymeRules.width(config)), RT)
# D = typeof(dactives)
#
# return EnzymeCore.EnzymeRules.AugmentedReturn{P, B, D}(primal, shadow, dactives)
# end
#
# function EnzymeCore.EnzymeRules.reverse(config,
# enzyme_func::Union{EnzymeCore.Const{<:Type{<:FunctionField}}, EnzymeCore.Const{Type{FT2}}},
# ::RT,
# tape,
# function_field_func,
# grid;
# clock = nothing,
# parameters = nothing) where {RT, FT2 <: FunctionField}
#
# dactives = if function_field_func isa Active
# if EnzymeCore.EnzymeRules.width(config) == 1
# tape[]
# else
# ntuple(Val(EnzymeCore.EnzymeRules.width(config))) do i
# Base.@_inline_meta
# tape[i][]
# end
# end
# else
# nothing
# end
#
# # return (dactives, grid (nothing))
# return (dactives, nothing)
# end

#####
##### launch!
#####

# function EnzymeCore.EnzymeRules.augmented_primal(config,
# func::EnzymeCore.Const{typeof(Oceananigans.Models.flattened_unique_values)},
# ::Type{<:EnzymeCore.Annotation{RT}},
# a) where RT
#
# sprimal = if EnzymeCore.EnzymeRules.needs_primal(config) || EnzymeCore.EnzymeRules.needs_shadow(config)
# func.val(a.val)
# else
# nothing
# end
#
# shadow = if EnzymeCore.EnzymeRules.needs_shadow(config)
# if EnzymeCore.EnzymeRules.width(config) == 1
# (typeof(a) <: Const) ? EnzymeCore.make_zero(sprimal)::RT : func.val(a.dval)
# else
# ntuple(Val(EnzymeCore.EnzymeRules.width(config))) do i
# Base.@_inline_meta
# (typeof(a) <: Const) ? EnzymeCore.make_zero(sprimal)::RT : func.val(a.dval[i])
# end
# end
# else
# nothing
# end
#
# primal = if EnzymeCore.EnzymeRules.needs_primal(config)
# sprimal
# else
# nothing
# end
#
# P = EnzymeCore.EnzymeRules.needs_primal(config) ? RT : Nothing
# B = EnzymeCore.EnzymeRules.needs_primal(config) ? batch(Val(EnzymeCore.EnzymeRules.width(config)), RT) : Nothing
#
# return EnzymeCore.EnzymeRules.AugmentedReturn{P, B, Nothing}(primal, shadow, nothing)
# end
#
# function EnzymeCore.EnzymeRules.reverse(config,
# func::EnzymeCore.Const{typeof(Oceananigans.Models.flattened_unique_values)},
# ::Type{<:EnzymeCore.Annotation{RT}},
# tape,
# a) where RT
#
# return (nothing,)
# end

#####
##### launch!
#####

# function EnzymeCore.EnzymeRules.augmented_primal(config,
# func::EnzymeCore.Const{typeof(Oceananigans.Utils.launch!)},
# ::Type{EnzymeCore.Const{Nothing}},
# arch,
# grid,
# workspec,
# kernel!,
# kernel_args::Vararg{Any,N};
# include_right_boundaries = false,
# reduced_dimensions = (),
# location = nothing,
# active_cells_map = nothing,
# kwargs...) where N
#
#
# workgroup, worksize = Oceananigans.Utils.work_layout(grid.val, workspec.val;
# include_right_boundaries,
# reduced_dimensions,
# location)
#
# offset = Oceananigans.Utils.offsets(workspec.val)
#
# if !isnothing(active_cells_map)
# workgroup, worksize = Oceananigans.Utils.active_cells_work_layout(workgroup, worksize, active_cells_map, grid.val)
# offset = nothing
# end
#
# if worksize != 0
#
# # We can only launch offset kernels with Static sizes!!!!
#
# if isnothing(offset)
# loop! = kernel!.val(Oceananigans.Architectures.device(arch.val), workgroup, worksize)
# dloop! = (typeof(kernel!) <: EnzymeCore.Const) ? nothing : kernel!.dval(Oceananigans.Architectures.device(arch.val), workgroup, worksize)
# else
# loop! = kernel!.val(Oceananigans.Architectures.device(arch.val), KernelAbstractions.StaticSize(workgroup), Oceananigans.Utils.OffsetStaticSize(contiguousrange(worksize, offset)))
# dloop! = (typeof(kernel!) <: EnzymeCore.Const) ? nothing : kernel!.val(Oceananigans.Architectures.device(arch.val), KernelAbstractions.StaticSize(workgroup), Oceananigans.Utils.OffsetStaticSize(contiguousrange(worksize, offset)))
# end
#
# @debug "Launching kernel $kernel! with worksize $worksize and offsets $offset from $workspec.val"
#
#
# duploop = (typeof(kernel!) <: EnzymeCore.Const) ? EnzymeCore.Const(loop!) : EnzymeCore.Duplicated(loop!, dloop!)
#
# config2 = EnzymeCore.EnzymeRules.Config{#=needsprimal=#false, #=needsshadow=#false, #=width=#EnzymeCore.EnzymeRules.width(config), EnzymeCore.EnzymeRules.overwritten(config)[5:end]}()
# subtape = EnzymeCore.EnzymeRules.augmented_primal(config2, duploop, EnzymeCore.Const{Nothing}, kernel_args...).tape
#
# tape = (duploop, subtape)
# else
# tape = nothing
# end
#
# return EnzymeCore.EnzymeRules.AugmentedReturn{Nothing, Nothing, Any}(nothing, nothing, tape)
# end
#
# @inline arg_elem_type(::Type{T}, ::Val{i}) where {T<:Tuple, i} = eltype(T.parameters[i])
#
# function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1},
# func::EnzymeCore.Const{typeof(Oceananigans.Utils.launch!)},
# ::Type{EnzymeCore.Const{Nothing}},
# tape,
# arch,
# grid,
# workspec,
# kernel!,
# kernel_args::Vararg{Any,N};
# include_right_boundaries = false,
# reduced_dimensions = (),
# location = nothing,
# active_cells_map = nothing,
# kwargs...) where N
#
# subrets = if tape !== nothing
# duploop, subtape = tape
# config2 = EnzymeCore.EnzymeRules.Config{#=needsprimal=#false, #=needsshadow=#false, #=width=#EnzymeCore.EnzymeRules.width(config), EnzymeCore.EnzymeRules.overwritten(config)[5:end]}()
# EnzymeCore.EnzymeRules.reverse(config2, duploop, EnzymeCore.Const{Nothing}, subtape, kernel_args...)
# else
# res2 = ntuple(Val(N)) do i
# Base.@_inline_meta
# if kernel_args[i] isa Active
# EnzymeCore.make_zero(kernel_args[i].val)
# else
# nothing
# end
# end
# end
#
# subrets2 = ntuple(Val(N)) do i
# Base.@_inline_meta
# if kernel_args[i] isa Active
# subrets[i]::arg_elem_type(typeof(kernel_args), Val(i))
# else
# nothing
# end
# end
#
# return (nothing, nothing, nothing, nothing, subrets2...)
#
# end

#####
##### update_model_field_time_series!
#####
Expand Down
10 changes: 5 additions & 5 deletions src/BuoyancyModels/seawater_buoyancy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ temperature and salinity are active, or of type `FT` if temperature
or salinity are constant, respectively.
"""
struct SeawaterBuoyancy{FT, EOS, T, S} <: AbstractBuoyancyModel{EOS}
equation_of_state :: EOS
equation_of_state :: EOS
gravitational_acceleration :: FT
constant_temperature :: T
constant_salinity :: S
constant_temperature :: T
constant_salinity :: S
end

required_tracers(::SeawaterBuoyancy) = (:T, :S)
required_tracers(::SeawaterBuoyancy{FT, EOS, <:Nothing, <:Number}) where {FT, EOS} = (:T,) # active temperature only
required_tracers(::SeawaterBuoyancy{FT, EOS, <:Number, <:Nothing}) where {FT, EOS} = (:S,) # active salinity only
required_tracers(::SeawaterBuoyancy{FT, EOS, <:Nothing, <:Number}) where {FT, EOS} = tuple(:T) # active temperature only
required_tracers(::SeawaterBuoyancy{FT, EOS, <:Number, <:Nothing}) where {FT, EOS} = tuple(:S) # active salinity only

Base.nameof(::Type{SeawaterBuoyancy}) = "SeawaterBuoyancy"
Base.summary(b::SeawaterBuoyancy) = string(nameof(typeof(b)), " with g=", prettysummary(b.gravitational_acceleration),
Expand Down
4 changes: 2 additions & 2 deletions src/Fields/field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,8 @@ Base.checkbounds(f::Field, I...) = Base.checkbounds(f.data, I...)
@propagate_inbounds Base.lastindex(f::Field) = lastindex(f.data)
@propagate_inbounds Base.lastindex(f::Field, dim) = lastindex(f.data, dim)

Base.fill!(f::Field, val) = fill!(parent(f), val)
Base.parent(f::Field) = parent(f.data)
@inline Base.fill!(f::Field, val) = fill!(parent(f), val)
@inline Base.parent(f::Field) = parent(f.data)
Adapt.adapt_structure(to, f::Field) = Adapt.adapt(to, f.data)

total_size(f::Field) = total_size(f.grid, location(f), f.indices)
Expand Down
Loading