Skip to content

Tr/refactor fm internal index #2346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,18 @@ steps:
agents:
slurm_gpus: 1

- label: "Unit: scalar_fieldmatrix (CPU)"
key: cpu_scalar_fieldmatrix
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/MatrixFields/scalar_fieldmatrix.jl"

- label: "Unit: mscalar_fieldmatrix (GPU)"
key: gpu_scalar_fieldmatrix
command: "julia --color=yes --project=.buildkite test/MatrixFields/scalar_fieldmatrix.jl"
env:
CLIMACOMMS_DEVICE: "CUDA"
agents:
slurm_gpus: 1

- group: "Unit: MatrixFields - broadcasting (CPU)"
steps:

Expand Down
95 changes: 95 additions & 0 deletions docs/src/matrix_fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ preconditioner_cache
check_preconditioner
lazy_or_concrete_preconditioner
apply_preconditioner
get_scalar_keys
field_offset_and_type
```

## Utilities
Expand All @@ -98,4 +100,97 @@ column_field2array
column_field2array_view
field2arrays
field2arrays_view
scalar_fieldmatrix
```

## Indexing a FieldMatrix

A FieldMatrix entry can be:

- An `UniformScaling`, which contains a `Number`
- A `DiagonalMatrixRow`, which can contain aything
- A `ColumnwiseBandMatrixField`, where each row is a [`BandMatrixRow`](@ref) where the band element type is representable with the space's base number type.

If an entry contains a composite type, the fields of that type can be extracted.
This is also true for nested composite types.

For example:

```@example 1
using ClimaCore.CommonSpaces # hide
import ClimaCore: MatrixFields, Quadratures # hide
import ClimaCore.MatrixFields: @name # hide
space = Box3DSpace(; # hide
z_elem = 3, # hide
x_min = 0, # hide
x_max = 1, # hide
y_min = 0, # hide
y_max = 1, # hide
z_min = 0, # hide
z_max = 10, # hide
periodic_x = false, # hide
periodic_y = false, # hide
n_quad_points = 1, # hide
quad = Quadratures.GL{1}(), # hide
x_elem = 1, # hide
y_elem = 2, # hide
staggering = CellCenter() # hide
) # hide
nt_entry_field = fill(MatrixFields.DiagonalMatrixRow((; foo = 1.0, bar = 2.0)), space)
nt_fieldmatrix = MatrixFields.FieldMatrix((@name(a), @name(b)) => nt_entry_field)
nt_fieldmatrix[(@name(a), @name(b))]
```

The internal values of the named tuples can be extracted with

```@example 1
nt_fieldmatrix[(@name(a.foo), @name(b))]
```

and

```@example 1
nt_fieldmatrix[(@name(a.bar), @name(b))]
```

### Further Indexing Details

Let key `(@name(name1), @name(name2))` correspond to entry `sample_entry` in `FieldMatrix` `A`.
An example of this is:

```julia
A = MatrixFields.FieldMatrix((@name(name1), @name(name2)) => sample_entry)
```

Now consider what happens indexing `A` with the key `(@name(name1.foo.bar.buz), @name(name2.biz.bop.fud))`.

First, a function searches the keys of `A` for a key that `(@name(foo.bar.buz), @name(biz.bop.fud))`
is a child of. In this example, `(@name(foo.bar.buz), @name(biz.bop.fud))` is a child of
the key `(@name(name1), @name(name2))`, and
`(@name(foo.bar.buz), @name(biz.bop.fud))` is referred to as the internal key.

Next, the entry that `(@name(name1), @name(name2))` is paired with is recursively indexed
by the internal key.

The recursive indexing of an internal entry given some entry `entry` and internal_key `internal_name_pair`
works as follows:

1. If the `internal_name_pair` is blank, return `entry`
2. If the element type of each band of `entry` is an `Axis2Tensor`, and `internal_name_pair` is of the form
`(@name(components.data.1...), @name(components.data.2...))` (potentially with different numbers),
then extract the specified component, and recurse on it with the remaining `internal_name_pair`.
3. If the element type of each band of `entry` is a `Geometry.AdjointAxisVector`, then recurse on the parent of the adjoint.
4. If `internal_name_pair[1]` is not empty, and the first name in it is a field of the element type of each band of `entry`,
extract that field from `entry`, and recurse on the it with the remaining names of `internal_name_pair[1]` and all of `internal_name_pair[2]`
5. If `internal_name_pair[2]` is not empty, and the first name in it is a field of the element type of each row of `entry`,
extract that field from `entry`, and recurse on the it with all of `internal_name_pair[1]` and the remaining names of `internal_name_pair[2]`
6. At this point, if none of the previous cases are true, both `internal_name_pair[1]` and `internal_name_pair[2]` should be
non-empty, and it is assumed that `entry` is being used to implicitly represent some tensor structure. If the first name in
`internal_name_pair[1]` is equivalent to `internal_name_pair[2]`, then both the first names are dropped, and entry is recursed onto.
If the first names are different, both the first names are dropped, and the zero of entry is recursed onto.

When the entry is a `ColumnWiseBandMatrixField`, indexing it will return a broadcasted object in
the following situations:

1. The internal key indexes to a type different than the basetype of the entry
2. The internal key indexes to a zero-ed value
5 changes: 4 additions & 1 deletion src/Geometry/axistensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ coordinate_axis(::Type{<:LatLongPoint}) = (1, 2)

coordinate_axis(coord::AbstractPoint) = coordinate_axis(typeof(coord))

@inline idxin(I::Tuple{Int}, i::Int) = 1
@inline idxin(I::Tuple{Int}, i::Int) = I[1] == i ? 1 : nothing
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the original function is correct here because it behaves differently than the other idxin functions.
I found this when I was experimenting with using axis symbols to index with vs numbers.


@inline function idxin(I::Tuple{Int, Int}, i::Int)
@inbounds begin
Expand Down Expand Up @@ -308,6 +308,9 @@ Base.zero(::Type{AdjointAxisTensor{T, N, A, S}}) where {T, N, A, S} =

const AdjointAxisVector{T, A1, S} = Adjoint{T, AxisVector{T, A1, S}}

const AxisVectorOrAdj{T, A, S} =
Union{AxisVector{T, A, S}, AdjointAxisVector{T, A, S}}

Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int) =
getindex(components(va), i)
Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int, j::Int) =
Expand Down
1 change: 1 addition & 0 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import ..Utilities: PlusHalf, half
import ..RecursiveApply:
rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv
import ..RecursiveApply: ⊠, ⊞, ⊟
import ..DataLayouts
import ..DataLayouts: AbstractData
import ..DataLayouts: vindex
import ..Geometry
Expand Down
Loading
Loading