From c6f7c15ad36fdae21a10cd8947d79ee5a0dfe5f8 Mon Sep 17 00:00:00 2001
From: Lukas Devos <ldevos98@gmail.com>
Date: Thu, 16 Jan 2025 20:59:08 -0500
Subject: [PATCH] Add scheduler support in `mul!`

---
 src/tensors/blockiterator.jl |  2 ++
 src/tensors/linalg.jl        | 56 +++++++++++++++++++++++++++++++++---
 2 files changed, 54 insertions(+), 4 deletions(-)

diff --git a/src/tensors/blockiterator.jl b/src/tensors/blockiterator.jl
index b4ec4b87..06576dc6 100644
--- a/src/tensors/blockiterator.jl
+++ b/src/tensors/blockiterator.jl
@@ -13,3 +13,5 @@ Base.IteratorEltype(::BlockIterator) = Base.HasEltype()
 Base.eltype(::Type{<:BlockIterator{T}}) where {T} = blocktype(T)
 Base.length(iter::BlockIterator) = length(iter.structure)
 Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...)
+
+Base.haskey(iter::BlockIterator, c) = haskey(iter.structure, c)
diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl
index a3f84adb..2d897309 100644
--- a/src/tensors/linalg.jl
+++ b/src/tensors/linalg.jl
@@ -283,9 +283,43 @@ function LinearAlgebra.tr(t::AbstractTensorMap)
 end
 
 # TensorMap multiplication
-function LinearAlgebra.mul!(tC::AbstractTensorMap,
-                            tA::AbstractTensorMap,
-                            tB::AbstractTensorMap, α=true, β=false)
+function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap,
+                            tB::AbstractTensorMap,
+                            α::Number, β::Number,
+                            backend=TO.DefaultBackend())
+    if backend isa TO.DefaultBackend
+        newbackend = TO.select_backend(mul!, tC, tA, tB)
+        return mul!(tC, tA, tB, α, β, newbackend)
+    elseif backend isa TO.NoBackend # error for missing backend
+        TC = typeof(tC)
+        TA = typeof(tA)
+        TB = typeof(tB)
+        throw(ArgumentError("No suitable backend found for `mul!` and tensor types $TC, $TA and $TB"))
+    else # error for unknown backend
+        TC = typeof(tC)
+        TA = typeof(tA)
+        TB = typeof(tB)
+        throw(ArgumentError("Unknown backend for `mul!` and tensor types $TC, $TA and $TB"))
+    end
+end
+
+function TO.select_backend(::typeof(mul!), C::AbstractTensorMap, A::AbstractTensorMap,
+                           B::AbstractTensorMap)
+    return SerialScheduler()
+end
+
+function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap,
+                            tB::AbstractTensorMap, α::Number, β::Number,
+                            scheduler::Union{Nothing,Scheduler})
+    if isnothing(scheduler)
+        return sequential_mul!(tC, tA, tB, α, β)
+    else
+        return threaded_mul!(tC, tA, tB, α, β, scheduler)
+    end
+end
+
+function sequential_mul!(tC::AbstractTensorMap, tA::AbstractTensorMap,
+                         tB::AbstractTensorMap, α::Number, β::Number)
     compose(space(tA), space(tB)) == space(tC) ||
         throw(SpaceMismatch(lazy"$(space(tC)) ≠ $(space(tA)) * $(space(tB))"))
 
@@ -325,7 +359,21 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
     return tC
 end
 
-# TODO: consider spawning threads for different blocks, support backends
+function threaded_mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, tB::AbstractTensorMap,
+                       α::Number, β::Number, scheduler::Scheduler)
+    # obtain cached data before multithreading
+    bCs, bAs, bBs = blocks(tC), blocks(tA), blocks(tB)
+
+    tforeach(blocksectors(tC); scheduler) do c
+        if haskey(bAs, c) # then also bBs should have it
+            mul!(bCs[c], bAs[c], bBs[c], α, β)
+        elseif !isone(β)
+            scale!(bCs[c], β)
+        end
+    end
+
+    return tC
+end
 
 # TensorMap inverse
 function Base.inv(t::AbstractTensorMap)