Skip to content

Commit

Permalink
Merge branch 'master' into lp/julia-1.11
Browse files Browse the repository at this point in the history
  • Loading branch information
lpawela committed Nov 15, 2024
2 parents 43ce99e + 6bae509 commit 048b5ae
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 16 deletions.
2 changes: 2 additions & 0 deletions src/SpinGlassTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import Base.Prehashed

CUDA.allowscalar(false)

ArrayorCuArray(A::AbstractArray, onGPU) = onGPU ? CuArray(A) : A

include("projectors.jl")
include("base.jl")
include("linear_algebra_ext.jl")
Expand Down
8 changes: 4 additions & 4 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ const MatOrCentral{T,N} = Union{AbstractMatrix{T},CentralTensor{T,N}}
function dense_central(ten::CentralTensor)
# @cast V[(u1, u2), (d1, d2)] :=
# ten.e11[u1, d1] * ten.e21[u2, d1] * ten.e12[u1, d2] * ten.e22[u2, d2]
a11 = reshape(CuArray(ten.e11), size(ten.e11, 1), :, size(ten.e11, 2))
a21 = reshape(CuArray(ten.e21), :, size(ten.e21, 1), size(ten.e21, 2))
a12 = reshape(CuArray(ten.e12), size(ten.e12, 1), 1, 1, size(ten.e12, 2))
a22 = reshape(CuArray(ten.e22), 1, size(ten.e22, 1), 1, size(ten.e22, 2))
a11 = reshape(ten.e11, size(ten.e11, 1), :, size(ten.e11, 2))
a21 = reshape(ten.e21, :, size(ten.e21, 1), size(ten.e21, 2))
a12 = reshape(ten.e12, size(ten.e12, 1), 1, 1, size(ten.e12, 2))
a22 = reshape(ten.e22, 1, size(ten.e22, 1), 1, size(ten.e22, 2))
V = @__dot__(a11 * a21 * a12 * a22)
V = reshape(V, size(V, 1) * size(V, 2), size(V, 3) * size(V, 4))
V ./ maximum(V)
Expand Down
28 changes: 16 additions & 12 deletions src/contractions/central.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ export contract_tensor3_matrix, contract_matrix_tensor3, update_reduced_env_righ
# my_batched_mul!

function contract_tensor3_matrix(LE::Tensor{R,3}, M::CentralTensor{R,2}) where {R<:Real}
contract_tensor3_central(LE, M.e11, M.e12, M.e21, M.e22)
onGPU = typeof(LE) <: CuArray ? true : false
contract_tensor3_central(LE, M.e11, M.e12, M.e21, M.e22, onGPU)
end

function contract_matrix_tensor3(M::CentralTensor{R,2}, RE::Tensor{R,3}) where {R<:Real}
contract_tensor3_central(RE, M.e11', M.e21', M.e12', M.e22')
onGPU = typeof(RE) <: CuArray ? true : false
contract_tensor3_central(RE, M.e11', M.e21', M.e12', M.e22', onGPU)
end

function update_reduced_env_right(RR::Tensor{R,2}, M::CentralTensor{R,2}) where {R<:Real}
Expand All @@ -17,18 +19,18 @@ function update_reduced_env_right(RR::Tensor{R,2}, M::CentralTensor{R,2}) where
end


function contract_tensor3_central(LE, e11, e12, e21, e22)
function contract_tensor3_central(LE, e11, e12, e21, e22, onGPU)
sb, st = size(LE)
sbt = sb * st
sl1, sl2, sr1, sr2 = size(e11, 1), size(e22, 1), size(e11, 2), size(e22, 2)
sinter = sbt * max(sl1 * sl2 * min(sr1, sr2), sr1 * sr2 * min(sl1, sl2))
if sl1 * sl2 * sr1 * sr2 < sinter
# @cast E[(l1, l2), (r1, r2)] := e11[l1, r1] * e21[l2, r1] * e12[l1, r2] * e22[l2, r2]
# TODO: terrible hack, rmeove when TensorCast is updated
a11 = reshape(CuArray(e11), size(e11, 1), :, size(e11, 2))
a21 = reshape(CuArray(e21), :, size(e21, 1), size(e21, 2))
a12 = reshape(CuArray(e12), size(e12, 1), 1, 1, size(e12, 2))
a22 = reshape(CuArray(e22), 1, size(e22, 1), 1, size(e22, 2))
a11 = reshape(ArrayorCuArray(e11, onGPU), size(e11, 1), :, size(e11, 2))
a21 = reshape(ArrayorCuArray(e21, onGPU), :, size(e21, 1), size(e21, 2))
a12 = reshape(ArrayorCuArray(e12, onGPU), size(e12, 1), 1, 1, size(e12, 2))
a22 = reshape(ArrayorCuArray(e22, onGPU), 1, size(e22, 1), 1, size(e22, 2))
E = @__dot__(a11 * a21 * a12 * a22)
E = reshape(E, size(E, 1) * size(E, 2), size(E, 3) * size(E, 4))
return reshape(reshape(LE, (sbt, sl1 * sl2)) * E, (sb, st, sr1 * sr2))
Expand Down Expand Up @@ -67,8 +69,9 @@ function batched_mul!(
LE::Tensor{R,3},
M::AbstractArray{R,2},
) where {R<:Real}
onGPU = typeof(newLE) <: CuArray ? true : false
N1, N2 = size(M)
new_M = CUDA.CuArray(M) # TODO: this is a hack to solve problem with types;
new_M = ArrayorCuArray(M, onGPU) # TODO: this is a hack to solve problem with types;
new_M = reshape(new_M, (N1, N2, 1))
NNlib.batched_mul!(newLE, LE, new_M)
end
Expand All @@ -78,16 +81,17 @@ function batched_mul!(
LE::Tensor{R,3},
M::CentralTensor{R,2},
) where {R<:Real}
onGPU = typeof(newLE) <: CuArray ? true : false
sb, _, st = size(LE)
sl1, sl2, sr1, sr2 = size(M.e11, 1), size(M.e22, 1), size(M.e11, 2), size(M.e22, 2)
sinter = sb * st * max(sl1 * sl2 * min(sr1, sr2), sr1 * sr2 * min(sl1, sl2))
if sl1 * sl2 * sr1 * sr2 < sinter
# @cast E[(l1, l2), (r1, r2)] :=
# M.e11[l1, r1] * M.e21[l2, r1] * M.e12[l1, r2] * M.e22[l2, r2]
a11 = reshape(CuArray(M.e11), size(M.e11, 1), :, size(M.e11, 2))
a21 = reshape(CuArray(M.e21), :, size(M.e21, 1), size(M.e21, 2))
a12 = reshape(CuArray(M.e12), size(M.e12, 1), 1, 1, size(M.e12, 2))
a22 = reshape(CuArray(M.e22), 1, size(M.e22, 1), 1, size(M.e22, 2))
a11 = reshape(ArrayorCuArray(M.e11, onGPU), size(M.e11, 1), :, size(M.e11, 2))
a21 = reshape(ArrayorCuArray(M.e21, onGPU), :, size(M.e21, 1), size(M.e21, 2))
a12 = reshape(ArrayorCuArray(M.e12, onGPU), size(M.e12, 1), 1, 1, size(M.e12, 2))
a22 = reshape(ArrayorCuArray(M.e22, onGPU), 1, size(M.e22, 1), 1, size(M.e22, 2))
E = @__dot__(a11 * a21 * a12 * a22)
E = reshape(E, size(E, 1) * size(E, 2), size(E, 3) * size(E, 4))
E = reshape(E, (sl1 * sl2, sr1 * sr2, 1))
Expand Down

0 comments on commit 048b5ae

Please sign in to comment.