diff --git a/src/Utilities/struct_of_constraints.jl b/src/Utilities/struct_of_constraints.jl index dd38494ea2..197baae977 100644 --- a/src/Utilities/struct_of_constraints.jl +++ b/src/Utilities/struct_of_constraints.jl @@ -2,8 +2,12 @@ abstract type StructOfConstraints <: MOI.ModelLike end function _throw_if_cannot_delete(model::StructOfConstraints, vis, fast_in_vis) broadcastcall(model) do constrs - return _throw_if_cannot_delete(constrs, vis, fast_in_vis) + if constrs !== nothing + _throw_if_cannot_delete(constrs, vis, fast_in_vis) + end + return end + return end function _deleted_constraints( callback::Function, @@ -11,24 +15,23 @@ function _deleted_constraints( vi, ) broadcastcall(model) do constrs - return _deleted_constraints(callback, constrs, vi) + if constrs !== nothing + _deleted_constraints(callback, constrs, vi) + end + return end + return end function MOI.add_constraint( model::StructOfConstraints, - func::MOI.AbstractFunction, - set::MOI.AbstractSet, -) - if MOI.supports_constraint(model, typeof(func), typeof(set)) - return MOI.add_constraint( - constraints(model, typeof(func), typeof(set)), - func, - set, - ) - else - throw(MOI.UnsupportedConstraint{typeof(func),typeof(set)}()) + func::F, + set::S, +) where {F<:MOI.AbstractFunction,S<:MOI.AbstractSet} + if !MOI.supports_constraint(model, F, S) + throw(MOI.UnsupportedConstraint{F,S}()) end + return MOI.add_constraint(constraints(model, F, S), func, set) end function constraints( @@ -49,18 +52,18 @@ function MOI.get( end function MOI.delete(model::StructOfConstraints, ci::MOI.ConstraintIndex) - return MOI.delete(constraints(model, ci), ci) + MOI.delete(constraints(model, ci), ci) + return end function MOI.is_valid( model::StructOfConstraints, ci::MOI.ConstraintIndex{F,S}, ) where {F,S} - if MOI.supports_constraint(model, F, S) - return MOI.is_valid(constraints(model, ci), ci) - else + if !MOI.supports_constraint(model, F, S) return false end + return MOI.is_valid(constraints(model, ci), ci) end function MOI.modify( @@ -68,7 +71,8 @@ function MOI.modify( ci::MOI.ConstraintIndex, change::MOI.AbstractFunctionModification, ) - return MOI.modify(constraints(model, ci), ci, change) + MOI.modify(constraints(model, ci), ci, change) + return end function MOI.set( @@ -77,45 +81,55 @@ function MOI.set( ci::MOI.ConstraintIndex, func_or_set, ) - return MOI.set(constraints(model, ci), attr, ci, func_or_set) + MOI.set(constraints(model, ci), attr, ci, func_or_set) + return end function MOI.get( model::StructOfConstraints, - loc::MOI.ListOfConstraintTypesPresent, -) where {T} - return broadcastvcat(model) do v - return MOI.get(v, loc) + attr::MOI.ListOfConstraintTypesPresent, +) + return broadcastvcat(model) do constrs + if constrs === nothing + return Tuple{DataType,DataType}[] + end + return MOI.get(constrs, attr) end end function MOI.get( model::StructOfConstraints, - noc::MOI.NumberOfConstraints{F,S}, + attr::MOI.NumberOfConstraints{F,S}, ) where {F,S} - if MOI.supports_constraint(model, F, S) - return MOI.get(constraints(model, F, S), noc) - else + if !MOI.supports_constraint(model, F, S) return 0 end + return MOI.get(constraints(model, F, S), attr) end function MOI.get( model::StructOfConstraints, - loc::MOI.ListOfConstraintIndices{F,S}, + attr::MOI.ListOfConstraintIndices{F,S}, ) where {F,S} - if MOI.supports_constraint(model, F, S) - return MOI.get(constraints(model, F, S), loc) - else + if !MOI.supports_constraint(model, F, S) return MOI.ConstraintIndex{F,S}[] end + return MOI.get(constraints(model, F, S), attr) end function MOI.is_empty(model::StructOfConstraints) - return mapreduce_constraints(MOI.is_empty, &, model, true) + return mapreduce_constraints(&, model, true) do constrs + return constrs === nothing || MOI.is_empty(constrs) + end end function MOI.empty!(model::StructOfConstraints) - return broadcastcall(MOI.empty!, model) + broadcastcall(model) do constrs + if constrs !== nothing + MOI.empty!(constrs) + end + return + end + return end # Can be used to access constraints of a model @@ -178,33 +192,13 @@ end # (MOI, :Zeros) -> :(MOI.Zeros) # (:Zeros) -> :(MOI.Zeros) -_set(s::SymbolSet) = esc(s.s) -_fun(s::SymbolFun) = esc(s.s) -function _typedset(s::SymbolSet) - if s.typed - T = esc(:T) - :($(_set(s)){$T}) - else - _set(s) - end -end -function _typedfun(s::SymbolFun) - if s.typed - T = esc(:T) - :($(_fun(s)){$T}) - else - _fun(s) - end -end +_typedset(s::SymbolSet) = s.typed ? Expr(:curly, esc(s.s), esc(:T)) : esc(s.s) +_typedfun(s::SymbolFun) = s.typed ? Expr(:curly, esc(s.s), esc(:T)) : esc(s.s) # Base.lowercase is moved to Unicode.lowercase in Julia v0.7 using Unicode _field(s::SymbolFS) = Symbol(replace(lowercase(string(s.s)), "." => "_")) - -_getC(s::SymbolSet) = :(VectorOfConstraints{F,$(_typedset(s))}) -_getC(s::SymbolFun) = _typedfun(s) - _callfield(f, s::SymbolFS) = :($f(model.$(_field(s)))) _broadcastfield(b, s::SymbolFS) = :($b(f, model.$(_field(s)))) _mapreduce_field(s::SymbolFS) = :(cur = op(cur, f(model.$(_field(s))))) @@ -223,62 +217,61 @@ If `types` is vector of `SymbolFun` (resp. `SymbolSet`) then the constraints of that function (resp. set) type are stored in the corresponding field. """ function struct_of_constraint_code(struct_name, types, field_types = nothing) - esc_struct_name = struct_name T = esc(:T) - typed_struct = :($(esc_struct_name){$T}) + typed_struct = :($(struct_name){$T}) type_parametrized = field_types === nothing if type_parametrized field_types = [Symbol("C$i") for i in eachindex(types)] append!(typed_struct.args, field_types) end + struct_def = :(mutable struct $typed_struct <: StructOfConstraints end) - struct_def = :(struct $typed_struct <: StructOfConstraints end) - - for (t, field_type) in zip(types, field_types) - field = _field(t) - push!(struct_def.args[3].args, :($field::$field_type)) - end code = quote - function $MOIU.broadcastcall(f::Function, model::$esc_struct_name) - return $(Expr(:block, _callfield.(Ref(:f), types)...)) + function $MOIU.broadcastcall(f::Function, model::$struct_name) + $(Expr(:block, _callfield.(Ref(:f), types)...)) + return end - function $MOIU.broadcastvcat(f::Function, model::$esc_struct_name) + + function $MOIU.broadcastvcat(f::Function, model::$struct_name) return vcat($(_callfield.(Ref(:f), types)...)) end + function $MOIU.mapreduce_constraints( f::Function, op::Function, - model::$esc_struct_name, + model::$struct_name, cur, ) return $(Expr(:block, _mapreduce_field.(types)...)) end end - for t in types - if t isa SymbolFun - fun = _fun(t) - set = :(MOI.AbstractSet) - else - fun = :(MOI.AbstractFunction) - set = _set(t) - end + for (t, field_type) in zip(types, field_types) field = _field(t) - code = quote - $code + push!(struct_def.args[3].args, :($field::Union{Nothing,$field_type})) + fun = t isa SymbolFun ? esc(t.s) : :(MOI.AbstractFunction) + set = t isa SymbolFun ? :(MOI.AbstractSet) : esc(t.s) + constraints_code = :( function $MOIU.constraints( - model::$esc_struct_name, + model::$typed_struct, ::Type{<:$fun}, ::Type{<:$set}, - ) + ) where {$T} + if model.$field === nothing + model.$field = $(field_type)() + end return model.$field end + ) + if type_parametrized + append!(constraints_code.args[1].args, field_types) end + push!(code.args, constraints_code) end supports_code = if eltype(types) <: SymbolFun quote function $MOI.supports_constraint( - model::$esc_struct_name{$T}, + model::$struct_name{$T}, ::Type{F}, ::Type{S}, ) where { @@ -293,7 +286,7 @@ function struct_of_constraint_code(struct_name, types, field_types = nothing) @assert eltype(types) <: SymbolSet quote function $MOI.supports_constraint( - model::$esc_struct_name{$T}, + model::$struct_name{$T}, ::Type{F}, ::Type{S}, ) where { @@ -307,7 +300,8 @@ function struct_of_constraint_code(struct_name, types, field_types = nothing) end expr = Expr(:block, struct_def, supports_code, code) if !isempty(field_types) - constructors = [:($field_type()) for field_type in field_types] + constructors = [:(nothing) for field_type in field_types] + # constructors = [:($field_type()) for field_type in field_types] # If there is no field type, the default constructor is sufficient and # adding this constructor will make a `StackOverflow`. constructor_code = :(function $typed_struct() where {$T}