Skip to content

Commit

Permalink
Fix similar (#195)
Browse files Browse the repository at this point in the history
* fix similar

* add test
  • Loading branch information
Jutho authored Jan 10, 2025
1 parent 6aa0c7b commit 582b6d7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorKit"
uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
authors = ["Jutho Haegeman"]
version = "0.14.1"
version = "0.14.2"

[deps]
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
Expand Down
11 changes: 7 additions & 4 deletions src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ field(::Type{TT}) where {TT<:AbstractTensorMap} = field(spacetype(TT))
Return the type of vector that stores the data of a tensor.
""" storagetype

similarstoragetype(TT::Type{<:AbstractTensorMap}) = similarstoragetype(TT, scalartype(TT))

function similarstoragetype(TT::Type{<:AbstractTensorMap}, ::Type{T}) where {T}
return Core.Compiler.return_type(similar, Tuple{storagetype(TT),Type{T}})
end
Expand Down Expand Up @@ -193,7 +195,7 @@ sectortype(t::AbstractTensorMap) = sectortype(typeof(t))
InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t))
field(t::AbstractTensorMap) = field(typeof(t))
storagetype(t::AbstractTensorMap) = storagetype(typeof(t))
similarstoragetype(t::AbstractTensorMap, TT) = similarstoragetype(typeof(t), TT)
similarstoragetype(t::AbstractTensorMap, T=scalartype(t)) = similarstoragetype(typeof(t), T)

numout(t::AbstractTensorMap) = numout(typeof(t))
numin(t::AbstractTensorMap) = numin(typeof(t))
Expand Down Expand Up @@ -382,19 +384,19 @@ end
# 3 arguments
function Base.similar(t::AbstractTensorMap, codomain::TensorSpace{S},
domain::TensorSpace{S}) where {S}
return similar(t, storagetype(t), codomain domain)
return similar(t, similarstoragetype(t), codomain domain)
end
function Base.similar(t::AbstractTensorMap, ::Type{T}, codomain::TensorSpace) where {T}
return similar(t, T, codomain one(codomain))
end
# 2 arguments
function Base.similar(t::AbstractTensorMap, codomain::TensorSpace)
return similar(t, storagetype(t), codomain one(codomain))
return similar(t, similarstoragetype(t), codomain one(codomain))
end
Base.similar(t::AbstractTensorMap, P::TensorMapSpace) = similar(t, storagetype(t), P)
Base.similar(t::AbstractTensorMap, ::Type{T}) where {T} = similar(t, T, space(t))
# 1 argument
Base.similar(t::AbstractTensorMap) = similar(t, storagetype(t), space(t))
Base.similar(t::AbstractTensorMap) = similar(t, similarstoragetype(t), space(t))

# generic implementation for AbstractTensorMap -> returns `TensorMap`
function Base.similar(t::AbstractTensorMap, ::Type{TorA},
Expand All @@ -408,6 +410,7 @@ function Base.similar(t::AbstractTensorMap, ::Type{TorA},
else
throw(ArgumentError("Type $TorA not supported for similar"))
end

N₁ = length(codomain(P))
N₂ = length(domain(P))
return TensorMap{T,S,N₁,N₂,A}(undef, P)
Expand Down
14 changes: 14 additions & 0 deletions test/bugfixes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,18 @@
a = convert(Array, t)
@test a == zeros(size(a))
end

# https://github.com/Jutho/TensorKit.jl/issues/194
@testset "Issue #194" begin
t1 = rand(ℂ^4 ^4)
t2 = tensoralloc(typeof(t1), space(t1), Val(true),
TensorOperations.ManualAllocator())
t3 = similar(t2, ComplexF64, space(t1))
@test storagetype(t3) == Vector{ComplexF64}
t4 = similar(t2, domain(t1))
@test storagetype(t4) == Vector{Float64}
t5 = similar(t2)
@test storagetype(t5) == Vector{Float64}
tensorfree!(t2)
end
end

0 comments on commit 582b6d7

Please sign in to comment.