Skip to content

Commit

Permalink
Ad/documentation (#15)
Browse files Browse the repository at this point in the history
* docs compressions, contractions

* fix typos

* add $(TYPEDSIGNATURES)

* add docs

* work in progress

* work in progress

* Work in Progress

* work in progress

* somewhat done

---------

Co-authored-by: annamariadziubyna <[email protected]>
Co-authored-by: tomsmierz <[email protected]>
  • Loading branch information
3 people authored Mar 12, 2023
1 parent dd89004 commit 60b93ce
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 4 deletions.
13 changes: 13 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using Documenter, SpinGlassTensors

_pages = [
"Introduction" => "index.md",
"API Reference" => "api.md"
]
# ============================

makedocs(
sitename="SpinGlassTensors",
modules = [SpinGlassTensors],
pages = _pages
)
34 changes: 34 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Library

---
```@meta
CurrentModule = SpinGlassTensors
```
## Additional methods for `Base` and `LinearAlgebra`
```@docs
dot
norm
randn
rank
```

## MPS
```@docs
MPS
is_left_normalized
is_right_normalized
physical_dim
verify_bonds
verify_physical_dims
```

## Compresions and Contractions

```@docs
canonise!
compress!
left_env
right_env
truncate!
```
6 changes: 6 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SpinGlassTensor

Part of [SpinGlassPEPS](https://github.com/euro-hpc-pl/SpinGlassPEPS.jl) package. It constitutes the basis for the preparation of tensors and operations on them.

!!! info
We don't expect the user to interact with this package, as it is more of a "back-end" type. Nevertheless, we provide API references should the need arise.
81 changes: 81 additions & 0 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,59 @@ end
@inline Base.size(a::AbstractTensorNetwork) = (length(a.tensors),)
@inline Base.eachindex(a::AbstractTensorNetwork) = eachindex(a.tensors)

"""
LinearAlgebra.rank(ψ::AbstractMPS)
Returns rank of MPS tensors.
"""
@inline LinearAlgebra.rank::AbstractMPS) = Tuple(size(A, 2) for A ψ)

"""
physical_dim(ψ::AbstractMPS, i::Int)
Returns physical dimension of MPS tensors at given site i.
"""
@inline physical_dim::AbstractMPS, i::Int) = size(ψ[i], 2)



@inline MPS(A::AbstractArray) = MPS(A, :right)



"""
MPS(A::AbstractArray, s::Symbol, Dcut::Int = typemax(Int))
Construct a matrix product state (MPS) using the provided tensor array `A`.
## Arguments
- `A::AbstractArray`: The tensor array that defines the MPS.
- `s::Symbol`: The direction to canonically transform the MPS. Must be either `:left` or `:right`.
- `Dcut::Int`: The maximum bond dimension allowed during the truncation step.
## Returns
- `ψ::AbstractMPS`: The constructed MPS.
## Details
This function constructs a matrix product state (MPS) using the provided tensor array `A`,
and then canonically transforms it in the direction specified by the `s` argument. If `s` is `:right`,
the MPS is right-canonized, while if `s` is `:left`, the MPS is left-canonized.
The `Dcut` argument determines the maximum bond dimension allowed during the truncation step.
If neither `Dcut` nor `s` is specified, it will construct right-canonized MPS with default Dcut value.
## Example
```@repl
A = rand(2, 3, 2)
ψ = MPS(A, :left, 2);
typeof(ψ)
length(ψ)
bond_dimension(ψ)
```
"""
@inline function MPS(A::AbstractArray, s::Symbol, Dcut::Int = typemax(Int))
@assert s (:left, :right)
if s == :right
Expand All @@ -62,6 +110,12 @@ end

@inline dropindices::AbstractMPS, i::Int = 2) = (dropdims(A, dims = i) for A ψ)


"""
MPS(states::Vector{Vector{T}}) where {T<:Number}
Create a matrix product state (MPS) object from a vector of states.
"""
function MPS(states::Vector{Vector{T}}) where {T<:Number}
state_arrays = [reshape(copy(v), (1, length(v), 1)) for v states]
MPS(state_arrays)
Expand All @@ -76,6 +130,12 @@ function (::Type{T})(O::AbstractMPO) where {T<:AbstractMPS}
T([@cast A[x, (σ, η), y] := W[x, σ, y, η] for W in O])
end

"""
Base.randn(::Type{MPS{T}}, D::Int, rank::Union{Vector,NTuple}) where {T}
Create random MPS.The argument `D` specifies the physical dimension of the MPS
(i.e. the dimension of the vectors at each site), `rank` specifies rank of each site.
"""
function Base.randn(::Type{MPS{T}}, D::Int, rank::Union{Vector,NTuple}) where {T}
MPS([
randn(T, 1, first(rank), D),
Expand All @@ -88,6 +148,7 @@ function Base.randn(::Type{MPS{T}}, L::Int, D::Int, d::Int) where {T}
MPS([randn(T, 1, d, D), (randn(T, D, d, D) for _ = 2:L-1)..., randn(T, D, d, 1)])
end


Base.randn(::Type{MPS}, args...) = randn(MPS{Float64}, args...)

function Base.randn(::Type{MPO{T}}, L::Int, D::Int, d::Int) where {T}
Expand All @@ -98,11 +159,21 @@ function Base.randn(::Type{MPO{T}}, D::Int, rank::Union{Vector,NTuple}) where {T
MPO(randn(MPS{T}, D, rank .^ 2))
end

"""
is_left_normalized(ψ::MPS)
Check whether MPS is left normalized.
"""
is_left_normalized::MPS) = all(
I(size(A, 3)) @tensor Id[x, y] := conj(A[α, σ, x]) * A[α, σ, y] order = (α, σ) for
A ψ
)

"""
is_right_normalized(ϕ::MPS)
Check whether MPS is right normalized.
"""
is_right_normalized::MPS) = all(
I(size(B, 1)) @tensor Id[x, y] := B[x, σ, α] * conj(B[y, σ, α]) order = (α, σ) for
B in ϕ
Expand All @@ -113,12 +184,22 @@ function _verify_square(ψ::AbstractMPS)
@assert isqrt.(dims) .^ 2 == dims "Incorrect MPS dimensions"
end

"""
verify_physical_dims(ψ::AbstractMPS, dims::NTuple)
Check whether MPS has correct physical dimension at given site.
"""
function verify_physical_dims::AbstractMPS, dims::NTuple)
for i eachindex(ψ)
@assert physical_dim(ψ, i) == dims[i] "Incorrect physical dim at site $(i)."
end
end

"""
verify_bonds(ψ::AbstractMPS)
Check whether MPS has correct sizes.
"""
function verify_bonds::AbstractMPS)
L = length(ψ)

Expand Down
43 changes: 42 additions & 1 deletion src/compressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,25 @@ function compress(
ψ
end


"""
compress!(
ϕ::AbstractMPS,
Dcut::Int,
tol::Number = 1E-8,
max_sweeps::Int = 4,
args...,
)
# Arguments
- `ϕ::AbstractMPS`: the input MPS to be compressed
- `Dcut::Int`: the maximum bond dimension of the compressed MPS
- `tol::Number = 1E-8`: the tolerance threshold for convergence of the iterative compression process (default value: 1E-8)
- `max_sweeps::Int = 4`: the maximum number of iterations allowed for the compression process (default value: 4)
# Output
- `overlap`: The overlap of the compressed MPS with the original input MPS.
"""
function compress!(
ϕ::AbstractMPS,
Dcut::Int,
Expand Down Expand Up @@ -55,7 +73,21 @@ function compress!(
overlap
end

"""
truncate!(ψ::AbstractMPS, s::Symbol, Dcut::Int = typemax(Int), args...)
Truncate the bond dimension of a matrix product state (MPS) in either
the left or right canonical form, depending on the value of the `s` input argument.
# Arguments
- `ψ::AbstractMPS`: the input MPS to be truncated
- `s::Symbol`: determines whether to truncate the MPS in the left or right canonical form. Must be one of the following values:
- `:left`: truncate in left canonical form
- `:right`: truncate in right canonical form
- `Dcut::Int`: the maximum bond dimension to which the MPS should be truncated.
"""
function truncate!::AbstractMPS, s::Symbol, Dcut::Int = typemax(Int), args...)
@assert s (:left, :right)
if s == :right
Expand All @@ -67,8 +99,17 @@ function truncate!(ψ::AbstractMPS, s::Symbol, Dcut::Int = typemax(Int), args...
end
end

"""
canonise!(ψ::AbstractMPS, s::Symbol)
canonizes a matrix product state (MPS) in either the left or right canonical form,
depending on the value of the `s` input argument. Must be one of the following values:
- `:left`: canonize in left canonical form
- `:right`: canonize in right canonical form
"""
canonise!::AbstractMPS, s::Symbol) = canonise!(ψ, Val(s))

canonise!::AbstractMPS, ::Val{:right}) = _left_sweep!(ψ, typemax(Int))
canonise!::AbstractMPS, ::Val{:left}) = _right_sweep!(ψ, typemax(Int))

Expand Down
7 changes: 4 additions & 3 deletions src/contractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ function LinearAlgebra.dot(ϕ::AbstractMPS, ψ::AbstractMPS)
end

"""
left_env(ϕ::AbstractMPS, ψ::AbstractMPS)
Creates left environment (ϕ - bra, ψ - ket)
"""
function left_env::AbstractMPS, ψ::AbstractMPS)
Expand Down Expand Up @@ -50,6 +52,8 @@ end
end

"""
right_env(ϕ::AbstractMPS, ψ::AbstractMPS)
Creates right environment (ϕ - bra, ψ - ket)
"""
function right_env::AbstractMPS, ψ::AbstractMPS)
Expand Down Expand Up @@ -90,15 +94,13 @@ end
R
end


"""
$(TYPEDSIGNATURES)
Calculates the norm of an MPS \$\\ket{\\phi}\$
"""
LinearAlgebra.norm::AbstractMPS) = sqrt(abs(dot(ψ, ψ)))


"""
$(TYPEDSIGNATURES)
Expand All @@ -123,7 +125,6 @@ function LinearAlgebra.dot(ϕ::AbstractMPS, O::Union{Vector,NTuple}, ψ::Abstrac
tr(C)
end


function LinearAlgebra.dot(O::AbstractMPO, ψ::AbstractMPS)
S = promote_type(eltype(ψ), eltype(O))
T = typeof(ψ)
Expand Down

0 comments on commit 60b93ce

Please sign in to comment.