diff --git a/2D_model_LES_sin.jl b/2D_model_LES_sin.jl new file mode 100644 index 0000000000..259d370719 --- /dev/null +++ b/2D_model_LES_sin.jl @@ -0,0 +1,224 @@ +#using Pkg +using Oceananigans +using Printf +using Statistics + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 + +using Glob + +# Architecture +model_architecture = GPU() + +# number of grid points +Ny = 20000 +Nz = 256 + +const Ly = 40kilometers +const Lz = 512meters + +grid = RectilinearGrid(model_architecture, + topology = (Flat, Bounded, Bounded), + size = (Ny, Nz), + halo = (5, 5), + y = (0, Ly), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const dTdz = 0.014 +const dSdz = 0.0021 + +const T_surface = 20.0 +const S_surface = 36.6 +const max_temperature_flux = 2e-4 + +FILE_DIR = "./LES/NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES" +mkpath(FILE_DIR) + +@inline function temperature_flux(y, t) + return max_temperature_flux * sin(π * y / Ly) +end + +T_bcs = FieldBoundaryConditions(top=FluxBoundaryCondition(temperature_flux)) + +##### +##### Coriolis +##### + +const f₀ = 8e-5 +coriolis = FPlane(f=f₀) + +##### +##### Forcing and initial condition +##### +T_initial(y, z) = dTdz * z + T_surface +S_initial(y, z) = dSdz * z + S_surface + +##### +##### Model building +##### + +@info "Building a model..." + +model = NonhydrostaticModel(; grid = grid, + advection = WENO(order=9), + coriolis = coriolis, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + tracers = (:T, :S), + timestepper = :RungeKutta3, + closure = nothing, + boundary_conditions = (; T=T_bcs)) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(y, z) = T_initial(y, z) + 1e-6 * noise(z) +S_initial_noisy(y, z) = S_initial(y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +##### +##### Simulation building +##### +simulation = Simulation(model, Δt = 0.1, stop_time = 30days) + +# add timestep wizard callback +wizard = TimeStepWizard(cfl=0.6, max_change=1.05, max_Δt=20minutes) +simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(sim.model.velocities.u), + maximum(sim.model.velocities.v), + maximum(sim.model.tracers.T), + maximum(sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(1000)) + +##### +##### Diagnostics +##### + +u, w = model.velocities.u, model.velocities.w +v = @at (Center, Center, Center) model.velocities.v +T, S = model.tracers.T, model.tracers.S + +outputs = (; u, v, w, T, S) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:jld2] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields.jld2", + schedule = TimeInterval(1hour)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(1day), + prefix = "$(FILE_DIR)/checkpointer", + overwrite_existing = true) + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +checkpointers = glob("$(FILE_DIR)/checkpointer_iteration*.jld2") +if !isempty(checkpointers) + rm.(checkpointers) +end + +# ##### +# ##### Visualization +# ##### +#%% +using CairoMakie + + +u_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "u", backend=OnDisk()) +v_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "v", backend=OnDisk()) +T_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "T", backend=OnDisk()) +S_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "S", backend=OnDisk()) + +yC = ynodes(T_data.grid, Center()) +yF = ynodes(T_data.grid, Face()) + +zC = znodes(T_data.grid, Center()) +zF = znodes(T_data.grid, Face()) + +Nt = length(T_data.times) +#%% +fig = Figure(size = (1500, 900)) +axu = CairoMakie.Axis(fig[1, 1], xlabel = "y (m)", ylabel = "z (m)", title = "u") +axv = CairoMakie.Axis(fig[1, 3], xlabel = "y (m)", ylabel = "z (m)", title = "v") +axT = CairoMakie.Axis(fig[2, 1], xlabel = "y (m)", ylabel = "z (m)", title = "Temperature") +axS = CairoMakie.Axis(fig[2, 3], xlabel = "y (m)", ylabel = "z (m)", title = "Salinity") +n = Obeservable(1) + +uₙ = @lift interior(u_data[$n], 1, :, :) +vₙ = @lift interior(v_data[$n], 1, :, :) +Tₙ = @lift interior(T_data[$n], 1, :, :) +Sₙ = @lift interior(S_data[$n], 1, :, :) + +ulim = @lift (-maximum([maximum(abs, $uₙ), 1e-16]), maximum([maximum(abs, $uₙ), 1e-16])) +vlim = @lift (-maximum([maximum(abs, $vₙ), 1e-16]), maximum([maximum(abs, $vₙ), 1e-16])) +Tlim = (minimum(interior(T_data[1])), maximum(interior(T_data[1]))) +Slim = (minimum(interior(S_data[1])), maximum(interior(S_data[1]))) + +title_str = @lift "Time: $(round(T_data.times[$n] / 86400, digits=2)) days" +Label(fig[0, :], title_str, tellwidth = false) + +hu = heatmap!(axu, yC, zC, uₙ, colormap=:RdBu_9, colorrange=ulim) +hv = heatmap!(axv, yC, zC, vₙ, colormap=:RdBu_9, colorrange=vlim) +hT = heatmap!(axT, yC, zC, Tₙ, colorrange=Tlim) +hS = heatmap!(axS, yC, zC, Sₙ, colorrange=Slim) + +Colorbar(fig[1, 2], hu, label = "(m/s)") +Colorbar(fig[1, 4], hv, label = "(m/s)") +Colorbar(fig[2, 2], hT, label = "(°C)") +Colorbar(fig[2, 4], hS, label = "(psu)") + +CairoMakie.record(fig, "$(FILE_DIR)/2D_sin_cooling_$(max_temperature_flux)_30days.mp4", 1:Nt, framerate=15) do nn + n[] = nn +end + +# display(fig) +#%% \ No newline at end of file diff --git a/3D_model_LES_sin.jl b/3D_model_LES_sin.jl new file mode 100644 index 0000000000..8e3d0f7766 --- /dev/null +++ b/3D_model_LES_sin.jl @@ -0,0 +1,245 @@ +#using Pkg +using Oceananigans +using Printf +using Statistics + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 + +using Glob + +# Architecture +model_architecture = GPU() + +# number of grid points +Nx = 125 +Ny = 1000 +Nz = 250 + +const Lx = 250meters +const Ly = 2kilometers +const Lz = 500meters + +grid = RectilinearGrid(model_architecture, + topology = (Periodic, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (5, 5, 5), + x = (0, Lx), + y = (0, Ly), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const dTdz = 0.014 +const dSdz = 0.0021 + +const T_surface = 20.0 +const S_surface = 36.6 +const max_temperature_flux = 3e-4 + +FILE_DIR = "./LES/NN_3D_channel_sin_cooling_$(max_temperature_flux)_LES_Lx_$(Lx)_Ly_$(Ly)_Lz_$(Lz)_Nx_$(Nx)_Ny_$(Ny)_Nz_$(Nz)" +mkpath(FILE_DIR) + +@inline function temperature_flux(x, y, t) + return max_temperature_flux * sin(π * y / Ly) +end + +T_bcs = FieldBoundaryConditions(top=FluxBoundaryCondition(temperature_flux)) + +##### +##### Coriolis +##### + +const f₀ = 8e-5 +coriolis = FPlane(f=f₀) + +##### +##### Forcing and initial condition +##### +T_initial(x, y, z) = dTdz * z + T_surface +S_initial(x, y, z) = dSdz * z + S_surface + +damping_rate = 1/15minute + +T_target(x, y, z, t) = T_initial(x, y, z) +S_target(x, y, z, t) = S_initial(x, y, z) + +bottom_mask = GaussianMask{:z}(center=-grid.Lz, width=grid.Lz/10) + +uvw_sponge = Relaxation(rate=damping_rate, mask=bottom_mask) +T_sponge = Relaxation(rate=damping_rate, mask=bottom_mask, target=T_target) +S_sponge = Relaxation(rate=damping_rate, mask=bottom_mask, target=S_target) + +##### +##### Model building +##### + +@info "Building a model..." + +model = NonhydrostaticModel(; grid = grid, + advection = WENO(order=9), + coriolis = coriolis, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + tracers = (:T, :S), + timestepper = :RungeKutta3, + closure = nothing, + boundary_conditions = (; T=T_bcs), + forcing = (u=uvw_sponge, v=uvw_sponge, w=uvw_sponge, T=T_sponge, S=S_sponge)) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +##### +##### Simulation building +##### +simulation = Simulation(model, Δt = 0.1, stop_time = 10days) + +# add timestep wizard callback +wizard = TimeStepWizard(cfl=0.6, max_change=1.05, max_Δt=20minutes) +simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(sim.model.velocities.u), + maximum(sim.model.velocities.v), + maximum(sim.model.tracers.T), + maximum(sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(1000)) + +##### +##### Diagnostics +##### + +u, w = model.velocities.u, model.velocities.w +v = @at (Center, Center, Center) model.velocities.v +T, S = model.tracers.T, model.tracers.S + +ubar = Average(u, dims=1) +vbar = Average(v, dims=1) +wbar = Average(w, dims=1) +Tbar = Average(T, dims=1) +Sbar = Average(S, dims=1) + +# outputs = (; u, v, w, T, S) +outputs = (; ubar, vbar, wbar, Tbar, Sbar) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:jld2] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields.jld2", + schedule = TimeInterval(1hour)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(0.5day), + prefix = "$(FILE_DIR)/checkpointer", + overwrite_existing = true) + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +checkpointers = glob("$(FILE_DIR)/checkpointer_iteration*.jld2") +if !isempty(checkpointers) + rm.(checkpointers) +end + +# ##### +# ##### Visualization +# ##### +#%% +using CairoMakie + +u_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "u", backend=OnDisk()) +v_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "v", backend=OnDisk()) +T_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "T", backend=OnDisk()) +S_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "S", backend=OnDisk()) + +yC = ynodes(T_data.grid, Center()) +yF = ynodes(T_data.grid, Face()) + +zC = znodes(T_data.grid, Center()) +zF = znodes(T_data.grid, Face()) + +Nt = length(T_data.times) +#%% +fig = Figure(size = (1500, 900)) +axu = CairoMakie.Axis(fig[1, 1], xlabel = "y (m)", ylabel = "z (m)", title = "u") +axv = CairoMakie.Axis(fig[1, 3], xlabel = "y (m)", ylabel = "z (m)", title = "v") +axT = CairoMakie.Axis(fig[2, 1], xlabel = "y (m)", ylabel = "z (m)", title = "Temperature") +axS = CairoMakie.Axis(fig[2, 3], xlabel = "y (m)", ylabel = "z (m)", title = "Salinity") +n = Obeservable(1) + +uₙ = @lift interior(u_data[$n], 1, :, :) +vₙ = @lift interior(v_data[$n], 1, :, :) +Tₙ = @lift interior(T_data[$n], 1, :, :) +Sₙ = @lift interior(S_data[$n], 1, :, :) + +ulim = @lift (-maximum([maximum(abs, $uₙ), 1e-16]), maximum([maximum(abs, $uₙ), 1e-16])) +vlim = @lift (-maximum([maximum(abs, $vₙ), 1e-16]), maximum([maximum(abs, $vₙ), 1e-16])) +Tlim = (minimum(interior(T_data[1])), maximum(interior(T_data[1]))) +Slim = (minimum(interior(S_data[1])), maximum(interior(S_data[1]))) + +title_str = @lift "Time: $(round(T_data.times[$n] / 86400, digits=2)) days" +Label(fig[0, :], title_str, tellwidth = false) + +hu = heatmap!(axu, yC, zC, uₙ, colormap=:RdBu_9, colorrange=ulim) +hv = heatmap!(axv, yC, zC, vₙ, colormap=:RdBu_9, colorrange=vlim) +hT = heatmap!(axT, yC, zC, Tₙ, colorrange=Tlim) +hS = heatmap!(axS, yC, zC, Sₙ, colorrange=Slim) + +Colorbar(fig[1, 2], hu, label = "(m/s)") +Colorbar(fig[1, 4], hv, label = "(m/s)") +Colorbar(fig[2, 2], hT, label = "(°C)") +Colorbar(fig[2, 4], hS, label = "(psu)") + +CairoMakie.record(fig, "$(FILE_DIR)/3D_sin_cooling_$(max_temperature_flux)_Lx_$(Lx)_Ly_$(Ly)_Lz_$(Lz)_Nx_$(Nx)_Ny_$(Ny)_Nz_$(Nz)_10days.mp4", 1:Nt, framerate=10) do nn + n[] = nn +end + +# display(fig) +#%% \ No newline at end of file diff --git a/NN_1D_model.jl b/NN_1D_model.jl new file mode 100644 index 0000000000..532cb6cdf4 --- /dev/null +++ b/NN_1D_model.jl @@ -0,0 +1,219 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("NN_closure_global.jl") +include("xin_kai_vertical_diffusivity_local.jl") +include("feature_scaling.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 + + +# Architecture +model_architecture = CPU() + +# number of grid points +Nz = 64 +const Lz = 512 + +grid = RectilinearGrid(model_architecture, + topology = (Flat, Flat, Bounded), + size = Nz, + halo = 3, + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const dTdz = 0.014 +const dSdz = 0.0021 + +const T_surface = 20.0 +const S_surface = 36.6 + +T_bcs = FieldBoundaryConditions(top=FluxBoundaryCondition(3e-4)) +u_bcs = FieldBoundaryConditions(top=FluxBoundaryCondition(2e-4)) + +##### +##### Coriolis +##### + +const f₀ = 1e-4 +const β = 1e-11 +# coriolis = BetaPlane(f₀=f₀, β = β) +coriolis = FPlane(f=f₀) + +##### +##### Forcing and initial condition +##### +T_initial(z) = dTdz * z + T_surface +S_initial(z) = dSdz * z + S_surface + +nn_closure = NNFluxClosure(model_architecture) +base_closure = XinKaiLocalVerticalDiffusivity() + +##### +##### Model building +##### + +@info "Building a model..." + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = ImplicitFreeSurface(), + momentum_advection = WENO(grid = grid), + tracer_advection = WENO(grid = grid), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = (nn_closure, base_closure), + # closure = base_closure, + tracers = (:T, :S), + # boundary_conditions = (; T = T_bcs), + boundary_conditions = (; T = T_bcs, u = u_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(z) = T_initial(z) + 1e-6 * noise(z) +S_initial_noisy(z) = S_initial(z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.1, max_change=1.1, max_Δt=20minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(20)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(20)) + +##### +##### Diagnostics +##### + +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S + +ubar = Field(Average(u, dims = (1,2))) +vbar = Field(Average(v, dims = (1,2))) +Tbar = Field(Average(T, dims = (1,2))) +Sbar = Field(Average(S, dims = (1,2))) + +averaged_outputs = (; ubar, vbar, Tbar, Sbar) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:jld2] = JLD2OutputWriter(model, averaged_outputs, + filename = "NN_1D_channel_averages", + schedule = TimeInterval(10minutes), + overwrite_existing = true) + +@info "Running the simulation..." + +try + run!(simulation, pickup = false) +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +# ##### +# ##### Visualization +# ##### +#%% +using CairoMakie + +ubar_data = FieldTimeSeries("./NN_1D_channel_averages.jld2", "ubar") +vbar_data = FieldTimeSeries("./NN_1D_channel_averages.jld2", "vbar") +Tbar_data = FieldTimeSeries("./NN_1D_channel_averages.jld2", "Tbar") +Sbar_data = FieldTimeSeries("./NN_1D_channel_averages.jld2", "Sbar") + +zC = znodes(Tbar_data.grid, Center()) +zF = znodes(Tbar_data.grid, Face()) + +Nt = length(Tbar_data.times) + +fig = Figure(size = (1800, 600)) +axu = CairoMakie.Axis(fig[1, 1], xlabel = "u (m s⁻¹)", ylabel = "z (m)") +axv = CairoMakie.Axis(fig[1, 2], xlabel = "v (m s⁻¹)", ylabel = "z (m)") +axT = CairoMakie.Axis(fig[1, 3], xlabel = "T (°C)", ylabel = "z (m)") +axS = CairoMakie.Axis(fig[1, 4], xlabel = "S (g kg⁻¹)", ylabel = "z (m)") +# slider = Slider(fig[2, :], range=1:Nt) +n = Observable(1) + +ubarₙ = @lift interior(ubar_data[$n], 1, 1, :) +vbarₙ = @lift interior(vbar_data[$n], 1, 1, :) +Tbarₙ = @lift interior(Tbar_data[$n], 1, 1, :) +Sbarₙ = @lift interior(Sbar_data[$n], 1, 1, :) + +ulim = (minimum(ubar_data), maximum(ubar_data)) +vlim = (minimum(vbar_data), maximum(vbar_data)) +Tlim = (minimum(Tbar_data), maximum(Tbar_data)) +Slim = (minimum(Sbar_data), maximum(Sbar_data)) + +title_str = @lift "Time: $(round(Tbar_data.times[$n] / 86400, digits=3)) days" + +lines!(axu, ubarₙ, zC) +lines!(axv, vbarₙ, zC) +lines!(axT, Tbarₙ, zC) +lines!(axS, Sbarₙ, zC) + +xlims!(axu, ulim) +xlims!(axv, vlim) +xlims!(axT, Tlim) +xlims!(axS, Slim) + +Label(fig[0, :], title_str, tellwidth = false) + +CairoMakie.record(fig, "./NN_1D_fields.mp4", 1:Nt, framerate=60) do nn + n[] = nn +end + +# display(fig) +#%% \ No newline at end of file diff --git a/NN_2D_channel.jl b/NN_2D_channel.jl new file mode 100644 index 0000000000..56e2204153 --- /dev/null +++ b/NN_2D_channel.jl @@ -0,0 +1,310 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("NN_closure.jl") +include("xin_kai_vertical_diffusivity_local.jl") +include("feature_scaling.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 + +const Ly = 2000kilometers # meridional domain length [m] + +# Architecture +model_architecture = CPU() + +# number of grid points +Ny = 192 +Nz = 128 + +const Lz = 1024 + +grid = RectilinearGrid(model_architecture, + topology = (Flat, Bounded, Bounded), + size = (Ny, Nz), + halo = (3, 3), + y = (0, Ly), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +T_bcs = FieldBoundaryConditions(top=FluxBoundaryCondition(1e-4), bottom=FluxBoundaryCondition(0.0)) + +##### +##### Coriolis +##### + +const f₀ = 8e-5 +const β = 1e-11 +coriolis = BetaPlane(f₀=f₀, β = β) + +##### +##### Forcing and initial condition +##### +const dTdz = 0.014 +const dSdz = 0.0021 + +const T_surface = 20.0 +const S_surface = 36.6 + +T_initial(y, z) = dTdz * z + T_surface +S_initial(y, z) = dSdz * z + S_surface + +# closure +κh = 0.5e-5 # [m²/s] horizontal diffusivity +νh = 30.0 # [m²/s] horizontal viscocity +κz = 0.5e-5 # [m²/s] vertical diffusivity +νz = 3e-4 # [m²/s] vertical viscocity + +horizontal_closure = HorizontalScalarDiffusivity(ν = νh, κ = κh) +vertical_closure = VerticalScalarDiffusivity(ν = νz, κ = κz) + +convective_adjustment = ConvectiveAdjustmentVerticalDiffusivity(convective_κz = 1.0, + convective_νz = 0.0) + +nn_closure = NNFluxClosure(model_architecture) +base_closure = XinKaiLocalVerticalDiffusivity() + +##### +##### Model building +##### + +@info "Building a model..." + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = ImplicitFreeSurface(), + momentum_advection = WENO(grid = grid), + tracer_advection = WENO(grid = grid), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = (nn_closure, base_closure), + # closure = (horizontal_closure, vertical_closure, convective_adjustment), + tracers = (:T, :S), + boundary_conditions = (; T = T_bcs), + # forcing = (; b = Fb) +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(y, z) = rand() * exp(z / 8) + +T_initial_noisy(y, z) = T_initial(y, z) + 1e-6 * noise(y, z) +S_initial_noisy(y, z) = S_initial(y, z) + 1e-6 * noise(y, z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 1days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.1, max_change=1.1, max_Δt=20minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(20)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): (%6.3e, %6.3e, %6.3e) m/s, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.velocities.w), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(20)) + + +##### +##### Diagnostics +##### + +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S + +ζ = Field(∂x(v) - ∂y(u)) + +Tbar = Field(Average(T, dims = 1)) +Sbar = Field(Average(S, dims = 1)) +V = Field(Average(v, dims = 1)) +W = Field(Average(w, dims = 1)) + +T′ = T - Tbar +S′ = S - Sbar +v′ = v - V +w′ = w - W + +v′T′ = Field(Average(v′ * T′, dims = 1)) +w′T′ = Field(Average(w′ * T′, dims = 1)) +v′S′ = Field(Average(v′ * S′, dims = 1)) +w′S′ = Field(Average(w′ * S′, dims = 1)) + +outputs = (; T, S, ζ, w) + +averaged_outputs = (; v′T′, w′T′, v′S′, w′S′, Tbar, Sbar) + +##### +##### Build checkpointer and output writer +##### + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(100days), + prefix = "NN_channel", + overwrite_existing = true) + +simulation.output_writers[:fields] = JLD2OutputWriter(model, outputs, + schedule = TimeInterval(5days), + filename = "NN_channel", + # field_slicer = nothing, + verbose = true, + overwrite_existing = true) + +simulation.output_writers[:averages] = JLD2OutputWriter(model, averaged_outputs, + schedule = AveragedTimeInterval(1days, window = 1days, stride = 1), + filename = "NN_channel_averages", + verbose = true, + overwrite_existing = true) + +@info "Running the simulation..." + +try + run!(simulation, pickup = false) +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#= +# ##### +# ##### Visualization +# ##### + +using Plots + +grid = RectilinearGrid(architecture = CPU(), + topology = (Periodic, Bounded, Bounded), + size = (grid.Nx, grid.Ny, grid.Nz), + halo = (3, 3, 3), + x = (0, grid.Lx), + y = (0, grid.Ly), + z = z_faces) + +xζ, yζ, zζ = nodes((Face, Face, Center), grid) +xc, yc, zc = nodes((Center, Center, Center), grid) +xw, yw, zw = nodes((Center, Center, Face), grid) + +j′ = round(Int, grid.Ny / 2) +y′ = yζ[j′] + +b_timeseries = FieldTimeSeries("abernathey_channel.jld2", "b", grid = grid) +ζ_timeseries = FieldTimeSeries("abernathey_channel.jld2", "ζ", grid = grid) +w_timeseries = FieldTimeSeries("abernathey_channel.jld2", "w", grid = grid) + +@show b_timeseries + +anim = @animate for i in 1:length(b_timeseries.times) + b = b_timeseries[i] + ζ = ζ_timeseries[i] + w = w_timeseries[i] + + b′ = interior(b) .- mean(b) + b_xy = b′[:, :, grid.Nz] + ζ_xy = interior(ζ)[:, :, grid.Nz] + ζ_xz = interior(ζ)[:, j′, :] + w_xz = interior(w)[:, j′, :] + + @show bmax = max(1e-9, maximum(abs, b_xy)) + @show ζmax = max(1e-9, maximum(abs, ζ_xy)) + @show wmax = max(1e-9, maximum(abs, w_xz)) + + blims = (-bmax, bmax) .* 0.8 + ζlims = (-ζmax, ζmax) .* 0.8 + wlims = (-wmax, wmax) .* 0.8 + + blevels = vcat([-bmax], range(blims[1], blims[2], length = 31), [bmax]) + ζlevels = vcat([-ζmax], range(ζlims[1], ζlims[2], length = 31), [ζmax]) + wlevels = vcat([-wmax], range(wlims[1], wlims[2], length = 31), [wmax]) + + xlims = (-grid.Lx / 2, grid.Lx / 2) .* 1e-3 + ylims = (0, grid.Ly) .* 1e-3 + zlims = (-grid.Lz, 0) + + w_xz_plot = contourf(xw * 1e-3, zw, w_xz', + xlabel = "x (km)", + ylabel = "z (m)", + aspectratio = 0.05, + linewidth = 0, + levels = wlevels, + clims = wlims, + xlims = xlims, + ylims = zlims, + color = :balance) + + ζ_xy_plot = contourf(xζ * 1e-3, yζ * 1e-3, ζ_xy', + xlabel = "x (km)", + ylabel = "y (km)", + aspectratio = :equal, + linewidth = 0, + levels = ζlevels, + clims = ζlims, + xlims = xlims, + ylims = ylims, + color = :balance) + + b_xy_plot = contourf(xc * 1e-3, yc * 1e-3, b_xy', + xlabel = "x (km)", + ylabel = "y (km)", + aspectratio = :equal, + linewidth = 0, + levels = blevels, + clims = blims, + xlims = xlims, + ylims = ylims, + color = :balance) + + w_xz_title = @sprintf("w(x, z) at t = %s", prettytime(ζ_timeseries.times[i])) + ζ_xz_title = @sprintf("ζ(x, z) at t = %s", prettytime(ζ_timeseries.times[i])) + ζ_xy_title = "ζ(x, y)" + b_xy_title = "b(x, y)" + + layout = @layout [upper_slice_plot{0.2h} + Plots.grid(1, 2)] + + plot(w_xz_plot, ζ_xy_plot, b_xy_plot, layout = layout, size = (1200, 1200), title = [w_xz_title ζ_xy_title b_xy_title]) +end + +mp4(anim, "abernathey_channel.mp4", fps = 8) #hide +=# \ No newline at end of file diff --git a/NN_closure.jl b/NN_closure.jl new file mode 100644 index 0000000000..e728c63679 --- /dev/null +++ b/NN_closure.jl @@ -0,0 +1,171 @@ +import Oceananigans.TurbulenceClosures: + compute_diffusivities!, + DiffusivityFields, + viscosity, + diffusivity, + diffusive_flux_x, + diffusive_flux_y, + diffusive_flux_z, + viscous_flux_ux, + viscous_flux_uy, + viscous_flux_uz, + viscous_flux_vx, + viscous_flux_vy, + viscous_flux_vz, + viscous_flux_wx, + viscous_flux_wy, + viscous_flux_wz + +using Oceananigans.BuoyancyModels: ∂x_b, ∂y_b, ∂z_b +using Oceananigans.Coriolis +using Oceananigans.Grids: φnode +using Lux, LuxCUDA +using JLD2 +using ComponentArrays +using StaticArrays + +using KernelAbstractions: @index, @kernel, @private + +import Oceananigans.TurbulenceClosures: AbstractTurbulenceClosure, ExplicitTimeDiscretization + +using Adapt + +include("./feature_scaling.jl") + +@inline hack_sind(φ) = sin(φ * π / 180) + +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::HydrostaticSphericalCoriolis) = 2 * coriolis.rotation_rate * hack_sind(φnode(i, j, k, grid, Center(), Center(), Center())) +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::FPlane) = coriolis.f +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::BetaPlane) = coriolis.f₀ + coriolis.β * ynode(i, j, k, grid, Center(), Center(), Center()) + +struct NN{M, P, S} + model :: M + ps :: P + st :: S +end + +@inline (neural_network::NN)(input) = first(first(neural_network.model(input, neural_network.ps, neural_network.st))) +@inline tosarray(x::AbstractArray) = SArray{Tuple{size(x)...}}(x) + +struct NNFluxClosure{A <: NN, S} <: AbstractTurbulenceClosure{ExplicitTimeDiscretization, 3} + wT :: A + wS :: A + scaling :: S +end + +Adapt.adapt_structure(to, nn :: NNFluxClosure) = + NNFluxClosure(Adapt.adapt(to, nn.wT), + Adapt.adapt(to, nn.wS), + Adapt.adapt(to, nn.scaling)) + +Adapt.adapt_structure(to, nn :: NN) = + NN(Adapt.adapt(to, nn.model), + Adapt.adapt(to, nn.ps), + Adapt.adapt(to, nn.st)) + +function NNFluxClosure(arch) + dev = ifelse(arch == GPU(), gpu_device(), cpu_device()) + nn_model = Chain(Dense(11, 128, relu), Dense(128, 128, relu), Dense(128, 1)) + + # scaling = jldopen("./NN_model2.jld2")["scaling"] + scaling = (; ∂T∂z = ZeroMeanUnitVarianceScaling(-0.0006850967567052092, 0.019041912105983983), + ∂S∂z = ZeroMeanUnitVarianceScaling(-0.00042981832021978374, 0.0028927446724707905), + ∂ρ∂z = ZeroMeanUnitVarianceScaling(-0.0011311157767216616, 0.0008333035237211424), + f = ZeroMeanUnitVarianceScaling(-1.5e-5, 8.73212459828649e-5), + wb = ZeroMeanUnitVarianceScaling(6.539366623323223e-8, 1.827377562065243e-7), + wT = ZeroMeanUnitVarianceScaling(1.8169228278423015e-5, 0.00010721779595955453), + wS = ZeroMeanUnitVarianceScaling(-5.8185988680682135e-6, 1.7691239104281005e-5)) + + # NNs = jldopen("./NN_model2.jld2")["NNs"] + ps = jldopen("./NN_model2.jld2")["u"] + sts = jldopen("./NN_model2.jld2")["sts"] + + ps_static = Lux.recursive_map(tosarray, ps) + sts_static = Lux.recursive_map(tosarray, sts) + + wT_NN = NN(nn_model, ps.wT, sts.wT) + wS_NN = NN(nn_model, ps.wS, sts.wS) + + return NNFluxClosure(wT_NN, wS_NN, scaling) +end + +DiffusivityFields(grid, tracer_names, bcs, ::NNFluxClosure) = + (; wT = ZFaceField(grid), + wS = ZFaceField(grid)) + +function compute_diffusivities!(diffusivities, closure::NNFluxClosure, model; parameters = :xyz) + arch = model.architecture + grid = model.grid + velocities = model.velocities + tracers = model.tracers + buoyancy = model.buoyancy + coriolis = model.coriolis + clock = model.clock + top_tracer_bcs = NamedTuple(c => tracers[c].boundary_conditions.top for c in propertynames(tracers)) + + launch!(arch, grid, parameters, + _compute_residual_fluxes!, diffusivities, grid, closure, tracers, velocities, buoyancy, coriolis, top_tracer_bcs, clock) + + return nothing +end + +@kernel function _compute_residual_fluxes!(diffusivities, grid, closure, tracers, velocities, buoyancy, coriolis, top_tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + + # Find a way to extract the type FT + nn_input = @private eltype(grid) 11 + + scaling = closure.scaling + + nn_input[10] = Jᵇ = scaling.wb(top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, tracers)) + nn_input[11] = fᶜᶜ = scaling.f(fᶜᶜᵃ(i, j, k, grid, coriolis)) + + nn_input[1] = ∂Tᵢ₋₁ = scaling.∂T∂z(∂zᶜᶜᶠ(i, j, k-1, grid, tracers.T)) + nn_input[2] = ∂Tᵢ = scaling.∂T∂z(∂zᶜᶜᶠ(i, j, k, grid, tracers.T)) + nn_input[3] = ∂Tᵢ₊₁ = scaling.∂T∂z(∂zᶜᶜᶠ(i, j, k+1, grid, tracers.T)) + + nn_input[4] = ∂Sᵢ₋₁ = scaling.∂S∂z(∂zᶜᶜᶠ(i, j, k-1, grid, tracers.S)) + nn_input[5] = ∂Sᵢ = scaling.∂S∂z(∂zᶜᶜᶠ(i, j, k, grid, tracers.S)) + nn_input[6] = ∂Sᵢ₊₁ = scaling.∂S∂z(∂zᶜᶜᶠ(i, j, k+1, grid, tracers.S)) + + ρ₀ = buoyancy.model.equation_of_state.reference_density + g = buoyancy.model.gravitational_acceleration + + nn_input[7] = ∂σᵢ = scaling.∂ρ∂z(ρ₀ * ∂z_b(i, j, k, grid, buoyancy, tracers) / g) + nn_input[8] = ∂σᵢ₋₁ = scaling.∂ρ∂z(ρ₀ * ∂z_b(i, j, k, grid, buoyancy, tracers) / g) + nn_input[9] = ∂σᵢ₊₁ = scaling.∂ρ∂z(ρ₀ * ∂z_b(i, j, k, grid, buoyancy, tracers) / g) + + @inbounds wT = inv(scaling.wT)(closure.wT(nn_input)) + @inbounds wS = inv(scaling.wS)(closure.wS(nn_input)) + + @inbounds diffusivities.wT[i, j, k] = ifelse(k > grid.Nz - 2, 0, wT) + @inbounds diffusivities.wS[i, j, k] = ifelse(k > grid.Nz - 2, 0, wS) +end + +# Write here your constructor +# NNFluxClosure() = ... insert NN here ... (make sure it is on GPU if you need it on GPU!) + +const NNC = NNFluxClosure + +##### +##### Abstract Smagorinsky functionality +##### + +# Horizontal fluxes are zero! +@inline viscous_flux_wz( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_ux( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_uy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline diffusive_flux_x(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) +@inline diffusive_flux_y(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) + +# Viscous fluxes are zero (for now) +@inline viscous_flux_uz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) + +# The only function extended by NNFluxClosure +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{1}, c, clock, fields, buoyancy) = @inbounds K.wT[i, j, k] +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{2}, c, clock, fields, buoyancy) = @inbounds K.wS[i, j, k] \ No newline at end of file diff --git a/NN_closure_global.jl b/NN_closure_global.jl new file mode 100644 index 0000000000..63021bcd52 --- /dev/null +++ b/NN_closure_global.jl @@ -0,0 +1,200 @@ +import Oceananigans.TurbulenceClosures: + compute_diffusivities!, + DiffusivityFields, + viscosity, + diffusivity, + diffusive_flux_x, + diffusive_flux_y, + diffusive_flux_z, + viscous_flux_ux, + viscous_flux_uy, + viscous_flux_uz, + viscous_flux_vx, + viscous_flux_vy, + viscous_flux_vz, + viscous_flux_wx, + viscous_flux_wy, + viscous_flux_wz + +using Oceananigans.BuoyancyModels: ∂x_b, ∂y_b, ∂z_b +using Oceananigans.Coriolis +using Oceananigans.Grids: φnode +using Oceananigans.Grids: total_size +using Oceananigans.Utils: KernelParameters +using Oceananigans: architecture, on_architecture +using Lux, LuxCUDA +using JLD2 +using ComponentArrays +using OffsetArrays + +using KernelAbstractions: @index, @kernel, @private + +import Oceananigans.TurbulenceClosures: AbstractTurbulenceClosure, ExplicitTimeDiscretization + +using Adapt + +include("./feature_scaling.jl") + +@inline hack_sind(φ) = sin(φ * π / 180) + +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::HydrostaticSphericalCoriolis) = 2 * coriolis.rotation_rate * hack_sind(φnode(i, j, k, grid, Center(), Center(), Center())) +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::FPlane) = coriolis.f +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::BetaPlane) = coriolis.f₀ + coriolis.β * ynode(i, j, k, grid, Center(), Center(), Center()) + +struct NN{M, P, S} + model :: M + ps :: P + st :: S +end + +@inline (neural_network::NN)(input) = first(neural_network.model(input, neural_network.ps, neural_network.st)) +@inline tosarray(x::AbstractArray) = SArray{Tuple{size(x)...}}(x) + +struct NNFluxClosure{A <: NN, S} <: AbstractTurbulenceClosure{ExplicitTimeDiscretization, 3} + wT :: A + wS :: A + scaling :: S +end + +Adapt.adapt_structure(to, nn :: NNFluxClosure) = + NNFluxClosure(Adapt.adapt(to, nn.wT), + Adapt.adapt(to, nn.wS), + Adapt.adapt(to, nn.scaling)) + +Adapt.adapt_structure(to, nn :: NN) = + NN(Adapt.adapt(to, nn.model), + Adapt.adapt(to, nn.ps), + Adapt.adapt(to, nn.st)) + +function NNFluxClosure(arch) + dev = ifelse(arch == GPU(), gpu_device(), cpu_device()) + # nn_path = "./NDE_FC_Qb_absf_24simnew_2layer_128_relu_2Pr_model.jld2" + nn_path = "./NDE_FC_Qb_18simnew_2layer_128_relu_2Pr_model.jld2" + + ps, sts, scaling_params, wT_model, wS_model = jldopen(nn_path, "r") do file + ps = file["u"] |> dev |> f64 + sts = file["sts"] |> dev |> f64 + scaling_params = file["scaling"] + wT_model = file["model"].wT + wS_model = file["model"].wS + return ps, sts, scaling_params, wT_model, wS_model + end + + scaling = construct_zeromeanunitvariance_scaling(scaling_params) + + wT_NN = NN(wT_model, ps.wT, sts.wT) + wS_NN = NN(wS_model, ps.wS, sts.wS) + + return NNFluxClosure(wT_NN, wS_NN, scaling) +end + +function DiffusivityFields(grid, tracer_names, bcs, ::NNFluxClosure) + arch = architecture(grid) + wT = ZFaceField(grid) + wS = ZFaceField(grid) + + Nx_in, Ny_in, Nz_in = total_size(wT) + ox_in, oy_in, oz_in = wT.data.offsets + + wrk_in = OffsetArray(zeros(11, Nx_in, Ny_in, Nz_in), 0, ox_in, oy_in, oz_in) + wrk_in = on_architecture(arch, wrk_in) + + return (; wrk_in, wT, wS) +end + +function compute_diffusivities!(diffusivities, closure::NNFluxClosure, model; parameters = :xyz) + arch = model.architecture + grid = model.grid + velocities = model.velocities + tracers = model.tracers + buoyancy = model.buoyancy + coriolis = model.coriolis + clock = model.clock + top_tracer_bcs = NamedTuple(c => tracers[c].boundary_conditions.top for c in propertynames(tracers)) + input = diffusivities.wrk_in + wT = diffusivities.wT + wS = diffusivities.wS + + Nx_in, Ny_in, Nz_in = total_size(wT) + ox_in, oy_in, oz_in = wT.data.offsets + kp = KernelParameters((Nx_in, Ny_in, Nz_in), (ox_in, oy_in, oz_in)) + + launch!(arch, grid, kp, + _populate_input!, input, grid, closure, tracers, velocities, buoyancy, coriolis, top_tracer_bcs, clock) + + wT.data.parent .= dropdims(closure.wT(input.parent), dims=1) + wS.data.parent .= dropdims(closure.wS(input.parent), dims=1) + + launch!(arch, grid, kp, _rescale_nn_fluxes!, diffusivities, grid, closure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + launch!(arch, grid, kp, _adjust_nn_bottom_fluxes!, diffusivities, grid, closure) + return nothing +end + +@kernel function _populate_input!(input, grid, closure::NNFluxClosure, tracers, velocities, buoyancy, coriolis, top_tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + + scaling = closure.scaling + + ρ₀ = buoyancy.model.equation_of_state.reference_density + g = buoyancy.model.gravitational_acceleration + + @inbounds input[1, i, j, k] = ∂Tᵢ₋₁ = scaling.∂T∂z(∂zᶜᶜᶠ(i, j, k-1, grid, tracers.T)) + @inbounds input[2, i, j, k] = ∂Tᵢ = scaling.∂T∂z(∂zᶜᶜᶠ(i, j, k, grid, tracers.T)) + @inbounds input[3, i, j, k] = ∂Tᵢ₊₁ = scaling.∂T∂z(∂zᶜᶜᶠ(i, j, k+1, grid, tracers.T)) + + @inbounds input[4, i, j, k] = ∂Sᵢ₋₁ = scaling.∂S∂z(∂zᶜᶜᶠ(i, j, k-1, grid, tracers.S)) + @inbounds input[5, i, j, k] = ∂Sᵢ = scaling.∂S∂z(∂zᶜᶜᶠ(i, j, k, grid, tracers.S)) + @inbounds input[6, i, j, k] = ∂Sᵢ₊₁ = scaling.∂S∂z(∂zᶜᶜᶠ(i, j, k+1, grid, tracers.S)) + + @inbounds input[7, i, j, k] = ∂σᵢ₋₁ = scaling.∂ρ∂z(-ρ₀ * ∂z_b(i, j, k-1, grid, buoyancy, tracers) / g) + @inbounds input[8, i, j, k] = ∂σᵢ = scaling.∂ρ∂z(-ρ₀ * ∂z_b(i, j, k, grid, buoyancy, tracers) / g) + @inbounds input[9, i, j, k] = ∂σᵢ₊₁ = scaling.∂ρ∂z(-ρ₀ * ∂z_b(i, j, k+1, grid, buoyancy, tracers) / g) + + @inbounds input[10, i, j, k] = Jᵇ = scaling.wb(top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers))) + # @inbounds input[11, i, j, k] = fᶜᶜ = scaling.f(abs(fᶜᶜᵃ(i, j, k, grid, coriolis))) + @inbounds input[11, i, j, k] = fᶜᶜ = scaling.f(fᶜᶜᵃ(i, j, k, grid, coriolis)) +end + +@kernel function _rescale_nn_fluxes!(diffusivities, grid, closure::NNFluxClosure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + scaling = closure.scaling + convecting = top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers)) > 0 + interior_point = k <= grid.Nz - 1 & k >= 3 + + @inbounds diffusivities.wT[i, j, k] = ifelse(convecting & interior_point, inv(scaling.wT)(diffusivities.wT[i, j, k]) - inv(scaling.wT)(0), 0) + @inbounds diffusivities.wS[i, j, k] = ifelse(convecting & interior_point, inv(scaling.wS)(diffusivities.wS[i, j, k]) - inv(scaling.wS)(0), 0) +end + +@kernel function _adjust_nn_bottom_fluxes!(diffusivities, grid, closure::NNFluxClosure) + i, j, k = @index(Global, NTuple) + @inbounds diffusivities.wT[i, j, k] = ifelse(k <= 3, diffusivities.wT[i, j, 4], diffusivities.wT[i, j, k]) + @inbounds diffusivities.wS[i, j, k] = ifelse(k <= 3, diffusivities.wS[i, j, 4], diffusivities.wS[i, j, k]) +end + +# Write here your constructor +# NNFluxClosure() = ... insert NN here ... (make sure it is on GPU if you need it on GPU!) + +const NNC = NNFluxClosure + +##### +##### Abstract Smagorinsky functionality +##### + +# Horizontal fluxes are zero! +@inline viscous_flux_wz( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_ux( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_uy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline diffusive_flux_x(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) +@inline diffusive_flux_y(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) + +# Viscous fluxes are zero (for now) +@inline viscous_flux_uz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) + +# The only function extended by NNFluxClosure +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{1}, c, clock, fields, buoyancy) = @inbounds K.wT[i, j, k] +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{2}, c, clock, fields, buoyancy) = @inbounds K.wS[i, j, k] \ No newline at end of file diff --git a/NN_closure_global_Ri_nof_BBLRifirstzone510.jl b/NN_closure_global_Ri_nof_BBLRifirstzone510.jl new file mode 100644 index 0000000000..bde30dfe58 --- /dev/null +++ b/NN_closure_global_Ri_nof_BBLRifirstzone510.jl @@ -0,0 +1,294 @@ +import Oceananigans.TurbulenceClosures: + compute_diffusivities!, + DiffusivityFields, + viscosity, + diffusivity, + diffusive_flux_x, + diffusive_flux_y, + diffusive_flux_z, + viscous_flux_ux, + viscous_flux_uy, + viscous_flux_uz, + viscous_flux_vx, + viscous_flux_vy, + viscous_flux_vz, + viscous_flux_wx, + viscous_flux_wy, + viscous_flux_wz + +using Oceananigans.BuoyancyModels: ∂x_b, ∂y_b, ∂z_b +using Oceananigans.Coriolis +using Oceananigans.Grids: φnode +using Oceananigans.Grids: total_size +using Oceananigans.Utils: KernelParameters +using Oceananigans: architecture, on_architecture +using Lux, LuxCUDA +using JLD2 +using ComponentArrays +using OffsetArrays +using SeawaterPolynomials.TEOS10 + +using KernelAbstractions: @index, @kernel, @private + +import Oceananigans.TurbulenceClosures: AbstractTurbulenceClosure, ExplicitTimeDiscretization + +using Adapt + +include("./feature_scaling.jl") + +@inline hack_sind(φ) = sin(φ * π / 180) + +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::HydrostaticSphericalCoriolis) = 2 * coriolis.rotation_rate * hack_sind(φnode(i, j, k, grid, Center(), Center(), Center())) +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::FPlane) = coriolis.f +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::BetaPlane) = coriolis.f₀ + coriolis.β * ynode(i, j, k, grid, Center(), Center(), Center()) + +struct NN{M, P, S} + model :: M + ps :: P + st :: S +end + +@inline (neural_network::NN)(input) = first(neural_network.model(input, neural_network.ps, neural_network.st)) +@inline tosarray(x::AbstractArray) = SArray{Tuple{size(x)...}}(x) + +struct NNFluxClosure{A <: NN, S, G} <: AbstractTurbulenceClosure{ExplicitTimeDiscretization, 3} + wT :: A + wS :: A + scaling :: S + grid_point_above :: G + grid_point_below :: G +end + +Adapt.adapt_structure(to, nn :: NNFluxClosure) = + NNFluxClosure(Adapt.adapt(to, nn.wT), + Adapt.adapt(to, nn.wS), + Adapt.adapt(to, nn.scaling), + Adapt.adapt(to, nn.grid_point_above), + Adapt.adapt(to, nn.grid_point_below)) + +Adapt.adapt_structure(to, nn :: NN) = + NN(Adapt.adapt(to, nn.model), + Adapt.adapt(to, nn.ps), + Adapt.adapt(to, nn.st)) + +function NNFluxClosure(arch) + dev = ifelse(arch == GPU(), gpu_device(), cpu_device()) + nn_path = "./NDE5_FC_Qb_Ri_nof_BBLRifirst510_train62newnohighrotation_scalingtrain62newnohighrotation_validate30new_3layer_128_relu_30seed_2Pr_ls5_model_temp.jld2" + + ps, sts, scaling_params, wT_model, wS_model = jldopen(nn_path, "r") do file + ps = file["u_validation"] |> dev |> f64 + sts = file["sts"] |> dev |> f64 + scaling_params = file["scaling"] + wT_model = file["model"].wT + wS_model = file["model"].wS + return ps, sts, scaling_params, wT_model, wS_model + end + + scaling = construct_zeromeanunitvariance_scaling(scaling_params) + + wT_NN = NN(wT_model, ps.wT, sts.wT) + wS_NN = NN(wS_model, ps.wS, sts.wS) + + grid_point_above = 10 + grid_point_below = 5 + + return NNFluxClosure(wT_NN, wS_NN, scaling, grid_point_above, grid_point_below) +end + +function DiffusivityFields(grid, tracer_names, bcs, closure::NNFluxClosure) + arch = architecture(grid) + wT = ZFaceField(grid) + wS = ZFaceField(grid) + first_index = Field((Center, Center, Nothing), grid, Int32) + last_index = Field((Center, Center, Nothing), grid, Int32) + + N_input = closure.wT.model.layers.layer_1.in_dims + N_levels = closure.grid_point_above + closure.grid_point_below + + Nx_in, Ny_in, _ = size(wT) + wrk_in = zeros(N_input, Nx_in, Ny_in, N_levels) + wrk_in = on_architecture(arch, wrk_in) + + wrk_wT = zeros(Nx_in, Ny_in, 15) + wrk_wS = zeros(Nx_in, Ny_in, 15) + wrk_wT = on_architecture(arch, wrk_wT) + wrk_wS = on_architecture(arch, wrk_wS) + + return (; wrk_in, wrk_wT, wrk_wS, wT, wS, first_index, last_index) +end + +function compute_diffusivities!(diffusivities, closure::NNFluxClosure, model; parameters = :xyz) + arch = model.architecture + grid = model.grid + velocities = model.velocities + tracers = model.tracers + buoyancy = model.buoyancy + coriolis = model.coriolis + clock = model.clock + top_tracer_bcs = NamedTuple(c => tracers[c].boundary_conditions.top for c in propertynames(tracers)) + + Riᶜ = model.closure[1].Riᶜ + Ri = model.diffusivity_fields[1].Ri + + wrk_in = diffusivities.wrk_in + wrk_wT = diffusivities.wrk_wT + wrk_wS = diffusivities.wrk_wS + wT = diffusivities.wT + wS = diffusivities.wS + + first_index = diffusivities.first_index + last_index = diffusivities.last_index + + Nx_in, Ny_in, Nz_in = total_size(wT) + ox_in, oy_in, oz_in = wT.data.offsets + kp = KernelParameters((Nx_in, Ny_in, Nz_in), (ox_in, oy_in, oz_in)) + kp_2D = KernelParameters((Nx_in, Ny_in), (ox_in, oy_in)) + + N_levels = closure.grid_point_above + closure.grid_point_below + + kp_wrk = KernelParameters((Nx_in, Ny_in, N_levels), (0, 0, 0)) + + launch!(arch, grid, kp_2D, _find_NN_active_region!, Ri, grid, Riᶜ, first_index, last_index, closure) + + launch!(arch, grid, kp_wrk, + _populate_input!, wrk_in, first_index, last_index, grid, closure, tracers, velocities, buoyancy, top_tracer_bcs, Ri, clock) + + wrk_wT .= dropdims(closure.wT(wrk_in), dims=1) + wrk_wS .= dropdims(closure.wS(wrk_in), dims=1) + + launch!(arch, grid, kp, _fill_adjust_nn_fluxes!, diffusivities, first_index, last_index, wrk_wT, wrk_wS, grid, closure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + return nothing +end + +@kernel function _populate_input!(input, first_index, last_index, grid, closure::NNFluxClosure, tracers, velocities, buoyancy, top_tracer_bcs, Ri, clock) + i, j, k = @index(Global, NTuple) + + scaling = closure.scaling + + ρ₀ = buoyancy.model.equation_of_state.reference_density + g = buoyancy.model.gravitational_acceleration + + @inbounds k_first = first_index[i, j, 1] + @inbounds k_last = last_index[i, j, 1] + + quiescent = quiescent_condition(k_first, k_last) + + @inbounds k_tracer = clamp_k_interior(k_first + k - 1, grid) + + @inbounds k₋₂ = clamp_k_interior(k_tracer - 2, grid) + @inbounds k₋₁ = clamp_k_interior(k_tracer - 1, grid) + @inbounds k₀ = clamp_k_interior(k_tracer, grid) + @inbounds k₊₁ = clamp_k_interior(k_tracer + 1, grid) + @inbounds k₊₂ = clamp_k_interior(k_tracer + 2, grid) + + T, S = tracers.T, tracers.S + + @inbounds input[1, i, j, k] = ifelse(quiescent, 0, atan(Ri[i, j, k₋₂])) + @inbounds input[2, i, j, k] = ifelse(quiescent, 0, atan(Ri[i, j, k₋₁])) + @inbounds input[3, i, j, k] = ifelse(quiescent, 0, atan(Ri[i, j, k₀])) + @inbounds input[4, i, j, k] = ifelse(quiescent, 0, atan(Ri[i, j, k₊₁])) + @inbounds input[5, i, j, k] = ifelse(quiescent, 0, atan(Ri[i, j, k₊₂])) + + @inbounds input[6, i, j, k] = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₋₂, grid, T))) + @inbounds input[7, i, j, k] = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₋₁, grid, T))) + @inbounds input[8, i, j, k] = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₀, grid, T))) + @inbounds input[9, i, j, k] = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₊₁, grid, T))) + @inbounds input[10, i, j, k] = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₊₂, grid, T))) + + @inbounds input[11, i, j, k] = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₋₂, grid, S))) + @inbounds input[12, i, j, k] = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₋₁, grid, S))) + @inbounds input[13, i, j, k] = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₀, grid, S))) + @inbounds input[14, i, j, k] = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₊₁, grid, S))) + @inbounds input[15, i, j, k] = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₊₂, grid, S))) + + @inbounds input[16, i, j, k] = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k₋₂, grid, buoyancy, tracers) / g)) + @inbounds input[17, i, j, k] = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k₋₁, grid, buoyancy, tracers) / g)) + @inbounds input[18, i, j, k] = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k₀, grid, buoyancy, tracers) / g)) + @inbounds input[19, i, j, k] = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k₊₁, grid, buoyancy, tracers) / g)) + @inbounds input[20, i, j, k] = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k₊₂, grid, buoyancy, tracers) / g)) + + @inbounds input[21, i, j, k] = scaling.wb(top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers))) + +end + +@kernel function _find_NN_active_region!(Ri, grid, Riᶜ, first_index, last_index, closure::NNFluxClosure) + i, j = @index(Global, NTuple) + top_index = grid.Nz + 1 + grid_point_above_kappa = closure.grid_point_above + grid_point_below_kappa = closure.grid_point_below + + # Find the first index of the background κᶜ + kloc = grid.Nz+1 + @inbounds for k in grid.Nz:-1:2 + kloc = ifelse(Ri[i, j, k] < Riᶜ, k, kloc) + end + + background_κ_index = kloc - 1 + nonbackground_κ_index = background_κ_index + 1 + + @inbounds last_index[i, j, 1] = ifelse(nonbackground_κ_index == top_index, top_index - 1, clamp_k_interior(background_κ_index + grid_point_above_kappa, grid)) + @inbounds first_index[i, j, 1] = ifelse(nonbackground_κ_index == top_index, top_index, clamp_k_interior(background_κ_index - grid_point_below_kappa + 1, grid)) +end + +@inline function quiescent_condition(lo, hi) + return hi < lo +end + +@inline function within_zone_condition(k, lo, hi) + return (k >= lo) & (k <= hi) +end + +@inline function clamp_k_interior(k, grid) + kmax = grid.Nz + kmin = 2 + + return clamp(k, kmin, kmax) +end + +@kernel function _fill_adjust_nn_fluxes!(diffusivities, first_index, last_index, wrk_wT, wrk_wS, grid, closure::NNFluxClosure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + scaling = closure.scaling + + @inbounds k_first = first_index[i, j, 1] + @inbounds k_last = last_index[i, j, 1] + + convecting = top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers)) > 0 + quiescent = quiescent_condition(k_first, k_last) + within_zone = within_zone_condition(k, k_first, k_last) + + N_levels = closure.grid_point_above + closure.grid_point_below + @inbounds k_wrk = clamp(k - k_first + 1, 1, N_levels) + + NN_active = convecting & !quiescent & within_zone + + @inbounds diffusivities.wT[i, j, k] = ifelse(NN_active, scaling.wT.σ * wrk_wT[i, j, k_wrk], 0) + @inbounds diffusivities.wS[i, j, k] = ifelse(NN_active, scaling.wS.σ * wrk_wS[i, j, k_wrk], 0) +end + +# Write here your constructor +# NNFluxClosure() = ... insert NN here ... (make sure it is on GPU if you need it on GPU!) + +const NNC = NNFluxClosure + +##### +##### Abstract Smagorinsky functionality +##### + +# Horizontal fluxes are zero! +@inline viscous_flux_wz( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_ux( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_uy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline diffusive_flux_x(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) +@inline diffusive_flux_y(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) + +# Viscous fluxes are zero (for now) +@inline viscous_flux_uz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) + +# The only function extended by NNFluxClosure +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{1}, c, clock, fields, buoyancy) = @inbounds K.wT[i, j, k] +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{2}, c, clock, fields, buoyancy) = @inbounds K.wS[i, j, k] \ No newline at end of file diff --git a/NN_closure_global_Ri_nof_BBLkappazonelast55.jl b/NN_closure_global_Ri_nof_BBLkappazonelast55.jl new file mode 100644 index 0000000000..2255a8697b --- /dev/null +++ b/NN_closure_global_Ri_nof_BBLkappazonelast55.jl @@ -0,0 +1,268 @@ +import Oceananigans.TurbulenceClosures: + compute_diffusivities!, + DiffusivityFields, + viscosity, + diffusivity, + diffusive_flux_x, + diffusive_flux_y, + diffusive_flux_z, + viscous_flux_ux, + viscous_flux_uy, + viscous_flux_uz, + viscous_flux_vx, + viscous_flux_vy, + viscous_flux_vz, + viscous_flux_wx, + viscous_flux_wy, + viscous_flux_wz + +using Oceananigans.BuoyancyModels: ∂x_b, ∂y_b, ∂z_b +using Oceananigans.Coriolis +using Oceananigans.Grids: φnode +using Oceananigans.Grids: total_size +using Oceananigans.Utils: KernelParameters +using Oceananigans: architecture, on_architecture +using Lux, LuxCUDA +using JLD2 +using ComponentArrays +using OffsetArrays +using SeawaterPolynomials.TEOS10 + +using KernelAbstractions: @index, @kernel, @private + +import Oceananigans.TurbulenceClosures: AbstractTurbulenceClosure, ExplicitTimeDiscretization + +using Adapt + +include("./feature_scaling.jl") + +@inline hack_sind(φ) = sin(φ * π / 180) + +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::HydrostaticSphericalCoriolis) = 2 * coriolis.rotation_rate * hack_sind(φnode(i, j, k, grid, Center(), Center(), Center())) +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::FPlane) = coriolis.f +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::BetaPlane) = coriolis.f₀ + coriolis.β * ynode(i, j, k, grid, Center(), Center(), Center()) + +struct NN{M, P, S} + model :: M + ps :: P + st :: S +end + +@inline (neural_network::NN)(input) = first(neural_network.model(input, neural_network.ps, neural_network.st)) +@inline tosarray(x::AbstractArray) = SArray{Tuple{size(x)...}}(x) + +struct NNFluxClosure{A <: NN, S} <: AbstractTurbulenceClosure{ExplicitTimeDiscretization, 3} + wT :: A + wS :: A + scaling :: S +end + +Adapt.adapt_structure(to, nn :: NNFluxClosure) = + NNFluxClosure(Adapt.adapt(to, nn.wT), + Adapt.adapt(to, nn.wS), + Adapt.adapt(to, nn.scaling)) + +Adapt.adapt_structure(to, nn :: NN) = + NN(Adapt.adapt(to, nn.model), + Adapt.adapt(to, nn.ps), + Adapt.adapt(to, nn.st)) + +function NNFluxClosure(arch) + dev = ifelse(arch == GPU(), gpu_device(), cpu_device()) + nn_path = "./NDE_FC_Qb_Ri_nof_BBLkappazonelast55_trainFC19new_scalingtrain59new_2layer_256_relu_10seed_2Pr_model_temp.jld2" + + ps, sts, scaling_params, wT_model, wS_model = jldopen(nn_path, "r") do file + ps = file["u"] |> dev |> f64 + sts = file["sts"] |> dev |> f64 + scaling_params = file["scaling"] + wT_model = file["model"].wT + wS_model = file["model"].wS + return ps, sts, scaling_params, wT_model, wS_model + end + + scaling = construct_zeromeanunitvariance_scaling(scaling_params) + + wT_NN = NN(wT_model, ps.wT, sts.wT) + wS_NN = NN(wS_model, ps.wS, sts.wS) + + return NNFluxClosure(wT_NN, wS_NN, scaling) +end + +function DiffusivityFields(grid, tracer_names, bcs, ::NNFluxClosure) + arch = architecture(grid) + wT = ZFaceField(grid) + wS = ZFaceField(grid) + first_index = Field((Center, Center, Nothing), grid, Int32) + last_index = Field((Center, Center, Nothing), grid, Int32) + + Nx_in, Ny_in, _ = size(wT) + wrk_in = zeros(13, Nx_in, Ny_in, 10) + wrk_in = on_architecture(arch, wrk_in) + + wrk_wT = zeros(Nx_in, Ny_in, 10) + wrk_wS = zeros(Nx_in, Ny_in, 10) + wrk_wT = on_architecture(arch, wrk_wT) + wrk_wS = on_architecture(arch, wrk_wS) + + return (; wrk_in, wrk_wT, wrk_wS, wT, wS, first_index, last_index) +end + +function compute_diffusivities!(diffusivities, closure::NNFluxClosure, model; parameters = :xyz) + arch = model.architecture + grid = model.grid + velocities = model.velocities + tracers = model.tracers + buoyancy = model.buoyancy + coriolis = model.coriolis + clock = model.clock + top_tracer_bcs = NamedTuple(c => tracers[c].boundary_conditions.top for c in propertynames(tracers)) + + κᶜ = model.diffusivity_fields[1].κᶜ + κ₀ = model.closure[1].ν₀ / model.closure[1].Pr_shearₜ + Ri = model.diffusivity_fields[1].Ri + + wrk_in = diffusivities.wrk_in + wrk_wT = diffusivities.wrk_wT + wrk_wS = diffusivities.wrk_wS + wT = diffusivities.wT + wS = diffusivities.wS + + first_index = diffusivities.first_index + last_index = diffusivities.last_index + + Nx_in, Ny_in, Nz_in = total_size(wT) + ox_in, oy_in, oz_in = wT.data.offsets + kp = KernelParameters((Nx_in, Ny_in, Nz_in), (ox_in, oy_in, oz_in)) + kp_2D = KernelParameters((Nx_in, Ny_in), (ox_in, oy_in)) + + kp_wrk = KernelParameters((Nx_in, Ny_in, 10), (0, 0, 0)) + + launch!(arch, grid, kp_2D, _find_NN_active_region!, κᶜ, grid, κ₀, first_index, last_index) + + launch!(arch, grid, kp_wrk, + _populate_input!, wrk_in, first_index, last_index, grid, closure, tracers, velocities, buoyancy, top_tracer_bcs, Ri, clock) + + wrk_wT .= dropdims(closure.wT(wrk_in), dims=1) + wrk_wS .= dropdims(closure.wS(wrk_in), dims=1) + + launch!(arch, grid, kp, _fill_adjust_nn_fluxes!, diffusivities, first_index, last_index, wrk_wT, wrk_wS, grid, closure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + return nothing +end + +@kernel function _populate_input!(input, first_index, last_index, grid, closure::NNFluxClosure, tracers, velocities, buoyancy, top_tracer_bcs, Ri, clock) + i, j, k = @index(Global, NTuple) + + scaling = closure.scaling + + ρ₀ = buoyancy.model.equation_of_state.reference_density + g = buoyancy.model.gravitational_acceleration + eos = TEOS10.TEOS10EquationOfState() + + @inbounds k_first = first_index[i, j, 1] + @inbounds k_last = last_index[i, j, 1] + + quiescent = quiescent_condition(k_first, k_last) + + @inbounds k_tracer = clamp_k_interior(k_first + k - 1, grid) + + T, S = tracers.T, tracers.S + + @inbounds input[1, i, j, k] = Riᵢ₋₁ = ifelse(quiescent, 0, atan(Ri[i, j, k_tracer-1])) + @inbounds input[2, i, j, k] = Riᵢ = ifelse(quiescent, 0, atan(Ri[i, j, k_tracer])) + @inbounds input[3, i, j, k] = Riᵢ₊₁ = ifelse(quiescent, 0, atan(Ri[i, j, k_tracer+1])) + + @inbounds input[4, i, j, k] = ∂Tᵢ₋₁ = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer-1, grid, T))) + @inbounds input[5, i, j, k] = ∂Tᵢ = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer, grid, T))) + @inbounds input[6, i, j, k] = ∂Tᵢ₊₁ = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer+1, grid, T))) + + @inbounds input[7, i, j, k] = ∂Sᵢ₋₁ = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer-1, grid, S))) + @inbounds input[8, i, j, k] = ∂Sᵢ = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer, grid, S))) + @inbounds input[9, i, j, k] = ∂Sᵢ₊₁ = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer+1, grid, S))) + + @inbounds input[10, i, j, k] = ∂σᵢ₋₁ = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k_tracer-1, grid, buoyancy, tracers) / g)) + @inbounds input[11, i, j, k] = ∂σᵢ = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k_tracer, grid, buoyancy, tracers) / g)) + @inbounds input[12, i, j, k] = ∂σᵢ₊₁ = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k_tracer+1, grid, buoyancy, tracers) / g)) + + @inbounds input[13, i, j, k] = Jᵇ = scaling.wb(top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers))) + +end + +@kernel function _find_NN_active_region!(κᶜ, grid, κ₀, first_index, last_index) + i, j = @index(Global, NTuple) + top_index = grid.Nz + 1 + grid_point_above_kappa = 5 + grid_point_below_kappa = 5 + + # Find the last index of the background κᶜ + kloc = 1 + @inbounds for k in 2:grid.Nz + kloc = ifelse(κᶜ[i, j, k] ≈ κ₀, k, kloc) + end + + nonbackground_κ_index = kloc + 1 + + @inbounds last_index[i, j, 1] = ifelse(nonbackground_κ_index == top_index, top_index - 1, clamp_k_interior(kloc + grid_point_above_kappa, grid)) + @inbounds first_index[i, j, 1] = ifelse(nonbackground_κ_index == top_index, top_index, clamp_k_interior(kloc - grid_point_below_kappa + 1, grid)) +end + +@inline function quiescent_condition(lo, hi) + return hi < lo +end + +@inline function within_zone_condition(k, lo, hi) + return (k >= lo) & (k <= hi) +end + +@inline function clamp_k_interior(k, grid) + kmax = grid.Nz - 1 + kmin = 3 + + return clamp(k, kmin, kmax) +end + +@kernel function _fill_adjust_nn_fluxes!(diffusivities, first_index, last_index, wrk_wT, wrk_wS, grid, closure::NNFluxClosure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + scaling = closure.scaling + + @inbounds k_first = first_index[i, j, 1] + @inbounds k_last = last_index[i, j, 1] + + convecting = top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers)) > 0 + quiescent = quiescent_condition(k_first, k_last) + within_zone = within_zone_condition(k, k_first, k_last) + + @inbounds k_wrk = clamp(k - k_first + 1, 1, 10) + + NN_active = convecting & !quiescent & within_zone + + @inbounds diffusivities.wT[i, j, k] = ifelse(NN_active, scaling.wT.σ * wrk_wT[i, j, k_wrk], 0) + @inbounds diffusivities.wS[i, j, k] = ifelse(NN_active, scaling.wS.σ * wrk_wS[i, j, k_wrk], 0) +end + +# Write here your constructor +# NNFluxClosure() = ... insert NN here ... (make sure it is on GPU if you need it on GPU!) + +const NNC = NNFluxClosure + +##### +##### Abstract Smagorinsky functionality +##### + +# Horizontal fluxes are zero! +@inline viscous_flux_wz( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_ux( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_uy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline diffusive_flux_x(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) +@inline diffusive_flux_y(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) + +# Viscous fluxes are zero (for now) +@inline viscous_flux_uz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) + +# The only function extended by NNFluxClosure +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{1}, c, clock, fields, buoyancy) = @inbounds K.wT[i, j, k] +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{2}, c, clock, fields, buoyancy) = @inbounds K.wS[i, j, k] \ No newline at end of file diff --git a/NN_closure_global_nof_BBL.jl b/NN_closure_global_nof_BBL.jl new file mode 100644 index 0000000000..995e3aa2b9 --- /dev/null +++ b/NN_closure_global_nof_BBL.jl @@ -0,0 +1,224 @@ +import Oceananigans.TurbulenceClosures: + compute_diffusivities!, + DiffusivityFields, + viscosity, + diffusivity, + diffusive_flux_x, + diffusive_flux_y, + diffusive_flux_z, + viscous_flux_ux, + viscous_flux_uy, + viscous_flux_uz, + viscous_flux_vx, + viscous_flux_vy, + viscous_flux_vz, + viscous_flux_wx, + viscous_flux_wy, + viscous_flux_wz + +using Oceananigans.BuoyancyModels: ∂x_b, ∂y_b, ∂z_b +using Oceananigans.Coriolis +using Oceananigans.Grids: φnode +using Oceananigans.Grids: total_size +using Oceananigans.Utils: KernelParameters +using Oceananigans: architecture, on_architecture +using Lux, LuxCUDA +using JLD2 +using ComponentArrays +using OffsetArrays + +using KernelAbstractions: @index, @kernel, @private + +import Oceananigans.TurbulenceClosures: AbstractTurbulenceClosure, ExplicitTimeDiscretization + +using Adapt + +include("./feature_scaling.jl") + +@inline hack_sind(φ) = sin(φ * π / 180) + +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::HydrostaticSphericalCoriolis) = 2 * coriolis.rotation_rate * hack_sind(φnode(i, j, k, grid, Center(), Center(), Center())) +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::FPlane) = coriolis.f +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::BetaPlane) = coriolis.f₀ + coriolis.β * ynode(i, j, k, grid, Center(), Center(), Center()) + +struct NN{M, P, S} + model :: M + ps :: P + st :: S +end + +@inline (neural_network::NN)(input) = first(neural_network.model(input, neural_network.ps, neural_network.st)) +@inline tosarray(x::AbstractArray) = SArray{Tuple{size(x)...}}(x) + +struct NNFluxClosure{A <: NN, S} <: AbstractTurbulenceClosure{ExplicitTimeDiscretization, 3} + wT :: A + wS :: A + scaling :: S +end + +Adapt.adapt_structure(to, nn :: NNFluxClosure) = + NNFluxClosure(Adapt.adapt(to, nn.wT), + Adapt.adapt(to, nn.wS), + Adapt.adapt(to, nn.scaling)) + +Adapt.adapt_structure(to, nn :: NN) = + NN(Adapt.adapt(to, nn.model), + Adapt.adapt(to, nn.ps), + Adapt.adapt(to, nn.st)) + +function NNFluxClosure(arch) + dev = ifelse(arch == GPU(), gpu_device(), cpu_device()) + nn_path = "./NDE_FC_Qb_nof_BBL_trainFC24new_scalingtrain54new_2layer_64_relu_2Pr_model.jld2" + + ps, sts, scaling_params, wT_model, wS_model = jldopen(nn_path, "r") do file + ps = file["u"] |> dev |> f64 + sts = file["sts"] |> dev |> f64 + scaling_params = file["scaling"] + wT_model = file["model"].wT + wS_model = file["model"].wS + return ps, sts, scaling_params, wT_model, wS_model + end + + scaling = construct_zeromeanunitvariance_scaling(scaling_params) + + wT_NN = NN(wT_model, ps.wT, sts.wT) + wS_NN = NN(wS_model, ps.wS, sts.wS) + + return NNFluxClosure(wT_NN, wS_NN, scaling) +end + +function DiffusivityFields(grid, tracer_names, bcs, ::NNFluxClosure) + arch = architecture(grid) + wT = ZFaceField(grid) + wS = ZFaceField(grid) + ∂ρ²∂z² = ZFaceField(grid) + BBL_index = Field((Center, Center, Nothing), grid, Int32) + + Nx_in, Ny_in, Nz_in = total_size(wT) + ox_in, oy_in, oz_in = wT.data.offsets + + wrk_in = OffsetArray(zeros(10, Nx_in, Ny_in, Nz_in), 0, ox_in, oy_in, oz_in) + wrk_in = on_architecture(arch, wrk_in) + + return (; wrk_in, wT, wS, ∂ρ²∂z², BBL_index) +end + +function compute_diffusivities!(diffusivities, closure::NNFluxClosure, model; parameters = :xyz) + arch = model.architecture + grid = model.grid + velocities = model.velocities + tracers = model.tracers + buoyancy = model.buoyancy + coriolis = model.coriolis + clock = model.clock + top_tracer_bcs = NamedTuple(c => tracers[c].boundary_conditions.top for c in propertynames(tracers)) + input = diffusivities.wrk_in + wT = diffusivities.wT + wS = diffusivities.wS + ∂ρ²∂z² = diffusivities.∂ρ²∂z² + BBL_index = diffusivities.BBL_index + + Nx_in, Ny_in, Nz_in = total_size(wT) + ox_in, oy_in, oz_in = wT.data.offsets + kp = KernelParameters((Nx_in, Ny_in, Nz_in), (ox_in, oy_in, oz_in)) + kp_2D = KernelParameters((Nx_in, Ny_in), (ox_in, oy_in)) + + launch!(arch, grid, kp, + _populate_input!, input, ∂ρ²∂z², grid, closure, tracers, velocities, buoyancy, coriolis, top_tracer_bcs, clock) + + launch!(arch, grid, kp_2D, _find_base_boundary_layer!, ∂ρ²∂z², grid, BBL_index) + + wT.data.parent .= dropdims(closure.wT(input.parent), dims=1) + wS.data.parent .= dropdims(closure.wS(input.parent), dims=1) + + launch!(arch, grid, kp, _adjust_nn_fluxes!, diffusivities, grid, closure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + return nothing +end + +@kernel function _populate_input!(input, ∂ρ²∂z², grid, closure::NNFluxClosure, tracers, velocities, buoyancy, coriolis, top_tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + + scaling = closure.scaling + + ρ₀ = buoyancy.model.equation_of_state.reference_density + g = buoyancy.model.gravitational_acceleration + + @inbounds input[1, i, j, k] = ∂Tᵢ₋₁ = scaling.∂T∂z(∂zᶜᶜᶠ(i, j, k-1, grid, tracers.T)) + @inbounds input[2, i, j, k] = ∂Tᵢ = scaling.∂T∂z(∂zᶜᶜᶠ(i, j, k, grid, tracers.T)) + @inbounds input[3, i, j, k] = ∂Tᵢ₊₁ = scaling.∂T∂z(∂zᶜᶜᶠ(i, j, k+1, grid, tracers.T)) + + @inbounds input[4, i, j, k] = ∂Sᵢ₋₁ = scaling.∂S∂z(∂zᶜᶜᶠ(i, j, k-1, grid, tracers.S)) + @inbounds input[5, i, j, k] = ∂Sᵢ = scaling.∂S∂z(∂zᶜᶜᶠ(i, j, k, grid, tracers.S)) + @inbounds input[6, i, j, k] = ∂Sᵢ₊₁ = scaling.∂S∂z(∂zᶜᶜᶠ(i, j, k+1, grid, tracers.S)) + + @inbounds input[7, i, j, k] = ∂σᵢ₋₁ = scaling.∂ρ∂z(-ρ₀ * ∂z_b(i, j, k-1, grid, buoyancy, tracers) / g) + @inbounds input[8, i, j, k] = ∂σᵢ = scaling.∂ρ∂z(-ρ₀ * ∂z_b(i, j, k, grid, buoyancy, tracers) / g) + @inbounds input[9, i, j, k] = ∂σᵢ₊₁ = scaling.∂ρ∂z(-ρ₀ * ∂z_b(i, j, k+1, grid, buoyancy, tracers) / g) + + @inbounds input[10, i, j, k] = Jᵇ = scaling.wb(top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers))) + + @inbounds ∂ρ²∂z²[i, j, k] = abs(-ρ₀ * ∂zᶜᶜᶜ(i, j, k, grid, ∂z_b, buoyancy, tracers) / g) +end + +@inline function find_field_max!(i, j, field, grid, h) + kmax = grid.Nz + @inbounds maxf = field[i, j, grid.Nz-1] + + @inbounds for k in grid.Nz-2:-1:2 + kmax = ifelse(field[i, j, k] > maxf, k, kmax) + maxf = ifelse(field[i, j, k] > maxf, field[i, j, k], maxf) + end + + @inbounds h[i, j, 1] = kmax +end + +@kernel function _find_base_boundary_layer!(∂ρ²∂z², grid, h) + i, j = @index(Global, NTuple) + find_field_max!(i, j, ∂ρ²∂z², grid, h) + + @inbounds h[i, j, 1] = ifelse(h[i, j, 1] < 7, ifelse(h[i, j, 1] == 2, grid.Nz+1, 4), h[i, j, 1] - 3) +end + +@kernel function _adjust_nn_fluxes!(diffusivities, grid, closure::NNFluxClosure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + scaling = closure.scaling + @inbounds BBL_index = diffusivities.BBL_index[i, j, 1] + + convecting = top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers)) > 0 + above_base_boundary_layer = k > BBL_index + below_top = k <= grid.Nz - 1 + above_bottom = k >= 3 + + NN_active = convecting & above_base_boundary_layer & below_top & above_bottom + + @inbounds diffusivities.wT[i, j, k] = ifelse(NN_active, scaling.wT.σ * diffusivities.wT[i, j, k], 0) + @inbounds diffusivities.wS[i, j, k] = ifelse(NN_active, scaling.wS.σ * diffusivities.wS[i, j, k], 0) +end + +# Write here your constructor +# NNFluxClosure() = ... insert NN here ... (make sure it is on GPU if you need it on GPU!) + +const NNC = NNFluxClosure + +##### +##### Abstract Smagorinsky functionality +##### + +# Horizontal fluxes are zero! +@inline viscous_flux_wz( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_ux( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_uy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline diffusive_flux_x(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) +@inline diffusive_flux_y(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) + +# Viscous fluxes are zero (for now) +@inline viscous_flux_uz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) + +# The only function extended by NNFluxClosure +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{1}, c, clock, fields, buoyancy) = @inbounds K.wT[i, j, k] +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{2}, c, clock, fields, buoyancy) = @inbounds K.wS[i, j, k] \ No newline at end of file diff --git a/NN_closure_global_nof_BBLRifirstzone510.jl b/NN_closure_global_nof_BBLRifirstzone510.jl new file mode 100644 index 0000000000..ec31e26dec --- /dev/null +++ b/NN_closure_global_nof_BBLRifirstzone510.jl @@ -0,0 +1,280 @@ +import Oceananigans.TurbulenceClosures: + compute_diffusivities!, + DiffusivityFields, + viscosity, + diffusivity, + diffusive_flux_x, + diffusive_flux_y, + diffusive_flux_z, + viscous_flux_ux, + viscous_flux_uy, + viscous_flux_uz, + viscous_flux_vx, + viscous_flux_vy, + viscous_flux_vz, + viscous_flux_wx, + viscous_flux_wy, + viscous_flux_wz + +using Oceananigans.BuoyancyModels: ∂x_b, ∂y_b, ∂z_b +using Oceananigans.Coriolis +using Oceananigans.Grids: φnode +using Oceananigans.Grids: total_size +using Oceananigans.Utils: KernelParameters +using Oceananigans: architecture, on_architecture +using Lux, LuxCUDA +using JLD2 +using ComponentArrays +using OffsetArrays +using SeawaterPolynomials.TEOS10 + +using KernelAbstractions: @index, @kernel, @private + +import Oceananigans.TurbulenceClosures: AbstractTurbulenceClosure, ExplicitTimeDiscretization + +using Adapt + +include("./feature_scaling.jl") + +@inline hack_sind(φ) = sin(φ * π / 180) + +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::HydrostaticSphericalCoriolis) = 2 * coriolis.rotation_rate * hack_sind(φnode(i, j, k, grid, Center(), Center(), Center())) +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::FPlane) = coriolis.f +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::BetaPlane) = coriolis.f₀ + coriolis.β * ynode(i, j, k, grid, Center(), Center(), Center()) + +struct NN{M, P, S} + model :: M + ps :: P + st :: S +end + +@inline (neural_network::NN)(input) = first(neural_network.model(input, neural_network.ps, neural_network.st)) +@inline tosarray(x::AbstractArray) = SArray{Tuple{size(x)...}}(x) + +struct NNFluxClosure{A <: NN, S, G} <: AbstractTurbulenceClosure{ExplicitTimeDiscretization, 3} + wT :: A + wS :: A + scaling :: S + grid_point_above :: G + grid_point_below :: G +end + +Adapt.adapt_structure(to, nn :: NNFluxClosure) = + NNFluxClosure(Adapt.adapt(to, nn.wT), + Adapt.adapt(to, nn.wS), + Adapt.adapt(to, nn.scaling), + Adapt.adapt(to, nn.grid_point_above), + Adapt.adapt(to, nn.grid_point_below)) + +Adapt.adapt_structure(to, nn :: NN) = + NN(Adapt.adapt(to, nn.model), + Adapt.adapt(to, nn.ps), + Adapt.adapt(to, nn.st)) + +function NNFluxClosure(arch) + dev = ifelse(arch == GPU(), gpu_device(), cpu_device()) + nn_path = "./NDE3_FC_Qb_nof_BBLRifirst510_trainFC26new_model.jld2" + + ps, sts, scaling_params, wT_model, wS_model = jldopen(nn_path, "r") do file + ps = file["u_train"] |> dev |> f64 + sts = file["sts"] |> dev |> f64 + scaling_params = file["scaling"] + wT_model = file["model"].wT + wS_model = file["model"].wS + return ps, sts, scaling_params, wT_model, wS_model + end + + scaling = construct_zeromeanunitvariance_scaling(scaling_params) + + wT_NN = NN(wT_model, ps.wT, sts.wT) + wS_NN = NN(wS_model, ps.wS, sts.wS) + + grid_point_above = 10 + grid_point_below = 5 + + return NNFluxClosure(wT_NN, wS_NN, scaling, grid_point_above, grid_point_below) +end + +function DiffusivityFields(grid, tracer_names, bcs, closure::NNFluxClosure) + arch = architecture(grid) + wT = ZFaceField(grid) + wS = ZFaceField(grid) + first_index = Field((Center, Center, Nothing), grid, Int32) + last_index = Field((Center, Center, Nothing), grid, Int32) + + N_input = closure.wT.model.layers.layer_1.in_dims + N_levels = closure.grid_point_above + closure.grid_point_below + + Nx_in, Ny_in, _ = size(wT) + wrk_in = zeros(N_input, Nx_in, Ny_in, N_levels) + wrk_in = on_architecture(arch, wrk_in) + + wrk_wT = zeros(Nx_in, Ny_in, 15) + wrk_wS = zeros(Nx_in, Ny_in, 15) + wrk_wT = on_architecture(arch, wrk_wT) + wrk_wS = on_architecture(arch, wrk_wS) + + return (; wrk_in, wrk_wT, wrk_wS, wT, wS, first_index, last_index) +end + +function compute_diffusivities!(diffusivities, closure::NNFluxClosure, model; parameters = :xyz) + arch = model.architecture + grid = model.grid + velocities = model.velocities + tracers = model.tracers + buoyancy = model.buoyancy + coriolis = model.coriolis + clock = model.clock + top_tracer_bcs = NamedTuple(c => tracers[c].boundary_conditions.top for c in propertynames(tracers)) + + Riᶜ = model.closure[1].Riᶜ + Ri = model.diffusivity_fields[1].Ri + + wrk_in = diffusivities.wrk_in + wrk_wT = diffusivities.wrk_wT + wrk_wS = diffusivities.wrk_wS + wT = diffusivities.wT + wS = diffusivities.wS + + first_index = diffusivities.first_index + last_index = diffusivities.last_index + + Nx_in, Ny_in, Nz_in = total_size(wT) + ox_in, oy_in, oz_in = wT.data.offsets + kp = KernelParameters((Nx_in, Ny_in, Nz_in), (ox_in, oy_in, oz_in)) + kp_2D = KernelParameters((Nx_in, Ny_in), (ox_in, oy_in)) + + N_levels = closure.grid_point_above + closure.grid_point_below + + kp_wrk = KernelParameters((Nx_in, Ny_in, N_levels), (0, 0, 0)) + + launch!(arch, grid, kp_2D, _find_NN_active_region!, Ri, grid, Riᶜ, first_index, last_index, closure) + + launch!(arch, grid, kp_wrk, + _populate_input!, wrk_in, first_index, last_index, grid, closure, tracers, velocities, buoyancy, top_tracer_bcs, Ri, clock) + + wrk_wT .= dropdims(closure.wT(wrk_in), dims=1) + wrk_wS .= dropdims(closure.wS(wrk_in), dims=1) + + launch!(arch, grid, kp, _fill_adjust_nn_fluxes!, diffusivities, first_index, last_index, wrk_wT, wrk_wS, grid, closure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + return nothing +end + +@kernel function _populate_input!(input, first_index, last_index, grid, closure::NNFluxClosure, tracers, velocities, buoyancy, top_tracer_bcs, Ri, clock) + i, j, k = @index(Global, NTuple) + + scaling = closure.scaling + + ρ₀ = buoyancy.model.equation_of_state.reference_density + g = buoyancy.model.gravitational_acceleration + + @inbounds k_first = first_index[i, j, 1] + @inbounds k_last = last_index[i, j, 1] + + quiescent = quiescent_condition(k_first, k_last) + + @inbounds k_tracer = clamp_k_interior(k_first + k - 1, grid) + + @inbounds k₋ = clamp_k_interior(k_tracer - 1, grid) + @inbounds k₀ = clamp_k_interior(k_tracer, grid) + @inbounds k₊ = clamp_k_interior(k_tracer + 1, grid) + + T, S = tracers.T, tracers.S + + @inbounds input[1, i, j, k] = ∂Tᵢ₋₁ = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₋, grid, T))) + @inbounds input[2, i, j, k] = ∂Tᵢ = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₀, grid, T))) + @inbounds input[3, i, j, k] = ∂Tᵢ₊₁ = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₊, grid, T))) + + @inbounds input[4, i, j, k] = ∂Sᵢ₋₁ = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₋, grid, S))) + @inbounds input[5, i, j, k] = ∂Sᵢ = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₀, grid, S))) + @inbounds input[6, i, j, k] = ∂Sᵢ₊₁ = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k₊, grid, S))) + + @inbounds input[7, i, j, k] = ∂σᵢ₋₁ = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k₋, grid, buoyancy, tracers) / g)) + @inbounds input[8, i, j, k] = ∂σᵢ = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k₀, grid, buoyancy, tracers) / g)) + @inbounds input[9, i, j, k] = ∂σᵢ₊₁ = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k₊, grid, buoyancy, tracers) / g)) + + @inbounds input[10, i, j, k] = Jᵇ = scaling.wb(top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers))) + +end + +@kernel function _find_NN_active_region!(Ri, grid, Riᶜ, first_index, last_index, closure::NNFluxClosure) + i, j = @index(Global, NTuple) + top_index = grid.Nz + 1 + grid_point_above_kappa = closure.grid_point_above + grid_point_below_kappa = closure.grid_point_below + + # Find the first index of the background κᶜ + kloc = grid.Nz+1 + @inbounds for k in grid.Nz:-1:2 + kloc = ifelse(Ri[i, j, k] < Riᶜ, k, kloc) + end + + background_κ_index = kloc - 1 + nonbackground_κ_index = background_κ_index + 1 + + @inbounds last_index[i, j, 1] = ifelse(nonbackground_κ_index == top_index, top_index - 1, clamp_k_interior(background_κ_index + grid_point_above_kappa, grid)) + @inbounds first_index[i, j, 1] = ifelse(nonbackground_κ_index == top_index, top_index, clamp_k_interior(background_κ_index - grid_point_below_kappa + 1, grid)) +end + +@inline function quiescent_condition(lo, hi) + return hi < lo +end + +@inline function within_zone_condition(k, lo, hi) + return (k >= lo) & (k <= hi) +end + +@inline function clamp_k_interior(k, grid) + kmax = grid.Nz + kmin = 2 + + return clamp(k, kmin, kmax) +end + +@kernel function _fill_adjust_nn_fluxes!(diffusivities, first_index, last_index, wrk_wT, wrk_wS, grid, closure::NNFluxClosure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + scaling = closure.scaling + + @inbounds k_first = first_index[i, j, 1] + @inbounds k_last = last_index[i, j, 1] + + convecting = top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers)) > 0 + quiescent = quiescent_condition(k_first, k_last) + within_zone = within_zone_condition(k, k_first, k_last) + + N_levels = closure.grid_point_above + closure.grid_point_below + @inbounds k_wrk = clamp(k - k_first + 1, 1, N_levels) + + NN_active = convecting & !quiescent & within_zone + + @inbounds diffusivities.wT[i, j, k] = ifelse(NN_active, scaling.wT.σ * wrk_wT[i, j, k_wrk], 0) + @inbounds diffusivities.wS[i, j, k] = ifelse(NN_active, scaling.wS.σ * wrk_wS[i, j, k_wrk], 0) +end + +# Write here your constructor +# NNFluxClosure() = ... insert NN here ... (make sure it is on GPU if you need it on GPU!) + +const NNC = NNFluxClosure + +##### +##### Abstract Smagorinsky functionality +##### + +# Horizontal fluxes are zero! +@inline viscous_flux_wz( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_ux( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_uy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline diffusive_flux_x(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) +@inline diffusive_flux_y(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) + +# Viscous fluxes are zero (for now) +@inline viscous_flux_uz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) + +# The only function extended by NNFluxClosure +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{1}, c, clock, fields, buoyancy) = @inbounds K.wT[i, j, k] +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{2}, c, clock, fields, buoyancy) = @inbounds K.wS[i, j, k] \ No newline at end of file diff --git a/NN_closure_global_nof_BBLintegral.jl b/NN_closure_global_nof_BBLintegral.jl new file mode 100644 index 0000000000..781b8e8902 --- /dev/null +++ b/NN_closure_global_nof_BBLintegral.jl @@ -0,0 +1,224 @@ +import Oceananigans.TurbulenceClosures: + compute_diffusivities!, + DiffusivityFields, + viscosity, + diffusivity, + diffusive_flux_x, + diffusive_flux_y, + diffusive_flux_z, + viscous_flux_ux, + viscous_flux_uy, + viscous_flux_uz, + viscous_flux_vx, + viscous_flux_vy, + viscous_flux_vz, + viscous_flux_wx, + viscous_flux_wy, + viscous_flux_wz + +using Oceananigans.BuoyancyModels: ∂x_b, ∂y_b, ∂z_b +using Oceananigans.Coriolis +using Oceananigans.Grids: φnode +using Oceananigans.Grids: total_size +using Oceananigans.Utils: KernelParameters +using Oceananigans: architecture, on_architecture +using Lux, LuxCUDA +using JLD2 +using ComponentArrays +using OffsetArrays +using SeawaterPolynomials.TEOS10 + +using KernelAbstractions: @index, @kernel, @private + +import Oceananigans.TurbulenceClosures: AbstractTurbulenceClosure, ExplicitTimeDiscretization + +using Adapt + +include("./feature_scaling.jl") + +@inline hack_sind(φ) = sin(φ * π / 180) + +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::HydrostaticSphericalCoriolis) = 2 * coriolis.rotation_rate * hack_sind(φnode(i, j, k, grid, Center(), Center(), Center())) +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::FPlane) = coriolis.f +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::BetaPlane) = coriolis.f₀ + coriolis.β * ynode(i, j, k, grid, Center(), Center(), Center()) + +struct NN{M, P, S} + model :: M + ps :: P + st :: S +end + +@inline (neural_network::NN)(input) = first(neural_network.model(input, neural_network.ps, neural_network.st)) +@inline tosarray(x::AbstractArray) = SArray{Tuple{size(x)...}}(x) + +struct NNFluxClosure{A <: NN, S} <: AbstractTurbulenceClosure{ExplicitTimeDiscretization, 3} + wT :: A + wS :: A + scaling :: S +end + +Adapt.adapt_structure(to, nn :: NNFluxClosure) = + NNFluxClosure(Adapt.adapt(to, nn.wT), + Adapt.adapt(to, nn.wS), + Adapt.adapt(to, nn.scaling)) + +Adapt.adapt_structure(to, nn :: NN) = + NN(Adapt.adapt(to, nn.model), + Adapt.adapt(to, nn.ps), + Adapt.adapt(to, nn.st)) + +function NNFluxClosure(arch) + dev = ifelse(arch == GPU(), gpu_device(), cpu_device()) + nn_path = "./NDE_FC_Qb_nof_BBLintegral_trainFC24new_scalingtrain54new_2layer_64_relu_2Pr_model.jld2" + + ps, sts, scaling_params, wT_model, wS_model = jldopen(nn_path, "r") do file + ps = file["u"] |> dev |> f64 + sts = file["sts"] |> dev |> f64 + scaling_params = file["scaling"] + wT_model = file["model"].wT + wS_model = file["model"].wS + return ps, sts, scaling_params, wT_model, wS_model + end + + scaling = construct_zeromeanunitvariance_scaling(scaling_params) + + wT_NN = NN(wT_model, ps.wT, sts.wT) + wS_NN = NN(wS_model, ps.wS, sts.wS) + + return NNFluxClosure(wT_NN, wS_NN, scaling) +end + +function DiffusivityFields(grid, tracer_names, bcs, ::NNFluxClosure) + arch = architecture(grid) + wT = ZFaceField(grid) + wS = ZFaceField(grid) + BBL_constraint = CenterField(grid) + BBL_index = Field((Center, Center, Nothing), grid, Int32) + + Nx_in, Ny_in, Nz_in = total_size(wT) + ox_in, oy_in, oz_in = wT.data.offsets + + wrk_in = OffsetArray(zeros(10, Nx_in, Ny_in, Nz_in), 0, ox_in, oy_in, oz_in) + wrk_in = on_architecture(arch, wrk_in) + + return (; wrk_in, wT, wS, BBL_constraint, BBL_index) +end + +function compute_diffusivities!(diffusivities, closure::NNFluxClosure, model; parameters = :xyz) + arch = model.architecture + grid = model.grid + velocities = model.velocities + tracers = model.tracers + buoyancy = model.buoyancy + coriolis = model.coriolis + clock = model.clock + top_tracer_bcs = NamedTuple(c => tracers[c].boundary_conditions.top for c in propertynames(tracers)) + input = diffusivities.wrk_in + wT = diffusivities.wT + wS = diffusivities.wS + BBL_constraint = diffusivities.BBL_constraint + BBL_index = diffusivities.BBL_index + + Nx_in, Ny_in, Nz_in = total_size(wT) + ox_in, oy_in, oz_in = wT.data.offsets + kp = KernelParameters((Nx_in, Ny_in, Nz_in), (ox_in, oy_in, oz_in)) + kp_2D = KernelParameters((Nx_in, Ny_in), (ox_in, oy_in)) + + launch!(arch, grid, kp, + _populate_input!, input, BBL_constraint, grid, closure, tracers, velocities, buoyancy, coriolis, top_tracer_bcs, clock) + + launch!(arch, grid, kp_2D, _find_base_boundary_layer!, BBL_constraint, grid, BBL_index) + + wT.data.parent .= dropdims(closure.wT(input.parent), dims=1) + wS.data.parent .= dropdims(closure.wS(input.parent), dims=1) + + launch!(arch, grid, kp, _adjust_nn_fluxes!, diffusivities, grid, closure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + return nothing +end + +@kernel function _populate_input!(input, BBL_constraint, grid, closure::NNFluxClosure, tracers, velocities, buoyancy, coriolis, top_tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + + scaling = closure.scaling + + ρ₀ = buoyancy.model.equation_of_state.reference_density + g = buoyancy.model.gravitational_acceleration + eos = TEOS10.TEOS10EquationOfState() + + T, S = tracers.T, tracers.S + + @inbounds input[1, i, j, k] = ∂Tᵢ₋₁ = scaling.∂T∂z(∂zᶜᶜᶠ(i, j, k-1, grid, T)) + @inbounds input[2, i, j, k] = ∂Tᵢ = scaling.∂T∂z(∂zᶜᶜᶠ(i, j, k, grid, T)) + @inbounds input[3, i, j, k] = ∂Tᵢ₊₁ = scaling.∂T∂z(∂zᶜᶜᶠ(i, j, k+1, grid, T)) + + @inbounds input[4, i, j, k] = ∂Sᵢ₋₁ = scaling.∂S∂z(∂zᶜᶜᶠ(i, j, k-1, grid, S)) + @inbounds input[5, i, j, k] = ∂Sᵢ = scaling.∂S∂z(∂zᶜᶜᶠ(i, j, k, grid, S)) + @inbounds input[6, i, j, k] = ∂Sᵢ₊₁ = scaling.∂S∂z(∂zᶜᶜᶠ(i, j, k+1, grid, S)) + + @inbounds input[7, i, j, k] = ∂σᵢ₋₁ = scaling.∂ρ∂z(-ρ₀ * ∂z_b(i, j, k-1, grid, buoyancy, tracers) / g) + @inbounds input[8, i, j, k] = ∂σᵢ = scaling.∂ρ∂z(-ρ₀ * ∂z_b(i, j, k, grid, buoyancy, tracers) / g) + @inbounds input[9, i, j, k] = ∂σᵢ₊₁ = scaling.∂ρ∂z(-ρ₀ * ∂z_b(i, j, k+1, grid, buoyancy, tracers) / g) + + @inbounds input[10, i, j, k] = Jᵇ = scaling.wb(top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers))) + + @inbounds BBL_constraint[i, j, k] = -ρ₀ * ∂z_b(i, j, k+1, grid, buoyancy, tracers) / g - 8 / grid.zᵃᵃᶜ[k] * (TEOS10.ρ(T[i, j, k], S[i, j, k], 0, eos) - TEOS10.ρ(T[i, j, grid.Nz], S[i, j, grid.Nz], 0, eos)) +end + +@inline function find_based_boundary_layer_index!(i, j, field, grid, h) + kmax = 3 + + @inbounds for k in 6:grid.Nz-1 + kmax = ifelse(field[i, j, k] > 0, k-2, kmax) + end + + @inbounds h[i, j, 1] = kmax +end + +@kernel function _find_base_boundary_layer!(BBL_constraint, grid, h) + i, j = @index(Global, NTuple) + find_based_boundary_layer_index!(i, j, BBL_constraint, grid, h) +end + +@kernel function _adjust_nn_fluxes!(diffusivities, grid, closure::NNFluxClosure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + scaling = closure.scaling + @inbounds BBL_index = diffusivities.BBL_index[i, j, 1] + + convecting = top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers)) > 0 + above_base_boundary_layer = k > BBL_index + below_top = k <= grid.Nz - 1 + above_bottom = k >= 3 + + NN_active = convecting & above_base_boundary_layer & below_top & above_bottom + + @inbounds diffusivities.wT[i, j, k] = ifelse(NN_active, scaling.wT.σ * diffusivities.wT[i, j, k], 0) + @inbounds diffusivities.wS[i, j, k] = ifelse(NN_active, scaling.wS.σ * diffusivities.wS[i, j, k], 0) +end + +# Write here your constructor +# NNFluxClosure() = ... insert NN here ... (make sure it is on GPU if you need it on GPU!) + +const NNC = NNFluxClosure + +##### +##### Abstract Smagorinsky functionality +##### + +# Horizontal fluxes are zero! +@inline viscous_flux_wz( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_ux( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_uy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline diffusive_flux_x(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) +@inline diffusive_flux_y(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) + +# Viscous fluxes are zero (for now) +@inline viscous_flux_uz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) + +# The only function extended by NNFluxClosure +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{1}, c, clock, fields, buoyancy) = @inbounds K.wT[i, j, k] +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{2}, c, clock, fields, buoyancy) = @inbounds K.wS[i, j, k] \ No newline at end of file diff --git a/NN_closure_global_nof_BBLkappazonelast41.jl b/NN_closure_global_nof_BBLkappazonelast41.jl new file mode 100644 index 0000000000..9ee4c3c063 --- /dev/null +++ b/NN_closure_global_nof_BBLkappazonelast41.jl @@ -0,0 +1,245 @@ +import Oceananigans.TurbulenceClosures: + compute_diffusivities!, + DiffusivityFields, + viscosity, + diffusivity, + diffusive_flux_x, + diffusive_flux_y, + diffusive_flux_z, + viscous_flux_ux, + viscous_flux_uy, + viscous_flux_uz, + viscous_flux_vx, + viscous_flux_vy, + viscous_flux_vz, + viscous_flux_wx, + viscous_flux_wy, + viscous_flux_wz + +using Oceananigans.BuoyancyModels: ∂x_b, ∂y_b, ∂z_b +using Oceananigans.Coriolis +using Oceananigans.Grids: φnode +using Oceananigans.Grids: total_size +using Oceananigans.Utils: KernelParameters +using Oceananigans: architecture, on_architecture +using Lux, LuxCUDA +using JLD2 +using ComponentArrays +using OffsetArrays +using SeawaterPolynomials.TEOS10 + +using KernelAbstractions: @index, @kernel, @private + +import Oceananigans.TurbulenceClosures: AbstractTurbulenceClosure, ExplicitTimeDiscretization + +using Adapt + +include("./feature_scaling.jl") + +@inline hack_sind(φ) = sin(φ * π / 180) + +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::HydrostaticSphericalCoriolis) = 2 * coriolis.rotation_rate * hack_sind(φnode(i, j, k, grid, Center(), Center(), Center())) +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::FPlane) = coriolis.f +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::BetaPlane) = coriolis.f₀ + coriolis.β * ynode(i, j, k, grid, Center(), Center(), Center()) + +struct NN{M, P, S} + model :: M + ps :: P + st :: S +end + +@inline (neural_network::NN)(input) = first(neural_network.model(input, neural_network.ps, neural_network.st)) +@inline tosarray(x::AbstractArray) = SArray{Tuple{size(x)...}}(x) + +struct NNFluxClosure{A <: NN, S} <: AbstractTurbulenceClosure{ExplicitTimeDiscretization, 3} + wT :: A + wS :: A + scaling :: S +end + +Adapt.adapt_structure(to, nn :: NNFluxClosure) = + NNFluxClosure(Adapt.adapt(to, nn.wT), + Adapt.adapt(to, nn.wS), + Adapt.adapt(to, nn.scaling)) + +Adapt.adapt_structure(to, nn :: NN) = + NN(Adapt.adapt(to, nn.model), + Adapt.adapt(to, nn.ps), + Adapt.adapt(to, nn.st)) + +function NNFluxClosure(arch) + dev = ifelse(arch == GPU(), gpu_device(), cpu_device()) + nn_path = "./NDE_Qb_dt20min_nof_BBLkappazonelast41_wTwS_64simnew_2layer_128_relu_123seed_1.0e-5lr_localbaseclosure_2Pr_6simstableRi_model_temp.jld2" + + ps, sts, scaling_params, wT_model, wS_model = jldopen(nn_path, "r") do file + ps = file["u"] |> dev |> f64 + sts = file["sts"] |> dev |> f64 + scaling_params = file["scaling"] + wT_model = file["model"].wT + wS_model = file["model"].wS + return ps, sts, scaling_params, wT_model, wS_model + end + + scaling = construct_zeromeanunitvariance_scaling(scaling_params) + + wT_NN = NN(wT_model, ps.wT, sts.wT) + wS_NN = NN(wS_model, ps.wS, sts.wS) + + return NNFluxClosure(wT_NN, wS_NN, scaling) +end + +function DiffusivityFields(grid, tracer_names, bcs, ::NNFluxClosure) + arch = architecture(grid) + wT = ZFaceField(grid) + wS = ZFaceField(grid) + first_index = Field((Center, Center, Nothing), grid, Int32) + last_index = Field((Center, Center, Nothing), grid, Int32) + + Nx_in, Ny_in, _ = size(wT) + + wrk_in = zeros(10, Nx_in, Ny_in, 5) + wrk_in = on_architecture(arch, wrk_in) + + wrk_wT = zeros(Nx_in, Ny_in, 5) + wrk_wS = zeros(Nx_in, Ny_in, 5) + wrk_wT = on_architecture(arch, wrk_wT) + wrk_wS = on_architecture(arch, wrk_wS) + + return (; wrk_in, wrk_wT, wrk_wS, wT, wS, first_index, last_index) +end + +function compute_diffusivities!(diffusivities, closure::NNFluxClosure, model; parameters = :xyz) + arch = model.architecture + grid = model.grid + velocities = model.velocities + tracers = model.tracers + buoyancy = model.buoyancy + coriolis = model.coriolis + clock = model.clock + top_tracer_bcs = NamedTuple(c => tracers[c].boundary_conditions.top for c in propertynames(tracers)) + + κᶜ = model.diffusivity_fields[1].κᶜ + κ₀ = model.closure[1].ν₀ / model.closure[1].Pr_shearₜ + + wrk_in = diffusivities.wrk_in + wrk_wT = diffusivities.wrk_wT + wrk_wS = diffusivities.wrk_wS + wT = diffusivities.wT + wS = diffusivities.wS + + first_index = diffusivities.first_index + last_index = diffusivities.last_index + + Nx_in, Ny_in, Nz_in = total_size(wT) + ox_in, oy_in, oz_in = wT.data.offsets + kp = KernelParameters((Nx_in, Ny_in, Nz_in), (ox_in, oy_in, oz_in)) + kp_2D = KernelParameters((Nx_in, Ny_in), (ox_in, oy_in)) + + kp_wrk = KernelParameters((Nx_in, Ny_in, 5), (0, 0, 0)) + + launch!(arch, grid, kp_2D, _find_NN_active_region!, κᶜ, grid, κ₀, first_index, last_index) + + launch!(arch, grid, kp_wrk, + _populate_input!, wrk_in, first_index, last_index, grid, closure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + + wrk_wT .= dropdims(closure.wT(wrk_in), dims=1) + wrk_wS .= dropdims(closure.wS(wrk_in), dims=1) + + launch!(arch, grid, kp, _fill_adjust_nn_fluxes!, diffusivities, first_index, last_index, wrk_wT, wrk_wS, grid, closure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + return nothing +end + +@kernel function _populate_input!(input, first_index, last_index, grid, closure::NNFluxClosure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + + scaling = closure.scaling + + ρ₀ = buoyancy.model.equation_of_state.reference_density + g = buoyancy.model.gravitational_acceleration + eos = TEOS10.TEOS10EquationOfState() + + @inbounds quiescent = quiescent_condition(first_index[i, j, 1], last_index[i, j, 1]) + @inbounds k_tracer = first_index[i, j, 1] + k - 1 + + T, S = tracers.T, tracers.S + + @inbounds input[1, i, j, k] = ∂Tᵢ₋₁ = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer-1, grid, T))) + @inbounds input[2, i, j, k] = ∂Tᵢ = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer, grid, T))) + @inbounds input[3, i, j, k] = ∂Tᵢ₊₁ = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer+1, grid, T))) + + @inbounds input[4, i, j, k] = ∂Sᵢ₋₁ = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer-1, grid, S))) + @inbounds input[5, i, j, k] = ∂Sᵢ = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer, grid, S))) + @inbounds input[6, i, j, k] = ∂Sᵢ₊₁ = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer+1, grid, S))) + + @inbounds input[7, i, j, k] = ∂σᵢ₋₁ = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k_tracer-1, grid, buoyancy, tracers) / g)) + @inbounds input[8, i, j, k] = ∂σᵢ = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k_tracer, grid, buoyancy, tracers) / g)) + @inbounds input[9, i, j, k] = ∂σᵢ₊₁ = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k_tracer+1, grid, buoyancy, tracers) / g)) + + @inbounds input[10, i, j, k] = Jᵇ = scaling.wb(ifelse(quiescent, 0, top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers)))) + +end + +@kernel function _find_NN_active_region!(κᶜ, grid, κ₀, first_index, last_index) + i, j = @index(Global, NTuple) + top_index = grid.Nz + 1 + + # Find the last index of the background κᶜ + kmax = 1 + @inbounds for k in 2:grid.Nz + kmax = ifelse(κᶜ[i, j, k] ≈ κ₀, k, kmax) + end + + @inbounds last_index[i, j, 1] = ifelse(kmax == top_index, grid.Nz, min(kmax + 1, grid.Nz)) + @inbounds first_index[i, j, 1] = ifelse(kmax == top_index, top_index, max(kmax - 3, 2)) +end + +@inline function quiescent_condition(lo, hi) + return hi - lo != 4 +end + +@kernel function _fill_adjust_nn_fluxes!(diffusivities, first_index, last_index, wrk_wT, wrk_wS, grid, closure::NNFluxClosure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + scaling = closure.scaling + + k_first = first_index[i, j, 1] + k_last = last_index[i, j, 1] + + convecting = top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers)) > 0 + @inbounds quiescent = quiescent_condition(k_first, k_last) + within_zone = (k >= k_first) & (k <= k_last) + + @inbounds k_wrk = clamp(k - k_first + 1, 1, 5) + + NN_active = convecting & !quiescent & within_zone + + @inbounds diffusivities.wT[i, j, k] = ifelse(NN_active, scaling.wT.σ * wrk_wT[i, j, k_wrk], 0) + @inbounds diffusivities.wS[i, j, k] = ifelse(NN_active, scaling.wS.σ * wrk_wS[i, j, k_wrk], 0) +end + +# Write here your constructor +# NNFluxClosure() = ... insert NN here ... (make sure it is on GPU if you need it on GPU!) + +const NNC = NNFluxClosure + +##### +##### Abstract Smagorinsky functionality +##### + +# Horizontal fluxes are zero! +@inline viscous_flux_wz( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_ux( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_uy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline diffusive_flux_x(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) +@inline diffusive_flux_y(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) + +# Viscous fluxes are zero (for now) +@inline viscous_flux_uz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) + +# The only function extended by NNFluxClosure +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{1}, c, clock, fields, buoyancy) = @inbounds K.wT[i, j, k] +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{2}, c, clock, fields, buoyancy) = @inbounds K.wS[i, j, k] \ No newline at end of file diff --git a/NN_closure_global_nof_BBLkappazonelast55.jl b/NN_closure_global_nof_BBLkappazonelast55.jl new file mode 100644 index 0000000000..f7452dae84 --- /dev/null +++ b/NN_closure_global_nof_BBLkappazonelast55.jl @@ -0,0 +1,263 @@ +import Oceananigans.TurbulenceClosures: + compute_diffusivities!, + DiffusivityFields, + viscosity, + diffusivity, + diffusive_flux_x, + diffusive_flux_y, + diffusive_flux_z, + viscous_flux_ux, + viscous_flux_uy, + viscous_flux_uz, + viscous_flux_vx, + viscous_flux_vy, + viscous_flux_vz, + viscous_flux_wx, + viscous_flux_wy, + viscous_flux_wz + +using Oceananigans.BuoyancyModels: ∂x_b, ∂y_b, ∂z_b +using Oceananigans.Coriolis +using Oceananigans.Grids: φnode +using Oceananigans.Grids: total_size +using Oceananigans.Utils: KernelParameters +using Oceananigans: architecture, on_architecture +using Lux, LuxCUDA +using JLD2 +using ComponentArrays +using OffsetArrays +using SeawaterPolynomials.TEOS10 + +using KernelAbstractions: @index, @kernel, @private + +import Oceananigans.TurbulenceClosures: AbstractTurbulenceClosure, ExplicitTimeDiscretization + +using Adapt + +include("./feature_scaling.jl") + +@inline hack_sind(φ) = sin(φ * π / 180) + +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::HydrostaticSphericalCoriolis) = 2 * coriolis.rotation_rate * hack_sind(φnode(i, j, k, grid, Center(), Center(), Center())) +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::FPlane) = coriolis.f +@inline fᶜᶜᵃ(i, j, k, grid, coriolis::BetaPlane) = coriolis.f₀ + coriolis.β * ynode(i, j, k, grid, Center(), Center(), Center()) + +struct NN{M, P, S} + model :: M + ps :: P + st :: S +end + +@inline (neural_network::NN)(input) = first(neural_network.model(input, neural_network.ps, neural_network.st)) +@inline tosarray(x::AbstractArray) = SArray{Tuple{size(x)...}}(x) + +struct NNFluxClosure{A <: NN, S} <: AbstractTurbulenceClosure{ExplicitTimeDiscretization, 3} + wT :: A + wS :: A + scaling :: S +end + +Adapt.adapt_structure(to, nn :: NNFluxClosure) = + NNFluxClosure(Adapt.adapt(to, nn.wT), + Adapt.adapt(to, nn.wS), + Adapt.adapt(to, nn.scaling)) + +Adapt.adapt_structure(to, nn :: NN) = + NN(Adapt.adapt(to, nn.model), + Adapt.adapt(to, nn.ps), + Adapt.adapt(to, nn.st)) + +function NNFluxClosure(arch) + dev = ifelse(arch == GPU(), gpu_device(), cpu_device()) + nn_path = "./NDE_FC_Qb_nof_BBLkappazonelast55_trainFC23new_scalingtrain53new_2layer_128_relu_20seed_2Pr_model.jld2" + + ps, sts, scaling_params, wT_model, wS_model = jldopen(nn_path, "r") do file + ps = file["u"] |> dev |> f64 + sts = file["sts"] |> dev |> f64 + scaling_params = file["scaling"] + wT_model = file["model"].wT + wS_model = file["model"].wS + return ps, sts, scaling_params, wT_model, wS_model + end + + scaling = construct_zeromeanunitvariance_scaling(scaling_params) + + wT_NN = NN(wT_model, ps.wT, sts.wT) + wS_NN = NN(wS_model, ps.wS, sts.wS) + + return NNFluxClosure(wT_NN, wS_NN, scaling) +end + +function DiffusivityFields(grid, tracer_names, bcs, ::NNFluxClosure) + arch = architecture(grid) + wT = ZFaceField(grid) + wS = ZFaceField(grid) + first_index = Field((Center, Center, Nothing), grid, Int32) + last_index = Field((Center, Center, Nothing), grid, Int32) + + Nx_in, Ny_in, _ = size(wT) + wrk_in = zeros(10, Nx_in, Ny_in, 10) + wrk_in = on_architecture(arch, wrk_in) + + wrk_wT = zeros(Nx_in, Ny_in, 10) + wrk_wS = zeros(Nx_in, Ny_in, 10) + wrk_wT = on_architecture(arch, wrk_wT) + wrk_wS = on_architecture(arch, wrk_wS) + + return (; wrk_in, wrk_wT, wrk_wS, wT, wS, first_index, last_index) +end + +function compute_diffusivities!(diffusivities, closure::NNFluxClosure, model; parameters = :xyz) + arch = model.architecture + grid = model.grid + velocities = model.velocities + tracers = model.tracers + buoyancy = model.buoyancy + coriolis = model.coriolis + clock = model.clock + top_tracer_bcs = NamedTuple(c => tracers[c].boundary_conditions.top for c in propertynames(tracers)) + + κᶜ = model.diffusivity_fields[1].κᶜ + κ₀ = model.closure[1].ν₀ / model.closure[1].Pr_shearₜ + Ri = model.diffusivity_fields[1].Ri + + wrk_in = diffusivities.wrk_in + wrk_wT = diffusivities.wrk_wT + wrk_wS = diffusivities.wrk_wS + wT = diffusivities.wT + wS = diffusivities.wS + + first_index = diffusivities.first_index + last_index = diffusivities.last_index + + Nx_in, Ny_in, Nz_in = total_size(wT) + ox_in, oy_in, oz_in = wT.data.offsets + kp = KernelParameters((Nx_in, Ny_in, Nz_in), (ox_in, oy_in, oz_in)) + kp_2D = KernelParameters((Nx_in, Ny_in), (ox_in, oy_in)) + + kp_wrk = KernelParameters((Nx_in, Ny_in, 10), (0, 0, 0)) + + launch!(arch, grid, kp_2D, _find_NN_active_region!, κᶜ, grid, κ₀, first_index, last_index) + + launch!(arch, grid, kp_wrk, + _populate_input!, wrk_in, first_index, last_index, grid, closure, tracers, velocities, buoyancy, top_tracer_bcs, Ri, clock) + + wrk_wT .= dropdims(closure.wT(wrk_in), dims=1) + wrk_wS .= dropdims(closure.wS(wrk_in), dims=1) + + launch!(arch, grid, kp, _fill_adjust_nn_fluxes!, diffusivities, first_index, last_index, wrk_wT, wrk_wS, grid, closure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + return nothing +end + +@kernel function _populate_input!(input, first_index, last_index, grid, closure::NNFluxClosure, tracers, velocities, buoyancy, top_tracer_bcs, Ri, clock) + i, j, k = @index(Global, NTuple) + + scaling = closure.scaling + + ρ₀ = buoyancy.model.equation_of_state.reference_density + g = buoyancy.model.gravitational_acceleration + + @inbounds k_first = first_index[i, j, 1] + @inbounds k_last = last_index[i, j, 1] + + quiescent = quiescent_condition(k_first, k_last) + + @inbounds k_tracer = clamp_k_interior(k_first + k - 1, grid) + + T, S = tracers.T, tracers.S + + @inbounds input[1, i, j, k] = ∂Tᵢ₋₁ = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer-1, grid, T))) + @inbounds input[2, i, j, k] = ∂Tᵢ = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer, grid, T))) + @inbounds input[3, i, j, k] = ∂Tᵢ₊₁ = scaling.∂T∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer+1, grid, T))) + + @inbounds input[4, i, j, k] = ∂Sᵢ₋₁ = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer-1, grid, S))) + @inbounds input[5, i, j, k] = ∂Sᵢ = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer, grid, S))) + @inbounds input[6, i, j, k] = ∂Sᵢ₊₁ = scaling.∂S∂z(ifelse(quiescent, 0, ∂zᶜᶜᶠ(i, j, k_tracer+1, grid, S))) + + @inbounds input[7, i, j, k] = ∂σᵢ₋₁ = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k_tracer-1, grid, buoyancy, tracers) / g)) + @inbounds input[8, i, j, k] = ∂σᵢ = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k_tracer, grid, buoyancy, tracers) / g)) + @inbounds input[9, i, j, k] = ∂σᵢ₊₁ = scaling.∂ρ∂z(ifelse(quiescent, 0, -ρ₀ * ∂z_b(i, j, k_tracer+1, grid, buoyancy, tracers) / g)) + + @inbounds input[10, i, j, k] = Jᵇ = scaling.wb(top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers))) + +end + +@kernel function _find_NN_active_region!(κᶜ, grid, κ₀, first_index, last_index) + i, j = @index(Global, NTuple) + top_index = grid.Nz + 1 + grid_point_above_kappa = 5 + grid_point_below_kappa = 5 + + # Find the last index of the background κᶜ + kloc = 1 + @inbounds for k in 2:grid.Nz + kloc = ifelse(κᶜ[i, j, k] ≈ κ₀, k, kloc) + end + + nonbackground_κ_index = kloc + 1 + + @inbounds last_index[i, j, 1] = ifelse(nonbackground_κ_index == top_index, top_index - 1, clamp_k_interior(kloc + grid_point_above_kappa, grid)) + @inbounds first_index[i, j, 1] = ifelse(nonbackground_κ_index == top_index, top_index, clamp_k_interior(kloc - grid_point_below_kappa + 1, grid)) +end + +@inline function quiescent_condition(lo, hi) + return hi < lo +end + +@inline function within_zone_condition(k, lo, hi) + return (k >= lo) & (k <= hi) +end + +@inline function clamp_k_interior(k, grid) + kmax = grid.Nz - 1 + kmin = 3 + + return clamp(k, kmin, kmax) +end + +@kernel function _fill_adjust_nn_fluxes!(diffusivities, first_index, last_index, wrk_wT, wrk_wS, grid, closure::NNFluxClosure, tracers, velocities, buoyancy, top_tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + scaling = closure.scaling + + @inbounds k_first = first_index[i, j, 1] + @inbounds k_last = last_index[i, j, 1] + + convecting = top_buoyancy_flux(i, j, grid, buoyancy, top_tracer_bcs, clock, merge(velocities, tracers)) > 0 + quiescent = quiescent_condition(k_first, k_last) + within_zone = within_zone_condition(k, k_first, k_last) + + @inbounds k_wrk = clamp(k - k_first + 1, 1, 10) + + NN_active = convecting & !quiescent & within_zone + + @inbounds diffusivities.wT[i, j, k] = ifelse(NN_active, scaling.wT.σ * wrk_wT[i, j, k_wrk], 0) + @inbounds diffusivities.wS[i, j, k] = ifelse(NN_active, scaling.wS.σ * wrk_wS[i, j, k_wrk], 0) +end + +# Write here your constructor +# NNFluxClosure() = ... insert NN here ... (make sure it is on GPU if you need it on GPU!) + +const NNC = NNFluxClosure + +##### +##### Abstract Smagorinsky functionality +##### + +# Horizontal fluxes are zero! +@inline viscous_flux_wz( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_wy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_ux( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vx( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_uy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vy( i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline diffusive_flux_x(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) +@inline diffusive_flux_y(i, j, k, grid, clo::NNC, K, ::Val{tracer_index}, c, clock, fields, buoyancy) where tracer_index = zero(grid) + +# Viscous fluxes are zero (for now) +@inline viscous_flux_uz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) +@inline viscous_flux_vz(i, j, k, grid, clo::NNC, K, clk, fields, b) = zero(grid) + +# The only function extended by NNFluxClosure +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{1}, c, clock, fields, buoyancy) = @inbounds K.wT[i, j, k] +@inline diffusive_flux_z(i, j, k, grid, clo::NNC, K, ::Val{2}, c, clock, fields, buoyancy) = @inbounds K.wS[i, j, k] \ No newline at end of file diff --git a/NNclosure_Ri_nof_BBLRifirstzone510_doublegyre_model.jl b/NNclosure_Ri_nof_BBLRifirstzone510_doublegyre_model.jl new file mode 100644 index 0000000000..7a5f9d2a2e --- /dev/null +++ b/NNclosure_Ri_nof_BBLRifirstzone510_doublegyre_model.jl @@ -0,0 +1,1263 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("NN_closure_global_Ri_nof_BBLRifirstzone510.jl") +include("xin_kai_vertical_diffusivity_local_2step_new.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + + +#%% +filename = "doublegyre_30Cwarmflushbottom10_relaxation_30days_zWENO5_NN_closure_NDE5_Ri_BBLRifirztzone510_temp" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +nn_closure = NNFluxClosure(model_architecture) +base_closure = XinKaiLocalVerticalDiffusivity() +closure = (base_closure, nn_closure) + +advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), WENO(order=5)) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +##### +##### Forcing and initial condition +##### +@inline T_initial(x, y, z) = 10 + 20 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10800days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); +first_index = model.diffusivity_fields[2].first_index +last_index = model.diffusivity_fields[2].last_index +wT_NN = model.diffusivity_fields[2].wT +wS_NN = model.diffusivity_fields[2].wS + +κ = model.diffusivity_fields[1].κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) + +wT = wT_NN + wT_base +wS = wS_NN + wS_base + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +@inline function get_top_buoyancy_flux(i, j, k, grid, buoyancy, T_bc, S_bc, velocities, tracers, clock) + return top_buoyancy_flux(i, j, grid, buoyancy, (; T=T_bc, S=S_bc), clock, merge(velocities, tracers)) +end + +Qb = KernelFunctionOperation{Center, Center, Nothing}(get_top_buoyancy_flux, model.grid, model.buoyancy, T.boundary_conditions.top, S.boundary_conditions.top, model.velocities, model.tracers, model.clock) +Qb = Field(Qb) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +wT_NNbar_zonal = Average(wT_NN, dims=1) +wS_NNbar_zonal = Average(wS_NN, dims=1) + +wT_basebar_zonal = Average(wT_base, dims=1) +wS_basebar_zonal = Average(wS_base, dims=1) + +wTbar_zonal = Average(wT, dims=1) +wSbar_zonal = Average(wS, dims=1) + +outputs = (; u, v, w, T, S, ρ, N², wT_NN, wS_NN, wT_base, wS_base, wT, wS) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal, wT_NNbar_zonal, wS_NNbar_zonal, wT_basebar_zonal, wS_basebar_zonal, wTbar_zonal, wSbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_5] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_5", + indices = (:, 5, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_15] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_15", + indices = (:, 15, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_25] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_25", + indices = (:, 25, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_35] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_35", + indices = (:, 35, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_45] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_45", + indices = (:, 45, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_55] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_55", + indices = (:, 55, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_65] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_65", + indices = (:, 65, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_75] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_75", + indices = (:, 75, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_85] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_85", + indices = (:, 85, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_95] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_95", + indices = (:, 95, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:BBL] = JLD2OutputWriter(model, (; first_index, last_index, Qb), + filename = "$(FILE_DIR)/instantaneous_fields_NN_active_diagnostics", + indices = (:, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end + +#%% +@info "Recording T fields and fluxes in yz" + +fieldname = "T" +fluxname = "wT_NN" +field_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fieldname, backend=OnDisk()) +field_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fieldname, backend=OnDisk()) +field_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fieldname, backend=OnDisk()) +field_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fieldname, backend=OnDisk()) +field_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fieldname, backend=OnDisk()) +field_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fieldname, backend=OnDisk()) +field_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fieldname, backend=OnDisk()) +field_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fieldname, backend=OnDisk()) +field_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fieldname, backend=OnDisk()) +field_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fluxname, backend=OnDisk()) +flux_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fluxname, backend=OnDisk()) +flux_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fluxname, backend=OnDisk()) +flux_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fluxname, backend=OnDisk()) +flux_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fluxname, backend=OnDisk()) +flux_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fluxname, backend=OnDisk()) +flux_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fluxname, backend=OnDisk()) +flux_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fluxname, backend=OnDisk()) +flux_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fluxname, backend=OnDisk()) +flux_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fluxname, backend=OnDisk()) + +xC = field_NN_data_00.grid.xᶜᵃᵃ[1:field_NN_data_00.grid.Nx] +yC = field_NN_data_00.grid.yᵃᶜᵃ[1:field_NN_data_00.grid.Ny] +zC = field_NN_data_00.grid.zᵃᵃᶜ[1:field_NN_data_00.grid.Nz] +zF = field_NN_data_00.grid.zᵃᵃᶠ[1:field_NN_data_00.grid.Nz+1] + +Nt = length(field_NN_data_90) +times = field_NN_data_00.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_00 = CairoMakie.Axis(fig[1, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_00.grid.xᶜᵃᵃ[field_NN_data_00.indices[1][1]] / 1000) km") +axfield_10 = CairoMakie.Axis(fig[1, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_10.grid.xᶜᵃᵃ[field_NN_data_10.indices[1][1]] / 1000) km") +axfield_20 = CairoMakie.Axis(fig[2, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_20.grid.xᶜᵃᵃ[field_NN_data_20.indices[1][1]] / 1000) km") +axfield_30 = CairoMakie.Axis(fig[2, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_30.grid.xᶜᵃᵃ[field_NN_data_30.indices[1][1]] / 1000) km") +axfield_40 = CairoMakie.Axis(fig[3, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_40.grid.xᶜᵃᵃ[field_NN_data_40.indices[1][1]] / 1000) km") +axfield_50 = CairoMakie.Axis(fig[3, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_50.grid.xᶜᵃᵃ[field_NN_data_50.indices[1][1]] / 1000) km") +axfield_60 = CairoMakie.Axis(fig[4, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_60.grid.xᶜᵃᵃ[field_NN_data_60.indices[1][1]] / 1000) km") +axfield_70 = CairoMakie.Axis(fig[4, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_70.grid.xᶜᵃᵃ[field_NN_data_70.indices[1][1]] / 1000) km") +axfield_80 = CairoMakie.Axis(fig[5, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_80.grid.xᶜᵃᵃ[field_NN_data_80.indices[1][1]] / 1000) km") +axfield_90 = CairoMakie.Axis(fig[5, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_90.grid.xᶜᵃᵃ[field_NN_data_90.indices[1][1]] / 1000) km") + +axflux_00 = CairoMakie.Axis(fig[1, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_00.grid.xᶜᵃᵃ[flux_NN_data_00.indices[1][1]] / 1000) km") +axflux_10 = CairoMakie.Axis(fig[1, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_10.grid.xᶜᵃᵃ[flux_NN_data_10.indices[1][1]] / 1000) km") +axflux_20 = CairoMakie.Axis(fig[2, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_20.grid.xᶜᵃᵃ[flux_NN_data_20.indices[1][1]] / 1000) km") +axflux_30 = CairoMakie.Axis(fig[2, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_30.grid.xᶜᵃᵃ[flux_NN_data_30.indices[1][1]] / 1000) km") +axflux_40 = CairoMakie.Axis(fig[3, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_40.grid.xᶜᵃᵃ[flux_NN_data_40.indices[1][1]] / 1000) km") +axflux_50 = CairoMakie.Axis(fig[3, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_50.grid.xᶜᵃᵃ[flux_NN_data_50.indices[1][1]] / 1000) km") +axflux_60 = CairoMakie.Axis(fig[4, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_60.grid.xᶜᵃᵃ[flux_NN_data_60.indices[1][1]] / 1000) km") +axflux_70 = CairoMakie.Axis(fig[4, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_70.grid.xᶜᵃᵃ[flux_NN_data_70.indices[1][1]] / 1000) km") +axflux_80 = CairoMakie.Axis(fig[5, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_80.grid.xᶜᵃᵃ[flux_NN_data_80.indices[1][1]] / 1000) km") +axflux_90 = CairoMakie.Axis(fig[5, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_90.grid.xᶜᵃᵃ[flux_NN_data_90.indices[1][1]] / 1000) km") + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lim = (find_min(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (find_min(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (-maximum(abs, flux_lim), maximum(abs, flux_lim)) + +NN_00ₙ = @lift interior(field_NN_data_00[$n], 1, :, zC_indices) +NN_10ₙ = @lift interior(field_NN_data_10[$n], 1, :, zC_indices) +NN_20ₙ = @lift interior(field_NN_data_20[$n], 1, :, zC_indices) +NN_30ₙ = @lift interior(field_NN_data_30[$n], 1, :, zC_indices) +NN_40ₙ = @lift interior(field_NN_data_40[$n], 1, :, zC_indices) +NN_50ₙ = @lift interior(field_NN_data_50[$n], 1, :, zC_indices) +NN_60ₙ = @lift interior(field_NN_data_60[$n], 1, :, zC_indices) +NN_70ₙ = @lift interior(field_NN_data_70[$n], 1, :, zC_indices) +NN_80ₙ = @lift interior(field_NN_data_80[$n], 1, :, zC_indices) +NN_90ₙ = @lift interior(field_NN_data_90[$n], 1, :, zC_indices) + +flux_00ₙ = @lift interior(flux_NN_data_00[$n], 1, :, zF_indices) +flux_10ₙ = @lift interior(flux_NN_data_10[$n], 1, :, zF_indices) +flux_20ₙ = @lift interior(flux_NN_data_20[$n], 1, :, zF_indices) +flux_30ₙ = @lift interior(flux_NN_data_30[$n], 1, :, zF_indices) +flux_40ₙ = @lift interior(flux_NN_data_40[$n], 1, :, zF_indices) +flux_50ₙ = @lift interior(flux_NN_data_50[$n], 1, :, zF_indices) +flux_60ₙ = @lift interior(flux_NN_data_60[$n], 1, :, zF_indices) +flux_70ₙ = @lift interior(flux_NN_data_70[$n], 1, :, zF_indices) +flux_80ₙ = @lift interior(flux_NN_data_80[$n], 1, :, zF_indices) +flux_90ₙ = @lift interior(flux_NN_data_90[$n], 1, :, zF_indices) + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_00_surface = heatmap!(axfield_00, yC, zC[zC_indices], NN_00ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_10_surface = heatmap!(axfield_10, yC, zC[zC_indices], NN_10ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_20_surface = heatmap!(axfield_20, yC, zC[zC_indices], NN_20ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_30_surface = heatmap!(axfield_30, yC, zC[zC_indices], NN_30ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_40_surface = heatmap!(axfield_40, yC, zC[zC_indices], NN_40ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_50_surface = heatmap!(axfield_50, yC, zC[zC_indices], NN_50ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_60_surface = heatmap!(axfield_60, yC, zC[zC_indices], NN_60ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_70_surface = heatmap!(axfield_70, yC, zC[zC_indices], NN_70ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_80_surface = heatmap!(axfield_80, yC, zC[zC_indices], NN_80ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_90_surface = heatmap!(axfield_90, yC, zC[zC_indices], NN_90ₙ, colormap=colorscheme_field, colorrange=field_lim) + +flux_00_surface = heatmap!(axflux_00, yC, zC[zF_indices], flux_00ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_10_surface = heatmap!(axflux_10, yC, zC[zF_indices], flux_10ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_20_surface = heatmap!(axflux_20, yC, zC[zF_indices], flux_20ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_30_surface = heatmap!(axflux_30, yC, zC[zF_indices], flux_30ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_40_surface = heatmap!(axflux_40, yC, zC[zF_indices], flux_40ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_50_surface = heatmap!(axflux_50, yC, zC[zF_indices], flux_50ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_60_surface = heatmap!(axflux_60, yC, zC[zF_indices], flux_60ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_70_surface = heatmap!(axflux_70, yC, zC[zF_indices], flux_70ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_80_surface = heatmap!(axflux_80, yC, zC[zF_indices], flux_80ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_90_surface = heatmap!(axflux_90, yC, zC[zF_indices], flux_90ₙ, colormap=colorscheme_flux, colorrange=flux_lim) + +Colorbar(fig[1:5, 2], field_00_surface, label="Field") +Colorbar(fig[1:5, 4], flux_00_surface, label="NN Flux") +Colorbar(fig[1:5, 6], field_00_surface, label="Field") +Colorbar(fig[1:5, 8], flux_00_surface, label="NN Flux") + +xlims!(axfield_00, minimum(yC), maximum(yC)) +xlims!(axfield_10, minimum(yC), maximum(yC)) +xlims!(axfield_20, minimum(yC), maximum(yC)) +xlims!(axfield_30, minimum(yC), maximum(yC)) +xlims!(axfield_40, minimum(yC), maximum(yC)) +xlims!(axfield_50, minimum(yC), maximum(yC)) +xlims!(axfield_60, minimum(yC), maximum(yC)) +xlims!(axfield_70, minimum(yC), maximum(yC)) +xlims!(axfield_80, minimum(yC), maximum(yC)) +xlims!(axfield_90, minimum(yC), maximum(yC)) + +ylims!(axfield_00, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_10, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_20, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_30, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_40, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_50, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_60, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_70, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_80, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_90, minimum(zC[zC_indices]), maximum(zC[zC_indices])) + +xlims!(axflux_00, minimum(yC), maximum(yC)) +xlims!(axflux_10, minimum(yC), maximum(yC)) +xlims!(axflux_20, minimum(yC), maximum(yC)) +xlims!(axflux_30, minimum(yC), maximum(yC)) +xlims!(axflux_40, minimum(yC), maximum(yC)) +xlims!(axflux_50, minimum(yC), maximum(yC)) +xlims!(axflux_60, minimum(yC), maximum(yC)) +xlims!(axflux_70, minimum(yC), maximum(yC)) +xlims!(axflux_80, minimum(yC), maximum(yC)) +xlims!(axflux_90, minimum(yC), maximum(yC)) + +ylims!(axflux_00, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_10, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_20, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_30, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_40, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_50, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_60, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_70, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_80, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_90, minimum(zF[zF_indices]), maximum(zF[zF_indices])) + +title_str = @lift "Temperature (°C), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +# save("./Output/compare_3D_instantaneous_fields_slices_NNclosure_fluxes.png", fig) +# display(fig) +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_yzslices_fluxes_T.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end +#%% +@info "Recording S fields and fluxes in yz" + +fieldname = "S" +fluxname = "wS_NN" +field_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fieldname, backend=OnDisk()) +field_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fieldname, backend=OnDisk()) +field_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fieldname, backend=OnDisk()) +field_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fieldname, backend=OnDisk()) +field_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fieldname, backend=OnDisk()) +field_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fieldname, backend=OnDisk()) +field_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fieldname, backend=OnDisk()) +field_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fieldname, backend=OnDisk()) +field_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fieldname, backend=OnDisk()) +field_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fluxname, backend=OnDisk()) +flux_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fluxname, backend=OnDisk()) +flux_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fluxname, backend=OnDisk()) +flux_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fluxname, backend=OnDisk()) +flux_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fluxname, backend=OnDisk()) +flux_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fluxname, backend=OnDisk()) +flux_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fluxname, backend=OnDisk()) +flux_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fluxname, backend=OnDisk()) +flux_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fluxname, backend=OnDisk()) +flux_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fluxname, backend=OnDisk()) + +xC = field_NN_data_00.grid.xᶜᵃᵃ[1:field_NN_data_00.grid.Nx] +yC = field_NN_data_00.grid.yᵃᶜᵃ[1:field_NN_data_00.grid.Ny] +zC = field_NN_data_00.grid.zᵃᵃᶜ[1:field_NN_data_00.grid.Nz] +zF = field_NN_data_00.grid.zᵃᵃᶠ[1:field_NN_data_00.grid.Nz+1] + +Nt = length(field_NN_data_90) +times = field_NN_data_00.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_00 = CairoMakie.Axis(fig[1, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_00.grid.xᶜᵃᵃ[field_NN_data_00.indices[1][1]] / 1000) km") +axfield_10 = CairoMakie.Axis(fig[1, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_10.grid.xᶜᵃᵃ[field_NN_data_10.indices[1][1]] / 1000) km") +axfield_20 = CairoMakie.Axis(fig[2, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_20.grid.xᶜᵃᵃ[field_NN_data_20.indices[1][1]] / 1000) km") +axfield_30 = CairoMakie.Axis(fig[2, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_30.grid.xᶜᵃᵃ[field_NN_data_30.indices[1][1]] / 1000) km") +axfield_40 = CairoMakie.Axis(fig[3, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_40.grid.xᶜᵃᵃ[field_NN_data_40.indices[1][1]] / 1000) km") +axfield_50 = CairoMakie.Axis(fig[3, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_50.grid.xᶜᵃᵃ[field_NN_data_50.indices[1][1]] / 1000) km") +axfield_60 = CairoMakie.Axis(fig[4, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_60.grid.xᶜᵃᵃ[field_NN_data_60.indices[1][1]] / 1000) km") +axfield_70 = CairoMakie.Axis(fig[4, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_70.grid.xᶜᵃᵃ[field_NN_data_70.indices[1][1]] / 1000) km") +axfield_80 = CairoMakie.Axis(fig[5, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_80.grid.xᶜᵃᵃ[field_NN_data_80.indices[1][1]] / 1000) km") +axfield_90 = CairoMakie.Axis(fig[5, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_90.grid.xᶜᵃᵃ[field_NN_data_90.indices[1][1]] / 1000) km") + +axflux_00 = CairoMakie.Axis(fig[1, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_00.grid.xᶜᵃᵃ[flux_NN_data_00.indices[1][1]] / 1000) km") +axflux_10 = CairoMakie.Axis(fig[1, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_10.grid.xᶜᵃᵃ[flux_NN_data_10.indices[1][1]] / 1000) km") +axflux_20 = CairoMakie.Axis(fig[2, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_20.grid.xᶜᵃᵃ[flux_NN_data_20.indices[1][1]] / 1000) km") +axflux_30 = CairoMakie.Axis(fig[2, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_30.grid.xᶜᵃᵃ[flux_NN_data_30.indices[1][1]] / 1000) km") +axflux_40 = CairoMakie.Axis(fig[3, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_40.grid.xᶜᵃᵃ[flux_NN_data_40.indices[1][1]] / 1000) km") +axflux_50 = CairoMakie.Axis(fig[3, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_50.grid.xᶜᵃᵃ[flux_NN_data_50.indices[1][1]] / 1000) km") +axflux_60 = CairoMakie.Axis(fig[4, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_60.grid.xᶜᵃᵃ[flux_NN_data_60.indices[1][1]] / 1000) km") +axflux_70 = CairoMakie.Axis(fig[4, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_70.grid.xᶜᵃᵃ[flux_NN_data_70.indices[1][1]] / 1000) km") +axflux_80 = CairoMakie.Axis(fig[5, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_80.grid.xᶜᵃᵃ[flux_NN_data_80.indices[1][1]] / 1000) km") +axflux_90 = CairoMakie.Axis(fig[5, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_90.grid.xᶜᵃᵃ[flux_NN_data_90.indices[1][1]] / 1000) km") + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lim = (find_min(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (find_min(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (-maximum(abs, flux_lim), maximum(abs, flux_lim)) + +NN_00ₙ = @lift interior(field_NN_data_00[$n], 1, :, zC_indices) +NN_10ₙ = @lift interior(field_NN_data_10[$n], 1, :, zC_indices) +NN_20ₙ = @lift interior(field_NN_data_20[$n], 1, :, zC_indices) +NN_30ₙ = @lift interior(field_NN_data_30[$n], 1, :, zC_indices) +NN_40ₙ = @lift interior(field_NN_data_40[$n], 1, :, zC_indices) +NN_50ₙ = @lift interior(field_NN_data_50[$n], 1, :, zC_indices) +NN_60ₙ = @lift interior(field_NN_data_60[$n], 1, :, zC_indices) +NN_70ₙ = @lift interior(field_NN_data_70[$n], 1, :, zC_indices) +NN_80ₙ = @lift interior(field_NN_data_80[$n], 1, :, zC_indices) +NN_90ₙ = @lift interior(field_NN_data_90[$n], 1, :, zC_indices) + +flux_00ₙ = @lift interior(flux_NN_data_00[$n], 1, :, zF_indices) +flux_10ₙ = @lift interior(flux_NN_data_10[$n], 1, :, zF_indices) +flux_20ₙ = @lift interior(flux_NN_data_20[$n], 1, :, zF_indices) +flux_30ₙ = @lift interior(flux_NN_data_30[$n], 1, :, zF_indices) +flux_40ₙ = @lift interior(flux_NN_data_40[$n], 1, :, zF_indices) +flux_50ₙ = @lift interior(flux_NN_data_50[$n], 1, :, zF_indices) +flux_60ₙ = @lift interior(flux_NN_data_60[$n], 1, :, zF_indices) +flux_70ₙ = @lift interior(flux_NN_data_70[$n], 1, :, zF_indices) +flux_80ₙ = @lift interior(flux_NN_data_80[$n], 1, :, zF_indices) +flux_90ₙ = @lift interior(flux_NN_data_90[$n], 1, :, zF_indices) + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_00_surface = heatmap!(axfield_00, yC, zC[zC_indices], NN_00ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_10_surface = heatmap!(axfield_10, yC, zC[zC_indices], NN_10ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_20_surface = heatmap!(axfield_20, yC, zC[zC_indices], NN_20ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_30_surface = heatmap!(axfield_30, yC, zC[zC_indices], NN_30ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_40_surface = heatmap!(axfield_40, yC, zC[zC_indices], NN_40ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_50_surface = heatmap!(axfield_50, yC, zC[zC_indices], NN_50ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_60_surface = heatmap!(axfield_60, yC, zC[zC_indices], NN_60ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_70_surface = heatmap!(axfield_70, yC, zC[zC_indices], NN_70ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_80_surface = heatmap!(axfield_80, yC, zC[zC_indices], NN_80ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_90_surface = heatmap!(axfield_90, yC, zC[zC_indices], NN_90ₙ, colormap=colorscheme_field, colorrange=field_lim) + +flux_00_surface = heatmap!(axflux_00, yC, zC[zF_indices], flux_00ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_10_surface = heatmap!(axflux_10, yC, zC[zF_indices], flux_10ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_20_surface = heatmap!(axflux_20, yC, zC[zF_indices], flux_20ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_30_surface = heatmap!(axflux_30, yC, zC[zF_indices], flux_30ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_40_surface = heatmap!(axflux_40, yC, zC[zF_indices], flux_40ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_50_surface = heatmap!(axflux_50, yC, zC[zF_indices], flux_50ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_60_surface = heatmap!(axflux_60, yC, zC[zF_indices], flux_60ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_70_surface = heatmap!(axflux_70, yC, zC[zF_indices], flux_70ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_80_surface = heatmap!(axflux_80, yC, zC[zF_indices], flux_80ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_90_surface = heatmap!(axflux_90, yC, zC[zF_indices], flux_90ₙ, colormap=colorscheme_flux, colorrange=flux_lim) + +Colorbar(fig[1:5, 2], field_00_surface, label="Field") +Colorbar(fig[1:5, 4], flux_00_surface, label="NN Flux") +Colorbar(fig[1:5, 6], field_00_surface, label="Field") +Colorbar(fig[1:5, 8], flux_00_surface, label="NN Flux") + +xlims!(axfield_00, minimum(yC), maximum(yC)) +xlims!(axfield_10, minimum(yC), maximum(yC)) +xlims!(axfield_20, minimum(yC), maximum(yC)) +xlims!(axfield_30, minimum(yC), maximum(yC)) +xlims!(axfield_40, minimum(yC), maximum(yC)) +xlims!(axfield_50, minimum(yC), maximum(yC)) +xlims!(axfield_60, minimum(yC), maximum(yC)) +xlims!(axfield_70, minimum(yC), maximum(yC)) +xlims!(axfield_80, minimum(yC), maximum(yC)) +xlims!(axfield_90, minimum(yC), maximum(yC)) + +ylims!(axfield_00, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_10, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_20, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_30, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_40, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_50, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_60, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_70, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_80, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_90, minimum(zC[zC_indices]), maximum(zC[zC_indices])) + +xlims!(axflux_00, minimum(yC), maximum(yC)) +xlims!(axflux_10, minimum(yC), maximum(yC)) +xlims!(axflux_20, minimum(yC), maximum(yC)) +xlims!(axflux_30, minimum(yC), maximum(yC)) +xlims!(axflux_40, minimum(yC), maximum(yC)) +xlims!(axflux_50, minimum(yC), maximum(yC)) +xlims!(axflux_60, minimum(yC), maximum(yC)) +xlims!(axflux_70, minimum(yC), maximum(yC)) +xlims!(axflux_80, minimum(yC), maximum(yC)) +xlims!(axflux_90, minimum(yC), maximum(yC)) + +ylims!(axflux_00, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_10, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_20, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_30, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_40, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_50, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_60, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_70, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_80, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_90, minimum(zF[zF_indices]), maximum(zF[zF_indices])) + +title_str = @lift "Salinity (psu), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_yzslices_fluxes_S.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +#%% +@info "Recording T fields and fluxes in xz" + +fieldname = "T" +fluxname = "wT_NN" + +field_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fieldname, backend=OnDisk()) +field_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fieldname, backend=OnDisk()) +field_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fieldname, backend=OnDisk()) +field_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fieldname, backend=OnDisk()) +field_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fieldname, backend=OnDisk()) +field_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fieldname, backend=OnDisk()) +field_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fieldname, backend=OnDisk()) +field_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fieldname, backend=OnDisk()) +field_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fieldname, backend=OnDisk()) +field_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fluxname, backend=OnDisk()) +flux_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fluxname, backend=OnDisk()) +flux_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fluxname, backend=OnDisk()) +flux_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fluxname, backend=OnDisk()) +flux_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fluxname, backend=OnDisk()) +flux_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fluxname, backend=OnDisk()) +flux_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fluxname, backend=OnDisk()) +flux_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fluxname, backend=OnDisk()) +flux_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fluxname, backend=OnDisk()) +flux_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fluxname, backend=OnDisk()) + +field_datas = [field_NN_data_5, field_NN_data_15, field_NN_data_25, field_NN_data_35, field_NN_data_45, field_NN_data_55, field_NN_data_65, field_NN_data_75, field_NN_data_85, field_NN_data_95] +flux_datas = [flux_NN_data_5, flux_NN_data_15, flux_NN_data_25, flux_NN_data_35, flux_NN_data_45, flux_NN_data_55, flux_NN_data_65, flux_NN_data_75, flux_NN_data_85, flux_NN_data_95] + +xC = field_NN_data_5.grid.xᶜᵃᵃ[1:field_NN_data_5.grid.Nx] +yC = field_NN_data_5.grid.yᵃᶜᵃ[1:field_NN_data_5.grid.Ny] +zC = field_NN_data_5.grid.zᵃᵃᶜ[1:field_NN_data_5.grid.Nz] +zF = field_NN_data_5.grid.zᵃᵃᶠ[1:field_NN_data_5.grid.Nz+1] + +Nt = length(field_NN_data_95) +times = field_NN_data_5.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_5 = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_5.grid.yᵃᶜᵃ[field_NN_data_5.indices[2][1]] / 1000) km") +axfield_15 = CairoMakie.Axis(fig[1, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_15.grid.yᵃᶜᵃ[field_NN_data_15.indices[2][1]] / 1000) km") +axfield_25 = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_25.grid.yᵃᶜᵃ[field_NN_data_25.indices[2][1]] / 1000) km") +axfield_35 = CairoMakie.Axis(fig[2, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_35.grid.yᵃᶜᵃ[field_NN_data_35.indices[2][1]] / 1000) km") +axfield_45 = CairoMakie.Axis(fig[3, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_45.grid.yᵃᶜᵃ[field_NN_data_45.indices[2][1]] / 1000) km") +axfield_55 = CairoMakie.Axis(fig[3, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_55.grid.yᵃᶜᵃ[field_NN_data_55.indices[2][1]] / 1000) km") +axfield_65 = CairoMakie.Axis(fig[4, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_65.grid.yᵃᶜᵃ[field_NN_data_65.indices[2][1]] / 1000) km") +axfield_75 = CairoMakie.Axis(fig[4, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_75.grid.yᵃᶜᵃ[field_NN_data_75.indices[2][1]] / 1000) km") +axfield_85 = CairoMakie.Axis(fig[5, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_85.grid.yᵃᶜᵃ[field_NN_data_85.indices[2][1]] / 1000) km") +axfield_95 = CairoMakie.Axis(fig[5, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_95.grid.yᵃᶜᵃ[field_NN_data_95.indices[2][1]] / 1000) km") + +axflux_5 = CairoMakie.Axis(fig[1, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_5.grid.yᵃᶜᵃ[flux_NN_data_5.indices[2][1]] / 1000) km") +axflux_15 = CairoMakie.Axis(fig[1, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_15.grid.yᵃᶜᵃ[flux_NN_data_15.indices[2][1]] / 1000) km") +axflux_25 = CairoMakie.Axis(fig[2, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_25.grid.yᵃᶜᵃ[flux_NN_data_25.indices[2][1]] / 1000) km") +axflux_35 = CairoMakie.Axis(fig[2, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_35.grid.yᵃᶜᵃ[flux_NN_data_35.indices[2][1]] / 1000) km") +axflux_45 = CairoMakie.Axis(fig[3, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_45.grid.yᵃᶜᵃ[flux_NN_data_45.indices[2][1]] / 1000) km") +axflux_55 = CairoMakie.Axis(fig[3, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_55.grid.yᵃᶜᵃ[flux_NN_data_55.indices[2][1]] / 1000) km") +axflux_65 = CairoMakie.Axis(fig[4, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_65.grid.yᵃᶜᵃ[flux_NN_data_65.indices[2][1]] / 1000) km") +axflux_75 = CairoMakie.Axis(fig[4, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_75.grid.yᵃᶜᵃ[flux_NN_data_75.indices[2][1]] / 1000) km") +axflux_85 = CairoMakie.Axis(fig[5, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_85.grid.yᵃᶜᵃ[flux_NN_data_85.indices[2][1]] / 1000) km") +axflux_95 = CairoMakie.Axis(fig[5, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_95.grid.yᵃᶜᵃ[flux_NN_data_95.indices[2][1]] / 1000) km") + +axfields = [axfield_5, axfield_15, axfield_25, axfield_35, axfield_45, axfield_55, axfield_65, axfield_75, axfield_85, axfield_95] +axfluxes = [axflux_5, axflux_15, axflux_25, axflux_35, axflux_45, axflux_55, axflux_65, axflux_75, axflux_85, axflux_95] + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lims = [(find_min(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices))) for field_data in field_datas] + +flux_lims = [(find_min(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices))) for flux_data in flux_datas] + +flux_lims = [(-maximum(abs, flux_lim), maximum(abs, flux_lim)) for flux_lim in flux_lims] + +NNₙs = [@lift interior(field_data[$n], :, 1, zC_indices) for field_data in field_datas] +fluxₙs = [@lift interior(flux_data[$n], :, 1, zF_indices) for flux_data in flux_datas] + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_5_surface = heatmap!(axfield_5, xC, zC[zC_indices], NN_5ₙ, colormap=colorscheme_field, colorrange=field_lims[1]) +field_15_surface = heatmap!(axfield_15, xC, zC[zC_indices], NN_15ₙ, colormap=colorscheme_field, colorrange=field_lims[2]) +field_25_surface = heatmap!(axfield_25, xC, zC[zC_indices], NN_25ₙ, colormap=colorscheme_field, colorrange=field_lims[3]) +field_35_surface = heatmap!(axfield_35, xC, zC[zC_indices], NN_35ₙ, colormap=colorscheme_field, colorrange=field_lims[4]) +field_45_surface = heatmap!(axfield_45, xC, zC[zC_indices], NN_45ₙ, colormap=colorscheme_field, colorrange=field_lims[5]) +field_55_surface = heatmap!(axfield_55, xC, zC[zC_indices], NN_55ₙ, colormap=colorscheme_field, colorrange=field_lims[6]) +field_65_surface = heatmap!(axfield_65, xC, zC[zC_indices], NN_65ₙ, colormap=colorscheme_field, colorrange=field_lims[7]) +field_75_surface = heatmap!(axfield_75, xC, zC[zC_indices], NN_75ₙ, colormap=colorscheme_field, colorrange=field_lims[8]) +field_85_surface = heatmap!(axfield_85, xC, zC[zC_indices], NN_85ₙ, colormap=colorscheme_field, colorrange=field_lims[9]) +field_95_surface = heatmap!(axfield_95, xC, zC[zC_indices], NN_95ₙ, colormap=colorscheme_field, colorrange=field_lims[10]) + +flux_5_surface = heatmap!(axflux_5, xC, zC[zF_indices], flux_5ₙ, colormap=colorscheme_flux, colorrange=flux_lims[1]) +flux_15_surface = heatmap!(axflux_15, xC, zC[zF_indices], flux_15ₙ, colormap=colorscheme_flux, colorrange=flux_lims[2]) +flux_25_surface = heatmap!(axflux_25, xC, zC[zF_indices], flux_25ₙ, colormap=colorscheme_flux, colorrange=flux_lims[3]) +flux_35_surface = heatmap!(axflux_35, xC, zC[zF_indices], flux_35ₙ, colormap=colorscheme_flux, colorrange=flux_lims[4]) +flux_45_surface = heatmap!(axflux_45, xC, zC[zF_indices], flux_45ₙ, colormap=colorscheme_flux, colorrange=flux_lims[5]) +flux_55_surface = heatmap!(axflux_55, xC, zC[zF_indices], flux_55ₙ, colormap=colorscheme_flux, colorrange=flux_lims[6]) +flux_65_surface = heatmap!(axflux_65, xC, zC[zF_indices], flux_65ₙ, colormap=colorscheme_flux, colorrange=flux_lims[7]) +flux_75_surface = heatmap!(axflux_75, xC, zC[zF_indices], flux_75ₙ, colormap=colorscheme_flux, colorrange=flux_lims[8]) +flux_85_surface = heatmap!(axflux_85, xC, zC[zF_indices], flux_85ₙ, colormap=colorscheme_flux, colorrange=flux_lims[9]) +flux_95_surface = heatmap!(axflux_95, xC, zC[zF_indices], flux_95ₙ, colormap=colorscheme_flux, colorrange=flux_lims[10]) + +Colorbar(fig[1, 2], field_5_surface, label="Field") +Colorbar(fig[1, 4], flux_5_surface, label="NN Flux") +Colorbar(fig[1, 6], field_15_surface, label="Field") +Colorbar(fig[1, 8], flux_15_surface, label="NN Flux") +Colorbar(fig[2, 2], field_25_surface, label="Field") +Colorbar(fig[2, 4], flux_25_surface, label="NN Flux") +Colorbar(fig[2, 6], field_35_surface, label="Field") +Colorbar(fig[2, 8], flux_35_surface, label="NN Flux") +Colorbar(fig[3, 2], field_45_surface, label="Field") +Colorbar(fig[3, 4], flux_45_surface, label="NN Flux") +Colorbar(fig[3, 6], field_55_surface, label="Field") +Colorbar(fig[3, 8], flux_55_surface, label="NN Flux") +Colorbar(fig[4, 2], field_65_surface, label="Field") +Colorbar(fig[4, 4], flux_65_surface, label="NN Flux") +Colorbar(fig[4, 6], field_75_surface, label="Field") +Colorbar(fig[4, 8], flux_75_surface, label="NN Flux") +Colorbar(fig[5, 2], field_85_surface, label="Field") +Colorbar(fig[5, 4], flux_85_surface, label="NN Flux") +Colorbar(fig[5, 6], field_95_surface, label="Field") +Colorbar(fig[5, 8], flux_95_surface, label="NN Flux") + +for axfield in axfields + xlims!(axfield, minimum(xC), maximum(xC)) + ylims!(axfield, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +end + +for axflux in axfluxes + xlims!(axflux, minimum(xC), maximum(xC)) + ylims!(axflux, minimum(zC[zF_indices]), maximum(zC[zF_indices])) +end + +title_str = @lift "Temperature (°C), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_xzslices_fluxes_T.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end + +#%% +@info "Recording S fields and fluxes in xz" + +fieldname = "S" +fluxname = "wS_NN" + +field_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fieldname, backend=OnDisk()) +field_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fieldname, backend=OnDisk()) +field_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fieldname, backend=OnDisk()) +field_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fieldname, backend=OnDisk()) +field_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fieldname, backend=OnDisk()) +field_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fieldname, backend=OnDisk()) +field_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fieldname, backend=OnDisk()) +field_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fieldname, backend=OnDisk()) +field_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fieldname, backend=OnDisk()) +field_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fluxname, backend=OnDisk()) +flux_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fluxname, backend=OnDisk()) +flux_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fluxname, backend=OnDisk()) +flux_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fluxname, backend=OnDisk()) +flux_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fluxname, backend=OnDisk()) +flux_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fluxname, backend=OnDisk()) +flux_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fluxname, backend=OnDisk()) +flux_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fluxname, backend=OnDisk()) +flux_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fluxname, backend=OnDisk()) +flux_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fluxname, backend=OnDisk()) + +field_datas = [field_NN_data_5, field_NN_data_15, field_NN_data_25, field_NN_data_35, field_NN_data_45, field_NN_data_55, field_NN_data_65, field_NN_data_75, field_NN_data_85, field_NN_data_95] +flux_datas = [flux_NN_data_5, flux_NN_data_15, flux_NN_data_25, flux_NN_data_35, flux_NN_data_45, flux_NN_data_55, flux_NN_data_65, flux_NN_data_75, flux_NN_data_85, flux_NN_data_95] + +xC = field_NN_data_5.grid.xᶜᵃᵃ[1:field_NN_data_5.grid.Nx] +yC = field_NN_data_5.grid.yᵃᶜᵃ[1:field_NN_data_5.grid.Ny] +zC = field_NN_data_5.grid.zᵃᵃᶜ[1:field_NN_data_5.grid.Nz] +zF = field_NN_data_5.grid.zᵃᵃᶠ[1:field_NN_data_5.grid.Nz+1] + +Nt = length(field_NN_data_95) +times = field_NN_data_5.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_5 = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_5.grid.yᵃᶜᵃ[field_NN_data_5.indices[2][1]] / 1000) km") +axfield_15 = CairoMakie.Axis(fig[1, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_15.grid.yᵃᶜᵃ[field_NN_data_15.indices[2][1]] / 1000) km") +axfield_25 = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_25.grid.yᵃᶜᵃ[field_NN_data_25.indices[2][1]] / 1000) km") +axfield_35 = CairoMakie.Axis(fig[2, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_35.grid.yᵃᶜᵃ[field_NN_data_35.indices[2][1]] / 1000) km") +axfield_45 = CairoMakie.Axis(fig[3, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_45.grid.yᵃᶜᵃ[field_NN_data_45.indices[2][1]] / 1000) km") +axfield_55 = CairoMakie.Axis(fig[3, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_55.grid.yᵃᶜᵃ[field_NN_data_55.indices[2][1]] / 1000) km") +axfield_65 = CairoMakie.Axis(fig[4, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_65.grid.yᵃᶜᵃ[field_NN_data_65.indices[2][1]] / 1000) km") +axfield_75 = CairoMakie.Axis(fig[4, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_75.grid.yᵃᶜᵃ[field_NN_data_75.indices[2][1]] / 1000) km") +axfield_85 = CairoMakie.Axis(fig[5, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_85.grid.yᵃᶜᵃ[field_NN_data_85.indices[2][1]] / 1000) km") +axfield_95 = CairoMakie.Axis(fig[5, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_95.grid.yᵃᶜᵃ[field_NN_data_95.indices[2][1]] / 1000) km") + +axflux_5 = CairoMakie.Axis(fig[1, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_5.grid.yᵃᶜᵃ[flux_NN_data_5.indices[2][1]] / 1000) km") +axflux_15 = CairoMakie.Axis(fig[1, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_15.grid.yᵃᶜᵃ[flux_NN_data_15.indices[2][1]] / 1000) km") +axflux_25 = CairoMakie.Axis(fig[2, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_25.grid.yᵃᶜᵃ[flux_NN_data_25.indices[2][1]] / 1000) km") +axflux_35 = CairoMakie.Axis(fig[2, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_35.grid.yᵃᶜᵃ[flux_NN_data_35.indices[2][1]] / 1000) km") +axflux_45 = CairoMakie.Axis(fig[3, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_45.grid.yᵃᶜᵃ[flux_NN_data_45.indices[2][1]] / 1000) km") +axflux_55 = CairoMakie.Axis(fig[3, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_55.grid.yᵃᶜᵃ[flux_NN_data_55.indices[2][1]] / 1000) km") +axflux_65 = CairoMakie.Axis(fig[4, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_65.grid.yᵃᶜᵃ[flux_NN_data_65.indices[2][1]] / 1000) km") +axflux_75 = CairoMakie.Axis(fig[4, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_75.grid.yᵃᶜᵃ[flux_NN_data_75.indices[2][1]] / 1000) km") +axflux_85 = CairoMakie.Axis(fig[5, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_85.grid.yᵃᶜᵃ[flux_NN_data_85.indices[2][1]] / 1000) km") +axflux_95 = CairoMakie.Axis(fig[5, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_95.grid.yᵃᶜᵃ[flux_NN_data_95.indices[2][1]] / 1000) km") + +axfields = [axfield_5, axfield_15, axfield_25, axfield_35, axfield_45, axfield_55, axfield_65, axfield_75, axfield_85, axfield_95] +axfluxes = [axflux_5, axflux_15, axflux_25, axflux_35, axflux_45, axflux_55, axflux_65, axflux_75, axflux_85, axflux_95] + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lims = [(find_min(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices))) for field_data in field_datas] + +flux_lims = [(find_min(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices))) for flux_data in flux_datas] + +flux_lims = [(-maximum(abs, flux_lim), maximum(abs, flux_lim)) for flux_lim in flux_lims] + +NNₙs = [@lift interior(field_data[$n], :, 1, zC_indices) for field_data in field_datas] +fluxₙs = [@lift interior(flux_data[$n], :, 1, zF_indices) for flux_data in flux_datas] + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_5_surface = heatmap!(axfield_5, xC, zC[zC_indices], NN_5ₙ, colormap=colorscheme_field, colorrange=field_lims[1]) +field_15_surface = heatmap!(axfield_15, xC, zC[zC_indices], NN_15ₙ, colormap=colorscheme_field, colorrange=field_lims[2]) +field_25_surface = heatmap!(axfield_25, xC, zC[zC_indices], NN_25ₙ, colormap=colorscheme_field, colorrange=field_lims[3]) +field_35_surface = heatmap!(axfield_35, xC, zC[zC_indices], NN_35ₙ, colormap=colorscheme_field, colorrange=field_lims[4]) +field_45_surface = heatmap!(axfield_45, xC, zC[zC_indices], NN_45ₙ, colormap=colorscheme_field, colorrange=field_lims[5]) +field_55_surface = heatmap!(axfield_55, xC, zC[zC_indices], NN_55ₙ, colormap=colorscheme_field, colorrange=field_lims[6]) +field_65_surface = heatmap!(axfield_65, xC, zC[zC_indices], NN_65ₙ, colormap=colorscheme_field, colorrange=field_lims[7]) +field_75_surface = heatmap!(axfield_75, xC, zC[zC_indices], NN_75ₙ, colormap=colorscheme_field, colorrange=field_lims[8]) +field_85_surface = heatmap!(axfield_85, xC, zC[zC_indices], NN_85ₙ, colormap=colorscheme_field, colorrange=field_lims[9]) +field_95_surface = heatmap!(axfield_95, xC, zC[zC_indices], NN_95ₙ, colormap=colorscheme_field, colorrange=field_lims[10]) + +flux_5_surface = heatmap!(axflux_5, xC, zC[zF_indices], flux_5ₙ, colormap=colorscheme_flux, colorrange=flux_lims[1]) +flux_15_surface = heatmap!(axflux_15, xC, zC[zF_indices], flux_15ₙ, colormap=colorscheme_flux, colorrange=flux_lims[2]) +flux_25_surface = heatmap!(axflux_25, xC, zC[zF_indices], flux_25ₙ, colormap=colorscheme_flux, colorrange=flux_lims[3]) +flux_35_surface = heatmap!(axflux_35, xC, zC[zF_indices], flux_35ₙ, colormap=colorscheme_flux, colorrange=flux_lims[4]) +flux_45_surface = heatmap!(axflux_45, xC, zC[zF_indices], flux_45ₙ, colormap=colorscheme_flux, colorrange=flux_lims[5]) +flux_55_surface = heatmap!(axflux_55, xC, zC[zF_indices], flux_55ₙ, colormap=colorscheme_flux, colorrange=flux_lims[6]) +flux_65_surface = heatmap!(axflux_65, xC, zC[zF_indices], flux_65ₙ, colormap=colorscheme_flux, colorrange=flux_lims[7]) +flux_75_surface = heatmap!(axflux_75, xC, zC[zF_indices], flux_75ₙ, colormap=colorscheme_flux, colorrange=flux_lims[8]) +flux_85_surface = heatmap!(axflux_85, xC, zC[zF_indices], flux_85ₙ, colormap=colorscheme_flux, colorrange=flux_lims[9]) +flux_95_surface = heatmap!(axflux_95, xC, zC[zF_indices], flux_95ₙ, colormap=colorscheme_flux, colorrange=flux_lims[10]) + +Colorbar(fig[1, 2], field_5_surface, label="Field") +Colorbar(fig[1, 4], flux_5_surface, label="NN Flux") +Colorbar(fig[1, 6], field_15_surface, label="Field") +Colorbar(fig[1, 8], flux_15_surface, label="NN Flux") +Colorbar(fig[2, 2], field_25_surface, label="Field") +Colorbar(fig[2, 4], flux_25_surface, label="NN Flux") +Colorbar(fig[2, 6], field_35_surface, label="Field") +Colorbar(fig[2, 8], flux_35_surface, label="NN Flux") +Colorbar(fig[3, 2], field_45_surface, label="Field") +Colorbar(fig[3, 4], flux_45_surface, label="NN Flux") +Colorbar(fig[3, 6], field_55_surface, label="Field") +Colorbar(fig[3, 8], flux_55_surface, label="NN Flux") +Colorbar(fig[4, 2], field_65_surface, label="Field") +Colorbar(fig[4, 4], flux_65_surface, label="NN Flux") +Colorbar(fig[4, 6], field_75_surface, label="Field") +Colorbar(fig[4, 8], flux_75_surface, label="NN Flux") +Colorbar(fig[5, 2], field_85_surface, label="Field") +Colorbar(fig[5, 4], flux_85_surface, label="NN Flux") +Colorbar(fig[5, 6], field_95_surface, label="Field") +Colorbar(fig[5, 8], flux_95_surface, label="NN Flux") + +for axfield in axfields + xlims!(axfield, minimum(xC), maximum(xC)) + ylims!(axfield, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +end + +for axflux in axfluxes + xlims!(axflux, minimum(xC), maximum(xC)) + ylims!(axflux, minimum(zC[zF_indices]), maximum(zC[zF_indices])) +end + +title_str = @lift "Salinity (psu), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_xzslices_fluxes_S.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end +#%% \ No newline at end of file diff --git a/NNclosure_Ri_nof_BBLRifirstzone510_doublegyre_model_modewater.jl b/NNclosure_Ri_nof_BBLRifirstzone510_doublegyre_model_modewater.jl new file mode 100644 index 0000000000..1f107c5cb9 --- /dev/null +++ b/NNclosure_Ri_nof_BBLRifirstzone510_doublegyre_model_modewater.jl @@ -0,0 +1,1275 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("NN_closure_global_Ri_nof_BBLRifirstzone510.jl") +include("xin_kai_vertical_diffusivity_local_2step_new.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + +#%% +filename = "doublegyre_30Cwarmflushbottom10_relaxation_30days_modewater_zWENO5_NN_closure_NDE5_Ri_BBLRifirztzone510_temp" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +nn_closure = NNFluxClosure(model_architecture) +base_closure = XinKaiLocalVerticalDiffusivity() +closure = (base_closure, nn_closure) + +advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), WENO(order=5)) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +const X₀ = -Lx/2 + 800kilometers +const Y₀ = -Ly/2 + 1500kilometers +const R₀ = 700kilometers +const Qᵀ_mode = 4.5e-4 +const σ_mode = 20kilometers + +##### +##### Forcing and initial condition +##### +# @inline T_initial(x, y, z) = T_north + ΔT / 2 * (1 + z / Lz) +# @inline T_initial(x, y, z) = (T_north + T_south / 2) + 5 * (1 + z / Lz) +@inline T_initial(x, y, z) = 10 + 20 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y + +@inline Qᵀ_winter(t) = max(0, -Qᵀ_mode * sin(2π * t / 360days)) +@inline Qᵀ_subpolar(x, y, t) = ifelse((x - X₀)^2 + (y - Y₀)^2 <= R₀^2, Qᵀ_winter(t), + exp(-(sqrt((x - X₀)^2 + (y - Y₀)^2) - R₀)^2 / (2 * σ_mode^2)) * Qᵀ_winter(t)) + +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) + Qᵀ_subpolar(x, y, t) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10800days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); +first_index = model.diffusivity_fields[2].first_index +last_index = model.diffusivity_fields[2].last_index +wT_NN = model.diffusivity_fields[2].wT +wS_NN = model.diffusivity_fields[2].wS + +κ = model.diffusivity_fields[1].κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) + +wT = wT_NN + wT_base +wS = wS_NN + wS_base + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +@inline function get_top_buoyancy_flux(i, j, k, grid, buoyancy, T_bc, S_bc, velocities, tracers, clock) + return top_buoyancy_flux(i, j, grid, buoyancy, (; T=T_bc, S=S_bc), clock, merge(velocities, tracers)) +end + +Qb = KernelFunctionOperation{Center, Center, Nothing}(get_top_buoyancy_flux, model.grid, model.buoyancy, T.boundary_conditions.top, S.boundary_conditions.top, model.velocities, model.tracers, model.clock) +Qb = Field(Qb) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +wT_NNbar_zonal = Average(wT_NN, dims=1) +wS_NNbar_zonal = Average(wS_NN, dims=1) + +wT_basebar_zonal = Average(wT_base, dims=1) +wS_basebar_zonal = Average(wS_base, dims=1) + +wTbar_zonal = Average(wT, dims=1) +wSbar_zonal = Average(wS, dims=1) + +outputs = (; u, v, w, T, S, ρ, N², wT_NN, wS_NN, wT_base, wS_base, wT, wS) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal, wT_NNbar_zonal, wS_NNbar_zonal, wT_basebar_zonal, wS_basebar_zonal, wTbar_zonal, wSbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_5] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_5", + indices = (:, 5, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_15] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_15", + indices = (:, 15, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_25] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_25", + indices = (:, 25, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_35] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_35", + indices = (:, 35, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_45] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_45", + indices = (:, 45, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_55] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_55", + indices = (:, 55, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_65] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_65", + indices = (:, 65, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_75] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_75", + indices = (:, 75, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_85] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_85", + indices = (:, 85, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_95] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_95", + indices = (:, 95, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:BBL] = JLD2OutputWriter(model, (; first_index, last_index, Qb), + filename = "$(FILE_DIR)/instantaneous_fields_NN_active_diagnostics", + indices = (:, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end + +#%% +@info "Recording T fields and fluxes in yz" + +fieldname = "T" +fluxname = "wT_NN" +field_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fieldname, backend=OnDisk()) +field_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fieldname, backend=OnDisk()) +field_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fieldname, backend=OnDisk()) +field_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fieldname, backend=OnDisk()) +field_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fieldname, backend=OnDisk()) +field_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fieldname, backend=OnDisk()) +field_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fieldname, backend=OnDisk()) +field_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fieldname, backend=OnDisk()) +field_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fieldname, backend=OnDisk()) +field_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fluxname, backend=OnDisk()) +flux_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fluxname, backend=OnDisk()) +flux_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fluxname, backend=OnDisk()) +flux_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fluxname, backend=OnDisk()) +flux_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fluxname, backend=OnDisk()) +flux_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fluxname, backend=OnDisk()) +flux_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fluxname, backend=OnDisk()) +flux_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fluxname, backend=OnDisk()) +flux_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fluxname, backend=OnDisk()) +flux_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fluxname, backend=OnDisk()) + +xC = field_NN_data_00.grid.xᶜᵃᵃ[1:field_NN_data_00.grid.Nx] +yC = field_NN_data_00.grid.yᵃᶜᵃ[1:field_NN_data_00.grid.Ny] +zC = field_NN_data_00.grid.zᵃᵃᶜ[1:field_NN_data_00.grid.Nz] +zF = field_NN_data_00.grid.zᵃᵃᶠ[1:field_NN_data_00.grid.Nz+1] + +Nt = length(field_NN_data_90) +times = field_NN_data_00.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_00 = CairoMakie.Axis(fig[1, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_00.grid.xᶜᵃᵃ[field_NN_data_00.indices[1][1]] / 1000) km") +axfield_10 = CairoMakie.Axis(fig[1, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_10.grid.xᶜᵃᵃ[field_NN_data_10.indices[1][1]] / 1000) km") +axfield_20 = CairoMakie.Axis(fig[2, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_20.grid.xᶜᵃᵃ[field_NN_data_20.indices[1][1]] / 1000) km") +axfield_30 = CairoMakie.Axis(fig[2, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_30.grid.xᶜᵃᵃ[field_NN_data_30.indices[1][1]] / 1000) km") +axfield_40 = CairoMakie.Axis(fig[3, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_40.grid.xᶜᵃᵃ[field_NN_data_40.indices[1][1]] / 1000) km") +axfield_50 = CairoMakie.Axis(fig[3, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_50.grid.xᶜᵃᵃ[field_NN_data_50.indices[1][1]] / 1000) km") +axfield_60 = CairoMakie.Axis(fig[4, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_60.grid.xᶜᵃᵃ[field_NN_data_60.indices[1][1]] / 1000) km") +axfield_70 = CairoMakie.Axis(fig[4, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_70.grid.xᶜᵃᵃ[field_NN_data_70.indices[1][1]] / 1000) km") +axfield_80 = CairoMakie.Axis(fig[5, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_80.grid.xᶜᵃᵃ[field_NN_data_80.indices[1][1]] / 1000) km") +axfield_90 = CairoMakie.Axis(fig[5, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_90.grid.xᶜᵃᵃ[field_NN_data_90.indices[1][1]] / 1000) km") + +axflux_00 = CairoMakie.Axis(fig[1, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_00.grid.xᶜᵃᵃ[flux_NN_data_00.indices[1][1]] / 1000) km") +axflux_10 = CairoMakie.Axis(fig[1, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_10.grid.xᶜᵃᵃ[flux_NN_data_10.indices[1][1]] / 1000) km") +axflux_20 = CairoMakie.Axis(fig[2, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_20.grid.xᶜᵃᵃ[flux_NN_data_20.indices[1][1]] / 1000) km") +axflux_30 = CairoMakie.Axis(fig[2, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_30.grid.xᶜᵃᵃ[flux_NN_data_30.indices[1][1]] / 1000) km") +axflux_40 = CairoMakie.Axis(fig[3, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_40.grid.xᶜᵃᵃ[flux_NN_data_40.indices[1][1]] / 1000) km") +axflux_50 = CairoMakie.Axis(fig[3, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_50.grid.xᶜᵃᵃ[flux_NN_data_50.indices[1][1]] / 1000) km") +axflux_60 = CairoMakie.Axis(fig[4, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_60.grid.xᶜᵃᵃ[flux_NN_data_60.indices[1][1]] / 1000) km") +axflux_70 = CairoMakie.Axis(fig[4, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_70.grid.xᶜᵃᵃ[flux_NN_data_70.indices[1][1]] / 1000) km") +axflux_80 = CairoMakie.Axis(fig[5, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_80.grid.xᶜᵃᵃ[flux_NN_data_80.indices[1][1]] / 1000) km") +axflux_90 = CairoMakie.Axis(fig[5, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_90.grid.xᶜᵃᵃ[flux_NN_data_90.indices[1][1]] / 1000) km") + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lim = (find_min(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (find_min(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (-maximum(abs, flux_lim), maximum(abs, flux_lim)) + +NN_00ₙ = @lift interior(field_NN_data_00[$n], 1, :, zC_indices) +NN_10ₙ = @lift interior(field_NN_data_10[$n], 1, :, zC_indices) +NN_20ₙ = @lift interior(field_NN_data_20[$n], 1, :, zC_indices) +NN_30ₙ = @lift interior(field_NN_data_30[$n], 1, :, zC_indices) +NN_40ₙ = @lift interior(field_NN_data_40[$n], 1, :, zC_indices) +NN_50ₙ = @lift interior(field_NN_data_50[$n], 1, :, zC_indices) +NN_60ₙ = @lift interior(field_NN_data_60[$n], 1, :, zC_indices) +NN_70ₙ = @lift interior(field_NN_data_70[$n], 1, :, zC_indices) +NN_80ₙ = @lift interior(field_NN_data_80[$n], 1, :, zC_indices) +NN_90ₙ = @lift interior(field_NN_data_90[$n], 1, :, zC_indices) + +flux_00ₙ = @lift interior(flux_NN_data_00[$n], 1, :, zF_indices) +flux_10ₙ = @lift interior(flux_NN_data_10[$n], 1, :, zF_indices) +flux_20ₙ = @lift interior(flux_NN_data_20[$n], 1, :, zF_indices) +flux_30ₙ = @lift interior(flux_NN_data_30[$n], 1, :, zF_indices) +flux_40ₙ = @lift interior(flux_NN_data_40[$n], 1, :, zF_indices) +flux_50ₙ = @lift interior(flux_NN_data_50[$n], 1, :, zF_indices) +flux_60ₙ = @lift interior(flux_NN_data_60[$n], 1, :, zF_indices) +flux_70ₙ = @lift interior(flux_NN_data_70[$n], 1, :, zF_indices) +flux_80ₙ = @lift interior(flux_NN_data_80[$n], 1, :, zF_indices) +flux_90ₙ = @lift interior(flux_NN_data_90[$n], 1, :, zF_indices) + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_00_surface = heatmap!(axfield_00, yC, zC[zC_indices], NN_00ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_10_surface = heatmap!(axfield_10, yC, zC[zC_indices], NN_10ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_20_surface = heatmap!(axfield_20, yC, zC[zC_indices], NN_20ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_30_surface = heatmap!(axfield_30, yC, zC[zC_indices], NN_30ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_40_surface = heatmap!(axfield_40, yC, zC[zC_indices], NN_40ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_50_surface = heatmap!(axfield_50, yC, zC[zC_indices], NN_50ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_60_surface = heatmap!(axfield_60, yC, zC[zC_indices], NN_60ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_70_surface = heatmap!(axfield_70, yC, zC[zC_indices], NN_70ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_80_surface = heatmap!(axfield_80, yC, zC[zC_indices], NN_80ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_90_surface = heatmap!(axfield_90, yC, zC[zC_indices], NN_90ₙ, colormap=colorscheme_field, colorrange=field_lim) + +flux_00_surface = heatmap!(axflux_00, yC, zC[zF_indices], flux_00ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_10_surface = heatmap!(axflux_10, yC, zC[zF_indices], flux_10ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_20_surface = heatmap!(axflux_20, yC, zC[zF_indices], flux_20ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_30_surface = heatmap!(axflux_30, yC, zC[zF_indices], flux_30ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_40_surface = heatmap!(axflux_40, yC, zC[zF_indices], flux_40ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_50_surface = heatmap!(axflux_50, yC, zC[zF_indices], flux_50ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_60_surface = heatmap!(axflux_60, yC, zC[zF_indices], flux_60ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_70_surface = heatmap!(axflux_70, yC, zC[zF_indices], flux_70ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_80_surface = heatmap!(axflux_80, yC, zC[zF_indices], flux_80ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_90_surface = heatmap!(axflux_90, yC, zC[zF_indices], flux_90ₙ, colormap=colorscheme_flux, colorrange=flux_lim) + +Colorbar(fig[1:5, 2], field_00_surface, label="Field") +Colorbar(fig[1:5, 4], flux_00_surface, label="NN Flux") +Colorbar(fig[1:5, 6], field_00_surface, label="Field") +Colorbar(fig[1:5, 8], flux_00_surface, label="NN Flux") + +xlims!(axfield_00, minimum(yC), maximum(yC)) +xlims!(axfield_10, minimum(yC), maximum(yC)) +xlims!(axfield_20, minimum(yC), maximum(yC)) +xlims!(axfield_30, minimum(yC), maximum(yC)) +xlims!(axfield_40, minimum(yC), maximum(yC)) +xlims!(axfield_50, minimum(yC), maximum(yC)) +xlims!(axfield_60, minimum(yC), maximum(yC)) +xlims!(axfield_70, minimum(yC), maximum(yC)) +xlims!(axfield_80, minimum(yC), maximum(yC)) +xlims!(axfield_90, minimum(yC), maximum(yC)) + +ylims!(axfield_00, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_10, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_20, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_30, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_40, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_50, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_60, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_70, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_80, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_90, minimum(zC[zC_indices]), maximum(zC[zC_indices])) + +xlims!(axflux_00, minimum(yC), maximum(yC)) +xlims!(axflux_10, minimum(yC), maximum(yC)) +xlims!(axflux_20, minimum(yC), maximum(yC)) +xlims!(axflux_30, minimum(yC), maximum(yC)) +xlims!(axflux_40, minimum(yC), maximum(yC)) +xlims!(axflux_50, minimum(yC), maximum(yC)) +xlims!(axflux_60, minimum(yC), maximum(yC)) +xlims!(axflux_70, minimum(yC), maximum(yC)) +xlims!(axflux_80, minimum(yC), maximum(yC)) +xlims!(axflux_90, minimum(yC), maximum(yC)) + +ylims!(axflux_00, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_10, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_20, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_30, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_40, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_50, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_60, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_70, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_80, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_90, minimum(zF[zF_indices]), maximum(zF[zF_indices])) + +title_str = @lift "Temperature (°C), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +# save("./Output/compare_3D_instantaneous_fields_slices_NNclosure_fluxes.png", fig) +# display(fig) +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_yzslices_fluxes_T.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end +#%% +@info "Recording S fields and fluxes in yz" + +fieldname = "S" +fluxname = "wS_NN" +field_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fieldname, backend=OnDisk()) +field_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fieldname, backend=OnDisk()) +field_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fieldname, backend=OnDisk()) +field_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fieldname, backend=OnDisk()) +field_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fieldname, backend=OnDisk()) +field_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fieldname, backend=OnDisk()) +field_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fieldname, backend=OnDisk()) +field_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fieldname, backend=OnDisk()) +field_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fieldname, backend=OnDisk()) +field_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fluxname, backend=OnDisk()) +flux_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fluxname, backend=OnDisk()) +flux_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fluxname, backend=OnDisk()) +flux_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fluxname, backend=OnDisk()) +flux_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fluxname, backend=OnDisk()) +flux_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fluxname, backend=OnDisk()) +flux_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fluxname, backend=OnDisk()) +flux_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fluxname, backend=OnDisk()) +flux_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fluxname, backend=OnDisk()) +flux_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fluxname, backend=OnDisk()) + +xC = field_NN_data_00.grid.xᶜᵃᵃ[1:field_NN_data_00.grid.Nx] +yC = field_NN_data_00.grid.yᵃᶜᵃ[1:field_NN_data_00.grid.Ny] +zC = field_NN_data_00.grid.zᵃᵃᶜ[1:field_NN_data_00.grid.Nz] +zF = field_NN_data_00.grid.zᵃᵃᶠ[1:field_NN_data_00.grid.Nz+1] + +Nt = length(field_NN_data_90) +times = field_NN_data_00.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_00 = CairoMakie.Axis(fig[1, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_00.grid.xᶜᵃᵃ[field_NN_data_00.indices[1][1]] / 1000) km") +axfield_10 = CairoMakie.Axis(fig[1, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_10.grid.xᶜᵃᵃ[field_NN_data_10.indices[1][1]] / 1000) km") +axfield_20 = CairoMakie.Axis(fig[2, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_20.grid.xᶜᵃᵃ[field_NN_data_20.indices[1][1]] / 1000) km") +axfield_30 = CairoMakie.Axis(fig[2, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_30.grid.xᶜᵃᵃ[field_NN_data_30.indices[1][1]] / 1000) km") +axfield_40 = CairoMakie.Axis(fig[3, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_40.grid.xᶜᵃᵃ[field_NN_data_40.indices[1][1]] / 1000) km") +axfield_50 = CairoMakie.Axis(fig[3, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_50.grid.xᶜᵃᵃ[field_NN_data_50.indices[1][1]] / 1000) km") +axfield_60 = CairoMakie.Axis(fig[4, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_60.grid.xᶜᵃᵃ[field_NN_data_60.indices[1][1]] / 1000) km") +axfield_70 = CairoMakie.Axis(fig[4, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_70.grid.xᶜᵃᵃ[field_NN_data_70.indices[1][1]] / 1000) km") +axfield_80 = CairoMakie.Axis(fig[5, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_80.grid.xᶜᵃᵃ[field_NN_data_80.indices[1][1]] / 1000) km") +axfield_90 = CairoMakie.Axis(fig[5, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_90.grid.xᶜᵃᵃ[field_NN_data_90.indices[1][1]] / 1000) km") + +axflux_00 = CairoMakie.Axis(fig[1, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_00.grid.xᶜᵃᵃ[flux_NN_data_00.indices[1][1]] / 1000) km") +axflux_10 = CairoMakie.Axis(fig[1, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_10.grid.xᶜᵃᵃ[flux_NN_data_10.indices[1][1]] / 1000) km") +axflux_20 = CairoMakie.Axis(fig[2, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_20.grid.xᶜᵃᵃ[flux_NN_data_20.indices[1][1]] / 1000) km") +axflux_30 = CairoMakie.Axis(fig[2, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_30.grid.xᶜᵃᵃ[flux_NN_data_30.indices[1][1]] / 1000) km") +axflux_40 = CairoMakie.Axis(fig[3, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_40.grid.xᶜᵃᵃ[flux_NN_data_40.indices[1][1]] / 1000) km") +axflux_50 = CairoMakie.Axis(fig[3, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_50.grid.xᶜᵃᵃ[flux_NN_data_50.indices[1][1]] / 1000) km") +axflux_60 = CairoMakie.Axis(fig[4, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_60.grid.xᶜᵃᵃ[flux_NN_data_60.indices[1][1]] / 1000) km") +axflux_70 = CairoMakie.Axis(fig[4, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_70.grid.xᶜᵃᵃ[flux_NN_data_70.indices[1][1]] / 1000) km") +axflux_80 = CairoMakie.Axis(fig[5, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_80.grid.xᶜᵃᵃ[flux_NN_data_80.indices[1][1]] / 1000) km") +axflux_90 = CairoMakie.Axis(fig[5, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_90.grid.xᶜᵃᵃ[flux_NN_data_90.indices[1][1]] / 1000) km") + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lim = (find_min(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (find_min(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (-maximum(abs, flux_lim), maximum(abs, flux_lim)) + +NN_00ₙ = @lift interior(field_NN_data_00[$n], 1, :, zC_indices) +NN_10ₙ = @lift interior(field_NN_data_10[$n], 1, :, zC_indices) +NN_20ₙ = @lift interior(field_NN_data_20[$n], 1, :, zC_indices) +NN_30ₙ = @lift interior(field_NN_data_30[$n], 1, :, zC_indices) +NN_40ₙ = @lift interior(field_NN_data_40[$n], 1, :, zC_indices) +NN_50ₙ = @lift interior(field_NN_data_50[$n], 1, :, zC_indices) +NN_60ₙ = @lift interior(field_NN_data_60[$n], 1, :, zC_indices) +NN_70ₙ = @lift interior(field_NN_data_70[$n], 1, :, zC_indices) +NN_80ₙ = @lift interior(field_NN_data_80[$n], 1, :, zC_indices) +NN_90ₙ = @lift interior(field_NN_data_90[$n], 1, :, zC_indices) + +flux_00ₙ = @lift interior(flux_NN_data_00[$n], 1, :, zF_indices) +flux_10ₙ = @lift interior(flux_NN_data_10[$n], 1, :, zF_indices) +flux_20ₙ = @lift interior(flux_NN_data_20[$n], 1, :, zF_indices) +flux_30ₙ = @lift interior(flux_NN_data_30[$n], 1, :, zF_indices) +flux_40ₙ = @lift interior(flux_NN_data_40[$n], 1, :, zF_indices) +flux_50ₙ = @lift interior(flux_NN_data_50[$n], 1, :, zF_indices) +flux_60ₙ = @lift interior(flux_NN_data_60[$n], 1, :, zF_indices) +flux_70ₙ = @lift interior(flux_NN_data_70[$n], 1, :, zF_indices) +flux_80ₙ = @lift interior(flux_NN_data_80[$n], 1, :, zF_indices) +flux_90ₙ = @lift interior(flux_NN_data_90[$n], 1, :, zF_indices) + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_00_surface = heatmap!(axfield_00, yC, zC[zC_indices], NN_00ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_10_surface = heatmap!(axfield_10, yC, zC[zC_indices], NN_10ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_20_surface = heatmap!(axfield_20, yC, zC[zC_indices], NN_20ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_30_surface = heatmap!(axfield_30, yC, zC[zC_indices], NN_30ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_40_surface = heatmap!(axfield_40, yC, zC[zC_indices], NN_40ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_50_surface = heatmap!(axfield_50, yC, zC[zC_indices], NN_50ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_60_surface = heatmap!(axfield_60, yC, zC[zC_indices], NN_60ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_70_surface = heatmap!(axfield_70, yC, zC[zC_indices], NN_70ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_80_surface = heatmap!(axfield_80, yC, zC[zC_indices], NN_80ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_90_surface = heatmap!(axfield_90, yC, zC[zC_indices], NN_90ₙ, colormap=colorscheme_field, colorrange=field_lim) + +flux_00_surface = heatmap!(axflux_00, yC, zC[zF_indices], flux_00ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_10_surface = heatmap!(axflux_10, yC, zC[zF_indices], flux_10ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_20_surface = heatmap!(axflux_20, yC, zC[zF_indices], flux_20ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_30_surface = heatmap!(axflux_30, yC, zC[zF_indices], flux_30ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_40_surface = heatmap!(axflux_40, yC, zC[zF_indices], flux_40ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_50_surface = heatmap!(axflux_50, yC, zC[zF_indices], flux_50ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_60_surface = heatmap!(axflux_60, yC, zC[zF_indices], flux_60ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_70_surface = heatmap!(axflux_70, yC, zC[zF_indices], flux_70ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_80_surface = heatmap!(axflux_80, yC, zC[zF_indices], flux_80ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_90_surface = heatmap!(axflux_90, yC, zC[zF_indices], flux_90ₙ, colormap=colorscheme_flux, colorrange=flux_lim) + +Colorbar(fig[1:5, 2], field_00_surface, label="Field") +Colorbar(fig[1:5, 4], flux_00_surface, label="NN Flux") +Colorbar(fig[1:5, 6], field_00_surface, label="Field") +Colorbar(fig[1:5, 8], flux_00_surface, label="NN Flux") + +xlims!(axfield_00, minimum(yC), maximum(yC)) +xlims!(axfield_10, minimum(yC), maximum(yC)) +xlims!(axfield_20, minimum(yC), maximum(yC)) +xlims!(axfield_30, minimum(yC), maximum(yC)) +xlims!(axfield_40, minimum(yC), maximum(yC)) +xlims!(axfield_50, minimum(yC), maximum(yC)) +xlims!(axfield_60, minimum(yC), maximum(yC)) +xlims!(axfield_70, minimum(yC), maximum(yC)) +xlims!(axfield_80, minimum(yC), maximum(yC)) +xlims!(axfield_90, minimum(yC), maximum(yC)) + +ylims!(axfield_00, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_10, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_20, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_30, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_40, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_50, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_60, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_70, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_80, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_90, minimum(zC[zC_indices]), maximum(zC[zC_indices])) + +xlims!(axflux_00, minimum(yC), maximum(yC)) +xlims!(axflux_10, minimum(yC), maximum(yC)) +xlims!(axflux_20, minimum(yC), maximum(yC)) +xlims!(axflux_30, minimum(yC), maximum(yC)) +xlims!(axflux_40, minimum(yC), maximum(yC)) +xlims!(axflux_50, minimum(yC), maximum(yC)) +xlims!(axflux_60, minimum(yC), maximum(yC)) +xlims!(axflux_70, minimum(yC), maximum(yC)) +xlims!(axflux_80, minimum(yC), maximum(yC)) +xlims!(axflux_90, minimum(yC), maximum(yC)) + +ylims!(axflux_00, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_10, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_20, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_30, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_40, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_50, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_60, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_70, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_80, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_90, minimum(zF[zF_indices]), maximum(zF[zF_indices])) + +title_str = @lift "Salinity (psu), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_yzslices_fluxes_S.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +#%% +@info "Recording T fields and fluxes in xz" + +fieldname = "T" +fluxname = "wT_NN" + +field_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fieldname, backend=OnDisk()) +field_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fieldname, backend=OnDisk()) +field_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fieldname, backend=OnDisk()) +field_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fieldname, backend=OnDisk()) +field_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fieldname, backend=OnDisk()) +field_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fieldname, backend=OnDisk()) +field_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fieldname, backend=OnDisk()) +field_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fieldname, backend=OnDisk()) +field_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fieldname, backend=OnDisk()) +field_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fluxname, backend=OnDisk()) +flux_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fluxname, backend=OnDisk()) +flux_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fluxname, backend=OnDisk()) +flux_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fluxname, backend=OnDisk()) +flux_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fluxname, backend=OnDisk()) +flux_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fluxname, backend=OnDisk()) +flux_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fluxname, backend=OnDisk()) +flux_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fluxname, backend=OnDisk()) +flux_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fluxname, backend=OnDisk()) +flux_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fluxname, backend=OnDisk()) + +field_datas = [field_NN_data_5, field_NN_data_15, field_NN_data_25, field_NN_data_35, field_NN_data_45, field_NN_data_55, field_NN_data_65, field_NN_data_75, field_NN_data_85, field_NN_data_95] +flux_datas = [flux_NN_data_5, flux_NN_data_15, flux_NN_data_25, flux_NN_data_35, flux_NN_data_45, flux_NN_data_55, flux_NN_data_65, flux_NN_data_75, flux_NN_data_85, flux_NN_data_95] + +xC = field_NN_data_5.grid.xᶜᵃᵃ[1:field_NN_data_5.grid.Nx] +yC = field_NN_data_5.grid.yᵃᶜᵃ[1:field_NN_data_5.grid.Ny] +zC = field_NN_data_5.grid.zᵃᵃᶜ[1:field_NN_data_5.grid.Nz] +zF = field_NN_data_5.grid.zᵃᵃᶠ[1:field_NN_data_5.grid.Nz+1] + +Nt = length(field_NN_data_95) +times = field_NN_data_5.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_5 = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_5.grid.yᵃᶜᵃ[field_NN_data_5.indices[2][1]] / 1000) km") +axfield_15 = CairoMakie.Axis(fig[1, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_15.grid.yᵃᶜᵃ[field_NN_data_15.indices[2][1]] / 1000) km") +axfield_25 = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_25.grid.yᵃᶜᵃ[field_NN_data_25.indices[2][1]] / 1000) km") +axfield_35 = CairoMakie.Axis(fig[2, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_35.grid.yᵃᶜᵃ[field_NN_data_35.indices[2][1]] / 1000) km") +axfield_45 = CairoMakie.Axis(fig[3, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_45.grid.yᵃᶜᵃ[field_NN_data_45.indices[2][1]] / 1000) km") +axfield_55 = CairoMakie.Axis(fig[3, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_55.grid.yᵃᶜᵃ[field_NN_data_55.indices[2][1]] / 1000) km") +axfield_65 = CairoMakie.Axis(fig[4, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_65.grid.yᵃᶜᵃ[field_NN_data_65.indices[2][1]] / 1000) km") +axfield_75 = CairoMakie.Axis(fig[4, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_75.grid.yᵃᶜᵃ[field_NN_data_75.indices[2][1]] / 1000) km") +axfield_85 = CairoMakie.Axis(fig[5, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_85.grid.yᵃᶜᵃ[field_NN_data_85.indices[2][1]] / 1000) km") +axfield_95 = CairoMakie.Axis(fig[5, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_95.grid.yᵃᶜᵃ[field_NN_data_95.indices[2][1]] / 1000) km") + +axflux_5 = CairoMakie.Axis(fig[1, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_5.grid.yᵃᶜᵃ[flux_NN_data_5.indices[2][1]] / 1000) km") +axflux_15 = CairoMakie.Axis(fig[1, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_15.grid.yᵃᶜᵃ[flux_NN_data_15.indices[2][1]] / 1000) km") +axflux_25 = CairoMakie.Axis(fig[2, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_25.grid.yᵃᶜᵃ[flux_NN_data_25.indices[2][1]] / 1000) km") +axflux_35 = CairoMakie.Axis(fig[2, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_35.grid.yᵃᶜᵃ[flux_NN_data_35.indices[2][1]] / 1000) km") +axflux_45 = CairoMakie.Axis(fig[3, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_45.grid.yᵃᶜᵃ[flux_NN_data_45.indices[2][1]] / 1000) km") +axflux_55 = CairoMakie.Axis(fig[3, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_55.grid.yᵃᶜᵃ[flux_NN_data_55.indices[2][1]] / 1000) km") +axflux_65 = CairoMakie.Axis(fig[4, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_65.grid.yᵃᶜᵃ[flux_NN_data_65.indices[2][1]] / 1000) km") +axflux_75 = CairoMakie.Axis(fig[4, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_75.grid.yᵃᶜᵃ[flux_NN_data_75.indices[2][1]] / 1000) km") +axflux_85 = CairoMakie.Axis(fig[5, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_85.grid.yᵃᶜᵃ[flux_NN_data_85.indices[2][1]] / 1000) km") +axflux_95 = CairoMakie.Axis(fig[5, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_95.grid.yᵃᶜᵃ[flux_NN_data_95.indices[2][1]] / 1000) km") + +axfields = [axfield_5, axfield_15, axfield_25, axfield_35, axfield_45, axfield_55, axfield_65, axfield_75, axfield_85, axfield_95] +axfluxes = [axflux_5, axflux_15, axflux_25, axflux_35, axflux_45, axflux_55, axflux_65, axflux_75, axflux_85, axflux_95] + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lims = [(find_min(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices))) for field_data in field_datas] + +flux_lims = [(find_min(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices))) for flux_data in flux_datas] + +flux_lims = [(-maximum(abs, flux_lim), maximum(abs, flux_lim)) for flux_lim in flux_lims] + +NNₙs = [@lift interior(field_data[$n], :, 1, zC_indices) for field_data in field_datas] +fluxₙs = [@lift interior(flux_data[$n], :, 1, zF_indices) for flux_data in flux_datas] + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_5_surface = heatmap!(axfield_5, xC, zC[zC_indices], NN_5ₙ, colormap=colorscheme_field, colorrange=field_lims[1]) +field_15_surface = heatmap!(axfield_15, xC, zC[zC_indices], NN_15ₙ, colormap=colorscheme_field, colorrange=field_lims[2]) +field_25_surface = heatmap!(axfield_25, xC, zC[zC_indices], NN_25ₙ, colormap=colorscheme_field, colorrange=field_lims[3]) +field_35_surface = heatmap!(axfield_35, xC, zC[zC_indices], NN_35ₙ, colormap=colorscheme_field, colorrange=field_lims[4]) +field_45_surface = heatmap!(axfield_45, xC, zC[zC_indices], NN_45ₙ, colormap=colorscheme_field, colorrange=field_lims[5]) +field_55_surface = heatmap!(axfield_55, xC, zC[zC_indices], NN_55ₙ, colormap=colorscheme_field, colorrange=field_lims[6]) +field_65_surface = heatmap!(axfield_65, xC, zC[zC_indices], NN_65ₙ, colormap=colorscheme_field, colorrange=field_lims[7]) +field_75_surface = heatmap!(axfield_75, xC, zC[zC_indices], NN_75ₙ, colormap=colorscheme_field, colorrange=field_lims[8]) +field_85_surface = heatmap!(axfield_85, xC, zC[zC_indices], NN_85ₙ, colormap=colorscheme_field, colorrange=field_lims[9]) +field_95_surface = heatmap!(axfield_95, xC, zC[zC_indices], NN_95ₙ, colormap=colorscheme_field, colorrange=field_lims[10]) + +flux_5_surface = heatmap!(axflux_5, xC, zC[zF_indices], flux_5ₙ, colormap=colorscheme_flux, colorrange=flux_lims[1]) +flux_15_surface = heatmap!(axflux_15, xC, zC[zF_indices], flux_15ₙ, colormap=colorscheme_flux, colorrange=flux_lims[2]) +flux_25_surface = heatmap!(axflux_25, xC, zC[zF_indices], flux_25ₙ, colormap=colorscheme_flux, colorrange=flux_lims[3]) +flux_35_surface = heatmap!(axflux_35, xC, zC[zF_indices], flux_35ₙ, colormap=colorscheme_flux, colorrange=flux_lims[4]) +flux_45_surface = heatmap!(axflux_45, xC, zC[zF_indices], flux_45ₙ, colormap=colorscheme_flux, colorrange=flux_lims[5]) +flux_55_surface = heatmap!(axflux_55, xC, zC[zF_indices], flux_55ₙ, colormap=colorscheme_flux, colorrange=flux_lims[6]) +flux_65_surface = heatmap!(axflux_65, xC, zC[zF_indices], flux_65ₙ, colormap=colorscheme_flux, colorrange=flux_lims[7]) +flux_75_surface = heatmap!(axflux_75, xC, zC[zF_indices], flux_75ₙ, colormap=colorscheme_flux, colorrange=flux_lims[8]) +flux_85_surface = heatmap!(axflux_85, xC, zC[zF_indices], flux_85ₙ, colormap=colorscheme_flux, colorrange=flux_lims[9]) +flux_95_surface = heatmap!(axflux_95, xC, zC[zF_indices], flux_95ₙ, colormap=colorscheme_flux, colorrange=flux_lims[10]) + +Colorbar(fig[1, 2], field_5_surface, label="Field") +Colorbar(fig[1, 4], flux_5_surface, label="NN Flux") +Colorbar(fig[1, 6], field_15_surface, label="Field") +Colorbar(fig[1, 8], flux_15_surface, label="NN Flux") +Colorbar(fig[2, 2], field_25_surface, label="Field") +Colorbar(fig[2, 4], flux_25_surface, label="NN Flux") +Colorbar(fig[2, 6], field_35_surface, label="Field") +Colorbar(fig[2, 8], flux_35_surface, label="NN Flux") +Colorbar(fig[3, 2], field_45_surface, label="Field") +Colorbar(fig[3, 4], flux_45_surface, label="NN Flux") +Colorbar(fig[3, 6], field_55_surface, label="Field") +Colorbar(fig[3, 8], flux_55_surface, label="NN Flux") +Colorbar(fig[4, 2], field_65_surface, label="Field") +Colorbar(fig[4, 4], flux_65_surface, label="NN Flux") +Colorbar(fig[4, 6], field_75_surface, label="Field") +Colorbar(fig[4, 8], flux_75_surface, label="NN Flux") +Colorbar(fig[5, 2], field_85_surface, label="Field") +Colorbar(fig[5, 4], flux_85_surface, label="NN Flux") +Colorbar(fig[5, 6], field_95_surface, label="Field") +Colorbar(fig[5, 8], flux_95_surface, label="NN Flux") + +for axfield in axfields + xlims!(axfield, minimum(xC), maximum(xC)) + ylims!(axfield, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +end + +for axflux in axfluxes + xlims!(axflux, minimum(xC), maximum(xC)) + ylims!(axflux, minimum(zC[zF_indices]), maximum(zC[zF_indices])) +end + +title_str = @lift "Temperature (°C), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_xzslices_fluxes_T.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end + +#%% +@info "Recording S fields and fluxes in xz" + +fieldname = "S" +fluxname = "wS_NN" + +field_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fieldname, backend=OnDisk()) +field_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fieldname, backend=OnDisk()) +field_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fieldname, backend=OnDisk()) +field_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fieldname, backend=OnDisk()) +field_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fieldname, backend=OnDisk()) +field_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fieldname, backend=OnDisk()) +field_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fieldname, backend=OnDisk()) +field_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fieldname, backend=OnDisk()) +field_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fieldname, backend=OnDisk()) +field_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fluxname, backend=OnDisk()) +flux_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fluxname, backend=OnDisk()) +flux_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fluxname, backend=OnDisk()) +flux_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fluxname, backend=OnDisk()) +flux_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fluxname, backend=OnDisk()) +flux_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fluxname, backend=OnDisk()) +flux_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fluxname, backend=OnDisk()) +flux_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fluxname, backend=OnDisk()) +flux_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fluxname, backend=OnDisk()) +flux_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fluxname, backend=OnDisk()) + +field_datas = [field_NN_data_5, field_NN_data_15, field_NN_data_25, field_NN_data_35, field_NN_data_45, field_NN_data_55, field_NN_data_65, field_NN_data_75, field_NN_data_85, field_NN_data_95] +flux_datas = [flux_NN_data_5, flux_NN_data_15, flux_NN_data_25, flux_NN_data_35, flux_NN_data_45, flux_NN_data_55, flux_NN_data_65, flux_NN_data_75, flux_NN_data_85, flux_NN_data_95] + +xC = field_NN_data_5.grid.xᶜᵃᵃ[1:field_NN_data_5.grid.Nx] +yC = field_NN_data_5.grid.yᵃᶜᵃ[1:field_NN_data_5.grid.Ny] +zC = field_NN_data_5.grid.zᵃᵃᶜ[1:field_NN_data_5.grid.Nz] +zF = field_NN_data_5.grid.zᵃᵃᶠ[1:field_NN_data_5.grid.Nz+1] + +Nt = length(field_NN_data_95) +times = field_NN_data_5.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_5 = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_5.grid.yᵃᶜᵃ[field_NN_data_5.indices[2][1]] / 1000) km") +axfield_15 = CairoMakie.Axis(fig[1, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_15.grid.yᵃᶜᵃ[field_NN_data_15.indices[2][1]] / 1000) km") +axfield_25 = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_25.grid.yᵃᶜᵃ[field_NN_data_25.indices[2][1]] / 1000) km") +axfield_35 = CairoMakie.Axis(fig[2, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_35.grid.yᵃᶜᵃ[field_NN_data_35.indices[2][1]] / 1000) km") +axfield_45 = CairoMakie.Axis(fig[3, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_45.grid.yᵃᶜᵃ[field_NN_data_45.indices[2][1]] / 1000) km") +axfield_55 = CairoMakie.Axis(fig[3, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_55.grid.yᵃᶜᵃ[field_NN_data_55.indices[2][1]] / 1000) km") +axfield_65 = CairoMakie.Axis(fig[4, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_65.grid.yᵃᶜᵃ[field_NN_data_65.indices[2][1]] / 1000) km") +axfield_75 = CairoMakie.Axis(fig[4, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_75.grid.yᵃᶜᵃ[field_NN_data_75.indices[2][1]] / 1000) km") +axfield_85 = CairoMakie.Axis(fig[5, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_85.grid.yᵃᶜᵃ[field_NN_data_85.indices[2][1]] / 1000) km") +axfield_95 = CairoMakie.Axis(fig[5, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_95.grid.yᵃᶜᵃ[field_NN_data_95.indices[2][1]] / 1000) km") + +axflux_5 = CairoMakie.Axis(fig[1, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_5.grid.yᵃᶜᵃ[flux_NN_data_5.indices[2][1]] / 1000) km") +axflux_15 = CairoMakie.Axis(fig[1, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_15.grid.yᵃᶜᵃ[flux_NN_data_15.indices[2][1]] / 1000) km") +axflux_25 = CairoMakie.Axis(fig[2, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_25.grid.yᵃᶜᵃ[flux_NN_data_25.indices[2][1]] / 1000) km") +axflux_35 = CairoMakie.Axis(fig[2, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_35.grid.yᵃᶜᵃ[flux_NN_data_35.indices[2][1]] / 1000) km") +axflux_45 = CairoMakie.Axis(fig[3, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_45.grid.yᵃᶜᵃ[flux_NN_data_45.indices[2][1]] / 1000) km") +axflux_55 = CairoMakie.Axis(fig[3, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_55.grid.yᵃᶜᵃ[flux_NN_data_55.indices[2][1]] / 1000) km") +axflux_65 = CairoMakie.Axis(fig[4, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_65.grid.yᵃᶜᵃ[flux_NN_data_65.indices[2][1]] / 1000) km") +axflux_75 = CairoMakie.Axis(fig[4, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_75.grid.yᵃᶜᵃ[flux_NN_data_75.indices[2][1]] / 1000) km") +axflux_85 = CairoMakie.Axis(fig[5, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_85.grid.yᵃᶜᵃ[flux_NN_data_85.indices[2][1]] / 1000) km") +axflux_95 = CairoMakie.Axis(fig[5, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_95.grid.yᵃᶜᵃ[flux_NN_data_95.indices[2][1]] / 1000) km") + +axfields = [axfield_5, axfield_15, axfield_25, axfield_35, axfield_45, axfield_55, axfield_65, axfield_75, axfield_85, axfield_95] +axfluxes = [axflux_5, axflux_15, axflux_25, axflux_35, axflux_45, axflux_55, axflux_65, axflux_75, axflux_85, axflux_95] + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lims = [(find_min(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices))) for field_data in field_datas] + +flux_lims = [(find_min(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices))) for flux_data in flux_datas] + +flux_lims = [(-maximum(abs, flux_lim), maximum(abs, flux_lim)) for flux_lim in flux_lims] + +NNₙs = [@lift interior(field_data[$n], :, 1, zC_indices) for field_data in field_datas] +fluxₙs = [@lift interior(flux_data[$n], :, 1, zF_indices) for flux_data in flux_datas] + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_5_surface = heatmap!(axfield_5, xC, zC[zC_indices], NN_5ₙ, colormap=colorscheme_field, colorrange=field_lims[1]) +field_15_surface = heatmap!(axfield_15, xC, zC[zC_indices], NN_15ₙ, colormap=colorscheme_field, colorrange=field_lims[2]) +field_25_surface = heatmap!(axfield_25, xC, zC[zC_indices], NN_25ₙ, colormap=colorscheme_field, colorrange=field_lims[3]) +field_35_surface = heatmap!(axfield_35, xC, zC[zC_indices], NN_35ₙ, colormap=colorscheme_field, colorrange=field_lims[4]) +field_45_surface = heatmap!(axfield_45, xC, zC[zC_indices], NN_45ₙ, colormap=colorscheme_field, colorrange=field_lims[5]) +field_55_surface = heatmap!(axfield_55, xC, zC[zC_indices], NN_55ₙ, colormap=colorscheme_field, colorrange=field_lims[6]) +field_65_surface = heatmap!(axfield_65, xC, zC[zC_indices], NN_65ₙ, colormap=colorscheme_field, colorrange=field_lims[7]) +field_75_surface = heatmap!(axfield_75, xC, zC[zC_indices], NN_75ₙ, colormap=colorscheme_field, colorrange=field_lims[8]) +field_85_surface = heatmap!(axfield_85, xC, zC[zC_indices], NN_85ₙ, colormap=colorscheme_field, colorrange=field_lims[9]) +field_95_surface = heatmap!(axfield_95, xC, zC[zC_indices], NN_95ₙ, colormap=colorscheme_field, colorrange=field_lims[10]) + +flux_5_surface = heatmap!(axflux_5, xC, zC[zF_indices], flux_5ₙ, colormap=colorscheme_flux, colorrange=flux_lims[1]) +flux_15_surface = heatmap!(axflux_15, xC, zC[zF_indices], flux_15ₙ, colormap=colorscheme_flux, colorrange=flux_lims[2]) +flux_25_surface = heatmap!(axflux_25, xC, zC[zF_indices], flux_25ₙ, colormap=colorscheme_flux, colorrange=flux_lims[3]) +flux_35_surface = heatmap!(axflux_35, xC, zC[zF_indices], flux_35ₙ, colormap=colorscheme_flux, colorrange=flux_lims[4]) +flux_45_surface = heatmap!(axflux_45, xC, zC[zF_indices], flux_45ₙ, colormap=colorscheme_flux, colorrange=flux_lims[5]) +flux_55_surface = heatmap!(axflux_55, xC, zC[zF_indices], flux_55ₙ, colormap=colorscheme_flux, colorrange=flux_lims[6]) +flux_65_surface = heatmap!(axflux_65, xC, zC[zF_indices], flux_65ₙ, colormap=colorscheme_flux, colorrange=flux_lims[7]) +flux_75_surface = heatmap!(axflux_75, xC, zC[zF_indices], flux_75ₙ, colormap=colorscheme_flux, colorrange=flux_lims[8]) +flux_85_surface = heatmap!(axflux_85, xC, zC[zF_indices], flux_85ₙ, colormap=colorscheme_flux, colorrange=flux_lims[9]) +flux_95_surface = heatmap!(axflux_95, xC, zC[zF_indices], flux_95ₙ, colormap=colorscheme_flux, colorrange=flux_lims[10]) + +Colorbar(fig[1, 2], field_5_surface, label="Field") +Colorbar(fig[1, 4], flux_5_surface, label="NN Flux") +Colorbar(fig[1, 6], field_15_surface, label="Field") +Colorbar(fig[1, 8], flux_15_surface, label="NN Flux") +Colorbar(fig[2, 2], field_25_surface, label="Field") +Colorbar(fig[2, 4], flux_25_surface, label="NN Flux") +Colorbar(fig[2, 6], field_35_surface, label="Field") +Colorbar(fig[2, 8], flux_35_surface, label="NN Flux") +Colorbar(fig[3, 2], field_45_surface, label="Field") +Colorbar(fig[3, 4], flux_45_surface, label="NN Flux") +Colorbar(fig[3, 6], field_55_surface, label="Field") +Colorbar(fig[3, 8], flux_55_surface, label="NN Flux") +Colorbar(fig[4, 2], field_65_surface, label="Field") +Colorbar(fig[4, 4], flux_65_surface, label="NN Flux") +Colorbar(fig[4, 6], field_75_surface, label="Field") +Colorbar(fig[4, 8], flux_75_surface, label="NN Flux") +Colorbar(fig[5, 2], field_85_surface, label="Field") +Colorbar(fig[5, 4], flux_85_surface, label="NN Flux") +Colorbar(fig[5, 6], field_95_surface, label="Field") +Colorbar(fig[5, 8], flux_95_surface, label="NN Flux") + +for axfield in axfields + xlims!(axfield, minimum(xC), maximum(xC)) + ylims!(axfield, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +end + +for axflux in axfluxes + xlims!(axflux, minimum(xC), maximum(xC)) + ylims!(axflux, minimum(zC[zF_indices]), maximum(zC[zF_indices])) +end + +title_str = @lift "Salinity (psu), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_xzslices_fluxes_S.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end +#%% \ No newline at end of file diff --git a/NNclosure_Ri_nof_BBLkappazonelast55_doublegyre_model.jl b/NNclosure_Ri_nof_BBLkappazonelast55_doublegyre_model.jl new file mode 100644 index 0000000000..6e6e425e94 --- /dev/null +++ b/NNclosure_Ri_nof_BBLkappazonelast55_doublegyre_model.jl @@ -0,0 +1,533 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("NN_closure_global_Ri_nof_BBLkappazonelast55.jl") +include("xin_kai_vertical_diffusivity_local_2step.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + +#%% +filename = "doublegyre_30Cwarmflushbottom10_relaxation_30days_NN_closure_NDE_Ri_BBLkappazonelast55_temp" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +nn_closure = NNFluxClosure(model_architecture) +base_closure = XinKaiLocalVerticalDiffusivity() +vertical_scalar_closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5) +closure = (base_closure, nn_closure, vertical_scalar_closure) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +##### +##### Forcing and initial condition +##### +@inline T_initial(x, y, z) = 10 + 20 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10950days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); +first_index = model.diffusivity_fields[2].first_index +last_index = model.diffusivity_fields[2].last_index +wT_NN = model.diffusivity_fields[2].wT +wS_NN = model.diffusivity_fields[2].wS + +κ = model.diffusivity_fields[1].κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +@inline function get_top_buoyancy_flux(i, j, k, grid, buoyancy, T_bc, S_bc, velocities, tracers, clock) + return top_buoyancy_flux(i, j, grid, buoyancy, (; T=T_bc, S=S_bc), clock, merge(velocities, tracers)) +end + +Qb = KernelFunctionOperation{Center, Center, Nothing}(get_top_buoyancy_flux, model.grid, model.buoyancy, T.boundary_conditions.top, S.boundary_conditions.top, model.velocities, model.tracers, model.clock) +Qb = Field(Qb) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +outputs = (; u, v, w, T, S, ρ, N², wT_NN, wS_NN, wT_base, wS_base) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:BBL] = JLD2OutputWriter(model, (; first_index, last_index, Qb), + filename = "$(FILE_DIR)/instantaneous_fields_NN_active_diagnostics", + indices = (:, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 365 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% \ No newline at end of file diff --git a/NNclosure_Ri_nof_BBLkappazonelast55_doublegyre_model_seasonalforcing.jl b/NNclosure_Ri_nof_BBLkappazonelast55_doublegyre_model_seasonalforcing.jl new file mode 100644 index 0000000000..11dd334c24 --- /dev/null +++ b/NNclosure_Ri_nof_BBLkappazonelast55_doublegyre_model_seasonalforcing.jl @@ -0,0 +1,539 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("NN_closure_global_Ri_nof_BBLkappazonelast55.jl") +include("xin_kai_vertical_diffusivity_local_2step.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + + +#%% +filename = "doublegyre_seasonalforcing_30C-20C_relaxation_8days_NN_closure_NDE_Ri_BBLkappazonelast55_temp" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +nn_closure = NNFluxClosure(model_architecture) +base_closure = XinKaiLocalVerticalDiffusivity() +vertical_scalar_closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5) +closure = (base_closure, nn_closure, vertical_scalar_closure) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/8days + +const seasonal_period = 360days +const seasonal_forcing_width = Ly / 6 +const seasonal_T_amplitude = 15 + +##### +##### Forcing and initial condition +##### +@inline T_initial(x, y, z) = 20 + 10 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_seasonal(y, t) = seasonal_T_amplitude * exp(-y^2/(2 * seasonal_forcing_width^2)) * sin(2π * t / seasonal_period) +@inline T_ref(y, t) = T_mid - ΔT / Ly * y + T_seasonal(y, t) +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y, t)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 36000days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); +first_index = model.diffusivity_fields[2].first_index +last_index = model.diffusivity_fields[2].last_index +wT_NN = model.diffusivity_fields[2].wT +wS_NN = model.diffusivity_fields[2].wS + +κ = model.diffusivity_fields[1].κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +@inline function get_top_buoyancy_flux(i, j, k, grid, buoyancy, T_bc, S_bc, velocities, tracers, clock) + return top_buoyancy_flux(i, j, grid, buoyancy, (; T=T_bc, S=S_bc), clock, merge(velocities, tracers)) +end + +Qb = KernelFunctionOperation{Center, Center, Nothing}(get_top_buoyancy_flux, model.grid, model.buoyancy, T.boundary_conditions.top, S.boundary_conditions.top, model.velocities, model.tracers, model.clock) +Qb = Field(Qb) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +outputs = (; u, v, w, T, S, ρ, N², wT_NN, wS_NN, wT_base, wS_base) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:BBL] = JLD2OutputWriter(model, (; first_index, last_index, Qb), + filename = "$(FILE_DIR)/instantaneous_fields_NN_active_diagnostics", + indices = (:, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 360 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% \ No newline at end of file diff --git a/NNclosure_Ri_nof_BBLkappazonelast55_doublegyre_model_seasonalforcing_wallrestoration.jl b/NNclosure_Ri_nof_BBLkappazonelast55_doublegyre_model_seasonalforcing_wallrestoration.jl new file mode 100644 index 0000000000..0d0aea8d81 --- /dev/null +++ b/NNclosure_Ri_nof_BBLkappazonelast55_doublegyre_model_seasonalforcing_wallrestoration.jl @@ -0,0 +1,549 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("NN_closure_global_Ri_nof_BBLkappazonelast55.jl") +include("xin_kai_vertical_diffusivity_local_2step.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + + +#%% +filename = "doublegyre_seasonalforcing_30C-20C_relaxation_wallrestoration_8days_NN_closure_NDE_Ri_BBLkappazonelast55_temp" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +nn_closure = NNFluxClosure(model_architecture) +base_closure = XinKaiLocalVerticalDiffusivity() +vertical_scalar_closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5) +closure = (base_closure, nn_closure, vertical_scalar_closure) + +# number of grid points +##### +##### Boundary conditions +##### +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +const δy = Ly / Ny + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 10 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/8days + +const seasonal_period = 360days +const seasonal_forcing_width = Ly / 6 +const seasonal_T_amplitude = 20 + +##### +##### Forcing and initial condition +##### +@inline T_initial(x, y, z) = 20 + 10 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_seasonal(y, t) = seasonal_T_amplitude * exp(-y^2/(2 * seasonal_forcing_width^2)) * sin(2π * t / seasonal_period) +@inline T_ref(y, t) = T_mid - ΔT / Ly * y + T_seasonal(y, t) +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y, t)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) + +@inline T_north_ref(z) = 10 * (1 + z / Lz) +@inline north_T_flux(x, z, t, T) = μ_T * δy * (T - T_north_ref(z)) +north_T_flux_bc = FluxBoundaryCondition(north_T_flux; field_dependencies=:T) + +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc, north = north_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10800days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); +first_index = model.diffusivity_fields[2].first_index +last_index = model.diffusivity_fields[2].last_index +wT_NN = model.diffusivity_fields[2].wT +wS_NN = model.diffusivity_fields[2].wS + +κ = model.diffusivity_fields[1].κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +@inline function get_top_buoyancy_flux(i, j, k, grid, buoyancy, T_bc, S_bc, velocities, tracers, clock) + return top_buoyancy_flux(i, j, grid, buoyancy, (; T=T_bc, S=S_bc), clock, merge(velocities, tracers)) +end + +Qb = KernelFunctionOperation{Center, Center, Nothing}(get_top_buoyancy_flux, model.grid, model.buoyancy, T.boundary_conditions.top, S.boundary_conditions.top, model.velocities, model.tracers, model.clock) +Qb = Field(Qb) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +outputs = (; u, v, w, T, S, ρ, N², wT_NN, wS_NN, wT_base, wS_base) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:BBL] = JLD2OutputWriter(model, (; first_index, last_index, Qb), + filename = "$(FILE_DIR)/instantaneous_fields_NN_active_diagnostics", + indices = (:, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 360 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% \ No newline at end of file diff --git a/NNclosure_nof_BBLRifirstzone510_doublegyre_model.jl b/NNclosure_nof_BBLRifirstzone510_doublegyre_model.jl new file mode 100644 index 0000000000..64af3bda2f --- /dev/null +++ b/NNclosure_nof_BBLRifirstzone510_doublegyre_model.jl @@ -0,0 +1,1265 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("NN_closure_global_nof_BBLRifirstzone510.jl") +include("xin_kai_vertical_diffusivity_local_2step.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + + +#%% +# filename = "doublegyre_30Cwarmflushbottom10_relaxation_30days_zC2O_NN_closure_NDE_BBLRifirztzone510" +filename = "doublegyre_30Cwarmflushbottom10_relaxation_30days_zWENO5_NN_closure_NDE_BBLRifirztzone510" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +nn_closure = NNFluxClosure(model_architecture) +base_closure = XinKaiLocalVerticalDiffusivity() +closure = (base_closure, nn_closure) + +# advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), CenteredSecondOrder()) +advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), WENO(order=5)) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +##### +##### Forcing and initial condition +##### +@inline T_initial(x, y, z) = 10 + 20 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 7300days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); +first_index = model.diffusivity_fields[2].first_index +last_index = model.diffusivity_fields[2].last_index +wT_NN = model.diffusivity_fields[2].wT +wS_NN = model.diffusivity_fields[2].wS + +κ = model.diffusivity_fields[1].κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) + +wT = wT_NN + wT_base +wS = wS_NN + wS_base + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +@inline function get_top_buoyancy_flux(i, j, k, grid, buoyancy, T_bc, S_bc, velocities, tracers, clock) + return top_buoyancy_flux(i, j, grid, buoyancy, (; T=T_bc, S=S_bc), clock, merge(velocities, tracers)) +end + +Qb = KernelFunctionOperation{Center, Center, Nothing}(get_top_buoyancy_flux, model.grid, model.buoyancy, T.boundary_conditions.top, S.boundary_conditions.top, model.velocities, model.tracers, model.clock) +Qb = Field(Qb) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +wT_NNbar_zonal = Average(wT_NN, dims=1) +wS_NNbar_zonal = Average(wS_NN, dims=1) + +wT_basebar_zonal = Average(wT_base, dims=1) +wS_basebar_zonal = Average(wS_base, dims=1) + +wTbar_zonal = Average(wT, dims=1) +wSbar_zonal = Average(wS, dims=1) + +outputs = (; u, v, w, T, S, ρ, N², wT_NN, wS_NN, wT_base, wS_base, wT, wS) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal, wT_NNbar_zonal, wS_NNbar_zonal, wT_basebar_zonal, wS_basebar_zonal, wTbar_zonal, wSbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_5] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_5", + indices = (:, 5, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_15] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_15", + indices = (:, 15, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_25] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_25", + indices = (:, 25, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_35] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_35", + indices = (:, 35, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_45] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_45", + indices = (:, 45, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_55] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_55", + indices = (:, 55, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_65] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_65", + indices = (:, 65, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_75] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_75", + indices = (:, 75, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_85] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_85", + indices = (:, 85, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_95] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_95", + indices = (:, 95, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:BBL] = JLD2OutputWriter(model, (; first_index, last_index, Qb), + filename = "$(FILE_DIR)/instantaneous_fields_NN_active_diagnostics", + indices = (:, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end + +#%% +@info "Recording T fields and fluxes in yz" + +fieldname = "T" +fluxname = "wT_NN" +field_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fieldname, backend=OnDisk()) +field_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fieldname, backend=OnDisk()) +field_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fieldname, backend=OnDisk()) +field_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fieldname, backend=OnDisk()) +field_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fieldname, backend=OnDisk()) +field_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fieldname, backend=OnDisk()) +field_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fieldname, backend=OnDisk()) +field_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fieldname, backend=OnDisk()) +field_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fieldname, backend=OnDisk()) +field_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fluxname, backend=OnDisk()) +flux_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fluxname, backend=OnDisk()) +flux_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fluxname, backend=OnDisk()) +flux_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fluxname, backend=OnDisk()) +flux_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fluxname, backend=OnDisk()) +flux_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fluxname, backend=OnDisk()) +flux_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fluxname, backend=OnDisk()) +flux_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fluxname, backend=OnDisk()) +flux_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fluxname, backend=OnDisk()) +flux_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fluxname, backend=OnDisk()) + +xC = field_NN_data_00.grid.xᶜᵃᵃ[1:field_NN_data_00.grid.Nx] +yC = field_NN_data_00.grid.yᵃᶜᵃ[1:field_NN_data_00.grid.Ny] +zC = field_NN_data_00.grid.zᵃᵃᶜ[1:field_NN_data_00.grid.Nz] +zF = field_NN_data_00.grid.zᵃᵃᶠ[1:field_NN_data_00.grid.Nz+1] + +Nt = length(field_NN_data_90) +times = field_NN_data_00.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_00 = CairoMakie.Axis(fig[1, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_00.grid.xᶜᵃᵃ[field_NN_data_00.indices[1][1]] / 1000) km") +axfield_10 = CairoMakie.Axis(fig[1, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_10.grid.xᶜᵃᵃ[field_NN_data_10.indices[1][1]] / 1000) km") +axfield_20 = CairoMakie.Axis(fig[2, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_20.grid.xᶜᵃᵃ[field_NN_data_20.indices[1][1]] / 1000) km") +axfield_30 = CairoMakie.Axis(fig[2, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_30.grid.xᶜᵃᵃ[field_NN_data_30.indices[1][1]] / 1000) km") +axfield_40 = CairoMakie.Axis(fig[3, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_40.grid.xᶜᵃᵃ[field_NN_data_40.indices[1][1]] / 1000) km") +axfield_50 = CairoMakie.Axis(fig[3, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_50.grid.xᶜᵃᵃ[field_NN_data_50.indices[1][1]] / 1000) km") +axfield_60 = CairoMakie.Axis(fig[4, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_60.grid.xᶜᵃᵃ[field_NN_data_60.indices[1][1]] / 1000) km") +axfield_70 = CairoMakie.Axis(fig[4, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_70.grid.xᶜᵃᵃ[field_NN_data_70.indices[1][1]] / 1000) km") +axfield_80 = CairoMakie.Axis(fig[5, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_80.grid.xᶜᵃᵃ[field_NN_data_80.indices[1][1]] / 1000) km") +axfield_90 = CairoMakie.Axis(fig[5, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_90.grid.xᶜᵃᵃ[field_NN_data_90.indices[1][1]] / 1000) km") + +axflux_00 = CairoMakie.Axis(fig[1, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_00.grid.xᶜᵃᵃ[flux_NN_data_00.indices[1][1]] / 1000) km") +axflux_10 = CairoMakie.Axis(fig[1, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_10.grid.xᶜᵃᵃ[flux_NN_data_10.indices[1][1]] / 1000) km") +axflux_20 = CairoMakie.Axis(fig[2, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_20.grid.xᶜᵃᵃ[flux_NN_data_20.indices[1][1]] / 1000) km") +axflux_30 = CairoMakie.Axis(fig[2, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_30.grid.xᶜᵃᵃ[flux_NN_data_30.indices[1][1]] / 1000) km") +axflux_40 = CairoMakie.Axis(fig[3, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_40.grid.xᶜᵃᵃ[flux_NN_data_40.indices[1][1]] / 1000) km") +axflux_50 = CairoMakie.Axis(fig[3, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_50.grid.xᶜᵃᵃ[flux_NN_data_50.indices[1][1]] / 1000) km") +axflux_60 = CairoMakie.Axis(fig[4, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_60.grid.xᶜᵃᵃ[flux_NN_data_60.indices[1][1]] / 1000) km") +axflux_70 = CairoMakie.Axis(fig[4, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_70.grid.xᶜᵃᵃ[flux_NN_data_70.indices[1][1]] / 1000) km") +axflux_80 = CairoMakie.Axis(fig[5, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_80.grid.xᶜᵃᵃ[flux_NN_data_80.indices[1][1]] / 1000) km") +axflux_90 = CairoMakie.Axis(fig[5, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_90.grid.xᶜᵃᵃ[flux_NN_data_90.indices[1][1]] / 1000) km") + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lim = (find_min(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (find_min(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (-maximum(abs, flux_lim), maximum(abs, flux_lim)) + +NN_00ₙ = @lift interior(field_NN_data_00[$n], 1, :, zC_indices) +NN_10ₙ = @lift interior(field_NN_data_10[$n], 1, :, zC_indices) +NN_20ₙ = @lift interior(field_NN_data_20[$n], 1, :, zC_indices) +NN_30ₙ = @lift interior(field_NN_data_30[$n], 1, :, zC_indices) +NN_40ₙ = @lift interior(field_NN_data_40[$n], 1, :, zC_indices) +NN_50ₙ = @lift interior(field_NN_data_50[$n], 1, :, zC_indices) +NN_60ₙ = @lift interior(field_NN_data_60[$n], 1, :, zC_indices) +NN_70ₙ = @lift interior(field_NN_data_70[$n], 1, :, zC_indices) +NN_80ₙ = @lift interior(field_NN_data_80[$n], 1, :, zC_indices) +NN_90ₙ = @lift interior(field_NN_data_90[$n], 1, :, zC_indices) + +flux_00ₙ = @lift interior(flux_NN_data_00[$n], 1, :, zF_indices) +flux_10ₙ = @lift interior(flux_NN_data_10[$n], 1, :, zF_indices) +flux_20ₙ = @lift interior(flux_NN_data_20[$n], 1, :, zF_indices) +flux_30ₙ = @lift interior(flux_NN_data_30[$n], 1, :, zF_indices) +flux_40ₙ = @lift interior(flux_NN_data_40[$n], 1, :, zF_indices) +flux_50ₙ = @lift interior(flux_NN_data_50[$n], 1, :, zF_indices) +flux_60ₙ = @lift interior(flux_NN_data_60[$n], 1, :, zF_indices) +flux_70ₙ = @lift interior(flux_NN_data_70[$n], 1, :, zF_indices) +flux_80ₙ = @lift interior(flux_NN_data_80[$n], 1, :, zF_indices) +flux_90ₙ = @lift interior(flux_NN_data_90[$n], 1, :, zF_indices) + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_00_surface = heatmap!(axfield_00, yC, zC[zC_indices], NN_00ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_10_surface = heatmap!(axfield_10, yC, zC[zC_indices], NN_10ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_20_surface = heatmap!(axfield_20, yC, zC[zC_indices], NN_20ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_30_surface = heatmap!(axfield_30, yC, zC[zC_indices], NN_30ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_40_surface = heatmap!(axfield_40, yC, zC[zC_indices], NN_40ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_50_surface = heatmap!(axfield_50, yC, zC[zC_indices], NN_50ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_60_surface = heatmap!(axfield_60, yC, zC[zC_indices], NN_60ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_70_surface = heatmap!(axfield_70, yC, zC[zC_indices], NN_70ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_80_surface = heatmap!(axfield_80, yC, zC[zC_indices], NN_80ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_90_surface = heatmap!(axfield_90, yC, zC[zC_indices], NN_90ₙ, colormap=colorscheme_field, colorrange=field_lim) + +flux_00_surface = heatmap!(axflux_00, yC, zC[zF_indices], flux_00ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_10_surface = heatmap!(axflux_10, yC, zC[zF_indices], flux_10ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_20_surface = heatmap!(axflux_20, yC, zC[zF_indices], flux_20ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_30_surface = heatmap!(axflux_30, yC, zC[zF_indices], flux_30ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_40_surface = heatmap!(axflux_40, yC, zC[zF_indices], flux_40ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_50_surface = heatmap!(axflux_50, yC, zC[zF_indices], flux_50ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_60_surface = heatmap!(axflux_60, yC, zC[zF_indices], flux_60ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_70_surface = heatmap!(axflux_70, yC, zC[zF_indices], flux_70ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_80_surface = heatmap!(axflux_80, yC, zC[zF_indices], flux_80ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_90_surface = heatmap!(axflux_90, yC, zC[zF_indices], flux_90ₙ, colormap=colorscheme_flux, colorrange=flux_lim) + +Colorbar(fig[1:5, 2], field_00_surface, label="Field") +Colorbar(fig[1:5, 4], flux_00_surface, label="NN Flux") +Colorbar(fig[1:5, 6], field_00_surface, label="Field") +Colorbar(fig[1:5, 8], flux_00_surface, label="NN Flux") + +xlims!(axfield_00, minimum(yC), maximum(yC)) +xlims!(axfield_10, minimum(yC), maximum(yC)) +xlims!(axfield_20, minimum(yC), maximum(yC)) +xlims!(axfield_30, minimum(yC), maximum(yC)) +xlims!(axfield_40, minimum(yC), maximum(yC)) +xlims!(axfield_50, minimum(yC), maximum(yC)) +xlims!(axfield_60, minimum(yC), maximum(yC)) +xlims!(axfield_70, minimum(yC), maximum(yC)) +xlims!(axfield_80, minimum(yC), maximum(yC)) +xlims!(axfield_90, minimum(yC), maximum(yC)) + +ylims!(axfield_00, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_10, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_20, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_30, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_40, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_50, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_60, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_70, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_80, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_90, minimum(zC[zC_indices]), maximum(zC[zC_indices])) + +xlims!(axflux_00, minimum(yC), maximum(yC)) +xlims!(axflux_10, minimum(yC), maximum(yC)) +xlims!(axflux_20, minimum(yC), maximum(yC)) +xlims!(axflux_30, minimum(yC), maximum(yC)) +xlims!(axflux_40, minimum(yC), maximum(yC)) +xlims!(axflux_50, minimum(yC), maximum(yC)) +xlims!(axflux_60, minimum(yC), maximum(yC)) +xlims!(axflux_70, minimum(yC), maximum(yC)) +xlims!(axflux_80, minimum(yC), maximum(yC)) +xlims!(axflux_90, minimum(yC), maximum(yC)) + +ylims!(axflux_00, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_10, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_20, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_30, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_40, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_50, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_60, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_70, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_80, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_90, minimum(zF[zF_indices]), maximum(zF[zF_indices])) + +title_str = @lift "Temperature (°C), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +# save("./Output/compare_3D_instantaneous_fields_slices_NNclosure_fluxes.png", fig) +# display(fig) +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_yzslices_fluxes_T.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end +#%% +@info "Recording S fields and fluxes in yz" + +fieldname = "S" +fluxname = "wS_NN" +field_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fieldname, backend=OnDisk()) +field_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fieldname, backend=OnDisk()) +field_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fieldname, backend=OnDisk()) +field_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fieldname, backend=OnDisk()) +field_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fieldname, backend=OnDisk()) +field_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fieldname, backend=OnDisk()) +field_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fieldname, backend=OnDisk()) +field_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fieldname, backend=OnDisk()) +field_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fieldname, backend=OnDisk()) +field_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fluxname, backend=OnDisk()) +flux_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fluxname, backend=OnDisk()) +flux_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fluxname, backend=OnDisk()) +flux_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fluxname, backend=OnDisk()) +flux_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fluxname, backend=OnDisk()) +flux_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fluxname, backend=OnDisk()) +flux_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fluxname, backend=OnDisk()) +flux_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fluxname, backend=OnDisk()) +flux_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fluxname, backend=OnDisk()) +flux_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fluxname, backend=OnDisk()) + +xC = field_NN_data_00.grid.xᶜᵃᵃ[1:field_NN_data_00.grid.Nx] +yC = field_NN_data_00.grid.yᵃᶜᵃ[1:field_NN_data_00.grid.Ny] +zC = field_NN_data_00.grid.zᵃᵃᶜ[1:field_NN_data_00.grid.Nz] +zF = field_NN_data_00.grid.zᵃᵃᶠ[1:field_NN_data_00.grid.Nz+1] + +Nt = length(field_NN_data_90) +times = field_NN_data_00.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_00 = CairoMakie.Axis(fig[1, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_00.grid.xᶜᵃᵃ[field_NN_data_00.indices[1][1]] / 1000) km") +axfield_10 = CairoMakie.Axis(fig[1, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_10.grid.xᶜᵃᵃ[field_NN_data_10.indices[1][1]] / 1000) km") +axfield_20 = CairoMakie.Axis(fig[2, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_20.grid.xᶜᵃᵃ[field_NN_data_20.indices[1][1]] / 1000) km") +axfield_30 = CairoMakie.Axis(fig[2, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_30.grid.xᶜᵃᵃ[field_NN_data_30.indices[1][1]] / 1000) km") +axfield_40 = CairoMakie.Axis(fig[3, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_40.grid.xᶜᵃᵃ[field_NN_data_40.indices[1][1]] / 1000) km") +axfield_50 = CairoMakie.Axis(fig[3, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_50.grid.xᶜᵃᵃ[field_NN_data_50.indices[1][1]] / 1000) km") +axfield_60 = CairoMakie.Axis(fig[4, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_60.grid.xᶜᵃᵃ[field_NN_data_60.indices[1][1]] / 1000) km") +axfield_70 = CairoMakie.Axis(fig[4, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_70.grid.xᶜᵃᵃ[field_NN_data_70.indices[1][1]] / 1000) km") +axfield_80 = CairoMakie.Axis(fig[5, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_80.grid.xᶜᵃᵃ[field_NN_data_80.indices[1][1]] / 1000) km") +axfield_90 = CairoMakie.Axis(fig[5, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_90.grid.xᶜᵃᵃ[field_NN_data_90.indices[1][1]] / 1000) km") + +axflux_00 = CairoMakie.Axis(fig[1, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_00.grid.xᶜᵃᵃ[flux_NN_data_00.indices[1][1]] / 1000) km") +axflux_10 = CairoMakie.Axis(fig[1, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_10.grid.xᶜᵃᵃ[flux_NN_data_10.indices[1][1]] / 1000) km") +axflux_20 = CairoMakie.Axis(fig[2, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_20.grid.xᶜᵃᵃ[flux_NN_data_20.indices[1][1]] / 1000) km") +axflux_30 = CairoMakie.Axis(fig[2, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_30.grid.xᶜᵃᵃ[flux_NN_data_30.indices[1][1]] / 1000) km") +axflux_40 = CairoMakie.Axis(fig[3, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_40.grid.xᶜᵃᵃ[flux_NN_data_40.indices[1][1]] / 1000) km") +axflux_50 = CairoMakie.Axis(fig[3, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_50.grid.xᶜᵃᵃ[flux_NN_data_50.indices[1][1]] / 1000) km") +axflux_60 = CairoMakie.Axis(fig[4, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_60.grid.xᶜᵃᵃ[flux_NN_data_60.indices[1][1]] / 1000) km") +axflux_70 = CairoMakie.Axis(fig[4, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_70.grid.xᶜᵃᵃ[flux_NN_data_70.indices[1][1]] / 1000) km") +axflux_80 = CairoMakie.Axis(fig[5, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_80.grid.xᶜᵃᵃ[flux_NN_data_80.indices[1][1]] / 1000) km") +axflux_90 = CairoMakie.Axis(fig[5, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_90.grid.xᶜᵃᵃ[flux_NN_data_90.indices[1][1]] / 1000) km") + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lim = (find_min(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (find_min(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (-maximum(abs, flux_lim), maximum(abs, flux_lim)) + +NN_00ₙ = @lift interior(field_NN_data_00[$n], 1, :, zC_indices) +NN_10ₙ = @lift interior(field_NN_data_10[$n], 1, :, zC_indices) +NN_20ₙ = @lift interior(field_NN_data_20[$n], 1, :, zC_indices) +NN_30ₙ = @lift interior(field_NN_data_30[$n], 1, :, zC_indices) +NN_40ₙ = @lift interior(field_NN_data_40[$n], 1, :, zC_indices) +NN_50ₙ = @lift interior(field_NN_data_50[$n], 1, :, zC_indices) +NN_60ₙ = @lift interior(field_NN_data_60[$n], 1, :, zC_indices) +NN_70ₙ = @lift interior(field_NN_data_70[$n], 1, :, zC_indices) +NN_80ₙ = @lift interior(field_NN_data_80[$n], 1, :, zC_indices) +NN_90ₙ = @lift interior(field_NN_data_90[$n], 1, :, zC_indices) + +flux_00ₙ = @lift interior(flux_NN_data_00[$n], 1, :, zF_indices) +flux_10ₙ = @lift interior(flux_NN_data_10[$n], 1, :, zF_indices) +flux_20ₙ = @lift interior(flux_NN_data_20[$n], 1, :, zF_indices) +flux_30ₙ = @lift interior(flux_NN_data_30[$n], 1, :, zF_indices) +flux_40ₙ = @lift interior(flux_NN_data_40[$n], 1, :, zF_indices) +flux_50ₙ = @lift interior(flux_NN_data_50[$n], 1, :, zF_indices) +flux_60ₙ = @lift interior(flux_NN_data_60[$n], 1, :, zF_indices) +flux_70ₙ = @lift interior(flux_NN_data_70[$n], 1, :, zF_indices) +flux_80ₙ = @lift interior(flux_NN_data_80[$n], 1, :, zF_indices) +flux_90ₙ = @lift interior(flux_NN_data_90[$n], 1, :, zF_indices) + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_00_surface = heatmap!(axfield_00, yC, zC[zC_indices], NN_00ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_10_surface = heatmap!(axfield_10, yC, zC[zC_indices], NN_10ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_20_surface = heatmap!(axfield_20, yC, zC[zC_indices], NN_20ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_30_surface = heatmap!(axfield_30, yC, zC[zC_indices], NN_30ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_40_surface = heatmap!(axfield_40, yC, zC[zC_indices], NN_40ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_50_surface = heatmap!(axfield_50, yC, zC[zC_indices], NN_50ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_60_surface = heatmap!(axfield_60, yC, zC[zC_indices], NN_60ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_70_surface = heatmap!(axfield_70, yC, zC[zC_indices], NN_70ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_80_surface = heatmap!(axfield_80, yC, zC[zC_indices], NN_80ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_90_surface = heatmap!(axfield_90, yC, zC[zC_indices], NN_90ₙ, colormap=colorscheme_field, colorrange=field_lim) + +flux_00_surface = heatmap!(axflux_00, yC, zC[zF_indices], flux_00ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_10_surface = heatmap!(axflux_10, yC, zC[zF_indices], flux_10ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_20_surface = heatmap!(axflux_20, yC, zC[zF_indices], flux_20ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_30_surface = heatmap!(axflux_30, yC, zC[zF_indices], flux_30ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_40_surface = heatmap!(axflux_40, yC, zC[zF_indices], flux_40ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_50_surface = heatmap!(axflux_50, yC, zC[zF_indices], flux_50ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_60_surface = heatmap!(axflux_60, yC, zC[zF_indices], flux_60ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_70_surface = heatmap!(axflux_70, yC, zC[zF_indices], flux_70ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_80_surface = heatmap!(axflux_80, yC, zC[zF_indices], flux_80ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_90_surface = heatmap!(axflux_90, yC, zC[zF_indices], flux_90ₙ, colormap=colorscheme_flux, colorrange=flux_lim) + +Colorbar(fig[1:5, 2], field_00_surface, label="Field") +Colorbar(fig[1:5, 4], flux_00_surface, label="NN Flux") +Colorbar(fig[1:5, 6], field_00_surface, label="Field") +Colorbar(fig[1:5, 8], flux_00_surface, label="NN Flux") + +xlims!(axfield_00, minimum(yC), maximum(yC)) +xlims!(axfield_10, minimum(yC), maximum(yC)) +xlims!(axfield_20, minimum(yC), maximum(yC)) +xlims!(axfield_30, minimum(yC), maximum(yC)) +xlims!(axfield_40, minimum(yC), maximum(yC)) +xlims!(axfield_50, minimum(yC), maximum(yC)) +xlims!(axfield_60, minimum(yC), maximum(yC)) +xlims!(axfield_70, minimum(yC), maximum(yC)) +xlims!(axfield_80, minimum(yC), maximum(yC)) +xlims!(axfield_90, minimum(yC), maximum(yC)) + +ylims!(axfield_00, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_10, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_20, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_30, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_40, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_50, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_60, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_70, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_80, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_90, minimum(zC[zC_indices]), maximum(zC[zC_indices])) + +xlims!(axflux_00, minimum(yC), maximum(yC)) +xlims!(axflux_10, minimum(yC), maximum(yC)) +xlims!(axflux_20, minimum(yC), maximum(yC)) +xlims!(axflux_30, minimum(yC), maximum(yC)) +xlims!(axflux_40, minimum(yC), maximum(yC)) +xlims!(axflux_50, minimum(yC), maximum(yC)) +xlims!(axflux_60, minimum(yC), maximum(yC)) +xlims!(axflux_70, minimum(yC), maximum(yC)) +xlims!(axflux_80, minimum(yC), maximum(yC)) +xlims!(axflux_90, minimum(yC), maximum(yC)) + +ylims!(axflux_00, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_10, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_20, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_30, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_40, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_50, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_60, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_70, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_80, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_90, minimum(zF[zF_indices]), maximum(zF[zF_indices])) + +title_str = @lift "Salinity (psu), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_yzslices_fluxes_S.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +#%% +@info "Recording T fields and fluxes in xz" + +fieldname = "T" +fluxname = "wT_NN" + +field_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fieldname, backend=OnDisk()) +field_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fieldname, backend=OnDisk()) +field_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fieldname, backend=OnDisk()) +field_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fieldname, backend=OnDisk()) +field_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fieldname, backend=OnDisk()) +field_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fieldname, backend=OnDisk()) +field_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fieldname, backend=OnDisk()) +field_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fieldname, backend=OnDisk()) +field_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fieldname, backend=OnDisk()) +field_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fluxname, backend=OnDisk()) +flux_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fluxname, backend=OnDisk()) +flux_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fluxname, backend=OnDisk()) +flux_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fluxname, backend=OnDisk()) +flux_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fluxname, backend=OnDisk()) +flux_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fluxname, backend=OnDisk()) +flux_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fluxname, backend=OnDisk()) +flux_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fluxname, backend=OnDisk()) +flux_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fluxname, backend=OnDisk()) +flux_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fluxname, backend=OnDisk()) + +field_datas = [field_NN_data_5, field_NN_data_15, field_NN_data_25, field_NN_data_35, field_NN_data_45, field_NN_data_55, field_NN_data_65, field_NN_data_75, field_NN_data_85, field_NN_data_95] +flux_datas = [flux_NN_data_5, flux_NN_data_15, flux_NN_data_25, flux_NN_data_35, flux_NN_data_45, flux_NN_data_55, flux_NN_data_65, flux_NN_data_75, flux_NN_data_85, flux_NN_data_95] + +xC = field_NN_data_5.grid.xᶜᵃᵃ[1:field_NN_data_5.grid.Nx] +yC = field_NN_data_5.grid.yᵃᶜᵃ[1:field_NN_data_5.grid.Ny] +zC = field_NN_data_5.grid.zᵃᵃᶜ[1:field_NN_data_5.grid.Nz] +zF = field_NN_data_5.grid.zᵃᵃᶠ[1:field_NN_data_5.grid.Nz+1] + +Nt = length(field_NN_data_95) +times = field_NN_data_5.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_5 = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_5.grid.yᵃᶜᵃ[field_NN_data_5.indices[2][1]] / 1000) km") +axfield_15 = CairoMakie.Axis(fig[1, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_15.grid.yᵃᶜᵃ[field_NN_data_15.indices[2][1]] / 1000) km") +axfield_25 = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_25.grid.yᵃᶜᵃ[field_NN_data_25.indices[2][1]] / 1000) km") +axfield_35 = CairoMakie.Axis(fig[2, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_35.grid.yᵃᶜᵃ[field_NN_data_35.indices[2][1]] / 1000) km") +axfield_45 = CairoMakie.Axis(fig[3, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_45.grid.yᵃᶜᵃ[field_NN_data_45.indices[2][1]] / 1000) km") +axfield_55 = CairoMakie.Axis(fig[3, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_55.grid.yᵃᶜᵃ[field_NN_data_55.indices[2][1]] / 1000) km") +axfield_65 = CairoMakie.Axis(fig[4, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_65.grid.yᵃᶜᵃ[field_NN_data_65.indices[2][1]] / 1000) km") +axfield_75 = CairoMakie.Axis(fig[4, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_75.grid.yᵃᶜᵃ[field_NN_data_75.indices[2][1]] / 1000) km") +axfield_85 = CairoMakie.Axis(fig[5, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_85.grid.yᵃᶜᵃ[field_NN_data_85.indices[2][1]] / 1000) km") +axfield_95 = CairoMakie.Axis(fig[5, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_95.grid.yᵃᶜᵃ[field_NN_data_95.indices[2][1]] / 1000) km") + +axflux_5 = CairoMakie.Axis(fig[1, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_5.grid.yᵃᶜᵃ[flux_NN_data_5.indices[2][1]] / 1000) km") +axflux_15 = CairoMakie.Axis(fig[1, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_15.grid.yᵃᶜᵃ[flux_NN_data_15.indices[2][1]] / 1000) km") +axflux_25 = CairoMakie.Axis(fig[2, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_25.grid.yᵃᶜᵃ[flux_NN_data_25.indices[2][1]] / 1000) km") +axflux_35 = CairoMakie.Axis(fig[2, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_35.grid.yᵃᶜᵃ[flux_NN_data_35.indices[2][1]] / 1000) km") +axflux_45 = CairoMakie.Axis(fig[3, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_45.grid.yᵃᶜᵃ[flux_NN_data_45.indices[2][1]] / 1000) km") +axflux_55 = CairoMakie.Axis(fig[3, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_55.grid.yᵃᶜᵃ[flux_NN_data_55.indices[2][1]] / 1000) km") +axflux_65 = CairoMakie.Axis(fig[4, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_65.grid.yᵃᶜᵃ[flux_NN_data_65.indices[2][1]] / 1000) km") +axflux_75 = CairoMakie.Axis(fig[4, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_75.grid.yᵃᶜᵃ[flux_NN_data_75.indices[2][1]] / 1000) km") +axflux_85 = CairoMakie.Axis(fig[5, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_85.grid.yᵃᶜᵃ[flux_NN_data_85.indices[2][1]] / 1000) km") +axflux_95 = CairoMakie.Axis(fig[5, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_95.grid.yᵃᶜᵃ[flux_NN_data_95.indices[2][1]] / 1000) km") + +axfields = [axfield_5, axfield_15, axfield_25, axfield_35, axfield_45, axfield_55, axfield_65, axfield_75, axfield_85, axfield_95] +axfluxes = [axflux_5, axflux_15, axflux_25, axflux_35, axflux_45, axflux_55, axflux_65, axflux_75, axflux_85, axflux_95] + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lims = [(find_min(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices))) for field_data in field_datas] + +flux_lims = [(find_min(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices))) for flux_data in flux_datas] + +flux_lims = [(-maximum(abs, flux_lim), maximum(abs, flux_lim)) for flux_lim in flux_lims] + +NNₙs = [@lift interior(field_data[$n], :, 1, zC_indices) for field_data in field_datas] +fluxₙs = [@lift interior(flux_data[$n], :, 1, zF_indices) for flux_data in flux_datas] + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_5_surface = heatmap!(axfield_5, xC, zC[zC_indices], NN_5ₙ, colormap=colorscheme_field, colorrange=field_lims[1]) +field_15_surface = heatmap!(axfield_15, xC, zC[zC_indices], NN_15ₙ, colormap=colorscheme_field, colorrange=field_lims[2]) +field_25_surface = heatmap!(axfield_25, xC, zC[zC_indices], NN_25ₙ, colormap=colorscheme_field, colorrange=field_lims[3]) +field_35_surface = heatmap!(axfield_35, xC, zC[zC_indices], NN_35ₙ, colormap=colorscheme_field, colorrange=field_lims[4]) +field_45_surface = heatmap!(axfield_45, xC, zC[zC_indices], NN_45ₙ, colormap=colorscheme_field, colorrange=field_lims[5]) +field_55_surface = heatmap!(axfield_55, xC, zC[zC_indices], NN_55ₙ, colormap=colorscheme_field, colorrange=field_lims[6]) +field_65_surface = heatmap!(axfield_65, xC, zC[zC_indices], NN_65ₙ, colormap=colorscheme_field, colorrange=field_lims[7]) +field_75_surface = heatmap!(axfield_75, xC, zC[zC_indices], NN_75ₙ, colormap=colorscheme_field, colorrange=field_lims[8]) +field_85_surface = heatmap!(axfield_85, xC, zC[zC_indices], NN_85ₙ, colormap=colorscheme_field, colorrange=field_lims[9]) +field_95_surface = heatmap!(axfield_95, xC, zC[zC_indices], NN_95ₙ, colormap=colorscheme_field, colorrange=field_lims[10]) + +flux_5_surface = heatmap!(axflux_5, xC, zC[zF_indices], flux_5ₙ, colormap=colorscheme_flux, colorrange=flux_lims[1]) +flux_15_surface = heatmap!(axflux_15, xC, zC[zF_indices], flux_15ₙ, colormap=colorscheme_flux, colorrange=flux_lims[2]) +flux_25_surface = heatmap!(axflux_25, xC, zC[zF_indices], flux_25ₙ, colormap=colorscheme_flux, colorrange=flux_lims[3]) +flux_35_surface = heatmap!(axflux_35, xC, zC[zF_indices], flux_35ₙ, colormap=colorscheme_flux, colorrange=flux_lims[4]) +flux_45_surface = heatmap!(axflux_45, xC, zC[zF_indices], flux_45ₙ, colormap=colorscheme_flux, colorrange=flux_lims[5]) +flux_55_surface = heatmap!(axflux_55, xC, zC[zF_indices], flux_55ₙ, colormap=colorscheme_flux, colorrange=flux_lims[6]) +flux_65_surface = heatmap!(axflux_65, xC, zC[zF_indices], flux_65ₙ, colormap=colorscheme_flux, colorrange=flux_lims[7]) +flux_75_surface = heatmap!(axflux_75, xC, zC[zF_indices], flux_75ₙ, colormap=colorscheme_flux, colorrange=flux_lims[8]) +flux_85_surface = heatmap!(axflux_85, xC, zC[zF_indices], flux_85ₙ, colormap=colorscheme_flux, colorrange=flux_lims[9]) +flux_95_surface = heatmap!(axflux_95, xC, zC[zF_indices], flux_95ₙ, colormap=colorscheme_flux, colorrange=flux_lims[10]) + +Colorbar(fig[1, 2], field_5_surface, label="Field") +Colorbar(fig[1, 4], flux_5_surface, label="NN Flux") +Colorbar(fig[1, 6], field_15_surface, label="Field") +Colorbar(fig[1, 8], flux_15_surface, label="NN Flux") +Colorbar(fig[2, 2], field_25_surface, label="Field") +Colorbar(fig[2, 4], flux_25_surface, label="NN Flux") +Colorbar(fig[2, 6], field_35_surface, label="Field") +Colorbar(fig[2, 8], flux_35_surface, label="NN Flux") +Colorbar(fig[3, 2], field_45_surface, label="Field") +Colorbar(fig[3, 4], flux_45_surface, label="NN Flux") +Colorbar(fig[3, 6], field_55_surface, label="Field") +Colorbar(fig[3, 8], flux_55_surface, label="NN Flux") +Colorbar(fig[4, 2], field_65_surface, label="Field") +Colorbar(fig[4, 4], flux_65_surface, label="NN Flux") +Colorbar(fig[4, 6], field_75_surface, label="Field") +Colorbar(fig[4, 8], flux_75_surface, label="NN Flux") +Colorbar(fig[5, 2], field_85_surface, label="Field") +Colorbar(fig[5, 4], flux_85_surface, label="NN Flux") +Colorbar(fig[5, 6], field_95_surface, label="Field") +Colorbar(fig[5, 8], flux_95_surface, label="NN Flux") + +for axfield in axfields + xlims!(axfield, minimum(xC), maximum(xC)) + ylims!(axfield, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +end + +for axflux in axfluxes + xlims!(axflux, minimum(xC), maximum(xC)) + ylims!(axflux, minimum(zC[zF_indices]), maximum(zC[zF_indices])) +end + +title_str = @lift "Temperature (°C), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_xzslices_fluxes_T.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end + +#%% +@info "Recording S fields and fluxes in xz" + +fieldname = "S" +fluxname = "wS_NN" + +field_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fieldname, backend=OnDisk()) +field_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fieldname, backend=OnDisk()) +field_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fieldname, backend=OnDisk()) +field_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fieldname, backend=OnDisk()) +field_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fieldname, backend=OnDisk()) +field_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fieldname, backend=OnDisk()) +field_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fieldname, backend=OnDisk()) +field_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fieldname, backend=OnDisk()) +field_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fieldname, backend=OnDisk()) +field_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fluxname, backend=OnDisk()) +flux_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fluxname, backend=OnDisk()) +flux_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fluxname, backend=OnDisk()) +flux_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fluxname, backend=OnDisk()) +flux_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fluxname, backend=OnDisk()) +flux_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fluxname, backend=OnDisk()) +flux_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fluxname, backend=OnDisk()) +flux_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fluxname, backend=OnDisk()) +flux_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fluxname, backend=OnDisk()) +flux_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fluxname, backend=OnDisk()) + +field_datas = [field_NN_data_5, field_NN_data_15, field_NN_data_25, field_NN_data_35, field_NN_data_45, field_NN_data_55, field_NN_data_65, field_NN_data_75, field_NN_data_85, field_NN_data_95] +flux_datas = [flux_NN_data_5, flux_NN_data_15, flux_NN_data_25, flux_NN_data_35, flux_NN_data_45, flux_NN_data_55, flux_NN_data_65, flux_NN_data_75, flux_NN_data_85, flux_NN_data_95] + +xC = field_NN_data_5.grid.xᶜᵃᵃ[1:field_NN_data_5.grid.Nx] +yC = field_NN_data_5.grid.yᵃᶜᵃ[1:field_NN_data_5.grid.Ny] +zC = field_NN_data_5.grid.zᵃᵃᶜ[1:field_NN_data_5.grid.Nz] +zF = field_NN_data_5.grid.zᵃᵃᶠ[1:field_NN_data_5.grid.Nz+1] + +Nt = length(field_NN_data_95) +times = field_NN_data_5.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_5 = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_5.grid.yᵃᶜᵃ[field_NN_data_5.indices[2][1]] / 1000) km") +axfield_15 = CairoMakie.Axis(fig[1, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_15.grid.yᵃᶜᵃ[field_NN_data_15.indices[2][1]] / 1000) km") +axfield_25 = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_25.grid.yᵃᶜᵃ[field_NN_data_25.indices[2][1]] / 1000) km") +axfield_35 = CairoMakie.Axis(fig[2, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_35.grid.yᵃᶜᵃ[field_NN_data_35.indices[2][1]] / 1000) km") +axfield_45 = CairoMakie.Axis(fig[3, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_45.grid.yᵃᶜᵃ[field_NN_data_45.indices[2][1]] / 1000) km") +axfield_55 = CairoMakie.Axis(fig[3, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_55.grid.yᵃᶜᵃ[field_NN_data_55.indices[2][1]] / 1000) km") +axfield_65 = CairoMakie.Axis(fig[4, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_65.grid.yᵃᶜᵃ[field_NN_data_65.indices[2][1]] / 1000) km") +axfield_75 = CairoMakie.Axis(fig[4, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_75.grid.yᵃᶜᵃ[field_NN_data_75.indices[2][1]] / 1000) km") +axfield_85 = CairoMakie.Axis(fig[5, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_85.grid.yᵃᶜᵃ[field_NN_data_85.indices[2][1]] / 1000) km") +axfield_95 = CairoMakie.Axis(fig[5, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_95.grid.yᵃᶜᵃ[field_NN_data_95.indices[2][1]] / 1000) km") + +axflux_5 = CairoMakie.Axis(fig[1, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_5.grid.yᵃᶜᵃ[flux_NN_data_5.indices[2][1]] / 1000) km") +axflux_15 = CairoMakie.Axis(fig[1, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_15.grid.yᵃᶜᵃ[flux_NN_data_15.indices[2][1]] / 1000) km") +axflux_25 = CairoMakie.Axis(fig[2, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_25.grid.yᵃᶜᵃ[flux_NN_data_25.indices[2][1]] / 1000) km") +axflux_35 = CairoMakie.Axis(fig[2, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_35.grid.yᵃᶜᵃ[flux_NN_data_35.indices[2][1]] / 1000) km") +axflux_45 = CairoMakie.Axis(fig[3, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_45.grid.yᵃᶜᵃ[flux_NN_data_45.indices[2][1]] / 1000) km") +axflux_55 = CairoMakie.Axis(fig[3, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_55.grid.yᵃᶜᵃ[flux_NN_data_55.indices[2][1]] / 1000) km") +axflux_65 = CairoMakie.Axis(fig[4, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_65.grid.yᵃᶜᵃ[flux_NN_data_65.indices[2][1]] / 1000) km") +axflux_75 = CairoMakie.Axis(fig[4, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_75.grid.yᵃᶜᵃ[flux_NN_data_75.indices[2][1]] / 1000) km") +axflux_85 = CairoMakie.Axis(fig[5, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_85.grid.yᵃᶜᵃ[flux_NN_data_85.indices[2][1]] / 1000) km") +axflux_95 = CairoMakie.Axis(fig[5, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_95.grid.yᵃᶜᵃ[flux_NN_data_95.indices[2][1]] / 1000) km") + +axfields = [axfield_5, axfield_15, axfield_25, axfield_35, axfield_45, axfield_55, axfield_65, axfield_75, axfield_85, axfield_95] +axfluxes = [axflux_5, axflux_15, axflux_25, axflux_35, axflux_45, axflux_55, axflux_65, axflux_75, axflux_85, axflux_95] + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lims = [(find_min(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices))) for field_data in field_datas] + +flux_lims = [(find_min(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices))) for flux_data in flux_datas] + +flux_lims = [(-maximum(abs, flux_lim), maximum(abs, flux_lim)) for flux_lim in flux_lims] + +NNₙs = [@lift interior(field_data[$n], :, 1, zC_indices) for field_data in field_datas] +fluxₙs = [@lift interior(flux_data[$n], :, 1, zF_indices) for flux_data in flux_datas] + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_5_surface = heatmap!(axfield_5, xC, zC[zC_indices], NN_5ₙ, colormap=colorscheme_field, colorrange=field_lims[1]) +field_15_surface = heatmap!(axfield_15, xC, zC[zC_indices], NN_15ₙ, colormap=colorscheme_field, colorrange=field_lims[2]) +field_25_surface = heatmap!(axfield_25, xC, zC[zC_indices], NN_25ₙ, colormap=colorscheme_field, colorrange=field_lims[3]) +field_35_surface = heatmap!(axfield_35, xC, zC[zC_indices], NN_35ₙ, colormap=colorscheme_field, colorrange=field_lims[4]) +field_45_surface = heatmap!(axfield_45, xC, zC[zC_indices], NN_45ₙ, colormap=colorscheme_field, colorrange=field_lims[5]) +field_55_surface = heatmap!(axfield_55, xC, zC[zC_indices], NN_55ₙ, colormap=colorscheme_field, colorrange=field_lims[6]) +field_65_surface = heatmap!(axfield_65, xC, zC[zC_indices], NN_65ₙ, colormap=colorscheme_field, colorrange=field_lims[7]) +field_75_surface = heatmap!(axfield_75, xC, zC[zC_indices], NN_75ₙ, colormap=colorscheme_field, colorrange=field_lims[8]) +field_85_surface = heatmap!(axfield_85, xC, zC[zC_indices], NN_85ₙ, colormap=colorscheme_field, colorrange=field_lims[9]) +field_95_surface = heatmap!(axfield_95, xC, zC[zC_indices], NN_95ₙ, colormap=colorscheme_field, colorrange=field_lims[10]) + +flux_5_surface = heatmap!(axflux_5, xC, zC[zF_indices], flux_5ₙ, colormap=colorscheme_flux, colorrange=flux_lims[1]) +flux_15_surface = heatmap!(axflux_15, xC, zC[zF_indices], flux_15ₙ, colormap=colorscheme_flux, colorrange=flux_lims[2]) +flux_25_surface = heatmap!(axflux_25, xC, zC[zF_indices], flux_25ₙ, colormap=colorscheme_flux, colorrange=flux_lims[3]) +flux_35_surface = heatmap!(axflux_35, xC, zC[zF_indices], flux_35ₙ, colormap=colorscheme_flux, colorrange=flux_lims[4]) +flux_45_surface = heatmap!(axflux_45, xC, zC[zF_indices], flux_45ₙ, colormap=colorscheme_flux, colorrange=flux_lims[5]) +flux_55_surface = heatmap!(axflux_55, xC, zC[zF_indices], flux_55ₙ, colormap=colorscheme_flux, colorrange=flux_lims[6]) +flux_65_surface = heatmap!(axflux_65, xC, zC[zF_indices], flux_65ₙ, colormap=colorscheme_flux, colorrange=flux_lims[7]) +flux_75_surface = heatmap!(axflux_75, xC, zC[zF_indices], flux_75ₙ, colormap=colorscheme_flux, colorrange=flux_lims[8]) +flux_85_surface = heatmap!(axflux_85, xC, zC[zF_indices], flux_85ₙ, colormap=colorscheme_flux, colorrange=flux_lims[9]) +flux_95_surface = heatmap!(axflux_95, xC, zC[zF_indices], flux_95ₙ, colormap=colorscheme_flux, colorrange=flux_lims[10]) + +Colorbar(fig[1, 2], field_5_surface, label="Field") +Colorbar(fig[1, 4], flux_5_surface, label="NN Flux") +Colorbar(fig[1, 6], field_15_surface, label="Field") +Colorbar(fig[1, 8], flux_15_surface, label="NN Flux") +Colorbar(fig[2, 2], field_25_surface, label="Field") +Colorbar(fig[2, 4], flux_25_surface, label="NN Flux") +Colorbar(fig[2, 6], field_35_surface, label="Field") +Colorbar(fig[2, 8], flux_35_surface, label="NN Flux") +Colorbar(fig[3, 2], field_45_surface, label="Field") +Colorbar(fig[3, 4], flux_45_surface, label="NN Flux") +Colorbar(fig[3, 6], field_55_surface, label="Field") +Colorbar(fig[3, 8], flux_55_surface, label="NN Flux") +Colorbar(fig[4, 2], field_65_surface, label="Field") +Colorbar(fig[4, 4], flux_65_surface, label="NN Flux") +Colorbar(fig[4, 6], field_75_surface, label="Field") +Colorbar(fig[4, 8], flux_75_surface, label="NN Flux") +Colorbar(fig[5, 2], field_85_surface, label="Field") +Colorbar(fig[5, 4], flux_85_surface, label="NN Flux") +Colorbar(fig[5, 6], field_95_surface, label="Field") +Colorbar(fig[5, 8], flux_95_surface, label="NN Flux") + +for axfield in axfields + xlims!(axfield, minimum(xC), maximum(xC)) + ylims!(axfield, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +end + +for axflux in axfluxes + xlims!(axflux, minimum(xC), maximum(xC)) + ylims!(axflux, minimum(zC[zF_indices]), maximum(zC[zF_indices])) +end + +title_str = @lift "Salinity (psu), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_xzslices_fluxes_S.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end +#%% \ No newline at end of file diff --git a/NNclosure_nof_BBLkappazonelast41_doublegyre_model.jl b/NNclosure_nof_BBLkappazonelast41_doublegyre_model.jl new file mode 100644 index 0000000000..7658778581 --- /dev/null +++ b/NNclosure_nof_BBLkappazonelast41_doublegyre_model.jl @@ -0,0 +1,534 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("NN_closure_global_nof_BBLkappazonelast41.jl") +include("xin_kai_vertical_diffusivity_local_2step.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + + +#%% +filename = "doublegyre_30Cwarmflushbottom10_relaxation_8days_NN_closure_NDE_BBLkappazonelast41_temp" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +nn_closure = NNFluxClosure(model_architecture) +base_closure = XinKaiLocalVerticalDiffusivity() +vertical_scalar_closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5) +closure = (base_closure, nn_closure, vertical_scalar_closure) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/8days + +##### +##### Forcing and initial condition +##### +@inline T_initial(x, y, z) = 10 + 20 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10950days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); +first_index = model.diffusivity_fields[2].first_index +last_index = model.diffusivity_fields[2].last_index +wT_NN = model.diffusivity_fields[2].wT +wS_NN = model.diffusivity_fields[2].wS + +κ = model.diffusivity_fields[1].κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +@inline function get_top_buoyancy_flux(i, j, k, grid, buoyancy, T_bc, S_bc, velocities, tracers, clock) + return top_buoyancy_flux(i, j, grid, buoyancy, (; T=T_bc, S=S_bc), clock, merge(velocities, tracers)) +end + +Qb = KernelFunctionOperation{Center, Center, Nothing}(get_top_buoyancy_flux, model.grid, model.buoyancy, T.boundary_conditions.top, S.boundary_conditions.top, model.velocities, model.tracers, model.clock) +Qb = Field(Qb) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +outputs = (; u, v, w, T, S, ρ, N², wT_NN, wS_NN, wT_base, wS_base) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:BBL] = JLD2OutputWriter(model, (; first_index, last_index, Qb), + filename = "$(FILE_DIR)/instantaneous_fields_NN_active_diagnostics", + indices = (:, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 365 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% \ No newline at end of file diff --git a/NNclosure_nof_BBLkappazonelast41_doublegyre_model_initialized.jl b/NNclosure_nof_BBLkappazonelast41_doublegyre_model_initialized.jl new file mode 100644 index 0000000000..eacddcbe77 --- /dev/null +++ b/NNclosure_nof_BBLkappazonelast41_doublegyre_model_initialized.jl @@ -0,0 +1,596 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("NN_closure_global_nof_BBLkappazonelast41.jl") +include("xin_kai_vertical_diffusivity_local_2step.jl") +# include("xin_kai_vertical_diffusivity_2Pr.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + + +#%% +filename = "doublegyre_relaxation_30days_NN_closure_2D_channel_NDE_FC_Qb_nof_BBLkappazonelast41_trainFC24new_scalingtrain54new_2layer_64_relu_2Pr_initialized_41" +FILE_DIR = "./Output/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +nn_closure = NNFluxClosure(model_architecture) +base_closure = XinKaiLocalVerticalDiffusivity() +vertical_scalar_closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5) +closure = (base_closure, nn_closure, vertical_scalar_closure) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +##### +##### Forcing and initial condition +##### + +@inline T_initial(x, y, z) = T_north + ΔT / 2 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +# noise(z) = rand() * exp(z / 8) + +# T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +# S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +# set!(model, T=T_initial_noisy, S=S_initial_noisy) +DATA_DIR = "./Output/doublegyre_30Cwarmflush_relaxation_8days_baseclosure_trainFC24new_scalingtrain54new_2Pr_2step" +# DATA_DIR = "./Output/doublegyre_30Cwarmflush_relaxation_30days_baseclosure_trainFC24new_scalingtrain54new_2Pr_2step" + +u_data = FieldTimeSeries("$(DATA_DIR)/instantaneous_fields.jld2", "u", backend=OnDisk()) +v_data = FieldTimeSeries("$(DATA_DIR)/instantaneous_fields.jld2", "v", backend=OnDisk()) +T_data = FieldTimeSeries("$(DATA_DIR)/instantaneous_fields.jld2", "T", backend=OnDisk()) +S_data = FieldTimeSeries("$(DATA_DIR)/instantaneous_fields.jld2", "S", backend=OnDisk()) + +ntimes = length(u_data.times) + +set!(model, T=T_data[ntimes], S=S_data[ntimes], u=u_data[ntimes], v=v_data[ntimes]) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 5110days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); +first_index = model.diffusivity_fields[2].first_index +last_index = model.diffusivity_fields[2].last_index +wT_NN = model.diffusivity_fields[2].wT +wS_NN = model.diffusivity_fields[2].wS + +κ = model.diffusivity_fields[1].κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +@inline function get_top_buoyancy_flux(i, j, k, grid, buoyancy, T_bc, S_bc, velocities, tracers, clock) + return top_buoyancy_flux(i, j, grid, buoyancy, (; T=T_bc, S=S_bc), clock, merge(velocities, tracers)) +end + +Qb = KernelFunctionOperation{Center, Center, Nothing}(get_top_buoyancy_flux, model.grid, model.buoyancy, T.boundary_conditions.top, S.boundary_conditions.top, model.velocities, model.tracers, model.clock) +Qb = Field(Qb) + +outputs = (; u, v, w, T, S, ρ, N², wT_NN, wS_NN, wT_base, wS_base) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_south] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_south", + indices = (:, 25, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_north] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_north", + indices = (:, 75, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:BBL] = JLD2OutputWriter(model, (; first_index, last_index, Qb), + filename = "$(FILE_DIR)/instantaneous_fields_NN_active_diagnostics", + indices = (:, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 365 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +N²_xz_north_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_north.jld2", "N²") +N²_xz_south_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_south.jld2", "N²") + +xC = N²_xz_north_data.grid.xᶜᵃᵃ[1:N²_xz_north_data.grid.Nx] +zf = N²_xz_north_data.grid.zᵃᵃᶠ[1:N²_xz_north_data.grid.Nz+1] + +yloc_north = N²_xz_north_data.grid.yᵃᶠᵃ[N²_xz_north_data.indices[2][1]] +yloc_south = N²_xz_south_data.grid.yᵃᶠᵃ[N²_xz_south_data.indices[2][1]] + +Nt = length(N²_xz_north_data) +times = N²_xz_north_data.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +N²_lim = (find_min(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) - 1e-13, + find_max(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) + 1e-13) +#%% +fig = Figure(size=(800, 800)) +ax_north = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_south / 1e3)) km") +ax_south = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_north / 1e3)) km") + +n = Observable(2) + +N²_north = @lift interior(N²_xz_north_data[$n], :, 1, :) +N²_south = @lift interior(N²_xz_south_data[$n], :, 1, :) + +colorscheme = colorschemes[:jet] + +N²_north_surface = heatmap!(ax_north, xC, zf, N²_north, colormap=colorscheme, colorrange=N²_lim) +N²_south_surface = heatmap!(ax_south, xC, zf, N²_south, colormap=colorscheme, colorrange=N²_lim) + +Colorbar(fig[1:2, 2], N²_north_surface, label="N² (s⁻²)") + +title_str = @lift "Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +@info "Recording buoyancy frequency xz slice" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_buoyancy_frequency_xz_slice.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end + +@info "Done!" +#%% \ No newline at end of file diff --git a/NNclosure_nof_BBLkappazonelast41_doublegyre_model_seasonalforcing.jl b/NNclosure_nof_BBLkappazonelast41_doublegyre_model_seasonalforcing.jl new file mode 100644 index 0000000000..c100a5a516 --- /dev/null +++ b/NNclosure_nof_BBLkappazonelast41_doublegyre_model_seasonalforcing.jl @@ -0,0 +1,539 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("NN_closure_global_nof_BBLkappazonelast41.jl") +include("xin_kai_vertical_diffusivity_local_2step.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + + +#%% +filename = "doublegyre_seasonalforcing_30C-20C_relaxation_8days_NN_closure_NDE_BBLkappazonelast41_temp" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +nn_closure = NNFluxClosure(model_architecture) +base_closure = XinKaiLocalVerticalDiffusivity() +vertical_scalar_closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5) +closure = (base_closure, nn_closure, vertical_scalar_closure) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/8days + +const seasonal_period = 360days +const seasonal_forcing_width = Ly / 6 +const seasonal_T_amplitude = 15 + +##### +##### Forcing and initial condition +##### +@inline T_initial(x, y, z) = 20 + 10 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_seasonal(y, t) = seasonal_T_amplitude * exp(-y^2/(2 * seasonal_forcing_width^2)) * sin(2π * t / seasonal_period) +@inline T_ref(y, t) = T_mid - ΔT / Ly * y + T_seasonal(y, t) +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y, t)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10800days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); +first_index = model.diffusivity_fields[2].first_index +last_index = model.diffusivity_fields[2].last_index +wT_NN = model.diffusivity_fields[2].wT +wS_NN = model.diffusivity_fields[2].wS + +κ = model.diffusivity_fields[1].κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +@inline function get_top_buoyancy_flux(i, j, k, grid, buoyancy, T_bc, S_bc, velocities, tracers, clock) + return top_buoyancy_flux(i, j, grid, buoyancy, (; T=T_bc, S=S_bc), clock, merge(velocities, tracers)) +end + +Qb = KernelFunctionOperation{Center, Center, Nothing}(get_top_buoyancy_flux, model.grid, model.buoyancy, T.boundary_conditions.top, S.boundary_conditions.top, model.velocities, model.tracers, model.clock) +Qb = Field(Qb) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +outputs = (; u, v, w, T, S, ρ, N², wT_NN, wS_NN, wT_base, wS_base) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:BBL] = JLD2OutputWriter(model, (; first_index, last_index, Qb), + filename = "$(FILE_DIR)/instantaneous_fields_NN_active_diagnostics", + indices = (:, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 360 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% \ No newline at end of file diff --git a/NNclosure_nof_BBLkappazonelast55_doublegyre_model.jl b/NNclosure_nof_BBLkappazonelast55_doublegyre_model.jl new file mode 100644 index 0000000000..89a98b92a1 --- /dev/null +++ b/NNclosure_nof_BBLkappazonelast55_doublegyre_model.jl @@ -0,0 +1,1265 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("NN_closure_global_nof_BBLkappazonelast55.jl") +include("xin_kai_vertical_diffusivity_local_2step.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + + +#%% +filename = "doublegyre_30Cwarmflushbottom10_relaxation_30days_zC2O_NN_closure_NDE_BBLkappazonelast55" +# filename = "doublegyre_30Cwarmflushbottom10_relaxation_30days_zWENO5_NN_closure_NDE_BBLkappazonelast55" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +nn_closure = NNFluxClosure(model_architecture) +base_closure = XinKaiLocalVerticalDiffusivity() +closure = (base_closure, nn_closure) + +advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), CenteredSecondOrder()) +# advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), WENO(order=5)) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +##### +##### Forcing and initial condition +##### +@inline T_initial(x, y, z) = 10 + 20 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10950days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); +first_index = model.diffusivity_fields[2].first_index +last_index = model.diffusivity_fields[2].last_index +wT_NN = model.diffusivity_fields[2].wT +wS_NN = model.diffusivity_fields[2].wS + +κ = model.diffusivity_fields[1].κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) + +wT = wT_NN + wT_base +wS = wS_NN + wS_base + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +@inline function get_top_buoyancy_flux(i, j, k, grid, buoyancy, T_bc, S_bc, velocities, tracers, clock) + return top_buoyancy_flux(i, j, grid, buoyancy, (; T=T_bc, S=S_bc), clock, merge(velocities, tracers)) +end + +Qb = KernelFunctionOperation{Center, Center, Nothing}(get_top_buoyancy_flux, model.grid, model.buoyancy, T.boundary_conditions.top, S.boundary_conditions.top, model.velocities, model.tracers, model.clock) +Qb = Field(Qb) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +wT_NNbar_zonal = Average(wT_NN, dims=1) +wS_NNbar_zonal = Average(wS_NN, dims=1) + +wT_basebar_zonal = Average(wT_base, dims=1) +wS_basebar_zonal = Average(wS_base, dims=1) + +wTbar_zonal = Average(wT, dims=1) +wSbar_zonal = Average(wS, dims=1) + +outputs = (; u, v, w, T, S, ρ, N², wT_NN, wS_NN, wT_base, wS_base, wT, wS) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal, wT_NNbar_zonal, wS_NNbar_zonal, wT_basebar_zonal, wS_basebar_zonal, wTbar_zonal, wSbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_5] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_5", + indices = (:, 5, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_15] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_15", + indices = (:, 15, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_25] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_25", + indices = (:, 25, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_35] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_35", + indices = (:, 35, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_45] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_45", + indices = (:, 45, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_55] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_55", + indices = (:, 55, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_65] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_65", + indices = (:, 65, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_75] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_75", + indices = (:, 75, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_85] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_85", + indices = (:, 85, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_95] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_95", + indices = (:, 95, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:BBL] = JLD2OutputWriter(model, (; first_index, last_index, Qb), + filename = "$(FILE_DIR)/instantaneous_fields_NN_active_diagnostics", + indices = (:, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end + +#%% +@info "Recording T fields and fluxes in yz" + +fieldname = "T" +fluxname = "wT_NN" +field_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fieldname, backend=OnDisk()) +field_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fieldname, backend=OnDisk()) +field_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fieldname, backend=OnDisk()) +field_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fieldname, backend=OnDisk()) +field_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fieldname, backend=OnDisk()) +field_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fieldname, backend=OnDisk()) +field_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fieldname, backend=OnDisk()) +field_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fieldname, backend=OnDisk()) +field_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fieldname, backend=OnDisk()) +field_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fluxname, backend=OnDisk()) +flux_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fluxname, backend=OnDisk()) +flux_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fluxname, backend=OnDisk()) +flux_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fluxname, backend=OnDisk()) +flux_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fluxname, backend=OnDisk()) +flux_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fluxname, backend=OnDisk()) +flux_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fluxname, backend=OnDisk()) +flux_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fluxname, backend=OnDisk()) +flux_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fluxname, backend=OnDisk()) +flux_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fluxname, backend=OnDisk()) + +xC = field_NN_data_00.grid.xᶜᵃᵃ[1:field_NN_data_00.grid.Nx] +yC = field_NN_data_00.grid.yᵃᶜᵃ[1:field_NN_data_00.grid.Ny] +zC = field_NN_data_00.grid.zᵃᵃᶜ[1:field_NN_data_00.grid.Nz] +zF = field_NN_data_00.grid.zᵃᵃᶠ[1:field_NN_data_00.grid.Nz+1] + +Nt = length(field_NN_data_90) +times = field_NN_data_00.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_00 = CairoMakie.Axis(fig[1, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_00.grid.xᶜᵃᵃ[field_NN_data_00.indices[1][1]] / 1000) km") +axfield_10 = CairoMakie.Axis(fig[1, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_10.grid.xᶜᵃᵃ[field_NN_data_10.indices[1][1]] / 1000) km") +axfield_20 = CairoMakie.Axis(fig[2, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_20.grid.xᶜᵃᵃ[field_NN_data_20.indices[1][1]] / 1000) km") +axfield_30 = CairoMakie.Axis(fig[2, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_30.grid.xᶜᵃᵃ[field_NN_data_30.indices[1][1]] / 1000) km") +axfield_40 = CairoMakie.Axis(fig[3, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_40.grid.xᶜᵃᵃ[field_NN_data_40.indices[1][1]] / 1000) km") +axfield_50 = CairoMakie.Axis(fig[3, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_50.grid.xᶜᵃᵃ[field_NN_data_50.indices[1][1]] / 1000) km") +axfield_60 = CairoMakie.Axis(fig[4, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_60.grid.xᶜᵃᵃ[field_NN_data_60.indices[1][1]] / 1000) km") +axfield_70 = CairoMakie.Axis(fig[4, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_70.grid.xᶜᵃᵃ[field_NN_data_70.indices[1][1]] / 1000) km") +axfield_80 = CairoMakie.Axis(fig[5, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_80.grid.xᶜᵃᵃ[field_NN_data_80.indices[1][1]] / 1000) km") +axfield_90 = CairoMakie.Axis(fig[5, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_90.grid.xᶜᵃᵃ[field_NN_data_90.indices[1][1]] / 1000) km") + +axflux_00 = CairoMakie.Axis(fig[1, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_00.grid.xᶜᵃᵃ[flux_NN_data_00.indices[1][1]] / 1000) km") +axflux_10 = CairoMakie.Axis(fig[1, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_10.grid.xᶜᵃᵃ[flux_NN_data_10.indices[1][1]] / 1000) km") +axflux_20 = CairoMakie.Axis(fig[2, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_20.grid.xᶜᵃᵃ[flux_NN_data_20.indices[1][1]] / 1000) km") +axflux_30 = CairoMakie.Axis(fig[2, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_30.grid.xᶜᵃᵃ[flux_NN_data_30.indices[1][1]] / 1000) km") +axflux_40 = CairoMakie.Axis(fig[3, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_40.grid.xᶜᵃᵃ[flux_NN_data_40.indices[1][1]] / 1000) km") +axflux_50 = CairoMakie.Axis(fig[3, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_50.grid.xᶜᵃᵃ[flux_NN_data_50.indices[1][1]] / 1000) km") +axflux_60 = CairoMakie.Axis(fig[4, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_60.grid.xᶜᵃᵃ[flux_NN_data_60.indices[1][1]] / 1000) km") +axflux_70 = CairoMakie.Axis(fig[4, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_70.grid.xᶜᵃᵃ[flux_NN_data_70.indices[1][1]] / 1000) km") +axflux_80 = CairoMakie.Axis(fig[5, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_80.grid.xᶜᵃᵃ[flux_NN_data_80.indices[1][1]] / 1000) km") +axflux_90 = CairoMakie.Axis(fig[5, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_90.grid.xᶜᵃᵃ[flux_NN_data_90.indices[1][1]] / 1000) km") + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lim = (find_min(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (find_min(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (-maximum(abs, flux_lim), maximum(abs, flux_lim)) + +NN_00ₙ = @lift interior(field_NN_data_00[$n], 1, :, zC_indices) +NN_10ₙ = @lift interior(field_NN_data_10[$n], 1, :, zC_indices) +NN_20ₙ = @lift interior(field_NN_data_20[$n], 1, :, zC_indices) +NN_30ₙ = @lift interior(field_NN_data_30[$n], 1, :, zC_indices) +NN_40ₙ = @lift interior(field_NN_data_40[$n], 1, :, zC_indices) +NN_50ₙ = @lift interior(field_NN_data_50[$n], 1, :, zC_indices) +NN_60ₙ = @lift interior(field_NN_data_60[$n], 1, :, zC_indices) +NN_70ₙ = @lift interior(field_NN_data_70[$n], 1, :, zC_indices) +NN_80ₙ = @lift interior(field_NN_data_80[$n], 1, :, zC_indices) +NN_90ₙ = @lift interior(field_NN_data_90[$n], 1, :, zC_indices) + +flux_00ₙ = @lift interior(flux_NN_data_00[$n], 1, :, zF_indices) +flux_10ₙ = @lift interior(flux_NN_data_10[$n], 1, :, zF_indices) +flux_20ₙ = @lift interior(flux_NN_data_20[$n], 1, :, zF_indices) +flux_30ₙ = @lift interior(flux_NN_data_30[$n], 1, :, zF_indices) +flux_40ₙ = @lift interior(flux_NN_data_40[$n], 1, :, zF_indices) +flux_50ₙ = @lift interior(flux_NN_data_50[$n], 1, :, zF_indices) +flux_60ₙ = @lift interior(flux_NN_data_60[$n], 1, :, zF_indices) +flux_70ₙ = @lift interior(flux_NN_data_70[$n], 1, :, zF_indices) +flux_80ₙ = @lift interior(flux_NN_data_80[$n], 1, :, zF_indices) +flux_90ₙ = @lift interior(flux_NN_data_90[$n], 1, :, zF_indices) + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_00_surface = heatmap!(axfield_00, yC, zC[zC_indices], NN_00ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_10_surface = heatmap!(axfield_10, yC, zC[zC_indices], NN_10ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_20_surface = heatmap!(axfield_20, yC, zC[zC_indices], NN_20ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_30_surface = heatmap!(axfield_30, yC, zC[zC_indices], NN_30ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_40_surface = heatmap!(axfield_40, yC, zC[zC_indices], NN_40ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_50_surface = heatmap!(axfield_50, yC, zC[zC_indices], NN_50ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_60_surface = heatmap!(axfield_60, yC, zC[zC_indices], NN_60ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_70_surface = heatmap!(axfield_70, yC, zC[zC_indices], NN_70ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_80_surface = heatmap!(axfield_80, yC, zC[zC_indices], NN_80ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_90_surface = heatmap!(axfield_90, yC, zC[zC_indices], NN_90ₙ, colormap=colorscheme_field, colorrange=field_lim) + +flux_00_surface = heatmap!(axflux_00, yC, zC[zF_indices], flux_00ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_10_surface = heatmap!(axflux_10, yC, zC[zF_indices], flux_10ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_20_surface = heatmap!(axflux_20, yC, zC[zF_indices], flux_20ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_30_surface = heatmap!(axflux_30, yC, zC[zF_indices], flux_30ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_40_surface = heatmap!(axflux_40, yC, zC[zF_indices], flux_40ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_50_surface = heatmap!(axflux_50, yC, zC[zF_indices], flux_50ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_60_surface = heatmap!(axflux_60, yC, zC[zF_indices], flux_60ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_70_surface = heatmap!(axflux_70, yC, zC[zF_indices], flux_70ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_80_surface = heatmap!(axflux_80, yC, zC[zF_indices], flux_80ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_90_surface = heatmap!(axflux_90, yC, zC[zF_indices], flux_90ₙ, colormap=colorscheme_flux, colorrange=flux_lim) + +Colorbar(fig[1:5, 2], field_00_surface, label="Field") +Colorbar(fig[1:5, 4], flux_00_surface, label="NN Flux") +Colorbar(fig[1:5, 6], field_00_surface, label="Field") +Colorbar(fig[1:5, 8], flux_00_surface, label="NN Flux") + +xlims!(axfield_00, minimum(yC), maximum(yC)) +xlims!(axfield_10, minimum(yC), maximum(yC)) +xlims!(axfield_20, minimum(yC), maximum(yC)) +xlims!(axfield_30, minimum(yC), maximum(yC)) +xlims!(axfield_40, minimum(yC), maximum(yC)) +xlims!(axfield_50, minimum(yC), maximum(yC)) +xlims!(axfield_60, minimum(yC), maximum(yC)) +xlims!(axfield_70, minimum(yC), maximum(yC)) +xlims!(axfield_80, minimum(yC), maximum(yC)) +xlims!(axfield_90, minimum(yC), maximum(yC)) + +ylims!(axfield_00, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_10, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_20, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_30, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_40, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_50, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_60, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_70, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_80, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_90, minimum(zC[zC_indices]), maximum(zC[zC_indices])) + +xlims!(axflux_00, minimum(yC), maximum(yC)) +xlims!(axflux_10, minimum(yC), maximum(yC)) +xlims!(axflux_20, minimum(yC), maximum(yC)) +xlims!(axflux_30, minimum(yC), maximum(yC)) +xlims!(axflux_40, minimum(yC), maximum(yC)) +xlims!(axflux_50, minimum(yC), maximum(yC)) +xlims!(axflux_60, minimum(yC), maximum(yC)) +xlims!(axflux_70, minimum(yC), maximum(yC)) +xlims!(axflux_80, minimum(yC), maximum(yC)) +xlims!(axflux_90, minimum(yC), maximum(yC)) + +ylims!(axflux_00, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_10, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_20, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_30, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_40, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_50, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_60, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_70, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_80, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_90, minimum(zF[zF_indices]), maximum(zF[zF_indices])) + +title_str = @lift "Temperature (°C), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +# save("./Output/compare_3D_instantaneous_fields_slices_NNclosure_fluxes.png", fig) +# display(fig) +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_yzslices_fluxes_T.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end +#%% +@info "Recording S fields and fluxes in yz" + +fieldname = "S" +fluxname = "wS_NN" +field_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fieldname, backend=OnDisk()) +field_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fieldname, backend=OnDisk()) +field_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fieldname, backend=OnDisk()) +field_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fieldname, backend=OnDisk()) +field_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fieldname, backend=OnDisk()) +field_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fieldname, backend=OnDisk()) +field_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fieldname, backend=OnDisk()) +field_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fieldname, backend=OnDisk()) +field_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fieldname, backend=OnDisk()) +field_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fluxname, backend=OnDisk()) +flux_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fluxname, backend=OnDisk()) +flux_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fluxname, backend=OnDisk()) +flux_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fluxname, backend=OnDisk()) +flux_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fluxname, backend=OnDisk()) +flux_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fluxname, backend=OnDisk()) +flux_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fluxname, backend=OnDisk()) +flux_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fluxname, backend=OnDisk()) +flux_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fluxname, backend=OnDisk()) +flux_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fluxname, backend=OnDisk()) + +xC = field_NN_data_00.grid.xᶜᵃᵃ[1:field_NN_data_00.grid.Nx] +yC = field_NN_data_00.grid.yᵃᶜᵃ[1:field_NN_data_00.grid.Ny] +zC = field_NN_data_00.grid.zᵃᵃᶜ[1:field_NN_data_00.grid.Nz] +zF = field_NN_data_00.grid.zᵃᵃᶠ[1:field_NN_data_00.grid.Nz+1] + +Nt = length(field_NN_data_90) +times = field_NN_data_00.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_00 = CairoMakie.Axis(fig[1, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_00.grid.xᶜᵃᵃ[field_NN_data_00.indices[1][1]] / 1000) km") +axfield_10 = CairoMakie.Axis(fig[1, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_10.grid.xᶜᵃᵃ[field_NN_data_10.indices[1][1]] / 1000) km") +axfield_20 = CairoMakie.Axis(fig[2, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_20.grid.xᶜᵃᵃ[field_NN_data_20.indices[1][1]] / 1000) km") +axfield_30 = CairoMakie.Axis(fig[2, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_30.grid.xᶜᵃᵃ[field_NN_data_30.indices[1][1]] / 1000) km") +axfield_40 = CairoMakie.Axis(fig[3, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_40.grid.xᶜᵃᵃ[field_NN_data_40.indices[1][1]] / 1000) km") +axfield_50 = CairoMakie.Axis(fig[3, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_50.grid.xᶜᵃᵃ[field_NN_data_50.indices[1][1]] / 1000) km") +axfield_60 = CairoMakie.Axis(fig[4, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_60.grid.xᶜᵃᵃ[field_NN_data_60.indices[1][1]] / 1000) km") +axfield_70 = CairoMakie.Axis(fig[4, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_70.grid.xᶜᵃᵃ[field_NN_data_70.indices[1][1]] / 1000) km") +axfield_80 = CairoMakie.Axis(fig[5, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_80.grid.xᶜᵃᵃ[field_NN_data_80.indices[1][1]] / 1000) km") +axfield_90 = CairoMakie.Axis(fig[5, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_90.grid.xᶜᵃᵃ[field_NN_data_90.indices[1][1]] / 1000) km") + +axflux_00 = CairoMakie.Axis(fig[1, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_00.grid.xᶜᵃᵃ[flux_NN_data_00.indices[1][1]] / 1000) km") +axflux_10 = CairoMakie.Axis(fig[1, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_10.grid.xᶜᵃᵃ[flux_NN_data_10.indices[1][1]] / 1000) km") +axflux_20 = CairoMakie.Axis(fig[2, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_20.grid.xᶜᵃᵃ[flux_NN_data_20.indices[1][1]] / 1000) km") +axflux_30 = CairoMakie.Axis(fig[2, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_30.grid.xᶜᵃᵃ[flux_NN_data_30.indices[1][1]] / 1000) km") +axflux_40 = CairoMakie.Axis(fig[3, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_40.grid.xᶜᵃᵃ[flux_NN_data_40.indices[1][1]] / 1000) km") +axflux_50 = CairoMakie.Axis(fig[3, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_50.grid.xᶜᵃᵃ[flux_NN_data_50.indices[1][1]] / 1000) km") +axflux_60 = CairoMakie.Axis(fig[4, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_60.grid.xᶜᵃᵃ[flux_NN_data_60.indices[1][1]] / 1000) km") +axflux_70 = CairoMakie.Axis(fig[4, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_70.grid.xᶜᵃᵃ[flux_NN_data_70.indices[1][1]] / 1000) km") +axflux_80 = CairoMakie.Axis(fig[5, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_80.grid.xᶜᵃᵃ[flux_NN_data_80.indices[1][1]] / 1000) km") +axflux_90 = CairoMakie.Axis(fig[5, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_90.grid.xᶜᵃᵃ[flux_NN_data_90.indices[1][1]] / 1000) km") + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lim = (find_min(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (find_min(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (-maximum(abs, flux_lim), maximum(abs, flux_lim)) + +NN_00ₙ = @lift interior(field_NN_data_00[$n], 1, :, zC_indices) +NN_10ₙ = @lift interior(field_NN_data_10[$n], 1, :, zC_indices) +NN_20ₙ = @lift interior(field_NN_data_20[$n], 1, :, zC_indices) +NN_30ₙ = @lift interior(field_NN_data_30[$n], 1, :, zC_indices) +NN_40ₙ = @lift interior(field_NN_data_40[$n], 1, :, zC_indices) +NN_50ₙ = @lift interior(field_NN_data_50[$n], 1, :, zC_indices) +NN_60ₙ = @lift interior(field_NN_data_60[$n], 1, :, zC_indices) +NN_70ₙ = @lift interior(field_NN_data_70[$n], 1, :, zC_indices) +NN_80ₙ = @lift interior(field_NN_data_80[$n], 1, :, zC_indices) +NN_90ₙ = @lift interior(field_NN_data_90[$n], 1, :, zC_indices) + +flux_00ₙ = @lift interior(flux_NN_data_00[$n], 1, :, zF_indices) +flux_10ₙ = @lift interior(flux_NN_data_10[$n], 1, :, zF_indices) +flux_20ₙ = @lift interior(flux_NN_data_20[$n], 1, :, zF_indices) +flux_30ₙ = @lift interior(flux_NN_data_30[$n], 1, :, zF_indices) +flux_40ₙ = @lift interior(flux_NN_data_40[$n], 1, :, zF_indices) +flux_50ₙ = @lift interior(flux_NN_data_50[$n], 1, :, zF_indices) +flux_60ₙ = @lift interior(flux_NN_data_60[$n], 1, :, zF_indices) +flux_70ₙ = @lift interior(flux_NN_data_70[$n], 1, :, zF_indices) +flux_80ₙ = @lift interior(flux_NN_data_80[$n], 1, :, zF_indices) +flux_90ₙ = @lift interior(flux_NN_data_90[$n], 1, :, zF_indices) + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_00_surface = heatmap!(axfield_00, yC, zC[zC_indices], NN_00ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_10_surface = heatmap!(axfield_10, yC, zC[zC_indices], NN_10ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_20_surface = heatmap!(axfield_20, yC, zC[zC_indices], NN_20ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_30_surface = heatmap!(axfield_30, yC, zC[zC_indices], NN_30ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_40_surface = heatmap!(axfield_40, yC, zC[zC_indices], NN_40ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_50_surface = heatmap!(axfield_50, yC, zC[zC_indices], NN_50ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_60_surface = heatmap!(axfield_60, yC, zC[zC_indices], NN_60ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_70_surface = heatmap!(axfield_70, yC, zC[zC_indices], NN_70ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_80_surface = heatmap!(axfield_80, yC, zC[zC_indices], NN_80ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_90_surface = heatmap!(axfield_90, yC, zC[zC_indices], NN_90ₙ, colormap=colorscheme_field, colorrange=field_lim) + +flux_00_surface = heatmap!(axflux_00, yC, zC[zF_indices], flux_00ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_10_surface = heatmap!(axflux_10, yC, zC[zF_indices], flux_10ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_20_surface = heatmap!(axflux_20, yC, zC[zF_indices], flux_20ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_30_surface = heatmap!(axflux_30, yC, zC[zF_indices], flux_30ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_40_surface = heatmap!(axflux_40, yC, zC[zF_indices], flux_40ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_50_surface = heatmap!(axflux_50, yC, zC[zF_indices], flux_50ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_60_surface = heatmap!(axflux_60, yC, zC[zF_indices], flux_60ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_70_surface = heatmap!(axflux_70, yC, zC[zF_indices], flux_70ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_80_surface = heatmap!(axflux_80, yC, zC[zF_indices], flux_80ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_90_surface = heatmap!(axflux_90, yC, zC[zF_indices], flux_90ₙ, colormap=colorscheme_flux, colorrange=flux_lim) + +Colorbar(fig[1:5, 2], field_00_surface, label="Field") +Colorbar(fig[1:5, 4], flux_00_surface, label="NN Flux") +Colorbar(fig[1:5, 6], field_00_surface, label="Field") +Colorbar(fig[1:5, 8], flux_00_surface, label="NN Flux") + +xlims!(axfield_00, minimum(yC), maximum(yC)) +xlims!(axfield_10, minimum(yC), maximum(yC)) +xlims!(axfield_20, minimum(yC), maximum(yC)) +xlims!(axfield_30, minimum(yC), maximum(yC)) +xlims!(axfield_40, minimum(yC), maximum(yC)) +xlims!(axfield_50, minimum(yC), maximum(yC)) +xlims!(axfield_60, minimum(yC), maximum(yC)) +xlims!(axfield_70, minimum(yC), maximum(yC)) +xlims!(axfield_80, minimum(yC), maximum(yC)) +xlims!(axfield_90, minimum(yC), maximum(yC)) + +ylims!(axfield_00, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_10, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_20, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_30, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_40, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_50, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_60, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_70, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_80, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_90, minimum(zC[zC_indices]), maximum(zC[zC_indices])) + +xlims!(axflux_00, minimum(yC), maximum(yC)) +xlims!(axflux_10, minimum(yC), maximum(yC)) +xlims!(axflux_20, minimum(yC), maximum(yC)) +xlims!(axflux_30, minimum(yC), maximum(yC)) +xlims!(axflux_40, minimum(yC), maximum(yC)) +xlims!(axflux_50, minimum(yC), maximum(yC)) +xlims!(axflux_60, minimum(yC), maximum(yC)) +xlims!(axflux_70, minimum(yC), maximum(yC)) +xlims!(axflux_80, minimum(yC), maximum(yC)) +xlims!(axflux_90, minimum(yC), maximum(yC)) + +ylims!(axflux_00, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_10, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_20, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_30, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_40, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_50, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_60, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_70, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_80, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_90, minimum(zF[zF_indices]), maximum(zF[zF_indices])) + +title_str = @lift "Salinity (psu), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_yzslices_fluxes_S.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +#%% +@info "Recording T fields and fluxes in xz" + +fieldname = "T" +fluxname = "wT_NN" + +field_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fieldname, backend=OnDisk()) +field_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fieldname, backend=OnDisk()) +field_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fieldname, backend=OnDisk()) +field_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fieldname, backend=OnDisk()) +field_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fieldname, backend=OnDisk()) +field_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fieldname, backend=OnDisk()) +field_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fieldname, backend=OnDisk()) +field_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fieldname, backend=OnDisk()) +field_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fieldname, backend=OnDisk()) +field_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fluxname, backend=OnDisk()) +flux_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fluxname, backend=OnDisk()) +flux_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fluxname, backend=OnDisk()) +flux_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fluxname, backend=OnDisk()) +flux_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fluxname, backend=OnDisk()) +flux_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fluxname, backend=OnDisk()) +flux_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fluxname, backend=OnDisk()) +flux_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fluxname, backend=OnDisk()) +flux_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fluxname, backend=OnDisk()) +flux_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fluxname, backend=OnDisk()) + +field_datas = [field_NN_data_5, field_NN_data_15, field_NN_data_25, field_NN_data_35, field_NN_data_45, field_NN_data_55, field_NN_data_65, field_NN_data_75, field_NN_data_85, field_NN_data_95] +flux_datas = [flux_NN_data_5, flux_NN_data_15, flux_NN_data_25, flux_NN_data_35, flux_NN_data_45, flux_NN_data_55, flux_NN_data_65, flux_NN_data_75, flux_NN_data_85, flux_NN_data_95] + +xC = field_NN_data_5.grid.xᶜᵃᵃ[1:field_NN_data_5.grid.Nx] +yC = field_NN_data_5.grid.yᵃᶜᵃ[1:field_NN_data_5.grid.Ny] +zC = field_NN_data_5.grid.zᵃᵃᶜ[1:field_NN_data_5.grid.Nz] +zF = field_NN_data_5.grid.zᵃᵃᶠ[1:field_NN_data_5.grid.Nz+1] + +Nt = length(field_NN_data_95) +times = field_NN_data_5.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_5 = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_5.grid.yᵃᶜᵃ[field_NN_data_5.indices[2][1]] / 1000) km") +axfield_15 = CairoMakie.Axis(fig[1, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_15.grid.yᵃᶜᵃ[field_NN_data_15.indices[2][1]] / 1000) km") +axfield_25 = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_25.grid.yᵃᶜᵃ[field_NN_data_25.indices[2][1]] / 1000) km") +axfield_35 = CairoMakie.Axis(fig[2, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_35.grid.yᵃᶜᵃ[field_NN_data_35.indices[2][1]] / 1000) km") +axfield_45 = CairoMakie.Axis(fig[3, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_45.grid.yᵃᶜᵃ[field_NN_data_45.indices[2][1]] / 1000) km") +axfield_55 = CairoMakie.Axis(fig[3, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_55.grid.yᵃᶜᵃ[field_NN_data_55.indices[2][1]] / 1000) km") +axfield_65 = CairoMakie.Axis(fig[4, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_65.grid.yᵃᶜᵃ[field_NN_data_65.indices[2][1]] / 1000) km") +axfield_75 = CairoMakie.Axis(fig[4, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_75.grid.yᵃᶜᵃ[field_NN_data_75.indices[2][1]] / 1000) km") +axfield_85 = CairoMakie.Axis(fig[5, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_85.grid.yᵃᶜᵃ[field_NN_data_85.indices[2][1]] / 1000) km") +axfield_95 = CairoMakie.Axis(fig[5, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_95.grid.yᵃᶜᵃ[field_NN_data_95.indices[2][1]] / 1000) km") + +axflux_5 = CairoMakie.Axis(fig[1, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_5.grid.yᵃᶜᵃ[flux_NN_data_5.indices[2][1]] / 1000) km") +axflux_15 = CairoMakie.Axis(fig[1, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_15.grid.yᵃᶜᵃ[flux_NN_data_15.indices[2][1]] / 1000) km") +axflux_25 = CairoMakie.Axis(fig[2, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_25.grid.yᵃᶜᵃ[flux_NN_data_25.indices[2][1]] / 1000) km") +axflux_35 = CairoMakie.Axis(fig[2, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_35.grid.yᵃᶜᵃ[flux_NN_data_35.indices[2][1]] / 1000) km") +axflux_45 = CairoMakie.Axis(fig[3, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_45.grid.yᵃᶜᵃ[flux_NN_data_45.indices[2][1]] / 1000) km") +axflux_55 = CairoMakie.Axis(fig[3, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_55.grid.yᵃᶜᵃ[flux_NN_data_55.indices[2][1]] / 1000) km") +axflux_65 = CairoMakie.Axis(fig[4, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_65.grid.yᵃᶜᵃ[flux_NN_data_65.indices[2][1]] / 1000) km") +axflux_75 = CairoMakie.Axis(fig[4, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_75.grid.yᵃᶜᵃ[flux_NN_data_75.indices[2][1]] / 1000) km") +axflux_85 = CairoMakie.Axis(fig[5, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_85.grid.yᵃᶜᵃ[flux_NN_data_85.indices[2][1]] / 1000) km") +axflux_95 = CairoMakie.Axis(fig[5, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_95.grid.yᵃᶜᵃ[flux_NN_data_95.indices[2][1]] / 1000) km") + +axfields = [axfield_5, axfield_15, axfield_25, axfield_35, axfield_45, axfield_55, axfield_65, axfield_75, axfield_85, axfield_95] +axfluxes = [axflux_5, axflux_15, axflux_25, axflux_35, axflux_45, axflux_55, axflux_65, axflux_75, axflux_85, axflux_95] + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lims = [(find_min(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices))) for field_data in field_datas] + +flux_lims = [(find_min(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices))) for flux_data in flux_datas] + +flux_lims = [(-maximum(abs, flux_lim), maximum(abs, flux_lim)) for flux_lim in flux_lims] + +NNₙs = [@lift interior(field_data[$n], :, 1, zC_indices) for field_data in field_datas] +fluxₙs = [@lift interior(flux_data[$n], :, 1, zF_indices) for flux_data in flux_datas] + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_5_surface = heatmap!(axfield_5, xC, zC[zC_indices], NN_5ₙ, colormap=colorscheme_field, colorrange=field_lims[1]) +field_15_surface = heatmap!(axfield_15, xC, zC[zC_indices], NN_15ₙ, colormap=colorscheme_field, colorrange=field_lims[2]) +field_25_surface = heatmap!(axfield_25, xC, zC[zC_indices], NN_25ₙ, colormap=colorscheme_field, colorrange=field_lims[3]) +field_35_surface = heatmap!(axfield_35, xC, zC[zC_indices], NN_35ₙ, colormap=colorscheme_field, colorrange=field_lims[4]) +field_45_surface = heatmap!(axfield_45, xC, zC[zC_indices], NN_45ₙ, colormap=colorscheme_field, colorrange=field_lims[5]) +field_55_surface = heatmap!(axfield_55, xC, zC[zC_indices], NN_55ₙ, colormap=colorscheme_field, colorrange=field_lims[6]) +field_65_surface = heatmap!(axfield_65, xC, zC[zC_indices], NN_65ₙ, colormap=colorscheme_field, colorrange=field_lims[7]) +field_75_surface = heatmap!(axfield_75, xC, zC[zC_indices], NN_75ₙ, colormap=colorscheme_field, colorrange=field_lims[8]) +field_85_surface = heatmap!(axfield_85, xC, zC[zC_indices], NN_85ₙ, colormap=colorscheme_field, colorrange=field_lims[9]) +field_95_surface = heatmap!(axfield_95, xC, zC[zC_indices], NN_95ₙ, colormap=colorscheme_field, colorrange=field_lims[10]) + +flux_5_surface = heatmap!(axflux_5, xC, zC[zF_indices], flux_5ₙ, colormap=colorscheme_flux, colorrange=flux_lims[1]) +flux_15_surface = heatmap!(axflux_15, xC, zC[zF_indices], flux_15ₙ, colormap=colorscheme_flux, colorrange=flux_lims[2]) +flux_25_surface = heatmap!(axflux_25, xC, zC[zF_indices], flux_25ₙ, colormap=colorscheme_flux, colorrange=flux_lims[3]) +flux_35_surface = heatmap!(axflux_35, xC, zC[zF_indices], flux_35ₙ, colormap=colorscheme_flux, colorrange=flux_lims[4]) +flux_45_surface = heatmap!(axflux_45, xC, zC[zF_indices], flux_45ₙ, colormap=colorscheme_flux, colorrange=flux_lims[5]) +flux_55_surface = heatmap!(axflux_55, xC, zC[zF_indices], flux_55ₙ, colormap=colorscheme_flux, colorrange=flux_lims[6]) +flux_65_surface = heatmap!(axflux_65, xC, zC[zF_indices], flux_65ₙ, colormap=colorscheme_flux, colorrange=flux_lims[7]) +flux_75_surface = heatmap!(axflux_75, xC, zC[zF_indices], flux_75ₙ, colormap=colorscheme_flux, colorrange=flux_lims[8]) +flux_85_surface = heatmap!(axflux_85, xC, zC[zF_indices], flux_85ₙ, colormap=colorscheme_flux, colorrange=flux_lims[9]) +flux_95_surface = heatmap!(axflux_95, xC, zC[zF_indices], flux_95ₙ, colormap=colorscheme_flux, colorrange=flux_lims[10]) + +Colorbar(fig[1, 2], field_5_surface, label="Field") +Colorbar(fig[1, 4], flux_5_surface, label="NN Flux") +Colorbar(fig[1, 6], field_15_surface, label="Field") +Colorbar(fig[1, 8], flux_15_surface, label="NN Flux") +Colorbar(fig[2, 2], field_25_surface, label="Field") +Colorbar(fig[2, 4], flux_25_surface, label="NN Flux") +Colorbar(fig[2, 6], field_35_surface, label="Field") +Colorbar(fig[2, 8], flux_35_surface, label="NN Flux") +Colorbar(fig[3, 2], field_45_surface, label="Field") +Colorbar(fig[3, 4], flux_45_surface, label="NN Flux") +Colorbar(fig[3, 6], field_55_surface, label="Field") +Colorbar(fig[3, 8], flux_55_surface, label="NN Flux") +Colorbar(fig[4, 2], field_65_surface, label="Field") +Colorbar(fig[4, 4], flux_65_surface, label="NN Flux") +Colorbar(fig[4, 6], field_75_surface, label="Field") +Colorbar(fig[4, 8], flux_75_surface, label="NN Flux") +Colorbar(fig[5, 2], field_85_surface, label="Field") +Colorbar(fig[5, 4], flux_85_surface, label="NN Flux") +Colorbar(fig[5, 6], field_95_surface, label="Field") +Colorbar(fig[5, 8], flux_95_surface, label="NN Flux") + +for axfield in axfields + xlims!(axfield, minimum(xC), maximum(xC)) + ylims!(axfield, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +end + +for axflux in axfluxes + xlims!(axflux, minimum(xC), maximum(xC)) + ylims!(axflux, minimum(zC[zF_indices]), maximum(zC[zF_indices])) +end + +title_str = @lift "Temperature (°C), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_xzslices_fluxes_T.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end + +#%% +@info "Recording S fields and fluxes in xz" + +fieldname = "S" +fluxname = "wS_NN" + +field_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fieldname, backend=OnDisk()) +field_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fieldname, backend=OnDisk()) +field_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fieldname, backend=OnDisk()) +field_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fieldname, backend=OnDisk()) +field_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fieldname, backend=OnDisk()) +field_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fieldname, backend=OnDisk()) +field_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fieldname, backend=OnDisk()) +field_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fieldname, backend=OnDisk()) +field_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fieldname, backend=OnDisk()) +field_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fluxname, backend=OnDisk()) +flux_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fluxname, backend=OnDisk()) +flux_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fluxname, backend=OnDisk()) +flux_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fluxname, backend=OnDisk()) +flux_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fluxname, backend=OnDisk()) +flux_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fluxname, backend=OnDisk()) +flux_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fluxname, backend=OnDisk()) +flux_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fluxname, backend=OnDisk()) +flux_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fluxname, backend=OnDisk()) +flux_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fluxname, backend=OnDisk()) + +field_datas = [field_NN_data_5, field_NN_data_15, field_NN_data_25, field_NN_data_35, field_NN_data_45, field_NN_data_55, field_NN_data_65, field_NN_data_75, field_NN_data_85, field_NN_data_95] +flux_datas = [flux_NN_data_5, flux_NN_data_15, flux_NN_data_25, flux_NN_data_35, flux_NN_data_45, flux_NN_data_55, flux_NN_data_65, flux_NN_data_75, flux_NN_data_85, flux_NN_data_95] + +xC = field_NN_data_5.grid.xᶜᵃᵃ[1:field_NN_data_5.grid.Nx] +yC = field_NN_data_5.grid.yᵃᶜᵃ[1:field_NN_data_5.grid.Ny] +zC = field_NN_data_5.grid.zᵃᵃᶜ[1:field_NN_data_5.grid.Nz] +zF = field_NN_data_5.grid.zᵃᵃᶠ[1:field_NN_data_5.grid.Nz+1] + +Nt = length(field_NN_data_95) +times = field_NN_data_5.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_5 = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_5.grid.yᵃᶜᵃ[field_NN_data_5.indices[2][1]] / 1000) km") +axfield_15 = CairoMakie.Axis(fig[1, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_15.grid.yᵃᶜᵃ[field_NN_data_15.indices[2][1]] / 1000) km") +axfield_25 = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_25.grid.yᵃᶜᵃ[field_NN_data_25.indices[2][1]] / 1000) km") +axfield_35 = CairoMakie.Axis(fig[2, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_35.grid.yᵃᶜᵃ[field_NN_data_35.indices[2][1]] / 1000) km") +axfield_45 = CairoMakie.Axis(fig[3, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_45.grid.yᵃᶜᵃ[field_NN_data_45.indices[2][1]] / 1000) km") +axfield_55 = CairoMakie.Axis(fig[3, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_55.grid.yᵃᶜᵃ[field_NN_data_55.indices[2][1]] / 1000) km") +axfield_65 = CairoMakie.Axis(fig[4, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_65.grid.yᵃᶜᵃ[field_NN_data_65.indices[2][1]] / 1000) km") +axfield_75 = CairoMakie.Axis(fig[4, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_75.grid.yᵃᶜᵃ[field_NN_data_75.indices[2][1]] / 1000) km") +axfield_85 = CairoMakie.Axis(fig[5, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_85.grid.yᵃᶜᵃ[field_NN_data_85.indices[2][1]] / 1000) km") +axfield_95 = CairoMakie.Axis(fig[5, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_95.grid.yᵃᶜᵃ[field_NN_data_95.indices[2][1]] / 1000) km") + +axflux_5 = CairoMakie.Axis(fig[1, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_5.grid.yᵃᶜᵃ[flux_NN_data_5.indices[2][1]] / 1000) km") +axflux_15 = CairoMakie.Axis(fig[1, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_15.grid.yᵃᶜᵃ[flux_NN_data_15.indices[2][1]] / 1000) km") +axflux_25 = CairoMakie.Axis(fig[2, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_25.grid.yᵃᶜᵃ[flux_NN_data_25.indices[2][1]] / 1000) km") +axflux_35 = CairoMakie.Axis(fig[2, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_35.grid.yᵃᶜᵃ[flux_NN_data_35.indices[2][1]] / 1000) km") +axflux_45 = CairoMakie.Axis(fig[3, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_45.grid.yᵃᶜᵃ[flux_NN_data_45.indices[2][1]] / 1000) km") +axflux_55 = CairoMakie.Axis(fig[3, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_55.grid.yᵃᶜᵃ[flux_NN_data_55.indices[2][1]] / 1000) km") +axflux_65 = CairoMakie.Axis(fig[4, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_65.grid.yᵃᶜᵃ[flux_NN_data_65.indices[2][1]] / 1000) km") +axflux_75 = CairoMakie.Axis(fig[4, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_75.grid.yᵃᶜᵃ[flux_NN_data_75.indices[2][1]] / 1000) km") +axflux_85 = CairoMakie.Axis(fig[5, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_85.grid.yᵃᶜᵃ[flux_NN_data_85.indices[2][1]] / 1000) km") +axflux_95 = CairoMakie.Axis(fig[5, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_95.grid.yᵃᶜᵃ[flux_NN_data_95.indices[2][1]] / 1000) km") + +axfields = [axfield_5, axfield_15, axfield_25, axfield_35, axfield_45, axfield_55, axfield_65, axfield_75, axfield_85, axfield_95] +axfluxes = [axflux_5, axflux_15, axflux_25, axflux_35, axflux_45, axflux_55, axflux_65, axflux_75, axflux_85, axflux_95] + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lims = [(find_min(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices))) for field_data in field_datas] + +flux_lims = [(find_min(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices))) for flux_data in flux_datas] + +flux_lims = [(-maximum(abs, flux_lim), maximum(abs, flux_lim)) for flux_lim in flux_lims] + +NNₙs = [@lift interior(field_data[$n], :, 1, zC_indices) for field_data in field_datas] +fluxₙs = [@lift interior(flux_data[$n], :, 1, zF_indices) for flux_data in flux_datas] + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_5_surface = heatmap!(axfield_5, xC, zC[zC_indices], NN_5ₙ, colormap=colorscheme_field, colorrange=field_lims[1]) +field_15_surface = heatmap!(axfield_15, xC, zC[zC_indices], NN_15ₙ, colormap=colorscheme_field, colorrange=field_lims[2]) +field_25_surface = heatmap!(axfield_25, xC, zC[zC_indices], NN_25ₙ, colormap=colorscheme_field, colorrange=field_lims[3]) +field_35_surface = heatmap!(axfield_35, xC, zC[zC_indices], NN_35ₙ, colormap=colorscheme_field, colorrange=field_lims[4]) +field_45_surface = heatmap!(axfield_45, xC, zC[zC_indices], NN_45ₙ, colormap=colorscheme_field, colorrange=field_lims[5]) +field_55_surface = heatmap!(axfield_55, xC, zC[zC_indices], NN_55ₙ, colormap=colorscheme_field, colorrange=field_lims[6]) +field_65_surface = heatmap!(axfield_65, xC, zC[zC_indices], NN_65ₙ, colormap=colorscheme_field, colorrange=field_lims[7]) +field_75_surface = heatmap!(axfield_75, xC, zC[zC_indices], NN_75ₙ, colormap=colorscheme_field, colorrange=field_lims[8]) +field_85_surface = heatmap!(axfield_85, xC, zC[zC_indices], NN_85ₙ, colormap=colorscheme_field, colorrange=field_lims[9]) +field_95_surface = heatmap!(axfield_95, xC, zC[zC_indices], NN_95ₙ, colormap=colorscheme_field, colorrange=field_lims[10]) + +flux_5_surface = heatmap!(axflux_5, xC, zC[zF_indices], flux_5ₙ, colormap=colorscheme_flux, colorrange=flux_lims[1]) +flux_15_surface = heatmap!(axflux_15, xC, zC[zF_indices], flux_15ₙ, colormap=colorscheme_flux, colorrange=flux_lims[2]) +flux_25_surface = heatmap!(axflux_25, xC, zC[zF_indices], flux_25ₙ, colormap=colorscheme_flux, colorrange=flux_lims[3]) +flux_35_surface = heatmap!(axflux_35, xC, zC[zF_indices], flux_35ₙ, colormap=colorscheme_flux, colorrange=flux_lims[4]) +flux_45_surface = heatmap!(axflux_45, xC, zC[zF_indices], flux_45ₙ, colormap=colorscheme_flux, colorrange=flux_lims[5]) +flux_55_surface = heatmap!(axflux_55, xC, zC[zF_indices], flux_55ₙ, colormap=colorscheme_flux, colorrange=flux_lims[6]) +flux_65_surface = heatmap!(axflux_65, xC, zC[zF_indices], flux_65ₙ, colormap=colorscheme_flux, colorrange=flux_lims[7]) +flux_75_surface = heatmap!(axflux_75, xC, zC[zF_indices], flux_75ₙ, colormap=colorscheme_flux, colorrange=flux_lims[8]) +flux_85_surface = heatmap!(axflux_85, xC, zC[zF_indices], flux_85ₙ, colormap=colorscheme_flux, colorrange=flux_lims[9]) +flux_95_surface = heatmap!(axflux_95, xC, zC[zF_indices], flux_95ₙ, colormap=colorscheme_flux, colorrange=flux_lims[10]) + +Colorbar(fig[1, 2], field_5_surface, label="Field") +Colorbar(fig[1, 4], flux_5_surface, label="NN Flux") +Colorbar(fig[1, 6], field_15_surface, label="Field") +Colorbar(fig[1, 8], flux_15_surface, label="NN Flux") +Colorbar(fig[2, 2], field_25_surface, label="Field") +Colorbar(fig[2, 4], flux_25_surface, label="NN Flux") +Colorbar(fig[2, 6], field_35_surface, label="Field") +Colorbar(fig[2, 8], flux_35_surface, label="NN Flux") +Colorbar(fig[3, 2], field_45_surface, label="Field") +Colorbar(fig[3, 4], flux_45_surface, label="NN Flux") +Colorbar(fig[3, 6], field_55_surface, label="Field") +Colorbar(fig[3, 8], flux_55_surface, label="NN Flux") +Colorbar(fig[4, 2], field_65_surface, label="Field") +Colorbar(fig[4, 4], flux_65_surface, label="NN Flux") +Colorbar(fig[4, 6], field_75_surface, label="Field") +Colorbar(fig[4, 8], flux_75_surface, label="NN Flux") +Colorbar(fig[5, 2], field_85_surface, label="Field") +Colorbar(fig[5, 4], flux_85_surface, label="NN Flux") +Colorbar(fig[5, 6], field_95_surface, label="Field") +Colorbar(fig[5, 8], flux_95_surface, label="NN Flux") + +for axfield in axfields + xlims!(axfield, minimum(xC), maximum(xC)) + ylims!(axfield, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +end + +for axflux in axfluxes + xlims!(axflux, minimum(xC), maximum(xC)) + ylims!(axflux, minimum(zC[zF_indices]), maximum(zC[zF_indices])) +end + +title_str = @lift "Salinity (psu), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_xzslices_fluxes_S.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end +#%% \ No newline at end of file diff --git a/NNclosure_nof_BBLkappazonelast55_doublegyre_model_wallrestoration.jl b/NNclosure_nof_BBLkappazonelast55_doublegyre_model_wallrestoration.jl new file mode 100644 index 0000000000..7c95a8f5e0 --- /dev/null +++ b/NNclosure_nof_BBLkappazonelast55_doublegyre_model_wallrestoration.jl @@ -0,0 +1,1271 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("NN_closure_global_nof_BBLkappazonelast55.jl") +include("xin_kai_vertical_diffusivity_local_2step.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + + +#%% +filename = "doublegyre_30Cwarmflushbottom10_relaxation_wallrestoration_30days_zC2O_NN_closure_NDE_BBLkappazonelast55" +# filename = "doublegyre_30Cwarmflushbottom10_relaxation_30days_zWENO5_NN_closure_NDE_BBLkappazonelast55" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +nn_closure = NNFluxClosure(model_architecture) +base_closure = XinKaiLocalVerticalDiffusivity() +closure = (base_closure, nn_closure) + +advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), CenteredSecondOrder()) +# advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), WENO(order=5)) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +const δy = Ly / Ny + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +##### +##### Forcing and initial condition +##### +@inline T_initial(x, y, z) = 10 + 20 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) + +@inline T_north_ref(z) = min(0, -5 + 5 * (1 + (z + 500) / (Lz - 500))) +@inline north_T_flux(x, z, t, T) = μ_T * δy * (T - T_north_ref(z)) +north_T_flux_bc = FluxBoundaryCondition(north_T_flux; field_dependencies=:T) + +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc, north = north_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10950days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); +first_index = model.diffusivity_fields[2].first_index +last_index = model.diffusivity_fields[2].last_index +wT_NN = model.diffusivity_fields[2].wT +wS_NN = model.diffusivity_fields[2].wS + +κ = model.diffusivity_fields[1].κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) + +wT = wT_NN + wT_base +wS = wS_NN + wS_base + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +@inline function get_top_buoyancy_flux(i, j, k, grid, buoyancy, T_bc, S_bc, velocities, tracers, clock) + return top_buoyancy_flux(i, j, grid, buoyancy, (; T=T_bc, S=S_bc), clock, merge(velocities, tracers)) +end + +Qb = KernelFunctionOperation{Center, Center, Nothing}(get_top_buoyancy_flux, model.grid, model.buoyancy, T.boundary_conditions.top, S.boundary_conditions.top, model.velocities, model.tracers, model.clock) +Qb = Field(Qb) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +wT_NNbar_zonal = Average(wT_NN, dims=1) +wS_NNbar_zonal = Average(wS_NN, dims=1) + +wT_basebar_zonal = Average(wT_base, dims=1) +wS_basebar_zonal = Average(wS_base, dims=1) + +wTbar_zonal = Average(wT, dims=1) +wSbar_zonal = Average(wS, dims=1) + +outputs = (; u, v, w, T, S, ρ, N², wT_NN, wS_NN, wT_base, wS_base, wT, wS) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal, wT_NNbar_zonal, wS_NNbar_zonal, wT_basebar_zonal, wS_basebar_zonal, wTbar_zonal, wSbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_5] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_5", + indices = (:, 5, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_15] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_15", + indices = (:, 15, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_25] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_25", + indices = (:, 25, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_35] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_35", + indices = (:, 35, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_45] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_45", + indices = (:, 45, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_55] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_55", + indices = (:, 55, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_65] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_65", + indices = (:, 65, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_75] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_75", + indices = (:, 75, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_85] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_85", + indices = (:, 85, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_95] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_95", + indices = (:, 95, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:BBL] = JLD2OutputWriter(model, (; first_index, last_index, Qb), + filename = "$(FILE_DIR)/instantaneous_fields_NN_active_diagnostics", + indices = (:, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end + +#%% +@info "Recording T fields and fluxes in yz" + +fieldname = "T" +fluxname = "wT_NN" +field_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fieldname, backend=OnDisk()) +field_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fieldname, backend=OnDisk()) +field_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fieldname, backend=OnDisk()) +field_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fieldname, backend=OnDisk()) +field_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fieldname, backend=OnDisk()) +field_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fieldname, backend=OnDisk()) +field_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fieldname, backend=OnDisk()) +field_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fieldname, backend=OnDisk()) +field_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fieldname, backend=OnDisk()) +field_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fluxname, backend=OnDisk()) +flux_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fluxname, backend=OnDisk()) +flux_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fluxname, backend=OnDisk()) +flux_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fluxname, backend=OnDisk()) +flux_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fluxname, backend=OnDisk()) +flux_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fluxname, backend=OnDisk()) +flux_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fluxname, backend=OnDisk()) +flux_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fluxname, backend=OnDisk()) +flux_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fluxname, backend=OnDisk()) +flux_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fluxname, backend=OnDisk()) + +xC = field_NN_data_00.grid.xᶜᵃᵃ[1:field_NN_data_00.grid.Nx] +yC = field_NN_data_00.grid.yᵃᶜᵃ[1:field_NN_data_00.grid.Ny] +zC = field_NN_data_00.grid.zᵃᵃᶜ[1:field_NN_data_00.grid.Nz] +zF = field_NN_data_00.grid.zᵃᵃᶠ[1:field_NN_data_00.grid.Nz+1] + +Nt = length(field_NN_data_90) +times = field_NN_data_00.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_00 = CairoMakie.Axis(fig[1, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_00.grid.xᶜᵃᵃ[field_NN_data_00.indices[1][1]] / 1000) km") +axfield_10 = CairoMakie.Axis(fig[1, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_10.grid.xᶜᵃᵃ[field_NN_data_10.indices[1][1]] / 1000) km") +axfield_20 = CairoMakie.Axis(fig[2, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_20.grid.xᶜᵃᵃ[field_NN_data_20.indices[1][1]] / 1000) km") +axfield_30 = CairoMakie.Axis(fig[2, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_30.grid.xᶜᵃᵃ[field_NN_data_30.indices[1][1]] / 1000) km") +axfield_40 = CairoMakie.Axis(fig[3, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_40.grid.xᶜᵃᵃ[field_NN_data_40.indices[1][1]] / 1000) km") +axfield_50 = CairoMakie.Axis(fig[3, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_50.grid.xᶜᵃᵃ[field_NN_data_50.indices[1][1]] / 1000) km") +axfield_60 = CairoMakie.Axis(fig[4, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_60.grid.xᶜᵃᵃ[field_NN_data_60.indices[1][1]] / 1000) km") +axfield_70 = CairoMakie.Axis(fig[4, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_70.grid.xᶜᵃᵃ[field_NN_data_70.indices[1][1]] / 1000) km") +axfield_80 = CairoMakie.Axis(fig[5, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_80.grid.xᶜᵃᵃ[field_NN_data_80.indices[1][1]] / 1000) km") +axfield_90 = CairoMakie.Axis(fig[5, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_90.grid.xᶜᵃᵃ[field_NN_data_90.indices[1][1]] / 1000) km") + +axflux_00 = CairoMakie.Axis(fig[1, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_00.grid.xᶜᵃᵃ[flux_NN_data_00.indices[1][1]] / 1000) km") +axflux_10 = CairoMakie.Axis(fig[1, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_10.grid.xᶜᵃᵃ[flux_NN_data_10.indices[1][1]] / 1000) km") +axflux_20 = CairoMakie.Axis(fig[2, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_20.grid.xᶜᵃᵃ[flux_NN_data_20.indices[1][1]] / 1000) km") +axflux_30 = CairoMakie.Axis(fig[2, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_30.grid.xᶜᵃᵃ[flux_NN_data_30.indices[1][1]] / 1000) km") +axflux_40 = CairoMakie.Axis(fig[3, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_40.grid.xᶜᵃᵃ[flux_NN_data_40.indices[1][1]] / 1000) km") +axflux_50 = CairoMakie.Axis(fig[3, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_50.grid.xᶜᵃᵃ[flux_NN_data_50.indices[1][1]] / 1000) km") +axflux_60 = CairoMakie.Axis(fig[4, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_60.grid.xᶜᵃᵃ[flux_NN_data_60.indices[1][1]] / 1000) km") +axflux_70 = CairoMakie.Axis(fig[4, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_70.grid.xᶜᵃᵃ[flux_NN_data_70.indices[1][1]] / 1000) km") +axflux_80 = CairoMakie.Axis(fig[5, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_80.grid.xᶜᵃᵃ[flux_NN_data_80.indices[1][1]] / 1000) km") +axflux_90 = CairoMakie.Axis(fig[5, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_90.grid.xᶜᵃᵃ[flux_NN_data_90.indices[1][1]] / 1000) km") + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lim = (find_min(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (find_min(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (-maximum(abs, flux_lim), maximum(abs, flux_lim)) + +NN_00ₙ = @lift interior(field_NN_data_00[$n], 1, :, zC_indices) +NN_10ₙ = @lift interior(field_NN_data_10[$n], 1, :, zC_indices) +NN_20ₙ = @lift interior(field_NN_data_20[$n], 1, :, zC_indices) +NN_30ₙ = @lift interior(field_NN_data_30[$n], 1, :, zC_indices) +NN_40ₙ = @lift interior(field_NN_data_40[$n], 1, :, zC_indices) +NN_50ₙ = @lift interior(field_NN_data_50[$n], 1, :, zC_indices) +NN_60ₙ = @lift interior(field_NN_data_60[$n], 1, :, zC_indices) +NN_70ₙ = @lift interior(field_NN_data_70[$n], 1, :, zC_indices) +NN_80ₙ = @lift interior(field_NN_data_80[$n], 1, :, zC_indices) +NN_90ₙ = @lift interior(field_NN_data_90[$n], 1, :, zC_indices) + +flux_00ₙ = @lift interior(flux_NN_data_00[$n], 1, :, zF_indices) +flux_10ₙ = @lift interior(flux_NN_data_10[$n], 1, :, zF_indices) +flux_20ₙ = @lift interior(flux_NN_data_20[$n], 1, :, zF_indices) +flux_30ₙ = @lift interior(flux_NN_data_30[$n], 1, :, zF_indices) +flux_40ₙ = @lift interior(flux_NN_data_40[$n], 1, :, zF_indices) +flux_50ₙ = @lift interior(flux_NN_data_50[$n], 1, :, zF_indices) +flux_60ₙ = @lift interior(flux_NN_data_60[$n], 1, :, zF_indices) +flux_70ₙ = @lift interior(flux_NN_data_70[$n], 1, :, zF_indices) +flux_80ₙ = @lift interior(flux_NN_data_80[$n], 1, :, zF_indices) +flux_90ₙ = @lift interior(flux_NN_data_90[$n], 1, :, zF_indices) + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_00_surface = heatmap!(axfield_00, yC, zC[zC_indices], NN_00ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_10_surface = heatmap!(axfield_10, yC, zC[zC_indices], NN_10ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_20_surface = heatmap!(axfield_20, yC, zC[zC_indices], NN_20ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_30_surface = heatmap!(axfield_30, yC, zC[zC_indices], NN_30ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_40_surface = heatmap!(axfield_40, yC, zC[zC_indices], NN_40ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_50_surface = heatmap!(axfield_50, yC, zC[zC_indices], NN_50ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_60_surface = heatmap!(axfield_60, yC, zC[zC_indices], NN_60ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_70_surface = heatmap!(axfield_70, yC, zC[zC_indices], NN_70ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_80_surface = heatmap!(axfield_80, yC, zC[zC_indices], NN_80ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_90_surface = heatmap!(axfield_90, yC, zC[zC_indices], NN_90ₙ, colormap=colorscheme_field, colorrange=field_lim) + +flux_00_surface = heatmap!(axflux_00, yC, zC[zF_indices], flux_00ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_10_surface = heatmap!(axflux_10, yC, zC[zF_indices], flux_10ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_20_surface = heatmap!(axflux_20, yC, zC[zF_indices], flux_20ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_30_surface = heatmap!(axflux_30, yC, zC[zF_indices], flux_30ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_40_surface = heatmap!(axflux_40, yC, zC[zF_indices], flux_40ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_50_surface = heatmap!(axflux_50, yC, zC[zF_indices], flux_50ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_60_surface = heatmap!(axflux_60, yC, zC[zF_indices], flux_60ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_70_surface = heatmap!(axflux_70, yC, zC[zF_indices], flux_70ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_80_surface = heatmap!(axflux_80, yC, zC[zF_indices], flux_80ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_90_surface = heatmap!(axflux_90, yC, zC[zF_indices], flux_90ₙ, colormap=colorscheme_flux, colorrange=flux_lim) + +Colorbar(fig[1:5, 2], field_00_surface, label="Field") +Colorbar(fig[1:5, 4], flux_00_surface, label="NN Flux") +Colorbar(fig[1:5, 6], field_00_surface, label="Field") +Colorbar(fig[1:5, 8], flux_00_surface, label="NN Flux") + +xlims!(axfield_00, minimum(yC), maximum(yC)) +xlims!(axfield_10, minimum(yC), maximum(yC)) +xlims!(axfield_20, minimum(yC), maximum(yC)) +xlims!(axfield_30, minimum(yC), maximum(yC)) +xlims!(axfield_40, minimum(yC), maximum(yC)) +xlims!(axfield_50, minimum(yC), maximum(yC)) +xlims!(axfield_60, minimum(yC), maximum(yC)) +xlims!(axfield_70, minimum(yC), maximum(yC)) +xlims!(axfield_80, minimum(yC), maximum(yC)) +xlims!(axfield_90, minimum(yC), maximum(yC)) + +ylims!(axfield_00, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_10, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_20, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_30, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_40, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_50, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_60, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_70, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_80, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_90, minimum(zC[zC_indices]), maximum(zC[zC_indices])) + +xlims!(axflux_00, minimum(yC), maximum(yC)) +xlims!(axflux_10, minimum(yC), maximum(yC)) +xlims!(axflux_20, minimum(yC), maximum(yC)) +xlims!(axflux_30, minimum(yC), maximum(yC)) +xlims!(axflux_40, minimum(yC), maximum(yC)) +xlims!(axflux_50, minimum(yC), maximum(yC)) +xlims!(axflux_60, minimum(yC), maximum(yC)) +xlims!(axflux_70, minimum(yC), maximum(yC)) +xlims!(axflux_80, minimum(yC), maximum(yC)) +xlims!(axflux_90, minimum(yC), maximum(yC)) + +ylims!(axflux_00, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_10, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_20, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_30, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_40, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_50, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_60, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_70, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_80, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_90, minimum(zF[zF_indices]), maximum(zF[zF_indices])) + +title_str = @lift "Temperature (°C), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +# save("./Output/compare_3D_instantaneous_fields_slices_NNclosure_fluxes.png", fig) +# display(fig) +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_yzslices_fluxes_T.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end +#%% +@info "Recording S fields and fluxes in yz" + +fieldname = "S" +fluxname = "wS_NN" +field_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fieldname, backend=OnDisk()) +field_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fieldname, backend=OnDisk()) +field_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fieldname, backend=OnDisk()) +field_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fieldname, backend=OnDisk()) +field_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fieldname, backend=OnDisk()) +field_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fieldname, backend=OnDisk()) +field_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fieldname, backend=OnDisk()) +field_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fieldname, backend=OnDisk()) +field_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fieldname, backend=OnDisk()) +field_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_00 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", fluxname, backend=OnDisk()) +flux_NN_data_10 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_10.jld2", fluxname, backend=OnDisk()) +flux_NN_data_20 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_20.jld2", fluxname, backend=OnDisk()) +flux_NN_data_30 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_30.jld2", fluxname, backend=OnDisk()) +flux_NN_data_40 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_40.jld2", fluxname, backend=OnDisk()) +flux_NN_data_50 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_50.jld2", fluxname, backend=OnDisk()) +flux_NN_data_60 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_60.jld2", fluxname, backend=OnDisk()) +flux_NN_data_70 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_70.jld2", fluxname, backend=OnDisk()) +flux_NN_data_80 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_80.jld2", fluxname, backend=OnDisk()) +flux_NN_data_90 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz_90.jld2", fluxname, backend=OnDisk()) + +xC = field_NN_data_00.grid.xᶜᵃᵃ[1:field_NN_data_00.grid.Nx] +yC = field_NN_data_00.grid.yᵃᶜᵃ[1:field_NN_data_00.grid.Ny] +zC = field_NN_data_00.grid.zᵃᵃᶜ[1:field_NN_data_00.grid.Nz] +zF = field_NN_data_00.grid.zᵃᵃᶠ[1:field_NN_data_00.grid.Nz+1] + +Nt = length(field_NN_data_90) +times = field_NN_data_00.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_00 = CairoMakie.Axis(fig[1, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_00.grid.xᶜᵃᵃ[field_NN_data_00.indices[1][1]] / 1000) km") +axfield_10 = CairoMakie.Axis(fig[1, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_10.grid.xᶜᵃᵃ[field_NN_data_10.indices[1][1]] / 1000) km") +axfield_20 = CairoMakie.Axis(fig[2, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_20.grid.xᶜᵃᵃ[field_NN_data_20.indices[1][1]] / 1000) km") +axfield_30 = CairoMakie.Axis(fig[2, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_30.grid.xᶜᵃᵃ[field_NN_data_30.indices[1][1]] / 1000) km") +axfield_40 = CairoMakie.Axis(fig[3, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_40.grid.xᶜᵃᵃ[field_NN_data_40.indices[1][1]] / 1000) km") +axfield_50 = CairoMakie.Axis(fig[3, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_50.grid.xᶜᵃᵃ[field_NN_data_50.indices[1][1]] / 1000) km") +axfield_60 = CairoMakie.Axis(fig[4, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_60.grid.xᶜᵃᵃ[field_NN_data_60.indices[1][1]] / 1000) km") +axfield_70 = CairoMakie.Axis(fig[4, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_70.grid.xᶜᵃᵃ[field_NN_data_70.indices[1][1]] / 1000) km") +axfield_80 = CairoMakie.Axis(fig[5, 1], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_80.grid.xᶜᵃᵃ[field_NN_data_80.indices[1][1]] / 1000) km") +axfield_90 = CairoMakie.Axis(fig[5, 5], xlabel="y (m)", ylabel="z (m)", title="x = $(field_NN_data_90.grid.xᶜᵃᵃ[field_NN_data_90.indices[1][1]] / 1000) km") + +axflux_00 = CairoMakie.Axis(fig[1, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_00.grid.xᶜᵃᵃ[flux_NN_data_00.indices[1][1]] / 1000) km") +axflux_10 = CairoMakie.Axis(fig[1, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_10.grid.xᶜᵃᵃ[flux_NN_data_10.indices[1][1]] / 1000) km") +axflux_20 = CairoMakie.Axis(fig[2, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_20.grid.xᶜᵃᵃ[flux_NN_data_20.indices[1][1]] / 1000) km") +axflux_30 = CairoMakie.Axis(fig[2, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_30.grid.xᶜᵃᵃ[flux_NN_data_30.indices[1][1]] / 1000) km") +axflux_40 = CairoMakie.Axis(fig[3, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_40.grid.xᶜᵃᵃ[flux_NN_data_40.indices[1][1]] / 1000) km") +axflux_50 = CairoMakie.Axis(fig[3, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_50.grid.xᶜᵃᵃ[flux_NN_data_50.indices[1][1]] / 1000) km") +axflux_60 = CairoMakie.Axis(fig[4, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_60.grid.xᶜᵃᵃ[flux_NN_data_60.indices[1][1]] / 1000) km") +axflux_70 = CairoMakie.Axis(fig[4, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_70.grid.xᶜᵃᵃ[flux_NN_data_70.indices[1][1]] / 1000) km") +axflux_80 = CairoMakie.Axis(fig[5, 3], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_80.grid.xᶜᵃᵃ[flux_NN_data_80.indices[1][1]] / 1000) km") +axflux_90 = CairoMakie.Axis(fig[5, 7], xlabel="y (m)", ylabel="z (m)", title="x = $(flux_NN_data_90.grid.xᶜᵃᵃ[flux_NN_data_90.indices[1][1]] / 1000) km") + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lim = (find_min(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_NN_data_00[timeframes[1]], :, :, zC_indices), interior(field_NN_data_00[timeframes[end]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (find_min(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_NN_data_00[timeframes[1]], :, :, zC_indices), interior(flux_NN_data_00[timeframes[end]], :, :, zC_indices))) + +flux_lim = (-maximum(abs, flux_lim), maximum(abs, flux_lim)) + +NN_00ₙ = @lift interior(field_NN_data_00[$n], 1, :, zC_indices) +NN_10ₙ = @lift interior(field_NN_data_10[$n], 1, :, zC_indices) +NN_20ₙ = @lift interior(field_NN_data_20[$n], 1, :, zC_indices) +NN_30ₙ = @lift interior(field_NN_data_30[$n], 1, :, zC_indices) +NN_40ₙ = @lift interior(field_NN_data_40[$n], 1, :, zC_indices) +NN_50ₙ = @lift interior(field_NN_data_50[$n], 1, :, zC_indices) +NN_60ₙ = @lift interior(field_NN_data_60[$n], 1, :, zC_indices) +NN_70ₙ = @lift interior(field_NN_data_70[$n], 1, :, zC_indices) +NN_80ₙ = @lift interior(field_NN_data_80[$n], 1, :, zC_indices) +NN_90ₙ = @lift interior(field_NN_data_90[$n], 1, :, zC_indices) + +flux_00ₙ = @lift interior(flux_NN_data_00[$n], 1, :, zF_indices) +flux_10ₙ = @lift interior(flux_NN_data_10[$n], 1, :, zF_indices) +flux_20ₙ = @lift interior(flux_NN_data_20[$n], 1, :, zF_indices) +flux_30ₙ = @lift interior(flux_NN_data_30[$n], 1, :, zF_indices) +flux_40ₙ = @lift interior(flux_NN_data_40[$n], 1, :, zF_indices) +flux_50ₙ = @lift interior(flux_NN_data_50[$n], 1, :, zF_indices) +flux_60ₙ = @lift interior(flux_NN_data_60[$n], 1, :, zF_indices) +flux_70ₙ = @lift interior(flux_NN_data_70[$n], 1, :, zF_indices) +flux_80ₙ = @lift interior(flux_NN_data_80[$n], 1, :, zF_indices) +flux_90ₙ = @lift interior(flux_NN_data_90[$n], 1, :, zF_indices) + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_00_surface = heatmap!(axfield_00, yC, zC[zC_indices], NN_00ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_10_surface = heatmap!(axfield_10, yC, zC[zC_indices], NN_10ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_20_surface = heatmap!(axfield_20, yC, zC[zC_indices], NN_20ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_30_surface = heatmap!(axfield_30, yC, zC[zC_indices], NN_30ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_40_surface = heatmap!(axfield_40, yC, zC[zC_indices], NN_40ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_50_surface = heatmap!(axfield_50, yC, zC[zC_indices], NN_50ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_60_surface = heatmap!(axfield_60, yC, zC[zC_indices], NN_60ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_70_surface = heatmap!(axfield_70, yC, zC[zC_indices], NN_70ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_80_surface = heatmap!(axfield_80, yC, zC[zC_indices], NN_80ₙ, colormap=colorscheme_field, colorrange=field_lim) +field_90_surface = heatmap!(axfield_90, yC, zC[zC_indices], NN_90ₙ, colormap=colorscheme_field, colorrange=field_lim) + +flux_00_surface = heatmap!(axflux_00, yC, zC[zF_indices], flux_00ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_10_surface = heatmap!(axflux_10, yC, zC[zF_indices], flux_10ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_20_surface = heatmap!(axflux_20, yC, zC[zF_indices], flux_20ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_30_surface = heatmap!(axflux_30, yC, zC[zF_indices], flux_30ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_40_surface = heatmap!(axflux_40, yC, zC[zF_indices], flux_40ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_50_surface = heatmap!(axflux_50, yC, zC[zF_indices], flux_50ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_60_surface = heatmap!(axflux_60, yC, zC[zF_indices], flux_60ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_70_surface = heatmap!(axflux_70, yC, zC[zF_indices], flux_70ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_80_surface = heatmap!(axflux_80, yC, zC[zF_indices], flux_80ₙ, colormap=colorscheme_flux, colorrange=flux_lim) +flux_90_surface = heatmap!(axflux_90, yC, zC[zF_indices], flux_90ₙ, colormap=colorscheme_flux, colorrange=flux_lim) + +Colorbar(fig[1:5, 2], field_00_surface, label="Field") +Colorbar(fig[1:5, 4], flux_00_surface, label="NN Flux") +Colorbar(fig[1:5, 6], field_00_surface, label="Field") +Colorbar(fig[1:5, 8], flux_00_surface, label="NN Flux") + +xlims!(axfield_00, minimum(yC), maximum(yC)) +xlims!(axfield_10, minimum(yC), maximum(yC)) +xlims!(axfield_20, minimum(yC), maximum(yC)) +xlims!(axfield_30, minimum(yC), maximum(yC)) +xlims!(axfield_40, minimum(yC), maximum(yC)) +xlims!(axfield_50, minimum(yC), maximum(yC)) +xlims!(axfield_60, minimum(yC), maximum(yC)) +xlims!(axfield_70, minimum(yC), maximum(yC)) +xlims!(axfield_80, minimum(yC), maximum(yC)) +xlims!(axfield_90, minimum(yC), maximum(yC)) + +ylims!(axfield_00, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_10, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_20, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_30, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_40, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_50, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_60, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_70, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_80, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +ylims!(axfield_90, minimum(zC[zC_indices]), maximum(zC[zC_indices])) + +xlims!(axflux_00, minimum(yC), maximum(yC)) +xlims!(axflux_10, minimum(yC), maximum(yC)) +xlims!(axflux_20, minimum(yC), maximum(yC)) +xlims!(axflux_30, minimum(yC), maximum(yC)) +xlims!(axflux_40, minimum(yC), maximum(yC)) +xlims!(axflux_50, minimum(yC), maximum(yC)) +xlims!(axflux_60, minimum(yC), maximum(yC)) +xlims!(axflux_70, minimum(yC), maximum(yC)) +xlims!(axflux_80, minimum(yC), maximum(yC)) +xlims!(axflux_90, minimum(yC), maximum(yC)) + +ylims!(axflux_00, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_10, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_20, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_30, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_40, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_50, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_60, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_70, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_80, minimum(zF[zF_indices]), maximum(zF[zF_indices])) +ylims!(axflux_90, minimum(zF[zF_indices]), maximum(zF[zF_indices])) + +title_str = @lift "Salinity (psu), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_yzslices_fluxes_S.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end + +#%% +@info "Recording T fields and fluxes in xz" + +fieldname = "T" +fluxname = "wT_NN" + +field_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fieldname, backend=OnDisk()) +field_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fieldname, backend=OnDisk()) +field_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fieldname, backend=OnDisk()) +field_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fieldname, backend=OnDisk()) +field_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fieldname, backend=OnDisk()) +field_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fieldname, backend=OnDisk()) +field_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fieldname, backend=OnDisk()) +field_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fieldname, backend=OnDisk()) +field_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fieldname, backend=OnDisk()) +field_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fluxname, backend=OnDisk()) +flux_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fluxname, backend=OnDisk()) +flux_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fluxname, backend=OnDisk()) +flux_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fluxname, backend=OnDisk()) +flux_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fluxname, backend=OnDisk()) +flux_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fluxname, backend=OnDisk()) +flux_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fluxname, backend=OnDisk()) +flux_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fluxname, backend=OnDisk()) +flux_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fluxname, backend=OnDisk()) +flux_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fluxname, backend=OnDisk()) + +field_datas = [field_NN_data_5, field_NN_data_15, field_NN_data_25, field_NN_data_35, field_NN_data_45, field_NN_data_55, field_NN_data_65, field_NN_data_75, field_NN_data_85, field_NN_data_95] +flux_datas = [flux_NN_data_5, flux_NN_data_15, flux_NN_data_25, flux_NN_data_35, flux_NN_data_45, flux_NN_data_55, flux_NN_data_65, flux_NN_data_75, flux_NN_data_85, flux_NN_data_95] + +xC = field_NN_data_5.grid.xᶜᵃᵃ[1:field_NN_data_5.grid.Nx] +yC = field_NN_data_5.grid.yᵃᶜᵃ[1:field_NN_data_5.grid.Ny] +zC = field_NN_data_5.grid.zᵃᵃᶜ[1:field_NN_data_5.grid.Nz] +zF = field_NN_data_5.grid.zᵃᵃᶠ[1:field_NN_data_5.grid.Nz+1] + +Nt = length(field_NN_data_95) +times = field_NN_data_5.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_5 = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_5.grid.yᵃᶜᵃ[field_NN_data_5.indices[2][1]] / 1000) km") +axfield_15 = CairoMakie.Axis(fig[1, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_15.grid.yᵃᶜᵃ[field_NN_data_15.indices[2][1]] / 1000) km") +axfield_25 = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_25.grid.yᵃᶜᵃ[field_NN_data_25.indices[2][1]] / 1000) km") +axfield_35 = CairoMakie.Axis(fig[2, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_35.grid.yᵃᶜᵃ[field_NN_data_35.indices[2][1]] / 1000) km") +axfield_45 = CairoMakie.Axis(fig[3, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_45.grid.yᵃᶜᵃ[field_NN_data_45.indices[2][1]] / 1000) km") +axfield_55 = CairoMakie.Axis(fig[3, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_55.grid.yᵃᶜᵃ[field_NN_data_55.indices[2][1]] / 1000) km") +axfield_65 = CairoMakie.Axis(fig[4, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_65.grid.yᵃᶜᵃ[field_NN_data_65.indices[2][1]] / 1000) km") +axfield_75 = CairoMakie.Axis(fig[4, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_75.grid.yᵃᶜᵃ[field_NN_data_75.indices[2][1]] / 1000) km") +axfield_85 = CairoMakie.Axis(fig[5, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_85.grid.yᵃᶜᵃ[field_NN_data_85.indices[2][1]] / 1000) km") +axfield_95 = CairoMakie.Axis(fig[5, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_95.grid.yᵃᶜᵃ[field_NN_data_95.indices[2][1]] / 1000) km") + +axflux_5 = CairoMakie.Axis(fig[1, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_5.grid.yᵃᶜᵃ[flux_NN_data_5.indices[2][1]] / 1000) km") +axflux_15 = CairoMakie.Axis(fig[1, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_15.grid.yᵃᶜᵃ[flux_NN_data_15.indices[2][1]] / 1000) km") +axflux_25 = CairoMakie.Axis(fig[2, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_25.grid.yᵃᶜᵃ[flux_NN_data_25.indices[2][1]] / 1000) km") +axflux_35 = CairoMakie.Axis(fig[2, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_35.grid.yᵃᶜᵃ[flux_NN_data_35.indices[2][1]] / 1000) km") +axflux_45 = CairoMakie.Axis(fig[3, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_45.grid.yᵃᶜᵃ[flux_NN_data_45.indices[2][1]] / 1000) km") +axflux_55 = CairoMakie.Axis(fig[3, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_55.grid.yᵃᶜᵃ[flux_NN_data_55.indices[2][1]] / 1000) km") +axflux_65 = CairoMakie.Axis(fig[4, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_65.grid.yᵃᶜᵃ[flux_NN_data_65.indices[2][1]] / 1000) km") +axflux_75 = CairoMakie.Axis(fig[4, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_75.grid.yᵃᶜᵃ[flux_NN_data_75.indices[2][1]] / 1000) km") +axflux_85 = CairoMakie.Axis(fig[5, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_85.grid.yᵃᶜᵃ[flux_NN_data_85.indices[2][1]] / 1000) km") +axflux_95 = CairoMakie.Axis(fig[5, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_95.grid.yᵃᶜᵃ[flux_NN_data_95.indices[2][1]] / 1000) km") + +axfields = [axfield_5, axfield_15, axfield_25, axfield_35, axfield_45, axfield_55, axfield_65, axfield_75, axfield_85, axfield_95] +axfluxes = [axflux_5, axflux_15, axflux_25, axflux_35, axflux_45, axflux_55, axflux_65, axflux_75, axflux_85, axflux_95] + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lims = [(find_min(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices))) for field_data in field_datas] + +flux_lims = [(find_min(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices))) for flux_data in flux_datas] + +flux_lims = [(-maximum(abs, flux_lim), maximum(abs, flux_lim)) for flux_lim in flux_lims] + +NNₙs = [@lift interior(field_data[$n], :, 1, zC_indices) for field_data in field_datas] +fluxₙs = [@lift interior(flux_data[$n], :, 1, zF_indices) for flux_data in flux_datas] + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_5_surface = heatmap!(axfield_5, xC, zC[zC_indices], NN_5ₙ, colormap=colorscheme_field, colorrange=field_lims[1]) +field_15_surface = heatmap!(axfield_15, xC, zC[zC_indices], NN_15ₙ, colormap=colorscheme_field, colorrange=field_lims[2]) +field_25_surface = heatmap!(axfield_25, xC, zC[zC_indices], NN_25ₙ, colormap=colorscheme_field, colorrange=field_lims[3]) +field_35_surface = heatmap!(axfield_35, xC, zC[zC_indices], NN_35ₙ, colormap=colorscheme_field, colorrange=field_lims[4]) +field_45_surface = heatmap!(axfield_45, xC, zC[zC_indices], NN_45ₙ, colormap=colorscheme_field, colorrange=field_lims[5]) +field_55_surface = heatmap!(axfield_55, xC, zC[zC_indices], NN_55ₙ, colormap=colorscheme_field, colorrange=field_lims[6]) +field_65_surface = heatmap!(axfield_65, xC, zC[zC_indices], NN_65ₙ, colormap=colorscheme_field, colorrange=field_lims[7]) +field_75_surface = heatmap!(axfield_75, xC, zC[zC_indices], NN_75ₙ, colormap=colorscheme_field, colorrange=field_lims[8]) +field_85_surface = heatmap!(axfield_85, xC, zC[zC_indices], NN_85ₙ, colormap=colorscheme_field, colorrange=field_lims[9]) +field_95_surface = heatmap!(axfield_95, xC, zC[zC_indices], NN_95ₙ, colormap=colorscheme_field, colorrange=field_lims[10]) + +flux_5_surface = heatmap!(axflux_5, xC, zC[zF_indices], flux_5ₙ, colormap=colorscheme_flux, colorrange=flux_lims[1]) +flux_15_surface = heatmap!(axflux_15, xC, zC[zF_indices], flux_15ₙ, colormap=colorscheme_flux, colorrange=flux_lims[2]) +flux_25_surface = heatmap!(axflux_25, xC, zC[zF_indices], flux_25ₙ, colormap=colorscheme_flux, colorrange=flux_lims[3]) +flux_35_surface = heatmap!(axflux_35, xC, zC[zF_indices], flux_35ₙ, colormap=colorscheme_flux, colorrange=flux_lims[4]) +flux_45_surface = heatmap!(axflux_45, xC, zC[zF_indices], flux_45ₙ, colormap=colorscheme_flux, colorrange=flux_lims[5]) +flux_55_surface = heatmap!(axflux_55, xC, zC[zF_indices], flux_55ₙ, colormap=colorscheme_flux, colorrange=flux_lims[6]) +flux_65_surface = heatmap!(axflux_65, xC, zC[zF_indices], flux_65ₙ, colormap=colorscheme_flux, colorrange=flux_lims[7]) +flux_75_surface = heatmap!(axflux_75, xC, zC[zF_indices], flux_75ₙ, colormap=colorscheme_flux, colorrange=flux_lims[8]) +flux_85_surface = heatmap!(axflux_85, xC, zC[zF_indices], flux_85ₙ, colormap=colorscheme_flux, colorrange=flux_lims[9]) +flux_95_surface = heatmap!(axflux_95, xC, zC[zF_indices], flux_95ₙ, colormap=colorscheme_flux, colorrange=flux_lims[10]) + +Colorbar(fig[1, 2], field_5_surface, label="Field") +Colorbar(fig[1, 4], flux_5_surface, label="NN Flux") +Colorbar(fig[1, 6], field_15_surface, label="Field") +Colorbar(fig[1, 8], flux_15_surface, label="NN Flux") +Colorbar(fig[2, 2], field_25_surface, label="Field") +Colorbar(fig[2, 4], flux_25_surface, label="NN Flux") +Colorbar(fig[2, 6], field_35_surface, label="Field") +Colorbar(fig[2, 8], flux_35_surface, label="NN Flux") +Colorbar(fig[3, 2], field_45_surface, label="Field") +Colorbar(fig[3, 4], flux_45_surface, label="NN Flux") +Colorbar(fig[3, 6], field_55_surface, label="Field") +Colorbar(fig[3, 8], flux_55_surface, label="NN Flux") +Colorbar(fig[4, 2], field_65_surface, label="Field") +Colorbar(fig[4, 4], flux_65_surface, label="NN Flux") +Colorbar(fig[4, 6], field_75_surface, label="Field") +Colorbar(fig[4, 8], flux_75_surface, label="NN Flux") +Colorbar(fig[5, 2], field_85_surface, label="Field") +Colorbar(fig[5, 4], flux_85_surface, label="NN Flux") +Colorbar(fig[5, 6], field_95_surface, label="Field") +Colorbar(fig[5, 8], flux_95_surface, label="NN Flux") + +for axfield in axfields + xlims!(axfield, minimum(xC), maximum(xC)) + ylims!(axfield, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +end + +for axflux in axfluxes + xlims!(axflux, minimum(xC), maximum(xC)) + ylims!(axflux, minimum(zC[zF_indices]), maximum(zC[zF_indices])) +end + +title_str = @lift "Temperature (°C), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_xzslices_fluxes_T.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end + +#%% +@info "Recording S fields and fluxes in xz" + +fieldname = "S" +fluxname = "wS_NN" + +field_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fieldname, backend=OnDisk()) +field_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fieldname, backend=OnDisk()) +field_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fieldname, backend=OnDisk()) +field_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fieldname, backend=OnDisk()) +field_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fieldname, backend=OnDisk()) +field_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fieldname, backend=OnDisk()) +field_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fieldname, backend=OnDisk()) +field_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fieldname, backend=OnDisk()) +field_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fieldname, backend=OnDisk()) +field_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fieldname, backend=OnDisk()) + +flux_NN_data_5 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_5.jld2", fluxname, backend=OnDisk()) +flux_NN_data_15 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_15.jld2", fluxname, backend=OnDisk()) +flux_NN_data_25 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_25.jld2", fluxname, backend=OnDisk()) +flux_NN_data_35 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_35.jld2", fluxname, backend=OnDisk()) +flux_NN_data_45 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_45.jld2", fluxname, backend=OnDisk()) +flux_NN_data_55 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_55.jld2", fluxname, backend=OnDisk()) +flux_NN_data_65 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_65.jld2", fluxname, backend=OnDisk()) +flux_NN_data_75 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_75.jld2", fluxname, backend=OnDisk()) +flux_NN_data_85 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_85.jld2", fluxname, backend=OnDisk()) +flux_NN_data_95 = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_95.jld2", fluxname, backend=OnDisk()) + +field_datas = [field_NN_data_5, field_NN_data_15, field_NN_data_25, field_NN_data_35, field_NN_data_45, field_NN_data_55, field_NN_data_65, field_NN_data_75, field_NN_data_85, field_NN_data_95] +flux_datas = [flux_NN_data_5, flux_NN_data_15, flux_NN_data_25, flux_NN_data_35, flux_NN_data_45, flux_NN_data_55, flux_NN_data_65, flux_NN_data_75, flux_NN_data_85, flux_NN_data_95] + +xC = field_NN_data_5.grid.xᶜᵃᵃ[1:field_NN_data_5.grid.Nx] +yC = field_NN_data_5.grid.yᵃᶜᵃ[1:field_NN_data_5.grid.Ny] +zC = field_NN_data_5.grid.zᵃᵃᶜ[1:field_NN_data_5.grid.Nz] +zF = field_NN_data_5.grid.zᵃᵃᶠ[1:field_NN_data_5.grid.Nz+1] + +Nt = length(field_NN_data_95) +times = field_NN_data_5.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(3000, 1200)) + +axfield_5 = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_5.grid.yᵃᶜᵃ[field_NN_data_5.indices[2][1]] / 1000) km") +axfield_15 = CairoMakie.Axis(fig[1, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_15.grid.yᵃᶜᵃ[field_NN_data_15.indices[2][1]] / 1000) km") +axfield_25 = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_25.grid.yᵃᶜᵃ[field_NN_data_25.indices[2][1]] / 1000) km") +axfield_35 = CairoMakie.Axis(fig[2, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_35.grid.yᵃᶜᵃ[field_NN_data_35.indices[2][1]] / 1000) km") +axfield_45 = CairoMakie.Axis(fig[3, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_45.grid.yᵃᶜᵃ[field_NN_data_45.indices[2][1]] / 1000) km") +axfield_55 = CairoMakie.Axis(fig[3, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_55.grid.yᵃᶜᵃ[field_NN_data_55.indices[2][1]] / 1000) km") +axfield_65 = CairoMakie.Axis(fig[4, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_65.grid.yᵃᶜᵃ[field_NN_data_65.indices[2][1]] / 1000) km") +axfield_75 = CairoMakie.Axis(fig[4, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_75.grid.yᵃᶜᵃ[field_NN_data_75.indices[2][1]] / 1000) km") +axfield_85 = CairoMakie.Axis(fig[5, 1], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_85.grid.yᵃᶜᵃ[field_NN_data_85.indices[2][1]] / 1000) km") +axfield_95 = CairoMakie.Axis(fig[5, 5], xlabel="x (m)", ylabel="z (m)", title="y = $(field_NN_data_95.grid.yᵃᶜᵃ[field_NN_data_95.indices[2][1]] / 1000) km") + +axflux_5 = CairoMakie.Axis(fig[1, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_5.grid.yᵃᶜᵃ[flux_NN_data_5.indices[2][1]] / 1000) km") +axflux_15 = CairoMakie.Axis(fig[1, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_15.grid.yᵃᶜᵃ[flux_NN_data_15.indices[2][1]] / 1000) km") +axflux_25 = CairoMakie.Axis(fig[2, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_25.grid.yᵃᶜᵃ[flux_NN_data_25.indices[2][1]] / 1000) km") +axflux_35 = CairoMakie.Axis(fig[2, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_35.grid.yᵃᶜᵃ[flux_NN_data_35.indices[2][1]] / 1000) km") +axflux_45 = CairoMakie.Axis(fig[3, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_45.grid.yᵃᶜᵃ[flux_NN_data_45.indices[2][1]] / 1000) km") +axflux_55 = CairoMakie.Axis(fig[3, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_55.grid.yᵃᶜᵃ[flux_NN_data_55.indices[2][1]] / 1000) km") +axflux_65 = CairoMakie.Axis(fig[4, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_65.grid.yᵃᶜᵃ[flux_NN_data_65.indices[2][1]] / 1000) km") +axflux_75 = CairoMakie.Axis(fig[4, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_75.grid.yᵃᶜᵃ[flux_NN_data_75.indices[2][1]] / 1000) km") +axflux_85 = CairoMakie.Axis(fig[5, 3], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_85.grid.yᵃᶜᵃ[flux_NN_data_85.indices[2][1]] / 1000) km") +axflux_95 = CairoMakie.Axis(fig[5, 7], xlabel="x (m)", ylabel="z (m)", title="y = $(flux_NN_data_95.grid.yᵃᶜᵃ[flux_NN_data_95.indices[2][1]] / 1000) km") + +axfields = [axfield_5, axfield_15, axfield_25, axfield_35, axfield_45, axfield_55, axfield_65, axfield_75, axfield_85, axfield_95] +axfluxes = [axflux_5, axflux_15, axflux_25, axflux_35, axflux_45, axflux_55, axflux_65, axflux_75, axflux_85, axflux_95] + +n = Observable(1096) + +zC_indices = 1:200 +zF_indices = 2:200 + +field_lims = [(find_min(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(field_data[timeframes[1]], :, :, zC_indices), interior(field_data[timeframes[end]], :, :, zC_indices))) for field_data in field_datas] + +flux_lims = [(find_min(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices)), + find_max(interior(flux_data[timeframes[1]], :, :, zC_indices), interior(flux_data[timeframes[end]], :, :, zC_indices))) for flux_data in flux_datas] + +flux_lims = [(-maximum(abs, flux_lim), maximum(abs, flux_lim)) for flux_lim in flux_lims] + +NNₙs = [@lift interior(field_data[$n], :, 1, zC_indices) for field_data in field_datas] +fluxₙs = [@lift interior(flux_data[$n], :, 1, zF_indices) for flux_data in flux_datas] + +colorscheme_field = colorschemes[:viridis] +colorscheme_flux = colorschemes[:balance] + +field_5_surface = heatmap!(axfield_5, xC, zC[zC_indices], NN_5ₙ, colormap=colorscheme_field, colorrange=field_lims[1]) +field_15_surface = heatmap!(axfield_15, xC, zC[zC_indices], NN_15ₙ, colormap=colorscheme_field, colorrange=field_lims[2]) +field_25_surface = heatmap!(axfield_25, xC, zC[zC_indices], NN_25ₙ, colormap=colorscheme_field, colorrange=field_lims[3]) +field_35_surface = heatmap!(axfield_35, xC, zC[zC_indices], NN_35ₙ, colormap=colorscheme_field, colorrange=field_lims[4]) +field_45_surface = heatmap!(axfield_45, xC, zC[zC_indices], NN_45ₙ, colormap=colorscheme_field, colorrange=field_lims[5]) +field_55_surface = heatmap!(axfield_55, xC, zC[zC_indices], NN_55ₙ, colormap=colorscheme_field, colorrange=field_lims[6]) +field_65_surface = heatmap!(axfield_65, xC, zC[zC_indices], NN_65ₙ, colormap=colorscheme_field, colorrange=field_lims[7]) +field_75_surface = heatmap!(axfield_75, xC, zC[zC_indices], NN_75ₙ, colormap=colorscheme_field, colorrange=field_lims[8]) +field_85_surface = heatmap!(axfield_85, xC, zC[zC_indices], NN_85ₙ, colormap=colorscheme_field, colorrange=field_lims[9]) +field_95_surface = heatmap!(axfield_95, xC, zC[zC_indices], NN_95ₙ, colormap=colorscheme_field, colorrange=field_lims[10]) + +flux_5_surface = heatmap!(axflux_5, xC, zC[zF_indices], flux_5ₙ, colormap=colorscheme_flux, colorrange=flux_lims[1]) +flux_15_surface = heatmap!(axflux_15, xC, zC[zF_indices], flux_15ₙ, colormap=colorscheme_flux, colorrange=flux_lims[2]) +flux_25_surface = heatmap!(axflux_25, xC, zC[zF_indices], flux_25ₙ, colormap=colorscheme_flux, colorrange=flux_lims[3]) +flux_35_surface = heatmap!(axflux_35, xC, zC[zF_indices], flux_35ₙ, colormap=colorscheme_flux, colorrange=flux_lims[4]) +flux_45_surface = heatmap!(axflux_45, xC, zC[zF_indices], flux_45ₙ, colormap=colorscheme_flux, colorrange=flux_lims[5]) +flux_55_surface = heatmap!(axflux_55, xC, zC[zF_indices], flux_55ₙ, colormap=colorscheme_flux, colorrange=flux_lims[6]) +flux_65_surface = heatmap!(axflux_65, xC, zC[zF_indices], flux_65ₙ, colormap=colorscheme_flux, colorrange=flux_lims[7]) +flux_75_surface = heatmap!(axflux_75, xC, zC[zF_indices], flux_75ₙ, colormap=colorscheme_flux, colorrange=flux_lims[8]) +flux_85_surface = heatmap!(axflux_85, xC, zC[zF_indices], flux_85ₙ, colormap=colorscheme_flux, colorrange=flux_lims[9]) +flux_95_surface = heatmap!(axflux_95, xC, zC[zF_indices], flux_95ₙ, colormap=colorscheme_flux, colorrange=flux_lims[10]) + +Colorbar(fig[1, 2], field_5_surface, label="Field") +Colorbar(fig[1, 4], flux_5_surface, label="NN Flux") +Colorbar(fig[1, 6], field_15_surface, label="Field") +Colorbar(fig[1, 8], flux_15_surface, label="NN Flux") +Colorbar(fig[2, 2], field_25_surface, label="Field") +Colorbar(fig[2, 4], flux_25_surface, label="NN Flux") +Colorbar(fig[2, 6], field_35_surface, label="Field") +Colorbar(fig[2, 8], flux_35_surface, label="NN Flux") +Colorbar(fig[3, 2], field_45_surface, label="Field") +Colorbar(fig[3, 4], flux_45_surface, label="NN Flux") +Colorbar(fig[3, 6], field_55_surface, label="Field") +Colorbar(fig[3, 8], flux_55_surface, label="NN Flux") +Colorbar(fig[4, 2], field_65_surface, label="Field") +Colorbar(fig[4, 4], flux_65_surface, label="NN Flux") +Colorbar(fig[4, 6], field_75_surface, label="Field") +Colorbar(fig[4, 8], flux_75_surface, label="NN Flux") +Colorbar(fig[5, 2], field_85_surface, label="Field") +Colorbar(fig[5, 4], flux_85_surface, label="NN Flux") +Colorbar(fig[5, 6], field_95_surface, label="Field") +Colorbar(fig[5, 8], flux_95_surface, label="NN Flux") + +for axfield in axfields + xlims!(axfield, minimum(xC), maximum(xC)) + ylims!(axfield, minimum(zC[zC_indices]), maximum(zC[zC_indices])) +end + +for axflux in axfluxes + xlims!(axflux, minimum(xC), maximum(xC)) + ylims!(axflux, minimum(zC[zF_indices]), maximum(zC[zF_indices])) +end + +title_str = @lift "Salinity (psu), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_xzslices_fluxes_S.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + n[] = nn +end +#%% \ No newline at end of file diff --git a/Project.toml b/Project.toml index cec2f85661..d430772f25 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.94.3" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" CubedSphere = "7445602f-e544-4518-8976-18f8e8ae6cdb" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" @@ -20,33 +21,35 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +PencilArrays = "0e08944d-e94e-41b1-9406-dcf66b6a9d2e" +PencilFFTs = "4a48f351-57a6-4416-9ec4-c37015456aae" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" SeawaterPolynomials = "d496a93d-167e-4197-9f49-d3af4ff8fe40" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [weakdeps] Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" -MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b" [extensions] OceananigansEnzymeExt = "Enzyme" -OceananigansMakieExt = ["MakieCore", "Makie"] [compat] Adapt = "4.1.1" CUDA = "4.1.1, 5" Crayons = "4" -CubedSphere = "0.2, 0.3" +CubedSphere = "0.1, 0.2" Dates = "1.9" Distances = "0.10" DocStringExtensions = "0.8, 0.9" @@ -61,15 +64,16 @@ KernelAbstractions = "0.9.21" LinearAlgebra = "1.9" Logging = "1.9" MPI = "0.16, 0.17, 0.18, 0.19, 0.20" -Makie = "0.21" -MakieCore = "0.7, 0.8" NCDatasets = "0.12.10, 0.13.1, 0.14" OffsetArrays = "1.4" OrderedCollections = "1.1" +PencilArrays = "0.16, 0.17, 0.18, 0.19" +PencilFFTs = "0.13.5, 0.14, 0.15" +Pkg = "1.9" Printf = "1.9" Random = "1.9" Rotations = "1.0" -SeawaterPolynomials = "0.3.5" +SeawaterPolynomials = "0.3.4" SparseArrays = "1.9" Statistics = "1.9" StructArrays = "0.4, 0.5, 0.6, 0.7" diff --git a/baseclosure_doublegyre_model.jl b/baseclosure_doublegyre_model.jl new file mode 100644 index 0000000000..51cff7e579 --- /dev/null +++ b/baseclosure_doublegyre_model.jl @@ -0,0 +1,578 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("xin_kai_vertical_diffusivity_local_2step.jl") +# include("xin_kai_vertical_diffusivity_2Pr.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + +#%% +# filename = "doublegyre_30Cwarmflushbottom10_relaxation_30days_zC2O_baseclosure_trainFC24new_scalingtrain54new_2Pr_2step" +filename = "doublegyre_30Cwarmflushbottom10_relaxation_30days_zWENO5_baseclosure_trainFC24new_scalingtrain54new_2Pr_2step" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +base_closure = XinKaiLocalVerticalDiffusivity() +closure = base_closure + +advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), WENO(order=5)) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +##### +##### Forcing and initial condition +##### +# @inline T_initial(x, y, z) = T_north + ΔT / 2 * (1 + z / Lz) +# @inline T_initial(x, y, z) = (T_north + T_south / 2) + 5 * (1 + z / Lz) +@inline T_initial(x, y, z) = 10 + 20 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10950days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); + +κ = model.diffusivity_fields.κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) +wTbar_zonal = Average(wT_base, dims=1) +wSbar_zonal = Average(wS_base, dims=1) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +outputs = (; u, v, w, T, S, ρ, N², wT=wT_base, wS=wS_base) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal, wTbar_zonal, wSbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:complete_fields] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields", + schedule = TimeInterval(1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=20, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 365 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% +# function find_min(a...) +# return minimum(minimum.([a...])) +# end + +# function find_max(a...) +# return maximum(maximum.([a...])) +# end + +# N²_xz_north_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_north.jld2", "N²") +# N²_xz_south_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_south.jld2", "N²") + +# xC = N²_xz_north_data.grid.xᶜᵃᵃ[1:N²_xz_north_data.grid.Nx] +# zf = N²_xz_north_data.grid.zᵃᵃᶠ[1:N²_xz_north_data.grid.Nz+1] + +# yloc_north = N²_xz_north_data.grid.yᵃᶠᵃ[N²_xz_north_data.indices[2][1]] +# yloc_south = N²_xz_south_data.grid.yᵃᶠᵃ[N²_xz_south_data.indices[2][1]] + +# Nt = length(N²_xz_north_data) +# times = N²_xz_north_data.times / 24 / 60^2 / 365 +# timeframes = 1:Nt + +# N²_lim = (find_min(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) - 1e-13, +# find_max(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) + 1e-13) +# #%% +# fig = Figure(size=(800, 800)) +# ax_north = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_south / 1e3)) km") +# ax_south = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_north / 1e3)) km") + +# n = Observable(2) + +# N²_north = @lift interior(N²_xz_north_data[$n], :, 1, :) +# N²_south = @lift interior(N²_xz_south_data[$n], :, 1, :) + +# colorscheme = colorschemes[:jet] + +# N²_north_surface = heatmap!(ax_north, xC, zf, N²_north, colormap=colorscheme, colorrange=N²_lim) +# N²_south_surface = heatmap!(ax_south, xC, zf, N²_south, colormap=colorscheme, colorrange=N²_lim) + +# Colorbar(fig[1:2, 2], N²_north_surface, label="N² (s⁻²)") + +# title_str = @lift "Time = $(round(times[$n], digits=2)) years" +# Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +# trim!(fig.layout) + +# @info "Recording buoyancy frequency xz slice" +# CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_buoyancy_frequency_xz_slice.mp4", 1:Nt, framerate=3, px_per_unit=2) do nn +# n[] = nn +# end + +# @info "Done!" +# #%% \ No newline at end of file diff --git a/baseclosure_doublegyre_model_initialized.jl b/baseclosure_doublegyre_model_initialized.jl new file mode 100644 index 0000000000..a376834270 --- /dev/null +++ b/baseclosure_doublegyre_model_initialized.jl @@ -0,0 +1,575 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("xin_kai_vertical_diffusivity_local_2step.jl") +# include("xin_kai_vertical_diffusivity_2Pr.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + + +#%% +filename = "doublegyre_relaxation_30days_baseclosure_2step_initialized" +FILE_DIR = "./Output/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +base_closure = XinKaiLocalVerticalDiffusivity() +vertical_scalar_closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5) +closure = (base_closure, vertical_scalar_closure) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +##### +##### Forcing and initial condition +##### + +# @inline T_initial(x, y, z) = T_north + ΔT / 2 * (1 + z / Lz) +@inline T_initial(x, y, z) = (T_north + T_south / 2) + 5 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### +# resting initial condition +# noise(z) = rand() * exp(z / 8) + +# T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +# S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) +DATA_DIR = "./Output/doublegyre_30Cwarmflush_relaxation_8days_baseclosure_trainFC24new_scalingtrain54new_2Pr_2step" + +u_data = FieldTimeSeries("$(DATA_DIR)/instantaneous_fields.jld2", "u", backend=OnDisk()) +v_data = FieldTimeSeries("$(DATA_DIR)/instantaneous_fields.jld2", "v", backend=OnDisk()) +T_data = FieldTimeSeries("$(DATA_DIR)/instantaneous_fields.jld2", "T", backend=OnDisk()) +S_data = FieldTimeSeries("$(DATA_DIR)/instantaneous_fields.jld2", "S", backend=OnDisk()) + +ntimes = length(u_data.times) + +set!(model, T=T_data[ntimes], S=S_data[ntimes], u=u_data[ntimes], v=v_data[ntimes]) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10950days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); + +κ = model.diffusivity_fields[1].κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +outputs = (; u, v, w, T, S, ρ, N², wT_base, wS_base) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_south] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_south", + indices = (:, 25, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_north] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_north", + indices = (:, 75, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=20, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 365 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% +# function find_min(a...) +# return minimum(minimum.([a...])) +# end + +# function find_max(a...) +# return maximum(maximum.([a...])) +# end + +# N²_xz_north_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_north.jld2", "N²") +# N²_xz_south_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_south.jld2", "N²") + +# xC = N²_xz_north_data.grid.xᶜᵃᵃ[1:N²_xz_north_data.grid.Nx] +# zf = N²_xz_north_data.grid.zᵃᵃᶠ[1:N²_xz_north_data.grid.Nz+1] + +# yloc_north = N²_xz_north_data.grid.yᵃᶠᵃ[N²_xz_north_data.indices[2][1]] +# yloc_south = N²_xz_south_data.grid.yᵃᶠᵃ[N²_xz_south_data.indices[2][1]] + +# Nt = length(N²_xz_north_data) +# times = N²_xz_north_data.times / 24 / 60^2 / 365 +# timeframes = 1:Nt + +# N²_lim = (find_min(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) - 1e-13, +# find_max(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) + 1e-13) +# #%% +# fig = Figure(size=(800, 800)) +# ax_north = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_south / 1e3)) km") +# ax_south = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_north / 1e3)) km") + +# n = Observable(2) + +# N²_north = @lift interior(N²_xz_north_data[$n], :, 1, :) +# N²_south = @lift interior(N²_xz_south_data[$n], :, 1, :) + +# colorscheme = colorschemes[:jet] + +# N²_north_surface = heatmap!(ax_north, xC, zf, N²_north, colormap=colorscheme, colorrange=N²_lim) +# N²_south_surface = heatmap!(ax_south, xC, zf, N²_south, colormap=colorscheme, colorrange=N²_lim) + +# Colorbar(fig[1:2, 2], N²_north_surface, label="N² (s⁻²)") + +# title_str = @lift "Time = $(round(times[$n], digits=2)) years" +# Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +# trim!(fig.layout) + +# @info "Recording buoyancy frequency xz slice" +# CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_buoyancy_frequency_xz_slice.mp4", 1:Nt, framerate=3, px_per_unit=2) do nn +# n[] = nn +# end + +# @info "Done!" +# #%% \ No newline at end of file diff --git a/baseclosure_doublegyre_model_modewater.jl b/baseclosure_doublegyre_model_modewater.jl new file mode 100644 index 0000000000..11328e9ede --- /dev/null +++ b/baseclosure_doublegyre_model_modewater.jl @@ -0,0 +1,587 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("xin_kai_vertical_diffusivity_local_2step_new.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + +#%% +filename = "doublegyre_30Cwarmflushbottom10_relaxation_30days_modewater_zWENO5_newbaseclosure" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +base_closure = XinKaiLocalVerticalDiffusivity() +closure = base_closure + +advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), WENO(order=5)) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +const X₀ = -Lx/2 + 800kilometers +const Y₀ = -Ly/2 + 1500kilometers +const R₀ = 700kilometers +const Qᵀ_mode = 4.5e-4 +const σ_mode = 20kilometers + +##### +##### Forcing and initial condition +##### +# @inline T_initial(x, y, z) = T_north + ΔT / 2 * (1 + z / Lz) +# @inline T_initial(x, y, z) = (T_north + T_south / 2) + 5 * (1 + z / Lz) +@inline T_initial(x, y, z) = 10 + 20 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y + +@inline Qᵀ_winter(t) = max(0, -Qᵀ_mode * sin(2π * t / 360days)) +@inline Qᵀ_subpolar(x, y, t) = ifelse((x - X₀)^2 + (y - Y₀)^2 <= R₀^2, Qᵀ_winter(t), + exp(-(sqrt((x - X₀)^2 + (y - Y₀)^2) - R₀)^2 / (2 * σ_mode^2)) * Qᵀ_winter(t)) + +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) + Qᵀ_subpolar(x, y, t) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10800days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); + +κ = model.diffusivity_fields.κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) +wTbar_zonal = Average(wT_base, dims=1) +wSbar_zonal = Average(wS_base, dims=1) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +outputs = (; u, v, w, T, S, ρ, N², wT=wT_base, wS=wS_base) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal, wTbar_zonal, wSbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:complete_fields] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields", + schedule = TimeInterval(1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=20, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 365 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% +# function find_min(a...) +# return minimum(minimum.([a...])) +# end + +# function find_max(a...) +# return maximum(maximum.([a...])) +# end + +# N²_xz_north_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_north.jld2", "N²") +# N²_xz_south_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_south.jld2", "N²") + +# xC = N²_xz_north_data.grid.xᶜᵃᵃ[1:N²_xz_north_data.grid.Nx] +# zf = N²_xz_north_data.grid.zᵃᵃᶠ[1:N²_xz_north_data.grid.Nz+1] + +# yloc_north = N²_xz_north_data.grid.yᵃᶠᵃ[N²_xz_north_data.indices[2][1]] +# yloc_south = N²_xz_south_data.grid.yᵃᶠᵃ[N²_xz_south_data.indices[2][1]] + +# Nt = length(N²_xz_north_data) +# times = N²_xz_north_data.times / 24 / 60^2 / 365 +# timeframes = 1:Nt + +# N²_lim = (find_min(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) - 1e-13, +# find_max(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) + 1e-13) +# #%% +# fig = Figure(size=(800, 800)) +# ax_north = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_south / 1e3)) km") +# ax_south = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_north / 1e3)) km") + +# n = Observable(2) + +# N²_north = @lift interior(N²_xz_north_data[$n], :, 1, :) +# N²_south = @lift interior(N²_xz_south_data[$n], :, 1, :) + +# colorscheme = colorschemes[:jet] + +# N²_north_surface = heatmap!(ax_north, xC, zf, N²_north, colormap=colorscheme, colorrange=N²_lim) +# N²_south_surface = heatmap!(ax_south, xC, zf, N²_south, colormap=colorscheme, colorrange=N²_lim) + +# Colorbar(fig[1:2, 2], N²_north_surface, label="N² (s⁻²)") + +# title_str = @lift "Time = $(round(times[$n], digits=2)) years" +# Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +# trim!(fig.layout) + +# @info "Recording buoyancy frequency xz slice" +# CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_buoyancy_frequency_xz_slice.mp4", 1:Nt, framerate=3, px_per_unit=2) do nn +# n[] = nn +# end + +# @info "Done!" +# #%% \ No newline at end of file diff --git a/baseclosure_doublegyre_model_seasonalforcing.jl b/baseclosure_doublegyre_model_seasonalforcing.jl new file mode 100644 index 0000000000..f9719ed68b --- /dev/null +++ b/baseclosure_doublegyre_model_seasonalforcing.jl @@ -0,0 +1,627 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("xin_kai_vertical_diffusivity_local_2step.jl") +# include("xin_kai_vertical_diffusivity_2Pr.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + +#%% +filename = "doublegyre_linearseasonalforcing_10C_relaxation_30days_baseclosure" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +base_closure = XinKaiLocalVerticalDiffusivity() +vertical_scalar_closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5) +closure = (base_closure, vertical_scalar_closure) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +const seasonal_period = 360days +const seasonal_forcing_width = Ly / 6 +const seasonal_T_amplitude = 10 + +##### +##### Forcing and initial condition +##### +@inline T_initial(x, y, z) = 20 + 10 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_seasonal(y, t) = seasonal_T_amplitude * (y/Ly + 1/2) * sin(2π * t / seasonal_period) +@inline T_ref(y, t) = T_mid - ΔT / Ly * y + T_seasonal(y, t) +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y, t)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 36000days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); + +κ = model.diffusivity_fields[1].κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +outputs = (; u, v, w, T, S, ρ, N², wT_base, wS_base) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_5] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_5", + indices = (:, 5, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_15] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_15", + indices = (:, 15, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_25] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_25", + indices = (:, 25, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_35] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_35", + indices = (:, 35, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_45] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_45", + indices = (:, 45, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_55] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_55", + indices = (:, 55, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_65] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_65", + indices = (:, 65, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_75] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_75", + indices = (:, 75, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_85] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_85", + indices = (:, 85, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_95] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_95", + indices = (:, 95, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:complete_fields] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields", + schedule = TimeInterval(1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=20, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 360 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% +# function find_min(a...) +# return minimum(minimum.([a...])) +# end + +# function find_max(a...) +# return maximum(maximum.([a...])) +# end + +# N²_xz_north_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_north.jld2", "N²") +# N²_xz_south_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_south.jld2", "N²") + +# xC = N²_xz_north_data.grid.xᶜᵃᵃ[1:N²_xz_north_data.grid.Nx] +# zf = N²_xz_north_data.grid.zᵃᵃᶠ[1:N²_xz_north_data.grid.Nz+1] + +# yloc_north = N²_xz_north_data.grid.yᵃᶠᵃ[N²_xz_north_data.indices[2][1]] +# yloc_south = N²_xz_south_data.grid.yᵃᶠᵃ[N²_xz_south_data.indices[2][1]] + +# Nt = length(N²_xz_north_data) +# times = N²_xz_north_data.times / 24 / 60^2 / 360 +# timeframes = 1:Nt + +# N²_lim = (find_min(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) - 1e-13, +# find_max(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) + 1e-13) +# #%% +# fig = Figure(size=(800, 800)) +# ax_north = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_south / 1e3)) km") +# ax_south = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_north / 1e3)) km") + +# n = Observable(2) + +# N²_north = @lift interior(N²_xz_north_data[$n], :, 1, :) +# N²_south = @lift interior(N²_xz_south_data[$n], :, 1, :) + +# colorscheme = colorschemes[:jet] + +# N²_north_surface = heatmap!(ax_north, xC, zf, N²_north, colormap=colorscheme, colorrange=N²_lim) +# N²_south_surface = heatmap!(ax_south, xC, zf, N²_south, colormap=colorscheme, colorrange=N²_lim) + +# Colorbar(fig[1:2, 2], N²_north_surface, label="N² (s⁻²)") + +# title_str = @lift "Time = $(round(times[$n], digits=2)) years" +# Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +# trim!(fig.layout) + +# @info "Recording buoyancy frequency xz slice" +# CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_buoyancy_frequency_xz_slice.mp4", 1:Nt, framerate=3, px_per_unit=2) do nn +# n[] = nn +# end + +# @info "Done!" +# #%% \ No newline at end of file diff --git a/baseclosure_doublegyre_model_seasonalforcing_wallrestoration.jl b/baseclosure_doublegyre_model_seasonalforcing_wallrestoration.jl new file mode 100644 index 0000000000..c484faa1af --- /dev/null +++ b/baseclosure_doublegyre_model_seasonalforcing_wallrestoration.jl @@ -0,0 +1,633 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("xin_kai_vertical_diffusivity_local_2step.jl") +# include("xin_kai_vertical_diffusivity_2Pr.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + +#%% +filename = "doublegyre_linearseasonalforcing_10C_relaxation_wallrestoration_30days_baseclosure" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +base_closure = XinKaiLocalVerticalDiffusivity() +vertical_scalar_closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5) +closure = (base_closure, vertical_scalar_closure) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +const δy = Ly / Ny + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +const seasonal_period = 360days +const seasonal_forcing_width = Ly / 6 +const seasonal_T_amplitude = 10 + +##### +##### Forcing and initial condition +##### +@inline T_initial(x, y, z) = 20 + 10 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_seasonal(y, t) = seasonal_T_amplitude * (y/Ly + 1/2) * sin(2π * t / seasonal_period) +@inline T_ref(y, t) = T_mid - ΔT / Ly * y + T_seasonal(y, t) +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y, t)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) + +@inline T_north_ref(z) = min(0, -5 + 5 * (1 + (z + 500) / (Lz - 500))) +@inline north_T_flux(x, z, t, T) = μ_T * δy * (T - T_north_ref(z)) +north_T_flux_bc = FluxBoundaryCondition(north_T_flux; field_dependencies=:T) + +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc, north = north_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10800days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); + +κ = model.diffusivity_fields[1].κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +outputs = (; u, v, w, T, S, ρ, N², wT_base, wS_base) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_5] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_5", + indices = (:, 5, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_15] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_15", + indices = (:, 15, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_25] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_25", + indices = (:, 25, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_35] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_35", + indices = (:, 35, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_45] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_45", + indices = (:, 45, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_55] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_55", + indices = (:, 55, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_65] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_65", + indices = (:, 65, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_75] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_75", + indices = (:, 75, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_85] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_85", + indices = (:, 85, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_95] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_95", + indices = (:, 95, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:complete_fields] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields", + schedule = TimeInterval(1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +T_colormap = colorschemes[:viridis] +S_colormap = colorschemes[:viridis] +u_colormap = colorschemes[:balance] +v_colormap = colorschemes[:balance] + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=20, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 360 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% +# function find_min(a...) +# return minimum(minimum.([a...])) +# end + +# function find_max(a...) +# return maximum(maximum.([a...])) +# end + +# N²_xz_north_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_north.jld2", "N²") +# N²_xz_south_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_south.jld2", "N²") + +# xC = N²_xz_north_data.grid.xᶜᵃᵃ[1:N²_xz_north_data.grid.Nx] +# zf = N²_xz_north_data.grid.zᵃᵃᶠ[1:N²_xz_north_data.grid.Nz+1] + +# yloc_north = N²_xz_north_data.grid.yᵃᶠᵃ[N²_xz_north_data.indices[2][1]] +# yloc_south = N²_xz_south_data.grid.yᵃᶠᵃ[N²_xz_south_data.indices[2][1]] + +# Nt = length(N²_xz_north_data) +# times = N²_xz_north_data.times / 24 / 60^2 / 360 +# timeframes = 1:Nt + +# N²_lim = (find_min(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) - 1e-13, +# find_max(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) + 1e-13) +# #%% +# fig = Figure(size=(800, 800)) +# ax_north = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_south / 1e3)) km") +# ax_south = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_north / 1e3)) km") + +# n = Observable(2) + +# N²_north = @lift interior(N²_xz_north_data[$n], :, 1, :) +# N²_south = @lift interior(N²_xz_south_data[$n], :, 1, :) + +# colorscheme = colorschemes[:jet] + +# N²_north_surface = heatmap!(ax_north, xC, zf, N²_north, colormap=colorscheme, colorrange=N²_lim) +# N²_south_surface = heatmap!(ax_south, xC, zf, N²_south, colormap=colorscheme, colorrange=N²_lim) + +# Colorbar(fig[1:2, 2], N²_north_surface, label="N² (s⁻²)") + +# title_str = @lift "Time = $(round(times[$n], digits=2)) years" +# Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +# trim!(fig.layout) + +# @info "Recording buoyancy frequency xz slice" +# CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_buoyancy_frequency_xz_slice.mp4", 1:Nt, framerate=3, px_per_unit=2) do nn +# n[] = nn +# end + +# @info "Done!" +# #%% \ No newline at end of file diff --git a/baseclosure_doublegyre_model_wallrestoration.jl b/baseclosure_doublegyre_model_wallrestoration.jl new file mode 100644 index 0000000000..22e3f67a05 --- /dev/null +++ b/baseclosure_doublegyre_model_wallrestoration.jl @@ -0,0 +1,586 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("xin_kai_vertical_diffusivity_local_2step.jl") +# include("xin_kai_vertical_diffusivity_2Pr.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes + +#%% +# filename = "doublegyre_30Cwarmflushbottom10_relaxation_wallrestoration_30days_zC2O_baseclosure_trainFC24new_scalingtrain54new_2Pr_2step" +filename = "doublegyre_30Cwarmflushbottom10_relaxation_wallrestoration_30days_zWENO5_baseclosure_trainFC24new_scalingtrain54new_2Pr_2step" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +base_closure = XinKaiLocalVerticalDiffusivity() +closure = base_closure + +# advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), CenteredSecondOrder()) +advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), WENO(order=5)) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +const δy = Ly / Ny + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +##### +##### Forcing and initial condition +##### +# @inline T_initial(x, y, z) = T_north + ΔT / 2 * (1 + z / Lz) +# @inline T_initial(x, y, z) = (T_north + T_south / 2) + 5 * (1 + z / Lz) +@inline T_initial(x, y, z) = 10 + 20 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) + +@inline T_north_ref(z) = min(0, -5 + 5 * (1 + (z + 500) / (Lz - 500))) +@inline north_T_flux(x, z, t, T) = μ_T * δy * (T - T_north_ref(z)) +north_T_flux_bc = FluxBoundaryCondition(north_T_flux; field_dependencies=:T) + +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc, north = north_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10950days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)); +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)); + +κ = model.diffusivity_fields.κᶜ +wT_base = κ * ∂z(T) +wS_base = κ * ∂z(S) +wTbar_zonal = Average(wT_base, dims=1) +wSbar_zonal = Average(wS_base, dims=1) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +outputs = (; u, v, w, T, S, ρ, N², wT=wT_base, wS=wS_base) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal, wTbar_zonal, wSbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:complete_fields] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields", + schedule = TimeInterval(1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = CairoMakie.Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = CairoMakie.Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = CairoMakie.Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = CairoMakie.Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=20, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 365 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% +# function find_min(a...) +# return minimum(minimum.([a...])) +# end + +# function find_max(a...) +# return maximum(maximum.([a...])) +# end + +# N²_xz_north_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_north.jld2", "N²") +# N²_xz_south_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_south.jld2", "N²") + +# xC = N²_xz_north_data.grid.xᶜᵃᵃ[1:N²_xz_north_data.grid.Nx] +# zf = N²_xz_north_data.grid.zᵃᵃᶠ[1:N²_xz_north_data.grid.Nz+1] + +# yloc_north = N²_xz_north_data.grid.yᵃᶠᵃ[N²_xz_north_data.indices[2][1]] +# yloc_south = N²_xz_south_data.grid.yᵃᶠᵃ[N²_xz_south_data.indices[2][1]] + +# Nt = length(N²_xz_north_data) +# times = N²_xz_north_data.times / 24 / 60^2 / 365 +# timeframes = 1:Nt + +# N²_lim = (find_min(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) - 1e-13, +# find_max(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) + 1e-13) +# #%% +# fig = Figure(size=(800, 800)) +# ax_north = CairoMakie.Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_south / 1e3)) km") +# ax_south = CairoMakie.Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_north / 1e3)) km") + +# n = Observable(2) + +# N²_north = @lift interior(N²_xz_north_data[$n], :, 1, :) +# N²_south = @lift interior(N²_xz_south_data[$n], :, 1, :) + +# colorscheme = colorschemes[:jet] + +# N²_north_surface = heatmap!(ax_north, xC, zf, N²_north, colormap=colorscheme, colorrange=N²_lim) +# N²_south_surface = heatmap!(ax_south, xC, zf, N²_south, colormap=colorscheme, colorrange=N²_lim) + +# Colorbar(fig[1:2, 2], N²_north_surface, label="N² (s⁻²)") + +# title_str = @lift "Time = $(round(times[$n], digits=2)) years" +# Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +# trim!(fig.layout) + +# @info "Recording buoyancy frequency xz slice" +# CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_buoyancy_frequency_xz_slice.mp4", 1:Nt, framerate=3, px_per_unit=2) do nn +# n[] = nn +# end + +# @info "Done!" +# #%% \ No newline at end of file diff --git a/compare_3D_instantaneous_fields_slices_BBLkappazonelast.jl b/compare_3D_instantaneous_fields_slices_BBLkappazonelast.jl new file mode 100644 index 0000000000..0695fc815f --- /dev/null +++ b/compare_3D_instantaneous_fields_slices_BBLkappazonelast.jl @@ -0,0 +1,316 @@ +using GLMakie +using Oceananigans +using ColorSchemes +using SeawaterPolynomials +using SeawaterPolynomials.TEOS10 + +NN_FILE_DIR = "./Output/doublegyre_30Cwarmflushbottom10_relaxation_8days_NN_closure_BBLkappazonelast41_temp" +CATKE_FILE_DIR = "./Output/doublegyre_30Cwarmflushbottom10_relaxation_8days_baseclosure_trainFC24new_scalingtrain54new_2Pr_2step" + +fieldname = "S" +ρ_NN_data_00 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz.jld2", fieldname, backend=OnDisk()) +ρ_NN_data_10 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_10.jld2", fieldname, backend=OnDisk()) +ρ_NN_data_20 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_20.jld2", fieldname, backend=OnDisk()) +ρ_NN_data_30 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_30.jld2", fieldname, backend=OnDisk()) +ρ_NN_data_40 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_40.jld2", fieldname, backend=OnDisk()) +ρ_NN_data_50 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_50.jld2", fieldname, backend=OnDisk()) +ρ_NN_data_60 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_60.jld2", fieldname, backend=OnDisk()) +ρ_NN_data_70 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_70.jld2", fieldname, backend=OnDisk()) +ρ_NN_data_80 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_80.jld2", fieldname, backend=OnDisk()) +ρ_NN_data_90 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_90.jld2", fieldname, backend=OnDisk()) + +ρ_CATKE_data_00 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz.jld2", fieldname, backend=OnDisk()) +ρ_CATKE_data_10 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_10.jld2", fieldname, backend=OnDisk()) +ρ_CATKE_data_20 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_20.jld2", fieldname, backend=OnDisk()) +ρ_CATKE_data_30 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_30.jld2", fieldname, backend=OnDisk()) +ρ_CATKE_data_40 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_40.jld2", fieldname, backend=OnDisk()) +ρ_CATKE_data_50 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_50.jld2", fieldname, backend=OnDisk()) +ρ_CATKE_data_60 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_60.jld2", fieldname, backend=OnDisk()) +ρ_CATKE_data_70 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_70.jld2", fieldname, backend=OnDisk()) +ρ_CATKE_data_80 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_80.jld2", fieldname, backend=OnDisk()) +ρ_CATKE_data_90 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_90.jld2", fieldname, backend=OnDisk()) + +# ρ_NN_data_00 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz.jld2", fieldname) +# ρ_NN_data_10 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_10.jld2", fieldname) +# ρ_NN_data_20 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_20.jld2", fieldname) +# ρ_NN_data_30 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_30.jld2", fieldname) +# ρ_NN_data_40 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_40.jld2", fieldname) +# ρ_NN_data_50 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_50.jld2", fieldname) +# ρ_NN_data_60 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_60.jld2", fieldname) +# ρ_NN_data_70 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_70.jld2", fieldname) +# ρ_NN_data_80 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_80.jld2", fieldname) +# ρ_NN_data_90 = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_yz_90.jld2", fieldname) + +# ρ_CATKE_data_00 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz.jld2", fieldname) +# ρ_CATKE_data_10 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_10.jld2", fieldname) +# ρ_CATKE_data_20 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_20.jld2", fieldname) +# ρ_CATKE_data_30 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_30.jld2", fieldname) +# ρ_CATKE_data_40 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_40.jld2", fieldname) +# ρ_CATKE_data_50 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_50.jld2", fieldname) +# ρ_CATKE_data_60 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_60.jld2", fieldname) +# ρ_CATKE_data_70 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_70.jld2", fieldname) +# ρ_CATKE_data_80 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_80.jld2", fieldname) +# ρ_CATKE_data_90 = FieldTimeSeries("$(CATKE_FILE_DIR)/instantaneous_fields_yz_90.jld2", fieldname) + +first_index_data = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_NN_active_diagnostics.jld2", "first_index") +last_index_data = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_NN_active_diagnostics.jld2", "last_index") +Qb_data = FieldTimeSeries("$(NN_FILE_DIR)/instantaneous_fields_NN_active_diagnostics.jld2", "Qb") + +Nx, Ny, Nz = ρ_NN_data_00.grid.Nx, ρ_NN_data_00.grid.Ny, ρ_NN_data_00.grid.Nz + +xC = ρ_NN_data_00.grid.xᶜᵃᵃ[1:ρ_NN_data_00.grid.Nx] +yC = ρ_NN_data_00.grid.yᵃᶜᵃ[1:ρ_NN_data_00.grid.Ny] +zC = ρ_NN_data_00.grid.zᵃᵃᶜ[1:ρ_NN_data_00.grid.Nz] +zF = ρ_NN_data_00.grid.zᵃᵃᶠ[1:ρ_NN_data_00.grid.Nz+1] + +Nt = length(ρ_NN_data_90) +times = ρ_NN_data_00.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +#%% +fig = Figure(size=(2400, 2400)) + +axNN_00 = GLMakie.Axis(fig[1, 1], xlabel="y (m)", ylabel="z (m)", title="NN, x = $(ρ_NN_data_00.grid.xᶜᵃᵃ[ρ_NN_data_00.indices[1][1]] / 1000) km") +axNN_10 = GLMakie.Axis(fig[1, 3], xlabel="y (m)", ylabel="z (m)", title="NN, x = $(ρ_NN_data_10.grid.xᶜᵃᵃ[ρ_NN_data_10.indices[1][1]] / 1000) km") +axNN_20 = GLMakie.Axis(fig[2, 1], xlabel="y (m)", ylabel="z (m)", title="NN, x = $(ρ_NN_data_20.grid.xᶜᵃᵃ[ρ_NN_data_20.indices[1][1]] / 1000) km") +axNN_30 = GLMakie.Axis(fig[2, 3], xlabel="y (m)", ylabel="z (m)", title="NN, x = $(ρ_NN_data_30.grid.xᶜᵃᵃ[ρ_NN_data_30.indices[1][1]] / 1000) km") +axNN_40 = GLMakie.Axis(fig[3, 1], xlabel="y (m)", ylabel="z (m)", title="NN, x = $(ρ_NN_data_40.grid.xᶜᵃᵃ[ρ_NN_data_40.indices[1][1]] / 1000) km") +axNN_50 = GLMakie.Axis(fig[3, 3], xlabel="y (m)", ylabel="z (m)", title="NN, x = $(ρ_NN_data_50.grid.xᶜᵃᵃ[ρ_NN_data_50.indices[1][1]] / 1000) km") +axNN_60 = GLMakie.Axis(fig[4, 1], xlabel="y (m)", ylabel="z (m)", title="NN, x = $(ρ_NN_data_60.grid.xᶜᵃᵃ[ρ_NN_data_60.indices[1][1]] / 1000) km") +axNN_70 = GLMakie.Axis(fig[4, 3], xlabel="y (m)", ylabel="z (m)", title="NN, x = $(ρ_NN_data_70.grid.xᶜᵃᵃ[ρ_NN_data_70.indices[1][1]] / 1000) km") +axNN_80 = GLMakie.Axis(fig[5, 1], xlabel="y (m)", ylabel="z (m)", title="NN, x = $(ρ_NN_data_80.grid.xᶜᵃᵃ[ρ_NN_data_80.indices[1][1]] / 1000) km") +axNN_90 = GLMakie.Axis(fig[5, 3], xlabel="y (m)", ylabel="z (m)", title="NN, x = $(ρ_NN_data_90.grid.xᶜᵃᵃ[ρ_NN_data_90.indices[1][1]] / 1000) km") + +axCATKE_00 = GLMakie.Axis(fig[1, 2], xlabel="y (m)", ylabel="z (m)", title="Base Closure, x = $(ρ_CATKE_data_00.grid.xᶜᵃᵃ[ρ_CATKE_data_00.indices[1][1]] / 1000) km") +axCATKE_10 = GLMakie.Axis(fig[1, 4], xlabel="y (m)", ylabel="z (m)", title="Base Closure, x = $(ρ_CATKE_data_10.grid.xᶜᵃᵃ[ρ_CATKE_data_10.indices[1][1]] / 1000) km") +axCATKE_20 = GLMakie.Axis(fig[2, 2], xlabel="y (m)", ylabel="z (m)", title="Base Closure, x = $(ρ_CATKE_data_20.grid.xᶜᵃᵃ[ρ_CATKE_data_20.indices[1][1]] / 1000) km") +axCATKE_30 = GLMakie.Axis(fig[2, 4], xlabel="y (m)", ylabel="z (m)", title="Base Closure, x = $(ρ_CATKE_data_30.grid.xᶜᵃᵃ[ρ_CATKE_data_30.indices[1][1]] / 1000) km") +axCATKE_40 = GLMakie.Axis(fig[3, 2], xlabel="y (m)", ylabel="z (m)", title="Base Closure, x = $(ρ_CATKE_data_40.grid.xᶜᵃᵃ[ρ_CATKE_data_40.indices[1][1]] / 1000) km") +axCATKE_50 = GLMakie.Axis(fig[3, 4], xlabel="y (m)", ylabel="z (m)", title="Base Closure, x = $(ρ_CATKE_data_50.grid.xᶜᵃᵃ[ρ_CATKE_data_50.indices[1][1]] / 1000) km") +axCATKE_60 = GLMakie.Axis(fig[4, 2], xlabel="y (m)", ylabel="z (m)", title="Base Closure, x = $(ρ_CATKE_data_60.grid.xᶜᵃᵃ[ρ_CATKE_data_60.indices[1][1]] / 1000) km") +axCATKE_70 = GLMakie.Axis(fig[4, 4], xlabel="y (m)", ylabel="z (m)", title="Base Closure, x = $(ρ_CATKE_data_70.grid.xᶜᵃᵃ[ρ_CATKE_data_70.indices[1][1]] / 1000) km") +axCATKE_80 = GLMakie.Axis(fig[5, 2], xlabel="y (m)", ylabel="z (m)", title="Base Closure, x = $(ρ_CATKE_data_80.grid.xᶜᵃᵃ[ρ_CATKE_data_80.indices[1][1]] / 1000) km") +axCATKE_90 = GLMakie.Axis(fig[5, 4], xlabel="y (m)", ylabel="z (m)", title="Base Closure, x = $(ρ_CATKE_data_90.grid.xᶜᵃᵃ[ρ_CATKE_data_90.indices[1][1]] / 1000) km") + +n = Observable(1096) + +z_indices = 1:200 + +ρlim = (find_min(interior(ρ_NN_data_00[timeframes[1]], :, :, z_indices), interior(ρ_NN_data_00[timeframes[end]], :, :, z_indices), interior(ρ_CATKE_data_00[timeframes[1]], :, :, z_indices), interior(ρ_CATKE_data_00[timeframes[end]], :, :, z_indices)), + find_max(interior(ρ_NN_data_00[timeframes[1]], :, :, z_indices), interior(ρ_NN_data_00[timeframes[end]], :, :, z_indices), interior(ρ_CATKE_data_00[timeframes[1]], :, :, z_indices), interior(ρ_CATKE_data_00[timeframes[end]], :, :, z_indices))) + +NN_00ₙ = @lift interior(ρ_NN_data_00[$n], 1, :, z_indices) +NN_10ₙ = @lift interior(ρ_NN_data_10[$n], 1, :, z_indices) +NN_20ₙ = @lift interior(ρ_NN_data_20[$n], 1, :, z_indices) +NN_30ₙ = @lift interior(ρ_NN_data_30[$n], 1, :, z_indices) +NN_40ₙ = @lift interior(ρ_NN_data_40[$n], 1, :, z_indices) +NN_50ₙ = @lift interior(ρ_NN_data_50[$n], 1, :, z_indices) +NN_60ₙ = @lift interior(ρ_NN_data_60[$n], 1, :, z_indices) +NN_70ₙ = @lift interior(ρ_NN_data_70[$n], 1, :, z_indices) +NN_80ₙ = @lift interior(ρ_NN_data_80[$n], 1, :, z_indices) +NN_90ₙ = @lift interior(ρ_NN_data_90[$n], 1, :, z_indices) + +CATKE_00ₙ = @lift interior(ρ_CATKE_data_00[$n], 1, :, z_indices) +CATKE_10ₙ = @lift interior(ρ_CATKE_data_10[$n], 1, :, z_indices) +CATKE_20ₙ = @lift interior(ρ_CATKE_data_20[$n], 1, :, z_indices) +CATKE_30ₙ = @lift interior(ρ_CATKE_data_30[$n], 1, :, z_indices) +CATKE_40ₙ = @lift interior(ρ_CATKE_data_40[$n], 1, :, z_indices) +CATKE_50ₙ = @lift interior(ρ_CATKE_data_50[$n], 1, :, z_indices) +CATKE_60ₙ = @lift interior(ρ_CATKE_data_60[$n], 1, :, z_indices) +CATKE_70ₙ = @lift interior(ρ_CATKE_data_70[$n], 1, :, z_indices) +CATKE_80ₙ = @lift interior(ρ_CATKE_data_80[$n], 1, :, z_indices) +CATKE_90ₙ = @lift interior(ρ_CATKE_data_90[$n], 1, :, z_indices) + +zs_first_index_00ₙ = @lift zF[Int.(vec(interior(first_index_data[$n], ρ_NN_data_00.indices[1][1], :, :)))] +zs_first_index_10ₙ = @lift zF[Int.(vec(interior(first_index_data[$n], ρ_NN_data_10.indices[1][1], :, :)))] +zs_first_index_20ₙ = @lift zF[Int.(vec(interior(first_index_data[$n], ρ_NN_data_20.indices[1][1], :, :)))] +zs_first_index_30ₙ = @lift zF[Int.(vec(interior(first_index_data[$n], ρ_NN_data_30.indices[1][1], :, :)))] +zs_first_index_40ₙ = @lift zF[Int.(vec(interior(first_index_data[$n], ρ_NN_data_40.indices[1][1], :, :)))] +zs_first_index_50ₙ = @lift zF[Int.(vec(interior(first_index_data[$n], ρ_NN_data_50.indices[1][1], :, :)))] +zs_first_index_60ₙ = @lift zF[Int.(vec(interior(first_index_data[$n], ρ_NN_data_60.indices[1][1], :, :)))] +zs_first_index_70ₙ = @lift zF[Int.(vec(interior(first_index_data[$n], ρ_NN_data_70.indices[1][1], :, :)))] +zs_first_index_80ₙ = @lift zF[Int.(vec(interior(first_index_data[$n], ρ_NN_data_80.indices[1][1], :, :)))] +zs_first_index_90ₙ = @lift zF[Int.(vec(interior(first_index_data[$n], ρ_NN_data_90.indices[1][1], :, :)))] + +zs_last_index_00ₙ = @lift zF[Int.(vec(interior(last_index_data[$n], ρ_NN_data_00.indices[1][1], :, :)))] +zs_last_index_10ₙ = @lift zF[Int.(vec(interior(last_index_data[$n], ρ_NN_data_10.indices[1][1], :, :)))] +zs_last_index_20ₙ = @lift zF[Int.(vec(interior(last_index_data[$n], ρ_NN_data_20.indices[1][1], :, :)))] +zs_last_index_30ₙ = @lift zF[Int.(vec(interior(last_index_data[$n], ρ_NN_data_30.indices[1][1], :, :)))] +zs_last_index_40ₙ = @lift zF[Int.(vec(interior(last_index_data[$n], ρ_NN_data_40.indices[1][1], :, :)))] +zs_last_index_50ₙ = @lift zF[Int.(vec(interior(last_index_data[$n], ρ_NN_data_50.indices[1][1], :, :)))] +zs_last_index_60ₙ = @lift zF[Int.(vec(interior(last_index_data[$n], ρ_NN_data_60.indices[1][1], :, :)))] +zs_last_index_70ₙ = @lift zF[Int.(vec(interior(last_index_data[$n], ρ_NN_data_70.indices[1][1], :, :)))] +zs_last_index_80ₙ = @lift zF[Int.(vec(interior(last_index_data[$n], ρ_NN_data_80.indices[1][1], :, :)))] +zs_last_index_90ₙ = @lift zF[Int.(vec(interior(last_index_data[$n], ρ_NN_data_90.indices[1][1], :, :)))] + +ys_convection_00ₙ = @lift yC[interior(Qb_data[$n], ρ_NN_data_00.indices[1][1], :, 1) .> 0] +ys_convection_10ₙ = @lift yC[interior(Qb_data[$n], ρ_NN_data_10.indices[1][1], :, 1) .> 0] +ys_convection_20ₙ = @lift yC[interior(Qb_data[$n], ρ_NN_data_20.indices[1][1], :, 1) .> 0] +ys_convection_30ₙ = @lift yC[interior(Qb_data[$n], ρ_NN_data_30.indices[1][1], :, 1) .> 0] +ys_convection_40ₙ = @lift yC[interior(Qb_data[$n], ρ_NN_data_40.indices[1][1], :, 1) .> 0] +ys_convection_50ₙ = @lift yC[interior(Qb_data[$n], ρ_NN_data_50.indices[1][1], :, 1) .> 0] +ys_convection_60ₙ = @lift yC[interior(Qb_data[$n], ρ_NN_data_60.indices[1][1], :, 1) .> 0] +ys_convection_70ₙ = @lift yC[interior(Qb_data[$n], ρ_NN_data_70.indices[1][1], :, 1) .> 0] +ys_convection_80ₙ = @lift yC[interior(Qb_data[$n], ρ_NN_data_80.indices[1][1], :, 1) .> 0] +ys_convection_90ₙ = @lift yC[interior(Qb_data[$n], ρ_NN_data_90.indices[1][1], :, 1) .> 0] + +zs_convection_00ₙ = @lift fill(zC[1], length($ys_convection_00ₙ)) +zs_convection_10ₙ = @lift fill(zC[1], length($ys_convection_10ₙ)) +zs_convection_20ₙ = @lift fill(zC[1], length($ys_convection_20ₙ)) +zs_convection_30ₙ = @lift fill(zC[1], length($ys_convection_30ₙ)) +zs_convection_40ₙ = @lift fill(zC[1], length($ys_convection_40ₙ)) +zs_convection_50ₙ = @lift fill(zC[1], length($ys_convection_50ₙ)) +zs_convection_60ₙ = @lift fill(zC[1], length($ys_convection_60ₙ)) +zs_convection_70ₙ = @lift fill(zC[1], length($ys_convection_70ₙ)) +zs_convection_80ₙ = @lift fill(zC[1], length($ys_convection_80ₙ)) +zs_convection_90ₙ = @lift fill(zC[1], length($ys_convection_90ₙ)) + +# colorscheme = Reverse(colorschemes[:jet]) +colorscheme = colorschemes[:jet] + +NN_00_surface = heatmap!(axNN_00, yC, zC[z_indices], NN_00ₙ, colormap=colorscheme, colorrange=ρlim) +NN_10_surface = heatmap!(axNN_10, yC, zC[z_indices], NN_10ₙ, colormap=colorscheme, colorrange=ρlim) +NN_20_surface = heatmap!(axNN_20, yC, zC[z_indices], NN_20ₙ, colormap=colorscheme, colorrange=ρlim) +NN_30_surface = heatmap!(axNN_30, yC, zC[z_indices], NN_30ₙ, colormap=colorscheme, colorrange=ρlim) +NN_40_surface = heatmap!(axNN_40, yC, zC[z_indices], NN_40ₙ, colormap=colorscheme, colorrange=ρlim) +NN_50_surface = heatmap!(axNN_50, yC, zC[z_indices], NN_50ₙ, colormap=colorscheme, colorrange=ρlim) +NN_60_surface = heatmap!(axNN_60, yC, zC[z_indices], NN_60ₙ, colormap=colorscheme, colorrange=ρlim) +NN_70_surface = heatmap!(axNN_70, yC, zC[z_indices], NN_70ₙ, colormap=colorscheme, colorrange=ρlim) +NN_80_surface = heatmap!(axNN_80, yC, zC[z_indices], NN_80ₙ, colormap=colorscheme, colorrange=ρlim) +NN_90_surface = heatmap!(axNN_90, yC, zC[z_indices], NN_90ₙ, colormap=colorscheme, colorrange=ρlim) + +CATKE_00_surface = heatmap!(axCATKE_00, yC, zC[z_indices], CATKE_00ₙ, colormap=colorscheme, colorrange=ρlim) +CATKE_10_surface = heatmap!(axCATKE_10, yC, zC[z_indices], CATKE_10ₙ, colormap=colorscheme, colorrange=ρlim) +CATKE_20_surface = heatmap!(axCATKE_20, yC, zC[z_indices], CATKE_20ₙ, colormap=colorscheme, colorrange=ρlim) +CATKE_30_surface = heatmap!(axCATKE_30, yC, zC[z_indices], CATKE_30ₙ, colormap=colorscheme, colorrange=ρlim) +CATKE_40_surface = heatmap!(axCATKE_40, yC, zC[z_indices], CATKE_40ₙ, colormap=colorscheme, colorrange=ρlim) +CATKE_50_surface = heatmap!(axCATKE_50, yC, zC[z_indices], CATKE_50ₙ, colormap=colorscheme, colorrange=ρlim) +CATKE_60_surface = heatmap!(axCATKE_60, yC, zC[z_indices], CATKE_60ₙ, colormap=colorscheme, colorrange=ρlim) +CATKE_70_surface = heatmap!(axCATKE_70, yC, zC[z_indices], CATKE_70ₙ, colormap=colorscheme, colorrange=ρlim) +CATKE_80_surface = heatmap!(axCATKE_80, yC, zC[z_indices], CATKE_80ₙ, colormap=colorscheme, colorrange=ρlim) +CATKE_90_surface = heatmap!(axCATKE_90, yC, zC[z_indices], CATKE_90ₙ, colormap=colorscheme, colorrange=ρlim) + +# contourlevels = range(ρlim[1], ρlim[2], length=10) + +# NN_00_surface = contourf!(axNN_00, yC, zC, NN_00ₙ, colormap=colorscheme, levels=contourlevels) +# NN_10_surface = contourf!(axNN_10, yC, zC, NN_10ₙ, colormap=colorscheme, levels=contourlevels) +# NN_20_surface = contourf!(axNN_20, yC, zC, NN_20ₙ, colormap=colorscheme, levels=contourlevels) +# NN_30_surface = contourf!(axNN_30, yC, zC, NN_30ₙ, colormap=colorscheme, levels=contourlevels) +# NN_40_surface = contourf!(axNN_40, yC, zC, NN_40ₙ, colormap=colorscheme, levels=contourlevels) +# NN_50_surface = contourf!(axNN_50, yC, zC, NN_50ₙ, colormap=colorscheme, levels=contourlevels) +# NN_60_surface = contourf!(axNN_60, yC, zC, NN_60ₙ, colormap=colorscheme, levels=contourlevels) +# NN_70_surface = contourf!(axNN_70, yC, zC, NN_70ₙ, colormap=colorscheme, levels=contourlevels) +# NN_80_surface = contourf!(axNN_80, yC, zC, NN_80ₙ, colormap=colorscheme, levels=contourlevels) +# NN_90_surface = contourf!(axNN_90, yC, zC, NN_90ₙ, colormap=colorscheme, levels=contourlevels) + +# CATKE_00_surface = contourf!(axCATKE_00, yC, zC, CATKE_00ₙ, colormap=colorscheme, levels=contourlevels) +# CATKE_10_surface = contourf!(axCATKE_10, yC, zC, CATKE_10ₙ, colormap=colorscheme, levels=contourlevels) +# CATKE_20_surface = contourf!(axCATKE_20, yC, zC, CATKE_20ₙ, colormap=colorscheme, levels=contourlevels) +# CATKE_30_surface = contourf!(axCATKE_30, yC, zC, CATKE_30ₙ, colormap=colorscheme, levels=contourlevels) +# CATKE_40_surface = contourf!(axCATKE_40, yC, zC, CATKE_40ₙ, colormap=colorscheme, levels=contourlevels) +# CATKE_50_surface = contourf!(axCATKE_50, yC, zC, CATKE_50ₙ, colormap=colorscheme, levels=contourlevels) +# CATKE_60_surface = contourf!(axCATKE_60, yC, zC, CATKE_60ₙ, colormap=colorscheme, levels=contourlevels) +# CATKE_70_surface = contourf!(axCATKE_70, yC, zC, CATKE_70ₙ, colormap=colorscheme, levels=contourlevels) +# CATKE_80_surface = contourf!(axCATKE_80, yC, zC, CATKE_80ₙ, colormap=colorscheme, levels=contourlevels) +# CATKE_90_surface = contourf!(axCATKE_90, yC, zC, CATKE_90ₙ, colormap=colorscheme, levels=contourlevels) + +Colorbar(fig[1:5, 5], NN_00_surface) + +# lines!(axNN_00, yC, zs_first_index_00ₙ, color=:black) +# lines!(axNN_10, yC, zs_first_index_10ₙ, color=:black) +# lines!(axNN_20, yC, zs_first_index_20ₙ, color=:black) +# lines!(axNN_30, yC, zs_first_index_30ₙ, color=:black) +# lines!(axNN_40, yC, zs_first_index_40ₙ, color=:black) +# lines!(axNN_50, yC, zs_first_index_50ₙ, color=:black) +# lines!(axNN_60, yC, zs_first_index_60ₙ, color=:black) +# lines!(axNN_70, yC, zs_first_index_70ₙ, color=:black) +# lines!(axNN_80, yC, zs_first_index_80ₙ, color=:black) +# lines!(axNN_90, yC, zs_first_index_90ₙ, color=:black) + +# lines!(axNN_00, yC, zs_last_index_00ₙ, color=:black) +# lines!(axNN_10, yC, zs_last_index_10ₙ, color=:black) +# lines!(axNN_20, yC, zs_last_index_20ₙ, color=:black) +# lines!(axNN_30, yC, zs_last_index_30ₙ, color=:black) +# lines!(axNN_40, yC, zs_last_index_40ₙ, color=:black) +# lines!(axNN_50, yC, zs_last_index_50ₙ, color=:black) +# lines!(axNN_60, yC, zs_last_index_60ₙ, color=:black) +# lines!(axNN_70, yC, zs_last_index_70ₙ, color=:black) +# lines!(axNN_80, yC, zs_last_index_80ₙ, color=:black) +# lines!(axNN_90, yC, zs_last_index_90ₙ, color=:black) + +scatter!(axNN_00, ys_convection_00ₙ, zs_convection_00ₙ, color=:red, markersize=10) +scatter!(axNN_10, ys_convection_10ₙ, zs_convection_10ₙ, color=:red, markersize=10) +scatter!(axNN_20, ys_convection_20ₙ, zs_convection_20ₙ, color=:red, markersize=10) +scatter!(axNN_30, ys_convection_30ₙ, zs_convection_30ₙ, color=:red, markersize=10) +scatter!(axNN_40, ys_convection_40ₙ, zs_convection_40ₙ, color=:red, markersize=10) +scatter!(axNN_50, ys_convection_50ₙ, zs_convection_50ₙ, color=:red, markersize=10) +scatter!(axNN_60, ys_convection_60ₙ, zs_convection_60ₙ, color=:red, markersize=10) +scatter!(axNN_70, ys_convection_70ₙ, zs_convection_70ₙ, color=:red, markersize=10) +scatter!(axNN_80, ys_convection_80ₙ, zs_convection_80ₙ, color=:red, markersize=10) +scatter!(axNN_90, ys_convection_90ₙ, zs_convection_90ₙ, color=:red, markersize=10) + +xlims!(axNN_00, minimum(yC), maximum(yC)) +xlims!(axNN_10, minimum(yC), maximum(yC)) +xlims!(axNN_20, minimum(yC), maximum(yC)) +xlims!(axNN_30, minimum(yC), maximum(yC)) +xlims!(axNN_40, minimum(yC), maximum(yC)) +xlims!(axNN_50, minimum(yC), maximum(yC)) +xlims!(axNN_60, minimum(yC), maximum(yC)) +xlims!(axNN_70, minimum(yC), maximum(yC)) +xlims!(axNN_80, minimum(yC), maximum(yC)) +xlims!(axNN_90, minimum(yC), maximum(yC)) + +ylims!(axNN_00, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axNN_10, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axNN_20, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axNN_30, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axNN_40, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axNN_50, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axNN_60, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axNN_70, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axNN_80, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axNN_90, minimum(zC[z_indices]), maximum(zC[z_indices])) + +xlims!(axCATKE_00, minimum(yC), maximum(yC)) +xlims!(axCATKE_10, minimum(yC), maximum(yC)) +xlims!(axCATKE_20, minimum(yC), maximum(yC)) +xlims!(axCATKE_30, minimum(yC), maximum(yC)) +xlims!(axCATKE_40, minimum(yC), maximum(yC)) +xlims!(axCATKE_50, minimum(yC), maximum(yC)) +xlims!(axCATKE_60, minimum(yC), maximum(yC)) +xlims!(axCATKE_70, minimum(yC), maximum(yC)) +xlims!(axCATKE_80, minimum(yC), maximum(yC)) +xlims!(axCATKE_90, minimum(yC), maximum(yC)) + +ylims!(axCATKE_00, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axCATKE_10, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axCATKE_20, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axCATKE_30, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axCATKE_40, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axCATKE_50, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axCATKE_60, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axCATKE_70, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axCATKE_80, minimum(zC[z_indices]), maximum(zC[z_indices])) +ylims!(axCATKE_90, minimum(zC[z_indices]), maximum(zC[z_indices])) + +# title_str = @lift "Temperature (°C), Time = $(round(times[$n], digits=2)) years" +title_str = @lift "Salinity (psu), Time = $(round(times[$n], digits=2)) years" +# title_str = @lift "Potential Density (kg m⁻³), Time = $(round(times[$n], digits=2)) years" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +display(fig) +GLMakie.record(fig, "./Output/doublegyre_relaxation_8days_30Cwarmflush10bottom_NNclosure_BBLkappazonelast41_baseclosure_S_BBLlines_yzslices.mp4", 1:Nt, framerate=15, px_per_unit=2) do nn + @info "Recording frame $nn" + n[] = nn +end +#%% \ No newline at end of file diff --git a/compare_LES_NN_closure.jl b/compare_LES_NN_closure.jl new file mode 100644 index 0000000000..a18fbd5b9d --- /dev/null +++ b/compare_LES_NN_closure.jl @@ -0,0 +1,208 @@ +using Oceananigans +using Statistics +using JLD2 +using CairoMakie + +# LES_FILE_DIR = "./NN_2D_channel_horizontal_convection_0.0003_LES.jld2" + +# u_data_LES = FieldTimeSeries(LES_FILE_DIR, "u", backend=OnDisk()) +# v_data_LES = FieldTimeSeries(LES_FILE_DIR, "v", backend=OnDisk()) +# T_data_LES = FieldTimeSeries(LES_FILE_DIR, "T", backend=OnDisk()) +# S_data_LES = FieldTimeSeries(LES_FILE_DIR, "S", backend=OnDisk()) + +# yC_LES = ynodes(T_data_LES.grid, Center()) +# yF_LES = ynodes(T_data_LES.grid, Face()) + +# zC_LES = znodes(T_data_LES.grid, Center()) +# zF_LES = znodes(T_data_LES.grid, Face()) + +# Nt_LES = findfirst(x -> x ≈ end_time, T_data_LES.times) + +# Δy_LES = T_data_LES.grid.Ly / T_data_LES.grid.Ny +# Δz_LES = T_data_LES.grid.Lz / T_data_LES.grid.Nz + +MODEL_FILE_DIR = "./NN_closure_2D_channel.jld2" + +u_data_model = FieldTimeSeries(MODEL_FILE_DIR, "u") +v_data_model = FieldTimeSeries(MODEL_FILE_DIR, "v") +T_data_model = FieldTimeSeries(MODEL_FILE_DIR, "T") +S_data_model = FieldTimeSeries(MODEL_FILE_DIR, "S") + +end_time = 23 * 60^2 * 24 + +Ny_model = T_data_model.grid.Ny +Nz_model = T_data_model.grid.Nz + +Δy_model = T_data_model.grid.Ly / Ny_model +Δz_model = T_data_model.grid.Lz / Nz_model + +yC_model = ynodes(T_data_model.grid, Center()) +zC_model = znodes(T_data_model.grid, Center()) + +# coarse_ratio_y = Int(Δy_model / Δy_LES) +# coarse_ratio_z = Int(Δz_model / Δz_LES) + +# function coarsen_dataᵃᶜᵃ(data_LES, coarse_ratio_y, coarse_ratio_z, Ny_model, Nz_model, Nt_LES) +# Ny_LES = data_LES.grid.Ny +# Nz_LES = data_LES.grid.Nz +# data_LES_coarse = zeros(1, Ny_model, Nz_model, Nt_LES) +# LES_temp = zeros(Ny_LES, Nz_LES) + +# for nt in 1:Nt_LES +# LES_temp .= interior(data_LES[nt], 1, :, :) +# Threads.@threads for j in axes(data_LES_coarse, 2) +# @info "nt = $nt, Processing j = $j" +# for k in axes(data_LES_coarse, 3) +# data_LES_coarse[1, j, k, nt] = mean(LES_temp[(j-1)*coarse_ratio_y+1:j*coarse_ratio_y, (k-1)*coarse_ratio_z+1:k*coarse_ratio_z]) +# end +# end +# end + +# return data_LES_coarse +# end + +# T_data_LES_coarse = coarsen_dataᵃᶜᵃ(T_data_LES, coarse_ratio_y, coarse_ratio_z, Ny_model, Nz_model, Nt_LES) +# S_data_LES_coarse = coarsen_dataᵃᶜᵃ(S_data_LES, coarse_ratio_y, coarse_ratio_z, Ny_model, Nz_model, Nt_LES) +# u_data_LES_coarse = coarsen_dataᵃᶜᵃ(u_data_LES, coarse_ratio_y, coarse_ratio_z, Ny_model, Nz_model, Nt_LES) + +# function coarsen_dataᵃᶠᵃ(data_LES, coarse_ratio_y, coarse_ratio_z, Ny_model, Nz_model, Nt) +# Ny_LES = data_LES.grid.Ny +# Nz_LES = data_LES.grid.Nz + +# dataᵃᶠᵃ_temp = zeros(Ny_LES+1, Nz_LES) +# dataᵃᶜᵃ_temp = zeros(Ny_LES, Nz_LES) + +# data_coarse = zeros(1, Ny_model, Nz_model, Nt) + +# for nt in 1:Nt +# @info "nt = $nt, Interpolating for LES data" +# dataᵃᶠᵃ_temp .= interior(data_LES[nt], 1, :, :) +# Threads.@threads for j in 1:Ny_LES +# # for j in 1:Ny_LES +# for k in 1:Nz_LES +# dataᵃᶜᵃ_temp[j, k] = mean(dataᵃᶠᵃ_temp[j:j+1, k]) +# end +# end + +# @info "nt = $nt, Coarsening data" +# Threads.@threads for j in axes(data_coarse, 2) +# # for j in axes(data_coarse, 2) +# for k in axes(data_coarse, 3) +# data_coarse[1, j, k, nt] = mean(dataᵃᶜᵃ_temp[(j-1)*coarse_ratio_y+1:j*coarse_ratio_y, (k-1)*coarse_ratio_z+1:k*coarse_ratio_z]) +# end +# end +# end + +# return data_coarse +# end + +# v_data_LES_coarse = coarsen_dataᵃᶠᵃ(v_data_LES, coarse_ratio_y, coarse_ratio_z, Ny_model, Nz_model, Nt_LES) +# v_data_model_coarse = coarsen_dataᵃᶠᵃ(v_data_model, 1, 1, Ny_model, Nz_model, Nt_LES) + +# u_data_model_coarse = interior(u_data_model) +# T_data_model_coarse = interior(T_data_model) +# S_data_model_coarse = interior(S_data_model) + +# jldopen("./LES_NDE_FC_Qb_absf_24simnew_2layer_128_relu_2Pr_coarsened.jld2", "w") do file +# file["u_LES"] = u_data_LES_coarse +# file["v_LES"] = v_data_LES_coarse +# file["T_LES"] = T_data_LES_coarse +# file["S_LES"] = S_data_LES_coarse +# file["u_NN_model"] = u_data_model_coarse +# file["v_NN_model"] = v_data_model_coarse +# file["T_NN_model"] = T_data_model_coarse +# file["S_NN_model"] = S_data_model_coarse +# end + +FILE_DIR = "./LES_NDE_FC_Qb_absf_24simnew_2layer_128_relu_2Pr_coarsened.jld2" + +u_data_LES_coarse, v_data_LES_coarse, T_data_LES_coarse, S_data_LES_coarse, u_data_model_coarse, v_data_model_coarse, T_data_model_coarse, S_data_model_coarse = jldopen(FILE_DIR, "r") do file + u_data_LES_coarse = file["u_LES"] + v_data_LES_coarse = file["v_LES"] + T_data_LES_coarse = file["T_LES"] + S_data_LES_coarse = file["S_LES"] + u_data_model_coarse = file["u_NN_model"] + v_data_model_coarse = file["v_NN_model"] + T_data_model_coarse = file["T_NN_model"] + S_data_model_coarse = file["S_NN_model"] + return u_data_LES_coarse, v_data_LES_coarse, T_data_LES_coarse, S_data_LES_coarse, u_data_model_coarse, v_data_model_coarse, T_data_model_coarse, S_data_model_coarse +end + +Nt = size(u_data_LES_coarse, 4) +#%% +fig = CairoMakie.Figure(size = (2000, 900)) +axu_LES = CairoMakie.Axis(fig[1, 1], xlabel = "y (m)", ylabel = "z (m)", title = "u (LES) m/s") +axv_LES = CairoMakie.Axis(fig[1, 3], xlabel = "y (m)", ylabel = "z (m)", title = "v (LES) m/s") +axT_LES = CairoMakie.Axis(fig[1, 5], xlabel = "y (m)", ylabel = "z (m)", title = "Temperature (LES) °C") +axS_LES = CairoMakie.Axis(fig[1, 7], xlabel = "y (m)", ylabel = "z (m)", title = "Salinity (LES) psu") +axu_model = CairoMakie.Axis(fig[2, 1], xlabel = "y (m)", ylabel = "z (m)", title = "u (NN closure) (m/s)") +axv_model = CairoMakie.Axis(fig[2, 3], xlabel = "y (m)", ylabel = "z (m)", title = "v (NN closure) (m/s)") +axT_model = CairoMakie.Axis(fig[2, 5], xlabel = "y (m)", ylabel = "z (m)", title = "Temperature (NN closure) °C") +axS_model = CairoMakie.Axis(fig[2, 7], xlabel = "y (m)", ylabel = "z (m)", title = "Salinity (NN closure) psu") +axΔu = CairoMakie.Axis(fig[3, 1], xlabel = "y (m)", ylabel = "z (m)", title = "u (LES) - u(NN closure) (m/s)") +axΔv = CairoMakie.Axis(fig[3, 3], xlabel = "y (m)", ylabel = "z (m)", title = "v (LES) - v(NN closure) (m/s)") +axΔT = CairoMakie.Axis(fig[3, 5], xlabel = "y (m)", ylabel = "z (m)", title = "Temperature (LES) - Temperature (NN closure) (°C)") +axΔS = CairoMakie.Axis(fig[3, 7], xlabel = "y (m)", ylabel = "z (m)", title = "Salinity (LES) - Salinity (NN closure) (psu)") + +n = Observable(1) +u_LESₙ = @lift u_data_LES_coarse[1, :, :, $n] +v_LESₙ = @lift v_data_LES_coarse[1, :, :, $n] +T_LESₙ = @lift T_data_LES_coarse[1, :, :, $n] +S_LESₙ = @lift S_data_LES_coarse[1, :, :, $n] + +u_modelₙ = @lift u_data_model_coarse[1, :, :, $n] +v_modelₙ = @lift v_data_model_coarse[1, :, :, $n] +T_modelₙ = @lift T_data_model_coarse[1, :, :, $n] +S_modelₙ = @lift S_data_model_coarse[1, :, :, $n] + +Δuₙ = @lift $u_LESₙ .- $u_modelₙ +Δvₙ = @lift $v_LESₙ .- $v_modelₙ +ΔTₙ = @lift $T_LESₙ .- $T_modelₙ +ΔSₙ = @lift $S_LESₙ .- $S_modelₙ + +ulim = @lift (-maximum([maximum(abs, $u_LESₙ), 1e-16, maximum(abs, $u_modelₙ)]), + maximum([maximum(abs, $u_LESₙ), 1e-16, maximum(abs, $u_modelₙ)])) +vlim = @lift (-maximum([maximum(abs, $v_LESₙ), 1e-16, maximum(abs, $v_modelₙ)]), + maximum([maximum(abs, $v_LESₙ), 1e-16, maximum(abs, $v_modelₙ)])) +Tlim = (minimum(T_data_LES_coarse[1, :, :, 1]), maximum(T_data_LES_coarse[1, :, :, 1])) +Slim = (minimum(S_data_LES_coarse[1, :, :, 1]), maximum(S_data_LES_coarse[1, :, :, 1])) + +Δulim = @lift (-maximum([maximum(abs, $Δuₙ), 1e-16]), maximum([maximum(abs, $Δuₙ), 1e-16])) +Δvlim = @lift (-maximum([maximum(abs, $Δvₙ), 1e-16]), maximum([maximum(abs, $Δvₙ), 1e-16])) +ΔTlim = @lift (-maximum([maximum(abs, $ΔTₙ), 1e-16]), maximum([maximum(abs, $ΔTₙ), 1e-16])) +ΔSlim = @lift (-maximum([maximum(abs, $ΔSₙ), 1e-16]), maximum([maximum(abs, $ΔSₙ), 1e-16])) + +hu = heatmap!(axu_LES, yC_model, zC_model, u_LESₙ, colormap = :RdBu_9, colorrange = ulim) +hv = heatmap!(axv_LES, yC_model, zC_model, v_LESₙ, colormap = :RdBu_9, colorrange = vlim) +hT = heatmap!(axT_LES, yC_model, zC_model, T_LESₙ, colorrange = Tlim) +hS = heatmap!(axS_LES, yC_model, zC_model, S_LESₙ, colorrange = Slim) + +hu_model = heatmap!(axu_model, yC_model, zC_model, u_modelₙ, colormap = :RdBu_9, colorrange = ulim) +hv_model = heatmap!(axv_model, yC_model, zC_model, v_modelₙ, colormap = :RdBu_9, colorrange = vlim) +hT_model = heatmap!(axT_model, yC_model, zC_model, T_modelₙ, colorrange = Tlim) +hS_model = heatmap!(axS_model, yC_model, zC_model, S_modelₙ, colorrange = Slim) + +hΔu = heatmap!(axΔu, yC_model, zC_model, Δuₙ, colormap = :RdBu_9, colorrange = Δulim) +hΔv = heatmap!(axΔv, yC_model, zC_model, Δvₙ, colormap = :RdBu_9, colorrange = Δvlim) +hΔT = heatmap!(axΔT, yC_model, zC_model, ΔTₙ, colormap = :RdBu_9, colorrange = ΔTlim) +hΔS = heatmap!(axΔS, yC_model, zC_model, ΔSₙ, colormap = :RdBu_9, colorrange = ΔSlim) + +Colorbar(fig[1:2, 2], hu, label = "u (m/s)") +Colorbar(fig[1:2, 4], hv, label = "v (m/s)") +Colorbar(fig[1:2, 6], hT, label = "T (°C)") +Colorbar(fig[1:2, 8], hS, label = "S (psu)") + +Colorbar(fig[3, 2], hΔu, label = "u (m/s)") +Colorbar(fig[3, 4], hΔv, label = "v (m/s)") +Colorbar(fig[3, 6], hΔT, label = "T (°C)") +Colorbar(fig[3, 8], hΔS, label = "S (psu)") + +# display(fig) + +CairoMakie.record(fig, "./LES_NN_2D_sin_cooling_heating_3e-4_23_days_comparison.mp4", 1:Nt, framerate=30) do nn + @info nn + n[] = nn +end + + +#%% \ No newline at end of file diff --git a/feature_scaling.jl b/feature_scaling.jl new file mode 100644 index 0000000000..4d3d8592a6 --- /dev/null +++ b/feature_scaling.jl @@ -0,0 +1,88 @@ +using Statistics + +abstract type AbstractFeatureScaling end + +##### +##### Zero-mean unit-variance feature scaling +##### + +struct ZeroMeanUnitVarianceScaling{T} <: AbstractFeatureScaling + μ :: T + σ :: T +end + +""" + ZeroMeanUnitVarianceScaling(data) + +Returns a feature scaler for `data` with zero mean and unit variance. +""" +function ZeroMeanUnitVarianceScaling(data) + μ, σ = mean(data), std(data) + return ZeroMeanUnitVarianceScaling(μ, σ) +end + +scale(x, s::ZeroMeanUnitVarianceScaling) = (x .- s.μ) / s.σ +unscale(y, s::ZeroMeanUnitVarianceScaling) = s.σ * y .+ s.μ + +##### +##### Min-max feature scaling +##### + +struct MinMaxScaling{T} <: AbstractFeatureScaling + a :: T + b :: T + data_min :: T + data_max :: T +end + +""" + MinMaxScaling(data; a=0, b=1) + +Returns a feature scaler for `data` with minimum `a` and `maximum `b`. +""" +function MinMaxScaling(data; a=0, b=1) + data_min, data_max = extrema(data) + return MinMaxScaling{typeof(data_min)}(a, b, data_min, data_max) +end + +scale(x, s::MinMaxScaling) = s.a + (x - s.data_min) * (s.b - s.a) / (s.data_max - s.data_min) +unscale(y, s::MinMaxScaling) = s.data_min .+ (y .- s.a) * (s.data_max - s.data_min) / (s.b - s.a) + +##### +##### Convenience functions +##### + +(s::AbstractFeatureScaling)(x) = scale(x, s) +Base.inv(s::AbstractFeatureScaling) = y -> unscale(y, s) + +struct DiffusivityScaling{T} <: AbstractFeatureScaling + ν₀ :: T + κ₀ :: T + ν₁ :: T + κ₁ :: T +end + +function DiffusivityScaling(ν₀=1e-5, κ₀=1e-5, ν₁=0.1, κ₁=0.1) + return DiffusivityScaling(ν₀, κ₀, ν₁, κ₁) +end + +function scale(x, s::DiffusivityScaling) + ν, κ = x + ν₀, κ₀, ν₁, κ₁ = s.ν₀, s.κ₀, s.ν₁, s.κ₁ + return ν₀ + ν * ν₁, κ₀ + κ * κ₁ +end + +function unscale(y, s::DiffusivityScaling) + ν, κ = y + ν₀, κ₀, ν₁, κ₁ = s.ν₀, s.κ₀, s.ν₁, s.κ₁ + return (ν - ν₀) / ν₁, (κ - κ₀) / κ₁ +end + +(s::DiffusivityScaling)(x) = scale(x, s) +Base.inv(s::DiffusivityScaling) = y -> unscale(y, s) + +function construct_zeromeanunitvariance_scaling(scaling_params) + return NamedTuple(key=>ZeroMeanUnitVarianceScaling(scaling_params[key].μ, scaling_params[key].σ) for key in keys(scaling_params)) +end + + diff --git a/physicalclosure_doublegyre_model.jl b/physicalclosure_doublegyre_model.jl new file mode 100644 index 0000000000..c11ec722de --- /dev/null +++ b/physicalclosure_doublegyre_model.jl @@ -0,0 +1,539 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +using Oceananigans.TurbulenceClosures: CATKEVerticalDiffusivity +using Oceananigans.TurbulenceClosures.TKEBasedVerticalDiffusivities: CATKEVerticalDiffusivity, CATKEMixingLength, CATKEEquation + +using Oceananigans.BuoyancyModels: ∂z_b +# include("NN_closure_global.jl") +# include("xin_kai_vertical_diffusivity_local.jl") +# include("xin_kai_vertical_diffusivity_2Pr.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes +using Glob + +#%% +# filename = "doublegyre_30Cwarmflushbottom10_relaxation_30days_zC2O_CATKEVerticalDiffusivity" +filename = "doublegyre_30Cwarmflushbottom10_relaxation_30days_zWENO5_CATKEVerticalDiffusivity" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +# vertical_base_closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5) +# convection_closure = XinKaiVerticalDiffusivity() +function CATKE_ocean_closure() + mixing_length = CATKEMixingLength(Cᵇ=0.01) + turbulent_kinetic_energy_equation = CATKEEquation(Cᵂϵ=1.0) + return CATKEVerticalDiffusivity(; mixing_length, turbulent_kinetic_energy_equation) +end +convection_closure = CATKE_ocean_closure() +closure = convection_closure +# closure = vertical_base_closure + +advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), WENO(order=5)) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +##### +##### Forcing and initial condition +##### +# @inline T_initial(x, y, z) = T_north + ΔT / 2 * (1 + z / Lz) +@inline T_initial(x, y, z) = 10 + 20 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S, :e), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S, :e), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10950days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)) +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) +compute!(ρ) + +ρᶠ = @at (Center, Center, Face) ρ +∂ρ∂z = ∂z(ρ) +∂²ρ∂z² = ∂z(∂ρ∂z) + +κc = model.diffusivity_fields.κc +wT = κc * ∂z(T) +wS = κc * ∂z(S) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) +wTbar_zonal = Average(wT, dims=1) +wSbar_zonal = Average(wS, dims=1) + +outputs = (; u, v, w, T, S, ρ, ρᶠ, ∂ρ∂z, ∂²ρ∂z², N², wT, wS) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal, wTbar_zonal, wSbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=20, px_per_unit=2) do nn + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 365 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% \ No newline at end of file diff --git a/physicalclosure_doublegyre_model_initialized.jl b/physicalclosure_doublegyre_model_initialized.jl new file mode 100644 index 0000000000..df33e85250 --- /dev/null +++ b/physicalclosure_doublegyre_model_initialized.jl @@ -0,0 +1,592 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +using Oceananigans.TurbulenceClosures: CATKEVerticalDiffusivity +using Oceananigans.TurbulenceClosures.TKEBasedVerticalDiffusivities: CATKEVerticalDiffusivity, CATKEMixingLength, CATKEEquation + +using Oceananigans.BuoyancyModels: ∂z_b +# include("NN_closure_global.jl") +# include("xin_kai_vertical_diffusivity_local.jl") +# include("xin_kai_vertical_diffusivity_2Pr.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes +using Glob + + +#%% +filename = "doublegyre_relaxation_30days_CATKEVerticalDiffusivity_initialized_test" +FILE_DIR = "./Output/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +# vertical_base_closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5) +# convection_closure = XinKaiVerticalDiffusivity() +function CATKE_ocean_closure() + mixing_length = CATKEMixingLength(Cᵇ=0.01) + turbulent_kinetic_energy_equation = CATKEEquation(Cᵂϵ=1.0) + return CATKEVerticalDiffusivity(; mixing_length, turbulent_kinetic_energy_equation) +end +convection_closure = CATKE_ocean_closure() +closure = convection_closure +# closure = vertical_base_closure + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +##### +##### Forcing and initial condition +##### + +@inline T_initial(x, y, z) = T_north + ΔT / 2 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S, :e), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S, :e), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +# noise(z) = rand() * exp(z / 8) + +# T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +# S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) +DATA_DIR = "./Output/doublegyre_30Cwarmflush_relaxation_8days_baseclosure_trainFC24new_scalingtrain54new_2Pr_2step" + +u_data = FieldTimeSeries("$(DATA_DIR)/instantaneous_fields.jld2", "u", backend=OnDisk()) +v_data = FieldTimeSeries("$(DATA_DIR)/instantaneous_fields.jld2", "v", backend=OnDisk()) +T_data = FieldTimeSeries("$(DATA_DIR)/instantaneous_fields.jld2", "T", backend=OnDisk()) +S_data = FieldTimeSeries("$(DATA_DIR)/instantaneous_fields.jld2", "S", backend=OnDisk()) + +ntimes = length(u_data.times) + +set!(model, T=T_data[ntimes], S=S_data[ntimes], u=u_data[ntimes], v=v_data[ntimes]) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10950days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)) +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) +compute!(ρ) + +ρᶠ = @at (Center, Center, Face) ρ +∂ρ∂z = ∂z(ρ) +∂²ρ∂z² = ∂z(∂ρ∂z) + +κc = model.diffusivity_fields.κc +wT = κc * ∂z(T) +wS = κc * ∂z(S) + +outputs = (; u, v, w, T, S, ρ, ρᶠ, ∂ρ∂z, ∂²ρ∂z², N², wT, wS) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_south] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_south", + indices = (:, 25, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_north] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_north", + indices = (:, 75, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=20, px_per_unit=2) do nn + n[] = nn +end + +@info "Done!" +#%% +Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +Nt = length(Ψ_data) +times = Ψ_data.times / 24 / 60^2 / 365 +#%% +timeframe = Nt +Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +clim = maximum(abs, Ψ_frame) + 1e-13 +N_levels = 16 +levels = range(-clim, stop=clim, length=N_levels) +fig = Figure(size=(800, 800)) +ax = Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +tightlimits!(ax) +save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +display(fig) +#%% +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +N²_xz_north_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_north.jld2", "N²") +N²_xz_south_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz_south.jld2", "N²") + +xC = N²_xz_north_data.grid.xᶜᵃᵃ[1:N²_xz_north_data.grid.Nx] +zf = N²_xz_north_data.grid.zᵃᵃᶠ[1:N²_xz_north_data.grid.Nz+1] + +yloc_north = N²_xz_north_data.grid.yᵃᶠᵃ[N²_xz_north_data.indices[2][1]] +yloc_south = N²_xz_south_data.grid.yᵃᶠᵃ[N²_xz_south_data.indices[2][1]] + +Nt = length(N²_xz_north_data) +times = N²_xz_north_data.times / 24 / 60^2 / 365 +timeframes = 1:Nt + +N²_lim = (find_min(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) - 1e-13, + find_max(interior(N²_xz_north_data, :, 1, :, timeframes), interior(N²_xz_south_data, :, 1, :, timeframes)) + 1e-13) +#%% +fig = Figure(size=(800, 800)) +ax_north = Axis(fig[1, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_south / 1e3)) km") +ax_south = Axis(fig[2, 1], xlabel="x (m)", ylabel="z (m)", title="Buoyancy Frequency N² at y = $(round(yloc_north / 1e3)) km") + +n = Observable(2) + +N²_north = @lift interior(N²_xz_north_data[$n], :, 1, :) +N²_south = @lift interior(N²_xz_south_data[$n], :, 1, :) + +colorscheme = colorschemes[:jet] + +N²_north_surface = heatmap!(ax_north, xC, zf, N²_north, colormap=colorscheme, colorrange=N²_lim) +N²_south_surface = heatmap!(ax_south, xC, zf, N²_south, colormap=colorscheme, colorrange=N²_lim) + +Colorbar(fig[1:2, 2], N²_north_surface, label="N² (s⁻²)") + +title_str = @lift "Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=title_str, tellwidth=false, font=:bold) + +trim!(fig.layout) + +@info "Recording buoyancy frequency xz slice" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_buoyancy_frequency_xz_slice.mp4", 1:Nt, framerate=20, px_per_unit=2) do nn + n[] = nn +end + +@info "Done!" +#%% \ No newline at end of file diff --git a/physicalclosure_doublegyre_model_modewater.jl b/physicalclosure_doublegyre_model_modewater.jl new file mode 100644 index 0000000000..6b80b049fb --- /dev/null +++ b/physicalclosure_doublegyre_model_modewater.jl @@ -0,0 +1,546 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +using Oceananigans.TurbulenceClosures: CATKEVerticalDiffusivity +using Oceananigans.TurbulenceClosures.TKEBasedVerticalDiffusivities: CATKEVerticalDiffusivity, CATKEMixingLength, CATKEEquation + +using Oceananigans.BuoyancyModels: ∂z_b +# include("NN_closure_global.jl") +# include("xin_kai_vertical_diffusivity_local.jl") +# include("xin_kai_vertical_diffusivity_2Pr.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes +using Glob + +#%% +filename = "doublegyre_30Cwarmflushbottom10_relaxation_30days_modewater_zWENO5_CATKEVerticalDiffusivity" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +function CATKE_ocean_closure() + mixing_length = CATKEMixingLength(Cᵇ=0.01) + turbulent_kinetic_energy_equation = CATKEEquation(Cᵂϵ=1.0) + return CATKEVerticalDiffusivity(; mixing_length, turbulent_kinetic_energy_equation) +end +convection_closure = CATKE_ocean_closure() +closure = convection_closure + +advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), WENO(order=5)) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +const X₀ = -Lx/2 + 800kilometers +const Y₀ = -Ly/2 + 1500kilometers +const R₀ = 700kilometers +const Qᵀ_mode = 4.5e-4 +const σ_mode = 20kilometers + +##### +##### Forcing and initial condition +##### +# @inline T_initial(x, y, z) = T_north + ΔT / 2 * (1 + z / Lz) +@inline T_initial(x, y, z) = 10 + 20 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y + +@inline Qᵀ_winter(t) = max(0, -Qᵀ_mode * sin(2π * t / 360days)) +@inline Qᵀ_subpolar(x, y, t) = ifelse((x - X₀)^2 + (y - Y₀)^2 <= R₀^2, Qᵀ_winter(t), + exp(-(sqrt((x - X₀)^2 + (y - Y₀)^2) - R₀)^2 / (2 * σ_mode^2)) * Qᵀ_winter(t)) + +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) + Qᵀ_subpolar(x, y, t) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S, :e), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S, :e), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10800days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)) +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) +compute!(ρ) + +ρᶠ = @at (Center, Center, Face) ρ +∂ρ∂z = ∂z(ρ) +∂²ρ∂z² = ∂z(∂ρ∂z) + +κc = model.diffusivity_fields.κc +wT = κc * ∂z(T) +wS = κc * ∂z(S) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) +wTbar_zonal = Average(wT, dims=1) +wSbar_zonal = Average(wS, dims=1) + +outputs = (; u, v, w, T, S, ρ, ρᶠ, ∂ρ∂z, ∂²ρ∂z², N², wT, wS) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal, wTbar_zonal, wSbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=20, px_per_unit=2) do nn + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 365 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% \ No newline at end of file diff --git a/physicalclosure_doublegyre_model_seasonalforcing.jl b/physicalclosure_doublegyre_model_seasonalforcing.jl new file mode 100644 index 0000000000..7e811d1c3b --- /dev/null +++ b/physicalclosure_doublegyre_model_seasonalforcing.jl @@ -0,0 +1,589 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +using Oceananigans.TurbulenceClosures: CATKEVerticalDiffusivity +using Oceananigans.TurbulenceClosures.TKEBasedVerticalDiffusivities: CATKEVerticalDiffusivity, CATKEMixingLength, CATKEEquation + +using Oceananigans.BuoyancyModels: ∂z_b +# include("NN_closure_global.jl") +# include("xin_kai_vertical_diffusivity_local.jl") +# include("xin_kai_vertical_diffusivity_2Pr.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes +using Glob + + +#%% +filename = "doublegyre_linearseasonalforcing_10C_relaxation_30days_CATKEVerticalDiffusivity" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +# vertical_base_closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5) +# convection_closure = XinKaiVerticalDiffusivity() +function CATKE_ocean_closure() + mixing_length = CATKEMixingLength(Cᵇ=0.01) + turbulent_kinetic_energy_equation = CATKEEquation(Cᵂϵ=1.0) + return CATKEVerticalDiffusivity(; mixing_length, turbulent_kinetic_energy_equation) +end +convection_closure = CATKE_ocean_closure() +closure = convection_closure +# closure = vertical_base_closure + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +const seasonal_period = 360days +const seasonal_forcing_width = Ly / 6 +const seasonal_T_amplitude = 10 + +##### +##### Forcing and initial condition +##### +@inline T_initial(x, y, z) = 20 + 10 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_seasonal(y, t) = seasonal_T_amplitude * (y/Ly + 1/2) * sin(2π * t / seasonal_period) +@inline T_ref(y, t) = T_mid - ΔT / Ly * y + T_seasonal(y, t) +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y, t)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S, :e), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S, :e), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 36000days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)) +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) +compute!(ρ) + +ρᶠ = @at (Center, Center, Face) ρ +∂ρ∂z = ∂z(ρ) +∂²ρ∂z² = ∂z(∂ρ∂z) + +κc = model.diffusivity_fields.κc +wT = κc * ∂z(T) +wS = κc * ∂z(S) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +outputs = (; u, v, w, T, S, ρ, ρᶠ, ∂ρ∂z, ∂²ρ∂z², N², wT, wS) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_5] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_5", + indices = (:, 5, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_15] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_15", + indices = (:, 15, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_25] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_25", + indices = (:, 25, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_35] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_35", + indices = (:, 35, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_45] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_45", + indices = (:, 45, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_55] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_55", + indices = (:, 55, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_65] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_65", + indices = (:, 65, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_75] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_75", + indices = (:, 75, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_85] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_85", + indices = (:, 85, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_95] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_95", + indices = (:, 95, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=20, px_per_unit=2) do nn + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 360 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% \ No newline at end of file diff --git a/physicalclosure_doublegyre_model_seasonalforcing_wallrestoration.jl b/physicalclosure_doublegyre_model_seasonalforcing_wallrestoration.jl new file mode 100644 index 0000000000..b0efdf5f82 --- /dev/null +++ b/physicalclosure_doublegyre_model_seasonalforcing_wallrestoration.jl @@ -0,0 +1,595 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +using Oceananigans.TurbulenceClosures: CATKEVerticalDiffusivity +using Oceananigans.TurbulenceClosures.TKEBasedVerticalDiffusivities: CATKEVerticalDiffusivity, CATKEMixingLength, CATKEEquation + +using Oceananigans.BuoyancyModels: ∂z_b +# include("NN_closure_global.jl") +# include("xin_kai_vertical_diffusivity_local.jl") +# include("xin_kai_vertical_diffusivity_2Pr.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes +using Glob + + +#%% +filename = "doublegyre_linearseasonalforcing_10C_relaxation_wallrestoration_30days_CATKEVerticalDiffusivity" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +# vertical_base_closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5) +# convection_closure = XinKaiVerticalDiffusivity() +function CATKE_ocean_closure() + mixing_length = CATKEMixingLength(Cᵇ=0.01) + turbulent_kinetic_energy_equation = CATKEEquation(Cᵂϵ=1.0) + return CATKEVerticalDiffusivity(; mixing_length, turbulent_kinetic_energy_equation) +end +convection_closure = CATKE_ocean_closure() +closure = convection_closure +# closure = vertical_base_closure + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +const Δy = Ly / Ny + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +const seasonal_period = 360days +const seasonal_forcing_width = Ly / 6 +const seasonal_T_amplitude = 10 + +##### +##### Forcing and initial condition +##### +@inline T_initial(x, y, z) = 20 + 10 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_seasonal(y, t) = seasonal_T_amplitude * (y/Ly + 1/2) * sin(2π * t / seasonal_period) +@inline T_ref(y, t) = T_mid - ΔT / Ly * y + T_seasonal(y, t) +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y, t)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) + +@inline T_north_ref(z) = min(0, -5 + 5 * (1 + (z + 500) / (Lz - 500))) +@inline north_T_flux(x, z, t, T) = μ_T * Δy * (T - T_north_ref(z)) +north_T_flux_bc = FluxBoundaryCondition(north_T_flux; field_dependencies=:T) + +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc, north = north_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S, :e), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S, :e), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10800days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)) +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) +compute!(ρ) + +ρᶠ = @at (Center, Center, Face) ρ +∂ρ∂z = ∂z(ρ) +∂²ρ∂z² = ∂z(∂ρ∂z) + +κc = model.diffusivity_fields.κc +wT = κc * ∂z(T) +wS = κc * ∂z(S) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) + +outputs = (; u, v, w, T, S, ρ, ρᶠ, ∂ρ∂z, ∂²ρ∂z², N², wT, wS) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_5] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_5", + indices = (:, 5, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_15] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_15", + indices = (:, 15, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_25] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_25", + indices = (:, 25, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_35] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_35", + indices = (:, 35, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_45] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_45", + indices = (:, 45, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_55] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_55", + indices = (:, 55, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_65] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_65", + indices = (:, 65, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_75] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_75", + indices = (:, 75, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_85] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_85", + indices = (:, 85, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_95] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz_95", + indices = (:, 95, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +T_colormap = colorschemes[:viridis] +S_colormap = colorschemes[:viridis] +u_colormap = colorschemes[:balance] +v_colormap = colorschemes[:balance] + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=20, px_per_unit=2) do nn + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 360 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% \ No newline at end of file diff --git a/physicalclosure_doublegyre_model_wallrestoration.jl b/physicalclosure_doublegyre_model_wallrestoration.jl new file mode 100644 index 0000000000..63921f0a42 --- /dev/null +++ b/physicalclosure_doublegyre_model_wallrestoration.jl @@ -0,0 +1,547 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +using Oceananigans.TurbulenceClosures: CATKEVerticalDiffusivity +using Oceananigans.TurbulenceClosures.TKEBasedVerticalDiffusivities: CATKEVerticalDiffusivity, CATKEMixingLength, CATKEEquation + +using Oceananigans.BuoyancyModels: ∂z_b +# include("NN_closure_global.jl") +# include("xin_kai_vertical_diffusivity_local.jl") +# include("xin_kai_vertical_diffusivity_2Pr.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes +using Glob + +#%% +# filename = "doublegyre_30Cwarmflushbottom10_relaxation_wallrestoration_30days_zC2O_CATKEVerticalDiffusivity" +filename = "doublegyre_30Cwarmflushbottom10_relaxation_wallrestoration_30days_zWENO5_CATKEVerticalDiffusivity" +FILE_DIR = "./Output/$(filename)" +# FILE_DIR = "/storage6/xinkai/NN_Oceananigans/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +# vertical_base_closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5) +# convection_closure = XinKaiVerticalDiffusivity() +function CATKE_ocean_closure() + mixing_length = CATKEMixingLength(Cᵇ=0.01) + turbulent_kinetic_energy_equation = CATKEEquation(Cᵂϵ=1.0) + return CATKEVerticalDiffusivity(; mixing_length, turbulent_kinetic_energy_equation) +end +convection_closure = CATKE_ocean_closure() +closure = convection_closure +# closure = vertical_base_closure + +# advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), CenteredSecondOrder()) +advection_scheme = FluxFormAdvection(WENO(order=5), WENO(order=5), WENO(order=5)) + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +const δy = Ly / Ny + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/30days + +##### +##### Forcing and initial condition +##### +# @inline T_initial(x, y, z) = T_north + ΔT / 2 * (1 + z / Lz) +@inline T_initial(x, y, z) = 10 + 20 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) + +@inline T_north_ref(z) = min(0, -5 + 5 * (1 + (z + 500) / (Lz - 500))) +@inline north_T_flux(x, z, t, T) = μ_T * δy * (T - T_north_ref(z)) +north_T_flux_bc = FluxBoundaryCondition(north_T_flux; field_dependencies=:T) + +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc, north = north_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S, :e), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = advection_scheme, + tracer_advection = advection_scheme, + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = closure, + tracers = (:T, :S, :e), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10950days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +U_bt = Field(Integral(u, dims=3)) +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)) + +@inline function get_N²(i, j, k, grid, b, C) + return ∂z_b(i, j, k, grid, b, C) +end + +N²_op = KernelFunctionOperation{Center, Center, Face}(get_N², model.grid, model.buoyancy.model, model.tracers) +N² = Field(N²_op) + +@inline function get_density(i, j, k, grid, b, C) + T, S = Oceananigans.BuoyancyModels.get_temperature_and_salinity(b, C) + @inbounds ρ = TEOS10.ρ(T[i, j, k], S[i, j, k], 0, b.model.equation_of_state) + return ρ +end + +ρ_op = KernelFunctionOperation{Center, Center, Center}(get_density, model.grid, model.buoyancy, model.tracers) +ρ = Field(ρ_op) +compute!(ρ) + +ρᶠ = @at (Center, Center, Face) ρ +∂ρ∂z = ∂z(ρ) +∂²ρ∂z² = ∂z(∂ρ∂z) + +κc = model.diffusivity_fields.κc +wT = κc * ∂z(T) +wS = κc * ∂z(S) + +ubar_zonal = Average(u, dims=1) +vbar_zonal = Average(v, dims=1) +wbar_zonal = Average(w, dims=1) +Tbar_zonal = Average(T, dims=1) +Sbar_zonal = Average(S, dims=1) +ρbar_zonal = Average(ρ, dims=1) +wTbar_zonal = Average(wT, dims=1) +wSbar_zonal = Average(wS, dims=1) + +outputs = (; u, v, w, T, S, ρ, ρᶠ, ∂ρ∂z, ∂²ρ∂z², N², wT, wS) +zonal_outputs = (; ubar_zonal, vbar_zonal, wbar_zonal, Tbar_zonal, Sbar_zonal, ρbar_zonal, wTbar_zonal, wSbar_zonal) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_10] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_10", + indices = (10, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_20] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_20", + indices = (20, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_30] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_30", + indices = (30, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_40] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_40", + indices = (40, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_50] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_50", + indices = (50, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_60] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_60", + indices = (60, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_70] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_70", + indices = (70, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_80] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_80", + indices = (80, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:yz_90] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz_90", + indices = (90, :, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:zonal_average] = JLD2OutputWriter(model, zonal_outputs, + filename = "$(FILE_DIR)/averaged_fields_zonal", + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(1825days, window=1825days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +@info "Recording 3D fields" +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=20, px_per_unit=2) do nn + n[] = nn +end + +@info "Done!" +#%% +# Ψ_data = FieldTimeSeries("$(FILE_DIR)/averaged_fields_streamfunction.jld2", "Ψ") + +# xF = Ψ_data.grid.xᶠᵃᵃ[1:Ψ_data.grid.Nx+1] +# yC = Ψ_data.grid.yᵃᶜᵃ[1:Ψ_data.grid.Ny] + +# Nt = length(Ψ_data) +# times = Ψ_data.times / 24 / 60^2 / 365 +# #%% +# timeframe = Nt +# Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +# clim = maximum(abs, Ψ_frame) + 1e-13 +# N_levels = 16 +# levels = range(-clim, stop=clim, length=N_levels) +# fig = Figure(size=(800, 800)) +# ax = Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="CATKE Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +# cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +# Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +# tightlimits!(ax) +# save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +# display(fig) +#%% \ No newline at end of file diff --git a/plot_barotropic_streamfunction.jl b/plot_barotropic_streamfunction.jl new file mode 100644 index 0000000000..6bc5454513 --- /dev/null +++ b/plot_barotropic_streamfunction.jl @@ -0,0 +1,30 @@ +using CairoMakie +using Oceananigans + +filename = "doublegyre_RiBasedVerticalDiffusivity_streamfunction" +FILE_DIR = "./Output/$(filename)/" + +Ψ_data = FieldTimeSeries("$(FILE_DIR)/doublegyre_Ri_based_vertical_diffusivity_2Pr_streamfunction.jld2", "Ψ") + +Nx = Ψ_data.grid.Nx +Ny = Ψ_data.grid.Ny + +xF = Ψ_data.grid.xᶠᵃᵃ[1:Nx+1] +yC = Ψ_data.grid.yᵃᶜᵃ[1:Ny] + +Nt = length(Ψ_data) +times = Ψ_data.times / 24 / 60^2 / 365 +#%% +timeframe = 31 +Ψ_frame = interior(Ψ_data[timeframe], :, :, 1) ./ 1e6 +clim = maximum(abs, Ψ_frame) +N_levels = 16 +levels = range(-clim, stop=clim, length=N_levels) +fig = Figure(size=(800, 800)) +ax = Axis(fig[1, 1], xlabel="x (m)", ylabel="y (m)", title="Ri-based Vertical Diffusivity, Yearly-Averaged Barotropic streamfunction Ψ, Year $(times[timeframe])") +cf = contourf!(ax, xF, yC, Ψ_frame, levels=levels, colormap=Reverse(:RdBu_11)) +Colorbar(fig[1, 2], cf, label="Ψ (Sv)") +tightlimits!(ax) +save("$(FILE_DIR)/barotropic_streamfunction_$(timeframe).png", fig, px_per_unit=4) +display(fig) +#%% \ No newline at end of file diff --git a/validate_NN_1D_model.jl b/validate_NN_1D_model.jl new file mode 100644 index 0000000000..c7b93d2cd6 --- /dev/null +++ b/validate_NN_1D_model.jl @@ -0,0 +1,203 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("NN_closure_global.jl") +include("xin_kai_vertical_diffusivity_local.jl") +include("feature_scaling.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 + + +# Architecture +model_architecture = CPU() + +file = jldopen("model_inference_run.jld2", "r") + +# number of grid points +const Nz = file["Nz"] +const Lz = file["Lz"] + +grid = RectilinearGrid(model_architecture, + topology = (Flat, Flat, Bounded), + size = Nz, + halo = 3, + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const dTdz = file["dTdz"] +const dSdz = file["dSdz"] + +const T_surface = file["T_surface"] +const S_surface = file["S_surface"] + +T_bcs = FieldBoundaryConditions(top=FluxBoundaryCondition(file["wT_top"])) +S_bcs = FieldBoundaryConditions(top=FluxBoundaryCondition(file["wS_top"])) + +##### +##### Coriolis +##### + +const f₀ = file["f₀"] +coriolis = FPlane(f=f₀) + +##### +##### Forcing and initial condition +##### +T_initial(z) = dTdz * z + T_surface +S_initial(z) = dSdz * z + S_surface + +nn_closure = NNFluxClosure(model_architecture) +base_closure = XinKaiLocalVerticalDiffusivity() + +##### +##### Model building +##### + +@info "Building a model..." + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = ImplicitFreeSurface(), + momentum_advection = WENO(grid = grid), + tracer_advection = WENO(grid = grid), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = (nn_closure, base_closure), + # closure = base_closure, + tracers = (:T, :S), + boundary_conditions = (; T = T_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(z) = T_initial(z) + 1e-6 * noise(z) +S_initial_noisy(z) = S_initial(z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = file["Δt"] +stop_time = file["τ"] + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(20)) + +##### +##### Diagnostics +##### + +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S + +Tbar = Field(Average(T, dims = (1,2))) +Sbar = Field(Average(S, dims = (1,2))) + +averaged_outputs = (; Tbar, Sbar) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:jld2] = JLD2OutputWriter(model, averaged_outputs, + filename = "NN_1D_channel_averages", + schedule = TimeInterval(Δt₀), + overwrite_existing = true) + +@info "Running the simulation..." + +try + run!(simulation, pickup = false) +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +# ##### +# ##### Visualization +# ##### +#%% +using GLMakie + +Tbar_data = FieldTimeSeries("./NN_1D_channel_averages.jld2", "Tbar") +Sbar_data = FieldTimeSeries("./NN_1D_channel_averages.jld2", "Sbar") + +zC = znodes(Tbar_data.grid, Center()) +zF = znodes(Tbar_data.grid, Face()) + +Nt = length(Tbar_data.times) + +fig = Figure(size = (900, 600)) +axT = GLMakie.Axis(fig[1, 1], xlabel = "T (°C)", ylabel = "z (m)") +axS = GLMakie.Axis(fig[1, 2], xlabel = "S (g kg⁻¹)", ylabel = "z (m)") +slider = Slider(fig[2, :], range=1:Nt) +n = slider.value + +Tbarₙ = @lift interior(Tbar_data[$n], 1, 1, :) +Sbarₙ = @lift interior(Sbar_data[$n], 1, 1, :) + +Tbar_truthₙ = @lift file["sol_T"][:, $n] +Sbar_truthₙ = @lift file["sol_S"][:, $n] + +title_str = @lift "Time: $(round(Tbar_data.times[$n] / 86400, digits=3)) days" + +lines!(axT, Tbarₙ, zC, label="Oceananigans") +lines!(axS, Sbarₙ, zC, label="Oceananigans") + +lines!(axT, Tbar_truthₙ, zC, label="Truth") +lines!(axS, Sbar_truthₙ, zC, label="Truth") + +axislegend(axT, position = :lb) +Label(fig[0, :], title_str, tellwidth = false) + +GLMakie.record(fig, "./NN_1D_validation.mp4", 1:Nt, framerate=60, px_per_unit=4) do nn + @info nn + n[] = nn +end + +display(fig) +#%% +close(file) \ No newline at end of file diff --git a/validate_NN_nof_BBLRifirstzone510_1D_model.jl b/validate_NN_nof_BBLRifirstzone510_1D_model.jl new file mode 100644 index 0000000000..54f4fdd9df --- /dev/null +++ b/validate_NN_nof_BBLRifirstzone510_1D_model.jl @@ -0,0 +1,249 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +include("NN_closure_global_nof_BBLRifirstzone510.jl") +include("xin_kai_vertical_diffusivity_local_2step.jl") +include("feature_scaling.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using Oceananigans.TimeSteppers: update_state! +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 + + +# Architecture +model_architecture = CPU() + +file = jldopen("model_inference_run_nof_BBLRifirstzone510.jld2", "r") + +# number of grid points +const Nz = file["Nz"] +const Lz = file["Lz"] + +grid = RectilinearGrid(model_architecture, + topology = (Flat, Flat, Bounded), + size = Nz, + halo = 3, + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const dTdz = file["dTdz"] +const dSdz = file["dSdz"] + +const T_surface = file["T_surface"] +const S_surface = file["S_surface"] + +T_bcs = FieldBoundaryConditions(top=FluxBoundaryCondition(file["wT_top"])) +S_bcs = FieldBoundaryConditions(top=FluxBoundaryCondition(file["wS_top"])) + +##### +##### Coriolis +##### + +const f₀ = file["f₀"] +coriolis = FPlane(f=f₀) + +##### +##### Forcing and initial condition +##### +T_initial(z) = dTdz * z + T_surface +S_initial(z) = dSdz * z + S_surface + +nn_closure = NNFluxClosure(model_architecture) +base_closure = XinKaiLocalVerticalDiffusivity() + +##### +##### Model building +##### + +@info "Building a model..." + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = ImplicitFreeSurface(), + momentum_advection = WENO(grid = grid), + tracer_advection = WENO(grid = grid), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = (base_closure, nn_closure), + tracers = (:T, :S), + boundary_conditions = (; T = T_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(z) = T_initial(z) + 1e-6 * noise(z) +S_initial_noisy(z) = S_initial(z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = file["Δt"] +stop_time = file["τ"] +# stop_time = 100minutes + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(20)) +# simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(1)) + +##### +##### Diagnostics +##### + +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +wT_residual, wS_residual = model.diffusivity_fields[2].wT, model.diffusivity_fields[2].wS +ν, κ = model.diffusivity_fields[1].κᵘ, model.diffusivity_fields[1].κᶜ + +Tbar = Field(Average(T, dims = (1,2))) +Sbar = Field(Average(S, dims = (1,2))) + +averaged_outputs = (; Tbar, Sbar, wT_residual, wS_residual, ν, κ) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:jld2] = JLD2OutputWriter(model, averaged_outputs, + filename = "NN_1D_channel_averages_nof_BBLRifirstzone510", + schedule = TimeInterval(Δt₀), + overwrite_existing = true) + +@info "Running the simulation..." + +try + run!(simulation, pickup = false) +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +# ##### +# ##### Visualization +# ##### +#%% +using GLMakie + +Tbar_data = FieldTimeSeries("./NN_1D_channel_averages_nof_BBLRifirstzone510.jld2", "Tbar") +Sbar_data = FieldTimeSeries("./NN_1D_channel_averages_nof_BBLRifirstzone510.jld2", "Sbar") +wT_residual_data = FieldTimeSeries("./NN_1D_channel_averages_nof_BBLRifirstzone510.jld2", "wT_residual") +wS_residual_data = FieldTimeSeries("./NN_1D_channel_averages_nof_BBLRifirstzone510.jld2", "wS_residual") +ν_data = FieldTimeSeries("./NN_1D_channel_averages_nof_BBLRifirstzone510.jld2", "ν") +κ_data = FieldTimeSeries("./NN_1D_channel_averages_nof_BBLRifirstzone510.jld2", "κ") + +#%% +zC = znodes(Tbar_data.grid, Center()) +zF = znodes(Tbar_data.grid, Face()) + +Nt = length(Tbar_data.times) + +fig = Figure(size = (1500, 1000)) +axT = GLMakie.Axis(fig[1, 1], xlabel = "T (°C)", ylabel = "z (m)") +axS = GLMakie.Axis(fig[2, 1], xlabel = "S (g kg⁻¹)", ylabel = "z (m)") +axwT_residual = GLMakie.Axis(fig[1, 2], xlabel = "wT residual", ylabel = "z (m)") +axwS_residual = GLMakie.Axis(fig[2, 2], xlabel = "wS residual", ylabel = "z (m)") +axν = GLMakie.Axis(fig[1, 3], xlabel = "ν (m² s⁻¹)", ylabel = "z (m)", xscale=log10) +axκ = GLMakie.Axis(fig[2, 3], xlabel = "κ (m² s⁻¹)", ylabel = "z (m)", xscale=log10) + +slider = Slider(fig[3, :], range=2:Nt) +n = slider.value + +Tbarₙ = @lift interior(Tbar_data[$n], 1, 1, :) +Sbarₙ = @lift interior(Sbar_data[$n], 1, 1, :) +wT_residualₙ = @lift interior(wT_residual_data[$n], 1, 1, :) +wS_residualₙ = @lift interior(wS_residual_data[$n], 1, 1, :) +νₙ = @lift interior(ν_data[$n], 1, 1, 2:32) +κₙ = @lift interior(κ_data[$n], 1, 1, 2:32) + +Tbar_truthₙ = @lift file["sol_T"][:, $n] +Sbar_truthₙ = @lift file["sol_S"][:, $n] +wT_residual_truthₙ = @lift file["sol_wT_residual_unscaled"][:, $n] +wS_residual_truthₙ = @lift file["sol_wS_residual_unscaled"][:, $n] +ν_truthₙ = @lift file["sol_ν"][2:32, $n] +κ_truthₙ = @lift file["sol_κ"][2:32, $n] + +title_str = @lift "Time: $(round(Tbar_data.times[$n] / 86400, digits=3)) days" + +wTlim = (minimum(interior(wT_residual_data)), maximum(interior(wT_residual_data))) +wSlim = (minimum(interior(wS_residual_data)), maximum(interior(wS_residual_data))) + +νlim = (1e-6, 10) +κlim = (1e-6, 10) + +lines!(axT, Tbarₙ, zC, label="Oceananigans") +lines!(axS, Sbarₙ, zC, label="Oceananigans") + +lines!(axwT_residual, wT_residualₙ, zF, label="Oceananigans") +lines!(axwS_residual, wS_residualₙ, zF, label="Oceananigans") + +lines!(axν, νₙ, zF[2:32], label="Oceananigans") +lines!(axκ, κₙ, zF[2:32], label="Oceananigans") + +lines!(axT, Tbar_truthₙ, zC, label="Truth") +lines!(axS, Sbar_truthₙ, zC, label="Truth") + +lines!(axwT_residual, wT_residual_truthₙ, zF, label="Truth") +lines!(axwS_residual, wS_residual_truthₙ, zF, label="Truth") + +lines!(axν, ν_truthₙ, zF[2:32], label="Truth") +lines!(axκ, κ_truthₙ, zF[2:32], label="Truth") + +xlims!(axwT_residual, wTlim) +xlims!(axwS_residual, wSlim) +xlims!(axν, νlim) +xlims!(axκ, κlim) + +linkyaxes!(axT, axS, axwT_residual, axwS_residual, axν, axκ) + +axislegend(axT, position = :lb) +Label(fig[0, :], title_str, tellwidth = false) + +# GLMakie.record(fig, "./NN_1D_validation_nof_BBL.mp4", 1:Nt, framerate=60, px_per_unit=4) do nn +# @info nn +# n[] = nn +# end + +display(fig) +#%% +close(file) \ No newline at end of file diff --git a/xin_kai_vertical_diffusivity.jl b/xin_kai_vertical_diffusivity.jl new file mode 100644 index 0000000000..29a482aa75 --- /dev/null +++ b/xin_kai_vertical_diffusivity.jl @@ -0,0 +1,255 @@ +using Oceananigans +using Oceananigans.Architectures: architecture +using Oceananigans.BuoyancyModels: ∂z_b +using Oceananigans.Operators +using Oceananigans.BoundaryConditions +using Oceananigans.Grids: inactive_node +using Oceananigans.Operators: ℑzᵃᵃᶜ, ℑxyᶠᶠᵃ, ℑxyᶜᶜᵃ + +using Adapt + +using KernelAbstractions: @index, @kernel +using KernelAbstractions.Extras.LoopInfo: @unroll + +using Oceananigans.TurbulenceClosures: + tapering_factorᶠᶜᶜ, + tapering_factorᶜᶠᶜ, + tapering_factorᶜᶜᶠ, + tapering_factor, + SmallSlopeIsopycnalTensor, + AbstractScalarDiffusivity, + ExplicitTimeDiscretization, + FluxTapering, + isopycnal_rotation_tensor_xz_ccf, + isopycnal_rotation_tensor_yz_ccf, + isopycnal_rotation_tensor_zz_ccf + +import Oceananigans.TurbulenceClosures: + compute_diffusivities!, + DiffusivityFields, + viscosity, + diffusivity, + getclosure, + top_buoyancy_flux, + diffusive_flux_x, + diffusive_flux_y, + diffusive_flux_z, + viscous_flux_ux, + viscous_flux_vx, + viscous_flux_uy, + viscous_flux_vy + +using Oceananigans.Utils: launch! +using Oceananigans.Coriolis: fᶠᶠᵃ +using Oceananigans.Operators +using Oceananigans.BuoyancyModels: ∂x_b, ∂y_b, ∂z_b + +using Oceananigans.TurbulenceClosures +using Oceananigans.TurbulenceClosures: HorizontalFormulation, VerticalFormulation, AbstractScalarDiffusivity +using Oceananigans.TurbulenceClosures: AbstractScalarBiharmonicDiffusivity +using Oceananigans.Operators +using Oceananigans.Operators: Δxᶜᶜᶜ, Δyᶜᶜᶜ, ℑxyᶜᶜᵃ, ζ₃ᶠᶠᶜ, div_xyᶜᶜᶜ +using Oceananigans.Operators: Δx, Δy +using Oceananigans.Operators: ℑxyz + +using Oceananigans.Operators: ℑxyzᶜᶜᶠ, ℑyzᵃᶜᶠ, ℑxzᶜᵃᶠ, Δxᶜᶜᶜ, Δyᶜᶜᶜ + +struct XinKaiVerticalDiffusivity{TD, FT} <: AbstractScalarDiffusivity{TD, VerticalFormulation, 2} + ν₀ :: FT + νˢʰ :: FT + νᶜⁿ :: FT + Cᵉⁿ :: FT + Prₜ :: FT + Riᶜ :: FT + δRi :: FT + Q₀ :: FT + δQ :: FT +end + +function XinKaiVerticalDiffusivity{TD}(ν₀ :: FT, + νˢʰ :: FT, + νᶜⁿ :: FT, + Cᵉⁿ :: FT, + Prₜ :: FT, + Riᶜ :: FT, + δRi :: FT, + Q₀ :: FT, + δQ :: FT) where {TD, FT} + + return XinKaiVerticalDiffusivity{TD, FT}(ν₀, νˢʰ, νᶜⁿ, Cᵉⁿ, Prₜ, Riᶜ, δRi, Q₀, δQ) +end + +function XinKaiVerticalDiffusivity(time_discretization = VerticallyImplicitTimeDiscretization(), + FT = Float64; + ν₀ = 1e-5, + νˢʰ = 0.0885, + νᶜⁿ = 4.3668, + Cᵉⁿ = 0.2071, + Prₜ = 1.207, + Riᶜ = -0.21982, + δRi = 8.342e-4, + Q₀ = 0.08116, + δQ = 0.02622) + + TD = typeof(time_discretization) + + return XinKaiVerticalDiffusivity{TD}(convert(FT, ν₀), + convert(FT, νˢʰ), + convert(FT, νᶜⁿ), + convert(FT, Cᵉⁿ), + convert(FT, Prₜ), + convert(FT, Riᶜ), + convert(FT, δRi), + convert(FT, Q₀), + convert(FT, δQ)) +end + +XinKaiVerticalDiffusivity(FT::DataType; kw...) = + XinKaiVerticalDiffusivity(VerticallyImplicitTimeDiscretization(), FT; kw...) + +Adapt.adapt_structure(to, clo::XinKaiVerticalDiffusivity{TD, FT}) where {TD, FT} = + XinKaiVerticalDiffusivity{TD, FT}(clo.ν₀, clo.νˢʰ, clo.νᶜⁿ, clo.Cᵉⁿ, clo.Prₜ, clo.Riᶜ, clo.δRi, clo.Q₀, clo.δQ) + +##### +##### Diffusivity field utilities +##### + +const RBVD = XinKaiVerticalDiffusivity +const RBVDArray = AbstractArray{<:RBVD} +const FlavorOfXKVD = Union{RBVD, RBVDArray} +const c = Center() +const f = Face() + +@inline viscosity_location(::FlavorOfXKVD) = (c, c, f) +@inline diffusivity_location(::FlavorOfXKVD) = (c, c, f) + +@inline viscosity(::FlavorOfXKVD, diffusivities) = diffusivities.κᵘ +@inline diffusivity(::FlavorOfXKVD, diffusivities, id) = diffusivities.κᶜ + +with_tracers(tracers, closure::FlavorOfXKVD) = closure + +# Note: computing diffusivities at cell centers for now. +function DiffusivityFields(grid, tracer_names, bcs, closure::FlavorOfXKVD) + κᶜ = Field((Center, Center, Face), grid) + κᵘ = Field((Center, Center, Face), grid) + N² = Field((Center, Center, Face), grid) + Ri = Field((Center, Center, Face), grid) + return (; κᶜ, κᵘ, Ri, N²) +end + +function compute_diffusivities!(diffusivities, closure::FlavorOfXKVD, model; parameters = :xyz) + arch = model.architecture + grid = model.grid + clock = model.clock + tracers = model.tracers + buoyancy = model.buoyancy + velocities = model.velocities + top_tracer_bcs = NamedTuple(c => tracers[c].boundary_conditions.top for c in propertynames(tracers)) + + Nx_in, Ny_in, Nz_in = total_size(diffusivities.κᶜ) + ox_in, oy_in, oz_in = diffusivities.κᶜ.data.offsets + + kp = KernelParameters((Nx_in, Ny_in, Nz_in), (ox_in, oy_in, oz_in)) + + launch!(arch, grid, kp, compute_N²!, diffusivities, grid, closure, tracers, buoyancy) + launch!(arch, grid, kp, compute_ri_number!, diffusivities, grid, closure, velocities) + + # Use `only_local_halos` to ensure that no communication occurs during + # this call to fill_halo_regions! + fill_halo_regions!(diffusivities.Ri; only_local_halos=true) + + launch!(arch, grid, kp, + compute_xinkai_diffusivities!, + diffusivities, + grid, + closure, + velocities, + tracers, + buoyancy, + top_tracer_bcs, + clock) + + return nothing +end + +@inline ϕ²(i, j, k, grid, ϕ, args...) = ϕ(i, j, k, grid, args...)^2 + +@inline function shear_squaredᶜᶜᶠ(i, j, k, grid, velocities) + ∂z_u² = ℑxᶜᵃᵃ(i, j, k, grid, ϕ², ∂zᶠᶜᶠ, velocities.u) + ∂z_v² = ℑyᵃᶜᵃ(i, j, k, grid, ϕ², ∂zᶜᶠᶠ, velocities.v) + return ∂z_u² + ∂z_v² +end + +@inline function N²ᶜᶜᶠ(i, j, k, grid, buoyancy, tracers) + return ∂z_b(i, j, k, grid, buoyancy, tracers) +end + +@inline function Riᶜᶜᶠ(i, j, k, grid, velocities, diffusivities) + S² = shear_squaredᶜᶜᶠ(i, j, k, grid, velocities) + N² = diffusivities.N²[i, j, k] + Ri = N² / S² + + # Clip N² and avoid NaN + return ifelse(N² == 0, zero(grid), Ri) +end + +const c = Center() +const f = Face() + +@kernel function compute_N²!(diffusivities, grid, closure::FlavorOfXKVD, tracers, buoyancy) + i, j, k = @index(Global, NTuple) + @inbounds diffusivities.N²[i, j, k] = N²ᶜᶜᶠ(i, j, k, grid, buoyancy, tracers) +end + +@kernel function compute_ri_number!(diffusivities, grid, closure::FlavorOfXKVD, velocities) + i, j, k = @index(Global, NTuple) + @inbounds diffusivities.Ri[i, j, k] = Riᶜᶜᶠ(i, j, k, grid, velocities, diffusivities) +end + +@kernel function compute_xinkai_diffusivities!(diffusivities, grid, closure::FlavorOfXKVD, + velocities, tracers, buoyancy, tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + _compute_xinkai_diffusivities!(i, j, k, diffusivities, grid, closure, + velocities, tracers, buoyancy, tracer_bcs, clock) +end + +@inline function _compute_xinkai_diffusivities!(i, j, k, diffusivities, grid, closure, + velocities, tracers, buoyancy, tracer_bcs, clock) + + # Ensure this works with "ensembles" of closures, in addition to ordinary single closures + closure_ij = getclosure(i, j, closure) + + ν₀ = closure_ij.ν₀ + νˢʰ = closure_ij.νˢʰ + νᶜⁿ = closure_ij.νᶜⁿ + Cᵉⁿ = closure_ij.Cᵉⁿ + Prₜ = closure_ij.Prₜ + Riᶜ = closure_ij.Riᶜ + δRi = closure_ij.δRi + Q₀ = closure_ij.Q₀ + δQ = closure_ij.δQ + + Qᵇ = top_buoyancy_flux(i, j, grid, buoyancy, tracer_bcs, clock, merge(velocities, tracers)) + + # (Potentially) apply a horizontal filter to the Richardson number + Ri = ℑxyᶜᶜᵃ(i, j, k, grid, ℑxyᶠᶠᵃ, diffusivities.Ri) + Ri_above = ℑxyᶜᶜᵃ(i, j, k + 1, grid, ℑxyᶠᶠᵃ, diffusivities.Ri) + N² = ℑxyᶜᶜᵃ(i, j, k, grid, ℑxyᶠᶠᵃ, diffusivities.N²) + + # Conditions + convecting = Ri < 0 # applies regardless of Qᵇ + entraining = (Ri > 0) & (Ri_above < 0) & (Qᵇ > 0) + + # Convective adjustment diffusivity + ν_local = ifelse(convecting, - (νᶜⁿ - νˢʰ) / 2 * tanh(Ri / δRi) + νˢʰ, clamp(Riᶜ * Ri + νˢʰ + ν₀, ν₀, νˢʰ)) + + # Entrainment diffusivity + x = Qᵇ / (N² + 1e-11) + ν_nonlocal = ifelse(entraining, Cᵉⁿ * νᶜⁿ * 0.5 * (tanh((x - Q₀) / δQ) + 1), 0) + + # Update by averaging in time + @inbounds diffusivities.κᵘ[i, j, k] = ifelse((k <= 1) | (k >= grid.Nz+1), 0, ν_local + ν_nonlocal) + @inbounds diffusivities.κᶜ[i, j, k] = ifelse((k <= 1) | (k >= grid.Nz+1), 0, (ν_local + ν_nonlocal) / Prₜ) + + return nothing +end diff --git a/xin_kai_vertical_diffusivity_2Pr.jl b/xin_kai_vertical_diffusivity_2Pr.jl new file mode 100644 index 0000000000..ca28a5fdfd --- /dev/null +++ b/xin_kai_vertical_diffusivity_2Pr.jl @@ -0,0 +1,268 @@ +using Oceananigans +using Oceananigans.Architectures: architecture +using Oceananigans.BuoyancyModels: ∂z_b +using Oceananigans.Operators +using Oceananigans.BoundaryConditions +using Oceananigans.Grids: inactive_node +using Oceananigans.Grids: total_size +using Oceananigans.Utils: KernelParameters +using Oceananigans.Operators: ℑzᵃᵃᶜ, ℑxyᶠᶠᵃ, ℑxyᶜᶜᵃ + +using Adapt + +using KernelAbstractions: @index, @kernel +using KernelAbstractions.Extras.LoopInfo: @unroll + +using Oceananigans.TurbulenceClosures: + tapering_factorᶠᶜᶜ, + tapering_factorᶜᶠᶜ, + tapering_factorᶜᶜᶠ, + tapering_factor, + SmallSlopeIsopycnalTensor, + AbstractScalarDiffusivity, + ExplicitTimeDiscretization, + FluxTapering, + isopycnal_rotation_tensor_xz_ccf, + isopycnal_rotation_tensor_yz_ccf, + isopycnal_rotation_tensor_zz_ccf + +import Oceananigans.TurbulenceClosures: + compute_diffusivities!, + DiffusivityFields, + viscosity, + diffusivity, + getclosure, + top_buoyancy_flux, + diffusive_flux_x, + diffusive_flux_y, + diffusive_flux_z, + viscous_flux_ux, + viscous_flux_vx, + viscous_flux_uy, + viscous_flux_vy + +using Oceananigans.Utils: launch! +using Oceananigans.Coriolis: fᶠᶠᵃ +using Oceananigans.Operators +using Oceananigans.BuoyancyModels: ∂x_b, ∂y_b, ∂z_b + +using Oceananigans.TurbulenceClosures +using Oceananigans.TurbulenceClosures: HorizontalFormulation, VerticalFormulation, AbstractScalarDiffusivity +using Oceananigans.TurbulenceClosures: AbstractScalarBiharmonicDiffusivity +using Oceananigans.Operators +using Oceananigans.Operators: Δxᶜᶜᶜ, Δyᶜᶜᶜ, ℑxyᶜᶜᵃ, ζ₃ᶠᶠᶜ, div_xyᶜᶜᶜ +using Oceananigans.Operators: Δx, Δy +using Oceananigans.Operators: ℑxyz + +using Oceananigans.Operators: ℑxyzᶜᶜᶠ, ℑyzᵃᶜᶠ, ℑxzᶜᵃᶠ, Δxᶜᶜᶜ, Δyᶜᶜᶜ + +struct XinKaiVerticalDiffusivity{TD, FT} <: AbstractScalarDiffusivity{TD, VerticalFormulation, 2} + ν₀ :: FT + νˢʰ :: FT + νᶜⁿ :: FT + Cᵉⁿ :: FT + Pr_convₜ :: FT + Pr_shearₜ :: FT + Riᶜ :: FT + δRi :: FT + Q₀ :: FT + δQ :: FT +end + +function XinKaiVerticalDiffusivity{TD}(ν₀ :: FT, + νˢʰ :: FT, + νᶜⁿ :: FT, + Cᵉⁿ :: FT, + Pr_convₜ :: FT, + Pr_shearₜ :: FT, + Riᶜ :: FT, + δRi :: FT, + Q₀ :: FT, + δQ :: FT) where {TD, FT} + + return XinKaiVerticalDiffusivity{TD, FT}(ν₀, νˢʰ, νᶜⁿ, Cᵉⁿ, Pr_convₜ, Pr_shearₜ, Riᶜ, δRi, Q₀, δQ) +end + +function XinKaiVerticalDiffusivity(time_discretization = VerticallyImplicitTimeDiscretization(), + FT = Float64; + ν₀ = 1e-5, + νˢʰ = 0.07738088203341657, + νᶜⁿ = 0.533741914196933, + Cᵉⁿ = 0.5196272898085122, + Pr_convₜ = 0.01632117727992826, + Pr_shearₜ = 1.8499159986192901, + Riᶜ = 0.4923581673007292, + δRi = 0.00012455519496760374, + Q₀ = 0.048232078296680234, + δQ = 0.01884938627051353) + + TD = typeof(time_discretization) + + return XinKaiVerticalDiffusivity{TD}(convert(FT, ν₀), + convert(FT, νˢʰ), + convert(FT, νᶜⁿ), + convert(FT, Cᵉⁿ), + convert(FT, Pr_convₜ), + convert(FT, Pr_shearₜ), + convert(FT, Riᶜ), + convert(FT, δRi), + convert(FT, Q₀), + convert(FT, δQ)) +end + +XinKaiVerticalDiffusivity(FT::DataType; kw...) = + XinKaiVerticalDiffusivity(VerticallyImplicitTimeDiscretization(), FT; kw...) + +Adapt.adapt_structure(to, clo::XinKaiVerticalDiffusivity{TD, FT}) where {TD, FT} = + XinKaiVerticalDiffusivity{TD, FT}(clo.ν₀, clo.νˢʰ, clo.νᶜⁿ, clo.Cᵉⁿ, clo.Pr_convₜ, clo.Pr_shearₜ, clo.Riᶜ, clo.δRi, clo.Q₀, clo.δQ) + +##### +##### Diffusivity field utilities +##### + +const RBVD = XinKaiVerticalDiffusivity +const RBVDArray = AbstractArray{<:RBVD} +const FlavorOfXKVD = Union{RBVD, RBVDArray} +const c = Center() +const f = Face() + +@inline viscosity_location(::FlavorOfXKVD) = (c, c, f) +@inline diffusivity_location(::FlavorOfXKVD) = (c, c, f) + +@inline viscosity(::FlavorOfXKVD, diffusivities) = diffusivities.κᵘ +@inline diffusivity(::FlavorOfXKVD, diffusivities, id) = diffusivities.κᶜ + +with_tracers(tracers, closure::FlavorOfXKVD) = closure + +# Note: computing diffusivities at cell centers for now. +function DiffusivityFields(grid, tracer_names, bcs, closure::FlavorOfXKVD) + κᶜ = Field((Center, Center, Face), grid) + κᵘ = Field((Center, Center, Face), grid) + N² = Field((Center, Center, Face), grid) + Ri = Field((Center, Center, Face), grid) + return (; κᶜ, κᵘ, Ri, N²) +end + +function compute_diffusivities!(diffusivities, closure::FlavorOfXKVD, model; parameters = :xyz) + arch = model.architecture + grid = model.grid + clock = model.clock + tracers = model.tracers + buoyancy = model.buoyancy + velocities = model.velocities + top_tracer_bcs = NamedTuple(c => tracers[c].boundary_conditions.top for c in propertynames(tracers)) + + Nx_in, Ny_in, Nz_in = total_size(diffusivities.κᶜ) + ox_in, oy_in, oz_in = diffusivities.κᶜ.data.offsets + + kp = KernelParameters((Nx_in, Ny_in, Nz_in), (ox_in, oy_in, oz_in)) + + launch!(arch, grid, kp, compute_N²!, diffusivities, grid, closure, tracers, buoyancy) + launch!(arch, grid, kp, compute_ri_number!, diffusivities, grid, closure, velocities) + + # Use `only_local_halos` to ensure that no communication occurs during + # this call to fill_halo_regions! + fill_halo_regions!(diffusivities.Ri; only_local_halos=true) + + launch!(arch, grid, kp, + compute_xinkai_diffusivities!, + diffusivities, + grid, + closure, + velocities, + tracers, + buoyancy, + top_tracer_bcs, + clock) + + return nothing +end + +@inline ϕ²(i, j, k, grid, ϕ, args...) = ϕ(i, j, k, grid, args...)^2 + +@inline function shear_squaredᶜᶜᶠ(i, j, k, grid, velocities) + ∂z_u² = ℑxᶜᵃᵃ(i, j, k, grid, ϕ², ∂zᶠᶜᶠ, velocities.u) + ∂z_v² = ℑyᵃᶜᵃ(i, j, k, grid, ϕ², ∂zᶜᶠᶠ, velocities.v) + return ∂z_u² + ∂z_v² +end + +@inline function N²ᶜᶜᶠ(i, j, k, grid, buoyancy, tracers) + return ∂z_b(i, j, k, grid, buoyancy, tracers) +end + +@inline function Riᶜᶜᶠ(i, j, k, grid, velocities, diffusivities) + S² = shear_squaredᶜᶜᶠ(i, j, k, grid, velocities) + N² = diffusivities.N²[i, j, k] + Ri = N² / S² + + # Clip N² and avoid NaN + return ifelse(N² == 0, zero(grid), Ri) +end + +const c = Center() +const f = Face() + +@kernel function compute_N²!(diffusivities, grid, closure::FlavorOfXKVD, tracers, buoyancy) + i, j, k = @index(Global, NTuple) + @inbounds diffusivities.N²[i, j, k] = N²ᶜᶜᶠ(i, j, k, grid, buoyancy, tracers) +end + +@kernel function compute_ri_number!(diffusivities, grid, closure::FlavorOfXKVD, velocities) + i, j, k = @index(Global, NTuple) + @inbounds diffusivities.Ri[i, j, k] = Riᶜᶜᶠ(i, j, k, grid, velocities, diffusivities) +end + +@kernel function compute_xinkai_diffusivities!(diffusivities, grid, closure::FlavorOfXKVD, + velocities, tracers, buoyancy, tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + _compute_xinkai_diffusivities!(i, j, k, diffusivities, grid, closure, + velocities, tracers, buoyancy, tracer_bcs, clock) +end + +@inline function _compute_xinkai_diffusivities!(i, j, k, diffusivities, grid, closure, + velocities, tracers, buoyancy, tracer_bcs, clock) + + # Ensure this works with "ensembles" of closures, in addition to ordinary single closures + closure_ij = getclosure(i, j, closure) + + ν₀ = closure_ij.ν₀ + νˢʰ = closure_ij.νˢʰ + νᶜⁿ = closure_ij.νᶜⁿ + Cᵉⁿ = closure_ij.Cᵉⁿ + Pr_convₜ = closure_ij.Pr_convₜ + Pr_shearₜ = closure_ij.Pr_shearₜ + Riᶜ = closure_ij.Riᶜ + δRi = closure_ij.δRi + Q₀ = closure_ij.Q₀ + δQ = closure_ij.δQ + + κ₀ = ν₀ / Pr_shearₜ + κˢʰ = νˢʰ / Pr_shearₜ + κᶜⁿ = νᶜⁿ / Pr_convₜ + + Qᵇ = top_buoyancy_flux(i, j, grid, buoyancy, tracer_bcs, clock, merge(velocities, tracers)) + + # (Potentially) apply a horizontal filter to the Richardson number + Ri = ℑxyᶜᶜᵃ(i, j, k, grid, ℑxyᶠᶠᵃ, diffusivities.Ri) + Ri_above = ℑxyᶜᶜᵃ(i, j, k + 1, grid, ℑxyᶠᶠᵃ, diffusivities.Ri) + N² = ℑxyᶜᶜᵃ(i, j, k, grid, ℑxyᶠᶠᵃ, diffusivities.N²) + + # Conditions + convecting = Ri < 0 # applies regardless of Qᵇ + entraining = (Ri > 0) & (Ri_above < 0) & (Qᵇ > 0) + + # Convective adjustment diffusivity + ν_local = ifelse(convecting, (νˢʰ - νᶜⁿ) * tanh(Ri / δRi) + νˢʰ, clamp((ν₀ - νˢʰ) * Ri / Riᶜ + νˢʰ, ν₀, νˢʰ)) + κ_local = ifelse(convecting, (κˢʰ - κᶜⁿ) * tanh(Ri / δRi) + κˢʰ, clamp((κ₀ - κˢʰ) * Ri / Riᶜ + κˢʰ, κ₀, κˢʰ)) + + # Entrainment diffusivity + x = Qᵇ / (N² + 1e-11) + ν_nonlocal = ifelse(entraining, Cᵉⁿ * νᶜⁿ * 0.5 * (tanh((x - Q₀) / δQ) + 1), 0) + κ_nonlocal = ifelse(entraining, ν_nonlocal / Pr_shearₜ, 0) + + # Update by averaging in time + @inbounds diffusivities.κᵘ[i, j, k] = ifelse(k <= 1 || k >= grid.Nz+1, 0, ν_local + ν_nonlocal) + @inbounds diffusivities.κᶜ[i, j, k] = ifelse(k <= 1 || k >= grid.Nz+1, 0, κ_local + κ_nonlocal) + + return nothing +end diff --git a/xin_kai_vertical_diffusivity_local.jl b/xin_kai_vertical_diffusivity_local.jl new file mode 100644 index 0000000000..f911b40f2d --- /dev/null +++ b/xin_kai_vertical_diffusivity_local.jl @@ -0,0 +1,242 @@ +using Oceananigans +using Oceananigans.Architectures: architecture +using Oceananigans.BuoyancyModels: ∂z_b +using Oceananigans.Operators +using Oceananigans.Grids: inactive_node, total_size +using Oceananigans.Operators: ℑzᵃᵃᶜ, ℑxyᶠᶠᵃ, ℑxyᶜᶜᵃ +using Oceananigans.Utils: KernelParameters + +using Adapt + +using KernelAbstractions: @index, @kernel +using KernelAbstractions.Extras.LoopInfo: @unroll + +using Oceananigans.TurbulenceClosures: + tapering_factorᶠᶜᶜ, + tapering_factorᶜᶠᶜ, + tapering_factorᶜᶜᶠ, + tapering_factor, + SmallSlopeIsopycnalTensor, + AbstractScalarDiffusivity, + ExplicitTimeDiscretization, + FluxTapering, + isopycnal_rotation_tensor_xz_ccf, + isopycnal_rotation_tensor_yz_ccf, + isopycnal_rotation_tensor_zz_ccf + +import Oceananigans.TurbulenceClosures: + compute_diffusivities!, + DiffusivityFields, + viscosity, + diffusivity, + getclosure, + top_buoyancy_flux, + diffusive_flux_x, + diffusive_flux_y, + diffusive_flux_z, + viscous_flux_ux, + viscous_flux_vx, + viscous_flux_uy, + viscous_flux_vy + +using Oceananigans.Utils: launch! +using Oceananigans.Coriolis: fᶠᶠᵃ +using Oceananigans.Operators +using Oceananigans.BuoyancyModels: ∂x_b, ∂y_b, ∂z_b + +using Oceananigans.TurbulenceClosures +using Oceananigans.TurbulenceClosures: HorizontalFormulation, VerticalFormulation, AbstractScalarDiffusivity +using Oceananigans.TurbulenceClosures: AbstractScalarBiharmonicDiffusivity +using Oceananigans.Operators +using Oceananigans.Operators: Δxᶜᶜᶜ, Δyᶜᶜᶜ, ℑxyᶜᶜᵃ, ζ₃ᶠᶠᶜ, div_xyᶜᶜᶜ +using Oceananigans.Operators: Δx, Δy +using Oceananigans.Operators: ℑxyz + +using Oceananigans.Operators: ℑxyzᶜᶜᶠ, ℑyzᵃᶜᶠ, ℑxzᶜᵃᶠ, Δxᶜᶜᶜ, Δyᶜᶜᶜ + +using Oceananigans.BoundaryConditions + +struct XinKaiLocalVerticalDiffusivity{TD, FT} <: AbstractScalarDiffusivity{TD, VerticalFormulation, 2} + ν₀ :: FT + νˢʰ :: FT + νᶜⁿ :: FT + Pr_convₜ :: FT + Pr_shearₜ :: FT + Riᶜ :: FT + δRi :: FT +end + +function XinKaiLocalVerticalDiffusivity{TD}(ν₀ :: FT, + νˢʰ :: FT, + νᶜⁿ :: FT, + Pr_convₜ :: FT, + Pr_shearₜ :: FT, + Riᶜ :: FT, + δRi :: FT) where {TD, FT} + + return XinKaiLocalVerticalDiffusivity{TD, FT}(ν₀, νˢʰ, νᶜⁿ, Pr_convₜ, Pr_shearₜ, Riᶜ, δRi) +end + +function XinKaiLocalVerticalDiffusivity(time_discretization = VerticallyImplicitTimeDiscretization(), + FT = Float64; + ν₀ = 1e-5, + νˢʰ = 0.04569735882746968, + νᶜⁿ = 0.47887785611155065, + Pr_convₜ = 0.1261854430705509, + Pr_shearₜ = 1.594794053970444, + Riᶜ = 0.9964350402840053, + δRi = 0.05635304878092709) + + TD = typeof(time_discretization) + + return XinKaiLocalVerticalDiffusivity{TD}(convert(FT, ν₀), + convert(FT, νˢʰ), + convert(FT, νᶜⁿ), + convert(FT, Pr_convₜ), + convert(FT, Pr_shearₜ), + convert(FT, Riᶜ), + convert(FT, δRi)) +end + +XinKaiLocalVerticalDiffusivity(FT::DataType; kw...) = + XinKaiLocalVerticalDiffusivity(VerticallyImplicitTimeDiscretization(), FT; kw...) + +Adapt.adapt_structure(to, clo::XinKaiLocalVerticalDiffusivity{TD, FT}) where {TD, FT} = + XinKaiLocalVerticalDiffusivity{TD, FT}(clo.ν₀, clo.νˢʰ, clo.νᶜⁿ, clo.Pr_convₜ, clo.Pr_shearₜ, clo.Riᶜ, clo.δRi) + +##### +##### Diffusivity field utilities +##### + +const RBVD = XinKaiLocalVerticalDiffusivity +const RBVDArray = AbstractArray{<:RBVD} +const FlavorOfXKVD = Union{RBVD, RBVDArray} +const c = Center() +const f = Face() + +@inline viscosity_location(::FlavorOfXKVD) = (c, c, f) +@inline diffusivity_location(::FlavorOfXKVD) = (c, c, f) + +@inline viscosity(::FlavorOfXKVD, diffusivities) = diffusivities.κᵘ +@inline diffusivity(::FlavorOfXKVD, diffusivities, id) = diffusivities.κᶜ + +with_tracers(tracers, closure::FlavorOfXKVD) = closure + +# Note: computing diffusivities at cell centers for now. +function DiffusivityFields(grid, tracer_names, bcs, closure::FlavorOfXKVD) + κᶜ = Field((Center, Center, Face), grid) + κᵘ = Field((Center, Center, Face), grid) + Ri = Field((Center, Center, Face), grid) + return (; κᶜ, κᵘ, Ri) +end + +function compute_diffusivities!(diffusivities, closure::FlavorOfXKVD, model; parameters = :xyz) + arch = model.architecture + grid = model.grid + clock = model.clock + tracers = model.tracers + buoyancy = model.buoyancy + velocities = model.velocities + top_tracer_bcs = NamedTuple(c => tracers[c].boundary_conditions.top for c in propertynames(tracers)) + + Nx_in, Ny_in, Nz_in = total_size(diffusivities.κᶜ) + ox_in, oy_in, oz_in = diffusivities.κᶜ.data.offsets + + kp = KernelParameters((Nx_in, Ny_in, Nz_in), (ox_in, oy_in, oz_in)) + + launch!(arch, grid, kp, + compute_ri_number!, + diffusivities, + grid, + closure, + velocities, + tracers, + buoyancy, + top_tracer_bcs, + clock) + + # Use `only_local_halos` to ensure that no communication occurs during + # this call to fill_halo_regions! + fill_halo_regions!(diffusivities.Ri; only_local_halos=true) + + launch!(arch, grid, kp, + compute_xinkai_diffusivities!, + diffusivities, + grid, + closure, + velocities, + tracers, + buoyancy, + top_tracer_bcs, + clock) + + return nothing +end + +@inline ϕ²(i, j, k, grid, ϕ, args...) = ϕ(i, j, k, grid, args...)^2 + +@inline function shear_squaredᶜᶜᶠ(i, j, k, grid, velocities) + ∂z_u² = ℑxᶜᵃᵃ(i, j, k, grid, ϕ², ∂zᶠᶜᶠ, velocities.u) + ∂z_v² = ℑyᵃᶜᵃ(i, j, k, grid, ϕ², ∂zᶜᶠᶠ, velocities.v) + return ∂z_u² + ∂z_v² +end + +@inline function Riᶜᶜᶠ(i, j, k, grid, velocities, buoyancy, tracers) + S² = shear_squaredᶜᶜᶠ(i, j, k, grid, velocities) + N² = ∂z_b(i, j, k, grid, buoyancy, tracers) + Ri = N² / S² + + # Clip N² and avoid NaN + return ifelse(N² == 0, zero(grid), Ri) +end + +const c = Center() +const f = Face() + +@kernel function compute_ri_number!(diffusivities, grid, closure::FlavorOfXKVD, + velocities, tracers, buoyancy, tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + @inbounds diffusivities.Ri[i, j, k] = Riᶜᶜᶠ(i, j, k, grid, velocities, buoyancy, tracers) +end + +@kernel function compute_xinkai_diffusivities!(diffusivities, grid, closure::FlavorOfXKVD, + velocities, tracers, buoyancy, tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + _compute_xinkai_diffusivities!(i, j, k, diffusivities, grid, closure, + velocities, tracers, buoyancy, tracer_bcs, clock) +end + +@inline function _compute_xinkai_diffusivities!(i, j, k, diffusivities, grid, closure, + velocities, tracers, buoyancy, tracer_bcs, clock) + + # Ensure this works with "ensembles" of closures, in addition to ordinary single closures + closure_ij = getclosure(i, j, closure) + + ν₀ = closure_ij.ν₀ + νˢʰ = closure_ij.νˢʰ + νᶜⁿ = closure_ij.νᶜⁿ + Pr_convₜ = closure_ij.Pr_convₜ + Pr_shearₜ = closure_ij.Pr_shearₜ + Riᶜ = closure_ij.Riᶜ + δRi = closure_ij.δRi + + κ₀ = ν₀ / Pr_shearₜ + κˢʰ = νˢʰ / Pr_shearₜ + κᶜⁿ = νᶜⁿ / Pr_convₜ + + # (Potentially) apply a horizontal filter to the Richardson number + Ri = ℑxyᶜᶜᵃ(i, j, k, grid, ℑxyᶠᶠᵃ, diffusivities.Ri) + + # Conditions + convecting = Ri < 0 # applies regardless of Qᵇ + + # Convective adjustment diffusivity + ν_local = ifelse(convecting, (νˢʰ - νᶜⁿ) * tanh(Ri / δRi) + νˢʰ, clamp((ν₀ - νˢʰ) * Ri / Riᶜ + νˢʰ, ν₀, νˢʰ)) + κ_local = ifelse(convecting, (κˢʰ - κᶜⁿ) * tanh(Ri / δRi) + κˢʰ, clamp((κ₀ - κˢʰ) * Ri / Riᶜ + κˢʰ, κ₀, κˢʰ)) + + # Update by averaging in time + @inbounds diffusivities.κᵘ[i, j, k] = ifelse(k <= 1 || k >= grid.Nz+1, 0, ν_local) + @inbounds diffusivities.κᶜ[i, j, k] = ifelse(k <= 1 || k >= grid.Nz+1, 0, κ_local) + + return nothing +end diff --git a/xin_kai_vertical_diffusivity_local_2step.jl b/xin_kai_vertical_diffusivity_local_2step.jl new file mode 100644 index 0000000000..2c26e09c9a --- /dev/null +++ b/xin_kai_vertical_diffusivity_local_2step.jl @@ -0,0 +1,242 @@ +using Oceananigans +using Oceananigans.Architectures: architecture +using Oceananigans.BuoyancyModels: ∂z_b +using Oceananigans.Operators +using Oceananigans.Grids: inactive_node, total_size +using Oceananigans.Operators: ℑzᵃᵃᶜ, ℑxyᶠᶠᵃ, ℑxyᶜᶜᵃ +using Oceananigans.Utils: KernelParameters + +using Adapt + +using KernelAbstractions: @index, @kernel +using KernelAbstractions.Extras.LoopInfo: @unroll + +using Oceananigans.TurbulenceClosures: + tapering_factorᶠᶜᶜ, + tapering_factorᶜᶠᶜ, + tapering_factorᶜᶜᶠ, + tapering_factor, + SmallSlopeIsopycnalTensor, + AbstractScalarDiffusivity, + ExplicitTimeDiscretization, + FluxTapering, + isopycnal_rotation_tensor_xz_ccf, + isopycnal_rotation_tensor_yz_ccf, + isopycnal_rotation_tensor_zz_ccf + +import Oceananigans.TurbulenceClosures: + compute_diffusivities!, + DiffusivityFields, + viscosity, + diffusivity, + getclosure, + top_buoyancy_flux, + diffusive_flux_x, + diffusive_flux_y, + diffusive_flux_z, + viscous_flux_ux, + viscous_flux_vx, + viscous_flux_uy, + viscous_flux_vy + +using Oceananigans.Utils: launch! +using Oceananigans.Coriolis: fᶠᶠᵃ +using Oceananigans.Operators +using Oceananigans.BuoyancyModels: ∂x_b, ∂y_b, ∂z_b + +using Oceananigans.TurbulenceClosures +using Oceananigans.TurbulenceClosures: HorizontalFormulation, VerticalFormulation, AbstractScalarDiffusivity +using Oceananigans.TurbulenceClosures: AbstractScalarBiharmonicDiffusivity +using Oceananigans.Operators +using Oceananigans.Operators: Δxᶜᶜᶜ, Δyᶜᶜᶜ, ℑxyᶜᶜᵃ, ζ₃ᶠᶠᶜ, div_xyᶜᶜᶜ +using Oceananigans.Operators: Δx, Δy +using Oceananigans.Operators: ℑxyz + +using Oceananigans.Operators: ℑxyzᶜᶜᶠ, ℑyzᵃᶜᶠ, ℑxzᶜᵃᶠ, Δxᶜᶜᶜ, Δyᶜᶜᶜ + +using Oceananigans.BoundaryConditions + +struct XinKaiLocalVerticalDiffusivity{TD, FT} <: AbstractScalarDiffusivity{TD, VerticalFormulation, 2} + ν₀ :: FT + νˢʰ :: FT + νᶜⁿ :: FT + Pr_convₜ :: FT + Pr_shearₜ :: FT + Riᶜ :: FT + δRi :: FT +end + +function XinKaiLocalVerticalDiffusivity{TD}(ν₀ :: FT, + νˢʰ :: FT, + νᶜⁿ :: FT, + Pr_convₜ :: FT, + Pr_shearₜ :: FT, + Riᶜ :: FT, + δRi :: FT) where {TD, FT} + + return XinKaiLocalVerticalDiffusivity{TD, FT}(ν₀, νˢʰ, νᶜⁿ, Pr_convₜ, Pr_shearₜ, Riᶜ, δRi) +end + +function XinKaiLocalVerticalDiffusivity(time_discretization = VerticallyImplicitTimeDiscretization(), + FT = Float64; + ν₀ = 1e-5, + νˢʰ = 0.0615914063656973, + νᶜⁿ = 1.5364711416895118, + Pr_convₜ = 0.18711389733455402, + Pr_shearₜ = 1.0842017486284887, + Riᶜ = 0.4366901962987793, + δRi = 0.0009691362773690692) + + TD = typeof(time_discretization) + + return XinKaiLocalVerticalDiffusivity{TD}(convert(FT, ν₀), + convert(FT, νˢʰ), + convert(FT, νᶜⁿ), + convert(FT, Pr_convₜ), + convert(FT, Pr_shearₜ), + convert(FT, Riᶜ), + convert(FT, δRi)) +end + +XinKaiLocalVerticalDiffusivity(FT::DataType; kw...) = + XinKaiLocalVerticalDiffusivity(VerticallyImplicitTimeDiscretization(), FT; kw...) + +Adapt.adapt_structure(to, clo::XinKaiLocalVerticalDiffusivity{TD, FT}) where {TD, FT} = + XinKaiLocalVerticalDiffusivity{TD, FT}(clo.ν₀, clo.νˢʰ, clo.νᶜⁿ, clo.Pr_convₜ, clo.Pr_shearₜ, clo.Riᶜ, clo.δRi) + +##### +##### Diffusivity field utilities +##### + +const RBVD = XinKaiLocalVerticalDiffusivity +const RBVDArray = AbstractArray{<:RBVD} +const FlavorOfXKVD = Union{RBVD, RBVDArray} +const c = Center() +const f = Face() + +@inline viscosity_location(::FlavorOfXKVD) = (c, c, f) +@inline diffusivity_location(::FlavorOfXKVD) = (c, c, f) + +@inline viscosity(::FlavorOfXKVD, diffusivities) = diffusivities.κᵘ +@inline diffusivity(::FlavorOfXKVD, diffusivities, id) = diffusivities.κᶜ + +with_tracers(tracers, closure::FlavorOfXKVD) = closure + +# Note: computing diffusivities at cell centers for now. +function DiffusivityFields(grid, tracer_names, bcs, closure::FlavorOfXKVD) + κᶜ = Field((Center, Center, Face), grid) + κᵘ = Field((Center, Center, Face), grid) + Ri = Field((Center, Center, Face), grid) + return (; κᶜ, κᵘ, Ri) +end + +function compute_diffusivities!(diffusivities, closure::FlavorOfXKVD, model; parameters = :xyz) + arch = model.architecture + grid = model.grid + clock = model.clock + tracers = model.tracers + buoyancy = model.buoyancy + velocities = model.velocities + top_tracer_bcs = NamedTuple(c => tracers[c].boundary_conditions.top for c in propertynames(tracers)) + + Nx_in, Ny_in, Nz_in = total_size(diffusivities.κᶜ) + ox_in, oy_in, oz_in = diffusivities.κᶜ.data.offsets + + kp = KernelParameters((Nx_in, Ny_in, Nz_in), (ox_in, oy_in, oz_in)) + + launch!(arch, grid, kp, + compute_ri_number!, + diffusivities, + grid, + closure, + velocities, + tracers, + buoyancy, + top_tracer_bcs, + clock) + + # Use `only_local_halos` to ensure that no communication occurs during + # this call to fill_halo_regions! + fill_halo_regions!(diffusivities.Ri; only_local_halos=true) + + launch!(arch, grid, kp, + compute_xinkai_diffusivities!, + diffusivities, + grid, + closure, + velocities, + tracers, + buoyancy, + top_tracer_bcs, + clock) + + return nothing +end + +@inline ϕ²(i, j, k, grid, ϕ, args...) = ϕ(i, j, k, grid, args...)^2 + +@inline function shear_squaredᶜᶜᶠ(i, j, k, grid, velocities) + ∂z_u² = ℑxᶜᵃᵃ(i, j, k, grid, ϕ², ∂zᶠᶜᶠ, velocities.u) + ∂z_v² = ℑyᵃᶜᵃ(i, j, k, grid, ϕ², ∂zᶜᶠᶠ, velocities.v) + return ∂z_u² + ∂z_v² +end + +@inline function Riᶜᶜᶠ(i, j, k, grid, velocities, buoyancy, tracers) + S² = shear_squaredᶜᶜᶠ(i, j, k, grid, velocities) + N² = ∂z_b(i, j, k, grid, buoyancy, tracers) + Ri = N² / S² + + # Clip N² and avoid NaN + return ifelse(N² == 0, zero(grid), Ri) +end + +const c = Center() +const f = Face() + +@kernel function compute_ri_number!(diffusivities, grid, closure::FlavorOfXKVD, + velocities, tracers, buoyancy, tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + @inbounds diffusivities.Ri[i, j, k] = Riᶜᶜᶠ(i, j, k, grid, velocities, buoyancy, tracers) +end + +@kernel function compute_xinkai_diffusivities!(diffusivities, grid, closure::FlavorOfXKVD, + velocities, tracers, buoyancy, tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + _compute_xinkai_diffusivities!(i, j, k, diffusivities, grid, closure, + velocities, tracers, buoyancy, tracer_bcs, clock) +end + +@inline function _compute_xinkai_diffusivities!(i, j, k, diffusivities, grid, closure, + velocities, tracers, buoyancy, tracer_bcs, clock) + + # Ensure this works with "ensembles" of closures, in addition to ordinary single closures + closure_ij = getclosure(i, j, closure) + + ν₀ = closure_ij.ν₀ + νˢʰ = closure_ij.νˢʰ + νᶜⁿ = closure_ij.νᶜⁿ + Pr_convₜ = closure_ij.Pr_convₜ + Pr_shearₜ = closure_ij.Pr_shearₜ + Riᶜ = closure_ij.Riᶜ + δRi = closure_ij.δRi + + κ₀ = ν₀ / Pr_shearₜ + κˢʰ = νˢʰ / Pr_shearₜ + κᶜⁿ = νᶜⁿ / Pr_convₜ + + # (Potentially) apply a horizontal filter to the Richardson number + Ri = ℑxyᶜᶜᵃ(i, j, k, grid, ℑxyᶠᶠᵃ, diffusivities.Ri) + + # Conditions + convecting = Ri < 0 # applies regardless of Qᵇ + + # Convective adjustment diffusivity + ν_local = ifelse(convecting, (νˢʰ - νᶜⁿ) * tanh(Ri / δRi) + νˢʰ, clamp((ν₀ - νˢʰ) * Ri / Riᶜ + νˢʰ, ν₀, νˢʰ)) + κ_local = ifelse(convecting, (κˢʰ - κᶜⁿ) * tanh(Ri / δRi) + κˢʰ, clamp((κ₀ - κˢʰ) * Ri / Riᶜ + κˢʰ, κ₀, κˢʰ)) + + # Update by averaging in time + @inbounds diffusivities.κᵘ[i, j, k] = ifelse(k <= 1 || k >= grid.Nz+1, 0, ν_local) + @inbounds diffusivities.κᶜ[i, j, k] = ifelse(k <= 1 || k >= grid.Nz+1, 0, κ_local) + + return nothing +end diff --git a/xin_kai_vertical_diffusivity_local_2step_new.jl b/xin_kai_vertical_diffusivity_local_2step_new.jl new file mode 100644 index 0000000000..9dda064e79 --- /dev/null +++ b/xin_kai_vertical_diffusivity_local_2step_new.jl @@ -0,0 +1,242 @@ +using Oceananigans +using Oceananigans.Architectures: architecture +using Oceananigans.BuoyancyModels: ∂z_b +using Oceananigans.Operators +using Oceananigans.Grids: inactive_node, total_size +using Oceananigans.Operators: ℑzᵃᵃᶜ, ℑxyᶠᶠᵃ, ℑxyᶜᶜᵃ +using Oceananigans.Utils: KernelParameters + +using Adapt + +using KernelAbstractions: @index, @kernel +using KernelAbstractions.Extras.LoopInfo: @unroll + +using Oceananigans.TurbulenceClosures: + tapering_factorᶠᶜᶜ, + tapering_factorᶜᶠᶜ, + tapering_factorᶜᶜᶠ, + tapering_factor, + SmallSlopeIsopycnalTensor, + AbstractScalarDiffusivity, + ExplicitTimeDiscretization, + FluxTapering, + isopycnal_rotation_tensor_xz_ccf, + isopycnal_rotation_tensor_yz_ccf, + isopycnal_rotation_tensor_zz_ccf + +import Oceananigans.TurbulenceClosures: + compute_diffusivities!, + DiffusivityFields, + viscosity, + diffusivity, + getclosure, + top_buoyancy_flux, + diffusive_flux_x, + diffusive_flux_y, + diffusive_flux_z, + viscous_flux_ux, + viscous_flux_vx, + viscous_flux_uy, + viscous_flux_vy + +using Oceananigans.Utils: launch! +using Oceananigans.Coriolis: fᶠᶠᵃ +using Oceananigans.Operators +using Oceananigans.BuoyancyModels: ∂x_b, ∂y_b, ∂z_b + +using Oceananigans.TurbulenceClosures +using Oceananigans.TurbulenceClosures: HorizontalFormulation, VerticalFormulation, AbstractScalarDiffusivity +using Oceananigans.TurbulenceClosures: AbstractScalarBiharmonicDiffusivity +using Oceananigans.Operators +using Oceananigans.Operators: Δxᶜᶜᶜ, Δyᶜᶜᶜ, ℑxyᶜᶜᵃ, ζ₃ᶠᶠᶜ, div_xyᶜᶜᶜ +using Oceananigans.Operators: Δx, Δy +using Oceananigans.Operators: ℑxyz + +using Oceananigans.Operators: ℑxyzᶜᶜᶠ, ℑyzᵃᶜᶠ, ℑxzᶜᵃᶠ, Δxᶜᶜᶜ, Δyᶜᶜᶜ + +using Oceananigans.BoundaryConditions + +struct XinKaiLocalVerticalDiffusivity{TD, FT} <: AbstractScalarDiffusivity{TD, VerticalFormulation, 2} + ν₀ :: FT + νˢʰ :: FT + νᶜⁿ :: FT + Pr_convₜ :: FT + Pr_shearₜ :: FT + Riᶜ :: FT + δRi :: FT +end + +function XinKaiLocalVerticalDiffusivity{TD}(ν₀ :: FT, + νˢʰ :: FT, + νᶜⁿ :: FT, + Pr_convₜ :: FT, + Pr_shearₜ :: FT, + Riᶜ :: FT, + δRi :: FT) where {TD, FT} + + return XinKaiLocalVerticalDiffusivity{TD, FT}(ν₀, νˢʰ, νᶜⁿ, Pr_convₜ, Pr_shearₜ, Riᶜ, δRi) +end + +function XinKaiLocalVerticalDiffusivity(time_discretization = VerticallyImplicitTimeDiscretization(), + FT = Float64; + ν₀ = 1e-5, + νˢʰ = 0.0615914063656973, + νᶜⁿ = 1.0514706176740092, + Pr_convₜ = 0.2684497234339729, + Pr_shearₜ = 1.0842017486284887, + Riᶜ = 0.4366901962987793, + δRi = 0.001484740489701266) + + TD = typeof(time_discretization) + + return XinKaiLocalVerticalDiffusivity{TD}(convert(FT, ν₀), + convert(FT, νˢʰ), + convert(FT, νᶜⁿ), + convert(FT, Pr_convₜ), + convert(FT, Pr_shearₜ), + convert(FT, Riᶜ), + convert(FT, δRi)) +end + +XinKaiLocalVerticalDiffusivity(FT::DataType; kw...) = + XinKaiLocalVerticalDiffusivity(VerticallyImplicitTimeDiscretization(), FT; kw...) + +Adapt.adapt_structure(to, clo::XinKaiLocalVerticalDiffusivity{TD, FT}) where {TD, FT} = + XinKaiLocalVerticalDiffusivity{TD, FT}(clo.ν₀, clo.νˢʰ, clo.νᶜⁿ, clo.Pr_convₜ, clo.Pr_shearₜ, clo.Riᶜ, clo.δRi) + +##### +##### Diffusivity field utilities +##### + +const RBVD = XinKaiLocalVerticalDiffusivity +const RBVDArray = AbstractArray{<:RBVD} +const FlavorOfXKVD = Union{RBVD, RBVDArray} +const c = Center() +const f = Face() + +@inline viscosity_location(::FlavorOfXKVD) = (c, c, f) +@inline diffusivity_location(::FlavorOfXKVD) = (c, c, f) + +@inline viscosity(::FlavorOfXKVD, diffusivities) = diffusivities.κᵘ +@inline diffusivity(::FlavorOfXKVD, diffusivities, id) = diffusivities.κᶜ + +with_tracers(tracers, closure::FlavorOfXKVD) = closure + +# Note: computing diffusivities at cell centers for now. +function DiffusivityFields(grid, tracer_names, bcs, closure::FlavorOfXKVD) + κᶜ = Field((Center, Center, Face), grid) + κᵘ = Field((Center, Center, Face), grid) + Ri = Field((Center, Center, Face), grid) + return (; κᶜ, κᵘ, Ri) +end + +function compute_diffusivities!(diffusivities, closure::FlavorOfXKVD, model; parameters = :xyz) + arch = model.architecture + grid = model.grid + clock = model.clock + tracers = model.tracers + buoyancy = model.buoyancy + velocities = model.velocities + top_tracer_bcs = NamedTuple(c => tracers[c].boundary_conditions.top for c in propertynames(tracers)) + + Nx_in, Ny_in, Nz_in = total_size(diffusivities.κᶜ) + ox_in, oy_in, oz_in = diffusivities.κᶜ.data.offsets + + kp = KernelParameters((Nx_in, Ny_in, Nz_in), (ox_in, oy_in, oz_in)) + + launch!(arch, grid, kp, + compute_ri_number!, + diffusivities, + grid, + closure, + velocities, + tracers, + buoyancy, + top_tracer_bcs, + clock) + + # Use `only_local_halos` to ensure that no communication occurs during + # this call to fill_halo_regions! + fill_halo_regions!(diffusivities.Ri; only_local_halos=true) + + launch!(arch, grid, kp, + compute_xinkai_diffusivities!, + diffusivities, + grid, + closure, + velocities, + tracers, + buoyancy, + top_tracer_bcs, + clock) + + return nothing +end + +@inline ϕ²(i, j, k, grid, ϕ, args...) = ϕ(i, j, k, grid, args...)^2 + +@inline function shear_squaredᶜᶜᶠ(i, j, k, grid, velocities) + ∂z_u² = ℑxᶜᵃᵃ(i, j, k, grid, ϕ², ∂zᶠᶜᶠ, velocities.u) + ∂z_v² = ℑyᵃᶜᵃ(i, j, k, grid, ϕ², ∂zᶜᶠᶠ, velocities.v) + return ∂z_u² + ∂z_v² +end + +@inline function Riᶜᶜᶠ(i, j, k, grid, velocities, buoyancy, tracers) + S² = shear_squaredᶜᶜᶠ(i, j, k, grid, velocities) + N² = ∂z_b(i, j, k, grid, buoyancy, tracers) + Ri = N² / S² + + # Clip N² and avoid NaN + return ifelse(N² == 0, zero(grid), Ri) +end + +const c = Center() +const f = Face() + +@kernel function compute_ri_number!(diffusivities, grid, closure::FlavorOfXKVD, + velocities, tracers, buoyancy, tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + @inbounds diffusivities.Ri[i, j, k] = Riᶜᶜᶠ(i, j, k, grid, velocities, buoyancy, tracers) +end + +@kernel function compute_xinkai_diffusivities!(diffusivities, grid, closure::FlavorOfXKVD, + velocities, tracers, buoyancy, tracer_bcs, clock) + i, j, k = @index(Global, NTuple) + _compute_xinkai_diffusivities!(i, j, k, diffusivities, grid, closure, + velocities, tracers, buoyancy, tracer_bcs, clock) +end + +@inline function _compute_xinkai_diffusivities!(i, j, k, diffusivities, grid, closure, + velocities, tracers, buoyancy, tracer_bcs, clock) + + # Ensure this works with "ensembles" of closures, in addition to ordinary single closures + closure_ij = getclosure(i, j, closure) + + ν₀ = closure_ij.ν₀ + νˢʰ = closure_ij.νˢʰ + νᶜⁿ = closure_ij.νᶜⁿ + Pr_convₜ = closure_ij.Pr_convₜ + Pr_shearₜ = closure_ij.Pr_shearₜ + Riᶜ = closure_ij.Riᶜ + δRi = closure_ij.δRi + + κ₀ = ν₀ / Pr_shearₜ + κˢʰ = νˢʰ / Pr_shearₜ + κᶜⁿ = νᶜⁿ / Pr_convₜ + + # (Potentially) apply a horizontal filter to the Richardson number + Ri = ℑxyᶜᶜᵃ(i, j, k, grid, ℑxyᶠᶠᵃ, diffusivities.Ri) + + # Conditions + convecting = Ri < 0 # applies regardless of Qᵇ + + # Convective adjustment diffusivity + ν_local = ifelse(convecting, (νˢʰ - νᶜⁿ) * tanh(Ri / δRi) + νˢʰ, clamp((ν₀ - νˢʰ) * Ri / Riᶜ + νˢʰ, ν₀, νˢʰ)) + κ_local = ifelse(convecting, (κˢʰ - κᶜⁿ) * tanh(Ri / δRi) + κˢʰ, clamp((κ₀ - κˢʰ) * Ri / Riᶜ + κˢʰ, κ₀, κˢʰ)) + + # Update by averaging in time + @inbounds diffusivities.κᵘ[i, j, k] = ifelse(k <= 1 || k >= grid.Nz+1, 0, ν_local) + @inbounds diffusivities.κᶜ[i, j, k] = ifelse(k <= 1 || k >= grid.Nz+1, 0, κ_local) + + return nothing +end diff --git a/xk_physicalclosure_doublegyre_model.jl b/xk_physicalclosure_doublegyre_model.jl new file mode 100644 index 0000000000..0001b36b20 --- /dev/null +++ b/xk_physicalclosure_doublegyre_model.jl @@ -0,0 +1,579 @@ +#using Pkg +# pkg"add Oceananigans CairoMakie" +using Oceananigans +# include("NN_closure_global.jl") +# include("xin_kai_vertical_diffusivity_local.jl") +include("xin_kai_vertical_diffusivity_2Pr.jl") + +ENV["GKSwstype"] = "100" + +pushfirst!(LOAD_PATH, @__DIR__) + +using Printf +using Statistics +using CairoMakie + +using Oceananigans +using Oceananigans.Units +using Oceananigans.OutputReaders: FieldTimeSeries +using Oceananigans.Grids: xnode, ynode, znode +using SeawaterPolynomials +using SeawaterPolynomials:TEOS10 +using ColorSchemes +using Glob + + +#%% +filename = "doublegyre_XinKaiVerticalDiffusivity_streamfunction" +FILE_DIR = "./Output/$(filename)" +mkpath(FILE_DIR) + +# Architecture +model_architecture = GPU() + +# nn_closure = NNFluxClosure(model_architecture) +# base_closure = XinKaiLocalVerticalDiffusivity() +# closure = (nn_closure, base_closure) + +vertical_base_closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5) +convection_closure = XinKaiVerticalDiffusivity() +# convection_closure = RiBasedVerticalDiffusivity() +closure = (vertical_base_closure, convection_closure) +# closure = vertical_base_closure + +# number of grid points +const Nx = 100 +const Ny = 100 +const Nz = 200 + +const Δz = 8meters +const Lx = 4000kilometers +const Ly = 6000kilometers +const Lz = Nz * Δz + +grid = RectilinearGrid(model_architecture, Float64, + topology = (Bounded, Bounded, Bounded), + size = (Nx, Ny, Nz), + halo = (4, 4, 4), + x = (-Lx/2, Lx/2), + y = (-Ly/2, Ly/2), + z = (-Lz, 0)) + +@info "Built a grid: $grid." + +##### +##### Boundary conditions +##### +const T_north = 0 +const T_south = 30 +const T_mid = (T_north + T_south) / 2 +const ΔT = T_south - T_north + +const S_north = 34 +const S_south = 37 +const S_mid = (S_north + S_south) / 2 + +const τ₀ = 1e-4 + +const μ_drag = 1/30days +const μ_T = 1/8days + +##### +##### Forcing and initial condition +##### + +@inline T_initial(x, y, z) = T_north + ΔT / 2 * (1 + z / Lz) + +@inline surface_u_flux(x, y, t) = -τ₀ * cos(2π * y / Ly) + +surface_u_flux_bc = FluxBoundaryCondition(surface_u_flux) + +@inline u_drag(x, y, t, u) = @inbounds -μ_drag * Lz * u +@inline v_drag(x, y, t, v) = @inbounds -μ_drag * Lz * v + +u_drag_bc = FluxBoundaryCondition(u_drag; field_dependencies=:u) +v_drag_bc = FluxBoundaryCondition(v_drag; field_dependencies=:v) + +u_bcs = FieldBoundaryConditions( top = surface_u_flux_bc, + bottom = u_drag_bc, + north = ValueBoundaryCondition(0), + south = ValueBoundaryCondition(0)) + +v_bcs = FieldBoundaryConditions( top = FluxBoundaryCondition(0), + bottom = v_drag_bc, + east = ValueBoundaryCondition(0), + west = ValueBoundaryCondition(0)) + +@inline T_ref(y) = T_mid - ΔT / Ly * y +@inline surface_T_flux(x, y, t, T) = μ_T * Δz * (T - T_ref(y)) +surface_T_flux_bc = FluxBoundaryCondition(surface_T_flux; field_dependencies=:T) +T_bcs = FieldBoundaryConditions(top = surface_T_flux_bc) + +@inline S_ref(y) = (S_north - S_south) / Ly * y + S_mid +@inline S_initial(x, y, z) = S_ref(y) +@inline surface_S_flux(x, y, t, S) = μ_T * Δz * (S - S_ref(y)) +surface_S_flux_bc = FluxBoundaryCondition(surface_S_flux; field_dependencies=:S) +S_bcs = FieldBoundaryConditions(top = surface_S_flux_bc) + +##### +##### Coriolis +##### +coriolis = BetaPlane(rotation_rate=7.292115e-5, latitude=45, radius=6371e3) + +##### +##### Model building +##### + +@info "Building a model..." + +# This is a weird bug. If a model is not initialized with a closure other than XinKaiVerticalDiffusivity, +# the code will throw a CUDA: illegal memory access error for models larger than a small size. +# This is a workaround to initialize the model with a closure other than XinKaiVerticalDiffusivity first, +# then the code will run without any issues. +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + closure = VerticalScalarDiffusivity(ν=1e-5, κ=1e-5), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +model = HydrostaticFreeSurfaceModel( + grid = grid, + free_surface = SplitExplicitFreeSurface(grid, cfl=0.75), + momentum_advection = WENO(order=5), + tracer_advection = WENO(order=5), + buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()), + coriolis = coriolis, + # closure = (nn_closure, base_closure), + closure = closure, + # closure = RiBasedVerticalDiffusivity(), + tracers = (:T, :S), + boundary_conditions = (; u = u_bcs, v = v_bcs, T = T_bcs, S = S_bcs), +) + +@info "Built $model." + +##### +##### Initial conditions +##### + +# resting initial condition +noise(z) = rand() * exp(z / 8) + +T_initial_noisy(x, y, z) = T_initial(x, y, z) + 1e-6 * noise(z) +S_initial_noisy(x, y, z) = S_initial(x, y, z) + 1e-6 * noise(z) + +set!(model, T=T_initial_noisy, S=S_initial_noisy) +using Oceananigans.TimeSteppers: update_state! +update_state!(model) +##### +##### Simulation building +##### +Δt₀ = 5minutes +stop_time = 10950days + +simulation = Simulation(model, Δt = Δt₀, stop_time = stop_time) + +# add timestep wizard callback +# wizard = TimeStepWizard(cfl=0.25, max_change=1.05, max_Δt=12minutes) +# simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10)) + +# add progress callback +wall_clock = [time_ns()] + +function print_progress(sim) + @printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n", + 100 * (sim.model.clock.time / sim.stop_time), + sim.model.clock.iteration, + prettytime(sim.model.clock.time), + prettytime(1e-9 * (time_ns() - wall_clock[1])), + maximum(abs, sim.model.velocities.u), + maximum(abs, sim.model.velocities.v), + maximum(abs, sim.model.tracers.T), + maximum(abs, sim.model.tracers.S), + prettytime(sim.Δt)) + + wall_clock[1] = time_ns() + + return nothing +end + +simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(100)) + +##### +##### Diagnostics +##### + +u, v, w = model.velocities +T, S = model.tracers.T, model.tracers.S +# ν, κ = model.diffusivity_fields[2].κᵘ, model.diffusivity_fields[2].κᶜ +# Ri = model.diffusivity_fields[2].Ri +# wT, wS = model.diffusivity_fields[2].wT, model.diffusivity_fields[2].wS +U_bt = Field(Integral(u, dims=3)) +Ψ = Field(CumulativeIntegral(-U_bt, dims=2)) + +# outputs = (; u, v, w, T, S, ν, κ, Ri, wT, wS) +# outputs = (; u, v, w, T, S, ν, κ, Ri) +outputs = (; u, v, w, T, S) + +##### +##### Build checkpointer and output writer +##### +simulation.output_writers[:xy] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xy", + indices = (:, :, Nz), + schedule = TimeInterval(5days)) + +simulation.output_writers[:yz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_yz", + indices = (1, :, :), + schedule = TimeInterval(5days)) + +simulation.output_writers[:xz] = JLD2OutputWriter(model, outputs, + filename = "$(FILE_DIR)/instantaneous_fields_xz", + indices = (:, 1, :), + schedule = TimeInterval(5days)) + +simulation.output_writers[:xz_south] = JLD2OutputWriter(model, outputs, + # filename = "NN_closure_2D_channel_NDE_FC_Qb_18simnew_2layer_128_relu_2Pr", + filename = "$(FILE_DIR)/instantaneous_fields_xz_south", + indices = (:, 25, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:xz_north] = JLD2OutputWriter(model, outputs, + # filename = "NN_closure_2D_channel_NDE_FC_Qb_18simnew_2layer_128_relu_2Pr", + filename = "$(FILE_DIR)/instantaneous_fields_xz_north", + indices = (:, 75, :), + schedule = TimeInterval(10days)) + +simulation.output_writers[:streamfunction] = JLD2OutputWriter(model, (; Ψ=Ψ,), + # filename = "NN_closure_2D_channel_NDE_FC_Qb_18simnew_2layer_128_relu_2Pr", + filename = "$(FILE_DIR)/averaged_fields_streamfunction", + schedule = AveragedTimeInterval(365days, window=365days)) + +simulation.output_writers[:checkpointer] = Checkpointer(model, + schedule = TimeInterval(730days), + prefix = "$(FILE_DIR)/checkpointer") + +@info "Running the simulation..." + +try + files = readdir(FILE_DIR) + checkpoint_files = files[occursin.("checkpointer_iteration", files)] + if !isempty(checkpoint_files) + checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files]) + pickup_iter = maximum(checkpoint_iters) + run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2") + else + run!(simulation) + end +catch err + @info "run! threw an error! The error message is" + showerror(stdout, err) +end + +#%% +T_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "T") +T_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "T") +T_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "T") + +S_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "S") +S_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "S") +S_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "S") + +u_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "u") +u_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "u") +u_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "u") + +v_xy_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xy.jld2", "v") +v_xz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_xz.jld2", "v") +v_yz_data = FieldTimeSeries("$(FILE_DIR)/instantaneous_fields_yz.jld2", "v") + +times = T_xy_data.times ./ 24 ./ 60^2 +Nt = length(times) +timeframes = 1:Nt + +# Nx, Ny, Nz = T_xy_data.grid.Nx, T_xy_data.grid.Ny, T_xy_data.grid.Nz +xC, yC, zC = T_xy_data.grid.xᶜᵃᵃ[1:Nx], T_xy_data.grid.yᵃᶜᵃ[1:Ny], T_xy_data.grid.zᵃᵃᶜ[1:Nz] +zF = T_xy_data.grid.zᵃᵃᶠ[1:Nz+1] + +# Lx, Ly, Lz = T_xy_data.grid.Lx, T_xy_data.grid.Ly, T_xy_data.grid.Lz + +xCs_xy = xC +yCs_xy = yC +zCs_xy = [zC[Nz] for x in xCs_xy, y in yCs_xy] + +yCs_yz = yC +xCs_yz = range(xC[1], stop=xC[1], length=length(zC)) +zCs_yz = zeros(length(xCs_yz), length(yCs_yz)) +for j in axes(zCs_yz, 2) + zCs_yz[:, j] .= zC +end + +xCs_xz = xC +yCs_xz = range(yC[1], stop=yC[1], length=length(zC)) +zCs_xz = zeros(length(xCs_xz), length(yCs_xz)) +for i in axes(zCs_xz, 1) + zCs_xz[i, :] .= zC +end + +xFs_xy = xC +yFs_xy = yC +zFs_xy = [zF[Nz+1] for x in xFs_xy, y in yFs_xy] + +yFs_yz = yC +xFs_yz = range(xC[1], stop=xC[1], length=length(zF)) +zFs_yz = zeros(length(xFs_yz), length(yFs_yz)) +for j in axes(zFs_yz, 2) + zFs_yz[:, j] .= zF +end + +xFs_xz = xC +yFs_xz = range(yC[1], stop=yC[1], length=length(zF)) +zFs_xz = zeros(length(xFs_xz), length(yFs_xz)) +for i in axes(zFs_xz, 1) + zFs_xz[i, :] .= zF +end + +function find_min(a...) + return minimum(minimum.([a...])) +end + +function find_max(a...) + return maximum(maximum.([a...])) +end + +# for freeconvection +# startheight = 64 + +# for wind mixing +startheight = 1 +Tlim = (find_min(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(T_xy_data, :, :, 1, timeframes), interior(T_yz_data, 1, :, startheight:Nz, timeframes), interior(T_xz_data, :, 1, startheight:Nz, timeframes))) +Slim = (find_min(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(S_xy_data, :, :, 1, timeframes), interior(S_yz_data, 1, :, startheight:Nz, timeframes), interior(S_xz_data, :, 1, startheight:Nz, timeframes))) +ulim = (-find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(u_xy_data, :, :, 1, timeframes), interior(u_yz_data, 1, :, startheight:Nz, timeframes), interior(u_xz_data, :, 1, startheight:Nz, timeframes))) +vlim = (-find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes)), + find_max(interior(v_xy_data, :, :, 1, timeframes), interior(v_yz_data, 1, :, startheight:Nz, timeframes), interior(v_xz_data, :, 1, startheight:Nz, timeframes))) + +colorscheme = colorschemes[:balance] +T_colormap = colorscheme +S_colormap = colorscheme +u_colormap = colorscheme +v_colormap = colorscheme + +T_color_range = Tlim +S_color_range = Slim +u_color_range = ulim +v_color_range = vlim +#%% +plot_aspect = (2, 3, 0.5) +fig = Figure(size=(1500, 700)) +axT = Axis3(fig[1, 1], title="Temperature (°C)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axS = Axis3(fig[1, 3], title="Salinity (g kg⁻¹)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axu = Axis3(fig[2, 1], title="u (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) +axv = Axis3(fig[2, 3], title="v (m/s)", xlabel="x (m)", ylabel="y (m)", zlabel="z (m)", viewmode=:fitzoom, aspect=plot_aspect) + +n = Observable(1) + +T_xy = @lift interior(T_xy_data[$n], :, :, 1) +T_yz = @lift transpose(interior(T_yz_data[$n], 1, :, :)) +T_xz = @lift interior(T_xz_data[$n], :, 1, :) + +S_xy = @lift interior(S_xy_data[$n], :, :, 1) +S_yz = @lift transpose(interior(S_yz_data[$n], 1, :, :)) +S_xz = @lift interior(S_xz_data[$n], :, 1, :) + +u_xy = @lift interior(u_xy_data[$n], :, :, 1) +u_yz = @lift transpose(interior(u_yz_data[$n], 1, :, :)) +u_xz = @lift interior(u_xz_data[$n], :, 1, :) + +v_xy = @lift interior(v_xy_data[$n], :, :, 1) +v_yz = @lift transpose(interior(v_yz_data[$n], 1, :, :)) +v_xz = @lift interior(v_xz_data[$n], :, 1, :) + +# time_str = @lift "Surface Cooling, Time = $(round(times[$n], digits=2)) hours" +time_str = @lift "Surface Wind Stress, Time = $(round(times[$n], digits=2)) days" +Label(fig[0, :], text=time_str, tellwidth=false, font=:bold) + +T_xy_surface = surface!(axT, xCs_xy, yCs_xy, zCs_xy, color=T_xy, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_yz_surface = surface!(axT, xCs_yz, yCs_yz, zCs_yz, color=T_yz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) +T_xz_surface = surface!(axT, xCs_xz, yCs_xz, zCs_xz, color=T_xz, colormap=T_colormap, colorrange = T_color_range, lowclip=T_colormap[1]) + +S_xy_surface = surface!(axS, xCs_xy, yCs_xy, zCs_xy, color=S_xy, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_yz_surface = surface!(axS, xCs_yz, yCs_yz, zCs_yz, color=S_yz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) +S_xz_surface = surface!(axS, xCs_xz, yCs_xz, zCs_xz, color=S_xz, colormap=S_colormap, colorrange = S_color_range, lowclip=S_colormap[1]) + +u_xy_surface = surface!(axu, xCs_xy, yCs_xy, zCs_xy, color=u_xy, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_yz_surface = surface!(axu, xCs_yz, yCs_yz, zCs_yz, color=u_yz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) +u_xz_surface = surface!(axu, xCs_xz, yCs_xz, zCs_xz, color=u_xz, colormap=u_colormap, colorrange = u_color_range, lowclip=u_colormap[1], highclip=u_colormap[end]) + +v_xy_surface = surface!(axv, xCs_xy, yCs_xy, zCs_xy, color=v_xy, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_yz_surface = surface!(axv, xCs_yz, yCs_yz, zCs_yz, color=v_yz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) +v_xz_surface = surface!(axv, xCs_xz, yCs_xz, zCs_xz, color=v_xz, colormap=v_colormap, colorrange = v_color_range, lowclip=v_colormap[1], highclip=v_colormap[end]) + +Colorbar(fig[1,2], T_xy_surface) +Colorbar(fig[1,4], S_xy_surface) +Colorbar(fig[2,2], u_xy_surface) +Colorbar(fig[2,4], v_xy_surface) + +xlims!(axT, (-Lx/2, Lx/2)) +xlims!(axS, (-Lx/2, Lx/2)) +xlims!(axu, (-Lx/2, Lx/2)) +xlims!(axv, (-Lx/2, Lx/2)) + +ylims!(axT, (-Ly/2, Ly/2)) +ylims!(axS, (-Ly/2, Ly/2)) +ylims!(axu, (-Ly/2, Ly/2)) +ylims!(axv, (-Ly/2, Ly/2)) + +zlims!(axT, (-Lz, 0)) +zlims!(axS, (-Lz, 0)) +zlims!(axu, (-Lz, 0)) +zlims!(axv, (-Lz, 0)) + +CairoMakie.record(fig, "$(FILE_DIR)/$(filename)_3D_instantaneous_fields.mp4", 1:Nt, framerate=20, px_per_unit=2) do nn + @info nn + n[] = nn +end + +# display(fig) +#%% + +# # ##### +# # ##### Visualization +# # ##### +# using CairoMakie + +# dataname = "NN_closure_doublegyre_NDE_FC_Qb_absf_24simnew_2layer_128_relu_2Pr" +# DATA_DIR = "./$(dataname).jld2" + +# u_data = FieldTimeSeries("$(DATA_DIR)", "u") +# v_data = FieldTimeSeries("$(DATA_DIR)", "v") +# T_data = FieldTimeSeries("$(DATA_DIR)", "T") +# S_data = FieldTimeSeries("$(DATA_DIR)", "S") +# # ν_data = FieldTimeSeries("$(DATA_DIR)", "ν") +# # κ_data = FieldTimeSeries("$(DATA_DIR)", "κ") +# # Ri_data = FieldTimeSeries("$(DATA_DIR)", "Ri") +# # wT_data = FieldTimeSeries("$(DATA_DIR)", "wT") +# # wS_data = FieldTimeSeries("$(DATA_DIR)", "wS") + +# yC = ynodes(T_data.grid, Center()) +# yF = ynodes(T_data.grid, Face()) + +# zC = znodes(T_data.grid, Center()) +# zF = znodes(T_data.grid, Face()) + +# Nt = length(T_data.times) +# #%% +# fig = Figure(size = (1500, 900)) +# axu = CairoMakie.Axis(fig[1, 1], xlabel = "y (m)", ylabel = "z (m)", title = "u (m/s)") +# axv = CairoMakie.Axis(fig[1, 3], xlabel = "y (m)", ylabel = "z (m)", title = "v (m/s)") +# axT = CairoMakie.Axis(fig[2, 1], xlabel = "y (m)", ylabel = "z (m)", title = "Temperature (°C)") +# axS = CairoMakie.Axis(fig[2, 3], xlabel = "y (m)", ylabel = "z (m)", title = "Salinity (psu)") +# n = Observable(1) + +# uₙ = @lift interior(u_data[$n], 45, :, :) +# vₙ = @lift interior(v_data[$n], 45, :, :) +# Tₙ = @lift interior(T_data[$n], 45, :, :) +# Sₙ = @lift interior(S_data[$n], 45, :, :) + +# ulim = @lift (-maximum([maximum(abs, $uₙ), 1e-16]), maximum([maximum(abs, $uₙ), 1e-16])) +# vlim = @lift (-maximum([maximum(abs, $vₙ), 1e-16]), maximum([maximum(abs, $vₙ), 1e-16])) +# Tlim = (minimum(interior(T_data[1])), maximum(interior(T_data[1]))) +# Slim = (minimum(interior(S_data[1])), maximum(interior(S_data[1]))) + +# title_str = @lift "Time: $(round(T_data.times[$n] / 86400, digits=2)) days" +# Label(fig[0, :], title_str, tellwidth = false) + +# hu = heatmap!(axu, yC, zC, uₙ, colormap=:RdBu_9, colorrange=ulim) +# hv = heatmap!(axv, yF, zC, vₙ, colormap=:RdBu_9, colorrange=vlim) +# hT = heatmap!(axT, yC, zC, Tₙ, colorrange=Tlim) +# hS = heatmap!(axS, yC, zC, Sₙ, colorrange=Slim) + +# Colorbar(fig[1, 2], hu, label = "u (m/s)") +# Colorbar(fig[1, 4], hv, label = "v (m/s)") +# Colorbar(fig[2, 2], hT, label = "T (°C)") +# Colorbar(fig[2, 4], hS, label = "S (psu)") + +# CairoMakie.record(fig, "./$(dataname)_test.mp4", 1:Nt, framerate=10) do nn +# n[] = nn +# end + +# display(fig) +# #%% +# fig = Figure(size = (1920, 1080)) +# axu = CairoMakie.Axis(fig[1, 1], xlabel = "y (m)", ylabel = "z (m)", title = "u") +# axv = CairoMakie.Axis(fig[1, 3], xlabel = "y (m)", ylabel = "z (m)", title = "v") +# axT = CairoMakie.Axis(fig[2, 1], xlabel = "y (m)", ylabel = "z (m)", title = "Temperature") +# axS = CairoMakie.Axis(fig[2, 3], xlabel = "y (m)", ylabel = "z (m)", title = "Salinity") +# axν = CairoMakie.Axis(fig[1, 5], xlabel = "y (m)", ylabel = "z (m)", title = "Viscosity (log10 scale)") +# axκ = CairoMakie.Axis(fig[2, 5], xlabel = "y (m)", ylabel = "z (m)", title = "Diffusivity (log10 scale)") +# axRi = CairoMakie.Axis(fig[3, 5], xlabel = "y (m)", ylabel = "z (m)", title = "Richardson number") +# axwT = CairoMakie.Axis(fig[3, 1], xlabel = "y (m)", ylabel = "z (m)", title = "wT(NN)") +# axwS = CairoMakie.Axis(fig[3, 3], xlabel = "y (m)", ylabel = "z (m)", title = "wS(NN)") + +# n = Observable(1) + +# uₙ = @lift interior(u_data[$n], 1, :, :) +# vₙ = @lift interior(v_data[$n], 1, :, :) +# Tₙ = @lift interior(T_data[$n], 1, :, :) +# Sₙ = @lift interior(S_data[$n], 1, :, :) +# νₙ = @lift log10.(interior(ν_data[$n], 1, :, :)) +# κₙ = @lift log10.(interior(κ_data[$n], 1, :, :)) +# Riₙ = @lift clamp.(interior(Ri_data[$n], 1, :, :), -20, 20) +# wTₙ = @lift interior(wT_data[$n], 1, :, :) +# wSₙ = @lift interior(wS_data[$n], 1, :, :) + +# ulim = @lift (-maximum([maximum(abs, $uₙ), 1e-7]), maximum([maximum(abs, $uₙ), 1e-7])) +# vlim = @lift (-maximum([maximum(abs, $vₙ), 1e-7]), maximum([maximum(abs, $vₙ), 1e-7])) +# Tlim = (minimum(interior(T_data[1])), maximum(interior(T_data[1]))) +# Slim = (minimum(interior(S_data[1])), maximum(interior(S_data[1]))) +# νlim = (-6, 2) +# κlim = (-6, 2) +# wTlim = @lift (-maximum([maximum(abs, $wTₙ), 1e-7]), maximum([maximum(abs, $wTₙ), 1e-7])) +# wSlim = @lift (-maximum([maximum(abs, $wSₙ), 1e-7]), maximum([maximum(abs, $wSₙ), 1e-7])) + +# title_str = @lift "Time: $(round(T_data.times[$n] / 86400, digits=2)) days" +# Label(fig[0, :], title_str, tellwidth = false) + +# hu = heatmap!(axu, yC, zC, uₙ, colormap=:RdBu_9, colorrange=ulim) +# hv = heatmap!(axv, yF, zC, vₙ, colormap=:RdBu_9, colorrange=vlim) +# hT = heatmap!(axT, yC, zC, Tₙ, colorrange=Tlim) +# hS = heatmap!(axS, yC, zC, Sₙ, colorrange=Slim) +# hν = heatmap!(axν, yC, zC, νₙ, colorrange=νlim) +# hκ = heatmap!(axκ, yC, zC, κₙ, colorrange=κlim) +# hRi = heatmap!(axRi, yC, zF, Riₙ, colormap=:RdBu_9, colorrange=(-20, 20)) +# hwT = heatmap!(axwT, yC, zF, wTₙ, colormap=:RdBu_9, colorrange=wTlim) +# hwS = heatmap!(axwS, yC, zF, wSₙ, colormap=:RdBu_9, colorrange=wSlim) + +# cbu = Colorbar(fig[1, 2], hu, label = "(m/s)") +# cbv = Colorbar(fig[1, 4], hv, label = "(m/s)") +# cbT = Colorbar(fig[2, 2], hT, label = "(°C)") +# cbS = Colorbar(fig[2, 4], hS, label = "(psu)") +# cbν = Colorbar(fig[1, 6], hν, label = "(m²/s)") +# cbκ = Colorbar(fig[2, 6], hκ, label = "(m²/s)") +# cbRi = Colorbar(fig[3, 6], hRi) +# cbwT = Colorbar(fig[3, 2], hwT, label = "(m/s °C)") +# cbwS = Colorbar(fig[3, 4], hwS, label = "(m/s psu)") + +# tight_ticklabel_spacing!(cbu) +# tight_ticklabel_spacing!(cbv) +# tight_ticklabel_spacing!(cbT) +# tight_ticklabel_spacing!(cbS) +# tight_ticklabel_spacing!(cbν) +# tight_ticklabel_spacing!(cbκ) +# tight_ticklabel_spacing!(cbRi) +# tight_ticklabel_spacing!(cbwT) +# tight_ticklabel_spacing!(cbwS) + +# CairoMakie.record(fig, "./$(dataname)_2D_sin_cooling_heating_23days_fluxes.mp4", 1:Nt, framerate=30) do nn +# n[] = nn +# end +# #%% \ No newline at end of file