Skip to content

Commit

Permalink
Merge pull request #12 from ZIB-IOL/seb
Browse files Browse the repository at this point in the history
Clean branch structure
  • Loading branch information
sebastiendesignolle authored Sep 25, 2024
2 parents 08fd228 + 27fb554 commit ae91e83
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 36 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://zib-iol.github.io/BellPolytopes.jl/dev/)
[![Build Status](https://github.com/zib-iol/BellPolytopes.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/zib-iol/BellPolytopes.jl/actions/workflows/CI.yml?query=branch%3Amain)

This package addresses the membership problem for local polytopes: it constructs Bell inequalities and local models in multipartite Bell scenarios with binary outcomes.
This package addresses the membership problem for local polytopes: it constructs Bell inequalities and local models in multipartite Bell scenarios with arbitrary settings.

The original article for which it was written can be found here:

Expand Down
8 changes: 0 additions & 8 deletions src/bell_frank_wolfe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,4 @@ function bell_frank_wolfe(
return x, ds, primal, dual_gap, as, M, β
end
end
function bell_frank_wolfe(
p::Array{T, N},
build_reduce_inflate::Function;
kwargs...,
) where {T <: Number, N}
reduce, inflate = build_reduce_inflate(p)
return bell_frank_wolfe(p; sym=true, reduce, inflate, kwargs...)
end
export bell_frank_wolfe
38 changes: 34 additions & 4 deletions src/fw_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,15 @@ function FrankWolfe.compute_extreme_point(
lmo::BellProbabilitiesLMO{T, 4, 1},
A::Array{T, 4};
verbose=false,
count=false,
kwargs...,
) where {T <: Number}
ax = [ones(Int, lmo.m[n]) for n in 1:2]
sc = zero(T)
axm = [zeros(Int, lmo.m[n]) for n in 1:2]
scm = typemax(T)
# set containing all optimal strategies when count=true
setm = Set{Array{T, 4}}()
for λa2 in 0:lmo.o[2]^lmo.m[2]-1
digits!(ax[2], λa2; base=lmo.o[2])
ax[2] .+= 1
Expand All @@ -385,10 +388,17 @@ function FrankWolfe.compute_extreme_point(
for n in 1:2
axm[n] .= ax[n]
end
empty!(setm)
end
if verbose && sc scm
println(rpad(string([λa2]), 2 + ndigits(2^(sum(lmo.m)÷2))), " ", string(-scm))
println(rpad(string([λa2]), 2 + ndigits(lmo.o[2]^lmo.m[2])), " ", string(-scm))
end
if count && sc scm
push!(setm, collect(BellProbabilitiesDS(ax, lmo)))
end
end
if count
println(length(setm))
end
dsm = BellProbabilitiesDS(axm, lmo)
lmo.cnt += 1
Expand All @@ -399,12 +409,15 @@ function FrankWolfe.compute_extreme_point(
lmo::BellProbabilitiesLMO{T, 4, 2},
A::Array{T, 4};
verbose=false,
count=false,
kwargs...,
) where {T <: Number}
ax = [ones(Int, lmo.m[n]) for n in 1:2]
sc = zero(T)
axm = [zeros(Int, lmo.m[n]) for n in 1:2]
scm = typemax(T)
# set containing all optimal strategies when count=true
setm = Set{Array{T, 4}}()
for λa1 in 0:lmo.o[1]^lmo.m[1]-1
digits!(ax[1], λa1; base=lmo.o[1])
ax[1] .+= 1
Expand All @@ -429,10 +442,17 @@ function FrankWolfe.compute_extreme_point(
for n in 1:2
axm[n] .= ax[n]
end
empty!(setm)
end
if verbose && sc scm
println(rpad(string([λa2]), 2 + ndigits(2^(sum(lmo.m)÷2))), " ", string(-scm))
println(rpad(string([λa2]), 2 + ndigits(lmo.o[1]^lmo.m[1])), " ", string(-scm))
end
if count && sc scm
push!(setm, collect(BellProbabilitiesDS(ax, lmo)))
end
end
if count
println(length(setm))
end
dsm = BellProbabilitiesDS(axm, lmo)
lmo.cnt += 1
Expand All @@ -443,13 +463,16 @@ function FrankWolfe.compute_extreme_point(
lmo::BellProbabilitiesLMO{T, 6, 1},
A::Array{T, 6};
verbose=false,
sym = false,
count=false,
sym=false,
kwargs...,
) where {T <: Number}
ax = [ones(Int, lmo.m[n]) for n in 1:3]
sc = zero(T)
axm = [zeros(Int, lmo.m[n]) for n in 1:3]
scm = typemax(T)
# set containing all optimal strategies when count=true
setm = Set{Array{T, 4}}()
for λa3 in 0:lmo.o[3]^lmo.m[3]-1
digits!(ax[3], λa3; base=lmo.o[3])
ax[3] .+= 1
Expand Down Expand Up @@ -477,12 +500,19 @@ function FrankWolfe.compute_extreme_point(
for n in 1:3
axm[n] .= ax[n]
end
empty!(setm)
end
if verbose && sc scm
println(rpad(string([λa3, λa2]), 4 + 2ndigits(2^(sum(lmo.m)÷3))), " ", string(-scm))
println(rpad(string([λa3, λa2]), 4 + ndigits(lmo.o[3]^lmo.m[3]) + ndigits(lmo.o[2]^lmo.m[2])), " ", string(-scm))
end
if count && sc scm
push!(setm, collect(BellProbabilitiesDS(ax, lmo)))
end
end
end
if count
println(length(setm))
end
dsm = BellProbabilitiesDS(axm, lmo)
lmo.cnt += 1
return dsm
Expand Down
8 changes: 4 additions & 4 deletions src/local_bound.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ function local_bound_correlation(
marg::Bool=false,
mode::Int=0,
nb::Int=10^4,
verbose=false,
kwargs...
) where {T <: Number} where {N}
ds = FrankWolfe.compute_extreme_point(BellCorrelationsLMO(M, M; marg, mode, nb), -M; verbose)
ds = FrankWolfe.compute_extreme_point(BellCorrelationsLMO(M, M; marg, mode, nb), -M; kwargs...)
return FrankWolfe.fast_dot(M, ds), ds
end
export local_bound_correlation
Expand All @@ -27,9 +27,9 @@ function local_bound_probability(
M::Array{T, N};
mode::Int=0,
nb::Int=10^4,
verbose=false,
kwargs...
) where {T <: Number} where {N}
ds = FrankWolfe.compute_extreme_point(BellProbabilitiesLMO(M, M; mode, nb), -M; verbose)
ds = FrankWolfe.compute_extreme_point(BellProbabilitiesLMO(M, M; mode, nb), -M; kwargs...)
return FrankWolfe.fast_dot(M, ds), ds
end
export local_bound_probability
Expand Down
36 changes: 17 additions & 19 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ function build_reduce_inflate_permutedims(p::Array{T, N}) where {T <: Number, N}
orbs = [unique(permutations(c)) for c in with_replacement_combinations(Int8.(1:m), N)]
dimension = length(orbs)
mul = length.(orbs)
sqmul = sqrt.(mul)
sqmul = sqrt.(T.(mul))
function reduce(A::AbstractArray{S, N}, lmo=nothing) where {S <: AbstractFloat}
vec = Vector{S}(undef, dimension)
@inbounds for i in 1:dimension
Expand Down Expand Up @@ -706,38 +706,36 @@ function build_reduce_inflate_permutedims(p::Array{T, N}) where {T <: Number, N}
return reduce, inflate
end

function build_reduce_inflate_unique(p::Array{T, N}; digits=9) where {T <: Number, N}
ptol = round.(p; digits)
ptol[ptol .== zero(T)] .= zero(T) # transform -0.0 into 0.0 as isequal(0.0, -0.0) is false
uniquetol = unique(ptol[:])
dim = length(uniquetol) # reduced dimension
indices = [ptol .== u for u in uniquetol]
mul = [sum(ind) for ind in indices] # multiplicities, used to have matching scalar products
sqmul = sqrt.(mul) # precomputed for speed
function build_reduce_inflate_q(::Type{T}, q::Array{<:Integer, N}) where {T <: Number, N}
dim = maximum(q) # reduced dimension
mul = zeros(Int, dim) # multiplicities, used to have matching scalar products
for qi in q
mul[qi] += 1
end
sqmul = sqrt.(T.(mul)) # precomputed for speed
function reduce(A::AbstractArray{S, N}, lmo=nothing) where {S <: AbstractFloat}
vec = zeros(S, dim)
@inbounds for (i, ind) in enumerate(indices)
vec[i] = sum(A[ind]) / sqmul[i]
@inbounds for (i, qi) in enumerate(q)
vec[qi] += A[i]
end
vec ./= sqmul
return FrankWolfe.SymmetricArray(A, vec)
end
function reduce(A::AbstractArray{S, N}, lmo=nothing) where {S <: Number}
vec = zeros(S, dim)
@inbounds for (i, ind) in enumerate(indices)
vec[i] = sum(A[ind]) / S(mul[i])
@inbounds for (i, qi) in enumerate(q)
vec[qi] += A[i]
end
vec ./= S.(mul)
return FrankWolfe.SymmetricArray(A, vec, mul)
end
function inflate(sa::FrankWolfe.SymmetricArray{false}, lmo=nothing)
@inbounds for (i, ind) in enumerate(indices)
@view(sa.data[ind]) .= sa.vec[i] / sqmul[i]
end
aux = sa.vec ./ sqmul
@inbounds sa.data .= aux[q]
return sa.data
end
function inflate(sa::FrankWolfe.SymmetricArray{true}, lmo=nothing)
@inbounds for (i, ind) in enumerate(indices)
@view(sa.data[ind]) .= sa.vec[i]
end
@inbounds sa.data .= sa.vec[q]
return sa.data
end
return reduce, inflate
Expand Down

0 comments on commit ae91e83

Please sign in to comment.