Skip to content

Commit

Permalink
Define zero on some MatrixField types
Browse files Browse the repository at this point in the history
Define zero for some MatrixField structs
  • Loading branch information
charleskawczynski committed Jan 18, 2025
1 parent dceaa19 commit a32ab63
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/MatrixFields/field_matrix_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ end
LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂) =
LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂, nothing, nothing, nothing, nothing)

Base.zero(lsc::LazySchurComplement) =
LazySchurComplement(map(fn -> zero(getfield(lsc, fn)), fieldnames(lsc))...)

NVTX.@annotate function lazy_mul(A₂₂′::LazySchurComplement, x₂)
(; A₁₁, A₁₂, A₂₁, A₂₂, alg₁, cache₁, A₁₂_x₂, invA₁₁_A₁₂_x₂) = A₂₂′
zero_rows = setdiff(keys(A₁₂_x₂), matrix_row_keys(keys(A₁₂)))
Expand Down Expand Up @@ -229,6 +232,8 @@ partial pivoting matrix).
"""
struct BlockDiagonalSolve <: FieldMatrixSolverAlgorithm end

Base.zero(alg::BlockDiagonalSolve) = alg

function field_matrix_solver_cache(::BlockDiagonalSolve, A, b)
caches = map(matrix_row_keys(keys(A))) do name
single_field_solver_cache(A[name, name], b[name])
Expand Down Expand Up @@ -315,6 +320,9 @@ BlockLowerTriangularSolve(
alg₂ = BlockDiagonalSolve(),
) = BlockLowerTriangularSolve(names₁, alg₁, alg₂)

Base.zero(alg::BlockLowerTriangularSolve) =
BlockLowerTriangularSolve(alg.names₁, zero(alg.alg₁), zero(alg.alg₂))

function field_matrix_solver_cache(alg::BlockLowerTriangularSolve, A, b)
A₁₁, _, _, A₂₂, b₁, b₂ = partition_blocks(alg.names₁, A, b)
cache₁ = field_matrix_solver_cache(alg.alg₁, A₁₁, b₁)
Expand Down Expand Up @@ -448,6 +456,9 @@ end
SchurComplementReductionSolve(names₁...; alg₁ = BlockDiagonalSolve(), alg₂) =
SchurComplementReductionSolve(names₁, alg₁, alg₂)

Base.zero(alg::SchurComplementReductionSolve) =
SchurComplementReductionSolve(alg.names₁, zero(alg.alg₁), zero(alg.alg₂))

function field_matrix_solver_cache(alg::SchurComplementReductionSolve, A, b)
A₁₁, A₁₂, A₂₁, A₂₂, b₁, b₂ = partition_blocks(alg.names₁, A, b)
b₁′ = similar(b₁)
Expand Down
3 changes: 3 additions & 0 deletions src/MatrixFields/field_matrix_with_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ Base.:(==)(A1::FieldMatrixWithSolver, A2::FieldMatrixWithSolver) =
Base.similar(A::FieldMatrixWithSolver) =
FieldMatrixWithSolver(similar(A.matrix), A.solver)

Base.zero(A::FieldMatrixWithSolver) =
FieldMatrixWithSolver(zero(A.matrix), A.solver)

ldiv!(x::Fields.FieldVector, A::FieldMatrixWithSolver, b::Fields.FieldVector) =
field_matrix_solve!(A.solver, x, A.matrix, b)

Expand Down
7 changes: 7 additions & 0 deletions src/MatrixFields/field_name_dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ function Base.similar(dict::FieldNameDict)
return FieldNameDict(keys(dict), entries)
end

function Base.zero(dict::FieldNameDict)
entries = unrolled_map(values(dict)) do entry
entry isa UniformScaling ? entry : zero(entry)
end
return FieldNameDict(keys(dict), entries)
end

# Note: This assumes that the matrix has the same row and column units, since I
# cannot be multiplied by anything other than a scalar.
function Base.one(matrix::FieldMatrix)
Expand Down
2 changes: 2 additions & 0 deletions test/MatrixFields/field_matrix_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ function test_field_matrix_solver(; test_name, alg, A, b, use_rel_error = false)
@testset "$test_name" begin
x = similar(b)
A′ = FieldMatrixWithSolver(A, b, alg)
@test zero(A′) isa typeof(A′)
solve_time =
@benchmark ClimaComms.@cuda_sync comms_device ldiv!(x, A′, b)

b_test = similar(b)
@test zero(b) isa typeof(b)
mul_time =
@benchmark ClimaComms.@cuda_sync comms_device mul!(b_test, A′, x)

Expand Down

0 comments on commit a32ab63

Please sign in to comment.