Skip to content

Commit

Permalink
Add scheduler support in mul!
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Jan 17, 2025
1 parent 4999722 commit c6f7c15
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/tensors/blockiterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ Base.IteratorEltype(::BlockIterator) = Base.HasEltype()
Base.eltype(::Type{<:BlockIterator{T}}) where {T} = blocktype(T)
Base.length(iter::BlockIterator) = length(iter.structure)
Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...)

Base.haskey(iter::BlockIterator, c) = haskey(iter.structure, c)
56 changes: 52 additions & 4 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,43 @@ function LinearAlgebra.tr(t::AbstractTensorMap)
end

# TensorMap multiplication
function LinearAlgebra.mul!(tC::AbstractTensorMap,
tA::AbstractTensorMap,
tB::AbstractTensorMap, α=true, β=false)
function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap,
tB::AbstractTensorMap,
α::Number, β::Number,
backend=TO.DefaultBackend())
if backend isa TO.DefaultBackend
newbackend = TO.select_backend(mul!, tC, tA, tB)
return mul!(tC, tA, tB, α, β, newbackend)
elseif backend isa TO.NoBackend # error for missing backend
TC = typeof(tC)
TA = typeof(tA)
TB = typeof(tB)
throw(ArgumentError("No suitable backend found for `mul!` and tensor types $TC, $TA and $TB"))
else # error for unknown backend
TC = typeof(tC)
TA = typeof(tA)
TB = typeof(tB)
throw(ArgumentError("Unknown backend for `mul!` and tensor types $TC, $TA and $TB"))
end
end

function TO.select_backend(::typeof(mul!), C::AbstractTensorMap, A::AbstractTensorMap,
B::AbstractTensorMap)
return SerialScheduler()
end

function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap,
tB::AbstractTensorMap, α::Number, β::Number,
scheduler::Union{Nothing,Scheduler})
if isnothing(scheduler)
return sequential_mul!(tC, tA, tB, α, β)
else
return threaded_mul!(tC, tA, tB, α, β, scheduler)
end
end

function sequential_mul!(tC::AbstractTensorMap, tA::AbstractTensorMap,
tB::AbstractTensorMap, α::Number, β::Number)
compose(space(tA), space(tB)) == space(tC) ||
throw(SpaceMismatch(lazy"$(space(tC)) ≠ $(space(tA)) * $(space(tB))"))

Expand Down Expand Up @@ -325,7 +359,21 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
return tC
end

# TODO: consider spawning threads for different blocks, support backends
function threaded_mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, tB::AbstractTensorMap,
α::Number, β::Number, scheduler::Scheduler)
# obtain cached data before multithreading
bCs, bAs, bBs = blocks(tC), blocks(tA), blocks(tB)

tforeach(blocksectors(tC); scheduler) do c
if haskey(bAs, c) # then also bBs should have it
mul!(bCs[c], bAs[c], bBs[c], α, β)
elseif !isone(β)
scale!(bCs[c], β)
end
end

return tC
end

# TensorMap inverse
function Base.inv(t::AbstractTensorMap)
Expand Down

0 comments on commit c6f7c15

Please sign in to comment.