Skip to content

Commit

Permalink
Add build_reduce_inflate_q
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiendesignolle committed Sep 25, 2024
1 parent 9cfa41d commit 79e50fd
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 25 deletions.
8 changes: 0 additions & 8 deletions src/bell_frank_wolfe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,4 @@ function bell_frank_wolfe(
return x, ds, primal, dual_gap, as, M, β
end
end
function bell_frank_wolfe(
p::Array{T, N},
build_reduce_inflate::Function;
kwargs...,
) where {T <: Number, N}
reduce, inflate = build_reduce_inflate(p)
return bell_frank_wolfe(p; sym=true, reduce, inflate, kwargs...)
end
export bell_frank_wolfe
32 changes: 15 additions & 17 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -706,38 +706,36 @@ function build_reduce_inflate_permutedims(p::Array{T, N}) where {T <: Number, N}
return reduce, inflate
end

function build_reduce_inflate_unique(p::Array{T, N}; digits=9) where {T <: Number, N}
ptol = round.(p; digits)
ptol[ptol .== zero(T)] .= zero(T) # transform -0.0 into 0.0 as isequal(0.0, -0.0) is false
uniquetol = unique(ptol[:])
dim = length(uniquetol) # reduced dimension
indices = [ptol .== u for u in uniquetol]
mul = [sum(ind) for ind in indices] # multiplicities, used to have matching scalar products
function build_reduce_inflate_q(::Type{T}, q::Array{<:Integer, N}) where {T <: Number, N}
dim = maximum(q) # reduced dimension
mul = zeros(Int, dim) # multiplicities, used to have matching scalar products
for qi in q
mul[qi] += 1
end
sqmul = sqrt.(T.(mul)) # precomputed for speed
function reduce(A::AbstractArray{S, N}, lmo=nothing) where {S <: AbstractFloat}
vec = zeros(S, dim)
@inbounds for (i, ind) in enumerate(indices)
vec[i] = sum(A[ind]) / sqmul[i]
@inbounds for (i, qi) in enumerate(q)
vec[qi] += A[i]
end
vec ./= sqmul
return FrankWolfe.SymmetricArray(A, vec)
end
function reduce(A::AbstractArray{S, N}, lmo=nothing) where {S <: Number}
vec = zeros(S, dim)
@inbounds for (i, ind) in enumerate(indices)
vec[i] = sum(A[ind]) / S(mul[i])
@inbounds for (i, qi) in enumerate(q)
vec[qi] += A[i]
end
vec ./= S.(mul)
return FrankWolfe.SymmetricArray(A, vec, mul)
end
function inflate(sa::FrankWolfe.SymmetricArray{false}, lmo=nothing)
@inbounds for (i, ind) in enumerate(indices)
@view(sa.data[ind]) .= sa.vec[i] / sqmul[i]
end
aux = sa.vec ./ sqmul
@inbounds sa.data .= aux[q]
return sa.data
end
function inflate(sa::FrankWolfe.SymmetricArray{true}, lmo=nothing)
@inbounds for (i, ind) in enumerate(indices)
@view(sa.data[ind]) .= sa.vec[i]
end
@inbounds sa.data .= sa.vec[q]
return sa.data
end
return reduce, inflate
Expand Down

0 comments on commit 79e50fd

Please sign in to comment.