From c6f7c15ad36fdae21a10cd8947d79ee5a0dfe5f8 Mon Sep 17 00:00:00 2001 From: Lukas Devos <ldevos98@gmail.com> Date: Thu, 16 Jan 2025 20:59:08 -0500 Subject: [PATCH] Add scheduler support in `mul!` --- src/tensors/blockiterator.jl | 2 ++ src/tensors/linalg.jl | 56 +++++++++++++++++++++++++++++++++--- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/src/tensors/blockiterator.jl b/src/tensors/blockiterator.jl index b4ec4b87..06576dc6 100644 --- a/src/tensors/blockiterator.jl +++ b/src/tensors/blockiterator.jl @@ -13,3 +13,5 @@ Base.IteratorEltype(::BlockIterator) = Base.HasEltype() Base.eltype(::Type{<:BlockIterator{T}}) where {T} = blocktype(T) Base.length(iter::BlockIterator) = length(iter.structure) Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...) + +Base.haskey(iter::BlockIterator, c) = haskey(iter.structure, c) diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index a3f84adb..2d897309 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -283,9 +283,43 @@ function LinearAlgebra.tr(t::AbstractTensorMap) end # TensorMap multiplication -function LinearAlgebra.mul!(tC::AbstractTensorMap, - tA::AbstractTensorMap, - tB::AbstractTensorMap, α=true, β=false) +function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, + tB::AbstractTensorMap, + α::Number, β::Number, + backend=TO.DefaultBackend()) + if backend isa TO.DefaultBackend + newbackend = TO.select_backend(mul!, tC, tA, tB) + return mul!(tC, tA, tB, α, β, newbackend) + elseif backend isa TO.NoBackend # error for missing backend + TC = typeof(tC) + TA = typeof(tA) + TB = typeof(tB) + throw(ArgumentError("No suitable backend found for `mul!` and tensor types $TC, $TA and $TB")) + else # error for unknown backend + TC = typeof(tC) + TA = typeof(tA) + TB = typeof(tB) + throw(ArgumentError("Unknown backend for `mul!` and tensor types $TC, $TA and $TB")) + end +end + +function TO.select_backend(::typeof(mul!), C::AbstractTensorMap, A::AbstractTensorMap, + B::AbstractTensorMap) + return SerialScheduler() +end + +function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, + tB::AbstractTensorMap, α::Number, β::Number, + scheduler::Union{Nothing,Scheduler}) + if isnothing(scheduler) + return sequential_mul!(tC, tA, tB, α, β) + else + return threaded_mul!(tC, tA, tB, α, β, scheduler) + end +end + +function sequential_mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, + tB::AbstractTensorMap, α::Number, β::Number) compose(space(tA), space(tB)) == space(tC) || throw(SpaceMismatch(lazy"$(space(tC)) ≠ $(space(tA)) * $(space(tB))")) @@ -325,7 +359,21 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap, return tC end -# TODO: consider spawning threads for different blocks, support backends +function threaded_mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, tB::AbstractTensorMap, + α::Number, β::Number, scheduler::Scheduler) + # obtain cached data before multithreading + bCs, bAs, bBs = blocks(tC), blocks(tA), blocks(tB) + + tforeach(blocksectors(tC); scheduler) do c + if haskey(bAs, c) # then also bBs should have it + mul!(bCs[c], bAs[c], bBs[c], α, β) + elseif !isone(β) + scale!(bCs[c], β) + end + end + + return tC +end # TensorMap inverse function Base.inv(t::AbstractTensorMap)