diff --git a/src/DataLayouts/DataLayouts.jl b/src/DataLayouts/DataLayouts.jl index 921d163801..db48b2ad93 100644 --- a/src/DataLayouts/DataLayouts.jl +++ b/src/DataLayouts/DataLayouts.jl @@ -917,6 +917,18 @@ end rebuild(data::VF{S, Nv}, array::AbstractArray{T, 2}) where {S, Nv, T} = VF{S, Nv}(array) +@inline function rebuild_with_MArray(data::VF) + Nv = nlevels(data) + Nf = ncomponents(data) + FT = eltype(parent(data)) + localmem = MArray{Tuple{Nv, Nf}, FT, 2, Nv * Nf}(undef) + rdata = rebuild(data, localmem) + @inbounds for v in 1:Nv + rdata[v] = data[v] + end + rdata +end + function replace_basetype(data::VF{S, Nv}, ::Type{T}) where {S, Nv, T} array = parent(data) S′ = replace_basetype(eltype(array), T, S) diff --git a/src/Operators/finitedifference.jl b/src/Operators/finitedifference.jl index 3b31dde3b2..f862c827e1 100644 --- a/src/Operators/finitedifference.jl +++ b/src/Operators/finitedifference.jl @@ -3364,6 +3364,8 @@ function Base.copyto!( (Ni, Nj, _, _, Nh) = size(local_geometry) context = ClimaComms.context(axes(field_out)) device = ClimaComms.device(context) + ᶜspace = Spaces.space(space, Grids.CellCenter()) + if (device isa ClimaComms.CPUMultiThreaded) && Nh > 1 return _threaded_copyto!(field_out, bc, Ni, Nj, Nh) end @@ -3386,6 +3388,70 @@ function window_bounds(space, bc) return (li, lw, rw, ri) end +# Recursively call transform_bc_args() on broadcast arguments in a way that is statically reducible by the optimizer +# see Base.Broadcast.preprocess_args +@inline transform_to_local_mem_args(args::Tuple, hidx, lg_data) = ( + transform_to_local_mem(args[1], hidx, lg_data), + transform_to_local_mem_args(Base.tail(args), hidx, lg_data)..., +) +@inline transform_to_local_mem_args(args::Tuple{Any}, hidx, lg_data) = + (transform_to_local_mem(args[1], hidx, lg_data),) +@inline transform_to_local_mem_args(args::Tuple{}, hidx, lg_data) = () + +@inline function transform_to_local_mem( + bc::StencilBroadcasted{ColumnStencilStyle}, + hidx, lg_data +) + StencilBroadcasted{ColumnStencilStyle}( + bc.op, + transform_to_local_mem_args(bc.args, hidx, lg_data), + bc.axes + ) +end + +@inline function transform_to_local_mem( + bc::Base.Broadcast.Broadcasted, + hidx, lg_data +) + args = transform_to_local_mem_args(bc.args, hidx, lg_data) + Base.Broadcast.Broadcasted( + bc.f, + args, + bc.axes + ) +end +import StaticArrays: MArray +@inline function transform_to_local_mem(data::DataLayouts.DataColumn, hidx, lg_data) + if eltype(data) <: Geometry.LocalGeometry # we al + (ᶠlg, ᶜlg) = lg_data + if DataLayouts.nlevels(data) == DataLayouts.nlevels(ᶠlg) + return ᶠlg + elseif DataLayouts.nlevels(data) == DataLayouts.nlevels(ᶜlg) + return ᶜlg + else + error("oops") + end + elseif parent(data) isa MArray + return data + else + return DataLayouts.rebuild_with_MArray(data) + end +end +@inline function transform_to_local_mem(f::Fields.Field, hidx, lg_data) + (ᶠlg, ᶜlg) = lg_data + fdata = Fields.field_values(f) + datacol_lm = transform_to_local_mem(fdata, hidx, lg_data) + return Fields.Field(datacol_lm, axes(f)) +end +@inline transform_to_local_mem(x::Tuple, hidx, lg_data) = + (transform_to_local_mem(first(x), hidx, lg_data), + transform_to_local_mem(Base.tail(x), hidx, lg_data)...) +@inline transform_to_local_mem(x::Tuple{Any}, hidx, lg_data) = + (transform_to_local_mem(first(x), hidx, lg_data),) +@inline transform_to_local_mem(x::Tuple{}, hidx, lg_data) = () + +@inline transform_to_local_mem(x, hidx, lg_data) = x +@inline transform_to_local_mem(x::DataLayouts.VIJFH, hidx, lg_data) = error("Data $x was not columnized.") Base.@propagate_inbounds function apply_stencil!( space, @@ -3394,6 +3460,70 @@ Base.@propagate_inbounds function apply_stencil!( hidx, (li, lw, rw, ri) = window_bounds(space, bc), ) + + (i, j, h) = hidx + bc_col = Spaces.column(bc, i,j,h) + ᶠspace = Spaces.FaceExtrudedFiniteDifferenceSpace(space) + ᶜspace = Spaces.CenterExtrudedFiniteDifferenceSpace(space) + ᶠlg_col = Spaces.column(Spaces.local_geometry_data(ᶠspace), i,j,h) + ᶜlg_col = Spaces.column(Spaces.local_geometry_data(ᶜspace), i,j,h) + ᶠlg_col_localmem = DataLayouts.rebuild_with_MArray(ᶠlg_col) + ᶜlg_col_localmem = DataLayouts.rebuild_with_MArray(ᶜlg_col) + lg_data = (ᶠlg_col_localmem, ᶜlg_col_localmem) + + try + bc_localmem = transform_to_local_mem(bc_col, hidx, lg_data) + catch + @show bc_col + bc_localmem = transform_to_local_mem(bc_col, hidx, lg_data) + end + field_out_col = Fields.column(field_out, i,j,h) + if !Topologies.isperiodic(Spaces.vertical_topology(space)) + # left window + lbw = LeftBoundaryWindow{Spaces.left_boundary_name(space)}() + @inbounds for idx in li:(lw - 1) + setidx!( + space, + field_out_col, + idx, + hidx, + getidx(space, bc_localmem, lbw, idx, hidx), + ) + end + end + # interior + @inbounds for idx in lw:rw + setidx!( + space, + field_out_col, + idx, + hidx, + getidx(space, bc_localmem, Interior(), idx, hidx), + ) + end + if !Topologies.isperiodic(Spaces.vertical_topology(space)) + # right window + rbw = RightBoundaryWindow{Spaces.right_boundary_name(space)}() + @inbounds for idx in (rw + 1):ri + setidx!( + space, + field_out_col, + idx, + hidx, + getidx(space, bc_localmem, rbw, idx, hidx), + ) + end + end + return field_out +end + +Base.@propagate_inbounds function apply_stencil!( + space::Spaces.FiniteDifferenceSpace, + field_out, + bc, + hidx, + (li, lw, rw, ri) = window_bounds(space, bc), +) if !Topologies.isperiodic(Spaces.vertical_topology(space)) # left window lbw = LeftBoundaryWindow{Spaces.left_boundary_name(space)}() diff --git a/test/Operators/finitedifference/opt_examples.jl b/test/Operators/finitedifference/opt_examples.jl index 41389c874f..86ea28970b 100644 --- a/test/Operators/finitedifference/opt_examples.jl +++ b/test/Operators/finitedifference/opt_examples.jl @@ -565,7 +565,7 @@ function set_ᶠuₕ³!(ᶜx, ᶠx) @. ᶠx.ᶠuₕ³ = ᶠwinterp(ᶜx.ρ * ᶜJ, CT3(ᶜx.uₕ)) return nothing end -@testset "Inference/allocations when broadcasting types" begin +# @testset "Inference/allocations when broadcasting types" begin FT = Float64 cspace = TU.CenterExtrudedFiniteDifferenceSpace(FT; zelem = 25, helem = 10) fspace = Spaces.FaceExtrudedFiniteDifferenceSpace(cspace) @@ -583,4 +583,4 @@ end @benchmark set_ᶠuₕ³!($ ᶜx, $ᶠx) end show(stdout, MIME("text/plain"), trial) -end +# end