-
Notifications
You must be signed in to change notification settings - Fork 196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement neural network parameterization of salty turbulent mixing in the upper ocean #3819
Draft
xkykai
wants to merge
113
commits into
main
Choose a base branch
from
xk/embed-nn
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 57 commits
Commits
Show all changes
113 commits
Select commit
Hold shift + click to select a range
13cd632
working version on CPU
xkykai bb70e5c
a working model for GPU!
xkykai e845dd5
it actually works on the gpu now!
xkykai b6552ab
build scaling from tuple values
xkykai c9a9c65
Update XinKaiLocalVerticalDiffusivity with 2 Pr values
xkykai 9a07b82
fix function construct_scaling to construct_zeromeanunitvariance_scaling
xkykai 74c8b61
update using KernelParameters
xkykai 6f97c3f
fix nn closure bug
xkykai 65f15e4
test script for nn closure
xkykai cb601ce
fix N2 average, tracer diffusivity expression for local diffusivity c…
xkykai dbe5759
2Pr version of local physical closure
xkykai e9ac435
nonlocal physical closure of vertical diffusivity
xkykai 5e548da
2Pr nonlocal physical closure
xkykai 772c235
working GPU version of NN closure with scaling and correction
xkykai 9c6873f
validation script for oceananigans NN implementation
xkykai f6d508a
close file and record video of validation
xkykai 1faa5a6
Update Project.toml with new package dependencies
xkykai 930cc7e
Update NN closure model to use a larger neural network
xkykai fd95fc6
Remove unused import of StaticArrays
xkykai 72aebc1
Add ComponentArrays dependency
xkykai c2d336e
add total_size and KernelParameters dependency
xkykai d4a1156
Coarsen LES data for NN closure model
xkykai ae7b563
rename file and compare LES with NN closure
xkykai 4d118c2
run LES for hald sinusoid cooling
xkykai fc081cc
run 3D simulation with limited extent
xkykai a78109e
add sponge at bottom
xkykai ca8d78c
fix metres to meters
xkykai 78e7a67
fix temperature flux
xkykai eaf2183
Calculate average velocities and tracers in 3D model LES simulation
xkykai 7e144ff
reduce size of model
xkykai ba7f58d
run double gyre with physical closure
xkykai fcb466a
rename file
xkykai b63e7d6
fix S forcing
xkykai 67e011c
use RiBasedVerticalDiffusivity as closure
xkykai 55d635a
fix sign errors in tracer fluxes
xkykai 8319ba8
fix boundary conditions, initialize trivial model first
xkykai 229feeb
update boundary conditions, initial state
xkykai a853275
fix tyop_buoyancy_flux bug
xkykai 477c6be
fix initial conditions, run for 10 years, set up checkpointing
xkykai 232b67d
run double gyre with new physical closure
xkykai 148a582
plotting barotropic streamfunction
xkykai e9f2f0e
fix plot units
xkykai 7f17cf2
Merge branch 'main' into xk/embed-nn
xkykai ed8ec7d
run double gyre with CATKE
xkykai f4081bf
Merge branch 'main' into xk/embed-nn
xkykai 893e529
add TKE tracer for CATKE
xkykai ba5491d
fix closure to use only CATKE
xkykai 21a8356
update CATKE configuration
xkykai de190d2
use older dependencies for compatibility with neural networks
xkykai 7e7a04f
add fields to be calculated for CATKE
xkykai 1c31868
local diffusivity for 2step calibration
xkykai 3609d6c
new NN closure with nof and base boundary layer criteria
xkykai 8eaaaae
fix bug in NN closure implementation
xkykai 6535be5
using Grids.total_size
xkykai 812df88
add using KernelParameters
xkykai eb50cca
update NN model
xkykai b2cd2fa
use BBL integral metric to compute base of boundary layer
xkykai d3fa8b8
Update xin_kai_vertical_diffusivity.jl
xkykai 2448003
remove type piracy TEOS10.s
xkykai cb5f864
NN closure using BBL zone below nonbabkground kappa
xkykai f443296
run double gyre with NDE BBLkappazonelast41
xkykai 2f3eed9
change initialized state to 8day restoration forcing
xkykai 987f1ed
run double gyre withj baseclosure initialized
xkykai c4351f7
NN closure for augmenting flux in a zone below MLD
xkykai 19f71d3
8 day relaxation double gyre for baseclosure
xkykai 3c4102a
uncomment CairoMakie
xkykai b122e7b
NN closure and CATKE with 8 day restoration and warm flush
xkykai e0bd65a
fix initial temperature issue
xkykai 8dcb6d3
add zonal average calculations to double gyre simulation
xkykai 132188a
run double gyre with seasonal forcing
xkykai 57c3ac5
increase simulation run time
xkykai d5a320e
wall restoration to maintain strratification
xkykai 0cd618a
baseclosure doublegyre with wallrestoration
xkykai 4abf7e0
change filenames
xkykai 7cf7090
using wider zone for NN closure
xkykai 70dbe2f
fix variable error
xkykai fd57bf8
update NN configuration
xkykai b582447
NDE double gyre script for seasonal forcing and wall restoration
xkykai 6235bc4
run NNclosure with Ri nof BBLkappazonelast55
xkykai 3609c21
updated NN model with no Ri
xkykai 823e8c7
run double gyre with NNclosure with no Ri input kappazonelast55
xkykai 893b7bc
add fluxes calculations and zonal average
xkykai ea77669
Merge branch 'main' into xk/embed-nn
xkykai 3428f5d
Merge branch 'main' into xk/embed-nn
xkykai 4a5a3e1
using centered second order in z instead of WENO
xkykai 0f6bf66
remove background vertical scalar diffusivity
xkykai 8385229
diffusivity fields indexing for base closure
xkykai 45eef0a
change temperature restoration to 30days
xkykai e29bfd9
change vertical advection scheme to WENO5
xkykai 0fba127
update neural network to new weights
xkykai 73afa68
fix file name change
xkykai 1fab19f
run double gyre with centered second order and wall restoration
xkykai 49d901f
fix dynamic function invocation that doesn't affect the baseclosure s…
xkykai 534d21d
add y resolution variable
xkykai 2551c38
run NN closure recording xz slices
xkykai ff272c4
recording and plotting xz yz slices and fluxes
xkykai 93bf0bc
change default z advection scheme to weno 5
xkykai d36677f
run with linear ramp seasonal forcing
xkykai ce68cc7
fix T_seasonal
xkykai e1603ca
fix temperature restoration
xkykai a3ab871
implement different NNclosure with Ri zone
xkykai 51acb9d
run double gyre on updated closure
xkykai 33f2ea8
new NN closure witth Rifirstzone
xkykai 36cecb6
fix NN closure
xkykai 946d710
fix NN_closure
xkykai 0c55c44
updated base closure with new calibration
xkykai dada2c5
run base closure and CATKE with mode waters
xkykai 1d0ef1f
fix advection scheme and file name
xkykai 4898be4
new NN closure including Ri
xkykai 3e84aaf
actual new closure including Ri
xkykai 88f7fa4
run double gyre with new NN closure
xkykai 030c890
run CATKE for 10800days
xkykai f4f7383
Merge branch 'main' into xk/embed-nn
xkykai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
#using Pkg | ||
using Oceananigans | ||
using Printf | ||
using Statistics | ||
|
||
using Oceananigans | ||
using Oceananigans.Units | ||
using Oceananigans.OutputReaders: FieldTimeSeries | ||
using Oceananigans.Grids: xnode, ynode, znode | ||
using SeawaterPolynomials | ||
using SeawaterPolynomials:TEOS10 | ||
import SeawaterPolynomials.TEOS10: s, ΔS, Sₐᵤ | ||
s(Sᴬ::Number) = Sᴬ + ΔS >= 0 ? √((Sᴬ + ΔS) / Sₐᵤ) : NaN | ||
using Glob | ||
|
||
# Architecture | ||
model_architecture = GPU() | ||
|
||
# number of grid points | ||
Ny = 20000 | ||
Nz = 256 | ||
|
||
const Ly = 40kilometers | ||
const Lz = 512meters | ||
|
||
grid = RectilinearGrid(model_architecture, | ||
topology = (Flat, Bounded, Bounded), | ||
size = (Ny, Nz), | ||
halo = (5, 5), | ||
y = (0, Ly), | ||
z = (-Lz, 0)) | ||
|
||
@info "Built a grid: $grid." | ||
|
||
##### | ||
##### Boundary conditions | ||
##### | ||
const dTdz = 0.014 | ||
const dSdz = 0.0021 | ||
|
||
const T_surface = 20.0 | ||
const S_surface = 36.6 | ||
const max_temperature_flux = 3e-4 | ||
|
||
FILE_DIR = "./LES/NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES" | ||
mkpath(FILE_DIR) | ||
|
||
@inline function temperature_flux(y, t) | ||
return max_temperature_flux * sin(π * y / Ly) | ||
end | ||
|
||
T_bcs = FieldBoundaryConditions(top=FluxBoundaryCondition(temperature_flux)) | ||
|
||
##### | ||
##### Coriolis | ||
##### | ||
|
||
const f₀ = 8e-5 | ||
coriolis = FPlane(f=f₀) | ||
|
||
##### | ||
##### Forcing and initial condition | ||
##### | ||
T_initial(y, z) = dTdz * z + T_surface | ||
S_initial(y, z) = dSdz * z + S_surface | ||
|
||
##### | ||
##### Model building | ||
##### | ||
|
||
@info "Building a model..." | ||
|
||
model = NonhydrostaticModel(; grid = grid, | ||
advection = WENO(order=9), | ||
coriolis = coriolis, | ||
buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), | ||
tracers = (:T, :S), | ||
timestepper = :RungeKutta3, | ||
closure = nothing, | ||
boundary_conditions = (; T=T_bcs)) | ||
|
||
@info "Built $model." | ||
|
||
##### | ||
##### Initial conditions | ||
##### | ||
|
||
# resting initial condition | ||
noise(z) = rand() * exp(z / 8) | ||
|
||
T_initial_noisy(y, z) = T_initial(y, z) + 1e-6 * noise(z) | ||
S_initial_noisy(y, z) = S_initial(y, z) + 1e-6 * noise(z) | ||
|
||
set!(model, T=T_initial_noisy, S=S_initial_noisy) | ||
##### | ||
##### Simulation building | ||
##### | ||
simulation = Simulation(model, Δt = 0.1, stop_time = 30days) | ||
|
||
# add timestep wizard callback | ||
wizard = TimeStepWizard(cfl=0.6, max_change=1.05, max_Δt=20minutes) | ||
simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) | ||
|
||
# add progress callback | ||
wall_clock = [time_ns()] | ||
|
||
function print_progress(sim) | ||
@printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", | ||
100 * (sim.model.clock.time / sim.stop_time), | ||
sim.model.clock.iteration, | ||
prettytime(sim.model.clock.time), | ||
prettytime(1e-9 * (time_ns() - wall_clock[1])), | ||
maximum(sim.model.velocities.u), | ||
maximum(sim.model.velocities.v), | ||
maximum(sim.model.tracers.T), | ||
maximum(sim.model.tracers.S), | ||
prettytime(sim.Δt)) | ||
|
||
wall_clock[1] = time_ns() | ||
|
||
return nothing | ||
end | ||
|
||
simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(1000)) | ||
|
||
##### | ||
##### Diagnostics | ||
##### | ||
|
||
u, w = model.velocities.u, model.velocities.w | ||
v = @at (Center, Center, Center) model.velocities.v | ||
T, S = model.tracers.T, model.tracers.S | ||
|
||
outputs = (; u, v, w, T, S) | ||
|
||
##### | ||
##### Build checkpointer and output writer | ||
##### | ||
simulation.output_writers[:jld2] = JLD2OutputWriter(model, outputs, | ||
filename = "$(FILE_DIR)/instantaneous_fields.jld2", | ||
schedule = TimeInterval(1hour)) | ||
|
||
simulation.output_writers[:checkpointer] = Checkpointer(model, | ||
schedule = TimeInterval(1day), | ||
prefix = "$(FILE_DIR)/checkpointer", | ||
overwrite_existing = true) | ||
|
||
@info "Running the simulation..." | ||
|
||
try | ||
files = readdir(FILE_DIR) | ||
checkpoint_files = files[occursin.("checkpointer_iteration", files)] | ||
if !isempty(checkpoint_files) | ||
checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) | ||
pickup_iter = maximum(checkpoint_iters) | ||
run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") | ||
else | ||
run!(simulation) | ||
end | ||
catch err | ||
@info "run! threw an error! The error message is" | ||
showerror(stdout, err) | ||
end | ||
|
||
checkpointers = glob("$(FILE_DIR)/checkpointer_iteration*.jld2") | ||
if !isempty(checkpointers) | ||
rm.(checkpointers) | ||
end | ||
|
||
# ##### | ||
# ##### Visualization | ||
# ##### | ||
#%% | ||
using CairoMakie | ||
|
||
|
||
u_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "u", backend=OnDisk()) | ||
v_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "v", backend=OnDisk()) | ||
T_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "T", backend=OnDisk()) | ||
S_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "S", backend=OnDisk()) | ||
|
||
yC = ynodes(T_data.grid, Center()) | ||
yF = ynodes(T_data.grid, Face()) | ||
|
||
zC = znodes(T_data.grid, Center()) | ||
zF = znodes(T_data.grid, Face()) | ||
|
||
Nt = length(T_data.times) | ||
#%% | ||
fig = Figure(size = (1500, 900)) | ||
axu = CairoMakie.Axis(fig[1, 1], xlabel = "y (m)", ylabel = "z (m)", title = "u") | ||
axv = CairoMakie.Axis(fig[1, 3], xlabel = "y (m)", ylabel = "z (m)", title = "v") | ||
axT = CairoMakie.Axis(fig[2, 1], xlabel = "y (m)", ylabel = "z (m)", title = "Temperature") | ||
axS = CairoMakie.Axis(fig[2, 3], xlabel = "y (m)", ylabel = "z (m)", title = "Salinity") | ||
n = Obeservable(1) | ||
|
||
uₙ = @lift interior(u_data[$n], 1, :, :) | ||
vₙ = @lift interior(v_data[$n], 1, :, :) | ||
Tₙ = @lift interior(T_data[$n], 1, :, :) | ||
Sₙ = @lift interior(S_data[$n], 1, :, :) | ||
|
||
ulim = @lift (-maximum([maximum(abs, $uₙ), 1e-16]), maximum([maximum(abs, $uₙ), 1e-16])) | ||
vlim = @lift (-maximum([maximum(abs, $vₙ), 1e-16]), maximum([maximum(abs, $vₙ), 1e-16])) | ||
Tlim = (minimum(interior(T_data[1])), maximum(interior(T_data[1]))) | ||
Slim = (minimum(interior(S_data[1])), maximum(interior(S_data[1]))) | ||
|
||
title_str = @lift "Time: $(round(T_data.times[$n] / 86400, digits=2)) days" | ||
Label(fig[0, :], title_str, tellwidth = false) | ||
|
||
hu = heatmap!(axu, yC, zC, uₙ, colormap=:RdBu_9, colorrange=ulim) | ||
hv = heatmap!(axv, yC, zC, vₙ, colormap=:RdBu_9, colorrange=vlim) | ||
hT = heatmap!(axT, yC, zC, Tₙ, colorrange=Tlim) | ||
hS = heatmap!(axS, yC, zC, Sₙ, colorrange=Slim) | ||
|
||
Colorbar(fig[1, 2], hu, label = "(m/s)") | ||
Colorbar(fig[1, 4], hv, label = "(m/s)") | ||
Colorbar(fig[2, 2], hT, label = "(°C)") | ||
Colorbar(fig[2, 4], hS, label = "(psu)") | ||
|
||
CairoMakie.record(fig, "$(FILE_DIR)/2D_sin_cooling_$(max_temperature_flux)_30days.mp4", 1:Nt, framerate=15) do nn | ||
n[] = nn | ||
end | ||
|
||
# display(fig) | ||
#%% |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is an interesting idea. Just need to use non-short-circuiting logic and be mindful of number type and it could potentially go in
SeawaterPolynomials
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was writing this in order to address the issue of
when I was training my model, in order to continue the training with
NaN
.Wouldn't doing non-short-circuiting logic mean that we are forced into a
DomainError
ifSᴬ + ΔS
becomes negative?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, we could clip to avoid it, for example:
but after reflecting on this a bit more, I think it's better to receive a
DomainError
from this point, than to get a NaN and not know why.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I believe I started doing this to throw
NaN
s when using EKI, which was necessary to continue training, but apart from training purposes there shouldn't be a need to do thisThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh no but I see your point now. It's part of the issue of failure handling with EKI.
The thing is, we really do want to support training / automatic calibration and it shouldn't require hacks like this, I feel this really impacts reproducibility and understandability (technically this is type piracy...)
Maybe clipping+NaN should be an option of the equation of state then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option would be to wrap
run!(simulation)
inside atry/catch
so that, if a simulation errors, we can replace the output with NaN / mark the simulation as failed for the purpose of estimating parameters. This would work too right? That might be simpler (simply capturing errors after they occur) than trying to prevent any errors from occurring, which is possible in this particular case because we own SeawaterPolynomials but is not generally a solution.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect what one could do is do clipping + throw
NaN
with a warning message that this is due to the model having salinity values that are beyond the regime where TEOS-10 is correct. Or perhaps throw a warning whenever the temperature and salinity values are outside of the reasonable regime of TEOS10 (between 0-40 °C and 0-42 psu)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't warn from inside a kernel though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, so try-catch it is then