Skip to content

Commit

Permalink
Switch to using Unrolled.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Dec 5, 2023
1 parent 3ac7026 commit 61037c4
Show file tree
Hide file tree
Showing 16 changed files with 239 additions and 247 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1"

[weakdeps]
Expand Down
8 changes: 7 additions & 1 deletion benchmarks/bickleyjet/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.6"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "Unrolled", "WeakValueDicts"]
path = "../.."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.11.1"
Expand Down Expand Up @@ -1265,6 +1265,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
8 changes: 7 additions & 1 deletion docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.6"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "Unrolled", "WeakValueDicts"]
path = ".."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.11.2"
Expand Down Expand Up @@ -2577,6 +2577,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
8 changes: 7 additions & 1 deletion examples/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.6"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "Unrolled", "WeakValueDicts"]
path = ".."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.11.2"
Expand Down Expand Up @@ -2128,6 +2128,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
8 changes: 7 additions & 1 deletion perf/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.6"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "WeakValueDicts"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "Memoize", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "Unrolled", "WeakValueDicts"]
path = ".."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.11.2"
Expand Down Expand Up @@ -2194,6 +2194,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
2 changes: 1 addition & 1 deletion src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import ..Topologies
import ..Grids: ColumnIndex
import ..Spaces: Spaces, AbstractSpace, AbstractPointSpace
import ..Geometry: Geometry, Cartesian12Vector
import ..Utilities: PlusHalf, half
import ..Utilities: PlusHalf, half, unrolled_map

using ..RecursiveApply
using CUDA
Expand Down
23 changes: 6 additions & 17 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,27 +260,16 @@ LinearAlgebra.ldiv!(A::LinearAlgebra.LU, x::FieldVector) =
x .= LinearAlgebra.ldiv!(A, Vector(x))

function LinearAlgebra.norm_sqr(x::FieldVector)
Base.sum(value -> LinearAlgebra.norm_sqr(backing_array(value)), _values(x))
value_norm_sqrs = unrolled_map(_values(x)) do value
LinearAlgebra.norm_sqr(backing_array(value))
end
return sum(value_norm_sqrs; init = zero(eltype(x)))
end
function LinearAlgebra.norm(x::FieldVector)
sqrt(LinearAlgebra.norm_sqr(x))
end

import ClimaComms

ClimaComms.array_type(x::FieldVector) = _array_type(x)

@inline _array_type(x::FieldVector) = _array_type(x, propertynames(x))
@inline _array_type(x::FieldVector, pns::Tuple{}) = Any

@inline _array_type(x::Field) = ClimaComms.array_type(x)
@inline _array_type(x::FieldVector, sym::Symbol) =
_array_type(getproperty(x, sym))

@inline _array_type(x::FieldVector, pns::Tuple{Symbol}) =
_array_type(getproperty(x, first(pns)))

@inline _array_type(x::FieldVector, pns::Tuple) = promote_type(
_array_type(getproperty(x, first(pns))),
_array_type(x, Base.tail(pns)),
)
ClimaComms.array_type(x::FieldVector) =
promote_type(unrolled_map(ClimaComms.array_type, _values(x))...)
7 changes: 5 additions & 2 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ import RecursiveArrayTools: recursive_bottom_eltype
import KrylovKit
import ClimaComms

import ..Utilities: PlusHalf, half
import ..Utilities: PlusHalf, half, unrolled_take, unrolled_drop
import ..Utilities: unrolled_foreach, unrolled_map, unrolled_reduce
import ..Utilities: unrolled_in, unrolled_any, unrolled_all, unrolled_unique
import ..Utilities: unrolled_flatten, unrolled_flatmap, unrolled_product
import ..Utilities: unrolled_filter, unrolled_split, unrolled_findonly
import ..RecursiveApply:
rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv
import ..RecursiveApply: , ,
Expand Down Expand Up @@ -94,7 +98,6 @@ include("matrix_multiplication.jl")
include("lazy_operators.jl")
include("operator_matrices.jl")
include("field2arrays.jl")
include("unrolled_functions.jl")
include("field_name.jl")
include("field_name_set.jl")
include("field_name_dict.jl")
Expand Down
51 changes: 25 additions & 26 deletions src/MatrixFields/field_name.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,20 @@ is_child_name(
::FieldName{parent_name_chain},
) where {child_name_chain, parent_name_chain} =
length(child_name_chain) >= length(parent_name_chain) &&
child_name_chain[1:length(parent_name_chain)] == parent_name_chain
unrolled_take(child_name_chain, Val(length(parent_name_chain))) ==
parent_name_chain

names_are_overlapping(name1, name2) =
is_overlapping_name(name1, name2) =
is_child_name(name1, name2) || is_child_name(name2, name1)

extract_internal_name(
child_name::FieldName{child_name_chain},
parent_name::FieldName{parent_name_chain},
) where {child_name_chain, parent_name_chain} =
is_child_name(child_name, parent_name) ?
FieldName(child_name_chain[(length(parent_name_chain) + 1):end]...) :
error("$child_name is not a child name of $parent_name")
FieldName(
unrolled_drop(child_name_chain, Val(length(parent_name_chain)))...,
) : error("$child_name is not a child name of $parent_name")

append_internal_name(
::FieldName{name_chain},
Expand Down Expand Up @@ -118,41 +120,38 @@ struct FieldNameTreeNode{V <: FieldName, S <: NTuple{<:Any, FieldNameTree}} <:
subtrees::S
end

FieldNameTree(x) = make_subtree_at_name(x, @name())
function make_subtree_at_name(x, name)
FieldNameTree(x) = subtree_at_name(x, @name())
function subtree_at_name(x, name)
internal_names = top_level_names(get_field(x, name))
isempty(internal_names) && return FieldNameTreeLeaf(name)
subsubtrees = unrolled_map(internal_names) do internal_name
make_subtree_at_name(x, append_internal_name(name, internal_name))
return if isempty(internal_names)
FieldNameTreeLeaf(name)
else
subsubtrees_at_name = unrolled_map(internal_names) do internal_name
subtree_at_name(x, append_internal_name(name, internal_name))
end
FieldNameTreeNode(name, subsubtrees_at_name)
end
return FieldNameTreeNode(name, subsubtrees)
end

is_valid_name(name, tree::FieldNameTreeLeaf) = name == tree.name
is_valid_name(name, tree::FieldNameTreeNode) =
is_valid_name(name, tree) =
name == tree.name ||
is_child_name(name, tree.name) &&
tree isa FieldNameTreeNode &&
unrolled_any(subtree -> is_valid_name(name, subtree), tree.subtrees)

function child_names(name, tree)
is_valid_name(name, tree) || error("$name is not a valid name")
subtree = get_subtree_at_name(name, tree)
subtree isa FieldNameTreeNode ||
error("FieldNameTree does not contain any child names for $name")
subtree isa FieldNameTreeNode || error("$name does not have child names")
return unrolled_map(subsubtree -> subsubtree.name, subtree.subtrees)
end
get_subtree_at_name(name, tree::FieldNameTreeLeaf) =
name == tree.name ? tree :
error("FieldNameTree does not contain the name $name")
get_subtree_at_name(name, tree::FieldNameTreeNode) =
get_subtree_at_name(name, tree) =
if name == tree.name
tree
elseif is_valid_name(name, tree)
subtree_that_contains_name = unrolled_findonly(tree.subtrees) do subtree
is_child_name(name, subtree.name)
end
get_subtree_at_name(name, subtree_that_contains_name)
else
error("FieldNameTree does not contain the name $name")
subtree = unrolled_findonly(tree.subtrees) do subtree
is_valid_name(name, subtree)
end
get_subtree_at_name(name, subtree)
end

################################################################################
Expand All @@ -175,7 +174,7 @@ if hasfield(Method, :recursion_relation)
for m in methods(wrapped_prop_names)
m.recursion_relation = dont_limit
end
for m in methods(make_subtree_at_name)
for m in methods(subtree_at_name)
m.recursion_relation = dont_limit
end
for m in methods(is_valid_name)
Expand Down
23 changes: 11 additions & 12 deletions src/MatrixFields/field_name_dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ Base.keys(dict::FieldNameDict) = dict.keys
Base.values(dict::FieldNameDict) = dict.entries

Base.pairs(dict::FieldNameDict) =
unrolled_map(unrolled_zip(keys(dict).values, values(dict))) do key_entry_tup
key_entry_tup[1] => key_entry_tup[2]
end
unrolled_map((key, value) -> key => value, keys(dict).values, values(dict))

Base.length(dict::FieldNameDict) = length(keys(dict))

Expand All @@ -112,15 +110,16 @@ function Base.getindex(dict::FieldNameDict, key)
return get_internal_entry(entry′, get_internal_key(key, key′))
end

get_internal_key(name1::FieldName, name2::FieldName) =
extract_internal_name(name1, name2)
get_internal_key(name_pair1::FieldNamePair, name_pair2::FieldNamePair) = (
extract_internal_name(name_pair1[1], name_pair2[1]),
extract_internal_name(name_pair1[2], name_pair2[2]),
get_internal_key(child_name::FieldName, name::FieldName) =
extract_internal_name(child_name, name)
get_internal_key(child_name_pair::FieldNamePair, name_pair::FieldNamePair) = (
extract_internal_name(child_name_pair[1], name_pair[1]),
extract_internal_name(child_name_pair[2], name_pair[2]),
)

unsupported_internal_entry_error(::T, key) where {T} =
error("Unsupported call to get_internal_entry(<$(T.name.name)>, $key)")
unsupported_internal_entry_error(entry, key) =
error("Unsupported FieldNameDict operation: \
get_internal_entry(<$(typeof(entry).name.name)>, $key)")

get_internal_entry(entry, name::FieldName) = get_field(entry, name)
get_internal_entry(entry, name_pair::FieldNamePair) =
Expand Down Expand Up @@ -227,7 +226,7 @@ function field_vector_view(x, name_tree = FieldNameTree(x))
return FieldVectorView(keys_of_fields, entries)
end
names_of_fields(x, name_tree) =
unrolled_mapflatten(top_level_names(x)) do name
unrolled_flatmap(top_level_names(x)) do name
entry = get_field(x, name)
if entry isa Fields.Field
(name,)
Expand Down Expand Up @@ -332,7 +331,7 @@ Base.Broadcast.broadcasted(
arg3,
args...,
) =
unrolled_foldl((arg1, arg2, arg3, args...)) do arg1′, arg2′
foldl((arg1, arg2, arg3, args...)) do arg1′, arg2′
Base.Broadcast.broadcasted(f, arg1′, arg2′)
end

Expand Down
Loading

0 comments on commit 61037c4

Please sign in to comment.