diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 93bdf9b49b..1c19c6d973 100755 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -861,6 +861,18 @@ steps: agents: slurm_gpus: 1 + - label: "Unit: scalar_fieldmatrix (CPU)" + key: cpu_scalar_fieldmatrix + command: "julia --color=yes --check-bounds=yes --project=.buildkite test/MatrixFields/scalar_fieldmatrix.jl" + + - label: "Unit: mscalar_fieldmatrix (GPU)" + key: gpu_scalar_fieldmatrix + command: "julia --color=yes --project=.buildkite test/MatrixFields/scalar_fieldmatrix.jl" + env: + CLIMACOMMS_DEVICE: "CUDA" + agents: + slurm_gpus: 1 + - group: "Unit: MatrixFields - broadcasting (CPU)" steps: diff --git a/docs/src/matrix_fields.md b/docs/src/matrix_fields.md index 4c89aa765d..21f423e2b7 100644 --- a/docs/src/matrix_fields.md +++ b/docs/src/matrix_fields.md @@ -89,6 +89,8 @@ preconditioner_cache check_preconditioner lazy_or_concrete_preconditioner apply_preconditioner +get_scalar_keys +field_offset_and_type ``` ## Utilities @@ -98,4 +100,97 @@ column_field2array column_field2array_view field2arrays field2arrays_view +scalar_fieldmatrix ``` + +## Indexing a FieldMatrix + +A FieldMatrix entry can be: + +- An `UniformScaling`, which contains a `Number` +- A `DiagonalMatrixRow`, which can contain aything +- A `ColumnwiseBandMatrixField`, where each row is a [`BandMatrixRow`](@ref) where the band element type is representable with the space's base number type. + +If an entry contains a composite type, the fields of that type can be extracted. +This is also true for nested composite types. + +For example: + +```@example 1 +using ClimaCore.CommonSpaces # hide +import ClimaCore: MatrixFields, Quadratures # hide +import ClimaCore.MatrixFields: @name # hide +space = Box3DSpace(; # hide + z_elem = 3, # hide + x_min = 0, # hide + x_max = 1, # hide + y_min = 0, # hide + y_max = 1, # hide + z_min = 0, # hide + z_max = 10, # hide + periodic_x = false, # hide + periodic_y = false, # hide + n_quad_points = 1, # hide + quad = Quadratures.GL{1}(), # hide + x_elem = 1, # hide + y_elem = 2, # hide + staggering = CellCenter() # hide + ) # hide +nt_entry_field = fill(MatrixFields.DiagonalMatrixRow((; foo = 1.0, bar = 2.0)), space) +nt_fieldmatrix = MatrixFields.FieldMatrix((@name(a), @name(b)) => nt_entry_field) +nt_fieldmatrix[(@name(a), @name(b))] +``` + +The internal values of the named tuples can be extracted with + +```@example 1 +nt_fieldmatrix[(@name(a.foo), @name(b))] +``` + +and + +```@example 1 +nt_fieldmatrix[(@name(a.bar), @name(b))] +``` + +### Further Indexing Details + +Let key `(@name(name1), @name(name2))` correspond to entry `sample_entry` in `FieldMatrix` `A`. +An example of this is: + +```julia + A = MatrixFields.FieldMatrix((@name(name1), @name(name2)) => sample_entry) +``` + +Now consider what happens indexing `A` with the key `(@name(name1.foo.bar.buz), @name(name2.biz.bop.fud))`. + +First, a function searches the keys of `A` for a key that `(@name(foo.bar.buz), @name(biz.bop.fud))` +is a child of. In this example, `(@name(foo.bar.buz), @name(biz.bop.fud))` is a child of +the key `(@name(name1), @name(name2))`, and +`(@name(foo.bar.buz), @name(biz.bop.fud))` is referred to as the internal key. + +Next, the entry that `(@name(name1), @name(name2))` is paired with is recursively indexed +by the internal key. + +The recursive indexing of an internal entry given some entry `entry` and internal_key `internal_name_pair` +works as follows: + +1. If the `internal_name_pair` is blank, return `entry` +2. If the element type of each band of `entry` is an `Axis2Tensor`, and `internal_name_pair` is of the form +`(@name(components.data.1...), @name(components.data.2...))` (potentially with different numbers), +then extract the specified component, and recurse on it with the remaining `internal_name_pair`. +3. If the element type of each band of `entry` is a `Geometry.AdjointAxisVector`, then recurse on the parent of the adjoint. +4. If `internal_name_pair[1]` is not empty, and the first name in it is a field of the element type of each band of `entry`, +extract that field from `entry`, and recurse on the it with the remaining names of `internal_name_pair[1]` and all of `internal_name_pair[2]` +5. If `internal_name_pair[2]` is not empty, and the first name in it is a field of the element type of each row of `entry`, +extract that field from `entry`, and recurse on the it with all of `internal_name_pair[1]` and the remaining names of `internal_name_pair[2]` +6. At this point, if none of the previous cases are true, both `internal_name_pair[1]` and `internal_name_pair[2]` should be +non-empty, and it is assumed that `entry` is being used to implicitly represent some tensor structure. If the first name in +`internal_name_pair[1]` is equivalent to `internal_name_pair[2]`, then both the first names are dropped, and entry is recursed onto. +If the first names are different, both the first names are dropped, and the zero of entry is recursed onto. + +When the entry is a `ColumnWiseBandMatrixField`, indexing it will return a broadcasted object in +the following situations: + +1. The internal key indexes to a type different than the basetype of the entry +2. The internal key indexes to a zero-ed value diff --git a/src/Geometry/axistensors.jl b/src/Geometry/axistensors.jl index 534e628380..8bbe7a6002 100644 --- a/src/Geometry/axistensors.jl +++ b/src/Geometry/axistensors.jl @@ -61,7 +61,7 @@ coordinate_axis(::Type{<:LatLongPoint}) = (1, 2) coordinate_axis(coord::AbstractPoint) = coordinate_axis(typeof(coord)) -@inline idxin(I::Tuple{Int}, i::Int) = 1 +@inline idxin(I::Tuple{Int}, i::Int) = I[1] == i ? 1 : nothing @inline function idxin(I::Tuple{Int, Int}, i::Int) @inbounds begin @@ -308,6 +308,9 @@ Base.zero(::Type{AdjointAxisTensor{T, N, A, S}}) where {T, N, A, S} = const AdjointAxisVector{T, A1, S} = Adjoint{T, AxisVector{T, A1, S}} +const AxisVectorOrAdj{T, A, S} = + Union{AxisVector{T, A, S}, AdjointAxisVector{T, A, S}} + Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int) = getindex(components(va), i) Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int, j::Int) = diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl index e2acdc67c1..43853e0096 100644 --- a/src/MatrixFields/MatrixFields.jl +++ b/src/MatrixFields/MatrixFields.jl @@ -58,6 +58,7 @@ import ..Utilities: PlusHalf, half import ..RecursiveApply: rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv import ..RecursiveApply: ⊠, ⊞, ⊟ +import ..DataLayouts import ..DataLayouts: AbstractData import ..DataLayouts: vindex import ..Geometry diff --git a/src/MatrixFields/field_name_dict.jl b/src/MatrixFields/field_name_dict.jl index 2d98bb3ff3..6dd9cf773b 100644 --- a/src/MatrixFields/field_name_dict.jl +++ b/src/MatrixFields/field_name_dict.jl @@ -150,44 +150,122 @@ get_internal_key(child_name_pair::FieldNamePair, name_pair::FieldNamePair) = ( ) get_internal_entry(entry, name::FieldName, key_error) = get_field(entry, name) -get_internal_entry(entry, name_pair::FieldNamePair, key_error) = - name_pair == (@name(), @name()) ? entry : throw(key_error) +# call get_internal_entry on scaling value, and rebuild entry container +get_internal_entry(entry::UniformScaling, name_pair::FieldNamePair, key_error) = + UniformScaling( + get_internal_entry(scaling_value(entry), name_pair, key_error), + ) get_internal_entry( - entry::ScalingFieldMatrixEntry, + entry::DiagonalMatrixRow, name_pair::FieldNamePair, key_error, -) = - if name_pair[1] == name_pair[2] - entry - elseif is_overlapping_name(name_pair[1], name_pair[2]) - throw(key_error) +) = DiagonalMatrixRow( + get_internal_entry(scaling_value(entry), name_pair, key_error), +) +# get_internal_entry to be used on the values held inside a `BandMatrixRow` +function get_internal_entry( + entry::T, + name_pair::FieldNamePair, + key_error, +) where {T} + if name_pair == (@name(), @name()) + return entry + elseif T <: Geometry.Axis2Tensor && + all(n -> is_child_name(n, @name(components.data)), name_pair) + # two indices needed to index into a 2d tensor (one can be Colon()) + internal_row_name = + extract_internal_name(name_pair[1], @name(components.data)) + internal_col_name = + extract_internal_name(name_pair[2], @name(components.data)) + row_index = extract_first(internal_row_name) + col_index = extract_first(internal_col_name) + return get_internal_entry( + entry[row_index, col_index], + (drop_first(internal_row_name), drop_first(internal_col_name)), + key_error, + ) + elseif T <: Geometry.AdjointAxisVector # bypass parent for adjoint vectors + return get_internal_entry( + getfield(entry, :parent), + name_pair, + key_error, + ) + elseif name_pair[1] != @name() && + extract_first(name_pair[1]) in fieldnames(T) + return get_internal_entry( + getfield(entry, extract_first(name_pair[1])), + (drop_first(name_pair[1]), name_pair[2]), + key_error, + ) + elseif name_pair[2] != @name() && + extract_first(name_pair[2]) in fieldnames(T) + return get_internal_entry( + getfield(entry, extract_first(name_pair[2])), + (name_pair[1], drop_first(name_pair[2])), + key_error, + ) + elseif !any(isequal(@name()), name_pair) # implicit tensor structure + return get_internal_entry( + extract_first(name_pair[1]) == extract_first(name_pair[2]) ? entry : + zero(entry), + (drop_first(name_pair[1]), drop_first(name_pair[2])), + key_error, + ) else - zero(entry) + throw(key_error) end +end function get_internal_entry( entry::ColumnwiseBandMatrixField, name_pair::FieldNamePair, key_error, ) - # Ensure compatibility with RecursiveApply (i.e., with rmul). - # See note above matrix_product_keys in field_name_set.jl for more details. - T = eltype(eltype(entry)) - if name_pair == (@name(), @name()) - entry - elseif name_pair[1] == name_pair[2] - # multiplication case 3 or 4, first argument - @assert T <: Geometry.SingleValue && - !broadcasted_has_field(T, name_pair[1]) - entry - elseif name_pair[2] == @name() && broadcasted_has_field(T, name_pair[1]) - # multiplication case 2 or 4, second argument - Base.broadcasted(entry) do matrix_row + name_pair == (@name(), @name()) && return entry + S = eltype(eltype(entry)) + T = eltype(parent(entry)) + (start_offset, target_type, apply_zero) = + field_offset_and_type(name_pair, T, S, key_error) + if target_type <: eltype(parent(entry)) && !apply_zero + band_element_size = + DataLayouts.typesize(eltype(parent(entry)), eltype(eltype(entry))) + singleton_datalayout = DataLayouts.singleton(Fields.field_values(entry)) + scalar_band_type = + band_matrix_row_type(outer_diagonals(eltype(entry))..., target_type) + field_dim_size = DataLayouts.ncomponents(Fields.field_values(entry)) + parent_indices = DataLayouts.to_data_specific_field( + singleton_datalayout, + (:, :, (start_offset + 1):band_element_size:field_dim_size, :, :), + ) + scalar_data = view(parent(entry), parent_indices...) + values = DataLayouts.union_all(singleton_datalayout){ + scalar_band_type, + Base.tail(DataLayouts.type_params(Fields.field_values(entry)))..., + }( + scalar_data, + ) + return Fields.Field(values, axes(entry)) + elseif apply_zero + zero_value = zero(target_type) + return Base.broadcasted(entry) do matrix_row map(matrix_row) do matrix_row_entry - broadcasted_get_field(matrix_row_entry, name_pair[1]) + # zero(target_type) + zero_value end - end # Note: This assumes that the entry is in a FieldMatrixBroadcasted. + end + elseif target_type == S + return entry else - throw(key_error) + return Base.broadcasted(entry) do matrix_row + map(matrix_row) do matrix_row_entry + get_internal_entry(matrix_row_entry, name_pair, key_error) + end + end + end +end +if hasfield(Method, :recursion_relation) + dont_limit = (args...) -> true + for m in methods(get_internal_entry) + m.recursion_relation = dont_limit end end @@ -237,6 +315,238 @@ function Base.one(matrix::FieldMatrix) return FieldNameDict(inferred_diagonal_keys, entries) end +""" + field_offset_and_type(name_pair::FieldNamePair, ::Type{T}, ::Type{S}, key_error) + +Returns the offset of the field with name `name_pair` in an object of type `S` in +multiples of `sizeof(T)` and the type of the field with name `name_pair`. + +When `S` is a `Geometry.Axis2Tensor`, the name pair must index into a scalar of +the tensor or be empty. In other words, the name pair cannot index into a slice. + +If neither element of `name_pair` is `@name()`, the first name in the pair is indexed with +first, and then the second name is used to index the result of the first. +""" +function field_offset_and_type( + name_pair::FieldNamePair, + ::Type{T}, + ::Type{S}, + key_error, +) where {S, T} + name_pair == (@name(), @name()) && return (0, S, false) # base case + if S <: Geometry.Axis2Tensor && + all(n -> is_child_name(n, @name(components.data)), name_pair)# special case to calculate index + (name_pair[1] == @name() || name_pair[2] == @name()) && throw(key_error) + internal_row_name = + extract_internal_name(name_pair[1], @name(components.data)) + internal_col_name = + extract_internal_name(name_pair[2], @name(components.data)) + row_index = extract_first(internal_row_name) + col_index = extract_first(internal_col_name) + ((row_index isa Number) && (col_index isa Number)) || throw(key_error) # slicing not supported + (n_rows, n_cols) = map(length, axes(S)) + (remaining_offset, end_type, apply_zero) = field_offset_and_type( + (drop_first(internal_row_name), drop_first(internal_col_name)), + T, + eltype(S), + key_error, + ) + (row_index <= n_rows && col_index <= n_cols) || throw(key_error) + return ( + (n_rows * (col_index - 1) + row_index - 1) + remaining_offset, + end_type, + apply_zero, + ) + elseif S <: Geometry.AdjointAxisVector + return field_offset_and_type(name_pair, T, fieldtype(S, 1), key_error) + elseif name_pair[1] != @name() && + extract_first(name_pair[1]) in fieldnames(S) + + remaining_field_chain = (drop_first(name_pair[1]), name_pair[2]) + child_type = fieldtype(S, extract_first(name_pair[1])) + field_index = unrolled_filter( + i -> fieldname(S, i) == extract_first(name_pair[1]), + 1:fieldcount(S), + )[1] + (remaining_offset, end_type, apply_zero) = field_offset_and_type( + remaining_field_chain, + T, + child_type, + key_error, + ) + return ( + DataLayouts.fieldtypeoffset(T, S, field_index) + remaining_offset, + end_type, + apply_zero, + ) + elseif name_pair[2] != @name() && + extract_first(name_pair[2]) in fieldnames(S) + + remaining_field_chain = name_pair[1], drop_first(name_pair[2]) + child_type = fieldtype(S, extract_first(name_pair[2])) + field_index = unrolled_filter( + i -> fieldname(S, i) == extract_first(name_pair[2]), + 1:fieldcount(S), + )[1] + (remaining_offset, end_type, apply_zero) = field_offset_and_type( + remaining_field_chain, + T, + child_type, + key_error, + ) + return ( + DataLayouts.fieldtypeoffset(T, S, field_index) + remaining_offset, + end_type, + apply_zero, + ) + elseif !any(isequal(@name()), name_pair) # implicit tensor structure + (remaining_offset, end_type, apply_zero) = field_offset_and_type( + (drop_first(name_pair[1]), drop_first(name_pair[2])), + T, + S, + key_error, + ) + return ( + remaining_offset, + end_type, + extract_first(name_pair[1]) == extract_first(name_pair[2]) ? + apply_zero : true, + ) + else + throw(key_error) + end +end +if hasfield(Method, :recursion_relation) + dont_limit = (args...) -> true + for m in methods(field_offset_and_type) + m.recursion_relation = dont_limit + end +end + +""" + get_scalar_keys(dict::FieldMatrix) + +Returns a `FieldMatrixKeys` object that contains the keys that result in +a `ScalingFieldMatrixEntry{<: target_type}` or a `ColumnwiseBandMatrixField` with bands of +eltype `<: target_type` when indexing `dict`. `target_type` is determined by the eltype of the +parent of the first entry in `dict` that is a `Fields.Field`. If no such entry +is found, `target_type` defaults to `Number`. +""" +function get_scalar_keys(dict::FieldMatrix) + first_field_idx = unrolled_findfirst(x -> x isa Fields.Field, dict.entries) + target_type = Val( + isnothing(first_field_idx) ? Number : + eltype(parent(dict.entries[first_field_idx])), + ) + keys_tuple = unrolled_flatmap(keys(dict).values) do outer_key + unrolled_map( + get_scalar_keys(eltype(dict[outer_key]), target_type), + ) do inner_key + ( + append_internal_name(outer_key[1], inner_key[1]), + append_internal_name(outer_key[2], inner_key[2]), + ) + end + end + return FieldMatrixKeys(keys_tuple) +end + +""" + get_scalar_keys(T::Type, ::Val{FT}) + +Returns a tuple of `FieldNamePair` objects that correspond to any children +of `T` that are of type `<: FT`. +""" +function get_scalar_keys(::Type{T}, ::Val{FT}) where {T, FT} + if T <: FT + return ((@name(), @name()),) + elseif T <: BandMatrixRow + return get_scalar_keys(eltype(T), Val(FT)) + elseif T <: Geometry.Axis2Tensor + return unrolled_flatmap(1:length(axes(T)[1])) do row_component + unrolled_map(1:length(axes(T)[2])) do col_component + append_internal_name.( + Ref(@name(components.data)), + (FieldName(row_component), FieldName(col_component)), + ) + end + end + elseif T <: Geometry.AdjointAxisVector + return unrolled_map( + get_scalar_keys(fieldtype(T, :parent), Val(FT)), + ) do inner_key + (inner_key[2], inner_key[1]) # assumes that adjoints only appear with d/dvec + end + elseif T <: Geometry.AxisVector # special case to avoid recursing into the axis field + return unrolled_map( + get_scalar_keys(fieldtype(T, :components), Val(FT)), + ) do inner_key + ( + append_internal_name(@name(components), inner_key[1]), + inner_key[2], + ) + end + else + return unrolled_flatmap(fieldnames(T)) do inner_name + unrolled_map( + get_scalar_keys(fieldtype(T, inner_name), Val(FT)), + ) do inner_key + ( + append_internal_name(FieldName(inner_name), inner_key[1]), + inner_key[2], + ) + end + end + end +end +if hasfield(Method, :recursion_relation) + dont_limit = (args...) -> true + for m in methods(get_scalar_keys) + m.recursion_relation = dont_limit + end +end + + +""" + scalar_fieldmatrix(field_matrix::FieldMatrix) + +Constructs a `FieldNameDict` where the keys and entries are views +of the entries of `field_matrix`, which corresponding to the +`FT` typed components of entries of `field_matrix`. + +# Example usage +```julia +e¹² = Geometry.Covariant12Vector(1.6, 0.7) +e₃ = Geometry.Contravariant3Vector(1.0) +e³ = Geometry.Covariant3Vector(1) +ᶜᶜmat3 = fill(TridiagonalMatrixRow(2.0, 3.2, 2.1), center_space) +ᶜᶠmat2 = fill(BidiagonalMatrixRow(4.3, 1.7), center_space) +ᶜᶜmat3_uₕ_scalar = ᶜᶜmat3 .* (e¹²,) +ρχ_unit = (;ρq_liq = 1.0, ρq_ice = 1.0) +ᶜᶠmat2_ρχ_u₃ = map(Base.Fix1(map, Base.Fix2(⊠, ρχ_unit ⊠ e₃')), ᶜᶠmat2) + +A = MatrixFields.FieldMatrix( + (@name(c.ρχ), @name(f.u₃)) => ᶜᶠmat2_ρχ_u₃, + (@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_uₕ_scalar, +) + +A_scalar = MatrixFields.scalar_fieldmatrix(A) +keys(A_scalar) +# Output: +# (@name(c.ρχ.ρq_liq), @name(f.u₃.:(1))) +# (@name(c.ρχ.ρq_ice), @name(f.u₃.:(1))) +# (@name(c.uₕ.:(1)), @name(c.sgsʲs.:(1).ρa)) +# (@name(c.uₕ.:(2)), @name(c.sgsʲs.:(1).ρa)) +``` +""" +function scalar_fieldmatrix(field_matrix::FieldMatrix) + scalar_keys = get_scalar_keys(field_matrix) + entries = unrolled_map(scalar_keys.values) do key + field_matrix[key] + end + return FieldNameDict(scalar_keys, entries) +end + replace_name_tree(dict::FieldNameDict, name_tree) = FieldNameDict(replace_name_tree(keys(dict), name_tree), values(dict)) @@ -546,8 +856,8 @@ function Base.Broadcast.broadcasted( ) product_value = scaling_value(entry1) * scaling_value(entry2) product_value isa Number ? - UniformScaling(product_value) : - DiagonalMatrixRow(product_value) + (UniformScaling(product_value),) : + (DiagonalMatrixRow(product_value),) elseif entry1 isa ScalingFieldMatrixEntry Base.Broadcast.broadcasted(*, (scaling_value(entry1),), entry2) elseif entry2 isa ScalingFieldMatrixEntry diff --git a/test/MatrixFields/field_matrix_solvers.jl b/test/MatrixFields/field_matrix_solvers.jl index 524d5b8b9e..4f65861fd5 100644 --- a/test/MatrixFields/field_matrix_solvers.jl +++ b/test/MatrixFields/field_matrix_solvers.jl @@ -376,15 +376,10 @@ end center_gs_unit = (; dry_center_gs_unit..., ρatke = 1, ρχ = ρχ_unit) center_sgsʲ_unit = (; ρa = 1, ρae_tot = 1, ρaχ = ρaχ_unit) - ᶜᶜmat3_uₕ_scalar = ᶜᶜmat3 .* (e¹²,) ᶠᶜmat2_u₃_scalar = ᶠᶜmat2 .* (e³,) ᶜᶠmat2_scalar_u₃ = ᶜᶠmat2 .* (e₃',) - ᶜᶠmat2_uₕ_u₃ = ᶜᶠmat2 .* (e¹² * e₃',) ᶠᶠmat3_u₃_u₃ = ᶠᶠmat3 .* (e³ * e₃',) - ᶜᶜmat3_ρχ_scalar = map(Base.Fix1(map, Base.Fix2(⊠, ρχ_unit)), ᶜᶜmat3) - ᶜᶜmat3_ρaχ_scalar = map(Base.Fix1(map, Base.Fix2(⊠, ρaχ_unit)), ᶜᶜmat3) ᶜᶠmat2_ρχ_u₃ = map(Base.Fix1(map, Base.Fix2(⊠, ρχ_unit ⊠ e₃')), ᶜᶠmat2) - ᶜᶠmat2_ρaχ_u₃ = map(Base.Fix1(map, Base.Fix2(⊠, ρaχ_unit ⊠ e₃')), ᶜᶠmat2) # We need to use Fix1 and Fix2 instead of defining anonymous functions in # order for the result of map to be inferrable. @@ -464,7 +459,10 @@ end ), b = b_moist_dycore_diagnostic_edmf, ) - + ( + A_moist_dycore_prognostic_edmf_prognostic_surface, + b_moist_dycore_prognostic_edmf_prognostic_surface, + ) = dycore_prognostic_EDMF_FieldMatrix(FT) test_field_matrix_solver(; test_name = "similar solve to ClimaAtmos's moist dycore + prognostic \ EDMF + prognostic surface temperature with implicit \ @@ -478,53 +476,7 @@ end n_iters = 6, ), ), - A = MatrixFields.FieldMatrix( - # GS-GS blocks: - (@name(sfc), @name(sfc)) => I, - (@name(c.ρ), @name(c.ρ)) => I, - (@name(c.ρe_tot), @name(c.ρe_tot)) => ᶜᶜmat3, - (@name(c.ρatke), @name(c.ρatke)) => ᶜᶜmat3, - (@name(c.ρχ), @name(c.ρχ)) => ᶜᶜmat3, - (@name(c.uₕ), @name(c.uₕ)) => ᶜᶜmat3, - (@name(c.ρ), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, - (@name(c.ρe_tot), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, - (@name(c.ρatke), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, - (@name(c.ρχ), @name(f.u₃)) => ᶜᶠmat2_ρχ_u₃, - (@name(f.u₃), @name(c.ρ)) => ᶠᶜmat2_u₃_scalar, - (@name(f.u₃), @name(c.ρe_tot)) => ᶠᶜmat2_u₃_scalar, - (@name(f.u₃), @name(f.u₃)) => ᶠᶠmat3_u₃_u₃, - # GS-SGS blocks: - (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρae_tot)) => ᶜᶜmat3, - (@name(c.ρχ.ρq_tot), @name(c.sgsʲs.:(1).ρaχ.ρaq_tot)) => ᶜᶜmat3, - (@name(c.ρχ.ρq_liq), @name(c.sgsʲs.:(1).ρaχ.ρaq_liq)) => ᶜᶜmat3, - (@name(c.ρχ.ρq_ice), @name(c.sgsʲs.:(1).ρaχ.ρaq_ice)) => ᶜᶜmat3, - (@name(c.ρχ.ρq_rai), @name(c.sgsʲs.:(1).ρaχ.ρaq_rai)) => ᶜᶜmat3, - (@name(c.ρχ.ρq_sno), @name(c.sgsʲs.:(1).ρaχ.ρaq_sno)) => ᶜᶜmat3, - (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3, - (@name(c.ρatke), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3, - (@name(c.ρχ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_ρχ_scalar, - (@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_uₕ_scalar, - (@name(c.ρe_tot), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃, - (@name(c.ρatke), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃, - (@name(c.ρχ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_ρχ_u₃, - (@name(c.uₕ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_uₕ_u₃, - (@name(f.u₃), @name(c.sgsʲs.:(1).ρa)) => ᶠᶜmat2_u₃_scalar, - (@name(f.u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠmat3_u₃_u₃, - # SGS-SGS blocks: - (@name(c.sgsʲs.:(1).ρa), @name(c.sgsʲs.:(1).ρa)) => I, - (@name(c.sgsʲs.:(1).ρae_tot), @name(c.sgsʲs.:(1).ρae_tot)) => I, - (@name(c.sgsʲs.:(1).ρaχ), @name(c.sgsʲs.:(1).ρaχ)) => I, - (@name(c.sgsʲs.:(1).ρa), @name(f.sgsʲs.:(1).u₃)) => - ᶜᶠmat2_scalar_u₃, - (@name(c.sgsʲs.:(1).ρae_tot), @name(f.sgsʲs.:(1).u₃)) => - ᶜᶠmat2_scalar_u₃, - (@name(c.sgsʲs.:(1).ρaχ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_ρaχ_u₃, - (@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρa)) => - ᶠᶜmat2_u₃_scalar, - (@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρae_tot)) => - ᶠᶜmat2_u₃_scalar, - (@name(f.sgsʲs.:(1).u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠmat3_u₃_u₃, - ), + A = A_moist_dycore_prognostic_edmf_prognostic_surface, b = b_moist_dycore_prognostic_edmf_prognostic_surface, ) end diff --git a/test/MatrixFields/field_names.jl b/test/MatrixFields/field_names.jl index 0253dd57bc..6966124e60 100644 --- a/test/MatrixFields/field_names.jl +++ b/test/MatrixFields/field_names.jl @@ -770,9 +770,9 @@ end (@name(a), @name(a)) => -I_CT3XC3, ) - for (vector, matrix, I_foo, I_a) in ( - (vector_of_scalars, matrix_of_scalars, I, I), - (vector_of_vectors, matrix_of_tensors, I_C12XCT12, I_CT3XC3), + for (vector, matrix, I_foo, I_a, is_scalar_test) in ( + (vector_of_scalars, matrix_of_scalars, I, I, true), + (vector_of_vectors, matrix_of_tensors, I_C12XCT12, I_CT3XC3, false), ) @test_all MatrixFields.field_vector_view(vector) == MatrixFields.FieldVectorView( @@ -842,10 +842,13 @@ end @test_all matrix[@name(a.c), @name(a.b)] == zero(I_a) @test_all matrix[@name(foo._value), @name(foo._value)] == matrix[@name(foo), @name(foo)] - - @test_all matrix[@name(foo._value), @name(a.b)] isa - Base.AbstractBroadcasted - @test Base.materialize(matrix[@name(foo._value), @name(a.b)]) == map( + entry = matrix[@name(foo._value), @name(a.b)] + @test_all entry isa ( + is_scalar_test ? MatrixFields.ColumnwiseBandMatrixField : + Base.AbstractBroadcasted + ) + entry = is_scalar_test ? entry : Base.materialize(entry) + @test entry == map( row -> map(foo -> foo.value, row), matrix[@name(foo), @name(a.b)], ) diff --git a/test/MatrixFields/matrix_field_test_utils.jl b/test/MatrixFields/matrix_field_test_utils.jl index 9e65be8830..acaa5a04b2 100644 --- a/test/MatrixFields/matrix_field_test_utils.jl +++ b/test/MatrixFields/matrix_field_test_utils.jl @@ -21,6 +21,10 @@ import ClimaCore: Operators, Quadratures using ClimaCore.MatrixFields +import ClimaCore.Utilities: half +import ClimaCore.RecursiveApply: ⊠ +import LinearAlgebra: I, norm, ldiv!, mul! +import ClimaCore.MatrixFields: @name # Test that an expression is true and that it is also type-stable. macro test_all(expression) @@ -32,7 +36,7 @@ macro test_all(expression) end end -# Compute the minimum time (in seconds) required to run an expression after it +# Compute the minimum time (in seconds) required to run an expression after it # has been compiled. This macro is used instead of @benchmark from # BenchmarkTools.jl because the latter is extremely slow (it appears to keep # triggering recompilations and allocating a lot of memory in the process). @@ -134,6 +138,209 @@ function test_field_broadcast(; end end +# Create a field matrix for a similar solve to ClimaAtmos's moist dycore + prognostic, +# EDMF + prognostic surface temperature with implicit acoustic waves and SGS fluxes +# also returns corresponding FieldVector +function dycore_prognostic_EDMF_FieldMatrix( + ::Type{FT}, + center_space = nothing, + face_space = nothing, +) where {FT} + seed!(1) # For reproducibility with random fields + if isnothing(center_space) || isnothing(face_space) + center_space, face_space = test_spaces(FT) + end + surface_space = Spaces.level(face_space, half) + surface_space = Spaces.level(face_space, half) + sfc_vec = random_field(FT, surface_space) + ᶜvec = random_field(FT, center_space) + ᶠvec = random_field(FT, face_space) + λ = 10 + ᶜᶜmat1 = random_field(DiagonalMatrixRow{FT}, center_space) ./ λ .+ (I,) + ᶜᶠmat2 = random_field(BidiagonalMatrixRow{FT}, center_space) ./ λ + ᶠᶜmat2 = random_field(BidiagonalMatrixRow{FT}, face_space) ./ λ + ᶜᶜmat3 = random_field(TridiagonalMatrixRow{FT}, center_space) ./ λ .+ (I,) + ᶠᶠmat3 = random_field(TridiagonalMatrixRow{FT}, face_space) ./ λ .+ (I,) + # Geometry.Covariant123Vector(1, 2, 3) * Geometry.Covariant12Vector(1, 2)' + e¹² = Geometry.Covariant12Vector(1, 1) + e₁₂ = Geometry.Contravariant12Vector(1, 1) + e³ = Geometry.Covariant3Vector(1) + e₃ = Geometry.Contravariant3Vector(1) + + ρχ_unit = (; ρq_tot = 1, ρq_liq = 1, ρq_ice = 1, ρq_rai = 1, ρq_sno = 1) + ρaχ_unit = + (; ρaq_tot = 1, ρaq_liq = 1, ρaq_ice = 1, ρaq_rai = 1, ρaq_sno = 1) + + + ᶠᶜmat2_u₃_scalar = ᶠᶜmat2 .* (e³,) + ᶜᶠmat2_scalar_u₃ = ᶜᶠmat2 .* (e₃',) + ᶠᶠmat3_u₃_u₃ = ᶠᶠmat3 .* (e³ * e₃',) + ᶜᶠmat2_ρχ_u₃ = map(Base.Fix1(map, Base.Fix2(⊠, ρχ_unit ⊠ e₃')), ᶜᶠmat2) + ᶜᶜmat3_uₕ_scalar = ᶜᶜmat3 .* (e¹²,) + ᶜᶜmat3_uₕ_uₕ = + ᶜᶜmat3 .* ( + Geometry.Covariant12Vector(1, 0) * + Geometry.Contravariant12Vector(1, 0)' + + Geometry.Covariant12Vector(0, 1) * + Geometry.Contravariant12Vector(0, 1)', + ) + ᶜᶠmat2_uₕ_u₃ = ᶜᶠmat2 .* (e¹² * e₃',) + ᶜᶜmat3_ρχ_scalar = map(Base.Fix1(map, Base.Fix2(⊠, ρχ_unit)), ᶜᶜmat3) + ᶜᶜmat3_ρaχ_scalar = map(Base.Fix1(map, Base.Fix2(⊠, ρaχ_unit)), ᶜᶜmat3) + ᶜᶠmat2_ρaχ_u₃ = map(Base.Fix1(map, Base.Fix2(⊠, ρaχ_unit ⊠ e₃')), ᶜᶠmat2) + + dry_center_gs_unit = (; ρ = 1, ρe_tot = 1, uₕ = e¹²) + center_gs_unit = (; dry_center_gs_unit..., ρatke = 1, ρχ = ρχ_unit) + center_sgsʲ_unit = (; ρa = 1, ρae_tot = 1, ρaχ = ρaχ_unit) + + b = Fields.FieldVector(; + sfc = sfc_vec .* ((; T = 1),), + c = ᶜvec .* ((; center_gs_unit..., sgsʲs = (center_sgsʲ_unit,)),), + f = ᶠvec .* ((; u₃ = e³, sgsʲs = ((; u₃ = e³),)),), + ) + A = MatrixFields.FieldMatrix( + # GS-GS blocks: + (@name(sfc), @name(sfc)) => I, + (@name(c.ρ), @name(c.ρ)) => I, + (@name(c.ρe_tot), @name(c.ρe_tot)) => ᶜᶜmat3, + (@name(c.ρatke), @name(c.ρatke)) => ᶜᶜmat3, + (@name(c.ρχ), @name(c.ρχ)) => ᶜᶜmat3, + (@name(c.uₕ), @name(c.uₕ)) => ᶜᶜmat3_uₕ_uₕ, + (@name(c.ρ), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρe_tot), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρatke), @name(f.u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρχ), @name(f.u₃)) => ᶜᶠmat2_ρχ_u₃, + (@name(f.u₃), @name(c.ρ)) => ᶠᶜmat2_u₃_scalar, + (@name(f.u₃), @name(c.ρe_tot)) => ᶠᶜmat2_u₃_scalar, + (@name(f.u₃), @name(f.u₃)) => ᶠᶠmat3_u₃_u₃, + # GS-SGS blocks: + (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρae_tot)) => ᶜᶜmat3, + (@name(c.ρχ.ρq_tot), @name(c.sgsʲs.:(1).ρaχ.ρaq_tot)) => ᶜᶜmat3, + (@name(c.ρχ.ρq_liq), @name(c.sgsʲs.:(1).ρaχ.ρaq_liq)) => ᶜᶜmat3, + (@name(c.ρχ.ρq_ice), @name(c.sgsʲs.:(1).ρaχ.ρaq_ice)) => ᶜᶜmat3, + (@name(c.ρχ.ρq_rai), @name(c.sgsʲs.:(1).ρaχ.ρaq_rai)) => ᶜᶜmat3, + (@name(c.ρχ.ρq_sno), @name(c.sgsʲs.:(1).ρaχ.ρaq_sno)) => ᶜᶜmat3, + (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3, + (@name(c.ρatke), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3, + (@name(c.ρχ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_ρχ_scalar, + (@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜmat3_uₕ_scalar, + (@name(c.ρe_tot), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρatke), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_scalar_u₃, + (@name(c.ρχ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_ρχ_u₃, + (@name(c.uₕ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_uₕ_u₃, + (@name(f.u₃), @name(c.sgsʲs.:(1).ρa)) => ᶠᶜmat2_u₃_scalar, + (@name(f.u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠmat3_u₃_u₃, + # SGS-SGS blocks: + (@name(c.sgsʲs.:(1).ρa), @name(c.sgsʲs.:(1).ρa)) => I, + (@name(c.sgsʲs.:(1).ρae_tot), @name(c.sgsʲs.:(1).ρae_tot)) => I, + (@name(c.sgsʲs.:(1).ρaχ), @name(c.sgsʲs.:(1).ρaχ)) => I, + (@name(c.sgsʲs.:(1).ρa), @name(f.sgsʲs.:(1).u₃)) => + ᶜᶠmat2_scalar_u₃, + (@name(c.sgsʲs.:(1).ρae_tot), @name(f.sgsʲs.:(1).u₃)) => + ᶜᶠmat2_scalar_u₃, + (@name(c.sgsʲs.:(1).ρaχ), @name(f.sgsʲs.:(1).u₃)) => ᶜᶠmat2_ρaχ_u₃, + (@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρa)) => + ᶠᶜmat2_u₃_scalar, + (@name(f.sgsʲs.:(1).u₃), @name(c.sgsʲs.:(1).ρae_tot)) => + ᶠᶜmat2_u₃_scalar, + (@name(f.sgsʲs.:(1).u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠmat3_u₃_u₃, + ) + return A, b +end + +function scaling_only_dycore_prognostic_EDMF_FieldMatrix( + ::Type{FT}, + center_space = nothing, + face_space = nothing, +) where {FT} + seed!(1) # For reproducibility with random fields + if isnothing(center_space) || isnothing(face_space) + center_space, face_space = test_spaces(FT) + end + surface_space = Spaces.level(face_space, half) + surface_space = Spaces.level(face_space, half) + sfc_vec = random_field(FT, surface_space) + ᶜvec = random_field(FT, center_space) + ᶠvec = random_field(FT, face_space) + λ = 10 + # Geometry.Covariant123Vector(1, 2, 3) * Geometry.Covariant12Vector(1, 2)' + e¹² = Geometry.Covariant12Vector(FT(1), FT(1)) + e₁₂ = Geometry.Contravariant12Vector(FT(1), FT(1)) + e³ = Geometry.Covariant3Vector(FT(1)) + e₃ = Geometry.Contravariant3Vector(FT(1)) + + ρχ_unit = (; + ρq_tot = FT(1), + ρq_liq = FT(1), + ρq_ice = FT(1), + ρq_rai = FT(1), + ρq_sno = FT(1), + ) + ρaχ_unit = (; + ρaq_tot = FT(1), + ρaq_liq = FT(1), + ρaq_ice = FT(1), + ρaq_rai = FT(1), + ρaq_sno = FT(1), + ) + + + + ᶠᶠu₃_u₃ = DiagonalMatrixRow(e³ * e₃') + ᶜᶜuₕ_scalar = DiagonalMatrixRow(e¹²) + ᶜᶜuₕ_uₕ = DiagonalMatrixRow( + Geometry.Covariant12Vector(FT(1), FT(0)) * + Geometry.Contravariant12Vector(FT(1), FT(0))' + + Geometry.Covariant12Vector(FT(0), FT(1)) * + Geometry.Contravariant12Vector(FT(0), FT(1))', + ) + ᶜᶜρχ_scalar = DiagonalMatrixRow(ρχ_unit) + ᶜᶜρaχ_scalar = DiagonalMatrixRow(ρaχ_unit) + + dry_center_gs_unit = (; ρ = FT(1), ρe_tot = FT(1), uₕ = e¹²) + center_gs_unit = (; dry_center_gs_unit..., ρatke = FT(1), ρχ = ρχ_unit) + center_sgsʲ_unit = (; ρa = FT(1), ρae_tot = FT(1), ρaχ = ρaχ_unit) + + b = Fields.FieldVector(; + sfc = sfc_vec .* ((; T = 1),), + c = ᶜvec .* ((; center_gs_unit..., sgsʲs = (center_sgsʲ_unit,)),), + f = ᶠvec .* ((; u₃ = e³, sgsʲs = ((; u₃ = e³),)),), + ) + A = MatrixFields.FieldMatrix( + # GS-GS blocks: + (@name(sfc), @name(sfc)) => I, + (@name(c.ρ), @name(c.ρ)) => I, + (@name(c.uₕ), @name(c.uₕ)) => ᶜᶜuₕ_uₕ, + (@name(f.u₃), @name(f.u₃)) => ᶠᶠu₃_u₃, + # GS-SGS blocks: + (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρae_tot)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρχ.ρq_tot), @name(c.sgsʲs.:(1).ρaχ.ρaq_tot)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρχ.ρq_liq), @name(c.sgsʲs.:(1).ρaχ.ρaq_liq)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρχ.ρq_ice), @name(c.sgsʲs.:(1).ρaχ.ρaq_ice)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρχ.ρq_rai), @name(c.sgsʲs.:(1).ρaχ.ρaq_rai)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρχ.ρq_sno), @name(c.sgsʲs.:(1).ρaχ.ρaq_sno)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρe_tot), @name(c.sgsʲs.:(1).ρa)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρatke), @name(c.sgsʲs.:(1).ρa)) => + DiagonalMatrixRow(rand(FT)), + (@name(c.ρχ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜρχ_scalar, + (@name(c.uₕ), @name(c.sgsʲs.:(1).ρa)) => ᶜᶜuₕ_scalar, + (@name(f.u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠu₃_u₃, + # SGS-SGS blocks: + (@name(c.sgsʲs.:(1).ρa), @name(c.sgsʲs.:(1).ρa)) => I, + (@name(c.sgsʲs.:(1).ρae_tot), @name(c.sgsʲs.:(1).ρae_tot)) => I, + (@name(c.sgsʲs.:(1).ρaχ), @name(c.sgsʲs.:(1).ρaχ)) => I, + (@name(f.sgsʲs.:(1).u₃), @name(f.sgsʲs.:(1).u₃)) => ᶠᶠu₃_u₃, + ) + return A, b +end + # Generate extruded finite difference spaces for testing. Include topography # when possible. function test_spaces(::Type{FT}) where {FT} diff --git a/test/MatrixFields/scalar_fieldmatrix.jl b/test/MatrixFields/scalar_fieldmatrix.jl new file mode 100644 index 0000000000..0ba7a75fe4 --- /dev/null +++ b/test/MatrixFields/scalar_fieldmatrix.jl @@ -0,0 +1,182 @@ +using Test +using JET + +import ClimaCore: + Geometry, Domains, Meshes, Spaces, Fields, MatrixFields, CommonSpaces +import ClimaCore.Utilities: half +import ClimaComms +import ClimaCore.MatrixFields: @name +ClimaComms.@import_required_backends +include("matrix_field_test_utils.jl") + +@testset "field_offset_and_type" begin + FT = Float64 + struct Singleton{T} + x::T + end + struct TwoFields{T1, T2} + x::T1 + y::T2 + end + function test_field_offset_and_type( + name, + ::Type{T}, + ::Type{S}, + expected_offset, + ::Type{E}, + key_error; + apply_zero = false, + ) where {T, S, E} + @test_all MatrixFields.field_offset_and_type(name, T, S, key_error) == + (expected_offset, E, apply_zero) + end + test_field_offset_and_type( + (@name(x), @name()), + FT, + Singleton{Singleton{Singleton{Singleton{FT}}}}, + 0, + Singleton{Singleton{Singleton{FT}}}, + KeyError(@name(x.x.x.x)), + ) + test_field_offset_and_type( + (@name(), @name(x.x.x.x)), + FT, + Singleton{Singleton{Singleton{Singleton{FT}}}}, + 0, + FT, + KeyError(@name(x.x.x.x)), + ) + test_field_offset_and_type( + (@name(), @name(y.x)), + FT, + TwoFields{TwoFields{FT, FT}, TwoFields{FT, FT}}, + 2, + FT, + KeyError(@name(y.x)), + ) + test_field_offset_and_type( + (@name(y), @name(y)), + FT, + TwoFields{ + TwoFields{FT, FT}, + TwoFields{FT, TwoFields{FT, Singleton{FT}}}, + }, + 3, + TwoFields{FT, Singleton{FT}}, + KeyError(@name(y.y.x)), + ) + test_field_offset_and_type( + (@name(y.k), @name(y.k)), + FT, + TwoFields{ + TwoFields{FT, FT}, + TwoFields{FT, TwoFields{FT, Singleton{FT}}}, + }, + 3, + TwoFields{FT, Singleton{FT}}, + KeyError(@name(y.y.x)), + ) + test_field_offset_and_type( + (@name(y.k.g), @name(y.k.l)), + FT, + TwoFields{ + TwoFields{FT, FT}, + TwoFields{FT, TwoFields{FT, Singleton{FT}}}, + }, + 3, + TwoFields{FT, Singleton{FT}}, + KeyError(@name(y.y.x)), + apply_zero = true, + ) + test_field_offset_and_type( + (@name(y.y), @name(x)), + FT, + TwoFields{ + TwoFields{FT, FT}, + TwoFields{FT, TwoFields{FT, Singleton{FT}}}, + }, + 3, + FT, + KeyError(@name(y.y.x.x)), + ) + test_field_offset_and_type( + (@name(y.y), @name(y.x)), + FT, + TwoFields{ + TwoFields{FT, FT}, + TwoFields{FT, TwoFields{FT, Singleton{FT}}}, + }, + 4, + FT, + KeyError(@name(y.y.y.x)), + ) +end + +@testset "fieldmatrix to scalar fieldmatrix unit tests" begin + FT = Float64 + for (A, _) in ( + dycore_prognostic_EDMF_FieldMatrix(FT), + scaling_only_dycore_prognostic_EDMF_FieldMatrix(FT), + ) + @test all( + entry -> + entry isa MatrixFields.UniformScaling || + eltype(eltype(entry)) <: FT, + MatrixFields.scalar_fieldmatrix(A).entries, + ) + test_get(A, entry, key) = A[key] === entry + for (key, entry) in MatrixFields.scalar_fieldmatrix(A) + @test test_get(A, entry, key) + @test (@allocated test_get(A, entry, key)) == 0 + @test_opt test_get(A, entry, key) + end + + function scalar_fieldmatrix_wrapper(field_matrix_of_tensors) + A_scalar = MatrixFields.scalar_fieldmatrix(field_matrix_of_tensors) + return nothing + end + + scalar_fieldmatrix_wrapper(A) + @test (@allocated scalar_fieldmatrix_wrapper(A)) == 0 + @test_opt MatrixFields.scalar_fieldmatrix(A) + end +end + +@testset "implicit tensor structure optimization indexing" begin + FT = Float64 + center_space = test_spaces(FT)[1] + for (maybe_copy, maybe_to_field) in + ((identity, identity), (copy, x -> fill(x, center_space))) + A = MatrixFields.FieldMatrix( + (@name(c.uₕ), @name(c.uₕ)) => + maybe_to_field(DiagonalMatrixRow(FT(2))), + (@name(foo), @name(bar)) => maybe_to_field( + DiagonalMatrixRow( + Geometry.Covariant12Vector(FT(1), FT(2)) * + Geometry.Contravariant12Vector(FT(1), FT(2))', + ), + ), + ) + @test A[( + @name(c.uₕ.components.data.:1), + @name(c.uₕ.components.data.:1) + )] == A[(@name(c.uₕ), @name(c.uₕ))] + @test maybe_copy( + A[(@name(c.uₕ.components.data.:2), @name(c.uₕ.components.data.:1))], + ) == maybe_to_field(DiagonalMatrixRow(FT(0))) + @test maybe_copy(A[(@name(foo.dog), @name(bar.dog))]) == + A[(@name(foo), @name(bar))] + @test maybe_copy(A[(@name(foo.cat), @name(bar.dog))]) == + zero(A[(@name(foo), @name(bar))]) + @test A[( + @name(foo.dog.components.data.:1), + @name(bar.dog.components.data.:2) + )] == maybe_to_field(DiagonalMatrixRow(FT(2))) + @test maybe_copy( + A[( + @name(foo.dog.components.data.:1), + @name(bar.cat.components.data.:2) + )], + ) == maybe_to_field(DiagonalMatrixRow(FT(0))) + end +end