Skip to content

Commit

Permalink
Add 4-partite alternating oracle
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiendesignolle committed Nov 27, 2024
1 parent 55c683e commit c6f5e2f
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 5 deletions.
3 changes: 3 additions & 0 deletions src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ function build_callback(
if save && mod(state.t, save_interval) == 0
serialize(file * "_tmp.dat", ActiveSetStorage(active_set))
end
# if state.dual_gap < state.primal / 2
# return false
# end
return state.primal > epsilon
end
return callback
Expand Down
10 changes: 5 additions & 5 deletions src/quantum_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,17 +267,17 @@ function probability_tensor(
return p
end

# convert a N sets of m d-outcome POVMs acting on C^e into a ex...xexmx...xm probability array
# convert a N sets of m o-outcome POVMs acting on C^d into a dx...xdxmx...xm probability array
function probability_tensor(
Aax::Vector{TB},
N::Int;
rho = rho_GHZ(N; d = size(Aax[1], 1), type = T),
) where {TB <: AbstractArray{Complex{T}, 4}} where {T <: Number}
e, _, d, m = size(Aax[1])
d, _, o, m = size(Aax[1])
@assert length(Aax) == N
@assert size(rho) == (e^N, e^N)
p = zeros(T, e * ones(Int, N)..., m * ones(Int, N)...)
cia = CartesianIndices(Tuple(e * ones(Int, N)))
@assert size(rho) == (d^N, d^N)
p = zeros(T, o * ones(Int, N)..., m * ones(Int, N)...)
cia = CartesianIndices(Tuple(o * ones(Int, N)))
cix = CartesianIndices(Tuple(m * ones(Int, N)))
for a in cia, x in cix
p[a, x] = real(tr(kron([Aax[n][:, :, a[n], x[n]] for n in 1:N]...) * rho))
Expand Down
67 changes: 67 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,73 @@ function alternating_minimisation!(
return sc1
end

function alternating_minimisation!(
ax::Vector{Vector{Int}},
lmo::BellProbabilitiesLMO{T, 8, 0},
A::Array{T, 8},
) where {T <: Number}
sc1 = zero(T)
sc2 = one(T)
@inbounds while sc1 < sc2
sc2 = sc1
for x4 in 1:length(ax[4])
for a4 in 1:lmo.o[4]
s = zero(T)
for x1 in 1:length(ax[1]), x2 in 1:length(ax[2]), x3 in 1:length(ax[3])
s += A[ax[1][x1], ax[2][x2], ax[3][x3], a4, x1, x2, x3, x4]
end
lmo.tmp[4][x4, a4] = s
end
end
for x4 in 1:length(ax[4])
ax[4][x4] = argmin(lmo.tmp[4][x4, :])[1]
end
for x3 in 1:length(ax[3])
for a3 in 1:lmo.o[3]
s = zero(T)
for x1 in 1:length(ax[1]), x2 in 1:length(ax[2]), x4 in 1:length(ax[4])
s += A[ax[1][x1], ax[2][x2], ax[4][x4], a3, x1, x2, x3, x4]
end
lmo.tmp[3][x3, a3] = s
end
end
for x3 in 1:length(ax[3])
ax[3][x3] = argmin(lmo.tmp[3][x3, :])[1]
end
for x2 in 1:length(ax[2])
for a2 in 1:lmo.o[2]
s = zero(T)
for x1 in 1:length(ax[1]), x3 in 1:length(ax[3]), x4 in 1:length(ax[4])
s += A[ax[1][x1], a2, ax[3][x3], ax[4][x4], x1, x2, x3, x4]
end
lmo.tmp[2][x2, a2] = s
end
end
for x2 in 1:length(ax[2])
ax[2][x2] = argmin(lmo.tmp[2][x2, :])[1]
end
for x1 in 1:length(ax[1])
for a1 in 1:lmo.o[1]
s = zero(T)
for x2 in 1:length(ax[2]), x3 in 1:length(ax[3]), x4 in 1:length(ax[4])
s += A[a1, ax[2][x2], ax[3][x3], ax[4][x4], x1, x2, x3, x4]
end
lmo.tmp[1][x1, a1] = s
end
end
for x1 in 1:length(ax[1])
ax[1][x1] = argmin(lmo.tmp[1][x1, :])[1]
end
# uses the precomputed sum of lines to compute the scalar product
sc1 = zero(T)
for x1 in 1:length(ax[1])
sc1 += lmo.tmp[1][x1, ax[1][x1]]
end
end
return sc1
end


##############
# ACTIVE SET #
##############
Expand Down

0 comments on commit c6f5e2f

Please sign in to comment.