diff --git a/src/ConditionalJuMP.jl b/src/ConditionalJuMP.jl index 57b6b4b..8df9611 100644 --- a/src/ConditionalJuMP.jl +++ b/src/ConditionalJuMP.jl @@ -211,10 +211,14 @@ end newbinaryvar(m::Model, args...) = newbinaryvar(getindmap!(m), args...) -function getindicator!(m::IndicatorMap, c::Conditional) +function getindicator!(m::IndicatorMap, c::Conditional, can_create=true) if haskey(m.indicators, c) return m.indicators[c] else + if !can_create + @show c + error("Not allowed to create a new variable here. Something has gone wrong") + end z = newbinaryvar(m) implies!(m.model, z, c) m.indicators[c] = z @@ -242,6 +246,7 @@ end function disjunction!(indmap::IndicatorMap, imps::NTuple{2, Implication}) z = getindicator!(indmap, first(imps[1])) implies!(indmap.model, z, second(imps[1])) + indmap.indicators[first(imps[2])] = 1 - z implies!(indmap.model, 1 - z, first(imps[2])) implies!(indmap.model, 1 - z, second(imps[2])) push!(indmap.disjunctions, Implication[imps...]) @@ -389,7 +394,7 @@ function warmstart!(m::Model, fix=false) for i in eachindex(violations) imp = implications[i] lhs, rhs = imp - z = getindicator!(indmap, lhs) + z = getindicator!(indmap, lhs, false) satisfied = i == best_match if fix if !isfixed(z) diff --git a/test/runtests.jl b/test/runtests.jl index f6066c8..535b530 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -458,6 +458,21 @@ end @test getvalue(x) ≈ [0.8, 0.6] end + @testset "test that warmstart does not introduce new variables" begin + m = Model(solver=CbcSolver()) + @variable m -1 <= x <= 1 + @disjunction m (x == 0.5) (x == -0.5) + @objective m Min x + @test length(m.colCat) == 2 + warmstart!(m, false) + @test length(m.colCat) == 2 + solve(m) + @test getvalue(x) ≈ -0.5 + warmstart!(m, true) + @test length(m.colCat) == 2 + solve(m) + @test getvalue(x) ≈ -0.5 + end end @testset "examples" begin