diff --git a/src/bell_frank_wolfe.jl b/src/bell_frank_wolfe.jl index 29a79ac..09d9636 100644 --- a/src/bell_frank_wolfe.jl +++ b/src/bell_frank_wolfe.jl @@ -239,12 +239,4 @@ function bell_frank_wolfe( return x, ds, primal, dual_gap, as, M, β end end -function bell_frank_wolfe( - p::Array{T, N}, - build_reduce_inflate::Function; - kwargs..., -) where {T <: Number, N} - reduce, inflate = build_reduce_inflate(p) - return bell_frank_wolfe(p; sym=true, reduce, inflate, kwargs...) -end export bell_frank_wolfe diff --git a/src/utils.jl b/src/utils.jl index 8b8b222..0a7349c 100755 --- a/src/utils.jl +++ b/src/utils.jl @@ -706,38 +706,36 @@ function build_reduce_inflate_permutedims(p::Array{T, N}) where {T <: Number, N} return reduce, inflate end -function build_reduce_inflate_unique(p::Array{T, N}; digits=9) where {T <: Number, N} - ptol = round.(p; digits) - ptol[ptol .== zero(T)] .= zero(T) # transform -0.0 into 0.0 as isequal(0.0, -0.0) is false - uniquetol = unique(ptol[:]) - dim = length(uniquetol) # reduced dimension - indices = [ptol .== u for u in uniquetol] - mul = [sum(ind) for ind in indices] # multiplicities, used to have matching scalar products +function build_reduce_inflate_q(::Type{T}, q::Array{<:Integer, N}) where {T <: Number, N} + dim = maximum(q) # reduced dimension + mul = zeros(Int, dim) # multiplicities, used to have matching scalar products + for qi in q + mul[qi] += 1 + end sqmul = sqrt.(T.(mul)) # precomputed for speed function reduce(A::AbstractArray{S, N}, lmo=nothing) where {S <: AbstractFloat} vec = zeros(S, dim) - @inbounds for (i, ind) in enumerate(indices) - vec[i] = sum(A[ind]) / sqmul[i] + @inbounds for (i, qi) in enumerate(q) + vec[qi] += A[i] end + vec ./= sqmul return FrankWolfe.SymmetricArray(A, vec) end function reduce(A::AbstractArray{S, N}, lmo=nothing) where {S <: Number} vec = zeros(S, dim) - @inbounds for (i, ind) in enumerate(indices) - vec[i] = sum(A[ind]) / S(mul[i]) + @inbounds for (i, qi) in enumerate(q) + vec[qi] += A[i] end + vec ./= S.(mul) return FrankWolfe.SymmetricArray(A, vec, mul) end function inflate(sa::FrankWolfe.SymmetricArray{false}, lmo=nothing) - @inbounds for (i, ind) in enumerate(indices) - @view(sa.data[ind]) .= sa.vec[i] / sqmul[i] - end + aux = sa.vec ./ sqmul + @inbounds sa.data .= aux[q] return sa.data end function inflate(sa::FrankWolfe.SymmetricArray{true}, lmo=nothing) - @inbounds for (i, ind) in enumerate(indices) - @view(sa.data[ind]) .= sa.vec[i] - end + @inbounds sa.data .= sa.vec[q] return sa.data end return reduce, inflate