Skip to content

Commit

Permalink
Fix nested broadcast of AbstractBlockTuple (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe authored Jan 14, 2025
1 parent b876749 commit 5323bc4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.3"
version = "0.1.4"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
12 changes: 12 additions & 0 deletions src/blockedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ end
function Base.getindex(bt::AbstractBlockTuple, bi::BlockIndexRange{1})
return bt[Block(bi)][only(bi.indices)]
end
# needed for nested broadcast in Julia < 1.11
Base.getindex(bt::AbstractBlockTuple, ci::CartesianIndex{1}) = bt[only(Tuple(ci))]

Base.iterate(bt::AbstractBlockTuple) = iterate(Tuple(bt))
Base.iterate(bt::AbstractBlockTuple, i::Int) = iterate(Tuple(bt), i)
Expand All @@ -53,6 +55,14 @@ function Base.map(f, bt::AbstractBlockTuple)
return widened_constructorof(typeof(bt))(map(f, Tuple(bt)), Val(BL))
end

function Base.show(io::IO, bt::AbstractBlockTuple)
return print(io, nameof(typeof(bt)), blocks(bt))
end
function Base.show(io::IO, ::MIME"text/plain", bt::AbstractBlockTuple)
println(io, typeof(bt))
return print(io, blocks(bt))
end

# Broadcast interface
Base.broadcastable(bt::AbstractBlockTuple) = bt
struct AbstractBlockTupleBroadcastStyle{BlockLengths,BT} <: Broadcast.BroadcastStyle end
Expand All @@ -72,6 +82,8 @@ function Base.copy(
return widened_constructorof(BT)(bc.f.((Tuple.(bc.args))...), Val(BlockLengths))
end

Base.ndims(::Type{<:AbstractBlockTuple}) = 1 # needed in nested broadcast

# BlockArrays interface
BlockArrays.blockfirsts(::AbstractBlockTuple{0}) = ()
function BlockArrays.blockfirsts(bt::AbstractBlockTuple)
Expand Down
1 change: 1 addition & 0 deletions test/test_blockedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ using TensorAlgebra: BlockedTuple, blockeachindex, tuplemortar

bt = tuplemortar(((1:2, 1:2), (1:3,)))
@test length.(bt) == tuplemortar(((2, 2), (3,)))
@test length.(length.(bt)) == tuplemortar(((1, 1), (1,)))

# empty blocks
bt = tuplemortar(((1,), (), (5, 3)))
Expand Down

0 comments on commit 5323bc4

Please sign in to comment.