Skip to content

Commit

Permalink
Moving all abstract functionality to QuantumInterface.jl (#100)
Browse files Browse the repository at this point in the history
* avoid piracy of SparseArrays.permutedims(::AbstractSparseMatrix, ...)

use the new _permutedims in operators_sparse

* new abstract ParticleBasis to avoid piracy of QuantumInterface.Basis

* move all operations on abstract types to QuantumInterface.jl

* document the purposeful piracy of identityoperator

* test with QuantumInterface v0.2.0

* bump version number
  • Loading branch information
Krastanov authored Jun 9, 2023
1 parent c5543d8 commit d9b5ae1
Show file tree
Hide file tree
Showing 13 changed files with 26 additions and 553 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "QuantumOpticsBase"
uuid = "4f57444f-1401-5e15-980d-4471b28d5678"
version = "0.4.1"
version = "0.4.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -19,7 +19,7 @@ Adapt = "1, 2, 3.3"
FFTW = "1.2"
FillArrays = "0.13, 1"
LRUCache = "1"
QuantumInterface = "0.1.0"
QuantumInterface = "0.2.0"
Strided = "1, 2"
UnsafeArrays = "1"
julia = "1.3"
Expand Down
8 changes: 5 additions & 3 deletions src/QuantumOpticsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ module QuantumOpticsBase
using SparseArrays, LinearAlgebra, LRUCache, Strided, UnsafeArrays, FillArrays
import LinearAlgebra: mul!, rmul!

import QuantumInterface: dagger, directsum, , dm, embed, expect, permutesystems,
projector, ptrace, reduced, tensor,
import QuantumInterface: dagger, directsum, , dm, embed, expect, identityoperator,
permutesystems, projector, ptrace, reduced, tensor, , variance

# index helpers
import QuantumInterface: complement, remove, shiftremove, reducedindices!, check_indices, check_sortedindices, check_embed_indices

export Basis, GenericBasis, CompositeBasis, basis,
tensor, , permutesystems, @samebases,
Expand Down Expand Up @@ -55,7 +58,6 @@ export Basis, GenericBasis, CompositeBasis, basis,
SumBasis, directsum, , LazyDirectSum, getblock, setblock!,
qfunc, wigner, coherentspinstate, qfuncsu2, wignersu2

include("sortedindices.jl")
include("polynomials.jl")
include("bases.jl")
include("states.jl")
Expand Down
232 changes: 1 addition & 231 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,94 +13,7 @@ abstract type DataOperator{BL,BR} <: AbstractOperator{BL,BR} end


# Common error messages
arithmetic_unary_error(funcname, x::AbstractOperator) = throw(ArgumentError("$funcname is not defined for this type of operator: $(typeof(x)).\nTry to convert to another operator type first with e.g. dense() or sparse()."))
arithmetic_binary_error(funcname, a::AbstractOperator, b::AbstractOperator) = throw(ArgumentError("$funcname is not defined for this combination of types of operators: $(typeof(a)), $(typeof(b)).\nTry to convert to a common operator type first with e.g. dense() or sparse()."))
addnumbererror() = throw(ArgumentError("Can't add or subtract a number and an operator. You probably want 'op + identityoperator(op)*x'."))

length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int
basis(a::AbstractOperator) = (check_samebases(a); a.basis_l)

# Ensure scalar broadcasting
Base.broadcastable(x::AbstractOperator) = Ref(x)

# Arithmetic operations
+(a::AbstractOperator, b::AbstractOperator) = arithmetic_binary_error("Addition", a, b)
+(a::Number, b::AbstractOperator) = addnumbererror()
+(a::AbstractOperator, b::Number) = addnumbererror()
+(a::AbstractOperator) = a

-(a::AbstractOperator) = arithmetic_unary_error("Negation", a)
-(a::AbstractOperator, b::AbstractOperator) = arithmetic_binary_error("Subtraction", a, b)
-(a::Number, b::AbstractOperator) = addnumbererror()
-(a::AbstractOperator, b::Number) = addnumbererror()

*(a::AbstractOperator, b::AbstractOperator) = arithmetic_binary_error("Multiplication", a, b)
^(a::AbstractOperator, b::Integer) = Base.power_by_squaring(a, b)


dagger(a::AbstractOperator) = arithmetic_unary_error("Hermitian conjugate", a)
Base.adjoint(a::AbstractOperator) = dagger(a)

conj(a::AbstractOperator) = arithmetic_unary_error("Complex conjugate", a)
conj!(a::AbstractOperator) = conj(a::AbstractOperator)

# dense(a::AbstractOperator) = arithmetic_unary_error("Conversion to dense", a)

transpose(a::AbstractOperator) = arithmetic_unary_error("Transpose", a)

"""
ishermitian(op::AbstractOperator)
Check if an operator is Hermitian.
"""
ishermitian(op::AbstractOperator) = arithmetic_unary_error(ishermitian, op)


"""
tensor(x::AbstractOperator, y::AbstractOperator, z::AbstractOperator...)
Tensor product ``\\hat{x}⊗\\hat{y}⊗\\hat{z}⊗…`` of the given operators.
"""
tensor(a::AbstractOperator, b::AbstractOperator) = arithmetic_binary_error("Tensor product", a, b)
tensor(op::AbstractOperator) = op
tensor(operators::AbstractOperator...) = reduce(tensor, operators)


"""
embed(basis1[, basis2], indices::Vector, operators::Vector)
Tensor product of operators where missing indices are filled up with identity operators.
"""
function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
indices, operators::Vector{T}) where T<:AbstractOperator

@assert check_embed_indices(indices)

N = length(basis_l.bases)
@assert length(basis_r.bases) == N
@assert length(indices) == length(operators)

# Embed all single-subspace operators.
idxop_sb = [x for x in zip(indices, operators) if x[1] isa Integer]
indices_sb = [x[1] for x in idxop_sb]
ops_sb = [x[2] for x in idxop_sb]

for (idxsb, opsb) in zip(indices_sb, ops_sb)
(opsb.basis_l == basis_l.bases[idxsb]) || throw(IncompatibleBases())
(opsb.basis_r == basis_r.bases[idxsb]) || throw(IncompatibleBases())
end

S = length(operators) > 0 ? mapreduce(eltype, promote_type, operators) : Any
embed_op = tensor([i indices_sb ? ops_sb[indexin(i, indices_sb)[1]] : identityoperator(T, S, basis_l.bases[i], basis_r.bases[i]) for i=1:N]...)

# Embed all joint-subspace operators.
idxop_comp = [x for x in zip(indices, operators) if x[1] isa Array]
for (idxs, op) in idxop_comp
embed_op *= embed(basis_l, basis_r, idxs, op)
end

return embed_op
end
using QuantumInterface: arithmetic_binary_error, arithmetic_unary_error, addnumbererror


"""
Expand Down Expand Up @@ -156,12 +69,6 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,

return unpermuted_op
end
# The dictionary implementation works for non-DataOperators
embed(basis_l::CompositeBasis, basis_r::CompositeBasis, indices, op::T) where T<:AbstractOperator = embed(basis_l, basis_r, Dict(indices=>op))

embed(basis_l::CompositeBasis, basis_r::CompositeBasis, index::Integer, op::AbstractOperator) = embed(basis_l, basis_r, index, [op])
embed(basis::CompositeBasis, indices, operators::Vector{T}) where {T<:AbstractOperator} = embed(basis, basis, indices, operators)
embed(basis::CompositeBasis, indices, op::AbstractOperator) = embed(basis, basis, indices, op)

function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
index::Integer, op::T) where T<:DataOperator
Expand Down Expand Up @@ -195,70 +102,6 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
return Operator(basis_l, basis_r, data)
end

"""
embed(basis1[, basis2], operators::Dict)
`operators` is a dictionary `Dict{Vector{Int}, AbstractOperator}`. The integer vector
specifies in which subsystems the corresponding operator is defined.
"""
function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
operators::Dict{<:Vector{<:Integer}, T}) where T<:AbstractOperator
@assert length(basis_l.bases) == length(basis_r.bases)
N = length(basis_l.bases)::Int # type assertion to help type inference
if length(operators) == 0
return identityoperator(T, basis_l, basis_r)
end
indices, operator_list = zip(operators...)
operator_list = [operator_list...;]
S = mapreduce(eltype, promote_type, operator_list)
indices_flat = [indices...;]::Vector{Int} # type assertion to help type inference
start_indices_flat = [i[1] for i in indices]
complement_indices_flat = Int[i for i=1:N if i indices_flat]
operators_flat = AbstractOperator[]
if all(([minimum(I):maximum(I);]==I)::Bool for I in indices) # type assertion to help type inference
for i in 1:N
if i in complement_indices_flat
push!(operators_flat, identityoperator(T, S, basis_l.bases[i], basis_r.bases[i]))
elseif i in start_indices_flat
push!(operators_flat, operator_list[indexin(i, start_indices_flat)[1]])
end
end
return tensor(operators_flat...)
else
complement_operators = [identityoperator(T, S, basis_l.bases[i], basis_r.bases[i]) for i in complement_indices_flat]
op = tensor([operator_list; complement_operators]...)
perm = sortperm([indices_flat; complement_indices_flat])
return permutesystems(op, perm)
end
end
embed(basis_l::CompositeBasis, basis_r::CompositeBasis, operators::Dict{<:Integer, T}; kwargs...) where {T<:AbstractOperator} = embed(basis_l, basis_r, Dict([i]=>op_i for (i, op_i) in operators); kwargs...)
embed(basis::CompositeBasis, operators::Dict{<:Integer, T}; kwargs...) where {T<:AbstractOperator} = embed(basis, basis, operators; kwargs...)
embed(basis::CompositeBasis, operators::Dict{<:Vector{<:Integer}, T}; kwargs...) where {T<:AbstractOperator} = embed(basis, basis, operators; kwargs...)


"""
tr(x::AbstractOperator)
Trace of the given operator.
"""
tr(x::AbstractOperator) = arithmetic_unary_error("Trace", x)

ptrace(a::AbstractOperator, index) = arithmetic_unary_error("Partial trace", a)

"""
normalize(op)
Return the normalized operator so that its `tr(op)` is one.
"""
normalize(op::AbstractOperator) = op/tr(op)

"""
normalize!(op)
In-place normalization of the given operator so that its `tr(x)` is one.
"""
normalize!(op::AbstractOperator) = throw(ArgumentError("normalize! is not defined for this type of operator: $(typeof(op)).\n You may have to fall back to the non-inplace version 'normalize()'."))

"""
expect(op, state)
Expand All @@ -267,30 +110,15 @@ Expectation value of the given operator `op` for the specified `state`.
`state` can either be a (density) operator or a ket.
"""
expect(op::AbstractOperator{B,B}, state::Ket{B}) where B = dot(state.data, (op * state).data)
expect(op::AbstractOperator{B1,B2}, state::AbstractOperator{B2,B2}) where {B1,B2} = tr(op*state)

"""
expect(index, op, state)
If an `index` is given, it assumes that `op` is defined in the subsystem specified by this number.
"""
function expect(indices, op::AbstractOperator{B1,B2}, state::AbstractOperator{B3,B3}) where {B1,B2,B3<:CompositeBasis}
N = length(state.basis_l.shape)
indices_ = complement(N, indices)
expect(op, ptrace(state, indices_))
end
function expect(indices, op::AbstractOperator{B,B}, state::Ket{B2}) where {B,B2<:CompositeBasis}
N = length(state.basis.shape)
indices_ = complement(N, indices)
expect(op, ptrace(state, indices_))
end

expect(index::Integer, op::AbstractOperator{B1,B2}, state::AbstractOperator{B3,B3}) where {B1,B2,B3<:CompositeBasis} = expect([index], op, state)
expect(index::Integer, op::AbstractOperator{B,B}, state::Ket{B2}) where {B,B2<:CompositeBasis} = expect([index], op, state)

expect(op::AbstractOperator, states::Vector) = [expect(op, state) for state=states]
expect(indices, op::AbstractOperator, states::Vector) = [expect(indices, op, state) for state=states]

"""
variance(op, state)
Expand All @@ -302,60 +130,14 @@ function variance(op::AbstractOperator{B,B}, state::Ket{B}) where B
x = op*state
state.data'*(op*x).data - (state.data'*x.data)^2
end
function variance(op::AbstractOperator{B,B}, state::AbstractOperator{B,B}) where B
expect(op*op, state) - expect(op, state)^2
end

"""
variance(index, op, state)

If an `index` is given, it assumes that `op` is defined in the subsystem specified by this number
"""
function variance(indices, op::AbstractOperator{B,B}, state::AbstractOperator{BC,BC}) where {B,BC<:CompositeBasis}
N = length(state.basis_l.shape)
indices_ = complement(N, indices)
variance(op, ptrace(state, indices_))
end
function variance(indices, op::AbstractOperator{B,B}, state::Ket{BC}) where {B,BC<:CompositeBasis}
N = length(state.basis.shape)
indices_ = complement(N, indices)
variance(op, ptrace(state, indices_))
end

variance(index::Integer, op::AbstractOperator{B,B}, state::AbstractOperator{BC,BC}) where {B,BC<:CompositeBasis} = variance([index], op, state)
variance(index::Integer, op::AbstractOperator{B,B}, state::Ket{BC}) where {B,BC<:CompositeBasis} = variance([index], op, state)
variance(op::AbstractOperator, states::Vector) = [variance(op, state) for state=states]
variance(indices, op::AbstractOperator, states::Vector) = [variance(indices, op, state) for state=states]


"""
exp(op::AbstractOperator)
Operator exponential.
"""
exp(op::AbstractOperator) = throw(ArgumentError("exp() is not defined for this type of operator: $(typeof(op)).\nTry to convert to dense operator first with dense()."))

permutesystems(a::AbstractOperator, perm) = arithmetic_unary_error("Permutations of subsystems", a)

"""
identityoperator(a::Basis[, b::Basis])
identityoperator(::Type{<:AbstractOperator}, a::Basis[, b::Basis])
identityoperator(::Type{<:Number}, a::Basis[, b::Basis])
identityoperator(::Type{<:AbstractOperator}, ::Type{<:Number}, a::Basis[, b::Basis])
Return an identityoperator in the given bases. One can optionally specify the container
type which has to a subtype of [`AbstractOperator`](@ref) as well as the number type
to be used in the identity matrix.
"""
identityoperator(::Type{T}, ::Type{S}, b1::Basis, b2::Basis) where {T<:AbstractOperator,S} = throw(ArgumentError("Identity operator not defined for operator type $T."))
identityoperator(::Type{T}, ::Type{S}, b::Basis) where {T<:AbstractOperator,S} = identityoperator(T,S,b,b)
identityoperator(::Type{T}, bases::Basis...) where T<:AbstractOperator = identityoperator(T,eltype(T),bases...)
identityoperator(op::T) where {T<:AbstractOperator} = identityoperator(T, op.basis_l, op.basis_r)

# Catch case where eltype cannot be inferred from type; this is a bit hacky
identityoperator(::Type{T}, ::Type{Any}, b1::Basis, b2::Basis) where T<:AbstractOperator = identityoperator(T, ComplexF64, b1, b2)

one(x::Union{<:Basis,<:AbstractOperator}) = identityoperator(x)

# Helper functions to check validity of arguments
function check_ptrace_arguments(a::AbstractOperator, indices)
Expand Down Expand Up @@ -383,17 +165,5 @@ function check_ptrace_arguments(a::StateVector, indices)
check_indices(length(basis(a).shape), indices)
end

samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool
samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool
check_samebases(a::AbstractOperator) = check_samebases(a.basis_l, a.basis_r)

multiplicable(a::AbstractOperator, b::Ket) = multiplicable(a.basis_r, b.basis)
multiplicable(a::Bra, b::AbstractOperator) = multiplicable(a.basis, b.basis_l)
multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l)

Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r))
function Base.size(op::AbstractOperator, i::Int)
i < 1 && throw(ErrorException("dimension index is < 1"))
i > 2 && return 1
i==1 ? length(op.basis_l) : length(op.basis_r)
end
3 changes: 0 additions & 3 deletions src/operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,6 @@ function ptrace(psi::Bra, indices)
return Operator(b_, b_, result)
end

_index_complement(b::CompositeBasis, indices) = complement(length(b.bases), indices)
reduced(a, indices) = ptrace(a, _index_complement(basis(a), indices))

normalize!(op::Operator) = (rmul!(op.data, 1.0/tr(op)); op)

function expect(op::DataOperator{B,B}, state::Ket{B}) where B
Expand Down
13 changes: 2 additions & 11 deletions src/operators_sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@ SparseOperator(::Type{T},b::Basis) where T = SparseOperator(b,b,spzeros(T,length
SparseOperator(b1::Basis, b2::Basis) = SparseOperator(ComplexF64, b1, b2)
SparseOperator(b::Basis) = SparseOperator(ComplexF64, b, b)

"""
sparse(op::AbstractOperator)
Convert an arbitrary operator into a [`SparseOperator`](@ref).
"""
sparse(a::AbstractOperator) = throw(ArgumentError("Direct conversion from $(typeof(a)) not implemented. Use sparse(full(op)) instead."))
sparse(a::DataOperator) = Operator(a.basis_l, a.basis_r, sparse(a.data))

function ptrace(op::SparseOpPureType, indices)
Expand All @@ -56,7 +50,7 @@ function permutesystems(rho::SparseOpPureType{B1,B2}, perm) where {B1<:Composite
@assert length(rho.basis_l.bases) == length(rho.basis_r.bases) == length(perm)
@assert isperm(perm)
shape = [rho.basis_l.shape; rho.basis_r.shape]
data = permutedims(rho.data, shape, [perm; perm .+ length(perm)])
data = _permutedims(rho.data, shape, [perm; perm .+ length(perm)])
SparseOperator(permutesystems(rho.basis_l, perm), permutesystems(rho.basis_r, perm), data)
end

Expand All @@ -74,10 +68,7 @@ identityoperator(::Type{T}, ::Type{S}, b1::Basis, b2::Basis) where {T<:DataOpera
identityoperator(::Type{T}, ::Type{S}, b::Basis) where {T<:DataOperator,S<:Number} =
Operator(b, b, Eye{S}(length(b)))

identityoperator(::Type{T}, b1::Basis, b2::Basis) where T<:Number = identityoperator(DataOperator, T, b1, b2)
identityoperator(::Type{T}, b::Basis) where T<:Number = identityoperator(DataOperator, T, b)
identityoperator(b1::Basis, b2::Basis) = identityoperator(ComplexF64, b1, b2)
identityoperator(b::Basis) = identityoperator(ComplexF64, b)
identityoperator(::Type{T}, b1::Basis, b2::Basis) where T<:Number = identityoperator(DataOperator, T, b1, b2) # XXX This is purposeful type piracy over QuantumInterface, that hardcodes the use of QuantumOpticsBase.DataOperator in identityoperator. Also necessary for backward compatibility.

"""
diagonaloperator(b::Basis)
Expand Down
Loading

2 comments on commit d9b5ae1

@Krastanov
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/85266

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.2 -m "<description of version>" d9b5ae13745cfa95f735c3ec005fabba7169869c
git push origin v0.4.2

Please sign in to comment.