From 6bae509eb75ca76fe2b34ba3430daf2c9f25f5fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tomasz=20=C5=9Amierzchalski?= <53766192+tomsmierz@users.noreply.github.com> Date: Thu, 14 Nov 2024 20:35:41 +0100 Subject: [PATCH] Zipper cpu 2 (#27) * added changes from julia 1.11 branch * mwe is passing * WIP * changed version for new release --------- Co-authored-by: tomsmierz --- Project.toml | 2 +- src/SpinGlassTensors.jl | 2 ++ src/base.jl | 8 ++++---- src/contractions/central.jl | 28 ++++++++++++++++------------ src/zipper.jl | 2 ++ 5 files changed, 25 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index 614577b..5184345 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SpinGlassTensors" uuid = "7584fc6a-5a23-4eeb-8277-827aab0146ea" authors = ["Anna Maria Dziubyna ", "Tomasz Śmierzchalski ", "Bartłomiej Gardas ", "Konrad Jałowiecki ", "Łukasz Pawela ", "Marek M. Rams "] -version = "1.1.3" +version = "1.2.0" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/src/SpinGlassTensors.jl b/src/SpinGlassTensors.jl index 636a463..3107b5b 100644 --- a/src/SpinGlassTensors.jl +++ b/src/SpinGlassTensors.jl @@ -15,6 +15,8 @@ import Base.Prehashed CUDA.allowscalar(false) +ArrayorCuArray(A::AbstractArray, onGPU) = onGPU ? CuArray(A) : A + include("projectors.jl") include("base.jl") include("linear_algebra_ext.jl") diff --git a/src/base.jl b/src/base.jl index 27c2629..1cd402c 100644 --- a/src/base.jl +++ b/src/base.jl @@ -68,10 +68,10 @@ const MatOrCentral{T,N} = Union{AbstractMatrix{T},CentralTensor{T,N}} function dense_central(ten::CentralTensor) # @cast V[(u1, u2), (d1, d2)] := # ten.e11[u1, d1] * ten.e21[u2, d1] * ten.e12[u1, d2] * ten.e22[u2, d2] - a11 = reshape(CuArray(ten.e11), size(ten.e11, 1), :, size(ten.e11, 2)) - a21 = reshape(CuArray(ten.e21), :, size(ten.e21, 1), size(ten.e21, 2)) - a12 = reshape(CuArray(ten.e12), size(ten.e12, 1), 1, 1, size(ten.e12, 2)) - a22 = reshape(CuArray(ten.e22), 1, size(ten.e22, 1), 1, size(ten.e22, 2)) + a11 = reshape(ten.e11, size(ten.e11, 1), :, size(ten.e11, 2)) + a21 = reshape(ten.e21, :, size(ten.e21, 1), size(ten.e21, 2)) + a12 = reshape(ten.e12, size(ten.e12, 1), 1, 1, size(ten.e12, 2)) + a22 = reshape(ten.e22, 1, size(ten.e22, 1), 1, size(ten.e22, 2)) V = @__dot__(a11 * a21 * a12 * a22) V = reshape(V, size(V, 1) * size(V, 2), size(V, 3) * size(V, 4)) V ./ maximum(V) diff --git a/src/contractions/central.jl b/src/contractions/central.jl index c0e7f01..a3ed82d 100644 --- a/src/contractions/central.jl +++ b/src/contractions/central.jl @@ -4,11 +4,13 @@ export contract_tensor3_matrix, contract_matrix_tensor3, update_reduced_env_righ # my_batched_mul! function contract_tensor3_matrix(LE::Tensor{R,3}, M::CentralTensor{R,2}) where {R<:Real} - contract_tensor3_central(LE, M.e11, M.e12, M.e21, M.e22) + onGPU = typeof(LE) <: CuArray ? true : false + contract_tensor3_central(LE, M.e11, M.e12, M.e21, M.e22, onGPU) end function contract_matrix_tensor3(M::CentralTensor{R,2}, RE::Tensor{R,3}) where {R<:Real} - contract_tensor3_central(RE, M.e11', M.e21', M.e12', M.e22') + onGPU = typeof(RE) <: CuArray ? true : false + contract_tensor3_central(RE, M.e11', M.e21', M.e12', M.e22', onGPU) end function update_reduced_env_right(RR::Tensor{R,2}, M::CentralTensor{R,2}) where {R<:Real} @@ -17,7 +19,7 @@ function update_reduced_env_right(RR::Tensor{R,2}, M::CentralTensor{R,2}) where end -function contract_tensor3_central(LE, e11, e12, e21, e22) +function contract_tensor3_central(LE, e11, e12, e21, e22, onGPU) sb, st = size(LE) sbt = sb * st sl1, sl2, sr1, sr2 = size(e11, 1), size(e22, 1), size(e11, 2), size(e22, 2) @@ -25,10 +27,10 @@ function contract_tensor3_central(LE, e11, e12, e21, e22) if sl1 * sl2 * sr1 * sr2 < sinter # @cast E[(l1, l2), (r1, r2)] := e11[l1, r1] * e21[l2, r1] * e12[l1, r2] * e22[l2, r2] # TODO: terrible hack, rmeove when TensorCast is updated - a11 = reshape(CuArray(e11), size(e11, 1), :, size(e11, 2)) - a21 = reshape(CuArray(e21), :, size(e21, 1), size(e21, 2)) - a12 = reshape(CuArray(e12), size(e12, 1), 1, 1, size(e12, 2)) - a22 = reshape(CuArray(e22), 1, size(e22, 1), 1, size(e22, 2)) + a11 = reshape(ArrayorCuArray(e11, onGPU), size(e11, 1), :, size(e11, 2)) + a21 = reshape(ArrayorCuArray(e21, onGPU), :, size(e21, 1), size(e21, 2)) + a12 = reshape(ArrayorCuArray(e12, onGPU), size(e12, 1), 1, 1, size(e12, 2)) + a22 = reshape(ArrayorCuArray(e22, onGPU), 1, size(e22, 1), 1, size(e22, 2)) E = @__dot__(a11 * a21 * a12 * a22) E = reshape(E, size(E, 1) * size(E, 2), size(E, 3) * size(E, 4)) return reshape(reshape(LE, (sbt, sl1 * sl2)) * E, (sb, st, sr1 * sr2)) @@ -67,8 +69,9 @@ function batched_mul!( LE::Tensor{R,3}, M::AbstractArray{R,2}, ) where {R<:Real} + onGPU = typeof(newLE) <: CuArray ? true : false N1, N2 = size(M) - new_M = CUDA.CuArray(M) # TODO: this is a hack to solve problem with types; + new_M = ArrayorCuArray(M, onGPU) # TODO: this is a hack to solve problem with types; new_M = reshape(new_M, (N1, N2, 1)) NNlib.batched_mul!(newLE, LE, new_M) end @@ -78,16 +81,17 @@ function batched_mul!( LE::Tensor{R,3}, M::CentralTensor{R,2}, ) where {R<:Real} + onGPU = typeof(newLE) <: CuArray ? true : false sb, _, st = size(LE) sl1, sl2, sr1, sr2 = size(M.e11, 1), size(M.e22, 1), size(M.e11, 2), size(M.e22, 2) sinter = sb * st * max(sl1 * sl2 * min(sr1, sr2), sr1 * sr2 * min(sl1, sl2)) if sl1 * sl2 * sr1 * sr2 < sinter # @cast E[(l1, l2), (r1, r2)] := # M.e11[l1, r1] * M.e21[l2, r1] * M.e12[l1, r2] * M.e22[l2, r2] - a11 = reshape(CuArray(M.e11), size(M.e11, 1), :, size(M.e11, 2)) - a21 = reshape(CuArray(M.e21), :, size(M.e21, 1), size(M.e21, 2)) - a12 = reshape(CuArray(M.e12), size(M.e12, 1), 1, 1, size(M.e12, 2)) - a22 = reshape(CuArray(M.e22), 1, size(M.e22, 1), 1, size(M.e22, 2)) + a11 = reshape(ArrayorCuArray(M.e11, onGPU), size(M.e11, 1), :, size(M.e11, 2)) + a21 = reshape(ArrayorCuArray(M.e21, onGPU), :, size(M.e21, 1), size(M.e21, 2)) + a12 = reshape(ArrayorCuArray(M.e12, onGPU), size(M.e12, 1), 1, 1, size(M.e12, 2)) + a22 = reshape(ArrayorCuArray(M.e22, onGPU), 1, size(M.e22, 1), 1, size(M.e22, 2)) E = @__dot__(a11 * a21 * a12 * a22) E = reshape(E, size(E, 1) * size(E, 2), size(E, 3) * size(E, 4)) E = reshape(E, (sl1 * sl2, sr1 * sr2, 1)) diff --git a/src/zipper.jl b/src/zipper.jl index 48e7223..a8a6a02 100644 --- a/src/zipper.jl +++ b/src/zipper.jl @@ -163,6 +163,7 @@ function _left_sweep_var_site!(env::EnvironmentMixed, site; kwargs...) # site: _, Q = rq_fact(B; toGPU = env.onGPU, kwargs...) # @cast C[l, r, t] := Q[l, (r, t)] (t ∈ 1:size(A, 3)) C = reshape(Q, size(Q, 1), size(Q, 2) ÷ size(A, 3), size(A, 3)) + !env.onGPU && (C = collect(C)) if site == :central env.C = C else @@ -181,6 +182,7 @@ function _right_sweep_var_site!(env::EnvironmentMixed, site; kwargs...) # @cast C[l, t, r] := Q[(l, t), r] (t ∈ 1:size(A, 3)) C = reshape(Q, size(Q, 1) ÷ size(A, 3), size(A, 3), size(Q, 2)) C = permutedims(C, (1, 3, 2)) # [l, r, t] + !env.onGPU && (C = collect(C)) if site == :central env.C = C else